diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index 3cce2c5bb773..b529c6be39fe 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -113,6 +113,42 @@ __all__ = ['Pipeline', 'transform_annotations'] +def _descendant_applied_ptransforms( + current: 'AppliedPTransform') -> set['AppliedPTransform']: + descendants = set() + pending = list(current.parts) + while pending: + part = pending.pop() + if part in descendants: + continue + descendants.add(part) + pending.extend(part.parts) + return descendants + + +def _register_top_level_side_outputs( + current: 'AppliedPTransform', result: pvalue.PCollection) -> None: + if not result._side_outputs: + return + + descendants = _descendant_applied_ptransforms(current) + for side_tag, side_pcoll in result._side_outputs.items(): + if side_pcoll.producer is not current and side_pcoll.producer not in descendants: + raise ValueError( + f"Side output {side_tag!r} must be produced by " + f"{current.full_label!r} or one of its descendant transforms.") + + existing = current.outputs.get(side_tag) + if existing is not None: + if existing is not side_pcoll: + raise ValueError( + f"Side output tag {side_tag!r} conflicts with an existing " + 'output of the same transform.') + continue + + current.add_output(side_pcoll, side_tag) + + class Pipeline(HasDisplayData): """A pipeline object that manages a DAG of :class:`~apache_beam.transforms.ptransform.PTransform` s @@ -411,6 +447,9 @@ def _replace_if_needed( new_output, new_output._main_tag) else: replacement_transform_node.add_output(new_output, new_output.tag) + if isinstance(new_output, pvalue.PCollection): + _register_top_level_side_outputs( + replacement_transform_node, new_output) # Recording updated outputs. This cannot be done in the same # visitor since if we dynamically update output type here, we'll @@ -844,6 +883,9 @@ def _apply_internal( tag = '%s_%d' % (base, counter) current.add_output(result, tag) + if result is pvalueish_result and isinstance(result, + pvalue.PCollection): + _register_top_level_side_outputs(current, result) if (type_options is not None and type_options.type_check_strictness == 'ALL_REQUIRED' and diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py index b28fe3c3d14e..e19859eb49eb 100644 --- a/sdks/python/apache_beam/pipeline_test.py +++ b/sdks/python/apache_beam/pipeline_test.py @@ -121,6 +121,31 @@ def expand(self, pvalues): AddThenMultiplyDoFn(), AsSingleton(pvalues[1]), AsSingleton(pvalues[2])) +def _all_applied_transforms(pipeline): + all_applied_transforms = {} + current_transforms = list(pipeline.transforms_stack) + while current_transforms: + applied_transform = current_transforms.pop() + all_applied_transforms[applied_transform.full_label] = applied_transform + current_transforms.extend(applied_transform.parts) + return all_applied_transforms + + +class _RemoveEvensDoFn(beam.DoFn): + def process(self, element): + if element % 2 == 0: + yield TaggedOutput('dropped', element) + else: + yield element + + +class RemoveEvens(beam.PTransform): + def expand(self, pcoll): + split = pcoll | 'Split' >> beam.ParDo(_RemoveEvensDoFn()).with_outputs( + 'dropped', main='main') + return split.main.with_side_outputs(dropped=split.dropped) + + class PipelineTest(unittest.TestCase): @staticmethod def custom_callable(pcoll): @@ -642,6 +667,128 @@ def mux_input(x): self.assertNotIn(multi.letters, visitor.visited) self.assertNotIn(multi.numbers, visitor.visited) + def test_pcollection_side_outputs_end_to_end(self): + with TestPipeline() as pipeline: + out = ( + pipeline + | beam.Create([1, 2, 3, 4]) + | 'RemoveEvens' >> RemoveEvens()) + chained = out | 'ChainMainOutput' >> beam.Map(lambda x: x * 10) + + self.assertIsInstance(out.side_outputs.dropped, beam.pvalue.PCollection) + assert_that(out, equal_to([1, 3]), label='assert_main_output') + assert_that( + out.side_outputs.dropped, + equal_to([2, 4]), + label='assert_side_output') + assert_that(chained, equal_to([10, 30]), label='assert_chained_output') + + applied_transform = _all_applied_transforms(pipeline)['RemoveEvens'] + self.assertIs(applied_transform.outputs[None], out) + self.assertIs( + applied_transform.outputs['dropped'], out.side_outputs.dropped) + + def test_pcollection_side_outputs_rejects_foreign_pcollection(self): + class ExposeForeignSideOutput(beam.PTransform): + def __init__(self, foreign): + self._foreign = foreign + + def expand(self, pcoll): + main = pcoll | 'Main' >> beam.Map(lambda x: x) + return main.with_side_outputs(other=self._foreign) + + pipeline = beam.Pipeline() + source = pipeline | 'Source' >> beam.Create([1, 2, 3]) + foreign = pipeline | 'Foreign' >> beam.Create([10]) + + with self.assertRaisesRegex(ValueError, + r"Side output 'other' must be produced by"): + _ = source | 'ExposeForeignSideOutput' >> ExposeForeignSideOutput(foreign) + + def test_pcollection_side_outputs_rejects_tag_collision(self): + class OriginalDroppedOutput(beam.PTransform): + def expand(self, pcoll): + return {'dropped': pcoll | 'Inner' >> beam.Filter(lambda x: x % 2)} + + class ConflictingSideOutput(beam.PTransform): + def expand(self, pcoll): + split = pcoll | 'Split' >> beam.ParDo(_RemoveEvensDoFn()).with_outputs( + 'dropped', main='main') + return split.dropped.with_side_outputs(dropped=split.main) + + class CollisionOverride(PTransformOverride): + def matches(self, applied_ptransform): + return applied_ptransform.full_label == 'NeedsCollisionReplacement' + + def get_replacement_transform_for_applied_ptransform( + self, applied_ptransform): + return ConflictingSideOutput() + + pipeline = beam.Pipeline() + _ = ( + pipeline + | beam.Create([1, 2, 3, 4]) + | 'NeedsCollisionReplacement' >> OriginalDroppedOutput()) + + with self.assertRaisesRegex( + ValueError, + r"Side output tag 'dropped' conflicts with an existing output"): + pipeline.replace_all([CollisionOverride()]) + + def test_ptransform_override_registers_side_outputs(self): + class IdentityComposite(beam.PTransform): + def expand(self, pcoll): + return pcoll | 'Inner' >> beam.Map(lambda x: x) + + class ReplacementWithSideOutputs(beam.PTransform): + def expand(self, pcoll): + split = pcoll | 'Split' >> beam.ParDo(_RemoveEvensDoFn()).with_outputs( + 'dropped', main='main') + return split.main.with_side_outputs(dropped=split.dropped) + + class SideOutputOverride(PTransformOverride): + def matches(self, applied_ptransform): + return applied_ptransform.full_label == 'NeedsReplacement' + + def get_replacement_transform_for_applied_ptransform( + self, applied_ptransform): + return ReplacementWithSideOutputs() + + pipeline = beam.Pipeline() + _ = ( + pipeline + | beam.Create([1, 2, 3, 4]) + | 'NeedsReplacement' >> IdentityComposite()) + + pipeline.replace_all([SideOutputOverride()]) + + applied_transform = _all_applied_transforms(pipeline)['NeedsReplacement'] + self.assertEqual({None, 'dropped'}, set(applied_transform.outputs)) + + def test_pcollection_side_outputs_not_registered_for_nested_return_values( + self): + class NestedReturnWithSideOutputs(beam.PTransform): + def expand(self, pcoll): + split = pcoll | 'Split' >> beam.ParDo(_RemoveEvensDoFn()).with_outputs( + 'dropped', main='main') + return { + 'main': split.main.with_side_outputs(dropped=split.dropped), + } + + pipeline = beam.Pipeline() + result = ( + pipeline + | beam.Create([1, 2, 3, 4]) + | 'NestedReturnWithSideOutputs' >> NestedReturnWithSideOutputs()) + + applied_transform = _all_applied_transforms( + pipeline)['NestedReturnWithSideOutputs'] + self.assertEqual({'main'}, set(applied_transform.outputs)) + self.assertNotIn('dropped', applied_transform.outputs) + self.assertIs( + result['main'].side_outputs.dropped, + result['main']._side_outputs['dropped']) + def test_filter_typehint(self): # Check input type hint and output type hint are both specified. def always_true_with_all_typehints(x: int) -> bool: @@ -1068,6 +1215,16 @@ def expand(self, p): self.assertEqual( p.transforms_stack[0].parts[0].parent, p.transforms_stack[0]) + def test_side_outputs_survive_runner_api_round_trip(self): + pipeline = beam.Pipeline() + _ = (pipeline | beam.Create([1, 2, 3, 4]) | 'RemoveEvens' >> RemoveEvens()) + + round_tripped = Pipeline.from_runner_api( + pipeline.to_runner_api(use_fake_coders=True), None, None) + applied_transform = _all_applied_transforms(round_tripped)['RemoveEvens'] + + self.assertEqual({None, 'dropped'}, set(applied_transform.outputs)) + def test_requirements(self): p = beam.Pipeline() _ = ( diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py index 1cd220cc2566..169f78e89a12 100644 --- a/sdks/python/apache_beam/pvalue.py +++ b/sdks/python/apache_beam/pvalue.py @@ -27,6 +27,7 @@ # pytype: skip-file import collections +import copy import itertools from typing import TYPE_CHECKING from typing import Any @@ -139,12 +140,61 @@ def __or__(self, ptransform): return self.pipeline.apply(ptransform, self) +class _SideOutputsContainer: + """Lightweight accessor over named side-output PCollections. + + Supports attribute access (``container.dropped``), indexing + (``container["dropped"]``), iteration over tag names, ``len()``, and + ``in``. ``__getattr__`` on a missing tag raises ``AttributeError`` with + a message listing available tags. + """ + def __init__(self, side_outputs: dict[str, 'PCollection']): + object.__setattr__(self, '_side_outputs', dict(side_outputs)) + + def __getattr__(self, tag: str) -> 'PCollection': + if tag.startswith('__'): + raise AttributeError(tag) + try: + return self._side_outputs[tag] + except KeyError as exc: + available = sorted(self._side_outputs) + raise AttributeError( + f"No side output named {tag!r}. Available: {available}") from exc + + def __getitem__(self, tag: str) -> 'PCollection': + return self._side_outputs[tag] + + def __iter__(self): + return iter(self._side_outputs) + + def __len__(self): + return len(self._side_outputs) + + def __contains__(self, tag): + return tag in self._side_outputs + + class PCollection(PValue, Generic[T]): """A multiple values (potentially huge) container. Dataflow users should not construct PCollection objects directly in their pipelines. """ + def __init__( + self, + pipeline: 'Pipeline', + tag: Optional[str] = None, + element_type: Optional[Union[type, 'typehints.TypeConstraint']] = None, + windowing: Optional['Windowing'] = None, + is_bounded=True): + super().__init__( + pipeline, + tag=tag, + element_type=element_type, + windowing=windowing, + is_bounded=is_bounded) + self._side_outputs: Optional[dict[str, 'PCollection']] = None + def __eq__(self, other): if isinstance(other, PCollection): return self.tag == other.tag and self.producer == other.producer @@ -152,6 +202,43 @@ def __eq__(self, other): def __hash__(self): return hash((self.tag, self.producer)) + def __copy__(self): + result = type(self).__new__(type(self)) + result.__dict__ = self.__dict__.copy() + return result + + @property + def side_outputs(self) -> _SideOutputsContainer: + return _SideOutputsContainer(self._side_outputs or {}) + + def with_side_outputs(self, **side_outputs: 'PCollection') -> 'PCollection': + """Return a copy of this PCollection carrying the given side outputs. + + Each kwarg becomes accessible as ``result.side_outputs.``. Tags + must be valid Python identifiers (enforced by ``**`` syntax). + + Calling ``with_side_outputs`` again replaces any previously-set side + outputs on the new copy; the original is unchanged. + + This annotation is not preserved across runner API round-trips; inspect + ``producer.outputs`` on the deserialized pipeline instead. + + Only the main output participates in composite-boundary type checking. + """ + for side_tag, side_pcoll in side_outputs.items(): + if not isinstance(side_pcoll, PCollection): + raise TypeError( + 'Side output %r must be a PCollection. Got %r.' % + (side_tag, side_pcoll)) + if side_pcoll.pipeline != self.pipeline: + raise ValueError( + 'Side output %r must belong to the same pipeline as %r.' % + (side_tag, self)) + + result = copy.copy(self) + result._side_outputs = dict(side_outputs) + return result + @property def windowing(self) -> 'Windowing': if not hasattr(self, '_windowing'): diff --git a/sdks/python/apache_beam/pvalue_test.py b/sdks/python/apache_beam/pvalue_test.py index 447d2327dc4f..b4e1a470344e 100644 --- a/sdks/python/apache_beam/pvalue_test.py +++ b/sdks/python/apache_beam/pvalue_test.py @@ -19,9 +19,12 @@ # pytype: skip-file +import copy import unittest +import apache_beam as beam from apache_beam.pvalue import AsSingleton +from apache_beam.pvalue import PCollection from apache_beam.pvalue import PValue from apache_beam.pvalue import Row from apache_beam.pvalue import TaggedOutput @@ -50,6 +53,79 @@ def test_passed_tuple_as_tag(self): TaggedOutput((1, 2, 3), 'value') +class PCollectionSideOutputsTest(unittest.TestCase): + def test_with_side_outputs_returns_copy_and_preserves_original(self): + pipeline = beam.Pipeline() + pcoll = PCollection(pipeline) + dropped = PCollection(pipeline) + + result = pcoll.with_side_outputs(dropped=dropped) + copied = copy.copy(result) + + self.assertIsNot(result, pcoll) + self.assertIsNone(pcoll._side_outputs) + self.assertEqual({'dropped': dropped}, result._side_outputs) + self.assertEqual({'dropped': dropped}, copied._side_outputs) + + def test_side_outputs_attribute_and_index_access(self): + pipeline = beam.Pipeline() + dropped = PCollection(pipeline) + kept = PCollection(pipeline) + pcoll = PCollection(pipeline).with_side_outputs(dropped=dropped, kept=kept) + + self.assertIs(pcoll.side_outputs.dropped, dropped) + self.assertIs(pcoll.side_outputs['kept'], kept) + self.assertEqual(['dropped', 'kept'], sorted(pcoll.side_outputs)) + self.assertEqual(2, len(pcoll.side_outputs)) + self.assertIn('dropped', pcoll.side_outputs) + + def test_missing_side_output_lists_available_tags(self): + pipeline = beam.Pipeline() + dropped = PCollection(pipeline) + kept = PCollection(pipeline) + pcoll = PCollection(pipeline).with_side_outputs(dropped=dropped, kept=kept) + + with self.assertRaisesRegex( + AttributeError, + r"No side output named 'missing'\. Available: \['dropped', 'kept'\]"): + _ = pcoll.side_outputs.missing + + def test_with_side_outputs_validation(self): + pcoll = PCollection(beam.Pipeline()) + foreign = PCollection(beam.Pipeline()) + + with self.assertRaisesRegex(TypeError, + r"Side output 'dropped' must be a PCollection"): + pcoll.with_side_outputs(dropped='not a PCollection') + + with self.assertRaisesRegex( + ValueError, r"Side output 'dropped' must belong to the same pipeline"): + pcoll.with_side_outputs(dropped=foreign) + + def test_side_outputs_empty_container_behavior(self): + pcoll = PCollection(beam.Pipeline()) + + self.assertEqual([], list(pcoll.side_outputs)) + self.assertEqual(0, len(pcoll.side_outputs)) + self.assertNotIn('missing', pcoll.side_outputs) + with self.assertRaisesRegex( + AttributeError, r"No side output named 'missing'\. Available: \[\]"): + _ = pcoll.side_outputs.missing + + def test_with_side_outputs_second_call_replaces_existing_side_outputs(self): + pipeline = beam.Pipeline() + dropped = PCollection(pipeline) + kept = PCollection(pipeline) + first = PCollection(pipeline).with_side_outputs(dropped=dropped) + + second = first.with_side_outputs(kept=kept) + + self.assertEqual({'dropped': dropped}, first._side_outputs) + self.assertEqual({'kept': kept}, second._side_outputs) + with self.assertRaises(AttributeError): + _ = second.side_outputs.dropped + + class RowTest(unittest.TestCase): def test_row_eq(self): row = Row(a=1, b=2)