diff --git a/dlt/common/destination/client.py b/dlt/common/destination/client.py index dda81c0711..5fb0d319d9 100644 --- a/dlt/common/destination/client.py +++ b/dlt/common/destination/client.py @@ -445,27 +445,28 @@ def run_managed( self._job_client = job_client self._done_event = done_event - # filepath is now moved to running + # Track terminal state locally; publish to self._state only after + # _on_completed() finishes to prevent the race in #3849. + terminal_state: TLoadJobState = "running" try: self._state = "running" self._job_client.prepare_load_job_execution(self) self.run() - self._state = "completed" + terminal_state = "completed" except (TerminalException, AssertionError) as e: - self._state = "failed" + terminal_state = "failed" self._exception = e logger.exception(f"Terminal exception in job {self.job_id()} in file {self._file_path}") except (DestinationTransientException, Exception) as e: - self._state = "retry" + terminal_state = "retry" self._exception = e logger.exception( f"Transient exception in job {self.job_id()} in file {self._file_path}" ) finally: - # sanity check - assert self._state in ("completed", "retry", "failed") - if self._state != "retry": - # persist terminal state so resume can skip re-execution + assert terminal_state in ("completed", "retry", "failed") + on_completed_exc = None + if terminal_state != "retry": if self._on_completed: if self._exception: failed_message = "".join( @@ -477,12 +478,21 @@ def run_managed( ) else: failed_message = None - self._on_completed(self._state, failed_message) + try: + self._on_completed(terminal_state, failed_message) + except Exception as exc: + terminal_state = "failed" + if self._exception is None: + self._exception = exc + on_completed_exc = exc self._finished_at = pendulum.now() - # wake up waiting threads - if self._done_event: - with contextlib.suppress(ValueError): - self._done_event.release() + # Publish only after callback and timestamp are done. + self._state = terminal_state + if terminal_state != "retry" and self._done_event: + with contextlib.suppress(ValueError): + self._done_event.release() + if on_completed_exc is not None: + raise on_completed_exc @abstractmethod def run(self) -> None: diff --git a/tests/load/test_jobs.py b/tests/load/test_jobs.py index f6f06a52a1..5862a27931 100644 --- a/tests/load/test_jobs.py +++ b/tests/load/test_jobs.py @@ -1,3 +1,4 @@ +import threading from typing import List, Optional, Tuple import pytest @@ -182,6 +183,10 @@ def bad_callback(state: TLoadJobState, msg: Optional[str]) -> None: with pytest.raises(OSError, match="disk full"): j.run_managed(MockClient(), None) # type: ignore + # job must still reach a terminal state + assert j.state() == "failed" + assert j._finished_at is not None + def test_on_completed_fires_before_semaphore_release() -> None: """_on_completed fires before _finished_at is set and before the @@ -214,6 +219,134 @@ def run(self) -> None: assert finished_at_during_callback[0] is None +def test_state_not_published_before_on_completed_finishes() -> None: + """Regression test for #3849: state() must not return a terminal value + while _on_completed is still executing.""" + file_path = "/table.1234.0.jsonl" + barrier = threading.Barrier(2, timeout=5) + state_during_callback: List[TLoadJobState] = [] + + class SuccessfulJob(RunnableLoadJob): + def run(self) -> None: + pass + + j = SuccessfulJob(file_path) + + def blocking_callback(state: TLoadJobState, msg: Optional[str]) -> None: + barrier.wait() # sync 1: signal that callback is running + barrier.wait() # sync 2: wait for observer to finish checking + + j.set_run_vars( + load_id="1", + schema=None, + load_table=None, + on_completed=blocking_callback, + ) + + def worker() -> None: + j.run_managed(MockClient(), None) # type: ignore + + t = threading.Thread(target=worker, daemon=True) + t.start() + + barrier.wait() # sync 1: callback is executing + state_during_callback.append(j.state()) + barrier.wait() # sync 2: let callback finish + + t.join(timeout=5) + assert not t.is_alive() + + # while callback was running, state must still be "running" + assert state_during_callback[0] == "running" + # after run_managed returns, state is terminal + assert j.state() == "completed" + + +def test_on_completed_exception_sets_terminal_state() -> None: + """When _on_completed raises, the job must reach 'failed' with + _finished_at set so execution does not halt.""" + file_path = "/table.1234.0.jsonl" + + class SuccessfulJob(RunnableLoadJob): + def run(self) -> None: + pass + + def bad_callback(state: TLoadJobState, msg: Optional[str]) -> None: + raise OSError(".dlt folder unavailable") + + j = SuccessfulJob(file_path) + j.set_run_vars( + load_id="1", + schema=None, + load_table=None, + on_completed=bad_callback, + ) + with pytest.raises(OSError, match=".dlt folder unavailable"): + j.run_managed(MockClient(), None) # type: ignore + + assert j.state() == "failed" + assert j._finished_at is not None + assert isinstance(j.exception(), OSError) + + +def test_on_completed_exception_releases_semaphore() -> None: + """Semaphore must be released even when _on_completed raises, + otherwise the loader hangs.""" + from threading import BoundedSemaphore + + file_path = "/table.1234.0.jsonl" + + class SuccessfulJob(RunnableLoadJob): + def run(self) -> None: + pass + + def bad_callback(state: TLoadJobState, msg: Optional[str]) -> None: + raise OSError("disk full") + + sem = BoundedSemaphore() + sem.acquire() + + j = SuccessfulJob(file_path) + j.set_run_vars( + load_id="1", + schema=None, + load_table=None, + on_completed=bad_callback, + ) + with pytest.raises(OSError): + j.run_managed(MockClient(), sem) # type: ignore + + # semaphore must have been released despite the exception + assert sem.acquire(blocking=False), "semaphore was not released" + + +def test_on_completed_exception_preserves_terminal_job_exception() -> None: + """Keep the destination failure on the job when both run() and _on_completed fail.""" + file_path = "/table.1234.0.jsonl" + + class TerminalJob(RunnableLoadJob): + def run(self) -> None: + raise DestinationTerminalException("destination broke") + + def bad_callback(state: TLoadJobState, msg: Optional[str]) -> None: + raise OSError("state persistence broke") + + j = TerminalJob(file_path) + j.set_run_vars( + load_id="1", + schema=None, + load_table=None, + on_completed=bad_callback, + ) + with pytest.raises(OSError, match="state persistence broke"): + j.run_managed(MockClient(), None) # type: ignore + + assert j.state() == "failed" + assert j._finished_at is not None + assert isinstance(j.exception(), DestinationTerminalException) + assert str(j.exception()) == "destination broke" + + def test_set_final_state_completed() -> None: """set_final_state puts a job into completed state.""" file_path = "/table.1234.0.jsonl"