Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion swift/template/vision_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,47 @@ def uniform_sample(_l, _n):
return frames


def load_audio(audio: Union[str, bytes], sampling_rate: int, return_sr: bool = False):
def _decode_worker(queue, func, args, kwargs):
try:
queue.put((True, func(*args, **kwargs)))
except Exception as e:
queue.put((False, e))


def _decode_with_timeout(func: Callable[..., _T], *args, **kwargs) -> _T:
# Native media decoders (audioread/ffmpeg, decord) can deadlock in C while holding the GIL on a
# corrupt/unsupported clip, silently hanging a DataLoader worker forever; a signal-based timeout
# can't interrupt them. When `MEDIA_DECODE_TIMEOUT` (seconds) > 0, decode in a killable subprocess.
timeout = get_env_args('media_decode_timeout', float, 0)
if not timeout or timeout <= 0:
return func(*args, **kwargs)
import multiprocessing as mp

# Fork the decode worker: load_audio runs inside the data pipeline where fork is already the
# norm (PyTorch DataLoader), and unlike forkserver/spawn it does not re-import the training
# entrypoint per call. Fall back to the default context where fork is unavailable.
try:
ctx = mp.get_context('fork')
except ValueError:
ctx = mp.get_context()
queue = ctx.SimpleQueue()
process = ctx.Process(target=_decode_worker, args=(queue, func, args, kwargs))
process.start()
process.join(timeout)
if process.is_alive():
process.terminate()
process.join()
raise TimeoutError(f'Media decode exceeded MEDIA_DECODE_TIMEOUT={timeout}s and was killed '
'(likely a corrupt or unsupported clip).')
if process.exitcode != 0:
raise RuntimeError(f'Media decode subprocess exited abnormally (exitcode={process.exitcode}).')
ok, payload = queue.get()
if not ok:
raise payload
return payload

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Using multiprocessing.SimpleQueue can lead to a silent deadlock when the decoded payload is large.

The Issue

SimpleQueue is backed by an OS pipe. If the serialized payload (e.g., decoded audio arrays or video frames) exceeds the OS pipe buffer limit (typically 64KB on Linux), the child process's queue.put() call will block indefinitely until the parent process reads from the queue. However, the parent process calls process.join(timeout) before reading from the queue. This creates a classic deadlock: the child is blocked waiting for the parent to read, and the parent is blocked waiting for the child to finish. This will cause a false TimeoutError and terminate the child process for any large media file.

The Solution

Use multiprocessing.Pipe and poll(timeout) instead. This allows the parent process to detect if data is available to read (or if the child has exited/errored) and read from the pipe, which unblocks the child process if the buffer fills up.

def _decode_worker(conn, func, args, kwargs):
    try:
        conn.send((True, func(*args, **kwargs)))
    except Exception as e:
        conn.send((False, e))
    finally:
        conn.close()


def _decode_with_timeout(func: Callable[..., _T], *args, **kwargs) -> _T:
    # Native media decoders (audioread/ffmpeg, decord) can deadlock in C while holding the GIL on a
    # corrupt/unsupported clip, silently hanging a DataLoader worker forever; a signal-based timeout
    # can't interrupt them. When `MEDIA_DECODE_TIMEOUT` (seconds) > 0, decode in a killable subprocess.
    timeout = get_env_args('media_decode_timeout', float, 0)
    if not timeout or timeout <= 0:
        return func(*args, **kwargs)
    import multiprocessing as mp

    # Fork the decode worker: load_audio runs inside the data pipeline where fork is already the
    # norm (PyTorch DataLoader), and unlike forkserver/spawn it does not re-import the training
    # entrypoint per call. Fall back to the default context where fork is unavailable.
    try:
        ctx = mp.get_context('fork')
    except ValueError:
        ctx = mp.get_context()
    parent_conn, child_conn = ctx.Pipe(duplex=False)
    process = ctx.Process(target=_decode_worker, args=(child_conn, func, args, kwargs))
    process.start()
    child_conn.close()
    if parent_conn.poll(timeout):
        try:
            ok, payload = parent_conn.recv()
        except EOFError:
            process.join()
            raise RuntimeError(f'Media decode subprocess exited abnormally (exitcode={process.exitcode}).')
        process.join()
        if not ok:
            raise payload
        return payload
    else:
        process.terminate()
        process.join()
        raise TimeoutError(f'Media decode exceeded MEDIA_DECODE_TIMEOUT={timeout}s and was killed '
                           '(likely a corrupt or unsupported clip).')



def _load_audio(audio: Union[str, bytes], sampling_rate: int):
import librosa
try:
audio_io = load_file(audio)
Expand All @@ -308,6 +348,11 @@ def load_audio(audio: Union[str, bytes], sampling_rate: int, return_sr: bool = F
else:
audio_io = _check_path(audio) or audio
res = librosa.load(audio_io, sr=sampling_rate)
return res


def load_audio(audio: Union[str, bytes], sampling_rate: int, return_sr: bool = False):
res = _decode_with_timeout(_load_audio, audio, sampling_rate)
return res if return_sr else res[0]


Expand Down
42 changes: 42 additions & 0 deletions tests/utils/test_vision_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest
import time

from swift.template.vision_utils import _decode_with_timeout


def _sleep_forever(*args, **kwargs):
time.sleep(3600)


def _echo(value):
return value


def _raise_value_error(*args, **kwargs):
raise ValueError('boom')


def test_decode_with_timeout_kills_hung_decode(monkeypatch):
# A decode that never returns must be killed and surface a TimeoutError rather than hang.
monkeypatch.setenv('MEDIA_DECODE_TIMEOUT', '2')
start = time.time()
with pytest.raises(TimeoutError):
_decode_with_timeout(_sleep_forever)
assert time.time() - start < 30


def test_decode_with_timeout_returns_result_when_fast(monkeypatch):
monkeypatch.setenv('MEDIA_DECODE_TIMEOUT', '10')
assert _decode_with_timeout(_echo, 'ok') == 'ok'


def test_decode_with_timeout_propagates_decode_error(monkeypatch):
monkeypatch.setenv('MEDIA_DECODE_TIMEOUT', '10')
with pytest.raises(ValueError, match='boom'):
_decode_with_timeout(_raise_value_error)


def test_decode_with_timeout_disabled_calls_directly(monkeypatch):
# Default (unset / 0): no subprocess, original behavior and zero overhead.
monkeypatch.delenv('MEDIA_DECODE_TIMEOUT', raising=False)
assert _decode_with_timeout(_echo, 'direct') == 'direct'