diff --git a/examples/trl_mixin/ex_trl_distillation.py b/examples/trl_mixin/ex_trl_distillation.py index 96cc78846..ebd14c5d2 100644 --- a/examples/trl_mixin/ex_trl_distillation.py +++ b/examples/trl_mixin/ex_trl_distillation.py @@ -1,7 +1,7 @@ from sft_trainer import SFTTrainer from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator -from llmcompressor.args import DatasetArguments, TrainingArguments +from llmcompressor.args import DatasetArguments, ModelArguments from llmcompressor.transformers import TextGenerationDataset model_path = "neuralmagic/Llama-2-7b-pruned50-retrained" @@ -16,18 +16,19 @@ ) tokenizer = AutoTokenizer.from_pretrained(model_path) +max_seq_length = 512 # Load gsm8k using SparseML dataset tools data_args = DatasetArguments( - dataset="gsm8k", dataset_config_name="main", max_seq_length=512 + dataset="gsm8k", dataset_config_name="main", max_seq_length=max_seq_length ) dataset_manager = TextGenerationDataset.load_from_registry( data_args.dataset, data_args=data_args, split="train", - tokenizer=tokenizer, + processor=tokenizer, ) -train_dataset = dataset_manager.tokenize_and_process() +train_dataset = dataset_manager() print(f"--> Training Set Length = {len(train_dataset)}") # recipe for maintaining model sparsity during finetuning @@ -48,25 +49,28 @@ """ data_collator = DefaultDataCollator() -training_args = TrainingArguments( +trl_sft_config_args = dict( output_dir=output_dir, num_train_epochs=0.6, logging_steps=50, gradient_checkpointing=True, bf16=True, save_safetensors=False, # workaround for shared tensors + max_seq_length=max_seq_length, + packing=True, ) +model_args = ModelArguments(model=model, distill_teacher=teacher) + trainer = SFTTrainer( model=model, teacher=teacher, - tokenizer=tokenizer, + processing_class=tokenizer, recipe=recipe, train_dataset=train_dataset, data_collator=data_collator, - args=training_args, + trl_sft_config_args=trl_sft_config_args, data_args=data_args, - max_seq_length=data_args.max_seq_length, - packing=True, + model_args=model_args, ) trainer.train() -trainer.save_model() +trainer.save_model(output_dir) diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index e32c64f62..dc78385e9 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -79,15 +79,29 @@ def __init__( # parse training and metadata args training_args = kwargs.get("args") - self.metadata = ( - self._extract_metadata( + + self.metadata = None + if training_args is not None: + # trl_sft_trainer pathway. Both training_args and data_args + # have `max_seq_length` which causes collision error. This is the + # only shared parameter, where training arg is `TRLSFTConfig` that + # inherits HuggingFace's `TrainingArguments` + training_args_dict = training_args.to_dict() + if "max_seq_length" in training_args_dict: + training_args_dict["training_args_max_seq_length"] = ( + training_args_dict.pop("max_seq_length") + ) + logger.warning( + "Detected `max_seq_length` in both data_args ", + "and training_args. This is expected for TRL in distillation. ", + "Updating metadata to `training_args_max_seq_length`", + ) + + self.metadata = self._extract_metadata( metadata_args=METADATA_ARGS, - training_args_dict=training_args.to_dict(), + training_args_dict=training_args_dict, data_args_dict=asdict(data_args) if data_args else {}, ) - if training_args and METADATA_ARGS - else None - ) # setup metrics and session self.logger_manager = LoggerManager(log_python=False)