Skip to content

Commit

Permalink
merge, add todos
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <[email protected]>
  • Loading branch information
kylesayrs committed Feb 19, 2025
2 parents a68bc1e + bf9a8cd commit e16836b
Show file tree
Hide file tree
Showing 29 changed files with 262 additions and 704 deletions.
12 changes: 7 additions & 5 deletions examples/trl_mixin/ex_trl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DataCollatorForCompletionOnlyLM

from llmcompressor.args import TrainingArguments
from llmcompressor.args import ModelArguments

model_path = "neuralmagic/Llama-2-7b-pruned50-retrained"
output_dir = "./output_trl_sft_test_7b_gsm8k_sft_data"
Expand Down Expand Up @@ -39,21 +39,23 @@ def formatting_prompts_func(example):
response_template = "Answer:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

training_args = TrainingArguments(
trl_sft_config_args = dict(
output_dir=output_dir,
num_train_epochs=0.6,
logging_steps=50,
gradient_checkpointing=True,
max_seq_length=512,
)
model_args = ModelArguments(model=model)

trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
processing_class=tokenizer,
recipe=recipe,
train_dataset=dataset,
formatting_func=formatting_prompts_func,
data_collator=collator,
args=training_args,
max_seq_length=512,
trl_sft_config_args=trl_sft_config_args,
model_args=model_args,
)
trainer.train()
24 changes: 14 additions & 10 deletions examples/trl_mixin/ex_trl_distillation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sft_trainer import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator

from llmcompressor.args import DatasetArguments, TrainingArguments
from llmcompressor.args import DatasetArguments, ModelArguments
from llmcompressor.transformers import TextGenerationDataset

model_path = "neuralmagic/Llama-2-7b-pruned50-retrained"
Expand All @@ -16,18 +16,19 @@
)

tokenizer = AutoTokenizer.from_pretrained(model_path)
max_seq_length = 512

# Load gsm8k using SparseML dataset tools
data_args = DatasetArguments(
dataset="gsm8k", dataset_config_name="main", max_seq_length=512
dataset="gsm8k", dataset_config_name="main", max_seq_length=max_seq_length
)
dataset_manager = TextGenerationDataset.load_from_registry(
data_args.dataset,
data_args=data_args,
split="train",
tokenizer=tokenizer,
processor=tokenizer,
)
train_dataset = dataset_manager.tokenize_and_process()
train_dataset = dataset_manager()
print(f"--> Training Set Length = {len(train_dataset)}")

