Skip to content

Commit 107b18f

Browse files
committed
manual state change should not use fork-execute model on scheduler
Signed-off-by: Maciej Obuchowski <maciej.obuchowski@datadoghq.com>
1 parent 578ab8e commit 107b18f

2 files changed

Lines changed: 152 additions & 31 deletions

File tree

providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py

Lines changed: 77 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,19 @@ def _executor_initializer():
8181
settings.configure_orm()
8282

8383

84+
def _emit_manual_state_change_event(adapter_method, stats_key, **kwargs):
85+
"""
86+
Emit an OL event via the given adapter method and record its serialized size.
87+
88+
Module-level so it is picklable across the ProcessPoolExecutor boundary used by
89+
`_on_task_instance_manual_state_change` for scheduler-side "task state changed
90+
externally" emissions.
91+
"""
92+
event = adapter_method(**kwargs)
93+
Stats.gauge(stats_key, len(Serde.to_json(event).encode("utf-8")))
94+
return event
95+
96+
8497
class OpenLineageListener:
8598
"""OpenLineage listener sends events on task instance and dag run starts, completes and failures."""
8699

@@ -653,6 +666,17 @@ def _on_task_instance_manual_state_change(
653666
ti_state: TaskInstanceState,
654667
error: None | str | BaseException = None,
655668
) -> None:
669+
"""
670+
Emit an OL event from the scheduler when a TI transitions externally.
671+
672+
This path is only reached on the scheduler (``process_executor_events ->
673+
handle_failure``, or manual UI/API state changes). Emission is routed through
674+
the same ``ProcessPoolExecutor`` the DAG-run listeners use rather than through
675+
``_fork_execute``: the pool's ``_executor_initializer`` rebuilds the ORM once
676+
per worker, so the child never shares a pooled Postgres SSL connection with
677+
the scheduler, and bursts of external-state-change events no longer produce a
678+
fork-per-event.
679+
"""
656680
self.log.debug("`_on_task_instance_manual_state_change` was called with state: `%s`.", ti_state)
657681
end_date = timezone.utcnow()
658682

@@ -674,22 +698,37 @@ def _on_task_instance_manual_state_change(
674698
)
675699
return
676700

677-
@print_warning(self.log)
678-
def on_state_change():
679-
date = dagrun.logical_date or dagrun.run_after
680-
parent_run_id = self.adapter.build_dag_run_id(
681-
dag_id=ti.dag_id,
682-
logical_date=date,
683-
clear_number=dagrun.clear_number,
684-
)
701+
try:
702+
if not self.executor:
703+
self.log.debug("Executor has not started before `_on_task_instance_manual_state_change`")
704+
return
705+
706+
if ti_state == TaskInstanceState.FAILED:
707+
adapter_method = self.adapter.fail_task
708+
event_type = RunState.FAIL.value.lower()
709+
elif ti_state in (TaskInstanceState.SUCCESS, TaskInstanceState.SKIPPED):
710+
adapter_method = self.adapter.complete_task
711+
event_type = RunState.COMPLETE.value.lower()
712+
else:
713+
raise ValueError(f"Unsupported ti_state: `{ti_state}`.")
685714

715+
# Extract primitives from live ORM objects in the parent (scheduler)
716+
# before crossing the pool boundary. Passing ORM objects through the pool
717+
# pickler loses TaskGroup attributes and crashes event emission -- see
718+
# the equivalent note in `on_dag_run_running` (listener.py ~868).
719+
date = dagrun.logical_date or dagrun.run_after
686720
task_uuid = self.adapter.build_task_instance_run_id(
687721
dag_id=ti.dag_id,
688722
task_id=ti.task_id,
689723
try_number=ti.try_number,
690724
logical_date=date,
691725
map_index=ti.map_index,
692726
)
727+
parent_run_id = self.adapter.build_dag_run_id(
728+
dag_id=ti.dag_id,
729+
logical_date=date,
730+
clear_number=dagrun.clear_number,
731+
)
693732

