Skip to content

Commit

Permalink
[TRL_SFT_Trainer] Fix TRL-SFT Distillation Training (#1163)
Browse files Browse the repository at this point in the history
SUMMARY:
* Fix examples script failure
https://github.com/neuralmagic/llm-compressor-testing/actions/runs/13350457472/job/37286313648
for `llm-compressor/examples/trl_mixin/ex_trl_distillation.py`
* Update code with respect to
#1161

PROBLEM:
1. 
```bash
TypeError: GSM8KDataset.__init__() got an unexpected keyword argument 'tokenizer'
```
2.
```bash
AttributeError: 'GSM8KDataset' object has no attribute 'tokenize_and_process'
```

3. 
```bash
TypeError: SFTTrainer.__init__() got an unexpected keyword argument 'packing'
```

4. 
```bash
ValueError: Found common keys in `training_args` and `data args`. This is prohibitive and may lead to undesired behavior.
```

5. 
```bash
TypeError: SessionManagerMixIn.save_model() missing 1 required positional argument: 'output_dir'
```
SOLUTION:
1. `TextGenerationDataset.load_from_registry` takes in `processor`, not
`tokenizer`
2. Obtain training dataset from `__call__` of `TextGenerationDataset`,
not `dataset_manager.tokenize_and_process()`
3. Move `max_seq_length` and `packing` as a part of `TRLSFTConfig`, not
`TrainingArguments`
4. Collision on "max_seq_length' on
https://github.com/vllm-project/llm-compressor/blob/9258eb3e5d143b3bb38fa9abceb8da12e1e9cc08/src/llmcompressor/transformers/finetune/session_mixin.py#L583-L587,
when trl sft trainer is used, `max_seq_length` is in both
`training_args` and `data_args`. Update `training_args_dict`'s
`max_seq_length` key to `training_args_max_seq_length`. This is used to
populate the metadata, where it used to populate the state for
bookkeeping.
5. Add `output_dir` to `trainer.save_model`


TEST PLAN:
* Run `llm-compressor/examples/trl_mixin/ex_trl_distillation.py` to
completion, check the outputs
* Pass existing tests


OUTPUT:
```bash
(.venv) gohashi@janice:~/llm-compressor$ cpy 2,3 '/home/gohashi/llm-compressor/examples/trl_mixin/ex_trl_distillation.py'
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████| 3/3 [00:08<00:00,  2.89s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████| 3/3 [00:12<00:00,  4.24s/it]
Tokenizing:  64%|███████████████████████████████████████████████████▉                        Tokenizing:  67%|█████████████████████████████████████████████████████▉                      Tokenizing:  69%|████████████████████████████████████████████████████████▎                   Tokenizing:  72%|██████████████████████████████████████████████████████████▎                 Tokenizing:  75%|████████████████████████████████████████████████████████████▌               Tokenizing:  77%|██████████████████████████████████████████████████████████████▋             Tokenizing:  80%|████████████████████████████████████████████████████████████████▋           Tokenizing:  83%|███████████████████████████████████████████████████████████████████▏        Tokenizing:  86%|█████████████████████████████████████████████████████████████████████▎      Tokenizing:  88%|███████████████████████████████████████████████████████████████████████▌    Tokenizing:  91%|█████████████████████████████████████████████████████████████████████████▋  Tokenizing:  94%|███████████████████████████████████████████████████████████████████████████▊Tokenizing:  96%|████████████████████████████████████████████████████████████████████████████Tokenizing:  99%|████████████████████████████████████████████████████████████████████████████Tokenizing: 100%|█████████████████████████████████████████████████████████████████████████████████| 7473/7473 [00:04<00:00, 1689.60 examples/s]
Adding labels: 100%|████████████████████████████| 7473/7473 [00:03<00:00, 2193.90 examples/s]
--> Training Set Length = 7473
2025-02-17T18:00:02.961389-0500 | _calculate_checkpoint_info | WARNING - resume_from_checkpoint not passed into LLM Compressor Trainer.train. This will cause issues with restoring recipes when running from a checkpoint.
2025-02-17T18:00:02.964752-0500 | _check_create_state | INFO - State created for compression lifecycle
2025-02-17T18:00:03.015515-0500 | _check_compile_recipe | INFO - Recipe compiled and 1 modifiers created
manager stage: Modifiers initialized
2025-02-17T18:00:03.650149-0500 | initialize | INFO - Compression lifecycle initialized for 1 modifiers
manager stage: Modifiers initialized
2025-02-17T18:00:03.824371-0500 | initialize | INFO - Compression lifecycle initialized for 1 modifiers
  0%|                                                                 | 0/94 [00:00<?, ?it/s]2025-02-17T18:00:03.876159-0500 | _check_setup_event_lifecycle | INFO - Event lifecycle for compression lifecycle created: CallbacksEventLifecycle(type_=None, steps_per_epoch=935, batches_per_step=None, invocations_per_step=1, global_step=0, global_batch=0) with start event type: EventType.BATCH_START
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
{'step_loss': 1.0210824012756348, 'perplexity': 2.776197910308838, 'distill_step_loss': 6.337211608886719, 'epoch': 0}
{'loss': 3.0398, 'grad_norm': 11.9375, 'learning_rate': 9.361702127659576e-06, 'epoch': 0.05}
{'step_loss': 0.664872407913208, 'perplexity': 1.9442424774169922, 'distill_step_loss': 1.768540620803833, 'epoch': 0.05}
100%|████████████████████████████████████████████████████████| 94/94 [02:41<00:00,  1.65s/it]Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
2025-02-17T18:03:08.193956-0500 | save_model | INFO - Saved LLM Compressor recipe with model state to ./output_trl_sft_test_7b_gsm8k/checkpoint-94/recipe.yaml
{'train_runtime': 226.8346, 'train_samples_per_second': 3.294, 'train_steps_per_second': 0.414, 'train_loss': 2.703931686726022, 'epoch': 0.1}
100%|████████████████████████████████████████████████████████| 94/94 [03:46<00:00,  2.41s/it]
manager stage: Modifiers finalized
2025-02-17T18:03:50.979758-0500 | finalize | INFO - Compression lifecycle finalized for 1 modifiers
2025-02-17T18:03:50.979908-0500 | finalize_session | INFO - Finalized LLM Compressor session
2025-02-17T18:04:36.878043-0500 | log_model_sparsification | INFO - Sparsification info for LlamaForCausalLM: 6738415616 total params. 
Calculating model sparsity: 100%|██████████████████████████| 291/291 [00:07<00:00, 37.93it/s]
2025-02-17T18:04:44.552073-0500 | log_model_sparsification | INFO - There are 6738415616 prunable params which have 48.05% avg sparsity.
2025-02-17T18:04:44.553706-0500 | log_model_sparsification | INFO - There are 6738415616 quantizable params, with a quantization percentage of 0.00%.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
2025-02-17T18:05:08.985045-0500 | save_model | INFO - Saved LLM Compressor recipe with model state to ./output_trl_sft_test_7b_gsm8k/recipe.yaml
```

```bash
(.venv) gohashi@janice:~/llm-compressor/output_trl_sft_test_7b_gsm8k$ ls
checkpoint-94                     pytorch_model-00003-of-00003.bin  tokenizer.json
config.json                       pytorch_model.bin.index.json      tokenizer.model
generation_config.json            recipe.yaml                       trainer_state.json
pytorch_model-00001-of-00003.bin  special_tokens_map.json
pytorch_model-00002-of-00003.bin  tokenizer_config.json
```

---------

Signed-off-by: George Ohashi <[email protected]>
Co-authored-by: Kyle Sayers <[email protected]>
  • Loading branch information
horheynm and kylesayrs authored Feb 18, 2025
1 parent 9258eb3 commit 32dd30d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 16 deletions.
24 changes: 14 additions & 10 deletions examples/trl_mixin/ex_trl_distillation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sft_trainer import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator

from llmcompressor.args import DatasetArguments, TrainingArguments
from llmcompressor.args import DatasetArguments, ModelArguments
from llmcompressor.transformers import TextGenerationDataset

model_path = "neuralmagic/Llama-2-7b-pruned50-retrained"
Expand All @@ -16,18 +16,19 @@
)

