Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions dlt/common/data_writers/buffered.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
DestinationCapabilitiesRequired,
FileImportNotFound,
InvalidFileNameTemplateException,
SchemaEvolutionRequired,
)
from dlt.common.data_writers.writers import TWriter, DataWriter, FileWriterSpec, count_rows_in_items
from dlt.common.schema.typing import TTableSchemaColumns
Expand Down Expand Up @@ -239,25 +240,37 @@ def _flush_items(self, allow_empty_file: bool = False) -> None:
if self._buffered_items or allow_empty_file:
# we only open a writer when there are any items in the buffer and first flush is requested
if not self._writer:
# create new writer and write header
if self.writer_spec.is_binary_format:
self._file = self.open(self._file_name, "wb") # type: ignore
else:
self._file = self.open(self._file_name, "wt", encoding="utf-8", newline="")
self._writer = self.writer_cls(self._file, caps=self._caps) # type: ignore[assignment]
self._writer.write_header(self._current_columns)
self._open_writer()
# swap out buffer before writing so batch references are released
# as soon as write_data returns, without waiting for the next
# write_data_item call.
if self._buffered_items:
items = self._buffered_items
self._buffered_items = []
self._buffered_items_count = 0
self._writer.write_data(items)
try:
self._writer.write_data(items)
except SchemaEvolutionRequired as sc:
# cross-batch schema widened - rotate to new file.
# items list was already cleared inside write_data;
# the materialized table is on the exception.
self._rotate_file()
self._open_writer()
table = sc.table.cast(sc.unified_schema)
self._writer.write_data([table])
items.clear()
else:
self._buffered_items_count = 0

def _open_writer(self) -> None:
"""Open the current file and create a writer."""
if self.writer_spec.is_binary_format:
self._file = self.open(self._file_name, "wb") # type: ignore
else:
self._file = self.open(self._file_name, "wt", encoding="utf-8", newline="")
self._writer = self.writer_cls(self._file, caps=self._caps) # type: ignore[assignment]
self._writer.write_header(self._current_columns)

def _flush_and_close_file(
self, allow_empty_file: bool = False, skip_flush: bool = False
) -> DataWriterMetrics:
Expand Down
19 changes: 18 additions & 1 deletion dlt/common/data_writers/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import NamedTuple, Sequence
from typing import TYPE_CHECKING, Any, NamedTuple, Sequence

from dlt.common.destination import TLoaderFileFormat
from dlt.common.exceptions import DltException

if TYPE_CHECKING:
from dlt.common.libs.pyarrow import pyarrow


class DataWriterException(DltException):
pass
Expand Down Expand Up @@ -89,3 +92,17 @@ def __init__(self, file_format: TLoaderFileFormat, data_item_format: str, detail
f"A data item of type {data_item_format=:} cannot be written as `{file_format}:"
f" {details}`"
)


class SchemaEvolutionRequired(DataWriterException):
"""Cross-batch schema widened; signals file rotation."""

