Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions dlt/common/destination/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,27 +445,28 @@ def run_managed(
self._job_client = job_client
self._done_event = done_event

# filepath is now moved to running
# Track terminal state locally; publish to self._state only after
# _on_completed() finishes to prevent the race in #3849.
terminal_state: TLoadJobState = "running"
try:
self._state = "running"
self._job_client.prepare_load_job_execution(self)
self.run()
self._state = "completed"
terminal_state = "completed"
except (TerminalException, AssertionError) as e:
self._state = "failed"
terminal_state = "failed"
self._exception = e
logger.exception(f"Terminal exception in job {self.job_id()} in file {self._file_path}")
except (DestinationTransientException, Exception) as e:
self._state = "retry"
terminal_state = "retry"
self._exception = e
logger.exception(
f"Transient exception in job {self.job_id()} in file {self._file_path}"
)
finally:
# sanity check
assert self._state in ("completed", "retry", "failed")
if self._state != "retry":
# persist terminal state so resume can skip re-execution
assert terminal_state in ("completed", "retry", "failed")
on_completed_exc = None
if terminal_state != "retry":
if self._on_completed:
if self._exception:
failed_message = "".join(
Expand All @@ -477,12 +478,21 @@ def run_managed(
)
else:
failed_message = None
self._on_completed(self._state, failed_message)
try:
self._on_completed(terminal_state, failed_message)
except Exception as exc:
terminal_state = "failed"
if self._exception is None:
self._exception = exc
on_completed_exc = exc
self._finished_at = pendulum.now()
# wake up waiting threads
if self._done_event:
with contextlib.suppress(ValueError):
self._done_event.release()
# Publish only after callback and timestamp are done.
self._state = terminal_state
if terminal_state != "retry" and self._done_event:
with contextlib.suppress(ValueError):
self._done_event.release()
if on_completed_exc is not None:
raise on_completed_exc

@abstractmethod
def run(self) -> None:
Expand Down
133 changes: 133 additions & 0 deletions tests/load/test_jobs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import threading
from typing import List, Optional, Tuple

import pytest
Expand Down Expand Up @@ -182,6 +183,10 @@ def bad_callback(state: TLoadJobState, msg: Optional[str]) -> None:
with pytest.raises(OSError, match="disk full"):
j.run_managed(MockClient(), None) # type: ignore

# job must still reach a terminal state
assert j.state() == "failed"
assert j._finished_at is not None


def test_on_completed_fires_before_semaphore_release() -> None:
"""_on_completed fires before _finished_at is set and before the
Expand Down Expand Up @@ -214,6 +219,134 @@ def run(self) -> None:
assert finished_at_during_callback[0] is None


def test_state_not_published_before_on_completed_finishes() -> None:
"""Regression test for #3849: state() must not return a terminal value
while _on_completed is still executing."""
file_path = "/table.1234.0.jsonl"
barrier = threading.Barrier(2, timeout=5)
state_during_callback: List[TLoadJobState] = []

class SuccessfulJob(RunnableLoadJob):
def run(self) -> None:
pass

j = SuccessfulJob(file_path)

def blocking_callback(state: TLoadJobState, msg: Optional[str]) -> None:
barrier.wait() # sync 1: signal that callback is running
barrier.wait() # sync 2: wait for observer to finish checking

j.set_run_vars(
load_id="1",
schema=None,
load_table=None,
on_completed=blocking_callback,
)

def worker() -> None:
j.run_managed(MockClient(), None) # type: ignore

t = threading.Thread(target=worker, daemon=True)
t.start()

barrier.wait() # sync 1: callback is executing
state_during_callback.append(j.state())
barrier.wait() # sync 2: let callback finish

t.join(timeout=5)
assert not t.is_alive()

# while callback was running, state must still be "running"
assert state_during_callback[0] == "running"
# after run_managed returns, state is terminal
assert j.state() == "completed"


def test_on_completed_exception_sets_terminal_state() -> None:
"""When _on_completed raises, the job must reach 'failed' with
_finished_at set so execution does not halt."""
file_path = "/table.1234.0.jsonl"

class SuccessfulJob(RunnableLoadJob):
def run(self) -> None:
pass

def bad_callback(state: TLoadJobState, msg: Optional[str]) -> None:
raise OSError(".dlt folder unavailable")

j = SuccessfulJob(file_path)
j.set_run_vars(
load_id="1",
schema=None,
load_table=None,
on_completed=bad_callback,
)
with pytest.raises(OSError, match=".dlt folder unavailable"):
j.run_managed(MockClient(), None) # type: ignore

assert j.state() == "failed"
assert j._finished_at is not None
assert isinstance(j.exception(), OSError)


def test_on_completed_exception_releases_semaphore() -> None:
"""Semaphore must be released even when _on_completed raises,
otherwise the loader hangs."""
from threading import BoundedSemaphore

file_path = "/table.1234.0.jsonl"

class SuccessfulJob(RunnableLoadJob):
def run(self) -> None:
pass

def bad_callback(state: TLoadJobState, msg: Optional[str]) -> None:
raise OSError("disk full")

sem = BoundedSemaphore()
sem.acquire()

j = SuccessfulJob(file_path)
j.set_run_vars(
load_id="1",
schema=None,
load_table=None,
on_completed=bad_callback,
)
with pytest.raises(OSError):
j.run_managed(MockClient(), sem) # type: ignore

# semaphore must have been released despite the exception
assert sem.acquire(blocking=False), "semaphore was not released"


def test_on_completed_exception_preserves_terminal_job_exception() -> None:
"""Keep the destination failure on the job when both run() and _on_completed fail."""
file_path = "/table.1234.0.jsonl"

class TerminalJob(RunnableLoadJob):
def run(self) -> None:
raise DestinationTerminalException("destination broke")

def bad_callback(state: TLoadJobState, msg: Optional[str]) -> None:
raise OSError("state persistence broke")

j = TerminalJob(file_path)
j.set_run_vars(
load_id="1",
schema=None,
load_table=None,
on_completed=bad_callback,
)
with pytest.raises(OSError, match="state persistence broke"):
j.run_managed(MockClient(), None) # type: ignore

assert j.state() == "failed"
assert j._finished_at is not None
assert isinstance(j.exception(), DestinationTerminalException)
assert str(j.exception()) == "destination broke"


def test_set_final_state_completed() -> None:
"""set_final_state puts a job into completed state."""
file_path = "/table.1234.0.jsonl"
Expand Down