Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] example_data for NonTensor spec #2698

Open
wants to merge 2 commits into
base: gh/vmoens/67/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,12 +1402,13 @@ def test_multionehot(self, shape1, shape2):
assert spec2.zero().shape == spec2.shape

def test_non_tensor(self):
spec = NonTensor((3, 4), device="cpu")
spec = NonTensor((3, 4), device="cpu", example_data="example_data")
assert (
spec.expand(2, 3, 4)
== spec.expand((2, 3, 4))
== NonTensor((2, 3, 4), device="cpu")
== NonTensor((2, 3, 4), device="cpu", example_data="example_data")
)
assert spec.expand(2, 3, 4).example_data == "example_data"

@pytest.mark.parametrize("shape1", [None, (), (5,)])
@pytest.mark.parametrize("shape2", [(), (10,)])
Expand Down Expand Up @@ -1607,9 +1608,10 @@ def test_multionehot(
assert spec is not spec.clone()

def test_non_tensor(self):
spec = NonTensor(shape=(3, 4), device="cpu")
spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data")
assert spec.clone() == spec
assert spec.clone() is not spec
assert spec.clone().example_data == "example_data"

@pytest.mark.parametrize("shape1", [None, (), (5,)])
def test_onehot(
Expand Down Expand Up @@ -1840,9 +1842,10 @@ def test_multionehot(
spec.unbind(-1)

def test_non_tensor(self):
spec = NonTensor(shape=(3, 4), device="cpu")
spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data")
assert spec.unbind(1)[0] == spec[:, 0]
assert spec.unbind(1)[0] is not spec[:, 0]
assert spec.unbind(1)[0].example_data == "example_data"

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_onehot(
Expand Down Expand Up @@ -2001,8 +2004,9 @@ def test_multionehot(self, shape1, device):
assert spec.to(device).device == device

def test_non_tensor(self, device):
spec = NonTensor(shape=(3, 4), device="cpu")
spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data")
assert spec.to(device).device == device
assert spec.to(device).example_data == "example_data"

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_onehot(self, shape1, device):
Expand Down Expand Up @@ -2262,13 +2266,14 @@ def test_stack_multionehot_zero(self, shape, stack_dim):
assert r.shape == c.shape

def test_stack_non_tensor(self, shape, stack_dim):
spec0 = NonTensor(shape=shape, device="cpu")
spec1 = NonTensor(shape=shape, device="cpu")
spec0 = NonTensor(shape=shape, device="cpu", example_data="example_data")
spec1 = NonTensor(shape=shape, device="cpu", example_data="example_data")
new_spec = torch.stack([spec0, spec1], stack_dim)
shape_insert = list(shape)
shape_insert.insert(stack_dim, 2)
assert new_spec.shape == torch.Size(shape_insert)
assert new_spec.device == torch.device("cpu")
assert new_spec.example_data == "example_data"

def test_stack_onehot(self, shape, stack_dim):
n = 5
Expand Down Expand Up @@ -3642,10 +3647,18 @@ def test_expand(self):

class TestNonTensorSpec:
def test_sample(self):
nts = NonTensor(shape=(3, 4))
nts = NonTensor(shape=(3, 4), example_data="example_data")
assert nts.one((2,)).shape == (2, 3, 4)
assert nts.rand((2,)).shape == (2, 3, 4)
assert nts.zero((2,)).shape == (2, 3, 4)
assert nts.one((2,)).data == "example_data"
assert nts.rand((2,)).data == "example_data"
assert nts.zero((2,)).data == "example_data"

def test_example_data_ineq(self):
nts0 = NonTensor(shape=(3, 4), example_data="example_data")
nts1 = NonTensor(shape=(3, 4), example_data="example_data 2")
assert nts0 != nts1


@pytest.mark.skipif(not torch.cuda.is_available(), reason="not cuda device")
Expand Down
65 changes: 57 additions & 8 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2452,11 +2452,14 @@ class NonTensor(TensorSpec):
(same will go for :meth:`.zero` and :meth:`.one`).
"""

example_data: Any = None

def __init__(
self,
shape: Union[torch.Size, int] = _DEFAULT_SHAPE,
device: Optional[DEVICE_TYPING] = None,
dtype: torch.dtype | None = None,
example_data: Any = None,
**kwargs,
):
if isinstance(shape, int):
Expand All @@ -2467,6 +2470,12 @@ def __init__(
super().__init__(
shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs
)
self.example_data = example_data

def __eq__(self, other):
eq = super().__eq__(other)
eq = eq & (self.example_data == getattr(other, "example_data", None))
return eq

def cardinality(self) -> Any:
raise RuntimeError("Cannot enumerate a NonTensorSpec.")
Expand All @@ -2485,30 +2494,46 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor:
dest_device = torch.device(dest)
if dest_device == self.device and dest_dtype == self.dtype:
return self
return self.__class__(shape=self.shape, device=dest_device, dtype=None)
return self.__class__(
shape=self.shape,
device=dest_device,
dtype=None,
example_data=self.example_data,
)

def clone(self) -> NonTensor:
return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype)
return self.__class__(
shape=self.shape,
device=self.device,
dtype=self.dtype,
example_data=self.example_data,
)

def rand(self, shape=None):
if shape is None:
shape = ()
return NonTensorData(
data=None, batch_size=(*shape, *self._safe_shape), device=self.device
data=self.example_data,
batch_size=(*shape, *self._safe_shape),
device=self.device,
)

def zero(self, shape=None):
if shape is None:
shape = ()
return NonTensorData(
data=None, batch_size=(*shape, *self._safe_shape), device=self.device
data=self.example_data,
batch_size=(*shape, *self._safe_shape),
device=self.device,
)

def one(self, shape=None):
if shape is None:
shape = ()
return NonTensorData(
data=None, batch_size=(*shape, *self._safe_shape), device=self.device
data=self.example_data,
batch_size=(*shape, *self._safe_shape),
device=self.device,
)

def is_in(self, val: Any) -> bool:
Expand All @@ -2533,23 +2558,46 @@ def expand(self, *shape):
raise ValueError(
f"The last elements of the expanded shape must match the current one. Got shape={shape} while self.shape={self.shape}."
)
return self.__class__(shape=shape, device=self.device, dtype=None)
return self.__class__(
shape=shape, device=self.device, dtype=None, example_data=self.example_data
)

def unsqueeze(self, dim: int) -> NonTensor:
unsq = super().unsqueeze(dim=dim)
unsq.example_data = self.example_data
return unsq

def squeeze(self, dim: int | None = None) -> NonTensor:
sq = super().squeeze(dim=dim)
sq.example_data = self.example_data
return sq

def _reshape(self, shape):
return self.__class__(shape=shape, device=self.device, dtype=self.dtype)
return self.__class__(
shape=shape,
device=self.device,
dtype=self.dtype,
example_data=self.example_data,
)

def _unflatten(self, dim, sizes):
shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape
return self.__class__(
shape=shape,
device=self.device,
dtype=self.dtype,
example_data=self.example_data,
)

def __getitem__(self, idx: SHAPE_INDEX_TYPING):
"""Indexes the current TensorSpec based on the provided index."""
indexed_shape = _size(_shape_indexing(self.shape, idx))
return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype)
return self.__class__(
shape=indexed_shape,
device=self.device,
dtype=self.dtype,
example_data=self.example_data,
)

def unbind(self, dim: int = 0):
orig_dim = dim
Expand All @@ -2565,6 +2613,7 @@ def unbind(self, dim: int = 0):
shape=shape,
device=self.device,
dtype=self.dtype,
example_data=self.example_data,
)
for i in range(self.shape[dim])
)
Expand Down
Loading