Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TRL_SFT_Trainer] Fix TRL-SFT Distillation Training (#1163)
SUMMARY: * Fix examples script failure https://github.com/neuralmagic/llm-compressor-testing/actions/runs/13350457472/job/37286313648 for `llm-compressor/examples/trl_mixin/ex_trl_distillation.py` * Update code with respect to #1161 PROBLEM: 1. ```bash TypeError: GSM8KDataset.__init__() got an unexpected keyword argument 'tokenizer' ``` 2. ```bash AttributeError: 'GSM8KDataset' object has no attribute 'tokenize_and_process' ``` 3. ```bash TypeError: SFTTrainer.__init__() got an unexpected keyword argument 'packing' ``` 4. ```bash ValueError: Found common keys in `training_args` and `data args`. This is prohibitive and may lead to undesired behavior. ``` 5. ```bash TypeError: SessionManagerMixIn.save_model() missing 1 required positional argument: 'output_dir' ``` SOLUTION: 1. `TextGenerationDataset.load_from_registry` takes in `processor`, not `tokenizer` 2. Obtain training dataset from `__call__` of `TextGenerationDataset`, not `dataset_manager.tokenize_and_process()` 3. Move `max_seq_length` and `packing` as a part of `TRLSFTConfig`, not `TrainingArguments` 4. Collision on "max_seq_length' on https://github.com/vllm-project/llm-compressor/blob/9258eb3e5d143b3bb38fa9abceb8da12e1e9cc08/src/llmcompressor/transformers/finetune/session_mixin.py#L583-L587, when trl sft trainer is used, `max_seq_length` is in both `training_args` and `data_args`. Update `training_args_dict`'s `max_seq_length` key to `training_args_max_seq_length`. This is used to populate the metadata, where it used to populate the state for bookkeeping. 5. Add `output_dir` to `trainer.save_model` TEST PLAN: * Run `llm-compressor/examples/trl_mixin/ex_trl_distillation.py` to completion, check the outputs * Pass existing tests OUTPUT: ```bash (.venv) gohashi@janice:~/llm-compressor$ cpy 2,3 '/home/gohashi/llm-compressor/examples/trl_mixin/ex_trl_distillation.py' Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████| 3/3 [00:08<00:00, 2.89s/it] Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████| 3/3 [00:12<00:00, 4.24s/it] Tokenizing: 64%|███████████████████████████████████████████████████▉ Tokenizing: 67%|█████████████████████████████████████████████████████▉ Tokenizing: 69%|████████████████████████████████████████████████████████▎ Tokenizing: 72%|██████████████████████████████████████████████████████████▎ Tokenizing: 75%|████████████████████████████████████████████████████████████▌ Tokenizing: 77%|██████████████████████████████████████████████████████████████▋ Tokenizing: 80%|████████████████████████████████████████████████████████████████▋ Tokenizing: 83%|███████████████████████████████████████████████████████████████████▏ Tokenizing: 86%|█████████████████████████████████████████████████████████████████████▎ Tokenizing: 88%|███████████████████████████████████████████████████████████████████████▌ Tokenizing: 91%|█████████████████████████████████████████████████████████████████████████▋ Tokenizing: 94%|███████████████████████████████████████████████████████████████████████████▊Tokenizing: 96%|████████████████████████████████████████████████████████████████████████████Tokenizing: 99%|████████████████████████████████████████████████████████████████████████████Tokenizing: 100%|█████████████████████████████████████████████████████████████████████████████████| 7473/7473 [00:04<00:00, 1689.60 examples/s] Adding labels: 100%|████████████████████████████| 7473/7473 [00:03<00:00, 2193.90 examples/s] --> Training Set Length = 7473 2025-02-17T18:00:02.961389-0500 | _calculate_checkpoint_info | WARNING - resume_from_checkpoint not passed into LLM Compressor Trainer.train. This will cause issues with restoring recipes when running from a checkpoint. 2025-02-17T18:00:02.964752-0500 | _check_create_state | INFO - State created for compression lifecycle 2025-02-17T18:00:03.015515-0500 | _check_compile_recipe | INFO - Recipe compiled and 1 modifiers created manager stage: Modifiers initialized 2025-02-17T18:00:03.650149-0500 | initialize | INFO - Compression lifecycle initialized for 1 modifiers manager stage: Modifiers initialized 2025-02-17T18:00:03.824371-0500 | initialize | INFO - Compression lifecycle initialized for 1 modifiers 0%| | 0/94 [00:00<?, ?it/s]2025-02-17T18:00:03.876159-0500 | _check_setup_event_lifecycle | INFO - Event lifecycle for compression lifecycle created: CallbacksEventLifecycle(type_=None, steps_per_epoch=935, batches_per_step=None, invocations_per_step=1, global_step=0, global_batch=0) with start event type: EventType.BATCH_START `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`. {'step_loss': 1.0210824012756348, 'perplexity': 2.776197910308838, 'distill_step_loss': 6.337211608886719, 'epoch': 0} {'loss': 3.0398, 'grad_norm': 11.9375, 'learning_rate': 9.361702127659576e-06, 'epoch': 0.05} {'step_loss': 0.664872407913208, 'perplexity': 1.9442424774169922, 'distill_step_loss': 1.768540620803833, 'epoch': 0.05} 100%|████████████████████████████████████████████████████████| 94/94 [02:41<00:00, 1.65s/it]Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead. 2025-02-17T18:03:08.193956-0500 | save_model | INFO - Saved LLM Compressor recipe with model state to ./output_trl_sft_test_7b_gsm8k/checkpoint-94/recipe.yaml {'train_runtime': 226.8346, 'train_samples_per_second': 3.294, 'train_steps_per_second': 0.414, 'train_loss': 2.703931686726022, 'epoch': 0.1} 100%|████████████████████████████████████████████████████████| 94/94 [03:46<00:00, 2.41s/it] manager stage: Modifiers finalized 2025-02-17T18:03:50.979758-0500 | finalize | INFO - Compression lifecycle finalized for 1 modifiers 2025-02-17T18:03:50.979908-0500 | finalize_session | INFO - Finalized LLM Compressor session 2025-02-17T18:04:36.878043-0500 | log_model_sparsification | INFO - Sparsification info for LlamaForCausalLM: 6738415616 total params. Calculating model sparsity: 100%|██████████████████████████| 291/291 [00:07<00:00, 37.93it/s] 2025-02-17T18:04:44.552073-0500 | log_model_sparsification | INFO - There are 6738415616 prunable params which have 48.05% avg sparsity. 2025-02-17T18:04:44.553706-0500 | log_model_sparsification | INFO - There are 6738415616 quantizable params, with a quantization percentage of 0.00%. Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead. 2025-02-17T18:05:08.985045-0500 | save_model | INFO - Saved LLM Compressor recipe with model state to ./output_trl_sft_test_7b_gsm8k/recipe.yaml ``` ```bash (.venv) gohashi@janice:~/llm-compressor/output_trl_sft_test_7b_gsm8k$ ls checkpoint-94 pytorch_model-00003-of-00003.bin tokenizer.json config.json pytorch_model.bin.index.json tokenizer.model generation_config.json recipe.yaml trainer_state.json pytorch_model-00001-of-00003.bin special_tokens_map.json pytorch_model-00002-of-00003.bin tokenizer_config.json ``` --------- Signed-off-by: George Ohashi <[email protected]> Co-authored-by: Kyle Sayers <[email protected]>
- Loading branch information