diff --git a/test/llm/test_llm_transforms.py b/test/llm/test_llm_transforms.py index b6d0c1a6853..5f3f11133fd 100644 --- a/test/llm/test_llm_transforms.py +++ b/test/llm/test_llm_transforms.py @@ -18,6 +18,7 @@ ExecuteToolsInOrder, IncrementalTokenizer, JSONCallParser, + PolicyVersion, ToolCall, ToolRegistry, XMLBlockParser, @@ -718,3 +719,31 @@ def test_empty_history_handling(self, tokenizer): assert ("tokens", "prompt") in result.keys(True, True) tokens = result.get(("tokens", "prompt"), as_list=True) assert tokens[0].numel() > 0 + + +class TestPolicyVersion: + def test_int_version_dtype_and_device(self): + """Integer policy version must stay int64 and follow the tensordict device. + + Regression for a bug where ``version_type="int"`` was cast to float + and dropped the device, producing CPU float tensors that mismatched + the surrounding tensordict. + """ + import torch + + transform = PolicyVersion(version_type="int") + transform.version = 7 + + td = TensorDict(batch_size=(4,)) + out = transform._step(td, td.copy()) + version = out.get("policy_version") + assert version.dtype == torch.int64 + assert version.shape == (4,) + assert torch.equal(version, torch.full((4,), 7, dtype=torch.int64)) + + if torch.cuda.is_available(): + td_cuda = TensorDict(batch_size=(4,), device="cuda") + out_cuda = transform._step(td_cuda, td_cuda.copy()) + version_cuda = out_cuda.get("policy_version") + assert version_cuda.dtype == torch.int64 + assert version_cuda.device.type == "cuda" diff --git a/torchrl/envs/llm/transforms/policy_version.py b/torchrl/envs/llm/transforms/policy_version.py index 493b630780c..430ce515b58 100644 --- a/torchrl/envs/llm/transforms/policy_version.py +++ b/torchrl/envs/llm/transforms/policy_version.py @@ -159,8 +159,15 @@ def _step( if self.version_type in (str, "uuid"): version = NonTensorData(self.version).expand(next_tensordict.shape) elif self.version_type in (int, "int"): - # Cast to float for torch.full - version = torch.full(next_tensordict.shape, float(cast(int, self.version))) + device = next_tensordict.device + kwargs = {"dtype": torch.int64} + if device is not None: + kwargs["device"] = device + version = torch.full( + next_tensordict.shape, + cast(int, self.version), + **kwargs, + ) else: raise ValueError(f"Invalid version type: {self.version_type}")