Skip to content

Commit

Permalink
[BugFix] Fix device transfer for collectors with init_random_frames m…
Browse files Browse the repository at this point in the history
…ixed devices

ghstack-source-id: 1684399a7c84dd19b396db6c903fbf68c971c73d
Pull Request resolved: #2704
  • Loading branch information
vmoens committed Jan 20, 2025
1 parent afb81de commit 1d45117
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,6 +1148,21 @@ def rollout(self) -> TensorDictBase:
and self._frames < self.init_random_frames
):
self.env.rand_action(self._shuttle)
if (
self.policy_device is not None
and self.policy_device != self.env_device
):
# TODO: This may break with exclusive / ragged lazy stacks
self._shuttle.apply(
lambda name, val: val.to(
device=self.policy_device, non_blocking=True
)
if name in self._policy_output_keys
else val,
out=self._shuttle,
named=True,
nested_keys=True,
)
else:
if self._cast_to_policy_device:
if self.policy_device is not None:
Expand Down

0 comments on commit 1d45117

Please sign in to comment.