# recipe for maintaining model sparsity during finetuning
Expand All @@ -48,25 +49,28 @@
"""

data_collator = DefaultDataCollator()
training_args = TrainingArguments(
trl_sft_config_args = dict(
output_dir=output_dir,
num_train_epochs=0.6,
logging_steps=50,
gradient_checkpointing=True,
bf16=True,
save_safetensors=False, # workaround for shared tensors
max_seq_length=max_seq_length,
packing=True,
)
model_args = ModelArguments(model=model, distill_teacher=teacher)

trainer = SFTTrainer(
model=model,
teacher=teacher,
tokenizer=tokenizer,
processing_class=tokenizer,
recipe=recipe,
train_dataset=train_dataset,
data_collator=data_collator,
args=training_args,
trl_sft_config_args=trl_sft_config_args,
data_args=data_args,
max_seq_length=data_args.max_seq_length,
packing=True,
model_args=model_args,
)
trainer.train()
trainer.save_model()
trainer.save_model(output_dir)
24 changes: 5 additions & 19 deletions examples/trl_mixin/sft_trainer.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
from typing import Dict, Optional

from trl import SFTConfig as TRLSFTConfig
from trl import SFTTrainer as TRLSFTTrainer

from llmcompressor.args import TrainingArguments
from llmcompressor.transformers.finetune.session_mixin import SessionManagerMixIn

__all__ = ["SFTTrainer"]


class SFTTrainer(SessionManagerMixIn, TRLSFTTrainer):
def __init__(self, *args, **kwargs):
sft_config_args = kwargs.get("args")
if (
sft_config_args is not None
and sft_config_args.__class__.__name__ == "TrainingArguments"
):
kwargs["args"] = SFTConfig(**sft_config_args.to_dict())
def __init__(self, trl_sft_config_args: Optional[Dict] = None, *args, **kwargs):
if trl_sft_config_args is not None:
kwargs["args"] = TRLSFTConfig(**trl_sft_config_args)
super().__init__(*args, **kwargs)

def _prepare_dataset(self, dataset, *args, **kwargs):
Expand All @@ -23,14 +20,3 @@ def _prepare_dataset(self, dataset, *args, **kwargs):
return dataset

return super()._prepare_dataset(dataset, *args, **kwargs)


class SFTConfig(TrainingArguments, TRLSFTConfig):
"""
This class is needed to wrap the llmcompressor.transformers.TrainingArguments
and TRLSFTConfig classes. This allows for the use of arguments and
configurations from both classes when training a model.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"datasets",
"accelerate>=0.20.3,!=1.1.0",
"pynvml",
"compressed-tensors"
"compressed-tensors==0.9.2"
if version_info.build_type == "release"
else "compressed-tensors-nightly",
],
Expand Down
1 change: 0 additions & 1 deletion src/llmcompressor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,5 @@
create_session,
finalize,
initialize,
pre_initialize_structure,
reset_session,
)
2 changes: 0 additions & 2 deletions src/llmcompressor/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
create_session,
finalize,
initialize,
pre_initialize_structure,
reset_session,
)
from llmcompressor.core.state import Data, Hardware, ModifiedState, State
Expand All @@ -37,7 +36,6 @@
"create_session",
"active_session",
"reset_session",
"pre_initialize_structure",
"initialize",
"finalize",
"apply",
Expand Down
31 changes: 0 additions & 31 deletions src/llmcompressor/core/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class EventType(Enum):
The purpose of each EventType is to trigger the corresponding
modifier callback during training or post training pipelines.
:param PRE_INIT: Event type for pre-initialization.
:param INITIALIZE: Event type for initialization.
:param FINALIZE: Event type for finalization.
:param BATCH_START: Event type for the start of a batch.
Expand All @@ -38,7 +37,6 @@ class EventType(Enum):
"""

# training lifecycle
PRE_INIT = "pre_init"
INITIALIZE = "initialize"
FINALIZE = "finalize"

Expand All @@ -51,35 +49,6 @@ class EventType(Enum):
OPTIM_PRE_STEP = "optim_pre_step"
OPTIM_POST_STEP = "optim_post_step"

def order(self) -> int:
"""
Returns the priority order of the current EventType.
Lower values have higher priority.
:raises ValueError: if the event type is invalid.
:return: The order of the event type, lower has higher priority.
:rtype: int
"""
if self == EventType.PRE_INIT:
return 0
elif self == EventType.INITIALIZE:
return 10
elif self == EventType.FINALIZE:
return 20
elif self == EventType.BATCH_START:
return 100
elif self == EventType.LOSS_CALCULATED:
return 110
elif self == EventType.OPTIM_PRE_STEP:
return 120
elif self == EventType.OPTIM_POST_STEP:
return 130
elif self == EventType.BATCH_END:
return 140
else:
logger.error("Invalid event type: {}", self)
raise ValueError(f"Invalid event type {self}")


@dataclass
class Event:
Expand Down
92 changes: 26 additions & 66 deletions src/llmcompressor/core/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from llmcompressor.core.events import Event, EventType
from llmcompressor.core.state import State
from llmcompressor.modifiers import StageModifiers
from llmcompressor.recipe import RecipeContainer
from llmcompressor.recipe import (
RecipeArgsInput,
RecipeContainer,
RecipeInput,
RecipeStageInput,
)

__all__ = ["CompressionLifecycle"]

Expand All @@ -31,11 +36,11 @@ class CompressionLifecycle:
:type modifiers: List[StageModifiers]
"""

state: Optional[State] = field(default_factory=lambda: State)
state: State = field(default_factory=State)
recipe_container: RecipeContainer = field(default_factory=RecipeContainer)
modifiers: List[StageModifiers] = field(default_factory=list)
current_index: int = 0

