Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions sdks/python/apache_beam/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,42 @@
__all__ = ['Pipeline', 'transform_annotations']


def _descendant_applied_ptransforms(
current: 'AppliedPTransform') -> set['AppliedPTransform']:
descendants = set()
pending = list(current.parts)
while pending:
part = pending.pop()
if part in descendants:
continue
descendants.add(part)
pending.extend(part.parts)
return descendants


def _register_top_level_side_outputs(
current: 'AppliedPTransform', result: pvalue.PCollection) -> None:
if not result._side_outputs:
return

descendants = _descendant_applied_ptransforms(current)
for side_tag, side_pcoll in result._side_outputs.items():
if side_pcoll.producer is not current and side_pcoll.producer not in descendants:
raise ValueError(
f"Side output {side_tag!r} must be produced by "
f"{current.full_label!r} or one of its descendant transforms.")

existing = current.outputs.get(side_tag)
if existing is not None:
if existing is not side_pcoll:
raise ValueError(
f"Side output tag {side_tag!r} conflicts with an existing "
'output of the same transform.')
continue

current.add_output(side_pcoll, side_tag)


class Pipeline(HasDisplayData):
"""A pipeline object that manages a DAG of
:class:`~apache_beam.transforms.ptransform.PTransform` s
Expand Down Expand Up @@ -411,6 +447,9 @@ def _replace_if_needed(
new_output, new_output._main_tag)
else:
replacement_transform_node.add_output(new_output, new_output.tag)
if isinstance(new_output, pvalue.PCollection):
_register_top_level_side_outputs(
replacement_transform_node, new_output)

# Recording updated outputs. This cannot be done in the same
# visitor since if we dynamically update output type here, we'll
Expand Down Expand Up @@ -844,6 +883,9 @@ def _apply_internal(
tag = '%s_%d' % (base, counter)

current.add_output(result, tag)
if result is pvalueish_result and isinstance(result,
pvalue.PCollection):
_register_top_level_side_outputs(current, result)

if (type_options is not None and
type_options.type_check_strictness == 'ALL_REQUIRED' and
Expand Down
157 changes: 157 additions & 0 deletions sdks/python/apache_beam/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,31 @@ def expand(self, pvalues):
AddThenMultiplyDoFn(), AsSingleton(pvalues[1]), AsSingleton(pvalues[2]))


def _all_applied_transforms(pipeline):
all_applied_transforms = {}
current_transforms = list(pipeline.transforms_stack)
while current_transforms:
applied_transform = current_transforms.pop()
all_applied_transforms[applied_transform.full_label] = applied_transform
current_transforms.extend(applied_transform.parts)
return all_applied_transforms


class _RemoveEvensDoFn(beam.DoFn):
def process(self, element):
if element % 2 == 0:
yield TaggedOutput('dropped', element)
else:
yield element


class RemoveEvens(beam.PTransform):
def expand(self, pcoll):
split = pcoll | 'Split' >> beam.ParDo(_RemoveEvensDoFn()).with_outputs(
'dropped', main='main')
return split.main.with_side_outputs(dropped=split.dropped)


class PipelineTest(unittest.TestCase):
@staticmethod
def custom_callable(pcoll):
Expand Down Expand Up @@ -642,6 +667,128 @@ def mux_input(x):
self.assertNotIn(multi.letters, visitor.visited)
self.assertNotIn(multi.numbers, visitor.visited)

def test_pcollection_side_outputs_end_to_end(self):
with TestPipeline() as pipeline:
out = (
pipeline
| beam.Create([1, 2, 3, 4])
| 'RemoveEvens' >> RemoveEvens())
chained = out | 'ChainMainOutput' >> beam.Map(lambda x: x * 10)

self.assertIsInstance(out.side_outputs.dropped, beam.pvalue.PCollection)
assert_that(out, equal_to([1, 3]), label='assert_main_output')
assert_that(
out.side_outputs.dropped,
equal_to([2, 4]),
label='assert_side_output')
assert_that(chained, equal_to([10, 30]), label='assert_chained_output')

applied_transform = _all_applied_transforms(pipeline)['RemoveEvens']
self.assertIs(applied_transform.outputs[None], out)
self.assertIs(
applied_transform.outputs['dropped'], out.side_outputs.dropped)

def test_pcollection_side_outputs_rejects_foreign_pcollection(self):
class ExposeForeignSideOutput(beam.PTransform):
def __init__(self, foreign):
self._foreign = foreign

def expand(self, pcoll):
main = pcoll | 'Main' >> beam.Map(lambda x: x)
return main.with_side_outputs(other=self._foreign)

pipeline = beam.Pipeline()
source = pipeline | 'Source' >> beam.Create([1, 2, 3])
foreign = pipeline | 'Foreign' >> beam.Create([10])

with self.assertRaisesRegex(ValueError,
r"Side output 'other' must be produced by"):
_ = source | 'ExposeForeignSideOutput' >> ExposeForeignSideOutput(foreign)

def test_pcollection_side_outputs_rejects_tag_collision(self):
class OriginalDroppedOutput(beam.PTransform):
def expand(self, pcoll):
return {'dropped': pcoll | 'Inner' >> beam.Filter(lambda x: x % 2)}

class ConflictingSideOutput(beam.PTransform):
def expand(self, pcoll):
split = pcoll | 'Split' >> beam.ParDo(_RemoveEvensDoFn()).with_outputs(
'dropped', main='main')
return split.dropped.with_side_outputs(dropped=split.main)

class CollisionOverride(PTransformOverride):
def matches(self, applied_ptransform):
return applied_ptransform.full_label == 'NeedsCollisionReplacement'

def get_replacement_transform_for_applied_ptransform(
self, applied_ptransform):
return ConflictingSideOutput()

pipeline = beam.Pipeline()
_ = (
pipeline
| beam.Create([1, 2, 3, 4])
| 'NeedsCollisionReplacement' >> OriginalDroppedOutput())

with self.assertRaisesRegex(
ValueError,
r"Side output tag 'dropped' conflicts with an existing output"):
pipeline.replace_all([CollisionOverride()])

def test_ptransform_override_registers_side_outputs(self):
class IdentityComposite(beam.PTransform):
def expand(self, pcoll):
return pcoll | 'Inner' >> beam.Map(lambda x: x)

class ReplacementWithSideOutputs(beam.PTransform):
def expand(self, pcoll):
split = pcoll | 'Split' >> beam.ParDo(_RemoveEvensDoFn()).with_outputs(
'dropped', main='main')
return split.main.with_side_outputs(dropped=split.dropped)

class SideOutputOverride(PTransformOverride):
def matches(self, applied_ptransform):
return applied_ptransform.full_label == 'NeedsReplacement'

def get_replacement_transform_for_applied_ptransform(
self, applied_ptransform):
return ReplacementWithSideOutputs()

pipeline = beam.Pipeline()
_ = (
pipeline
| beam.Create([1, 2, 3, 4])
| 'NeedsReplacement' >> IdentityComposite())

pipeline.replace_all([SideOutputOverride()])

applied_transform = _all_applied_transforms(pipeline)['NeedsReplacement']
self.assertEqual({None, 'dropped'}, set(applied_transform.outputs))

def test_pcollection_side_outputs_not_registered_for_nested_return_values(
self):
class NestedReturnWithSideOutputs(beam.PTransform):
def expand(self, pcoll):
split = pcoll | 'Split' >> beam.ParDo(_RemoveEvensDoFn()).with_outputs(
'dropped', main='main')
return {
'main': split.main.with_side_outputs(dropped=split.dropped),
}

pipeline = beam.Pipeline()
result = (
pipeline
| beam.Create([1, 2, 3, 4])
| 'NestedReturnWithSideOutputs' >> NestedReturnWithSideOutputs())

applied_transform = _all_applied_transforms(
pipeline)['NestedReturnWithSideOutputs']
self.assertEqual({'main'}, set(applied_transform.outputs))
self.assertNotIn('dropped', applied_transform.outputs)
self.assertIs(
result['main'].side_outputs.dropped,
result['main']._side_outputs['dropped'])

def test_filter_typehint(self):
# Check input type hint and output type hint are both specified.
def always_true_with_all_typehints(x: int) -> bool:
Expand Down Expand Up @@ -1068,6 +1215,16 @@ def expand(self, p):
self.assertEqual(
p.transforms_stack[0].parts[0].parent, p.transforms_stack[0])

def test_side_outputs_survive_runner_api_round_trip(self):
pipeline = beam.Pipeline()
_ = (pipeline | beam.Create([1, 2, 3, 4]) | 'RemoveEvens' >> RemoveEvens())

round_tripped = Pipeline.from_runner_api(
pipeline.to_runner_api(use_fake_coders=True), None, None)
applied_transform = _all_applied_transforms(round_tripped)['RemoveEvens']

self.assertEqual({None, 'dropped'}, set(applied_transform.outputs))

def test_requirements(self):
p = beam.Pipeline()
_ = (
Expand Down
87 changes: 87 additions & 0 deletions sdks/python/apache_beam/pvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
# pytype: skip-file

import collections
import copy
import itertools
from typing import TYPE_CHECKING
from typing import Any
Expand Down Expand Up @@ -139,19 +140,105 @@ def __or__(self, ptransform):
return self.pipeline.apply(ptransform, self)


class _SideOutputsContainer:
"""Lightweight accessor over named side-output PCollections.

