Skip to content
Merged
46 changes: 39 additions & 7 deletions lib/crewai-core/src/crewai_core/lock_store.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
"""Centralised lock factory.

If ``REDIS_URL`` is set and the ``redis`` package is installed, locks are
distributed via ``portalocker.RedisLock``. Otherwise, falls back to the
standard file-based ``portalocker.Lock`` in the system temp dir.
By default, if ``REDIS_URL`` is set and the ``redis`` package is installed,
locks are distributed via ``portalocker.RedisLock``. Otherwise, falls back to
the standard file-based ``portalocker.Lock`` in the system temp dir.

The backend can be replaced via :func:`set_lock_backend` to plug in a custom
locking strategy (e.g. a different distributed lock service, or an in-process
lock for tests).
"""

from __future__ import annotations

from collections.abc import Iterator
from contextlib import contextmanager
from collections.abc import Callable, Iterator
from contextlib import AbstractContextManager, contextmanager
from functools import lru_cache
from hashlib import md5
import logging
Expand All @@ -30,6 +34,25 @@

_DEFAULT_TIMEOUT: Final[int] = 120

# A backend is called as ``backend(name, timeout=...)`` and returns a context
# manager that holds the lock while the ``with`` block runs.
LockBackend = Callable[..., AbstractContextManager[None]]

# ``None`` means use the built-in Redis/file selection.
_backend: LockBackend | None = None


def set_lock_backend(backend: LockBackend | None) -> None:
"""Replace the process-wide locking backend used by :func:`lock`.

Intended for one-time setup at startup. Pass ``None`` to restore the
built-in Redis/file default. In-flight :func:`lock` calls keep the backend
they started with, but swapping backends while other threads acquire locks
is otherwise unsynchronised.
"""
global _backend
_backend = backend
Comment thread
coderabbitai[bot] marked this conversation as resolved.


def _redis_available() -> bool:
"""Return True if redis is installed and REDIS_URL is set."""
Expand Down Expand Up @@ -58,10 +81,19 @@ def lock(name: str, *, timeout: float = _DEFAULT_TIMEOUT) -> Iterator[None]:
"""Acquire a named lock, yielding while it is held.

Args:
name: A human-readable lock name (e.g. ``"chromadb_init"``).
Automatically namespaced to avoid collisions.
name: A human-readable lock name (e.g. ``"chromadb_init"``). The
built-in default namespaces it to avoid collisions; a custom
backend receives it verbatim.
timeout: Maximum seconds to wait for the lock before raising.
"""
# Snapshot the global once: a concurrent set_lock_backend() must not turn
# the check-then-call into calling ``None``.
backend = _backend
if backend is not None:
with backend(name, timeout=timeout):
yield
return
Comment thread
cursor[bot] marked this conversation as resolved.

channel = f"crewai:{md5(name.encode(), usedforsecurity=False).hexdigest()}"

if _redis_available():
Expand Down
51 changes: 49 additions & 2 deletions lib/crewai/tests/utilities/test_lock_store.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Tests for lock_store.

We verify our own logic: the _redis_available guard and which portalocker
backend is selected. We trust portalocker to handle actual locking mechanics.
We verify our own logic: the _redis_available guard, which portalocker
backend is selected, and that a custom backend can be plugged in. We trust
portalocker to handle actual locking mechanics.
"""

from __future__ import annotations

from contextlib import contextmanager
import sys
from unittest import mock

Expand All @@ -20,6 +22,14 @@ def no_redis_url(monkeypatch):
monkeypatch.setattr(lock_store, "_REDIS_URL", None)


@pytest.fixture(autouse=True)
def reset_backend():
"""Ensure a custom backend never leaks across tests."""
lock_store.set_lock_backend(None)
yield
lock_store.set_lock_backend(None)


# _redis_available


Expand Down Expand Up @@ -64,3 +74,40 @@ def test_uses_redis_lock_when_redis_available(monkeypatch):
kwargs = mock_redis_lock.call_args.kwargs
assert kwargs["channel"].startswith("crewai:")
assert kwargs["connection"] is fake_conn


# custom backend


def test_custom_backend_is_used():
calls = []

@contextmanager
def fake_backend(name, *, timeout):
calls.append((name, timeout))
yield

lock_store.set_lock_backend(fake_backend)

# The default file/redis path must not be touched when overridden.
with mock.patch("portalocker.Lock") as mock_lock:
with lock("custom_test", timeout=5):
pass

mock_lock.assert_not_called()
assert calls == [("custom_test", 5)]


def test_clearing_backend_restores_default():
@contextmanager
def fake_backend(name, *, timeout):
yield

lock_store.set_lock_backend(fake_backend)
lock_store.set_lock_backend(None)

with mock.patch("portalocker.Lock") as mock_lock:
with lock("after_clear"):
pass

mock_lock.assert_called_once()
Loading