From 40c7979f1775bdb6d39e1adf0d618c2e60a4c208 Mon Sep 17 00:00:00 2001 From: Erica Lin Date: Wed, 20 May 2026 16:25:48 -0400 Subject: [PATCH 1/5] test --- test/envs/test_auto_reset.py | 2 +- torchrl/envs/common.py | 5 ++++- torchrl/envs/libs/gym.py | 4 ++-- torchrl/envs/transforms/_misc.py | 4 ++-- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/test/envs/test_auto_reset.py b/test/envs/test_auto_reset.py index edc56efbb81..50fb56d8a07 100644 --- a/test/envs/test_auto_reset.py +++ b/test/envs/test_auto_reset.py @@ -99,7 +99,7 @@ def test_native_auto_reset_step_and_maybe_reset(self): ) assert tensordict["next", "done"].all() - assert (tensordict["next", "observation"] == 0).all() + assert (tensordict["next", "observation"] == tensordict["observation"]).all() assert not tensordict_["done"].any() assert tensordict_["is_init"].all() diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 117953ac1de..8142f74c3a6 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -3693,7 +3693,10 @@ def _native_autoreset_set_invalid_next_observation( if obs is None or not isinstance(obs, torch.Tensor): continue reset_observations[obs_key] = obs.clone() - if obs.is_floating_point(): + current_obs = tensordict.get(obs_key, None) + if current_obs is not None: + obs[done] = current_obs[done] + elif obs.is_floating_point(): obs[done] = torch.nan else: obs[done] = 0 diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index b4ab20f40d4..bcf29404211 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -955,8 +955,8 @@ class GymWrapper(GymLikeEnv, metaclass=_GymAsyncMeta): observation returned by the environment as the next root observation instead of calling reset from :meth:`~torchrl.envs.EnvBase.step_and_maybe_reset`. The terminal - ``"next"`` observation is still invalid and filled with ``NaN`` for - floating point observations. Defaults to ``False``. + ``"next"`` observation is replaced with the current (root) + observation. Defaults to ``False``. Attributes: available_envs (List[str]): a list of environments to build. diff --git a/torchrl/envs/transforms/_misc.py b/torchrl/envs/transforms/_misc.py index 296a8b35ad1..b0588dd9a50 100644 --- a/torchrl/envs/transforms/_misc.py +++ b/torchrl/envs/transforms/_misc.py @@ -408,8 +408,8 @@ class VecGymEnvTransform(Transform): native_autoreset (bool, optional): if ``True``, leaves the native auto-reset observation available to the environment wrapper so it can be cloned into the next root observation, while the terminal - floating point ``"next"`` observation is marked with ``NaN``. - Defaults to ``False``. + ``"next"`` observation is replaced with the current (root) + observation. Defaults to ``False``. .. note:: In general, this class should not be handled directly. It is created whenever a vectorized environment is placed within a :class:`GymWrapper`. From f7051824665cb59e3902dbb742b7940fb407fe1a Mon Sep 17 00:00:00 2001 From: Erica Lin Date: Wed, 20 May 2026 17:13:38 -0400 Subject: [PATCH 2/5] fix --- torchrl/envs/common.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 8142f74c3a6..7a0cc8362d1 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -3694,10 +3694,8 @@ def _native_autoreset_set_invalid_next_observation( continue reset_observations[obs_key] = obs.clone() current_obs = tensordict.get(obs_key, None) - if current_obs is not None: + if obs.is_floating_point(): obs[done] = current_obs[done] - elif obs.is_floating_point(): - obs[done] = torch.nan else: obs[done] = 0 return reset_observations From a2794701fa91b391bc08d4f9a1292f79c3926b0e Mon Sep 17 00:00:00 2001 From: Erica Lin Date: Wed, 20 May 2026 17:29:20 -0400 Subject: [PATCH 3/5] fix native autoreset --- torchrl/envs/libs/isaac_lab.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchrl/envs/libs/isaac_lab.py b/torchrl/envs/libs/isaac_lab.py index b92d883e3cf..b229c198035 100644 --- a/torchrl/envs/libs/isaac_lab.py +++ b/torchrl/envs/libs/isaac_lab.py @@ -72,6 +72,7 @@ def __init__( categorical_action_encoding=categorical_action_encoding, allow_done_after_reset=allow_done_after_reset, convert_actions_to_numpy=convert_actions_to_numpy, + native_autoreset=native_autoreset, **kwargs, ) From b5ea779a8452cd1603a5dc6b2b331e433e5f2da1 Mon Sep 17 00:00:00 2001 From: Erica Lin Date: Thu, 21 May 2026 12:37:18 -0400 Subject: [PATCH 4/5] fix native autoreset --- torchrl/envs/libs/isaac_lab.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchrl/envs/libs/isaac_lab.py b/torchrl/envs/libs/isaac_lab.py index b229c198035..fa289c56418 100644 --- a/torchrl/envs/libs/isaac_lab.py +++ b/torchrl/envs/libs/isaac_lab.py @@ -61,7 +61,6 @@ def __init__( allow_done_after_reset: bool = True, convert_actions_to_numpy: bool = False, device: torch.device | None = None, - native_autoreset: bool = False, **kwargs, ): if device is None: @@ -72,7 +71,6 @@ def __init__( categorical_action_encoding=categorical_action_encoding, allow_done_after_reset=allow_done_after_reset, convert_actions_to_numpy=convert_actions_to_numpy, - native_autoreset=native_autoreset, **kwargs, ) From 6c4605aa8a2e40a3210b95a2cdd0897aae043278 Mon Sep 17 00:00:00 2001 From: Erica Lin Date: Thu, 21 May 2026 13:34:43 -0400 Subject: [PATCH 5/5] fix native autoreset --- torchrl/envs/libs/isaac_lab.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchrl/envs/libs/isaac_lab.py b/torchrl/envs/libs/isaac_lab.py index fa289c56418..b92d883e3cf 100644 --- a/torchrl/envs/libs/isaac_lab.py +++ b/torchrl/envs/libs/isaac_lab.py @@ -61,6 +61,7 @@ def __init__( allow_done_after_reset: bool = True, convert_actions_to_numpy: bool = False, device: torch.device | None = None, + native_autoreset: bool = False, **kwargs, ): if device is None: