diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 93f935a49fd..50d2762ad0b 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4426,10 +4426,12 @@ class UnaryTransform(Transform): Args: in_keys (sequence of NestedKey): the keys of inputs to the unary operation. out_keys (sequence of NestedKey): the keys of the outputs of the unary operation. - fn (Callable): the function to use as the unary operation. If it accepts - a non-tensor input, it must also accept ``None``. + in_keys_inv (sequence of NestedKey): the keys of inputs to the unary operation during inverse call. + out_keys_inv (sequence of NestedKey): the keys of the outputs of the unary operation durin inverse call. Keyword Args: + fn (Callable): the function to use as the unary operation. If it accepts + a non-tensor input, it must also accept ``None``. use_raw_nontensor (bool, optional): if ``False``, data is extracted from :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` @@ -4500,11 +4502,18 @@ def __init__( self, in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey], - fn: Callable, + in_keys_inv: Sequence[NestedKey] | None = None, + out_keys_inv: Sequence[NestedKey] | None = None, *, + fn: Callable, use_raw_nontensor: bool = False, ): - super().__init__(in_keys=in_keys, out_keys=out_keys) + super().__init__( + in_keys=in_keys, + out_keys=out_keys, + in_keys_inv=in_keys_inv, + out_keys_inv=out_keys_inv, + ) self._fn = fn self._use_raw_nontensor = use_raw_nontensor @@ -4519,6 +4528,17 @@ def _apply_transform(self, value): value = value.tolist() return self._fn(value) + def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor: + if not self._use_raw_nontensor: + if isinstance(state, NonTensorData): + if state.dim() == 0: + state = state.get("data") + else: + state = state.tolist() + elif isinstance(state, NonTensorStack): + state = state.tolist() + return self._fn(state) + def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase ) -> TensorDictBase: @@ -4526,6 +4546,32 @@ def _reset( tensordict_reset = self._call(tensordict_reset) return tensordict_reset + def transform_input_spec(self, input_spec: Composite) -> Composite: + input_spec = input_spec.clone() + + # Make a generic input from the spec, call the transform with that + # input, and then generate the output spec from the output. + zero_input_ = input_spec.zero() + test_input = zero_input_["full_action_spec"].update( + zero_input_["full_state_spec"] + ) + test_output = self.inv(test_input) + test_input_spec = make_composite_from_td( + test_output, unsqueeze_null_shapes=False + ) + + input_spec["full_action_spec"] = self.transform_action_spec( + input_spec["full_action_spec"], + test_input_spec, + ) + if "full_state_spec" in input_spec.keys(): + input_spec["full_state_spec"] = self.transform_state_spec( + input_spec["full_state_spec"], + test_input_spec, + ) + print(input_spec) + return input_spec + def transform_output_spec(self, output_spec: Composite) -> Composite: output_spec = output_spec.clone() @@ -4586,6 +4632,16 @@ def transform_done_spec( ) -> TensorSpec: return self._transform_spec(done_spec, test_output_spec) + def transform_action_spec( + self, action_spec: TensorSpec, test_input_spec: TensorSpec + ) -> TensorSpec: + return self._transform_spec(action_spec, test_input_spec) + + def transform_state_spec( + self, state_spec: TensorSpec, test_input_spec: TensorSpec + ) -> TensorSpec: + return self._transform_spec(state_spec, test_input_spec) + class Hash(UnaryTransform): r"""Adds a hash value to a tensordict. @@ -4593,12 +4649,14 @@ class Hash(UnaryTransform): Args: in_keys (sequence of NestedKey): the keys of the values to hash. out_keys (sequence of NestedKey): the keys of the resulting hashes. + in_keys_inv (sequence of NestedKey): the keys of the values to hash during inv call. + out_keys_inv (sequence of NestedKey): the keys of the resulting hashes during inv call. + + Keyword Args: hash_fn (Callable, optional): the hash function to use. If ``seed`` is given, the hash function must accept it as its second argument. Default is ``Hash.reproducible_hash``. seed (optional): seed to use for the hash function, if it requires one. - - Keyword Args: use_raw_nontensor (bool, optional): if ``False``, data is extracted from :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` @@ -4684,9 +4742,11 @@ def __init__( self, in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey], + in_keys_inv: Sequence[NestedKey] | None = None, + out_keys_inv: Sequence[NestedKey] | None = None, + *, hash_fn: Callable = None, seed: Any | None = None, - *, use_raw_nontensor: bool = False, ): if hash_fn is None: @@ -4697,6 +4757,8 @@ def __init__( super().__init__( in_keys=in_keys, out_keys=out_keys, + in_keys_inv=in_keys_inv, + out_keys_inv=out_keys_inv, fn=self.call_hash_fn, use_raw_nontensor=use_raw_nontensor, ) @@ -4725,7 +4787,7 @@ def reproducible_hash(cls, string, seed=None): if seed is not None: seeded_string = seed + string else: - seeded_string = string + seeded_string = str(string) # Create a new SHA-256 hash object hash_object = hashlib.sha256()