diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 9b072cc9664..ada7bdfeed4 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -188,6 +188,9 @@ were found from rough benchmarking in https://github.com/pytorch/rl/tree/main/be | :class:`LazyMemmapStorage` | 3.44x | +-------------------------------+-----------+ +You can also read more about distributed replay buffers in https://github.com/pytorch/rl/tree/main/knowledge_base/DISTRIBUTED_BUFFER.md +and find examples of dummy training loops in https://github.com/pytorch/rl/tree/main/examples/replay-buffers/ + Sharing replay buffers across processes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/examples/distributed/replay_buffers/distributed_replay_buffer.py b/examples/distributed/replay_buffers/distributed_replay_buffer.py deleted file mode 100644 index f25ea0bdc8b..00000000000 --- a/examples/distributed/replay_buffers/distributed_replay_buffer.py +++ /dev/null @@ -1,221 +0,0 @@ -""" -Example use of a distributed replay buffer -=========================== - -This example illustrates how a skeleton reinforcement learning algorithm can be implemented in a distributed fashion with communication between nodes/workers handled using `torch.rpc`. -It focusses on how to set up a replay buffer worker that accepts remote operation requests efficiently, and so omits any learning component such as parameter updates that may be required for a complete distributed reinforcement learning algorithm implementation. -In this model, >= 1 data collectors workers are responsible for collecting experiences in an environment, the replay buffer worker receives all of these experiences and exposes them to a trainer that is responsible for making parameter updates to any required models. -""" - -import argparse -import os -import random -import sys -import time - -import torch -import torch.distributed.rpc as rpc -from tensordict import TensorDict -from torchrl._utils import accept_remote_rref_invocation, logger as torchrl_logger -from torchrl.data.replay_buffers import RemoteTensorDictReplayBuffer -from torchrl.data.replay_buffers.samplers import RandomSampler -from torchrl.data.replay_buffers.storages import LazyMemmapStorage -from torchrl.data.replay_buffers.writers import RoundRobinWriter - -RETRY_LIMIT = 2 -RETRY_DELAY_SECS = 3 -REPLAY_BUFFER_NODE = "ReplayBuffer" -TRAINER_NODE = "Trainer" - -parser = argparse.ArgumentParser( - description="RPC Replay Buffer Example", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, -) - -parser.add_argument( - "--rank", - type=int, - default=-1, - help="Node Rank [0 = Replay Buffer, 1 = Dummy Trainer, 2+ = Dummy Data Collector]", -) - - -class DummyDataCollectorNode: - """Data collector node responsible for collecting experiences used for learning. - - Args: - replay_buffer (rpc.RRef): the RRef associated with the construction of the replay buffer - """ - - def __init__(self, replay_buffer: rpc.RRef) -> None: - self.id = rpc.get_worker_info().id - self.replay_buffer = replay_buffer - torchrl_logger.info("Data Collector Node constructed") - - def _submit_random_item_async(self) -> rpc.RRef: - td = TensorDict({"a": torch.randint(100, (1,))}, []) - return rpc.remote( - self.replay_buffer.owner(), - ReplayBufferNode.add, - args=( - self.replay_buffer, - td, - ), - ) - - @accept_remote_rref_invocation - def collect(self): - """Method that begins experience collection (we just generate random TensorDicts in this example). `accept_remote_rref_invocation` enables this method to be invoked remotely provided the class instantiation `rpc.RRef` is provided in place of the object reference.""" - for elem in range(50): - time.sleep(random.randint(1, 4)) - torchrl_logger.info( - f"Collector [{self.id}] submission {elem}: {self._submit_random_item_async().to_here()}" - ) - - -class DummyTrainerNode: - """Trainer node responsible for learning from experiences sampled from an experience replay buffer.""" - - def __init__(self) -> None: - torchrl_logger.info("DummyTrainerNode") - self.id = rpc.get_worker_info().id - self.replay_buffer = self._create_replay_buffer() - self._create_and_launch_data_collectors() - - def train(self, iterations: int) -> None: - for iteration in range(iterations): - torchrl_logger.info(f"[{self.id}] Training Iteration: {iteration}") - time.sleep(3) - batch = rpc.rpc_sync( - self.replay_buffer.owner(), - ReplayBufferNode.sample, - args=(self.replay_buffer, 16), - ) - torchrl_logger.info(f"[{self.id}] Sample Obtained Iteration: {iteration}") - torchrl_logger.info(f"{batch}") - - def _create_replay_buffer(self) -> rpc.RRef: - while True: - try: - replay_buffer_info = rpc.get_worker_info(REPLAY_BUFFER_NODE) - buffer_rref = rpc.remote( - replay_buffer_info, ReplayBufferNode, args=(10000,) - ) - torchrl_logger.info(f"Connected to replay buffer {replay_buffer_info}") - return buffer_rref - except Exception as e: - torchrl_logger.info(f"Failed to connect to replay buffer: {e}") - time.sleep(RETRY_DELAY_SECS) - - def _create_and_launch_data_collectors(self) -> None: - data_collector_number = 2 - retries = 0 - data_collectors = [] - data_collector_infos = [] - # discover launched data collector nodes (with retry to allow collectors to dynamically join) - while True: - try: - data_collector_info = rpc.get_worker_info( - f"DataCollector{data_collector_number}" - ) - torchrl_logger.info(f"Data collector info: {data_collector_info}") - dc_ref = rpc.remote( - data_collector_info, - DummyDataCollectorNode, - args=(self.replay_buffer,), - ) - data_collectors.append(dc_ref) - data_collector_infos.append(data_collector_info) - data_collector_number += 1 - retries = 0 - except Exception: - retries += 1 - torchrl_logger.info( - f"Failed to connect to DataCollector{data_collector_number} with {retries} retries" - ) - if retries >= RETRY_LIMIT: - torchrl_logger.info(f"{len(data_collectors)} data collectors") - for data_collector_info, data_collector in zip( - data_collector_infos, data_collectors - ): - rpc.remote( - data_collector_info, - DummyDataCollectorNode.collect, - args=(data_collector,), - ) - break - else: - time.sleep(RETRY_DELAY_SECS) - - -class ReplayBufferNode(RemoteTensorDictReplayBuffer): - """Experience replay buffer node that is capable of accepting remote connections. Being a `RemoteTensorDictReplayBuffer` - means all of its public methods are remotely invokable using `torch.rpc`. - Using a LazyMemmapStorage is highly advised in distributed settings with shared storage due to the lower serialization - cost of MemoryMappedTensors as well as the ability to specify file storage locations which can improve ability to recover from node failures. - - Args: - capacity (int): the maximum number of elements that can be stored in the replay buffer. - """ - - def __init__(self, capacity: int): - super().__init__( - storage=LazyMemmapStorage( - max_size=capacity, scratch_dir="/tmp/", device=torch.device("cpu") - ), - sampler=RandomSampler(), - writer=RoundRobinWriter(), - collate_fn=lambda x: x, - ) - - -if __name__ == "__main__": - args = parser.parse_args() - rank = args.rank - torchrl_logger.info(f"Rank: {rank}") - - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29500" - os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" - str_init_method = "tcp://localhost:10000" - options = rpc.TensorPipeRpcBackendOptions( - num_worker_threads=16, init_method=str_init_method - ) - if rank == 0: - # rank 0 is the trainer - rpc.init_rpc( - TRAINER_NODE, - rank=rank, - backend=rpc.BackendType.TENSORPIPE, - rpc_backend_options=options, - ) - torchrl_logger.info(f"Initialised Trainer Node {rank}") - trainer = DummyTrainerNode() - trainer.train(100) - breakpoint() - elif rank == 1: - # rank 1 is the replay buffer - # replay buffer waits passively for construction instructions from trainer node - torchrl_logger.info(REPLAY_BUFFER_NODE) - rpc.init_rpc( - REPLAY_BUFFER_NODE, - rank=rank, - backend=rpc.BackendType.TENSORPIPE, - rpc_backend_options=options, - ) - torchrl_logger.info(f"Initialised RB Node {rank}") - breakpoint() - elif rank >= 2: - # rank 2+ is a new data collector node - # data collectors also wait passively for construction instructions from trainer node - rpc.init_rpc( - f"DataCollector{rank}", - rank=rank, - backend=rpc.BackendType.TENSORPIPE, - rpc_backend_options=options, - ) - torchrl_logger.info(f"Initialised DC Node {rank}") - breakpoint() - else: - sys.exit(1) - rpc.shutdown() diff --git a/examples/replay-buffers/distributed_rb_utils.py b/examples/replay-buffers/distributed_rb_utils.py new file mode 100644 index 00000000000..d79c7a78c34 --- /dev/null +++ b/examples/replay-buffers/distributed_rb_utils.py @@ -0,0 +1,273 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import random +import time + +import torch +import torch.distributed.rpc as rpc +import tqdm +from tensordict import TensorDict + +from torchrl._utils import accept_remote_rref_invocation, logger as torchrl_logger +from torchrl.data.replay_buffers import RemoteReplayBuffer +from torchrl.data.replay_buffers.samplers import SliceSampler +from torchrl.data.replay_buffers.storages import LazyMemmapStorage +from torchrl.data.replay_buffers.writers import RoundRobinWriter + +RETRY_LIMIT = 2 +RETRY_DELAY_SECS = 3 + +REPLAY_BUFFER_NODE = "ReplayBuffer" +TRAINER_NODE = "Trainer" + + +class CollectorNode: + """Data collector node responsible for collecting experiences used for learning. + + Args: + replay_buffer (rpc.RRef): the RRef associated with the construction of the replay buffer + frames_per_batch (int): the ``frames_per_batch`` of the collector. This serves as an example of hyperparameters + to be passed to the collector. + + """ + + def __init__(self, replay_buffer: rpc.RRef, frames_per_batch: int = 128) -> None: + self.id = rpc.get_worker_info().id + self.replay_buffer = replay_buffer + # Write your collector here + # self.collector = SyncDataCollector(...) + assert frames_per_batch > 0 + self.frames_per_batch = frames_per_batch + torchrl_logger.info("Data Collector Node constructed") + + def _submit_item_async(self) -> rpc.RRef: + """Function that collects data and populates the replay buffer.""" + # Replace this by a call to next() over the data collector + done = torch.zeros(self.frames_per_batch, 1, dtype=torch.bool) + done[..., -1, 0] = True + td = TensorDict( + { + "action": torch.randint( + 100, + ( + self.frames_per_batch, + 1, + ), + ), + "done": torch.zeros(self.frames_per_batch, dtype=torch.bool), + "observation": torch.randn(self.frames_per_batch, 4), + "step_count": torch.arange(self.frames_per_batch), + "terminated": torch.zeros(self.frames_per_batch, dtype=torch.bool), + "truncated": torch.zeros(self.frames_per_batch, dtype=torch.bool), + "next": { + "done": done, + "observation": torch.randn(self.frames_per_batch, 4), + "reward": torch.randn(self.frames_per_batch, 1), + "step_count": torch.arange(1, self.frames_per_batch + 1), + "terminated": torch.zeros_like(done), + "truncated": done, + }, + }, + [self.frames_per_batch], + ) + return rpc.remote( + self.replay_buffer.owner(), + ReplayBufferNode.extend, + args=( + self.replay_buffer, + td, + ), + ) + + @accept_remote_rref_invocation + def collect(self): + """Method that begins experience collection (we just generate random TensorDicts in this example). + + `accept_remote_rref_invocation` enables this method to be invoked remotely provided the class instantiation + `rpc.RRef` is provided in place of the object reference. + """ + for elem in range(50): + time.sleep(random.randint(1, 4)) + item = self._submit_item_async() + torchrl_logger.info( + f"Collector [{self.id}] submission {elem}: {item.to_here()}" + ) + + +class TrainerNode: + """Trainer node responsible for learning from experiences sampled from an experience replay buffer.""" + + def __init__(self, replay_buffer_node="ReplayBuffer", world_size=3) -> None: + self.replay_buffer_node = replay_buffer_node + self.world_size = world_size + torchrl_logger.info("TrainerNode") + self.id = rpc.get_worker_info().id + self.replay_buffer = self._create_replay_buffer() + self._create_and_launch_data_collectors() + + def train(self, iterations: int) -> None: + """Write your training loop here.""" + for iteration in tqdm.tqdm(range(iterations)): + torchrl_logger.info(f"[{self.id}] Training Iteration: {iteration}") + # # Wait until the buffer has elements + while not rpc.rpc_sync( + self.replay_buffer.owner(), + ReplayBufferNode.__len__, + args=(self.replay_buffer,), + ): + continue + + batch = rpc.rpc_sync( + self.replay_buffer.owner(), + ReplayBufferNode.sample, + args=(self.replay_buffer, 16), + ) + + torchrl_logger.info(f"[{self.id}] Sample Obtained Iteration: {iteration}") + torchrl_logger.info(f"{batch}") + # Process the sample here: forward, backward, ... + + def _create_replay_buffer(self) -> rpc.RRef: + def connect(): + replay_buffer_info = rpc.get_worker_info(self.replay_buffer_node) + buffer_rref = rpc.remote( + replay_buffer_info, ReplayBufferNode, args=(10000,) + ) + torchrl_logger.info(f"Connected to replay buffer {replay_buffer_info}") + return buffer_rref + + while True: + try: + return connect() + except Exception as e: + torchrl_logger.info(f"Failed to connect to replay buffer: {e}") + time.sleep(RETRY_DELAY_SECS) + + def _create_and_launch_data_collectors(self) -> None: + data_collector_number = self.world_size - 2 + self.data_collectors = [] + self.data_collector_infos = [] + # discover launched data collector nodes (with retry to allow collectors to dynamically join) + def connect(n, retry): + data_collector_info = rpc.get_worker_info( + f"DataCollector{n + 2}" # 2, 3, 4, ... + ) + torchrl_logger.info( + f"Data collector info: {data_collector_info}-retry={retry}" + ) + dc_ref = rpc.remote( + data_collector_info, + CollectorNode, + args=(self.replay_buffer,), + ) + self.data_collectors.append(dc_ref) + self.data_collector_infos.append(data_collector_info) + + for n in range(data_collector_number): + for retry in range(RETRY_LIMIT): + try: + connect(n, retry) + break + except Exception as e: + torchrl_logger.info( + f"Failed to connect to DataCollector{n} with {retry} retries (err={e})" + ) + time.sleep(RETRY_DELAY_SECS) + else: + raise Exception + for collector, data_collector_info in zip( + self.data_collectors, self.data_collector_infos + ): + rpc.remote( + data_collector_info, + CollectorNode.collect, + args=(collector,), + ) + + +class ReplayBufferNode(RemoteReplayBuffer): + """Experience replay buffer node that is capable of accepting remote connections. Being a `RemoteReplayBuffer` + means all of its public methods are remotely invokable using `torch.rpc`. + Using a LazyMemmapStorage is highly advised in distributed settings with shared storage due to the lower serialisation + cost of MemoryMappedTensors as well as the ability to specify file storage locations which can improve ability to recover from node failures. + + Args: + capacity (int): the maximum number of elements that can be stored in the replay buffer. + """ + + def __init__(self, capacity: int): + super().__init__( + storage=LazyMemmapStorage( + max_size=capacity, scratch_dir="/tmp/", device=torch.device("cpu") + ), + sampler=SliceSampler(num_slices=4), + writer=RoundRobinWriter(), + batch_size=32, + ) + + +def main(rank, world_size, **tensorpipe_kwargs): + """Dispatcher for the distributed workflow. + + rank 0 will be assigned the TRAINER job, + rank 1 will be assigned the REPLAY BUFFER job, + rank 2 to world_size-1 will be assigned the COLLECTOR jobs. + + """ + torchrl_logger.info(f"Rank: {rank}") + + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29500" + os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" + # + + options = rpc.TensorPipeRpcBackendOptions( + num_worker_threads=16, **tensorpipe_kwargs + ) + + if rank == 0: + # rank 0 is the trainer + torchrl_logger.info(f"Init RPC on {TRAINER_NODE}...") + rpc.init_rpc( + TRAINER_NODE, + rank=rank, + backend=rpc.BackendType.TENSORPIPE, + rpc_backend_options=options, + world_size=world_size, + ) + torchrl_logger.info(f"Initialised {TRAINER_NODE}") + trainer = TrainerNode(replay_buffer_node=REPLAY_BUFFER_NODE) + trainer.train(100) + rpc.shutdown() + elif rank == 1: + # rank 1 is the replay buffer + # replay buffer waits passively for construction instructions from trainer node + torchrl_logger.info(f"Init RPC on {REPLAY_BUFFER_NODE}...") + rpc.init_rpc( + REPLAY_BUFFER_NODE, + rank=rank, + backend=rpc.BackendType.TENSORPIPE, + rpc_backend_options=options, + world_size=world_size, + ) + torchrl_logger.info(f"Initialised {REPLAY_BUFFER_NODE}") + rpc.shutdown() + else: + # rank 2+ is a new data collector node + # data collectors also wait passively for construction instructions from trainer node + torchrl_logger.info(f"Init RPC on DataCollector{rank}") + rpc.init_rpc( + f"DataCollector{rank}", + rank=rank, + backend=rpc.BackendType.TENSORPIPE, + rpc_backend_options=options, + world_size=world_size, + ) + torchrl_logger.info(f"Initialised DataCollector{rank}") + rpc.shutdown() + print("exiting", rank) diff --git a/examples/replay-buffers/distributed_replay_buffer.py b/examples/replay-buffers/distributed_replay_buffer.py new file mode 100644 index 00000000000..325ceb6ea9a --- /dev/null +++ b/examples/replay-buffers/distributed_replay_buffer.py @@ -0,0 +1,55 @@ +""" +Example use of a distributed replay buffer (custom) +=================================================== + +This example illustrates how a skeleton reinforcement learning algorithm can be implemented in a distributed fashion +with communication between nodes/workers handled using `torch.rpc`. +It focusses on how to set up a replay buffer worker that accepts remote operation requests efficiently, and so omits +any learning component such as parameter updates that may be required for a complete distributed reinforcement learning +algorithm implementation. + +In this model, >= 1 data collectors workers are responsible for collecting experiences in an environment, the replay +buffer worker receives all of these experiences and exposes them to a trainer that is responsible for making parameter +updates to any required models. + +To launch this script, run + +```bash +$ # In terminal0: Trainer node +$ python examples/replay-buffers/distributed_replay_buffer.py --rank=0 +$ # In terminal1: Replay buffer node +$ python examples/replay-buffers/distributed_replay_buffer.py --rank=1 +$ # In terminal2 to N: Collector nodes +$ python examples/replay-buffers/distributed_replay_buffer.py --rank=2 + +``` +""" + +import argparse + +from distributed_rb_utils import main + +REPLAY_BUFFER_NODE = "ReplayBuffer" +TRAINER_NODE = "Trainer" + +parser = argparse.ArgumentParser( + description="RPC Replay Buffer Example", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, +) + +parser.add_argument( + "--rank", + type=int, + default=-1, + help="Node Rank [0 = Replay Buffer, 1 = Dummy Trainer, 2+ = Dummy Data Collector]", +) +parser.add_argument("--world_size", type=int, default=3, help="Number of nodes/workers") + + +if __name__ == "__main__": + + args = parser.parse_args() + rank = args.rank + world_size = args.world_size + + main(rank, world_size, str_init_method="tcp://localhost:10000") diff --git a/examples/replay-buffers/distributed_replay_buffer_multiproc.py b/examples/replay-buffers/distributed_replay_buffer_multiproc.py new file mode 100644 index 00000000000..79967340bd0 --- /dev/null +++ b/examples/replay-buffers/distributed_replay_buffer_multiproc.py @@ -0,0 +1,38 @@ +""" +Example use of a distributed replay buffer (single node) +======================================================== + +This example illustrates how a skeleton reinforcement learning algorithm can be implemented in a distributed fashion +with communication between nodes/workers handled using `torch.rpc`. +It focusses on how to set up a replay buffer worker that accepts remote operation requests efficiently, and so omits +any learning component such as parameter updates that may be required for a complete distributed reinforcement learning +algorithm implementation. + +In this model, >= 1 data collectors workers are responsible for collecting experiences in an environment, the replay +buffer worker receives all of these experiences and exposes them to a trainer that is responsible for making parameter +updates to any required models. + +To launch this script, run + +```bash +python examples/replay-buffers/distributed_replay_buffer_multiproc.py +``` + +""" + +from distributed_rb_utils import main +from torch import multiprocessing as mp + +REPLAY_BUFFER_NODE = "ReplayBuffer" +TRAINER_NODE = "Trainer" + +if __name__ == "__main__": + ctx = mp.get_context("spawn") + procs = [] + world_size = 3 + for i in range(world_size): + procs.append(ctx.Process(target=main, args=(i, world_size))) + procs[-1].start() + + for p in reversed(procs): + p.join() diff --git a/examples/replay-buffers/distributed_replay_buffer_submitit.py b/examples/replay-buffers/distributed_replay_buffer_submitit.py new file mode 100644 index 00000000000..5bb2c943815 --- /dev/null +++ b/examples/replay-buffers/distributed_replay_buffer_submitit.py @@ -0,0 +1,48 @@ +""" +Example use of a distributed replay buffer (submitit) +===================================================== + +This example illustrates how a skeleton reinforcement learning algorithm can be implemented in a distributed fashion +with communication between nodes/workers handled using `torch.rpc`. +It focusses on how to set up a replay buffer worker that accepts remote operation requests efficiently, and so omits +any learning component such as parameter updates that may be required for a complete distributed reinforcement learning +algorithm implementation. + +In this model, >= 1 data collectors workers are responsible for collecting experiences in an environment, the replay +buffer worker receives all of these experiences and exposes them to a trainer that is responsible for making parameter +updates to any required models. + +To launch this script, run + +```bash +python examples/replay-buffers/distributed_replay_buffer_submitit.py +``` + +""" + +import submitit + +from distributed_rb_utils import main +from torch import multiprocessing as mp + +DEFAULT_SLURM_CONF = { + "timeout_min": 10, + "slurm_partition": "train", + "slurm_cpus_per_task": 32, + "slurm_gpus_per_node": 0, +} #: Default value of the SLURM jobs + +if __name__ == "__main__": + + executor = submitit.AutoExecutor(folder="log_test") + executor.update_parameters(**DEFAULT_SLURM_CONF) + + ctx = mp.get_context("spawn") + jobs = [] + world_size = 3 + for i in range(world_size): + job = executor.submit(main, i, world_size) + jobs.append(job) + + for i in range(world_size): + jobs[i].result() diff --git a/knowledge_base/DISTRIBUTED_BUFFER.md b/knowledge_base/DISTRIBUTED_BUFFER.md new file mode 100644 index 00000000000..01d5865773b --- /dev/null +++ b/knowledge_base/DISTRIBUTED_BUFFER.md @@ -0,0 +1,43 @@ +# Distributed replay buffers in TorchRL + +This documents gives an overview of the various ways one can play with multiple nodes to collect +data in TorchRL. + + +## Sharing buffer between nodes using torch RPC + +TorchRL provides an API to call replay buffer methods remotely on a dedicated node. +This can be used as described in the following drawing: + +![distributed-rb.png](distributed-rb.png) + +Three node categories are instantiated: a trainer, a buffer node and a set of collector nodes. + +When all nodes are ready, the collector nodes start sending data to the buffer node by calling +`buffer.extend(tensordict)` remotely. + +The buffer node passively receives these calls and writes the data it receives in the storage. + +Once enough data has been written, the trainer node starts asking the buffer node for data to process. +The `buffer` sends that data over the wire. + +In some cases, all nodes have access to a shared physical storage (check with your administrator if this +is the case and if they will allow you to do frequent read and writing operations). +If this is the case, a `LazyMemmapStorage` can be instantiated with the shared path as `scratch_dir`. +In this case, node-to-node communications will be drastically reduced as each node will be able to directly +read and write on the storage. The only data passing over the wire from node to node will be the metadata +(shape, number of elements, size of the buffer etc.). + +Have a look at the examples [here](https://github.com/pytorch/rl/tree/main/examples/replay-buffers/). + +## Distributed collectors + +![distributed-collectors.png](distributed-collectors.png) + +### Backends + +#### Ray + +#### torch.distributed + +#### RPC diff --git a/knowledge_base/distributed-collectors.png b/knowledge_base/distributed-collectors.png new file mode 100644 index 00000000000..c9d52c7212e Binary files /dev/null and b/knowledge_base/distributed-collectors.png differ diff --git a/knowledge_base/distributed-rb.png b/knowledge_base/distributed-rb.png new file mode 100644 index 00000000000..5b868b7db91 Binary files /dev/null and b/knowledge_base/distributed-rb.png differ diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 3af44ee0ed7..17d8f103809 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -486,9 +486,47 @@ def unpack_rref_and_invoke_function(self, *args, **kwargs): def accept_remote_rref_udf_invocation(decorated_class): """Class decorator that applies `accept_remote_rref_invocation` to all public methods.""" # ignores private methods - for name in dir(decorated_class): - method = getattr(decorated_class, name) - if callable(method) and not name.startswith("_"): + __allowed_methods__ = { + "__add__", + "__sub__", + "__mul__", + "__truediv__", + "__floordiv__", + "__mod__", + "__pow__", + "__iadd__", + "__isub__", + "__imul__", + "__itruediv__", + "__ifloordiv__", + "__imod__", + "__ipow__", + "__and__", + "__or__", + "__xor__", + "__lshift__", + "__rshift__", + "__invert__", + "__iand__", + "__ior__", + "__ixor__", + "__ilshift__", + "__irshift__", + "__eq__", + "__ne__", + "__lt__", + "__le__", + "__gt__", + "__ge__", + "__bool__", + "__iter__", + "__len__", + } + + for name, method in inspect.getmembers(decorated_class): + if callable(method) and ( + not name.startswith("_") or name in __allowed_methods__ + ): setattr(decorated_class, name, accept_remote_rref_invocation(method)) return decorated_class diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py index 25822dcfe4c..6e33e72dd1d 100644 --- a/torchrl/data/replay_buffers/__init__.py +++ b/torchrl/data/replay_buffers/__init__.py @@ -14,6 +14,7 @@ ) from .replay_buffers import ( PrioritizedReplayBuffer, + RemoteReplayBuffer, RemoteTensorDictReplayBuffer, ReplayBuffer, ReplayBufferEnsemble, diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 2e0eeb80705..ed55b8b7447 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -1542,6 +1542,47 @@ def update_priority( def update_tensordict_priority(self, data: TensorDictBase) -> None: return super().update_tensordict_priority(data) + def __len__(self): + return super().__len__() + + def __iter__(self): + return super().__iter__() + + +@accept_remote_rref_udf_invocation +class RemoteReplayBuffer(ReplayBuffer): + """A remote invocation friendly ReplayBuffer class. Public methods can be invoked by remote agents using `torch.rpc` or called locally as normal.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def sample( + self, + batch_size: int | None = None, + return_info: bool = False, + ) -> TensorDictBase: + return super().sample(batch_size=batch_size, return_info=return_info) + + def add(self, data: TensorDictBase) -> int: + return super().add(data) + + def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor: + return super().extend(tensordicts) + + def update_priority( + self, index: Union[int, torch.Tensor], priority: Union[int, torch.Tensor] + ) -> None: + return super().update_priority(index, priority) + + def update_tensordict_priority(self, data: TensorDictBase) -> None: + return super().update_tensordict_priority(data) + + def __len__(self): + return super().__len__() + + def __iter__(self): + return super().__iter__() + class InPlaceSampler: """A sampler to write tennsordicts in-place.