Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
## Bugfixes

* Fixed BigQueryEnrichmentHandler batch mode dropping earlier requests when multiple requests share the same enrichment key (Python) ([#38035](https://github.com/apache/beam/issues/38035)).
* Added `max_batch_duration_secs` passthrough support in Python Enrichment BigQuery and CloudSQL handlers so batching duration can be forwarded to `BatchElements` ([#38243](https://github.com/apache/beam/issues/38243)).

## Security Fixes

Expand Down
16 changes: 12 additions & 4 deletions sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, list[Row]],
names to fetch.

This handler pulls data from BigQuery per element by default. To change this
behavior, set the `min_batch_size` and `max_batch_size` parameters.
These min and max values for batch size are sent to the
behavior, set the `min_batch_size`, `max_batch_size`, and
`max_batch_duration_secs` parameters.
These batching values are sent to the
:class:`apache_beam.transforms.utils.BatchElements` transform.

NOTE: Elements cannot be batched when using the `query_fn` parameter.
Expand All @@ -91,6 +92,7 @@ def __init__(
query_fn: Optional[QueryFn] = None,
min_batch_size: int = 1,
max_batch_size: int = 10000,
max_batch_duration_secs: Optional[float] = None,
throw_exception_on_empty_results: bool = True,
**kwargs,
):
Expand Down Expand Up @@ -124,11 +126,14 @@ def __init__(
querying BigQuery. Defaults to 1 if `query_fn` is not specified.
max_batch_size (int): Maximum number of rows to batch together.
Defaults to 10,000 if `query_fn` is not specified.
max_batch_duration_secs (float): Maximum amount of time in seconds to
buffer a batch before emitting it. If not provided, batching duration
is determined by `BatchElements` defaults.
**kwargs: Additional keyword arguments to pass to `bigquery.Client`.

Note:
* `min_batch_size` and `max_batch_size` cannot be defined if the
`query_fn` is provided.
* `min_batch_size`, `max_batch_size`, and `max_batch_duration_secs`
are not used if `query_fn` is provided.
* Either `fields` or `condition_value_fn` must be provided for query
construction if `query_fn` is not provided.
* Ensure appropriate permissions are granted for BigQuery access.
Expand Down Expand Up @@ -156,6 +161,9 @@ def __init__(
if not query_fn:
self._batching_kwargs['min_batch_size'] = min_batch_size
self._batching_kwargs['max_batch_size'] = max_batch_size
if max_batch_duration_secs is not None:
self._batching_kwargs[
'max_batch_duration_secs'] = max_batch_duration_secs

def __enter__(self):
self.client = bigquery.Client(project=self.project, **self.kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,41 @@ def test_batch_mode_emits_empty_rows_for_all_unmatched_duplicate_keys(self):
[(requests[0], beam.Row()), (requests[1], beam.Row())],
)

def test_batch_elements_kwargs_include_max_batch_duration_secs(self):
handler = BigQueryEnrichmentHandler(
project=self.project,
table_name='project.dataset.table',
row_restriction_template="id='{}'",
fields=['id'],
min_batch_size=2,
max_batch_size=10,
max_batch_duration_secs=0.75,
)

self.assertEqual(
handler.batch_elements_kwargs(),
{
'min_batch_size': 2,
'max_batch_size': 10,
'max_batch_duration_secs': 0.75,
})

def test_batch_elements_kwargs_omit_max_batch_duration_secs_by_default(self):
handler = BigQueryEnrichmentHandler(
project=self.project,
table_name='project.dataset.table',
row_restriction_template="id='{}'",
fields=['id'],
min_batch_size=2,
max_batch_size=10,
)

self.assertEqual(
handler.batch_elements_kwargs(), {
'min_batch_size': 2,
'max_batch_size': 10,
})


if __name__ == '__main__':
unittest.main()
15 changes: 12 additions & 3 deletions sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,9 @@ class CloudSQLEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]):
the desired column names.

This handler queries the Cloud SQL database per element by default.
To enable batching, set the `min_batch_size` and `max_batch_size` parameters.
These values control the batching behavior in the
To enable batching, set the `min_batch_size`, `max_batch_size`, and
`max_batch_duration_secs` parameters. These values control batching behavior
in the
:class:`apache_beam.transforms.utils.BatchElements` transform.

NOTE: Batching is not supported when using the CustomQueryConfig.
Expand All @@ -257,6 +258,7 @@ def __init__(
column_names: Optional[list[str]] = None,
min_batch_size: int = 1,
max_batch_size: int = 10000,
max_batch_duration_secs: Optional[float] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -290,11 +292,15 @@ def __init__(
querying the database. Defaults to 1 if `query_fn` is not used.
max_batch_size (int): Maximum number of rows to batch together. Defaults
to 10,000 if `query_fn` is not used.
max_batch_duration_secs (float): Maximum amount of time in seconds to
buffer a batch before emitting it. If not provided, batching duration
is determined by `BatchElements` defaults.
**kwargs: Additional keyword arguments for database connection or query
handling.

Note:
* Cannot use `min_batch_size` or `max_batch_size` with `query_fn`.
* `min_batch_size`, `max_batch_size`, and `max_batch_duration_secs`
are not used with `query_fn`.
* Either `where_clause_fields` or `where_clause_value_fn` must be provided
for query construction if `query_fn` is not provided.
* Ensure that the database user has the necessary permissions to query the
Expand All @@ -313,6 +319,9 @@ def __init__(
f"WHERE {query_config.where_clause_template}")
self._batching_kwargs['min_batch_size'] = min_batch_size
self._batching_kwargs['max_batch_size'] = max_batch_size
if max_batch_duration_secs is not None:
self._batching_kwargs[
'max_batch_duration_secs'] = max_batch_duration_secs

def __enter__(self):
connector = self._connection_config.get_connector_handler()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,59 @@ def test_custom_query_config_cache_key_error(self):
with self.assertRaises(NotImplementedError):
handler.get_cache_key(request)

def test_batch_elements_kwargs_include_max_batch_duration_secs(self):
connection_config = ExternalSQLDBConnectionConfig(
db_adapter=DatabaseTypeAdapter.POSTGRESQL,
host='localhost',
port=5432,
user='user',
password='password',
db_id='db')
query_config = TableFieldsQueryConfig(
table_id='my_table',
where_clause_template='id = :id',
where_clause_fields=['id'])

handler = CloudSQLEnrichmentHandler(
connection_config=connection_config,
query_config=query_config,
min_batch_size=2,
max_batch_size=10,
max_batch_duration_secs=0.5)

self.assertEqual(
handler.batch_elements_kwargs(),
{
'min_batch_size': 2,
'max_batch_size': 10,
'max_batch_duration_secs': 0.5,
})

def test_batch_elements_kwargs_omit_max_batch_duration_secs_by_default(self):
connection_config = ExternalSQLDBConnectionConfig(
db_adapter=DatabaseTypeAdapter.POSTGRESQL,
host='localhost',
port=5432,
user='user',
password='password',
db_id='db')
query_config = TableFieldsQueryConfig(
table_id='my_table',
where_clause_template='id = :id',
where_clause_fields=['id'])

handler = CloudSQLEnrichmentHandler(
connection_config=connection_config,
query_config=query_config,
min_batch_size=2,
max_batch_size=10)

self.assertEqual(
handler.batch_elements_kwargs(), {
'min_batch_size': 2,
'max_batch_size': 10,
})

def test_extract_parameter_names(self):
"""Test parameter extraction from SQL templates."""
connection_config = ExternalSQLDBConnectionConfig(
Expand Down
Loading