Skip to content

Commit

Permalink
[RLlib] Split AddStates... connectors into 2 pieces (`AddTimeDimToB…
Browse files Browse the repository at this point in the history
…atchAndZeroPad` and `AddStatesFromEpisodesToBatch`) (ray-project#49835)
  • Loading branch information
sven1977 authored Jan 20, 2025
1 parent 243927c commit 663e325
Show file tree
Hide file tree
Showing 15 changed files with 347 additions and 115 deletions.
14 changes: 10 additions & 4 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2991,6 +2991,16 @@ def log_result(self, result: ResultDict) -> None:

@override(Trainable)
def cleanup(self) -> None:
# Stop all Learners.
if hasattr(self, "learner_group") and self.learner_group is not None:
self.learner_group.shutdown()

# Stop all aggregation actors.
if hasattr(self, "_aggregator_actor_manager") and (
self._aggregator_actor_manager is not None
):
self._aggregator_actor_manager.clear()

# Stop all EnvRunners.
if hasattr(self, "env_runner_group") and self.env_runner_group is not None:
self.env_runner_group.stop()
Expand All @@ -3000,10 +3010,6 @@ def cleanup(self) -> None:
):
self.eval_env_runner_group.stop()

# Stop all Learners.
if hasattr(self, "learner_group") and self.learner_group is not None:
self.learner_group.shutdown()