initialized_structure: bool = False
initialized_: bool = False
finalized: bool = False
event_called: bool = False
Expand All @@ -59,55 +64,32 @@ def reset(self):
self.__init__()
logger.info("Compression lifecycle reset")

def pre_initialize_structure(self, **kwargs) -> List[Any]:
"""
Pre-initialize the structure of the compression lifecycle.
:param kwargs: Additional arguments to update the state with
:return: List of data returned from pre-initialization of modifiers
:rtype: List[Any]
"""
logger.debug("Pre-initializing structure")
self._check_create_state()
extras = self.state.update(**kwargs)
extras = self.recipe_container.update(**extras)

self._check_compile_recipe()
mod_data = []
for mod in self.modifiers:
data = mod.pre_initialize_structure(state=self.state, **extras)
logger.debug("Pre-initialized modifier: {}", mod)
if data is not None:
mod_data.append(data)

self.initialized_structure = True
applied_stage_names = [mod.unique_id for mod in self.modifiers if mod.applied]
self.recipe_container.update_applied_stages(applied_stage_names)
logger.info(
"Compression lifecycle structure pre-initialized for {} modifiers",
len(self.modifiers),
)

return mod_data

def initialize(self, **kwargs) -> List[Any]:
def initialize(
self,
recipe: Optional[RecipeInput] = None,
recipe_stage: Optional[RecipeStageInput] = None,
recipe_args: Optional[RecipeArgsInput] = None,
**kwargs,
) -> List[Any]:
"""
Initialize the compression lifecycle.
:param kwargs: Additional arguments to update the state with
:return: List of data returned from initialization of modifiers
:rtype: List[Any]
"""
logger.debug("Initializing compression lifecycle")
self._check_create_state()
extras = self.state.update(**kwargs)
extras = self.recipe_container.update(**extras)
self.state.update(**kwargs)
if self.initialized_: # TODO: do not initialize twice
return

self._check_compile_recipe()
logger.debug("Initializing compression lifecycle")
self.recipe_container.append(recipe, recipe_stage, recipe_args)
self.modifiers = self.recipe_container.get_modifiers()
self._set_model_layer_prefix()

mod_data = []
for mod in self.modifiers:
data = mod.initialize(state=self.state, **extras)
data = mod.initialize(state=self.state, **kwargs)
logger.debug("Initialized modifier: {}", mod)
if data is not None:
mod_data.append(data)
Expand Down Expand Up @@ -174,7 +156,7 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:
logger.error("Cannot invoke event after finalizing")
raise ValueError("Cannot invoke event after finalizing")

if event_type in [EventType.PRE_INIT, EventType.INITIALIZE, EventType.FINALIZE]:
if event_type in [EventType.INITIALIZE, EventType.FINALIZE]:
logger.error(
"Cannot invoke {} event. Use the corresponding method instead.",
event_type,
Expand All @@ -184,6 +166,8 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:
f"Use the corresponding method instead."
)

# TODO: populate current_index with event

if event_type == EventType.LOSS_CALCULATED and (
"loss" not in kwargs or kwargs["loss"] is None
):
Expand All @@ -208,30 +192,6 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:

return mod_data

def _check_create_state(self):
if self.state is not None:
return

logger.debug("Creating new State instance for compression lifecycle")
self.state = State()
logger.info("State created for compression lifecycle")

def _check_compile_recipe(self):
if not self.recipe_container.check_compile_recipe():
return

logger.debug(
"Compiling recipe and creating modifiers for compression lifecycle"
)
self.modifiers = self.recipe_container.compiled_recipe.create_modifier()
for mod in self.modifiers:
if mod.unique_id in self.recipe_container.applied_stages:
mod.applied = True
logger.info(
"Recipe compiled and {} modifiers created",
len(self.modifiers),
)

def _set_model_layer_prefix(self):
compiled_recipe = self.recipe_container.compiled_recipe
if (
Expand Down
Loading

0 comments on commit e16836b

Please sign in to comment.