694733
data_interval_start = dagrun.data_interval_start
695734
if isinstance(data_interval_start, datetime):
@@ -698,21 +737,22 @@ def on_state_change():
698737
if isinstance(data_interval_end, datetime):
699738
data_interval_end = data_interval_end.isoformat()
700739

701-
dag_tags, owners, doc, doc_type = None, None, None, None
702-
airflow_run_facet = {}
740+
dag_tags: list | None = None
741+
owners: list[str] | None = None
742+
doc: str | None = None
743+
doc_type: str | None = None
744+
airflow_run_facet: dict = {}
703745
if task: # on scheduler, we should have access to task
704746
doc, doc_type = get_task_documentation(task)
705747
dag = getattr(task, "dag")
706748
if dag:
707749
if not doc:
708750
doc, doc_type = get_dag_documentation(dag)
709-
710751
dag_tags = dag.tags
711752
owners = [x.strip() for x in (task if task.owner != "airflow" else dag).owner.split(",")]
712-
713753
airflow_run_facet = get_airflow_run_facet(dagrun, dag, ti, task, task_uuid)
714754

715-
adapter_kwargs = {
755+
adapter_kwargs: dict = {
716756
"run_id": task_uuid,
717757
"job_name": get_job_name(ti),
718758
"end_time": end_date.isoformat(),
@@ -733,23 +773,20 @@ def on_state_change():
733773
**get_airflow_debug_facet(),
734774
},
735775
}
736-
737776
if ti_state == TaskInstanceState.FAILED:
738-
event_type = RunState.FAIL.value.lower()
739-
redacted_event = self.adapter.fail_task(**adapter_kwargs, error=error)
740-
elif ti_state in (TaskInstanceState.SUCCESS, TaskInstanceState.SKIPPED):
741-
event_type = RunState.COMPLETE.value.lower()
742-
redacted_event = self.adapter.complete_task(**adapter_kwargs)
743-
else:
744-
raise ValueError(f"Unsupported ti_state: `{ti_state}`.")
777+
adapter_kwargs["error"] = error
745778

