Skip to content

Commit 9f5904b

Browse files
authored
Support inferring schemas from Python dataclasses (#37728)
* Support inferring schemas from Python dataclasses * Address comments; Revert native_type_compatibility _TypeMapEntry change * Add unit test for named tuple and dataclasses encoded by RowCoder and passing through GBK * Fix lint
1 parent 2ebe33d commit 9f5904b

6 files changed

Lines changed: 183 additions & 29 deletions

File tree

sdks/python/apache_beam/coders/coder_impl.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"""
3131
# pytype: skip-file
3232

33+
import dataclasses
3334
import decimal
3435
import enum
3536
import itertools
@@ -67,11 +68,6 @@
6768
from apache_beam.utils.timestamp import MIN_TIMESTAMP
6869
from apache_beam.utils.timestamp import Timestamp
6970

70-
try:
71-
import dataclasses
72-
except ImportError:
73-
dataclasses = None # type: ignore
74-
7571
try:
7672
import dill
7773
except ImportError:

sdks/python/apache_beam/typehints/native_type_compatibility.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import collections
2323
import collections.abc
24+
import dataclasses
2425
import logging
2526
import sys
2627
import types
@@ -175,6 +176,10 @@ def match_is_named_tuple(user_type):
175176
hasattr(user_type, '__annotations__') and hasattr(user_type, '_fields'))
176177

177178

179+
def match_is_dataclass(user_type):
180+
return dataclasses.is_dataclass(user_type) and isinstance(user_type, type)
181+
182+
178183
def _match_is_optional(user_type):
179184
return _match_is_union(user_type) and sum(
180185
tp is type(None) for tp in _get_args(user_type)) == 1

sdks/python/apache_beam/typehints/row_type.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919

2020
from __future__ import annotations
2121

22+
import dataclasses
2223
from typing import Any
2324
from typing import Dict
2425
from typing import Optional
2526
from typing import Sequence
2627
from typing import Tuple
2728

2829
from apache_beam.typehints import typehints
30+
from apache_beam.typehints.native_type_compatibility import match_is_dataclass
2931
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
3032
from apache_beam.typehints.schema_registry import SchemaTypeRegistry
3133

@@ -56,18 +58,14 @@ def __init__(
5658
for guidance on creating PCollections with inferred schemas.
5759
5860
Note RowTypeConstraint does not currently store arbitrary functions for
59-
converting to/from the user type. Instead, we only support ``NamedTuple``
60-
user types and make the follow assumptions:
61+
converting to/from the user type. Instead, we support ``NamedTuple`` and
62+
``dataclasses`` user types and make the follow assumptions:
6163
6264
- The user type can be constructed with field values as arguments in order
6365
(i.e. ``constructor(*field_values)``).
6466
- Field values can be accessed from instances of the user type by attribute
6567
(i.e. with ``getattr(obj, field_name)``).
6668
67-
In the future we will add support for dataclasses
68-
([#22085](https://github.com/apache/beam/issues/22085)) which also satisfy
69-
these assumptions.
70-
7169
The RowTypeConstraint constructor should not be called directly (even
7270
internally to Beam). Prefer static methods ``from_user_type`` or
7371
``from_fields``.
@@ -107,27 +105,30 @@ def from_user_type(
107105
if match_is_named_tuple(user_type):
108106
fields = [(name, user_type.__annotations__[name])
109107
for name in user_type._fields]
110-
111-
field_descriptions = getattr(user_type, '_field_descriptions', None)
112-
113-
if _user_type_is_generated(user_type):
114-
return RowTypeConstraint.from_fields(
115-
fields,
116-
schema_id=getattr(user_type, _BEAM_SCHEMA_ID),
117-
schema_options=schema_options,
118-
field_options=field_options,
119-
field_descriptions=field_descriptions)
120-
121-
# TODO(https://github.com/apache/beam/issues/22125): Add user API for
122-
# specifying schema/field options
123-
return RowTypeConstraint(
124-
fields=fields,
125-
user_type=user_type,
108+
elif match_is_dataclass(user_type):
109+
fields = [(field.name, field.type)
110+
for field in dataclasses.fields(user_type)]
111+
else:
112+
return None
113+
114+
field_descriptions = getattr(user_type, '_field_descriptions', None)
115+
116+
if _user_type_is_generated(user_type):
117+
return RowTypeConstraint.from_fields(
118+
fields,
119+
schema_id=getattr(user_type, _BEAM_SCHEMA_ID),
126120
schema_options=schema_options,
127121
field_options=field_options,
128122
field_descriptions=field_descriptions)
129123

130-
return None
124+
# TODO(https://github.com/apache/beam/issues/22125): Add user API for
125+
# specifying schema/field options
126+
return RowTypeConstraint(
127+
fields=fields,
128+
user_type=user_type,
129+
schema_options=schema_options,
130+
field_options=field_options,
131+
field_descriptions=field_descriptions)
131132

132133
@staticmethod
133134
def from_fields(
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""Unit tests for the Beam Row typing functionality."""
19+
20+
import typing
21+
import unittest
22+
from dataclasses import dataclass
23+
24+
import apache_beam as beam
25+
from apache_beam.testing.test_pipeline import TestPipeline
26+
from apache_beam.testing.util import assert_that
27+
from apache_beam.testing.util import equal_to
28+
from apache_beam.typehints import row_type
29+
30+
31+
class RowTypeTest(unittest.TestCase):
32+
@staticmethod
33+
def _check_key_type_and_count(x) -> int:
34+
key_type = type(x[0])
35+
if not row_type._user_type_is_generated(key_type):
36+
raise RuntimeError("Expect type after GBK to be generated user type")
37+
38+
return len(x[1])
39+
40+
def test_group_by_key_namedtuple(self):
41+
MyNamedTuple = typing.NamedTuple(
42+
"MyNamedTuple", [("id", int), ("name", str)])
43+
44+
beam.coders.typecoders.registry.register_coder(
45+
MyNamedTuple, beam.coders.RowCoder)
46+
47+
def generate(num: int):
48+
for i in range(100):
49+
yield (MyNamedTuple(i, 'a'), num)
50+
51+
pipeline = TestPipeline(is_integration_test=False)
52+
53+
with pipeline as p:
54+
result = (
55+
p
56+
| 'Create' >> beam.Create([i for i in range(10)])
57+
| 'Generate' >> beam.ParDo(generate).with_output_types(
58+
tuple[MyNamedTuple, int])
59+
| 'GBK' >> beam.GroupByKey()
60+
| 'Count Elements' >> beam.Map(self._check_key_type_and_count))
61+
assert_that(result, equal_to([10] * 100))
62+
63+
def test_group_by_key_dataclass(self):
64+
@dataclass
65+
class MyDataClass:
66+
id: int
67+
name: str
68+
69+
beam.coders.typecoders.registry.register_coder(
70+
MyDataClass, beam.coders.RowCoder)
71+
72+
def generate(num: int):
73+
for i in range(100):
74+
yield (MyDataClass(i, 'a'), num)
75+
76+
pipeline = TestPipeline(is_integration_test=False)
77+
78+
with pipeline as p:
79+
result = (
80+
p
81+
| 'Create' >> beam.Create([i for i in range(10)])
82+
| 'Generate' >> beam.ParDo(generate).with_output_types(
83+
tuple[MyDataClass, int])
84+
| 'GBK' >> beam.GroupByKey()
85+
| 'Count Elements' >> beam.Map(self._check_key_type_and_count))
86+
assert_that(result, equal_to([10] * 100))
87+
88+
89+
if __name__ == '__main__':
90+
unittest.main()

sdks/python/apache_beam/typehints/schemas.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
from apache_beam.typehints.native_type_compatibility import _safe_issubclass
9797
from apache_beam.typehints.native_type_compatibility import convert_to_python_type
9898
from apache_beam.typehints.native_type_compatibility import extract_optional_type
99+
from apache_beam.typehints.native_type_compatibility import match_is_dataclass
99100
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
100101
from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY
101102
from apache_beam.typehints.schema_registry import SchemaTypeRegistry
@@ -629,7 +630,7 @@ def schema_from_element_type(element_type: type) -> schema_pb2.Schema:
629630
Returns schema as a list of (name, python_type) tuples"""
630631
if isinstance(element_type, row_type.RowTypeConstraint):
631632
return named_fields_to_schema(element_type._fields)
632-
elif match_is_named_tuple(element_type):
633+
elif match_is_named_tuple(element_type) or match_is_dataclass(element_type):
633634
if hasattr(element_type, row_type._BEAM_SCHEMA_ID):
634635
# if the named tuple's schema is in registry, we just use it instead of
635636
# regenerating one.

sdks/python/apache_beam/typehints/schemas_test.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
# pytype: skip-file
2121

22+
import dataclasses
2223
import itertools
2324
import pickle
2425
import unittest
@@ -388,6 +389,24 @@ def test_namedtuple_roundtrip(self, user_type):
388389
self.assertIsInstance(roundtripped, row_type.RowTypeConstraint)
389390
self.assert_namedtuple_equivalent(roundtripped.user_type, user_type)
390391

392+
def test_dataclass_roundtrip(self):
393+
@dataclasses.dataclass
394+
class SimpleDataclass:
395+
id: np.int64
396+
name: str
397+
398+
roundtripped = typing_from_runner_api(
399+
typing_to_runner_api(
400+
SimpleDataclass, schema_registry=SchemaTypeRegistry()),
401+
schema_registry=SchemaTypeRegistry())
402+
403+
self.assertIsInstance(roundtripped, row_type.RowTypeConstraint)
404+
# The roundtripped user_type is generated as a NamedTuple, so we can't test
405+
# equivalence directly with the dataclass.
406+
# Instead, let's verify annotations.
407+
self.assertEqual(
408+
roundtripped.user_type.__annotations__, SimpleDataclass.__annotations__)
409+
391410
def test_row_type_constraint_to_schema(self):
392411
result_type = typing_to_runner_api(
393412
row_type.RowTypeConstraint.from_fields([
@@ -646,6 +665,48 @@ def test_trivial_example(self):
646665
expected.row_type.schema.fields,
647666
typing_to_runner_api(MyCuteClass).row_type.schema.fields)
648667

668+
def test_trivial_example_dataclass(self):
669+
@dataclasses.dataclass
670+
class MyCuteDataclass:
671+
name: str
672+
age: Optional[int]
673+
interests: List[str]
674+
height: float
675+
blob: ByteString
676+
677+
expected = schema_pb2.FieldType(
678+
row_type=schema_pb2.RowType(
679+
schema=schema_pb2.Schema(
680+
fields=[
681+
schema_pb2.Field(
682+
name='name',
683+
type=schema_pb2.FieldType(
684+
atomic_type=schema_pb2.STRING),
685+
),
686+
schema_pb2.Field(
687+
name='age',
688+
type=schema_pb2.FieldType(
689+
nullable=True, atomic_type=schema_pb2.INT64)),
690+
schema_pb2.Field(
691+
name='interests',
692+
type=schema_pb2.FieldType(
693+
array_type=schema_pb2.ArrayType(
694+
element_type=schema_pb2.FieldType(
695+
atomic_type=schema_pb2.STRING)))),
696+
schema_pb2.Field(
697+
name='height',
698+
type=schema_pb2.FieldType(
699+
atomic_type=schema_pb2.DOUBLE)),
700+
schema_pb2.Field(
701+
name='blob',
702+
type=schema_pb2.FieldType(
703+
atomic_type=schema_pb2.BYTES)),
704+
])))
705+
706+
self.assertEqual(
707+
expected.row_type.schema.fields,
708+
typing_to_runner_api(MyCuteDataclass).row_type.schema.fields)
709+
649710
def test_user_type_annotated_with_id_after_conversion(self):
650711
MyCuteClass = NamedTuple('MyCuteClass', [
651712
('name', str),

0 commit comments

Comments
 (0)