Skip to content

Commit

Permalink
[Feature] Re-enable cache for specs
Browse files Browse the repository at this point in the history
ghstack-source-id: 6a005e3e6e5a16a7d17a7b6977709fff5d43d4d0
Pull Request resolved: #2730
  • Loading branch information
vmoens committed Jan 29, 2025
1 parent d44757d commit 2fe8bf4
Show file tree
Hide file tree
Showing 8 changed files with 374 additions and 212 deletions.
22 changes: 19 additions & 3 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,35 @@ Each env will have the following attributes:
the done-flag spec. See the section on trajectory termination below.
- :obj:`env.input_spec`: a :class:`~torchrl.data.Composite` object containing
all the input keys (:obj:`"full_action_spec"` and :obj:`"full_state_spec"`).
It is locked and should not be modified directly.
- :obj:`env.output_spec`: a :class:`~torchrl.data.Composite` object containing
all the output keys (:obj:`"full_observation_spec"`, :obj:`"full_reward_spec"` and :obj:`"full_done_spec"`).
It is locked and should not be modified directly.

If the environment carries non-tensor data, a :class:`~torchrl.data.NonTensorSpec`
If the environment carries non-tensor data, a :class:`~torchrl.data.NonTensor`
instance can be used.

Env specs: locks and batch size
-------------------------------

.. _Environment-lock:

Environment specs are locked by default (through a ``spec_locked`` arg passed to the env constructor).
Locking specs means that any modification of the spec (or its children if it is a :class:`~torchrl.data.Composite`
instance) will require to unlock it. This can be done via the :meth:`~torchrl.envs.EnvBase.set_spec_lock_`.
The reason specs are locked by default is that it makes it easy to cache values such as action or reset keys and the
likes.
Unlocking an env should only be done if it expected that the specs will be modified often (which, in principle, should
be avoided).
Modifications of the specs such as `env.observation_spec = new_spec` are allowed: under the hood, TorchRL will erase
the cache, unlock the specs, make the modification and relock the specs if the env was previously locked.

Importantly, the environment spec shapes should contain the batch size, e.g.
an environment with :obj:`env.batch_size == torch.Size([4])` should have
an :obj:`env.action_spec` with shape :obj:`torch.Size([4, action_size])`.
This is helpful when preallocation tensors, checking shape consistency etc.

Env methods
-----------

With these, the following methods are implemented:

- :meth:`env.reset`: a reset method that may (but not necessarily requires to) take
Expand Down
3 changes: 2 additions & 1 deletion test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2294,8 +2294,9 @@ def test_multi_purpose_env(self, serial):
env = SerialEnv(2, ContinuousActionVecMockEnv)
else:
env = ContinuousActionVecMockEnv()
env.set_spec_lock_()
env.rollout(10)
assert env._step_mdp.validate(None)
assert env._step_mdp.validated
c = SyncDataCollector(
env, env.rand_action, frames_per_batch=10, total_frames=20
)
Expand Down
136 changes: 92 additions & 44 deletions test/test_transforms.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4961,6 +4961,10 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Composite:
return self.__class__(**kwargs, device=_device, shape=self.shape)

def clone(self) -> Composite:
"""Clones the Composite spec.
Locked specs will not produce locked clones.
"""
try:
device = self.device
except RuntimeError:
Expand Down
7 changes: 6 additions & 1 deletion torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,8 @@ def __init__(
)
self._mp_start_method = mp_start_method

is_spec_locked = EnvBase.is_spec_locked

@property
def non_blocking(self):
nb = self._non_blocking
Expand Down Expand Up @@ -933,8 +935,9 @@ def _start_workers(self) -> None:
"environments!"
)
weakref_set.add(wr)
self._envs.append(env)
self._envs.append(env.set_spec_lock_())
self.is_closed = False
self.set_spec_lock_()

@_check_start
def state_dict(self) -> OrderedDict:
Expand Down Expand Up @@ -1458,6 +1461,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
for channel in self.parent_channels:
channel.send(("init", None))
self.is_closed = False
self.set_spec_lock_()

@_check_start
def state_dict(self) -> OrderedDict:
Expand Down Expand Up @@ -2164,6 +2168,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
)
env = env_fun
del env_fun
env.set_spec_lock_()

i = -1
import torchrl
Expand Down
Loading

0 comments on commit 2fe8bf4

Please sign in to comment.