Skip to content

Commit fed1b55

Browse files
authored
Add support for tagged output type hints. (#37434)
* Add tagged typehint support. * Just warn when bare tagged output * Remove contains tagged output check. * Mapped bare TaggedOutput to Any * Extract tagged outputs after strip_iterable.
1 parent 99e4868 commit fed1b55

7 files changed

Lines changed: 712 additions & 34 deletions

File tree

sdks/python/apache_beam/pvalue.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,8 @@ def __init__(
265265
self._tags = tags
266266
self._main_tag = main_tag
267267
self._transform = transform
268+
self._tagged_output_types = (
269+
transform.get_type_hints().tagged_output_types() if transform else {})
268270
self._allow_unknown_tags = (
269271
not tags if allow_unknown_tags is None else allow_unknown_tags)
270272
# The ApplyPTransform instance for the application of the multi FlatMap
@@ -322,7 +324,7 @@ def __getitem__(self, tag: Union[int, str, None]) -> PCollection:
322324
pcoll = PCollection(
323325
self._pipeline,
324326
tag=tag,
325-
element_type=typehints.Any,
327+
element_type=self._tagged_output_types.get(tag, typehints.Any),
326328
is_bounded=is_bounded)
327329
# Transfer the producer from the DoOutputsTuple to the resulting
328330
# PCollection.
@@ -342,15 +344,19 @@ def __getitem__(self, tag: Union[int, str, None]) -> PCollection:
342344
return pcoll
343345

344346

345-
class TaggedOutput(object):
347+
TagType = TypeVar('TagType', bound=str)
348+
ValueType = TypeVar('ValueType')
349+
350+
351+
class TaggedOutput(Generic[TagType, ValueType]):
346352
"""An object representing a tagged value.
347353
348354
ParDo, Map, and FlatMap transforms can emit values on multiple outputs which
349355
are distinguished by string tags. The DoFn will return plain values
350356
if it wants to emit on the main output and TaggedOutput objects
351357
if it wants to emit a value on a specific tagged output.
352358
"""
353-
def __init__(self, tag: str, value: Any) -> None:
359+
def __init__(self, tag: TagType, value: ValueType) -> None:
354360
if not isinstance(tag, str):
355361
raise TypeError(
356362
'Attempting to create a TaggedOutput with non-string tag %s' %

sdks/python/apache_beam/transforms/core.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,7 @@ def default_type_hints(self):
824824
process_type_hints = process_type_hints.strip_iterable()
825825
except ValueError as e:
826826
raise ValueError('Return value not iterable: %s: %s' % (self, e))
827+
process_type_hints = process_type_hints.extract_tagged_outputs()
827828

828829
# Prefer class decorator type hints for backwards compatibility.
829830
return get_type_hints(self.__class__).with_defaults(process_type_hints)
@@ -1039,6 +1040,7 @@ def default_type_hints(self):
10391040
raise TypeCheckError(
10401041
'Return value not iterable: %s: %s' %
10411042
(self.display_data()['fn'].value, e))
1043+
type_hints = type_hints.extract_tagged_outputs()
10421044
return type_hints
10431045

10441046
def infer_output_type(self, input_type):
@@ -1834,6 +1836,17 @@ def with_outputs(self, *tags, main=None, allow_unknown_tags=None):
18341836
raise ValueError(
18351837
'Main output tag %r must be different from side output tags %r.' %
18361838
(main, tags))
1839+
type_hints = self.get_type_hints()
1840+
declared_tags = set(type_hints.tagged_output_types().keys())
1841+
requested_tags = set(tags)
1842+
1843+
unknown = requested_tags - declared_tags
1844+
if unknown and declared_tags: # Only warn if type hints exist
1845+
logging.warning(
1846+
"Tags %s requested in with_outputs() but not declared "
1847+
"in type hints. Declared tags: %s",
1848+
unknown,
1849+
declared_tags)
18371850
return _MultiParDo(self, tags, main, allow_unknown_tags)
18381851

18391852
def _do_fn_info(self):
@@ -2120,8 +2133,14 @@ def Map(fn, *args, **kwargs): # pylint: disable=invalid-name
21202133
wrapper)
21212134
output_hint = type_hints.simple_output_type(label)
21222135
if output_hint:
2136+
tagged = {
2137+
k: typehints.Iterable[v]
2138+
for k, v in type_hints.tagged_output_types().items()
2139+
}
21232140
wrapper = with_output_types(
2124-
typehints.Iterable[_strip_output_annotations(output_hint)])(
2141+
typehints.Iterable[_strip_output_annotations(
2142+
output_hint, strip_tagged_output=False)],
2143+
**tagged)(
21252144
wrapper)
21262145
# pylint: disable=protected-access
21272146
wrapper._argspec_fn = fn
@@ -2189,8 +2208,14 @@ def MapTuple(fn, *args, **kwargs): # pylint: disable=invalid-name
21892208
pass
21902209
output_hint = type_hints.simple_output_type(label)
21912210
if output_hint:
2211+
tagged = {
2212+
k: typehints.Iterable[v]
2213+
for k, v in type_hints.tagged_output_types().items()
2214+
}
21922215
wrapper = with_output_types(
2193-
typehints.Iterable[_strip_output_annotations(output_hint)])(
2216+
typehints.Iterable[_strip_output_annotations(
2217+
output_hint, strip_tagged_output=False)],
2218+
**tagged)(
21942219
wrapper)
21952220

21962221
# Replace the first (args) component.
@@ -2261,7 +2286,10 @@ def FlatMapTuple(fn, *args, **kwargs): # pylint: disable=invalid-name
22612286
pass
22622287
output_hint = type_hints.simple_output_type(label)
22632288
if output_hint:
2264-
wrapper = with_output_types(_strip_output_annotations(output_hint))(wrapper)
2289+
wrapper = with_output_types(
2290+
_strip_output_annotations(output_hint, strip_tagged_output=False),
2291+
**type_hints.tagged_output_types())(
2292+
wrapper)
22652293

22662294
# Replace the first (args) component.
22672295
modified_arg_names = ['tuple_element'] + arg_names[-num_defaults:]
@@ -4222,12 +4250,15 @@ def from_runner_api_parameter(
42224250
return Impulse()
42234251

42244252

4225-
def _strip_output_annotations(type_hint):
4253+
def _strip_output_annotations(type_hint, strip_tagged_output=True):
42264254
# TODO(robertwb): These should be parameterized types that the
42274255
# type inferencer understands.
42284256
# Then we can replace them with the correct element types instead of
42294257
# using Any. Refer to typehints.WindowedValue when doing this.
4230-
annotations = (TimestampedValue, WindowedValue, pvalue.TaggedOutput)
4258+
annotations = [TimestampedValue, WindowedValue]
4259+
if strip_tagged_output:
4260+
annotations.append(pvalue.TaggedOutput)
4261+
annotations = tuple(annotations)
42314262

42324263
contains_annotation = False
42334264

sdks/python/apache_beam/transforms/ptransform.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -414,12 +414,15 @@ def with_input_types(self, input_type_hint):
414414
input_type_hint, 'Type hints for a PTransform')
415415
return super().with_input_types(input_type_hint)
416416

417-
def with_output_types(self, type_hint):
417+
def with_output_types(self, type_hint, **tagged_type_hints):
418418
"""Annotates the output type of a :class:`PTransform` with a type-hint.
419419
420420
Args:
421421
type_hint (type): An instance of an allowed built-in type, a custom class,
422-
or a :class:`~apache_beam.typehints.typehints.TypeConstraint`.
422+
or a :class:`~apache_beam.typehints.typehints.TypeConstraint`. This is
423+
the type hint for the main output.
424+
**tagged_type_hints: Type hints for tagged outputs. Each keyword argument
425+
specifies the type for a tagged output e.g., ``errors=str``.
423426
424427
Raises:
425428
TypeError: If **type_hint** is not a valid type-hint. See
@@ -430,10 +433,22 @@ def with_output_types(self, type_hint):
430433
PTransform: A reference to the instance of this particular
431434
:class:`PTransform` object. This allows chaining type-hinting related
432435
methods.
436+
437+
Example::
438+
result = pcoll | beam.ParDo(MyDoFn()).with_output_types(
439+
int, # main output type
440+
errors=str, # 'errors' tagged output type
441+
warnings=str # 'warnings' tagged output type
442+
).with_outputs('errors', 'warnings', main='main')
433443
"""
434444
type_hint = native_type_compatibility.convert_to_beam_type(type_hint)
435445
validate_composite_type_param(type_hint, 'Type hints for a PTransform')
436-
return super().with_output_types(type_hint)
446+
for tag, hint in tagged_type_hints.items():
447+
tagged_type_hints[tag] = native_type_compatibility.convert_to_beam_type(
448+
hint)
449+
validate_composite_type_param(
450+
tagged_type_hints[tag], f'Tagged output type hint for {tag!r}')
451+
return super().with_output_types(type_hint, **tagged_type_hints)
437452

438453
def with_resource_hints(self, **kwargs): # type: (...) -> PTransform
439454
"""Adds resource hints to the :class:`PTransform`.
@@ -479,10 +494,11 @@ def type_check_inputs_or_outputs(self, pvalueish, input_or_output):
479494
if hints is None or not any(hints):
480495
return
481496
arg_hints, kwarg_hints = hints
482-
if arg_hints and kwarg_hints:
497+
# Output types can have kwargs for tagged output types.
498+
if arg_hints and kwarg_hints and input_or_output != 'output':
483499
raise TypeCheckError(
484-
'PTransform cannot have both positional and keyword type hints '
485-
'without overriding %s._type_check_%s()' %
500+
'PTransform cannot have both positional and keyword input type hints'
501+
' without overriding %s._type_check_%s()' %
486502
(self.__class__, input_or_output))
487503
root_hint = (
488504
arg_hints[0] if len(arg_hints) == 1 else arg_hints or kwarg_hints)

0 commit comments

Comments
 (0)