diff --git a/CHANGES.md b/CHANGES.md index bdcbd3451c7b..adb7198537f2 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -86,6 +86,7 @@ ## Bugfixes * Fixed BigQueryEnrichmentHandler batch mode dropping earlier requests when multiple requests share the same enrichment key (Python) ([#38035](https://github.com/apache/beam/issues/38035)). +* Added `max_batch_duration_secs` passthrough support in Python Enrichment BigQuery and CloudSQL handlers so batching duration can be forwarded to `BatchElements` ([#38243](https://github.com/apache/beam/issues/38243)). ## Security Fixes diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py index 2306d7d97a57..3881e49c769b 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py @@ -73,8 +73,9 @@ class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, list[Row]], names to fetch. This handler pulls data from BigQuery per element by default. To change this - behavior, set the `min_batch_size` and `max_batch_size` parameters. - These min and max values for batch size are sent to the + behavior, set the `min_batch_size`, `max_batch_size`, and + `max_batch_duration_secs` parameters. + These batching values are sent to the :class:`apache_beam.transforms.utils.BatchElements` transform. NOTE: Elements cannot be batched when using the `query_fn` parameter. @@ -91,6 +92,7 @@ def __init__( query_fn: Optional[QueryFn] = None, min_batch_size: int = 1, max_batch_size: int = 10000, + max_batch_duration_secs: Optional[float] = None, throw_exception_on_empty_results: bool = True, **kwargs, ): @@ -124,11 +126,14 @@ def __init__( querying BigQuery. Defaults to 1 if `query_fn` is not specified. max_batch_size (int): Maximum number of rows to batch together. Defaults to 10,000 if `query_fn` is not specified. + max_batch_duration_secs (float): Maximum amount of time in seconds to + buffer a batch before emitting it. If not provided, batching duration + is determined by `BatchElements` defaults. **kwargs: Additional keyword arguments to pass to `bigquery.Client`. Note: - * `min_batch_size` and `max_batch_size` cannot be defined if the - `query_fn` is provided. + * `min_batch_size`, `max_batch_size`, and `max_batch_duration_secs` + are not used if `query_fn` is provided. * Either `fields` or `condition_value_fn` must be provided for query construction if `query_fn` is not provided. * Ensure appropriate permissions are granted for BigQuery access. @@ -156,6 +161,9 @@ def __init__( if not query_fn: self._batching_kwargs['min_batch_size'] = min_batch_size self._batching_kwargs['max_batch_size'] = max_batch_size + if max_batch_duration_secs is not None: + self._batching_kwargs[ + 'max_batch_duration_secs'] = max_batch_duration_secs def __enter__(self): self.client = bigquery.Client(project=self.project, **self.kwargs) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py index 98508baf6619..5b1c08b0e8ff 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py @@ -112,6 +112,41 @@ def test_batch_mode_emits_empty_rows_for_all_unmatched_duplicate_keys(self): [(requests[0], beam.Row()), (requests[1], beam.Row())], ) + def test_batch_elements_kwargs_include_max_batch_duration_secs(self): + handler = BigQueryEnrichmentHandler( + project=self.project, + table_name='project.dataset.table', + row_restriction_template="id='{}'", + fields=['id'], + min_batch_size=2, + max_batch_size=10, + max_batch_duration_secs=0.75, + ) + + self.assertEqual( + handler.batch_elements_kwargs(), + { + 'min_batch_size': 2, + 'max_batch_size': 10, + 'max_batch_duration_secs': 0.75, + }) + + def test_batch_elements_kwargs_omit_max_batch_duration_secs_by_default(self): + handler = BigQueryEnrichmentHandler( + project=self.project, + table_name='project.dataset.table', + row_restriction_template="id='{}'", + fields=['id'], + min_batch_size=2, + max_batch_size=10, + ) + + self.assertEqual( + handler.batch_elements_kwargs(), { + 'min_batch_size': 2, + 'max_batch_size': 10, + }) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py index ba0b8617f67b..d2ddd598209e 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py @@ -243,8 +243,9 @@ class CloudSQLEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]): the desired column names. This handler queries the Cloud SQL database per element by default. - To enable batching, set the `min_batch_size` and `max_batch_size` parameters. - These values control the batching behavior in the + To enable batching, set the `min_batch_size`, `max_batch_size`, and + `max_batch_duration_secs` parameters. These values control batching behavior + in the :class:`apache_beam.transforms.utils.BatchElements` transform. NOTE: Batching is not supported when using the CustomQueryConfig. @@ -257,6 +258,7 @@ def __init__( column_names: Optional[list[str]] = None, min_batch_size: int = 1, max_batch_size: int = 10000, + max_batch_duration_secs: Optional[float] = None, **kwargs, ): """ @@ -290,11 +292,15 @@ def __init__( querying the database. Defaults to 1 if `query_fn` is not used. max_batch_size (int): Maximum number of rows to batch together. Defaults to 10,000 if `query_fn` is not used. + max_batch_duration_secs (float): Maximum amount of time in seconds to + buffer a batch before emitting it. If not provided, batching duration + is determined by `BatchElements` defaults. **kwargs: Additional keyword arguments for database connection or query handling. Note: - * Cannot use `min_batch_size` or `max_batch_size` with `query_fn`. + * `min_batch_size`, `max_batch_size`, and `max_batch_duration_secs` + are not used with `query_fn`. * Either `where_clause_fields` or `where_clause_value_fn` must be provided for query construction if `query_fn` is not provided. * Ensure that the database user has the necessary permissions to query the @@ -313,6 +319,9 @@ def __init__( f"WHERE {query_config.where_clause_template}") self._batching_kwargs['min_batch_size'] = min_batch_size self._batching_kwargs['max_batch_size'] = max_batch_size + if max_batch_duration_secs is not None: + self._batching_kwargs[ + 'max_batch_duration_secs'] = max_batch_duration_secs def __enter__(self): connector = self._connection_config.get_connector_handler() diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py index 98f1acfa53cf..2530a49714ec 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py @@ -213,6 +213,59 @@ def test_custom_query_config_cache_key_error(self): with self.assertRaises(NotImplementedError): handler.get_cache_key(request) + def test_batch_elements_kwargs_include_max_batch_duration_secs(self): + connection_config = ExternalSQLDBConnectionConfig( + db_adapter=DatabaseTypeAdapter.POSTGRESQL, + host='localhost', + port=5432, + user='user', + password='password', + db_id='db') + query_config = TableFieldsQueryConfig( + table_id='my_table', + where_clause_template='id = :id', + where_clause_fields=['id']) + + handler = CloudSQLEnrichmentHandler( + connection_config=connection_config, + query_config=query_config, + min_batch_size=2, + max_batch_size=10, + max_batch_duration_secs=0.5) + + self.assertEqual( + handler.batch_elements_kwargs(), + { + 'min_batch_size': 2, + 'max_batch_size': 10, + 'max_batch_duration_secs': 0.5, + }) + + def test_batch_elements_kwargs_omit_max_batch_duration_secs_by_default(self): + connection_config = ExternalSQLDBConnectionConfig( + db_adapter=DatabaseTypeAdapter.POSTGRESQL, + host='localhost', + port=5432, + user='user', + password='password', + db_id='db') + query_config = TableFieldsQueryConfig( + table_id='my_table', + where_clause_template='id = :id', + where_clause_fields=['id']) + + handler = CloudSQLEnrichmentHandler( + connection_config=connection_config, + query_config=query_config, + min_batch_size=2, + max_batch_size=10) + + self.assertEqual( + handler.batch_elements_kwargs(), { + 'min_batch_size': 2, + 'max_batch_size': 10, + }) + def test_extract_parameter_names(self): """Test parameter extraction from SQL templates.""" connection_config = ExternalSQLDBConnectionConfig(