diff --git a/lib/crewai-core/src/crewai_core/lock_store.py b/lib/crewai-core/src/crewai_core/lock_store.py index 0f09fa7f66..be1d08faaf 100644 --- a/lib/crewai-core/src/crewai_core/lock_store.py +++ b/lib/crewai-core/src/crewai_core/lock_store.py @@ -1,14 +1,18 @@ """Centralised lock factory. -If ``REDIS_URL`` is set and the ``redis`` package is installed, locks are -distributed via ``portalocker.RedisLock``. Otherwise, falls back to the -standard file-based ``portalocker.Lock`` in the system temp dir. +By default, if ``REDIS_URL`` is set and the ``redis`` package is installed, +locks are distributed via ``portalocker.RedisLock``. Otherwise, falls back to +the standard file-based ``portalocker.Lock`` in the system temp dir. + +The backend can be replaced via :func:`set_lock_backend` to plug in a custom +locking strategy (e.g. a different distributed lock service, or an in-process +lock for tests). """ from __future__ import annotations -from collections.abc import Iterator -from contextlib import contextmanager +from collections.abc import Callable, Iterator +from contextlib import AbstractContextManager, contextmanager from functools import lru_cache from hashlib import md5 import logging @@ -30,6 +34,25 @@ _DEFAULT_TIMEOUT: Final[int] = 120 +# A backend is called as ``backend(name, timeout=...)`` and returns a context +# manager that holds the lock while the ``with`` block runs. +LockBackend = Callable[..., AbstractContextManager[None]] + +# ``None`` means use the built-in Redis/file selection. +_backend: LockBackend | None = None + + +def set_lock_backend(backend: LockBackend | None) -> None: + """Replace the process-wide locking backend used by :func:`lock`. + + Intended for one-time setup at startup. Pass ``None`` to restore the + built-in Redis/file default. In-flight :func:`lock` calls keep the backend + they started with, but swapping backends while other threads acquire locks + is otherwise unsynchronised. + """ + global _backend + _backend = backend + def _redis_available() -> bool: """Return True if redis is installed and REDIS_URL is set.""" @@ -58,10 +81,19 @@ def lock(name: str, *, timeout: float = _DEFAULT_TIMEOUT) -> Iterator[None]: """Acquire a named lock, yielding while it is held. Args: - name: A human-readable lock name (e.g. ``"chromadb_init"``). - Automatically namespaced to avoid collisions. + name: A human-readable lock name (e.g. ``"chromadb_init"``). The + built-in default namespaces it to avoid collisions; a custom + backend receives it verbatim. timeout: Maximum seconds to wait for the lock before raising. """ + # Snapshot the global once: a concurrent set_lock_backend() must not turn + # the check-then-call into calling ``None``. + backend = _backend + if backend is not None: + with backend(name, timeout=timeout): + yield + return + channel = f"crewai:{md5(name.encode(), usedforsecurity=False).hexdigest()}" if _redis_available(): diff --git a/lib/crewai/tests/utilities/test_lock_store.py b/lib/crewai/tests/utilities/test_lock_store.py index 1baa0169a6..baad049d8a 100644 --- a/lib/crewai/tests/utilities/test_lock_store.py +++ b/lib/crewai/tests/utilities/test_lock_store.py @@ -1,11 +1,13 @@ """Tests for lock_store. -We verify our own logic: the _redis_available guard and which portalocker -backend is selected. We trust portalocker to handle actual locking mechanics. +We verify our own logic: the _redis_available guard, which portalocker +backend is selected, and that a custom backend can be plugged in. We trust +portalocker to handle actual locking mechanics. """ from __future__ import annotations +from contextlib import contextmanager import sys from unittest import mock @@ -20,6 +22,14 @@ def no_redis_url(monkeypatch): monkeypatch.setattr(lock_store, "_REDIS_URL", None) +@pytest.fixture(autouse=True) +def reset_backend(): + """Ensure a custom backend never leaks across tests.""" + lock_store.set_lock_backend(None) + yield + lock_store.set_lock_backend(None) + + # _redis_available @@ -64,3 +74,40 @@ def test_uses_redis_lock_when_redis_available(monkeypatch): kwargs = mock_redis_lock.call_args.kwargs assert kwargs["channel"].startswith("crewai:") assert kwargs["connection"] is fake_conn + + +# custom backend + + +def test_custom_backend_is_used(): + calls = [] + + @contextmanager + def fake_backend(name, *, timeout): + calls.append((name, timeout)) + yield + + lock_store.set_lock_backend(fake_backend) + + # The default file/redis path must not be touched when overridden. + with mock.patch("portalocker.Lock") as mock_lock: + with lock("custom_test", timeout=5): + pass + + mock_lock.assert_not_called() + assert calls == [("custom_test", 5)] + + +def test_clearing_backend_restores_default(): + @contextmanager + def fake_backend(name, *, timeout): + yield + + lock_store.set_lock_backend(fake_backend) + lock_store.set_lock_backend(None) + + with mock.patch("portalocker.Lock") as mock_lock: + with lock("after_clear"): + pass + + mock_lock.assert_called_once()