Skip to content

Commit

Permalink
[BC]Weighted bc training (#416)
Browse files Browse the repository at this point in the history
raining code for the BC-Max algorithm which includes the tensorflow
code to train a new policy and to save it as a tf-policy and code to
compute the re-weighting for the supervised learning problem. This
required updates to SequenceExampleFeatureNames from
generate_bc_trajectories_lib.
  • Loading branch information
tvmarino authored Jan 17, 2025
1 parent e7b7e1c commit 9915a6d
Show file tree
Hide file tree
Showing 5 changed files with 771 additions and 8 deletions.
18 changes: 12 additions & 6 deletions compiler_opt/rl/imitation_learning/generate_bc_trajectories_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ class SequenceExampleFeatureNames:
"""Feature names for features that are always added to seq example."""
action: str = 'action'
reward: str = 'reward'
loss: str = 'loss'
regret: str = 'regret'
module_name: str = 'module_name'
horizon: str = 'horizon'
label_name: str = 'label'


def get_loss(seq_example: tf.train.SequenceExample,
Expand Down Expand Up @@ -631,7 +635,8 @@ def _partition_for_loss(self, seq_example: tf.train.SequenceExample,
seq_loss = get_loss(seq_example)

label = bisect.bisect_right(partitions, seq_loss)
horizon = len(seq_example.feature_lists.feature_list['action'].feature)
horizon = len(seq_example.feature_lists.feature_list[
SequenceExampleFeatureNames.action].feature)
label_list = [label for _ in range(horizon)]
add_feature_list(seq_example, label_list, label_name)

Expand All @@ -640,7 +645,7 @@ def process_succeeded(
succeeded: List[Tuple[List, List[str], int, float]],
spec_name: str,
partitions: List[float],
label_name: str = 'label'
label_name: str = SequenceExampleFeatureNames.label_name
) -> Tuple[tf.train.SequenceExample, ProfilingDictValueType,
ProfilingDictValueType]:
seq_example_list = [exploration_res[0] for exploration_res in succeeded]
Expand Down Expand Up @@ -691,12 +696,13 @@ def _profiling_dict(
"""

per_module_dict = {
'module_name':
SequenceExampleFeatureNames.module_name:
module_name,
'loss':
SequenceExampleFeatureNames.loss:
float(get_loss(feature_list)),
'horizon':
len(feature_list.feature_lists.feature_list['action'].feature),
SequenceExampleFeatureNames.horizon:
len(feature_list.feature_lists.feature_list[
SequenceExampleFeatureNames.action].feature),
}
return per_module_dict

Expand Down
86 changes: 86 additions & 0 deletions compiler_opt/rl/imitation_learning/weighted_bc_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for training an inlining policy with imitation learning."""

from absl import app
from absl import flags
from absl import logging

import gin
import json
from compiler_opt.rl import policy_saver

from compiler_opt.rl.inlining import imitation_learning_config as config

from compiler_opt.rl.imitation_learning.weighted_bc_trainer_lib import TrainingWeights
from compiler_opt.rl.imitation_learning.weighted_bc_trainer_lib import ImitationLearningTrainer
from compiler_opt.rl.imitation_learning.weighted_bc_trainer_lib import WrapKerasModel

_TRAINING_DATA = flags.DEFINE_multi_string(
'training_data', None, 'Training data for one step of BC-Max')
_PROFILING_DATA = flags.DEFINE_multi_string(
'profiling_data', None,
('Paths to profile files for computing the TrainingWeights'
'If specified the order for each pair of json files is'
'comparator.json followed by eval.json and the number of'
'files should always be even.'))
_SAVE_MODEL_DIR = flags.DEFINE_string(
'save_model_dir', None, 'Location to save the keras and TFAgents policies.')
_GIN_FILES = flags.DEFINE_multi_string(
'gin_files', [], 'List of paths to gin configuration files.')
_GIN_BINDINGS = flags.DEFINE_multi_string(
'gin_bindings', [],
'Gin bindings to override the values set in the config files.')


def train():
training_weights = None
if _PROFILING_DATA.value:
if len(_PROFILING_DATA.value) % 2 != 0:
raise ValueError('Profiling file paths should always be an even number.')
training_weights = TrainingWeights()
for i in range(len(_PROFILING_DATA.value) // 2):
with open(
_PROFILING_DATA.value[2 * i], encoding='utf-8') as comp_f, open(
_PROFILING_DATA.value[2 * i + 1], encoding='utf-8') as eval_f:
comparator_prof = json.load(comp_f)
eval_prof = json.load(eval_f)
training_weights.update_weights(
comparator_profile=comparator_prof, policy_profile=eval_prof)
trainer = ImitationLearningTrainer(
save_model_dir=_SAVE_MODEL_DIR.value, training_weights=training_weights)
trainer.train(filepaths=_TRAINING_DATA.value)
if _SAVE_MODEL_DIR.value:
keras_policy = trainer.get_policy()
expected_signature, action_spec = config.get_input_signature()
wrapped_keras_model = WrapKerasModel(
keras_policy=keras_policy,
time_step_spec=expected_signature,
action_spec=action_spec)
policy_dict = {'tf_agents_policy': wrapped_keras_model}
saver = policy_saver.PolicySaver(policy_dict=policy_dict)
saver.save(_SAVE_MODEL_DIR.value)


def main(_):
gin.parse_config_files_and_bindings(
_GIN_FILES.value, _GIN_BINDINGS.value, skip_unknown=False)
logging.info(gin.config_str())

train()


if __name__ == '__main__':
app.run(main)
Loading

0 comments on commit 9915a6d

Please sign in to comment.