tokenizer = AutoTokenizer.from_pretrained(model_path)
max_seq_length = 512

# Load gsm8k using SparseML dataset tools
data_args = DatasetArguments(
dataset="gsm8k", dataset_config_name="main", max_seq_length=512
dataset="gsm8k", dataset_config_name="main", max_seq_length=max_seq_length
)
dataset_manager = TextGenerationDataset.load_from_registry(
data_args.dataset,
data_args=data_args,
split="train",
tokenizer=tokenizer,
processor=tokenizer,
)
train_dataset = dataset_manager.tokenize_and_process()
train_dataset = dataset_manager()
print(f"--> Training Set Length = {len(train_dataset)}")

# recipe for maintaining model sparsity during finetuning
Expand All @@ -48,25 +49,28 @@
"""

data_collator = DefaultDataCollator()
training_args = TrainingArguments(
trl_sft_config_args = dict(
output_dir=output_dir,
num_train_epochs=0.6,
logging_steps=50,
gradient_checkpointing=True,
bf16=True,
save_safetensors=False, # workaround for shared tensors
max_seq_length=max_seq_length,
packing=True,
)
model_args = ModelArguments(model=model, distill_teacher=teacher)

trainer = SFTTrainer(
model=model,
teacher=teacher,
tokenizer=tokenizer,
processing_class=tokenizer,
recipe=recipe,
train_dataset=train_dataset,
data_collator=data_collator,
args=training_args,
trl_sft_config_args=trl_sft_config_args,
data_args=data_args,
max_seq_length=data_args.max_seq_length,
packing=True,
model_args=model_args,
)
trainer.train()
trainer.save_model()
trainer.save_model(output_dir)
26 changes: 20 additions & 6 deletions src/llmcompressor/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,29 @@ def __init__(

# parse training and metadata args
training_args = kwargs.get("args")
self.metadata = (
self._extract_metadata(

self.metadata = None
if training_args is not None:
# trl_sft_trainer pathway. Both training_args and data_args
# have `max_seq_length` which causes collision error. This is the
# only shared parameter, where training arg is `TRLSFTConfig` that
# inherits HuggingFace's `TrainingArguments`
training_args_dict = training_args.to_dict()
if "max_seq_length" in training_args_dict:
training_args_dict["training_args_max_seq_length"] = (
training_args_dict.pop("max_seq_length")
)
logger.warning(
"Detected `max_seq_length` in both data_args ",
"and training_args. This is expected for TRL in distillation. ",
"Updating metadata to `training_args_max_seq_length`",
)

self.metadata = self._extract_metadata(
metadata_args=METADATA_ARGS,
training_args_dict=training_args.to_dict(),
training_args_dict=training_args_dict,
data_args_dict=asdict(data_args) if data_args else {},
)
if training_args and METADATA_ARGS
else None
)

# setup metrics and session
self.logger_manager = LoggerManager(log_python=False)
Expand Down

0 comments on commit 32dd30d

Please sign in to comment.