From 3b9db6c3e84e867d8c7929b92a14ec49c5f73f68 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 15 May 2026 08:20:33 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- test/llm/test_llm_transforms.py | 31 +++++++++++++++++++ torchrl/envs/llm/transforms/policy_version.py | 11 +++++-- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/test/llm/test_llm_transforms.py b/test/llm/test_llm_transforms.py index b6d0c1a6853..6c83ebe1b0c 100644 --- a/test/llm/test_llm_transforms.py +++ b/test/llm/test_llm_transforms.py @@ -18,6 +18,7 @@ ExecuteToolsInOrder, IncrementalTokenizer, JSONCallParser, + PolicyVersion, ToolCall, ToolRegistry, XMLBlockParser, @@ -718,3 +719,33 @@ def test_empty_history_handling(self, tokenizer): assert ("tokens", "prompt") in result.keys(True, True) tokens = result.get(("tokens", "prompt"), as_list=True) assert tokens[0].numel() > 0 + + +class TestPolicyVersion: + def test_int_version_dtype_and_device(self): + """Integer policy version must stay int64 and follow the tensordict device. + + Regression for a bug where ``version_type="int"`` was cast to float + and dropped the device, producing CPU float tensors that mismatched + the surrounding tensordict. + """ + import torch + + transform = PolicyVersion(version_type="int") + transform.version = 7 + + td = TensorDict(batch_size=(4,)) + out = transform._step(td, td.copy()) + version = out.get("policy_version") + assert version.dtype == torch.int64 + assert version.shape == (4,) + assert torch.equal( + version, torch.full((4,), 7, dtype=torch.int64) + ) + + if torch.cuda.is_available(): + td_cuda = TensorDict(batch_size=(4,), device="cuda") + out_cuda = transform._step(td_cuda, td_cuda.copy()) + version_cuda = out_cuda.get("policy_version") + assert version_cuda.dtype == torch.int64 + assert version_cuda.device.type == "cuda" diff --git a/torchrl/envs/llm/transforms/policy_version.py b/torchrl/envs/llm/transforms/policy_version.py index 493b630780c..430ce515b58 100644 --- a/torchrl/envs/llm/transforms/policy_version.py +++ b/torchrl/envs/llm/transforms/policy_version.py @@ -159,8 +159,15 @@ def _step( if self.version_type in (str, "uuid"): version = NonTensorData(self.version).expand(next_tensordict.shape) elif self.version_type in (int, "int"): - # Cast to float for torch.full - version = torch.full(next_tensordict.shape, float(cast(int, self.version))) + device = next_tensordict.device + kwargs = {"dtype": torch.int64} + if device is not None: + kwargs["device"] = device + version = torch.full( + next_tensordict.shape, + cast(int, self.version), + **kwargs, + ) else: raise ValueError(f"Invalid version type: {self.version_type}")