Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchrl/collectors/_multi_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,6 +1335,7 @@ def _run_processes(self) -> None:
"trajs_per_write": self.trajs_per_write,
"init_fn": self._worker_init_fn,
"auto_register_policy_transforms": self._auto_register_policy_transforms,
"track_policy_version": self.policy_version_tracker is not None,
"pre_collect_hook": self._worker_pre_collect_hook,
"post_collect_hook": self._worker_post_collect_hook,
"compact_obs": self.compact_obs,
Expand Down
7 changes: 7 additions & 0 deletions torchrl/collectors/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def _main_async_collector(
trajs_per_write: int | None = None,
init_fn: Callable[[], None] | None = None,
auto_register_policy_transforms: bool | None = None,
track_policy_version: bool = False,
pre_collect_hook: Callable[[], None] | None = None,
post_collect_hook: Callable[[TensorDictBase], None] | None = None,
compact_obs: bool = False,
Expand Down Expand Up @@ -140,6 +141,7 @@ def _main_async_collector(
trajs_per_batch=trajs_per_batch,
trajs_per_write=trajs_per_write,
auto_register_policy_transforms=auto_register_policy_transforms,
track_policy_version=track_policy_version,
pre_collect_hook=pre_collect_hook,
post_collect_hook=post_collect_hook,
compact_obs=compact_obs,
Expand Down Expand Up @@ -189,11 +191,13 @@ def _main_async_collector(
counter = 0
run_free = False
while True:
fresh_command = False
_timeout = _TIMEOUT if not has_timed_out else 1e-3
if not run_free and pipe_child.poll(_timeout):
counter = 0
try:
data_in, msg = pipe_child.recv()
fresh_command = True
if verbose:
torchrl_logger.debug(f"mp worker {idx} received {msg}")
except EOFError:
Expand Down Expand Up @@ -241,6 +245,7 @@ def _main_async_collector(
# Capture shutdown / update / seed signal, but continue should not be expected
if pipe_child.poll(1e-4):
data_in, msg = pipe_child.recv()
fresh_command = True
if msg == "continue":
# Switch back to run_free = False
run_free = False
Expand Down Expand Up @@ -285,6 +290,8 @@ def _main_async_collector(
# applies weights automatically. No explicit message handling needed here.

if msg in ("continue", "continue_random"):
if track_policy_version and fresh_command and not run_free:
inner_collector.increment_version()
# When in run_free mode with a replay_buffer, the inner collector uses
# _should_use_random_frames() which checks replay_buffer.write_count.
# So we don't override init_random_frames. Otherwise, we use the message
Expand Down
Loading