Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TRL_SFT_Trainer] Fix TRL-SFT Distillation Training #1163

Merged
merged 7 commits into from
Feb 18, 2025
Merged

Conversation

horheynm
Copy link
Collaborator

@horheynm horheynm commented Feb 17, 2025

SUMMARY:

PROBLEM:
1.

TypeError: GSM8KDataset.__init__() got an unexpected keyword argument 'tokenizer'
AttributeError: 'GSM8KDataset' object has no attribute 'tokenize_and_process'
TypeError: SFTTrainer.__init__() got an unexpected keyword argument 'packing'
ValueError: Found common keys in `training_args` and `data args`. This is prohibitive and may lead to undesired behavior.
TypeError: SessionManagerMixIn.save_model() missing 1 required positional argument: 'output_dir'

SOLUTION:

  1. TextGenerationDataset.load_from_registry takes in processor, not tokenizer
  2. Obtain training dataset from __call__ of TextGenerationDataset, not dataset_manager.tokenize_and_process()
  3. Move max_seq_length and packing as a part of TRLSFTConfig, not TrainingArguments
  4. Collision on "max_seq_length' on
    if not training_args_dict.keys().isdisjoint(data_args_dict.keys()):
    raise ValueError(
    "Found common keys in `training_args` and `data args`. "
    "This is prohibitive and may lead to undesired behavior."
    )
    , when trl sft trainer is used, max_seq_length is in both training_args and data_args. Update training_args_dict's max_seq_length key to training_args_max_seq_length. This is used to populate the metadata, where it used to populate the state for bookkeeping.
  5. Add output_dir to trainer.save_model

TEST PLAN:

  • Run llm-compressor/examples/trl_mixin/ex_trl_distillation.py to completion, check the outputs
  • Pass existing tests

OUTPUT:

