Skip to content

Commit

Permalink
[RLlib] Old API stack IMPALA/APPO: Re-introduce mixin-replay-buffer p…
Browse files Browse the repository at this point in the history
…ass, even if `replay-ratio=0` (fixes a memory leak). (ray-project#49964)
  • Loading branch information
sven1977 authored Jan 20, 2025
1 parent 663e325 commit 84b2b0e
Showing 1 changed file with 12 additions and 16 deletions.
28 changes: 12 additions & 16 deletions rllib/algorithms/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,8 +554,6 @@ def setup(self, config: AlgorithmConfig):

# Queue of data to be sent to the Learner.
self.data_to_place_on_learner = []
# The local mixin buffer (if required).
self.local_mixin_buffer = None
self._batch_being_built = [] # @OldAPIStack

# Create extra aggregation workers and assign each rollout worker to
Expand All @@ -565,18 +563,17 @@ def setup(self, config: AlgorithmConfig):
i: [] for i in range(self.config.num_learners or 1)
}

# Create our local mixin buffer if the num of aggregation workers is 0.
# Create our local mixin buffer.
if not self.config.enable_rl_module_and_learner:
if self.config.replay_proportion > 0.0:
self.local_mixin_buffer = MixInMultiAgentReplayBuffer(
capacity=(
self.config.replay_buffer_num_slots
if self.config.replay_buffer_num_slots > 0
else 1
),
replay_ratio=self.config.replay_ratio,
replay_mode=ReplayMode.LOCKSTEP,
)
self.local_mixin_buffer = MixInMultiAgentReplayBuffer(
capacity=(
self.config.replay_buffer_num_slots
if self.config.replay_buffer_num_slots > 0
else 1
),
replay_ratio=self.config.replay_ratio,
replay_mode=ReplayMode.LOCKSTEP,
)

# This variable is used to keep track of the statistics from the most recent
# update of the learner group
Expand Down Expand Up @@ -1081,9 +1078,8 @@ def _process_experiences_old_api_stack(
batch = batch.decompress_if_needed()
# Only make a pass through the buffer, if replay proportion is > 0.0 (and
# we actually have one).
if self.local_mixin_buffer:
self.local_mixin_buffer.add(batch)
batch = self.local_mixin_buffer.replay(_ALL_POLICIES)
self.local_mixin_buffer.add(batch)
batch = self.local_mixin_buffer.replay(_ALL_POLICIES)
if batch:
processed_batches.append(batch)

Expand Down

0 comments on commit 84b2b0e

Please sign in to comment.