@OverrideToImplementCustomLogic
@classmethod
@override(Trainable)
Expand Down
10 changes: 8 additions & 2 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,7 @@ def build_env_to_module_connector(self, env, device=None):
from ray.rllib.connectors.env_to_module import (
AddObservationsFromEpisodesToBatch,
AddStatesFromEpisodesToBatch,
AddTimeDimToBatchAndZeroPad,
AgentToModuleMapping,
BatchIndividualItems,
EnvToModulePipeline,
Expand Down Expand Up @@ -1016,7 +1017,9 @@ def build_env_to_module_connector(self, env, device=None):
if self.add_default_connectors_to_env_to_module_pipeline:
# Append OBS handling.
pipeline.append(AddObservationsFromEpisodesToBatch())
# Append STATE_IN/STATE_OUT (and time-rank) handler.
# Append time-rank handler.
pipeline.append(AddTimeDimToBatchAndZeroPad())
# Append STATE_IN/STATE_OUT handler.
pipeline.append(AddStatesFromEpisodesToBatch())
# If multi-agent -> Map from AgentID-based data to ModuleID based data.
if self.is_multi_agent:
Expand Down Expand Up @@ -1138,6 +1141,7 @@ def build_learner_connector(
AddColumnsFromEpisodesToTrainBatch,
AddObservationsFromEpisodesToBatch,
AddStatesFromEpisodesToBatch,
AddTimeDimToBatchAndZeroPad,
AgentToModuleMapping,
BatchIndividualItems,
LearnerConnectorPipeline,
Expand Down Expand Up @@ -1182,7 +1186,9 @@ def build_learner_connector(
)
# Append all other columns handling.
pipeline.append(AddColumnsFromEpisodesToTrainBatch())
# Append STATE_IN/STATE_OUT (and time-rank) handler.
# Append time-rank handler.
pipeline.append(AddTimeDimToBatchAndZeroPad(as_learner_connector=True))
# Append STATE_IN/STATE_OUT handler.
pipeline.append(AddStatesFromEpisodesToBatch(as_learner_connector=True))
# If multi-agent -> Map from AgentID-based data to ModuleID based data.
if self.is_multi_agent:
Expand Down
11 changes: 0 additions & 11 deletions rllib/algorithms/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def __init__(self, algo_class=None):
self.vtrace_clip_rho_threshold = 1.0
self.vtrace_clip_pg_rho_threshold = 1.0
self.learner_queue_size = 3
self.max_requests_in_flight_per_env_runner = 1
self.timeout_s_sampler_manager = 0.0
self.timeout_s_aggregator_manager = 0.0
self.broadcast_interval = 1
Expand Down Expand Up @@ -758,16 +757,6 @@ def _func(actor, p):

time.sleep(0.01)

@override(Algorithm)
def cleanup(self) -> None:
super().cleanup()

# Stop all aggregation actors.
if hasattr(self, "_aggregator_actor_manager") and (
self._aggregator_actor_manager is not None
):
self._aggregator_actor_manager.clear()

def _sample_and_get_connector_states(self):
def _remote_sample_get_state_and_metrics(_worker):
_episodes = _worker.sample()
Expand Down
4 changes: 4 additions & 0 deletions rllib/connectors/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from ray.rllib.connectors.common.add_states_from_episodes_to_batch import (
AddStatesFromEpisodesToBatch,
)
from ray.rllib.connectors.common.add_time_dim_to_batch_and_zero_pad import (
AddTimeDimToBatchAndZeroPad,
)
from ray.rllib.connectors.common.agent_to_module_mapping import AgentToModuleMapping
from ray.rllib.connectors.common.batch_individual_items import BatchIndividualItems
from ray.rllib.connectors.common.numpy_to_tensor import NumpyToTensor
Expand All @@ -12,6 +15,7 @@
__all__ = [
"AddObservationsFromEpisodesToBatch",
"AddStatesFromEpisodesToBatch",
"AddTimeDimToBatchAndZeroPad",
"AgentToModuleMapping",
"BatchIndividualItems",
"NumpyToTensor",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class AddObservationsFromEpisodesToBatch(ConnectorV2):
[
[0 or more user defined ConnectorV2 pieces],
AddObservationsFromEpisodesToBatch,
AddTimeDimToBatchAndZeroPad,
AddStatesFromEpisodesToBatch,
AgentToModuleMapping, # only in multi-agent setups!
BatchIndividualItems,
Expand All @@ -34,6 +35,7 @@ class AddObservationsFromEpisodesToBatch(ConnectorV2):
[0 or more user defined ConnectorV2 pieces],
AddObservationsFromEpisodesToBatch,
AddColumnsFromEpisodesToTrainBatch,
AddTimeDimToBatchAndZeroPad,
AddStatesFromEpisodesToBatch,
AgentToModuleMapping, # only in multi-agent setups!
BatchIndividualItems,
Expand Down
102 changes: 6 additions & 96 deletions rllib/connectors/common/add_states_from_episodes_to_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.utils.annotations import override
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.postprocessing.zero_padding import (
create_mask_and_seq_lens,
split_and_zero_pad,
)
from ray.rllib.utils.spaces.space_utils import BatchedNdArray
from ray.rllib.utils.typing import EpisodeType
from ray.util.annotations import PublicAPI

Expand All @@ -35,6 +30,7 @@ class AddStatesFromEpisodesToBatch(ConnectorV2):
[
[0 or more user defined ConnectorV2 pieces],
AddObservationsFromEpisodesToBatch,
AddTimeDimToBatchAndZeroPad,
AddStatesFromEpisodesToBatch,
AgentToModuleMapping, # only in multi-agent setups!
BatchIndividualItems,
Expand All @@ -45,6 +41,7 @@ class AddStatesFromEpisodesToBatch(ConnectorV2):
[0 or more user defined ConnectorV2 pieces],
AddObservationsFromEpisodesToBatch,
AddColumnsFromEpisodesToTrainBatch,
AddTimeDimToBatchAndZeroPad,
AddStatesFromEpisodesToBatch,
AgentToModuleMapping, # only in multi-agent setups!
BatchIndividualItems,
Expand Down Expand Up @@ -160,7 +157,7 @@ def get_initial_state(self):
output_batch = connector(
rl_module=rl_module,
batch={},
episodes=[episode.to_numpy()],
episodes=[episode],
shared_data={},
)
check(
Expand All @@ -173,7 +170,7 @@ def get_initial_state(self):
# predictions).
# Also note that the different STATE_IN timesteps are already present
# as one batched item per episode in the list.
(episode.id_,): [[rl_module_init_state, -3.0]],
(episode.id_,): [rl_module_init_state, -3.0],
},
)
"""
Expand Down Expand Up @@ -217,61 +214,6 @@ def __call__(
if not rl_module.is_stateful() or Columns.STATE_IN in batch:
return batch

# Make all inputs (other than STATE_IN) have an additional T-axis.
# Since data has not been batched yet (we are still operating on lists in the
# batch), we add this time axis as 0 (not 1). When we batch, the batch axis will
# be 0 and the time axis will be 1.
# Also, let module-to-env pipeline know that we had added a single timestep
# time rank to the data (to remove it again).
if not self._as_learner_connector:
for column in batch.keys():
self.foreach_batch_item_change_in_place(
batch=batch,
column=column,
func=lambda item, eps_id, aid, mid: (
item
if mid is not None and not rl_module[mid].is_stateful()
# Expand on axis 0 (the to-be-time-dim) if item has not been
# batched yet, otherwise axis=1 (the time-dim).
else tree.map_structure(
lambda s: np.expand_dims(
s, axis=(1 if isinstance(s, BatchedNdArray) else 0)
),
item,
)
),
)
shared_data["_added_single_ts_time_rank"] = True
else:
# Before adding STATE_IN to the `data`, zero-pad existing data and batch
# into max_seq_len chunks.
for column, column_data in batch.copy().items():
# Do not zero-pad INFOS column.
if column == Columns.INFOS:
continue
for key, item_list in column_data.items():
# Multi-agent case AND RLModule is not stateful -> Do not zero-pad
# for this model.
assert isinstance(key, tuple)
mid = None
if len(key) == 3:
eps_id, aid, mid = key
if not rl_module[mid].is_stateful():
continue
column_data[key] = split_and_zero_pad(
item_list,
max_seq_len=self._get_max_seq_len(rl_module, module_id=mid),
)
# TODO (sven): Remove this hint/hack once we are not relying on
# SampleBatch anymore (which has to set its property
# zero_padded=True when shuffling).
shared_data[
(
"_zero_padded_for_mid="
f"{mid if mid is not None else DEFAULT_MODULE_ID}"
)
] = True

for sa_episode in self.single_agent_episode_iterator(
episodes,
# If Learner connector, get all episodes (for train batch).
Expand All @@ -280,8 +222,8 @@ def __call__(
agents_that_stepped_only=not self._as_learner_connector,
):
if self._as_learner_connector:
# Multi-agent case: Extract correct single agent RLModule (to get the
# state for individually).
# Multi-agent case: Extract correct single agent RLModule (to get its
# individual state).
if sa_episode.module_id is not None:
sa_module = rl_module[sa_episode.module_id]
else:
Expand Down Expand Up @@ -372,24 +314,6 @@ def __call__(
single_agent_episode=sa_episode,
)

# Also, create the loss mask (b/c of our now possibly zero-padded data)
# as well as the seq_lens array and add these to `data` as well.
mask, seq_lens = create_mask_and_seq_lens(len(sa_episode), max_seq_len)
self.add_n_batch_items(
batch=batch,
column=Columns.SEQ_LENS,
items_to_add=seq_lens,
num_items=len(seq_lens),
single_agent_episode=sa_episode,
)
if not shared_data.get("_added_loss_mask_for_valid_episode_ts"):
self.add_n_batch_items(
batch=batch,
column=Columns.LOSS_MASK,
items_to_add=mask,
num_items=len(mask),
single_agent_episode=sa_episode,
)
else:
assert not sa_episode.is_numpy

Expand Down Expand Up @@ -422,17 +346,3 @@ def __call__(
)

return batch

def _get_max_seq_len(self, rl_module, module_id=None):
if module_id:
mod = rl_module[module_id]
else:
mod = next(iter(rl_module.values()))
if "max_seq_len" not in mod.model_config:
raise ValueError(
"You are using a stateful RLModule and are not providing a "
"'max_seq_len' key inside your `model_config`. You can set this "
"dict and/or override keys in it via `config.rl_module("
"model_config={'max_seq_len': [some int]})`."
)
return mod.model_config["max_seq_len"]
Loading

0 comments on commit 663e325

Please sign in to comment.