Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 20, 2025
1 parent 1e30884 commit 6f98ca3
Showing 1 changed file with 215 additions and 0 deletions.
215 changes: 215 additions & 0 deletions examples/collectors/collector_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Using the SyncDataCollector with Different Device Combinations
==============================================================
TorchRL's SyncDataCollector allows you to specify the devices on which different components of the data collection
process are executed. This example demonstrates how to use the collector with various device combinations.
Understanding Device Precedence
-------------------------------
When creating a SyncDataCollector, you can specify the devices for the environment (env_device), policy (policy_device),
and data collection (device). The device argument serves as a default value for any unspecified devices. However, if you
provide env_device or policy_device, they take precedence over the device argument for their respective components.
For example:
- If you set device="cuda", all components will be executed on the CUDA device unless you specify otherwise.
- If you set env_device="cpu" and device="cuda", the environment will be executed on the CPU, while the policy and data
collection will be executed on the CUDA device.
Keeping Policy Parameters in Sync
---------------------------------
When using a policy with buffers or other attributes that are not automatically updated when moving the policy's
parameters to a different device, it's essential to keep the policy's parameters in sync between the main workspace and
the collector.
To do this, call update_policy_weights_() anytime the policy's parameters (and buffers!) are updated. This ensures that
the policy used by the collector has the same parameters as the policy in the main workspace.
Example Use Cases
-----------------
This script demonstrates the SyncDataCollector with the following device combinations:
- Collector on CUDA
- Collector on CPU
- Mixed collector: policy on CUDA, env untouched (ie, unmarked CPU, env.device == None)
- Mixed collector: policy on CUDA, env on CPU (env.device == "cpu")
- Mixed collector: all on CUDA, except env on CPU.
For each configuration, we run a DQN algorithm and check that it converges.
By following this example, you can learn how to use the SyncDataCollector with different device combinations and ensure
that your policy's parameters are kept in sync.
"""

import logging
import time

import torch.cuda
import torch.nn as nn
import torch.optim as optim

from tensordict.nn import TensorDictSequential as TDSeq

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torchrl.envs import Compose, GymEnv, RewardSum, StepCounter, TransformedEnv
from torchrl.modules import EGreedyModule, QValueActor
from torchrl.objectives import DQNLoss, SoftUpdate


logging.basicConfig(level=logging.INFO)
my_logger = logging.getLogger(__name__)

ENV_NAME = "CartPole-v1"

INIT_RND_STEPS = 5_120
FRAMES_PER_BATCH = 128
BUFFER_SIZE = 100_000

GAMMA = 0.98
OPTIM_STEPS = 10
BATCH_SIZE = 128

SOFTU_EPS = 0.99
LR = 0.02


class Net(nn.Module):
def __init__(self, obs_size: int, n_actions: int) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_size, 128),
nn.ReLU(),
nn.Linear(128, n_actions),
)

def forward(self, x):
orig_shape_unbatched = len(x.shape) == 1
if orig_shape_unbatched:
x = x.unsqueeze(0)

out = self.net(x)

if orig_shape_unbatched:
out = out.squeeze(0)
return out


def make_env(env_name: str):
return TransformedEnv(GymEnv(env_name), Compose(StepCounter(), RewardSum()))


if __name__ == "__main__":

for env_device, policy_device, device in (
(None, None, "cuda"),
(None, None, "cpu"),
(None, "cuda", None),
("cpu", "cuda", None),
("cpu", None, "cuda"),
# These configs don't run because the collector needs to know that the policy is on CUDA
# This is not true for the env which has specs that are associated with a device, we can
# automatically transfer the data. The policy does not, in general, have a spec indicating
# what the input and output devices are, so this must be told to the collector.
# (None, None, None),
# ("cpu", None, None),
):
torch.manual_seed(0)
torch.cuda.manual_seed(0)

env = make_env(ENV_NAME)
env.set_seed(0)

n_obs = env.observation_spec["observation"].shape[-1]
n_act = env.action_spec.shape[-1]

net = Net(n_obs, n_act).to(device="cuda:0")
agent = QValueActor(net, spec=env.action_spec.to("cuda:0"))

# policy_explore has buffers on CPU - we will need to call collector.update_policy_weights_()
# to sync them during data collection.
policy_explore = EGreedyModule(env.action_spec)
agent_explore = TDSeq(agent, policy_explore)

collector = SyncDataCollector(
env,
agent_explore,
frames_per_batch=FRAMES_PER_BATCH,
init_random_frames=INIT_RND_STEPS,
device=device,
env_device=env_device,
policy_device=policy_device,
)
exp_buffer = ReplayBuffer(
storage=LazyTensorStorage(BUFFER_SIZE, device="cuda:0")
)

loss = DQNLoss(
value_network=agent, action_space=env.action_spec, delay_value=True
)
loss.make_value_estimator(gamma=GAMMA)
target_updater = SoftUpdate(loss, eps=SOFTU_EPS)
optimizer = optim.Adam(loss.parameters(), lr=LR)

total_count = 0
total_episodes = 0
t0 = time.time()
for i, data in enumerate(collector):
# Check the data devices
if device is None:
assert data["action"].device == torch.device("cuda:0")
assert data["observation"].device == torch.device("cpu")
assert data["done"].device == torch.device("cpu")
elif device == "cpu":
assert data["action"].device == torch.device("cpu")
assert data["observation"].device == torch.device("cpu")
assert data["done"].device == torch.device("cpu")
else:
assert data["action"].device == torch.device("cuda:0")
assert data["observation"].device == torch.device("cuda:0")
assert data["done"].device == torch.device("cuda:0")

exp_buffer.extend(data)
max_length = exp_buffer["next", "step_count"].max()
max_reward = exp_buffer["next", "episode_reward"].max()
if len(exp_buffer) > INIT_RND_STEPS:
for _ in range(OPTIM_STEPS):
optimizer.zero_grad()
sample = exp_buffer.sample(batch_size=BATCH_SIZE)

loss_vals = loss(sample)
loss_vals["loss"].backward()
optimizer.step()

agent_explore[1].step(data.numel())
target_updater.step()

total_count += data.numel()
total_episodes += data["next", "done"].sum()

if i % 10 == 0:
my_logger.info(
f"Step: {i}, max. count / epi reward: {max_length} / {max_reward}."
)
collector.update_policy_weights_()
if max_length > 200:
t1 = time.time()
my_logger.info(f"SOLVED in {t1 - t0}s!! MaxLen: {max_length}!")
my_logger.info(f"With {max_reward} Reward!")
my_logger.info(f"In {total_episodes} Episodes!")
my_logger.info(f"Using devices {(env_device, policy_device, device)}")
break
else:
raise RuntimeError(
f"Failed to converge with config {(env_device, policy_device, device)}"
)

0 comments on commit 6f98ca3

Please sign in to comment.