|
19 | 19 |
|
20 | 20 | from __future__ import annotations |
21 | 21 |
|
| 22 | +import dataclasses |
22 | 23 | from typing import Any |
23 | 24 | from typing import Dict |
24 | 25 | from typing import Optional |
25 | 26 | from typing import Sequence |
26 | 27 | from typing import Tuple |
27 | 28 |
|
28 | 29 | from apache_beam.typehints import typehints |
| 30 | +from apache_beam.typehints.native_type_compatibility import match_is_dataclass |
29 | 31 | from apache_beam.typehints.native_type_compatibility import match_is_named_tuple |
30 | 32 | from apache_beam.typehints.schema_registry import SchemaTypeRegistry |
31 | 33 |
|
@@ -56,18 +58,14 @@ def __init__( |
56 | 58 | for guidance on creating PCollections with inferred schemas. |
57 | 59 |
|
58 | 60 | Note RowTypeConstraint does not currently store arbitrary functions for |
59 | | - converting to/from the user type. Instead, we only support ``NamedTuple`` |
60 | | - user types and make the follow assumptions: |
| 61 | + converting to/from the user type. Instead, we support ``NamedTuple`` and |
| 62 | + ``dataclasses`` user types and make the follow assumptions: |
61 | 63 |
|
62 | 64 | - The user type can be constructed with field values as arguments in order |
63 | 65 | (i.e. ``constructor(*field_values)``). |
64 | 66 | - Field values can be accessed from instances of the user type by attribute |
65 | 67 | (i.e. with ``getattr(obj, field_name)``). |
66 | 68 |
|
67 | | - In the future we will add support for dataclasses |
68 | | - ([#22085](https://github.com/apache/beam/issues/22085)) which also satisfy |
69 | | - these assumptions. |
70 | | -
|
71 | 69 | The RowTypeConstraint constructor should not be called directly (even |
72 | 70 | internally to Beam). Prefer static methods ``from_user_type`` or |
73 | 71 | ``from_fields``. |
@@ -107,27 +105,30 @@ def from_user_type( |
107 | 105 | if match_is_named_tuple(user_type): |
108 | 106 | fields = [(name, user_type.__annotations__[name]) |
109 | 107 | for name in user_type._fields] |
110 | | - |
111 | | - field_descriptions = getattr(user_type, '_field_descriptions', None) |
112 | | - |
113 | | - if _user_type_is_generated(user_type): |
114 | | - return RowTypeConstraint.from_fields( |
115 | | - fields, |
116 | | - schema_id=getattr(user_type, _BEAM_SCHEMA_ID), |
117 | | - schema_options=schema_options, |
118 | | - field_options=field_options, |
119 | | - field_descriptions=field_descriptions) |
120 | | - |
121 | | - # TODO(https://github.com/apache/beam/issues/22125): Add user API for |
122 | | - # specifying schema/field options |
123 | | - return RowTypeConstraint( |
124 | | - fields=fields, |
125 | | - user_type=user_type, |
| 108 | + elif match_is_dataclass(user_type): |
| 109 | + fields = [(field.name, field.type) |
| 110 | + for field in dataclasses.fields(user_type)] |
| 111 | + else: |
| 112 | + return None |
| 113 | + |
| 114 | + field_descriptions = getattr(user_type, '_field_descriptions', None) |
| 115 | + |
| 116 | + if _user_type_is_generated(user_type): |
| 117 | + return RowTypeConstraint.from_fields( |
| 118 | + fields, |
| 119 | + schema_id=getattr(user_type, _BEAM_SCHEMA_ID), |
126 | 120 | schema_options=schema_options, |
127 | 121 | field_options=field_options, |
128 | 122 | field_descriptions=field_descriptions) |
129 | 123 |
|
130 | | - return None |
| 124 | + # TODO(https://github.com/apache/beam/issues/22125): Add user API for |
| 125 | + # specifying schema/field options |
| 126 | + return RowTypeConstraint( |
| 127 | + fields=fields, |
| 128 | + user_type=user_type, |
| 129 | + schema_options=schema_options, |
| 130 | + field_options=field_options, |
| 131 | + field_descriptions=field_descriptions) |
131 | 132 |
|
132 | 133 | @staticmethod |
133 | 134 | def from_fields( |
|
0 commit comments