Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/test_inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def test_forward_as_tensordict_module(self):
transport,
in_keys=["observation"],
out_keys=["action"],
max_inflight=1,
)
td = TensorDict({"observation": torch.randn(4)})
result = remote_policy(td)
Expand Down
5 changes: 5 additions & 0 deletions torchrl/collectors/_async_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ class AsyncBatchedCollector(BaseCollector):
policy_version_key (NestedKey or None, optional): TensorDict key used
for behavior-policy version annotations. ``None`` disables
annotations. Defaults to ``"policy_version"``.
max_inflight_per_env (int, optional): maximum unresolved remote-policy
requests per environment coordinator. Defaults to ``1``.
backend (str, optional): global default backend for both
environments and policy inference. Specific overrides
``env_backend`` and ``policy_backend`` take precedence when set.
Expand Down Expand Up @@ -294,6 +296,7 @@ def __init__(
device_config: InferenceDeviceConfig | None = None,
policy_version: int = 0,
policy_version_key: NestedKey | None = "policy_version",
max_inflight_per_env: int | None = 1,
server_backend: Literal["thread", "process"] = "thread",
):
if policy is not None and policy_factory is not None:
Expand Down Expand Up @@ -423,6 +426,7 @@ def __init__(
policy_version_key=policy_version_key,
)
self._policy_version_key = policy_version_key
self._max_inflight_per_env = max_inflight_per_env

# ---- collector settings -----------------------------------------------
self.requested_frames_per_batch = frames_per_batch
Expand Down Expand Up @@ -472,6 +476,7 @@ def _ensure_started(self) -> None:
self._clients = [
PolicyClientModule(
self._transport.client(),
max_inflight=self._max_inflight_per_env,
policy_version_key=self._policy_version_key or "policy_version",
)
for _ in range(self._num_envs)
Expand Down
55 changes: 52 additions & 3 deletions torchrl/modules/inference_server/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import threading
from collections.abc import Callable, Sequence
from typing import Any

import torch
from tensordict.base import TensorDictBase
Expand All @@ -27,6 +29,32 @@ def result(self, timeout: float | None = None) -> TensorDictBase:
return self._result


class _ReleaseOnResultFuture:
def __init__(self, future, release: Callable[[], None]):
self._future = future
self._release = release
self._released = False
self._lock = threading.Lock()

def _release_once(self) -> None:
with self._lock:
if not self._released:
self._released = True
self._release()

def done(self) -> bool:
return self._future.done()

def result(self, timeout: float | None = None) -> TensorDictBase:
try:
return self._future.result(timeout=timeout)
finally:
self._release_once()

def __getattr__(self, name: str) -> Any:
return getattr(self._future, name)


class PolicyClientModule(TensorDictModuleBase):
"""TensorDict policy wrapper for remote inference-server clients.

Expand All @@ -45,6 +73,8 @@ class PolicyClientModule(TensorDictModuleBase):
module. The full input TensorDict is still sent to the server.
out_keys (sequence of NestedKey, optional): output keys advertised by
the module.
max_inflight (int, optional): maximum number of unresolved asynchronous
requests submitted through this module. ``None`` means unbounded.
target_policy_version (int, optional): expected latest policy version
used for bounded-staleness checks.
max_policy_lag (int, optional): maximum allowed
Expand Down Expand Up @@ -83,6 +113,7 @@ def __init__(
*,
in_keys: Sequence[NestedKey] | None = None,
out_keys: Sequence[NestedKey] | None = None,
max_inflight: int | None = None,
target_policy_version: int | None = None,
max_policy_lag: int | None = None,
policy_version_key: NestedKey = "policy_version",
Expand All @@ -93,9 +124,21 @@ def __init__(
self.client = client
self.in_keys = list(in_keys or [])
self.out_keys = list(out_keys or [])
self.max_inflight = max_inflight
self.target_policy_version = target_policy_version
self.max_policy_lag = max_policy_lag
self.policy_version_key = policy_version_key
self._inflight_sem = (
threading.BoundedSemaphore(max_inflight)
if max_inflight is not None
else None
)

def _acquire_inflight(self) -> Callable[[], None]:
if self._inflight_sem is None:
return lambda: None
self._inflight_sem.acquire()
return self._inflight_sem.release

def _check_policy_lag(self, tensordict: TensorDictBase) -> None:
if self.target_policy_version is None or self.max_policy_lag is None:
Expand Down Expand Up @@ -125,14 +168,20 @@ def submit(self, tensordict: TensorDictBase):
Returns:
Future-like object whose ``result()`` method returns a TensorDict.
"""
release = self._acquire_inflight()
submit = getattr(self.client, "submit", None)
if submit is None:
try:
result = self.client(tensordict)
return _ImmediateFuture(result)
return _ReleaseOnResultFuture(_ImmediateFuture(result), release)
except BaseException as exc:
return _ImmediateFuture(exc)
return submit(tensordict)
return _ReleaseOnResultFuture(_ImmediateFuture(exc), release)
try:
future = submit(tensordict)
except BaseException:
release()
raise
return _ReleaseOnResultFuture(future, release)

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
result = self.submit(tensordict).result()
Expand Down
Loading