Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions test/objectives/test_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,83 @@ def test_gae_recurrent(self, module, vectorized):
a1 = r1["advantage"]
torch.testing.assert_close(a0, a1)

def _build_shifted_test_td(self, *, with_internal_done: bool):
"""Build a rollout-shaped tensordict for shifted-mode tests."""
B, T, obs_dim = 4, 8, 6
all_obs = torch.randn(B, T + 1, obs_dim)
obs = all_obs[:, :T].clone()
next_obs = all_obs[:, 1:].clone()
reward = torch.randn(B, T, 1)
done = torch.zeros(B, T, 1, dtype=torch.bool)
terminated = torch.zeros(B, T, 1, dtype=torch.bool)
if with_internal_done:
done[0, 3, 0] = True
next_obs[0, 3] = torch.randn(obs_dim)
done[:, -1, 0] = True
td = TensorDict(
{
"observation": obs,
"next": TensorDict(
{
"observation": next_obs,
"reward": reward,
"done": done,
"terminated": terminated,
"truncated": done & ~terminated,
},
batch_size=[B, T],
),
},
batch_size=[B, T],
)
td.refine_names(..., "time")
return td, obs_dim

@pytest.mark.parametrize("with_internal_done", [False, True])
def test_gae_shifted_compact_and_legacy(self, with_internal_done):
# Both shifted='compact' and shifted='legacy' must produce a valid
# advantage. 'legacy' must match shifted=False exactly. 'compact'
# is allowed a small boundary bias from copying V(obs[T-1]) at
# the rollout boundary; not asserted.
torch.manual_seed(0)
td, obs_dim = self._build_shifted_test_td(with_internal_done=with_internal_done)
value_net = TensorDictModule(
nn.Linear(obs_dim, 1),
in_keys=["observation"],
out_keys=["state_value"],
)
gae_compact = GAE(
gamma=0.9, lmbda=0.95, value_network=value_net, shifted="compact"
)
gae_legacy = GAE(
gamma=0.9, lmbda=0.95, value_network=value_net, shifted="legacy"
)
gae_unshifted = GAE(
gamma=0.9, lmbda=0.95, value_network=value_net, shifted=False
)
gae_compact(td.copy())
adv_legacy = gae_legacy(td.copy())["advantage"]
adv_unshifted = gae_unshifted(td.copy())["advantage"]
torch.testing.assert_close(adv_legacy, adv_unshifted)

def test_gae_shifted_true_deprecation_aliases_legacy(self):
torch.manual_seed(0)
td, obs_dim = self._build_shifted_test_td(with_internal_done=True)
value_net = TensorDictModule(
nn.Linear(obs_dim, 1),
in_keys=["observation"],
out_keys=["state_value"],
)
with pytest.warns(DeprecationWarning, match="shifted=True is deprecated"):
gae_true = GAE(gamma=0.9, lmbda=0.95, value_network=value_net, shifted=True)
gae_legacy = GAE(
gamma=0.9, lmbda=0.95, value_network=value_net, shifted="legacy"
)
assert gae_true.shifted == "legacy"
adv_true = gae_true(td.copy())["advantage"]
adv_legacy = gae_legacy(td.copy())["advantage"]
torch.testing.assert_close(adv_true, adv_legacy)

@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("gamma", [0.1, 0.5, 0.99])
@pytest.mark.parametrize("lmbda", [0.1, 0.5, 0.99])
Expand Down
Loading
Loading