746-
operator_name = ti.operator.lower()
747-
Stats.gauge(
748-
f"ol.event.size.{event_type}.{operator_name}",
749-
len(Serde.to_json(redacted_event).encode("utf-8")),
779+
self.submit_callable(
780+
_emit_manual_state_change_event,
781+
adapter_method,
782+
f"ol.event.size.{event_type}.{ti.operator.lower()}",
783+
**adapter_kwargs,
784+
)
785+
except BaseException as e:
786+
self.log.warning(
787+
"OpenLineage received exception in method `_on_task_instance_manual_state_change`",
788+
exc_info=e,
750789
)
751-
752-
self._execute(on_state_change, "on_state_change", use_fork=True)
753790

754791
def _execute(self, callable, callable_name: str, use_fork: bool = False):
755792
if use_fork:
@@ -787,7 +824,19 @@ def _fork_execute(self, callable, callable_name: str):
787824
self.log.debug("Process with pid %s finished - parent", pid)
788825
else:
789826
setproctitle(getproctitle() + " - OpenLineage - " + callable_name)
790-
if not AIRFLOW_V_3_0_PLUS:
827+
# Rebuild the ORM in the forked child so it does not share pooled
828+
# Postgres connections (and in-flight SSL session state) with the
829+
# parent. Without this, the child and parent can both write on an
830+
# inherited SSL socket, desynchronising the TLS sequence counter
831+
# and producing "SSL error: decryption failed or bad record mac"
832+
# on the parent's next query (see #47580 and the DB-access crash
833+
# it introduced: on AF3+ workers the Task SDK sets SQL_ALCHEMY_CONN
834+
# to "airflow-db-not-allowed:///", which would make configure_orm
835+
# fail). settings.engine is only populated in processes where the
836+
# ORM was successfully configured -- i.e., the scheduler on AF3+
837+
# and both scheduler and workers on AF2 -- so gating on it gives
838+
# us fork-safety where we can have it and a no-op where we can't.
839+
if settings.engine is not None:
791840
configure_orm(disable_connection_pool=True)
792841
self.log.debug("Executing OpenLineage process - %s - pid %s", callable_name, os.getpid())
793842
callable()

providers/openlineage/tests/unit/openlineage/plugins/test_listener.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,23 @@ def regular_call(self, callable, callable_name, use_fork):
8585
callable()
8686

8787

88+
def direct_submit_call(self, callable, *args, **kwargs):
89+
"""Synchronous stand-in for ``OpenLineageListener.submit_callable``.
90+
91+
Bypasses the ``ProcessPoolExecutor`` so tests can assert against mocked
92+
adapter methods without hitting pickling of ``unittest.mock.Mock``.
93+
When the submitted callable is ``_emit_manual_state_change_event``, skip
94+
its ``Stats.gauge`` side effect (which would try to ``Serde.to_json`` a
95+
``MagicMock`` return value) and invoke the adapter method directly.
96+
"""
97+
from airflow.providers.openlineage.plugins.listener import _emit_manual_state_change_event
98+
99+
if callable is _emit_manual_state_change_event:
100+
adapter_method, _stats_key, *_ = args
101+
return adapter_method(**kwargs)
102+
return callable(*args, **kwargs)
103+
104+
88105
class MockExecutor:
89106
def __init__(self, *args, **kwargs):
90107
self.submitted = False
@@ -1463,7 +1480,8 @@ def test_adapter_fail_task_is_called_with_dag_description_when_task_doc_is_empty
14631480
@mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet")
14641481
@mock.patch("airflow.providers.openlineage.plugins.listener.get_task_parent_run_facet")
14651482
@mock.patch(
1466-
"airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call
1483+
"airflow.providers.openlineage.plugins.listener.OpenLineageListener.submit_callable",
1484+
new=direct_submit_call,
14671485
)
14681486
def test_adapter_fail_task_is_called_with_proper_arguments_for_db_task_instance_model(
14691487
self,
@@ -1482,6 +1500,7 @@ def test_adapter_fail_task_is_called_with_proper_arguments_for_db_task_instance_
14821500
time_machine.move_to(timezone.datetime(2023, 1, 3, 13, 1, 1), tick=False)
14831501

14841502
listener, task_instance = self._create_listener_and_task_instance(runtime_ti=False)
1503+
listener._executor = MagicMock() # satisfy `if not self.executor` guard
14851504
mock_get_airflow_run_facet.return_value = {"airflow": 3}
14861505
mock_get_task_parent_run_facet.return_value = {"parent": 4}
14871506
mock_debug_facet.return_value = {"debug": "packages"}
@@ -1649,7 +1668,8 @@ def test_adapter_complete_task_is_called_with_dag_description_when_task_doc_is_e
16491668
@mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_debug_facet")
16501669
@mock.patch("airflow.providers.openlineage.plugins.listener.get_task_parent_run_facet")
16511670
@mock.patch(
1652-
"airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call
1671+
"airflow.providers.openlineage.plugins.listener.OpenLineageListener.submit_callable",
1672+
new=direct_submit_call,
16531673
)
16541674
def test_adapter_complete_task_is_called_with_proper_arguments_for_db_task_instance_model(
16551675
self, mock_get_task_parent_run_facet, mock_debug_facet, mock_debug_mode, mock_emit, time_machine
@@ -1662,6 +1682,7 @@ def test_adapter_complete_task_is_called_with_proper_arguments_for_db_task_insta
16621682
time_machine.move_to(timezone.datetime(2023, 1, 3, 13, 1, 1), tick=False)
16631683

16641684
listener, task_instance = self._create_listener_and_task_instance(runtime_ti=False)
1685+
listener._executor = MagicMock() # satisfy `if not self.executor` guard
16651686
delattr(task_instance, "task") # Test api server path, where task is not available
16661687
mock_get_task_parent_run_facet.return_value = {"parent": 4}
16671688
mock_debug_facet.return_value = {"debug": "packages"}
@@ -1856,7 +1877,8 @@ def test_listener_on_task_instance_skipped_do_not_call_adapter_when_disabled_ope
18561877
@mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_debug_facet")
18571878
@mock.patch("airflow.providers.openlineage.plugins.listener.get_task_parent_run_facet")
18581879
@mock.patch(
1859-
"airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call
1880+
"airflow.providers.openlineage.plugins.listener.OpenLineageListener.submit_callable",
1881+
new=direct_submit_call,
18601882
)
18611883
def test_adapter_complete_task_is_called_with_proper_arguments_for_db_task_instance_model_on_skip(
18621884
self, mock_get_task_parent_run_facet, mock_debug_facet, mock_debug_mode, mock_emit, time_machine
@@ -1869,6 +1891,7 @@ def test_adapter_complete_task_is_called_with_proper_arguments_for_db_task_insta
18691891
time_machine.move_to(timezone.datetime(2023, 1, 3, 13, 1, 1), tick=False)
18701892

18711893
listener, task_instance = self._create_listener_and_task_instance(runtime_ti=False)
1894+
listener._executor = MagicMock() # satisfy `if not self.executor` guard
18721895
delattr(task_instance, "task") # Test api server path, where task is not available
18731896
mock_get_task_parent_run_facet.return_value = {"parent": 4}
18741897
mock_debug_facet.return_value = {"debug": "packages"}
@@ -1980,6 +2003,55 @@ def set_result(*args, **kwargs):
19802003
listener.log.warning.assert_called_once()
19812004

19822005

2006+
class TestOpenLineageListenerForkExecute:
2007+
"""Regression tests for `OpenLineageListener._fork_execute`.
2008+
2009+
On processes where the ORM is configured (scheduler; workers on AF2), the
2010+
forked child must rebuild the engine so it does not share pooled Postgres
2011+
connections with the parent -- otherwise an inherited SSL socket gets
2012+
written by both processes and the parent's next query dies with
2013+
``SSL error: decryption failed or bad record mac``.
2014+
2015+
On AF3+ workers the Task SDK sets SQL_ALCHEMY_CONN to
2016+
``airflow-db-not-allowed:///``; ``configure_orm`` would raise there, so
2017+
the child must skip the rebuild when ``settings.engine`` is ``None``.
2018+
"""
2019+
2020+
@staticmethod
2021+
def _run_child(engine_value):
2022+
listener = OpenLineageListener()
2023+
called = MagicMock()
2024+
with (
2025+
patch("airflow.providers.openlineage.plugins.listener.os.fork", return_value=0),
2026+
patch("airflow.providers.openlineage.plugins.listener.os._exit") as mock_exit,
2027+
patch("airflow.providers.openlineage.plugins.listener.configure_orm") as mock_configure_orm,
2028+
patch("airflow.providers.openlineage.plugins.listener.setproctitle"),
2029+
patch(
2030+
"airflow.providers.openlineage.plugins.listener.getproctitle",
2031+
return_value="test",
2032+
),
2033+
patch("airflow.providers.openlineage.plugins.listener.settings") as mock_settings,
2034+
):
2035+
mock_settings.engine = engine_value
2036+
listener._fork_execute(called, "on_failure")
2037+
return mock_configure_orm, called, mock_exit
2038+
2039+
def test_child_rebuilds_orm_when_engine_is_configured(self):
2040+
mock_configure_orm, called, mock_exit = self._run_child(engine_value=MagicMock())
2041+
mock_configure_orm.assert_called_once_with(disable_connection_pool=True)
2042+
called.assert_called_once()
2043+
mock_exit.assert_called_once_with(0)
2044+
2045+
def test_child_skips_orm_rebuild_when_engine_is_none(self):
2046+
# On AF3+ workers the metadata DB is intentionally unreachable and
2047+
# configure_orm would raise on the sentinel URL. The callable must
2048+
# still run and the child must still exit cleanly.
2049+
mock_configure_orm, called, mock_exit = self._run_child(engine_value=None)
2050+
mock_configure_orm.assert_not_called()
2051+
called.assert_called_once()
2052+
mock_exit.assert_called_once_with(0)
2053+
2054+
19832055
@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Airflow 2 tests")
19842056
class TestOpenLineageSelectiveEnableAirflow2:
19852057
def setup_method(self):

0 commit comments

Comments
 (0)