|
22 | 22 | import javax.annotation.Nullable; |
23 | 23 | import org.apache.beam.model.pipeline.v1.SchemaApi; |
24 | 24 | import org.apache.beam.runners.core.construction.SdkComponents; |
| 25 | +import org.apache.beam.sdk.coders.RowCoder; |
25 | 26 | import org.apache.beam.sdk.schemas.Schema; |
26 | 27 | import org.apache.beam.sdk.schemas.SchemaCoder; |
27 | 28 | import org.apache.beam.sdk.schemas.SchemaTranslation; |
28 | 29 | import org.apache.beam.sdk.transforms.SerializableFunction; |
| 30 | +import org.apache.beam.sdk.util.Preconditions; |
29 | 31 | import org.apache.beam.sdk.util.SerializableUtils; |
30 | 32 | import org.apache.beam.sdk.util.StringUtils; |
31 | 33 | import org.apache.beam.sdk.values.TypeDescriptor; |
@@ -100,16 +102,50 @@ public SchemaCoder fromCloudObject(CloudObject cloudObject) { |
100 | 102 | SchemaApi.Schema.Builder schemaBuilder = SchemaApi.Schema.newBuilder(); |
101 | 103 | JsonFormat.parser().merge(Structs.getString(cloudObject, SCHEMA), schemaBuilder); |
102 | 104 | Schema schema = SchemaTranslation.schemaFromProto(schemaBuilder.build()); |
103 | | - @Nullable UUID uuid = schema.getUUID(); |
104 | | - if (schema.isEncodingPositionsOverridden() && uuid != null) { |
105 | | - SchemaCoder.overrideEncodingPositions(uuid, schema.getEncodingPositions()); |
106 | | - } |
| 105 | + overrideEncodingPositions(schema); |
107 | 106 | return SchemaCoder.of(schema, typeDescriptor, toRowFunction, fromRowFunction); |
108 | 107 | } catch (IOException e) { |
109 | 108 | throw new RuntimeException(e); |
110 | 109 | } |
111 | 110 | } |
112 | 111 |
|
| 112 | + static void overrideEncodingPositions(Schema schema) { |
| 113 | + @Nullable UUID uuid = schema.getUUID(); |
| 114 | + if (schema.isEncodingPositionsOverridden() && uuid != null) { |
| 115 | + RowCoder.overrideEncodingPositions(uuid, schema.getEncodingPositions()); |
| 116 | + } |
| 117 | + schema.getFields().stream() |
| 118 | + .map(Schema.Field::getType) |
| 119 | + .forEach(SchemaCoderCloudObjectTranslator::overrideEncodingPositions); |
| 120 | + } |
| 121 | + |
| 122 | + private static void overrideEncodingPositions(Schema.FieldType fieldType) { |
| 123 | + switch (fieldType.getTypeName()) { |
| 124 | + case ROW: |
| 125 | + overrideEncodingPositions(Preconditions.checkArgumentNotNull(fieldType.getRowSchema())); |
| 126 | + break; |
| 127 | + case ARRAY: |
| 128 | + case ITERABLE: |
| 129 | + overrideEncodingPositions( |
| 130 | + Preconditions.checkArgumentNotNull(fieldType.getCollectionElementType())); |
| 131 | + break; |
| 132 | + case MAP: |
| 133 | + overrideEncodingPositions(Preconditions.checkArgumentNotNull(fieldType.getMapKeyType())); |
| 134 | + overrideEncodingPositions(Preconditions.checkArgumentNotNull(fieldType.getMapValueType())); |
| 135 | + break; |
| 136 | + case LOGICAL_TYPE: |
| 137 | + Schema.LogicalType logicalType = |
| 138 | + Preconditions.checkArgumentNotNull(fieldType.getLogicalType()); |
| 139 | + @Nullable Schema.FieldType argumentType = logicalType.getArgumentType(); |
| 140 | + if (argumentType != null) { |
| 141 | + overrideEncodingPositions(argumentType); |
| 142 | + } |
| 143 | + overrideEncodingPositions(logicalType.getBaseType()); |
| 144 | + break; |
| 145 | + default: |
| 146 | + } |
| 147 | + } |
| 148 | + |
113 | 149 | @Override |
114 | 150 | public Class<? extends SchemaCoder> getSupportedClass() { |
115 | 151 | return SchemaCoder.class; |
|
0 commit comments