From bf71148c614bcc01d53a4c01de6d473dc453ed98 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 21 Jun 2026 07:35:30 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- test/test_inference_server.py | 1 + torchrl/collectors/_async_batched.py | 5 ++ torchrl/modules/inference_server/_client.py | 55 +++++++++++++++++++-- 3 files changed, 58 insertions(+), 3 deletions(-) diff --git a/test/test_inference_server.py b/test/test_inference_server.py index 04b87a744f1..949d6eed583 100644 --- a/test/test_inference_server.py +++ b/test/test_inference_server.py @@ -344,6 +344,7 @@ def test_forward_as_tensordict_module(self): transport, in_keys=["observation"], out_keys=["action"], + max_inflight=1, ) td = TensorDict({"observation": torch.randn(4)}) result = remote_policy(td) diff --git a/torchrl/collectors/_async_batched.py b/torchrl/collectors/_async_batched.py index abdc8186339..e2b02e825f9 100644 --- a/torchrl/collectors/_async_batched.py +++ b/torchrl/collectors/_async_batched.py @@ -196,6 +196,8 @@ class AsyncBatchedCollector(BaseCollector): policy_version_key (NestedKey or None, optional): TensorDict key used for behavior-policy version annotations. ``None`` disables annotations. Defaults to ``"policy_version"``. + max_inflight_per_env (int, optional): maximum unresolved remote-policy + requests per environment coordinator. Defaults to ``1``. backend (str, optional): global default backend for both environments and policy inference. Specific overrides ``env_backend`` and ``policy_backend`` take precedence when set. @@ -291,6 +293,7 @@ def __init__( device_config: InferenceDeviceConfig | None = None, policy_version: int = 0, policy_version_key: NestedKey | None = "policy_version", + max_inflight_per_env: int | None = 1, server_backend: Literal["thread", "process"] = "thread", ): if policy is not None and policy_factory is not None: @@ -420,6 +423,7 @@ def __init__( policy_version_key=policy_version_key, ) self._policy_version_key = policy_version_key + self._max_inflight_per_env = max_inflight_per_env # ---- collector settings ----------------------------------------------- self.requested_frames_per_batch = frames_per_batch @@ -469,6 +473,7 @@ def _ensure_started(self) -> None: self._clients = [ PolicyClientModule( self._transport.client(), + max_inflight=self._max_inflight_per_env, policy_version_key=self._policy_version_key or "policy_version", ) for _ in range(self._num_envs) diff --git a/torchrl/modules/inference_server/_client.py b/torchrl/modules/inference_server/_client.py index a749e030447..7af3a5329a5 100644 --- a/torchrl/modules/inference_server/_client.py +++ b/torchrl/modules/inference_server/_client.py @@ -4,7 +4,9 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import threading from collections.abc import Callable, Sequence +from typing import Any import torch from tensordict.base import TensorDictBase @@ -27,6 +29,32 @@ def result(self, timeout: float | None = None) -> TensorDictBase: return self._result +class _ReleaseOnResultFuture: + def __init__(self, future, release: Callable[[], None]): + self._future = future + self._release = release + self._released = False + self._lock = threading.Lock() + + def _release_once(self) -> None: + with self._lock: + if not self._released: + self._released = True + self._release() + + def done(self) -> bool: + return self._future.done() + + def result(self, timeout: float | None = None) -> TensorDictBase: + try: + return self._future.result(timeout=timeout) + finally: + self._release_once() + + def __getattr__(self, name: str) -> Any: + return getattr(self._future, name) + + class PolicyClientModule(TensorDictModuleBase): """TensorDict policy wrapper for remote inference-server clients. @@ -45,6 +73,8 @@ class PolicyClientModule(TensorDictModuleBase): module. The full input TensorDict is still sent to the server. out_keys (sequence of NestedKey, optional): output keys advertised by the module. + max_inflight (int, optional): maximum number of unresolved asynchronous + requests submitted through this module. ``None`` means unbounded. target_policy_version (int, optional): expected latest policy version used for bounded-staleness checks. max_policy_lag (int, optional): maximum allowed @@ -83,6 +113,7 @@ def __init__( *, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, + max_inflight: int | None = None, target_policy_version: int | None = None, max_policy_lag: int | None = None, policy_version_key: NestedKey = "policy_version", @@ -93,9 +124,21 @@ def __init__( self.client = client self.in_keys = list(in_keys or []) self.out_keys = list(out_keys or []) + self.max_inflight = max_inflight self.target_policy_version = target_policy_version self.max_policy_lag = max_policy_lag self.policy_version_key = policy_version_key + self._inflight_sem = ( + threading.BoundedSemaphore(max_inflight) + if max_inflight is not None + else None + ) + + def _acquire_inflight(self) -> Callable[[], None]: + if self._inflight_sem is None: + return lambda: None + self._inflight_sem.acquire() + return self._inflight_sem.release def _check_policy_lag(self, tensordict: TensorDictBase) -> None: if self.target_policy_version is None or self.max_policy_lag is None: @@ -125,14 +168,20 @@ def submit(self, tensordict: TensorDictBase): Returns: Future-like object whose ``result()`` method returns a TensorDict. """ + release = self._acquire_inflight() submit = getattr(self.client, "submit", None) if submit is None: try: result = self.client(tensordict) - return _ImmediateFuture(result) + return _ReleaseOnResultFuture(_ImmediateFuture(result), release) except BaseException as exc: - return _ImmediateFuture(exc) - return submit(tensordict) + return _ReleaseOnResultFuture(_ImmediateFuture(exc), release) + try: + future = submit(tensordict) + except BaseException: + release() + raise + return _ReleaseOnResultFuture(future, release) def forward(self, tensordict: TensorDictBase) -> TensorDictBase: result = self.submit(tensordict).result()