(.venv) gohashi@janice:~/llm-compressor$ cpy 2,3 '/home/gohashi/llm-compressor/examples/trl_mixin/ex_trl_distillation.py'
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████| 3/3 [00:08<00:00,  2.89s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████| 3/3 [00:12<00:00,  4.24s/it]
Tokenizing:  64%|███████████████████████████████████████████████████▉                        Tokenizing:  67%|█████████████████████████████████████████████████████▉                      Tokenizing:  69%|████████████████████████████████████████████████████████▎                   Tokenizing:  72%|██████████████████████████████████████████████████████████▎                 Tokenizing:  75%|████████████████████████████████████████████████████████████▌               Tokenizing:  77%|██████████████████████████████████████████████████████████████▋             Tokenizing:  80%|████████████████████████████████████████████████████████████████▋           Tokenizing:  83%|███████████████████████████████████████████████████████████████████▏        Tokenizing:  86%|█████████████████████████████████████████████████████████████████████▎      Tokenizing:  88%|███████████████████████████████████████████████████████████████████████▌    Tokenizing:  91%|█████████████████████████████████████████████████████████████████████████▋  Tokenizing:  94%|███████████████████████████████████████████████████████████████████████████▊Tokenizing:  96%|████████████████████████████████████████████████████████████████████████████Tokenizing:  99%|████████████████████████████████████████████████████████████████████████████Tokenizing: 100%|█████████████████████████████████████████████████████████████████████████████████| 7473/7473 [00:04<00:00, 1689.60 examples/s]
Adding labels: 100%|████████████████████████████| 7473/7473 [00:03<00:00, 2193.90 examples/s]
--> Training Set Length = 7473
2025-02-17T18:00:02.961389-0500 | _calculate_checkpoint_info | WARNING - resume_from_checkpoint not passed into LLM Compressor Trainer.train. This will cause issues with restoring recipes when running from a checkpoint.
2025-02-17T18:00:02.964752-0500 | _check_create_state | INFO - State created for compression lifecycle
2025-02-17T18:00:03.015515-0500 | _check_compile_recipe | INFO - Recipe compiled and 1 modifiers created
manager stage: Modifiers initialized
2025-02-17T18:00:03.650149-0500 | initialize | INFO - Compression lifecycle initialized for 1 modifiers
manager stage: Modifiers initialized
2025-02-17T18:00:03.824371-0500 | initialize | INFO - Compression lifecycle initialized for 1 modifiers
  0%|                                                                 | 0/94 [00:00<?, ?it/s]2025-02-17T18:00:03.876159-0500 | _check_setup_event_lifecycle | INFO - Event lifecycle for compression lifecycle created: CallbacksEventLifecycle(type_=None, steps_per_epoch=935, batches_per_step=None, invocations_per_step=1, global_step=0, global_batch=0) with start event type: EventType.BATCH_START
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
{'step_loss': 1.0210824012756348, 'perplexity': 2.776197910308838, 'distill_step_loss': 6.337211608886719, 'epoch': 0}
{'loss': 3.0398, 'grad_norm': 11.9375, 'learning_rate': 9.361702127659576e-06, 'epoch': 0.05}
{'step_loss': 0.664872407913208, 'perplexity': 1.9442424774169922, 'distill_step_loss': 1.768540620803833, 'epoch': 0.05}
100%|████████████████████████████████████████████████████████| 94/94 [02:41<00:00,  1.65s/it]Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
2025-02-17T18:03:08.193956-0500 | save_model | INFO - Saved LLM Compressor recipe with model state to ./output_trl_sft_test_7b_gsm8k/checkpoint-94/recipe.yaml
{'train_runtime': 226.8346, 'train_samples_per_second': 3.294, 'train_steps_per_second': 0.414, 'train_loss': 2.703931686726022, 'epoch': 0.1}
100%|████████████████████████████████████████████████████████| 94/94 [03:46<00:00,  2.41s/it]
manager stage: Modifiers finalized
2025-02-17T18:03:50.979758-0500 | finalize | INFO - Compression lifecycle finalized for 1 modifiers
2025-02-17T18:03:50.979908-0500 | finalize_session | INFO - Finalized LLM Compressor session
2025-02-17T18:04:36.878043-0500 | log_model_sparsification | INFO - Sparsification info for LlamaForCausalLM: 6738415616 total params. 
Calculating model sparsity: 100%|██████████████████████████| 291/291 [00:07<00:00, 37.93it/s]
2025-02-17T18:04:44.552073-0500 | log_model_sparsification | INFO - There are 6738415616 prunable params which have 48.05% avg sparsity.
2025-02-17T18:04:44.553706-0500 | log_model_sparsification | INFO - There are 6738415616 quantizable params, with a quantization percentage of 0.00%.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
2025-02-17T18:05:08.985045-0500 | save_model | INFO - Saved LLM Compressor recipe with model state to ./output_trl_sft_test_7b_gsm8k/recipe.yaml
(.venv) gohashi@janice:~/llm-compressor/output_trl_sft_test_7b_gsm8k$ ls
checkpoint-94                     pytorch_model-00003-of-00003.bin  tokenizer.json
config.json                       pytorch_model.bin.index.json      tokenizer.model
generation_config.json            recipe.yaml                       trainer_state.json
pytorch_model-00001-of-00003.bin  special_tokens_map.json
pytorch_model-00002-of-00003.bin  tokenizer_config.json

Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

@horheynm horheynm marked this pull request as ready for review February 17, 2025 23:14
@horheynm horheynm added the ready When a PR is ready for review label Feb 18, 2025
Copy link
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a similar question to Dipika's here but if you two are in agreement everything else LGTM!

rahul-tuli
rahul-tuli previously approved these changes Feb 18, 2025
Copy link
Collaborator

@rahul-tuli rahul-tuli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending comments!

Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just small comments about adding comments otherwise LGTM

@horheynm horheynm dismissed stale reviews from rahul-tuli and brian-dellabetta via 83255e9 February 18, 2025 15:25
Signed-off-by: George Ohashi <[email protected]>
@dsikka dsikka enabled auto-merge (squash) February 18, 2025 18:09
@dsikka dsikka merged commit 32dd30d into main Feb 18, 2025
7 checks passed
@dsikka dsikka deleted the fix-trl-distillation branch February 18, 2025 18:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready When a PR is ready for review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants