diff --git a/test/test_cost.py b/test/test_cost.py index 8bbb7edce05..c8e45624580 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -40,6 +40,7 @@ from tensordict.nn.utils import Buffer from tensordict.utils import unravel_key from torch import autograd, nn +from torchrl._utils import _standardize from torchrl.data import Bounded, Categorical, Composite, MultiOneHot, OneHot, Unbounded from torchrl.data.postprocs.postprocs import MultiStep from torchrl.envs.model_based.dreamer import DreamerEnv @@ -16044,6 +16045,18 @@ def _composite_log_prob(self): yield setter.unset() + def test_standardization(self): + t = torch.arange(3 * 4 * 5 * 6, dtype=torch.float32).view(3, 4, 5, 6) + std_t0 = _standardize(t, exclude_dims=(1, 3)) + std_t1 = (t - t.mean((0, 2), keepdim=True)) / t.std((0, 2), keepdim=True).clamp( + 1 - 6 + ) + torch.testing.assert_close(std_t0, std_t1) + std_t = _standardize(t, (), -1, 2) + torch.testing.assert_close(std_t, (t + 1) / 2) + std_t = _standardize(t, ()) + torch.testing.assert_close(std_t, (t - t.mean()) / t.std()) + @pytest.mark.parametrize("B", [None, (1, ), (4, ), (2, 2, ), (1, 2, 8, )]) # fmt: skip @pytest.mark.parametrize("T", [1, 10]) @pytest.mark.parametrize("device", get_default_devices()) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 6a2f80aeffb..f999fa96c1d 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -24,7 +24,7 @@ from distutils.util import strtobool from functools import wraps from importlib import import_module -from typing import Any, Callable, cast, Dict, TypeVar, Union +from typing import Any, Callable, cast, Dict, Tuple, TypeVar, Union import numpy as np import torch @@ -32,7 +32,7 @@ from tensordict import unravel_key from tensordict.utils import NestedKey -from torch import multiprocessing as mp +from torch import multiprocessing as mp, Tensor try: from torch.compiler import is_compiling @@ -872,6 +872,70 @@ def set_mode(self, type: Any | None) -> None: self._mode = type +def _standardize( + input: Tensor, + exclude_dims: Tuple[int] = (), + mean: Tensor | None = None, + std: Tensor | None = None, + eps: float | None = None, +): + """Standardizes the input tensor with the possibility of excluding specific dims from the statistics. + + Useful when processing multi-agent data to keep the agent dimensions independent. + + Args: + input (Tensor): the input tensor to be standardized. + exclude_dims (Tuple[int]): dimensions to exclude from the statistics, can be negative. Default: (). + mean (Tensor): a mean to be used for standardization. Must be of shape broadcastable to input. Default: None. + std (Tensor): a standard deviation to be used for standardization. Must be of shape broadcastable to input. Default: None. + eps (float): epsilon to be used for numerical stability. Default: float32 resolution. + + """ + if eps is None: + if input.dtype.is_floating_point: + eps = torch.finfo(torch.float).resolution + else: + eps = 1e-6 + + len_exclude_dims = len(exclude_dims) + if not len_exclude_dims: + if mean is None: + mean = input.mean() + else: + # Assume dtypes are compatible + mean = torch.as_tensor(mean, device=input.device) + if std is None: + std = input.std() + else: + # Assume dtypes are compatible + std = torch.as_tensor(std, device=input.device) + return (input - mean) / std.clamp_min(eps) + + input_shape = input.shape + exclude_dims = [ + d if d >= 0 else d + len(input_shape) for d in exclude_dims + ] # Make negative dims positive + + if len(set(exclude_dims)) != len_exclude_dims: + raise ValueError("Exclude dims has repeating elements") + if any(dim < 0 or dim >= len(input_shape) for dim in exclude_dims): + raise ValueError( + f"exclude_dims={exclude_dims} provided outside bounds for input of shape={input_shape}" + ) + if len_exclude_dims == len(input_shape): + warnings.warn( + "_standardize called but all dims were excluded from the statistics, returning unprocessed input" + ) + return input + + included_dims = tuple(d for d in range(len(input_shape)) if d not in exclude_dims) + if mean is None: + mean = torch.mean(input, keepdim=True, dim=included_dims) + if std is None: + std = torch.std(input, keepdim=True, dim=included_dims) + return (input - mean) / std.clamp_min(eps) + + @wraps(torch.compile) def compile_with_warmup(*args, warmup: int = 1, **kwargs): """Compile a model with warm-up. diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 5b6763f6910..2a70f70a3e2 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1217,6 +1217,10 @@ class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta): __doc__ += BatchedEnvBase.__doc__ __doc__ += """ + .. note:: ParallelEnv will timeout after one of the worker is idle for a determinate amount of time. + This can be controlled via the BATCHED_PIPE_TIMEOUT environment variable, which in turn modifies + the torchrl._utils.BATCHED_PIPE_TIMEOUT integer. The default timeout value is 10000 seconds. + .. warning:: TorchRL's ParallelEnv is quite stringent when it comes to env specs, since these are used to build shared memory buffers for inter-process communication. @@ -1353,7 +1357,10 @@ class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta): """ def _start_workers(self) -> None: + import torchrl + self._timeout = 10.0 + self.BATCHED_PIPE_TIMEOUT = torchrl._utils.BATCHED_PIPE_TIMEOUT from torchrl.envs.env_creator import EnvCreator @@ -1606,7 +1613,7 @@ def step_and_maybe_reset( for i in workers_range: event = self._events[i] - event.wait(self._timeout) + event.wait(self.BATCHED_PIPE_TIMEOUT) event.clear() if self._non_tensor_keys: @@ -1796,7 +1803,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: for i in workers_range: event = self._events[i] - event.wait(self._timeout) + event.wait(self.BATCHED_PIPE_TIMEOUT) event.clear() if self._non_tensor_keys: @@ -1965,7 +1972,7 @@ def tentative_update(val, other): for i, _ in outs: event = self._events[i] - event.wait(self._timeout) + event.wait(self.BATCHED_PIPE_TIMEOUT) event.clear() workers_nontensor = [] @@ -2023,7 +2030,7 @@ def _shutdown_workers(self) -> None: for channel in self.parent_channels: channel.close() for proc in self._workers: - proc.join(timeout=1.0) + proc.join(timeout=self._timeout) finally: for proc in self._workers: if proc.is_alive(): diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 4e8c2c3b87b..392b7291df9 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import warnings from typing import Dict, List, Optional, Type, Union diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index e54f71bba7d..7701c1a662f 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -257,6 +257,7 @@ def __post_init__(self): self.sample_log_prob = "action_log_prob" default_keys = _AcceptedKeys + tensor_keys: _AcceptedKeys default_value_estimator: ValueEstimators = ValueEstimators.GAE actor_network: TensorDictModule diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 5a6893412a4..ab5a564abcf 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -128,6 +128,7 @@ class _AcceptedKeys: pass + tensor_keys: _AcceptedKeys _vmap_randomness = None default_value_estimator: ValueEstimators = None diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index df6ebb18571..f10dcab1e0e 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -260,6 +260,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TD0 @@ -1024,6 +1025,7 @@ class _AcceptedKeys: terminated: NestedKey = "terminated" pred_val: NestedKey = "pred_val" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TD0 out_keys = [ diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 162bce09c0d..d7dc5976308 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -242,6 +242,7 @@ class _AcceptedKeys: terminated: NestedKey = "terminated" log_prob: NestedKey = "_log_prob" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TD0 diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index b9617d8d6b3..fde8df0a93a 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -173,6 +173,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys default_value_estimator: ValueEstimators = ValueEstimators.TD0 out_keys = [ diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index e7df82f21a9..acb7156ad6e 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -70,6 +70,7 @@ class _AcceptedKeys: # the "action" output from the model action_pred: NestedKey = "action" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys actor_network: TensorDictModule @@ -280,6 +281,7 @@ class _AcceptedKeys: # the "action" output from the model action_pred: NestedKey = "action" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys actor_network: TensorDictModule diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 79e5b8de58c..47fc0508397 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -164,6 +164,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TD0 out_keys = ["loss"] @@ -435,6 +436,7 @@ class _AcceptedKeys: terminated: NestedKey = "terminated" steps_to_next_obs: NestedKey = "steps_to_next_obs" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TD0 diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index b183e94ddd1..6c03461ce76 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -89,8 +89,13 @@ class _AcceptedKeys: pixels: NestedKey = "pixels" reco_pixels: NestedKey = "reco_pixels" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys + decoder: TensorDictModule + reward_model: TensorDictModule + world_mdel: TensorDictModule + def __init__( self, world_model: TensorDictModule, @@ -238,9 +243,13 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TDLambda + value_model: TensorDictModule + actor_model: TensorDictModule + def __init__( self, actor_model: TensorDictModule, @@ -392,8 +401,11 @@ class _AcceptedKeys: value: NestedKey = "state_value" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys + value_model: TensorDictModule + def __init__( self, value_model: TensorDictModule, diff --git a/torchrl/objectives/gail.py b/torchrl/objectives/gail.py index dc26be70c54..bece855ad62 100644 --- a/torchrl/objectives/gail.py +++ b/torchrl/objectives/gail.py @@ -59,6 +59,7 @@ class _AcceptedKeys: collector_observation: NestedKey = "collector_observation" discriminator_pred: NestedKey = "d_logits" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys discriminator_network: TensorDictModule diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 774949407c1..ea84d59939a 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -233,6 +233,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TD0 out_keys = [ @@ -709,6 +710,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TD0 out_keys = [ diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index fa9c0ad97fb..93b50f3e76b 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -179,6 +179,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TD0 out_keys = ["loss"] diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index b5b45393a06..9b41afd9afa 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -30,6 +30,7 @@ from tensordict.utils import NestedKey from torch import distributions as d +from torchrl._utils import _standardize from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( @@ -95,6 +96,9 @@ class PPOLoss(LossModule): Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. normalize_advantage (bool, optional): if ``True``, the advantage will be normalized before being used. Defaults to ``False``. + normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage + standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings + where the agent (or objective) dimension may be excluded from the reductions. Default: (). separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. Defaults to ``False``, i.e., gradients are propagated to shared @@ -308,6 +312,7 @@ def __post_init__(self): self.sample_log_prob = "action_log_prob" default_keys = _AcceptedKeys + tensor_keys: _AcceptedKeys default_value_estimator = ValueEstimators.GAE actor_network: ProbabilisticTensorDictModule @@ -328,6 +333,7 @@ def __init__( critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, + normalize_advantage_exclude_dims: Tuple[int] = (), gamma: float = None, separate_losses: bool = False, advantage_key: str = None, @@ -398,6 +404,8 @@ def __init__( self.critic_coef = None self.loss_critic_type = loss_critic_type self.normalize_advantage = normalize_advantage + self.normalize_advantage_exclude_dims = normalize_advantage_exclude_dims + if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self._set_deprecated_ctor_keys( @@ -656,9 +664,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) advantage = tensordict.get(self.tensor_keys.advantage) if self.normalize_advantage and advantage.numel() > 1: - loc = advantage.mean() - scale = advantage.std().clamp_min(1e-6) - advantage = (advantage - loc) / scale + if advantage.numel() > tensordict.batch_size.numel() and not len( + self.normalize_advantage_exclude_dims + ): + warnings.warn( + "You requested advantage normalization and the advantage key has more dimensions" + " than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` " + "if you want to keep any dimension independent while computing normalization statistics. " + "If you are working in multi-agent/multi-objective settings this is highly suggested." + ) + advantage = _standardize(advantage, self.normalize_advantage_exclude_dims) log_weight, dist, kl_approx = self._log_weight(tensordict) if is_tensor_collection(log_weight): @@ -761,6 +776,9 @@ class ClipPPOLoss(PPOLoss): Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. normalize_advantage (bool, optional): if ``True``, the advantage will be normalized before being used. Defaults to ``False``. + normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage + standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings + where the agent (or objective) dimension may be excluded from the reductions. Default: (). separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. Defaults to ``False``, i.e., gradients are propagated to shared @@ -852,6 +870,7 @@ def __init__( critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, + normalize_advantage_exclude_dims: Tuple[int] = (), gamma: float = None, separate_losses: bool = False, reduction: str = None, @@ -871,6 +890,7 @@ def __init__( critic_coef=critic_coef, loss_critic_type=loss_critic_type, normalize_advantage=normalize_advantage, + normalize_advantage_exclude_dims=normalize_advantage_exclude_dims, gamma=gamma, separate_losses=separate_losses, reduction=reduction, @@ -921,9 +941,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) advantage = tensordict.get(self.tensor_keys.advantage) if self.normalize_advantage and advantage.numel() > 1: - loc = advantage.mean() - scale = advantage.std().clamp_min(1e-6) - advantage = (advantage - loc) / scale + if advantage.numel() > tensordict.batch_size.numel() and not len( + self.normalize_advantage_exclude_dims + ): + warnings.warn( + "You requested advantage normalization and the advantage key has more dimensions" + " than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` " + "if you want to keep any dimension independent while computing normalization statistics. " + "If you are working in multi-agent/multi-objective settings this is highly suggested." + ) + advantage = _standardize(advantage, self.normalize_advantage_exclude_dims) log_weight, dist, kl_approx = self._log_weight(tensordict) # ESS for logging @@ -1007,6 +1034,9 @@ class KLPENPPOLoss(PPOLoss): Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. normalize_advantage (bool, optional): if ``True``, the advantage will be normalized before being used. Defaults to ``False``. + normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage + standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings + where the agent (or objective) dimension may be excluded from the reductions. Default: (). separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. Defaults to ``False``, i.e., gradients are propagated to shared @@ -1100,6 +1130,7 @@ def __init__( critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, + normalize_advantage_exclude_dims: Tuple[int] = (), gamma: float = None, separate_losses: bool = False, reduction: str = None, @@ -1115,6 +1146,7 @@ def __init__( critic_coef=critic_coef, loss_critic_type=loss_critic_type, normalize_advantage=normalize_advantage, + normalize_advantage_exclude_dims=normalize_advantage_exclude_dims, gamma=gamma, separate_losses=separate_losses, reduction=reduction, @@ -1203,9 +1235,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: ) advantage = tensordict_copy.get(self.tensor_keys.advantage) if self.normalize_advantage and advantage.numel() > 1: - loc = advantage.mean() - scale = advantage.std().clamp_min(1e-6) - advantage = (advantage - loc) / scale + if advantage.numel() > tensordict.batch_size.numel() and not len( + self.normalize_advantage_exclude_dims + ): + warnings.warn( + "You requested advantage normalization and the advantage key has more dimensions" + " than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` " + "if you want to keep any dimension independent while computing normalization statistics. " + "If you are working in multi-agent/multi-objective settings this is highly suggested." + ) + advantage = _standardize(advantage, self.normalize_advantage_exclude_dims) + log_weight, dist, kl_approx = self._log_weight(tensordict_copy) neg_loss = log_weight.exp() * advantage if is_tensor_collection(neg_loss): diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 3a9a9bac791..0cac502b347 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -224,6 +224,7 @@ def __post_init__(self): else: self.sample_log_prob = "action_log_prob" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.GAE out_keys = ["loss_actor", "loss_value"] diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 9b5d9b35d64..9d60d51334a 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -301,6 +301,7 @@ def __post_init__(self): self.log_prob = "action_log_prob" default_keys = _AcceptedKeys + tensor_keys: _AcceptedKeys default_value_estimator = ValueEstimators.TD0 actor_network: TensorDictModule @@ -1046,6 +1047,7 @@ class _AcceptedKeys: terminated: NestedKey = "terminated" log_prob: NestedKey = "log_prob" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TD0 delay_actor: bool = False diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 0fcfa5e8351..f5d67eea164 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -204,6 +204,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TD0 out_keys = [ diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index a217240d96f..deb5b844500 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -217,6 +217,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TD0 out_keys = [ diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index dbb269e2e8b..dd253a6d908 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -152,6 +152,7 @@ def __post_init__(self): self.sample_log_prob = "action_log_prob" default_keys = _AcceptedKeys + tensor_keys: _AcceptedKeys value_network: Union[TensorDictModule, Callable] _vmap_randomness = None