Supports attribute access (``container.dropped``), indexing
(``container["dropped"]``), iteration over tag names, ``len()``, and
``in``. ``__getattr__`` on a missing tag raises ``AttributeError`` with
a message listing available tags.
"""
def __init__(self, side_outputs: dict[str, 'PCollection']):
object.__setattr__(self, '_side_outputs', dict(side_outputs))

def __getattr__(self, tag: str) -> 'PCollection':
if tag.startswith('__'):
raise AttributeError(tag)
try:
return self._side_outputs[tag]
except KeyError as exc:
available = sorted(self._side_outputs)
raise AttributeError(
f"No side output named {tag!r}. Available: {available}") from exc

def __getitem__(self, tag: str) -> 'PCollection':
return self._side_outputs[tag]

def __iter__(self):
return iter(self._side_outputs)

def __len__(self):
return len(self._side_outputs)

def __contains__(self, tag):
return tag in self._side_outputs


class PCollection(PValue, Generic[T]):
"""A multiple values (potentially huge) container.

Dataflow users should not construct PCollection objects directly in their
pipelines.
"""
def __init__(
self,
pipeline: 'Pipeline',
tag: Optional[str] = None,
element_type: Optional[Union[type, 'typehints.TypeConstraint']] = None,
windowing: Optional['Windowing'] = None,
is_bounded=True):
super().__init__(
pipeline,
tag=tag,
element_type=element_type,
windowing=windowing,
is_bounded=is_bounded)
self._side_outputs: Optional[dict[str, 'PCollection']] = None

def __eq__(self, other):
if isinstance(other, PCollection):
return self.tag == other.tag and self.producer == other.producer

def __hash__(self):
return hash((self.tag, self.producer))

def __copy__(self):
result = type(self).__new__(type(self))
result.__dict__ = self.__dict__.copy()
return result

@property
def side_outputs(self) -> _SideOutputsContainer:
return _SideOutputsContainer(self._side_outputs or {})

def with_side_outputs(self, **side_outputs: 'PCollection') -> 'PCollection':
"""Return a copy of this PCollection carrying the given side outputs.

Each kwarg becomes accessible as ``result.side_outputs.<tag>``. Tags
must be valid Python identifiers (enforced by ``**`` syntax).

Calling ``with_side_outputs`` again replaces any previously-set side
outputs on the new copy; the original is unchanged.

This annotation is not preserved across runner API round-trips; inspect
``producer.outputs`` on the deserialized pipeline instead.

Only the main output participates in composite-boundary type checking.
"""
for side_tag, side_pcoll in side_outputs.items():
if not isinstance(side_pcoll, PCollection):
raise TypeError(
'Side output %r must be a PCollection. Got %r.' %
(side_tag, side_pcoll))
if side_pcoll.pipeline != self.pipeline:
raise ValueError(
'Side output %r must belong to the same pipeline as %r.' %
(side_tag, self))

result = copy.copy(self)
result._side_outputs = dict(side_outputs)
return result

@property
def windowing(self) -> 'Windowing':
if not hasattr(self, '_windowing'):
Expand Down
Loading
Loading