def __init__(
self, unified_schema: "pyarrow.Schema", table: "pyarrow.Table"
) -> None:
self.unified_schema = unified_schema
self.table = table
super().__init__(
"Schema evolved across flush batches, rotating to new file with"
" unified schema."
)
40 changes: 37 additions & 3 deletions dlt/common/data_writers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
FileFormatForItemFormatNotFound,
FileSpecNotFound,
InvalidDataItem,
SchemaEvolutionRequired,
)
from dlt.common.destination.configuration import (
CsvFormatConfiguration,
Expand Down Expand Up @@ -514,24 +515,57 @@ def write_data(self, items: Sequence[TDataItem]) -> None:

if not items:
return

promote_options = self.parquet_format.arrow_concat_promote_options

# concat batches and tables into a single one, preserving order
# pyarrow writer starts a row group for each item it writes (even with 0 rows)
# it also converts batches into tables internally. by creating a single table
# we allow the user rudimentary control over row group size via max buffered items
table = concat_batches_and_tables_in_order(
items, promote_options=self.parquet_format.arrow_concat_promote_options
items, promote_options=promote_options
)
# release batch references - concat is zero-copy so table shares the
# underlying buffers via Arrow refcounting. clearing the input list
# drops the Python-level RecordBatch/Table references so only the
# concatenated table keeps the buffers alive
if isinstance(items, list):
items.clear()
self.items_count += table.num_rows

if not self.writer:
self.writer = self._create_writer(table.schema)
# write concatenated tables
elif (
promote_options != "none"
and not table.schema.equals(self.writer.schema, check_metadata=False)
):
# cross-batch schema mismatch: cast or rotate
table = self._reconcile_schema(table, promote_options)

self.writer.write_table(table, row_group_size=self.parquet_format.row_group_size)
# increment after successful write so metrics are correct when
# SchemaEvolutionRequired triggers file rotation mid-batch
self.items_count += table.num_rows

def _reconcile_schema(
self, table: "pa.Table", promote_options: str
) -> "pa.Table":
"""Reconcile table schema with writer schema across flush batches."""
from dlt.common.libs.pyarrow import pyarrow

writer_schema = self.writer.schema.remove_metadata()
table_schema = table.schema.remove_metadata()

# incompatible schemas (e.g. string vs int) raise ArrowInvalid/ArrowTypeError
unified = pyarrow.unify_schemas(
[writer_schema, table_schema],
promote_options=promote_options,
)

if unified == writer_schema:
# writer schema already covers incoming types - safe cast up
return table.cast(self.writer.schema)
# incoming has wider types - need a new file
raise SchemaEvolutionRequired(unified, table)

def write_footer(self) -> None:
if not self.writer:
Expand Down
161 changes: 161 additions & 0 deletions tests/libs/test_parquet_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,3 +489,164 @@ def test_empty_tables_get_flushed() -> None:
assert len(writer._buffered_items) == 1
writer.write_data_item(single_elem_table, columns=c1)
assert len(writer._buffered_items) == 0


def test_cross_batch_schema_promotion_narrower_cast() -> None:
c1 = {"val": new_column("val", "double")}

with get_writer(
ArrowToParquetWriter,
buffer_max_items=1,
caps=_caps_with_promote("permissive"),
) as writer:
t1 = pa.Table.from_pydict({"val": pa.array([1.5], type=pa.float64())})
writer.write_data_item(t1, columns=c1)
writer._flush_items()

t2 = pa.Table.from_pydict({"val": pa.array([2.5], type=pa.float32())})
writer.write_data_item(t2, columns=c1)

assert len(writer.closed_files) == 1
table = pq.read_table(writer.closed_files[0].file_path)
assert table.schema.field("val").type == pa.float64()
assert table.column("val").to_pylist() == [1.5, 2.5]


def test_cross_batch_schema_promotion_wider_rotates() -> None:
c1 = {"val": new_column("val", "double")}

with get_writer(
ArrowToParquetWriter,
buffer_max_items=1,
caps=_caps_with_promote("permissive"),
) as writer:
t1 = pa.Table.from_pydict({"val": pa.array([1.5], type=pa.float32())})
writer.write_data_item(t1, columns=c1)
writer._flush_items()

t2 = pa.Table.from_pydict({"val": pa.array([2.5], type=pa.float64())})
writer.write_data_item(t2, columns=c1)

assert len(writer.closed_files) == 2
t_file1 = pq.read_table(writer.closed_files[0].file_path)
assert t_file1.schema.field("val").type == pa.float32()
t_file2 = pq.read_table(writer.closed_files[1].file_path)
assert t_file2.schema.field("val").type == pa.float64()


def test_cross_batch_schema_promotion_int_widening() -> None:
c1 = {"val": new_column("val", "bigint")}

with get_writer(
ArrowToParquetWriter,
buffer_max_items=1,
caps=_caps_with_promote("permissive"),
) as writer:
t1 = pa.Table.from_pydict({"val": pa.array([1000], type=pa.int32())})
writer.write_data_item(t1, columns=c1)
writer._flush_items()

t2 = pa.Table.from_pydict({"val": pa.array([1], type=pa.int8())})
writer.write_data_item(t2, columns=c1)

assert len(writer.closed_files) == 1
table = pq.read_table(writer.closed_files[0].file_path)
assert table.schema.field("val").type == pa.int32()
assert table.column("val").to_pylist() == [1000, 1]


def test_cross_batch_schema_mixed_wider_and_narrower() -> None:
c1 = {
"a": new_column("a", "double"),
"b": new_column("b", "bigint"),
}

with get_writer(
ArrowToParquetWriter,
buffer_max_items=1,
caps=_caps_with_promote("permissive"),
) as writer:
t1 = pa.Table.from_pydict({
"a": pa.array([1.5], type=pa.float64()),
"b": pa.array([1], type=pa.int8()),
})
writer.write_data_item(t1, columns=c1)
writer._flush_items()

t2 = pa.Table.from_pydict({
"a": pa.array([2.5], type=pa.float32()),
"b": pa.array([1000], type=pa.int32()),
})
writer.write_data_item(t2, columns=c1)

assert len(writer.closed_files) == 2
t_file1 = pq.read_table(writer.closed_files[0].file_path)
assert t_file1.schema.field("a").type == pa.float64()
assert t_file1.schema.field("b").type == pa.int8()
t_file2 = pq.read_table(writer.closed_files[1].file_path)
assert t_file2.schema.field("a").type == pa.float64()
assert t_file2.schema.field("b").type == pa.int32()


def test_cross_batch_schema_metadata_only_diff_no_rotation() -> None:
c1 = {"val": new_column("val", "double")}

with get_writer(
ArrowToParquetWriter,
buffer_max_items=1,
caps=_caps_with_promote("permissive"),
) as writer:
schema1 = pa.schema([
pa.field("val", pa.float64(), metadata={b"source": b"file1"})
])
t1 = pa.Table.from_pydict(
{"val": pa.array([1.5], type=pa.float64())}, schema=schema1
)
writer.write_data_item(t1, columns=c1)
writer._flush_items()

schema2 = pa.schema([
pa.field("val", pa.float64(), metadata={b"source": b"file2"})
])
t2 = pa.Table.from_pydict(
{"val": pa.array([2.5], type=pa.float64())}, schema=schema2
)
writer.write_data_item(t2, columns=c1)

assert len(writer.closed_files) == 1
table = pq.read_table(writer.closed_files[0].file_path)
assert table.column("val").to_pylist() == [1.5, 2.5]


def test_cross_batch_schema_no_promotion_when_none() -> None:
c1 = {"val": new_column("val", "double")}

with get_writer(
ArrowToParquetWriter,
buffer_max_items=1,
caps=_caps_with_promote("none"),
) as writer:
t1 = pa.Table.from_pydict({"val": pa.array([1.5], type=pa.float64())})
writer.write_data_item(t1, columns=c1)
writer._flush_items()

t2 = pa.Table.from_pydict({"val": pa.array([2.5], type=pa.float32())})
with pytest.raises((pa.lib.ArrowInvalid, ValueError)):
writer.write_data_item(t2, columns=c1)


def test_cross_batch_schema_incompatible_types() -> None:
c1 = {"val": new_column("val", "text")}

with pytest.raises((pa.lib.ArrowInvalid, pa.lib.ArrowTypeError)):
with get_writer(
ArrowToParquetWriter,
buffer_max_items=1,
caps=_caps_with_promote("permissive"),
) as writer:
t1 = pa.Table.from_pydict({"val": pa.array(["hello"], type=pa.string())})
writer.write_data_item(t1, columns=c1)
writer._flush_items()

t2 = pa.Table.from_pydict({"val": pa.array([1.5], type=pa.float64())})
writer.write_data_item(t2, columns=c1)
Loading