From 075e82bfeb6a94f4158354577a69dc3f44978325 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 17 Jan 2025 13:29:40 +0000 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- torchrl/data/tensor_specs.py | 48 ++++++++++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 5f724577ddd..bb8bebf2db8 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -2457,6 +2457,7 @@ def __init__( shape: Union[torch.Size, int] = _DEFAULT_SHAPE, device: Optional[DEVICE_TYPING] = None, dtype: torch.dtype | None = None, + example_data: Any = None, **kwargs, ): if isinstance(shape, int): @@ -2467,6 +2468,7 @@ def __init__( super().__init__( shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs ) + self.example_data = example_data def cardinality(self) -> Any: raise RuntimeError("Cannot enumerate a NonTensorSpec.") @@ -2485,30 +2487,46 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor: dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self - return self.__class__(shape=self.shape, device=dest_device, dtype=None) + return self.__class__( + shape=self.shape, + device=dest_device, + dtype=None, + example_data=self.example_data, + ) def clone(self) -> NonTensor: - return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype) + return self.__class__( + shape=self.shape, + device=self.device, + dtype=self.dtype, + example_data=self.example_data, + ) def rand(self, shape=None): if shape is None: shape = () return NonTensorData( - data=None, batch_size=(*shape, *self._safe_shape), device=self.device + data=self.example_data, + batch_size=(*shape, *self._safe_shape), + device=self.device, ) def zero(self, shape=None): if shape is None: shape = () return NonTensorData( - data=None, batch_size=(*shape, *self._safe_shape), device=self.device + data=self.example_data, + batch_size=(*shape, *self._safe_shape), + device=self.device, ) def one(self, shape=None): if shape is None: shape = () return NonTensorData( - data=None, batch_size=(*shape, *self._safe_shape), device=self.device + data=self.example_data, + batch_size=(*shape, *self._safe_shape), + device=self.device, ) def is_in(self, val: Any) -> bool: @@ -2533,10 +2551,17 @@ def expand(self, *shape): raise ValueError( f"The last elements of the expanded shape must match the current one. Got shape={shape} while self.shape={self.shape}." ) - return self.__class__(shape=shape, device=self.device, dtype=None) + return self.__class__( + shape=shape, device=self.device, dtype=None, example_data=self.example_data + ) def _reshape(self, shape): - return self.__class__(shape=shape, device=self.device, dtype=self.dtype) + return self.__class__( + shape=shape, + device=self.device, + dtype=self.dtype, + example_data=self.example_data, + ) def _unflatten(self, dim, sizes): shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape @@ -2544,12 +2569,18 @@ def _unflatten(self, dim, sizes): shape=shape, device=self.device, dtype=self.dtype, + example_data=self.example_data, ) def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index.""" indexed_shape = _size(_shape_indexing(self.shape, idx)) - return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype) + return self.__class__( + shape=indexed_shape, + device=self.device, + dtype=self.dtype, + example_data=self.example_data, + ) def unbind(self, dim: int = 0): orig_dim = dim @@ -2565,6 +2596,7 @@ def unbind(self, dim: int = 0): shape=shape, device=self.device, dtype=self.dtype, + example_data=self.example_data, ) for i in range(self.shape[dim]) ) From 69c612242539bfa401cd344e29b64b59f2f7cd14 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 17 Jan 2025 18:12:58 +0000 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- test/test_specs.py | 29 +++++++++++++++++++++-------- torchrl/data/tensor_specs.py | 17 +++++++++++++++++ 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/test/test_specs.py b/test/test_specs.py index a75ff0352c7..07762c7ad30 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -1402,12 +1402,13 @@ def test_multionehot(self, shape1, shape2): assert spec2.zero().shape == spec2.shape def test_non_tensor(self): - spec = NonTensor((3, 4), device="cpu") + spec = NonTensor((3, 4), device="cpu", example_data="example_data") assert ( spec.expand(2, 3, 4) == spec.expand((2, 3, 4)) - == NonTensor((2, 3, 4), device="cpu") + == NonTensor((2, 3, 4), device="cpu", example_data="example_data") ) + assert spec.expand(2, 3, 4).example_data == "example_data" @pytest.mark.parametrize("shape1", [None, (), (5,)]) @pytest.mark.parametrize("shape2", [(), (10,)]) @@ -1607,9 +1608,10 @@ def test_multionehot( assert spec is not spec.clone() def test_non_tensor(self): - spec = NonTensor(shape=(3, 4), device="cpu") + spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data") assert spec.clone() == spec assert spec.clone() is not spec + assert spec.clone().example_data == "example_data" @pytest.mark.parametrize("shape1", [None, (), (5,)]) def test_onehot( @@ -1840,9 +1842,10 @@ def test_multionehot( spec.unbind(-1) def test_non_tensor(self): - spec = NonTensor(shape=(3, 4), device="cpu") + spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data") assert spec.unbind(1)[0] == spec[:, 0] assert spec.unbind(1)[0] is not spec[:, 0] + assert spec.unbind(1)[0].example_data == "example_data" @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) def test_onehot( @@ -2001,8 +2004,9 @@ def test_multionehot(self, shape1, device): assert spec.to(device).device == device def test_non_tensor(self, device): - spec = NonTensor(shape=(3, 4), device="cpu") + spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data") assert spec.to(device).device == device + assert spec.to(device).example_data == "example_data" @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) def test_onehot(self, shape1, device): @@ -2262,13 +2266,14 @@ def test_stack_multionehot_zero(self, shape, stack_dim): assert r.shape == c.shape def test_stack_non_tensor(self, shape, stack_dim): - spec0 = NonTensor(shape=shape, device="cpu") - spec1 = NonTensor(shape=shape, device="cpu") + spec0 = NonTensor(shape=shape, device="cpu", example_data="example_data") + spec1 = NonTensor(shape=shape, device="cpu", example_data="example_data") new_spec = torch.stack([spec0, spec1], stack_dim) shape_insert = list(shape) shape_insert.insert(stack_dim, 2) assert new_spec.shape == torch.Size(shape_insert) assert new_spec.device == torch.device("cpu") + assert new_spec.example_data == "example_data" def test_stack_onehot(self, shape, stack_dim): n = 5 @@ -3642,10 +3647,18 @@ def test_expand(self): class TestNonTensorSpec: def test_sample(self): - nts = NonTensor(shape=(3, 4)) + nts = NonTensor(shape=(3, 4), example_data="example_data") assert nts.one((2,)).shape == (2, 3, 4) assert nts.rand((2,)).shape == (2, 3, 4) assert nts.zero((2,)).shape == (2, 3, 4) + assert nts.one((2,)).data == "example_data" + assert nts.rand((2,)).data == "example_data" + assert nts.zero((2,)).data == "example_data" + + def test_example_data_ineq(self): + nts0 = NonTensor(shape=(3, 4), example_data="example_data") + nts1 = NonTensor(shape=(3, 4), example_data="example_data 2") + assert nts0 != nts1 @pytest.mark.skipif(not torch.cuda.is_available(), reason="not cuda device") diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index bb8bebf2db8..3d4198ae234 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -2452,6 +2452,8 @@ class NonTensor(TensorSpec): (same will go for :meth:`.zero` and :meth:`.one`). """ + example_data: Any = None + def __init__( self, shape: Union[torch.Size, int] = _DEFAULT_SHAPE, @@ -2470,6 +2472,11 @@ def __init__( ) self.example_data = example_data + def __eq__(self, other): + eq = super().__eq__(other) + eq = eq & (self.example_data == getattr(other, "example_data", None)) + return eq + def cardinality(self) -> Any: raise RuntimeError("Cannot enumerate a NonTensorSpec.") @@ -2555,6 +2562,16 @@ def expand(self, *shape): shape=shape, device=self.device, dtype=None, example_data=self.example_data ) + def unsqueeze(self, dim: int) -> NonTensor: + unsq = super().unsqueeze(dim=dim) + unsq.example_data = self.example_data + return unsq + + def squeeze(self, dim: int | None = None) -> NonTensor: + sq = super().squeeze(dim=dim) + sq.example_data = self.example_data + return sq + def _reshape(self, shape): return self.__class__( shape=shape,