Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Callbacks] Consolidate Saving Methods #1168

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 11 additions & 32 deletions src/llmcompressor/pytorch/model_load/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from loguru import logger
from safetensors import safe_open
from torch.nn import Module
from transformers import PreTrainedModel

from llmcompressor.core import active_session, create_session, pre_initialize_structure
from llmcompressor.typing import Processor
Expand All @@ -14,20 +15,19 @@

__all__ = [
"initialize_recipe",
"save_model_and_recipe",
"copy_python_files_from_model_cache",
"fallback_to_cpu",
"parse_dtype",
"get_session_model",
"get_completed_stages",
"save_completed_stages",
"save_checkpoint",
]


def initialize_recipe(model: Module, recipe_path: str):
"""
Initializes a recipe that has been previously applied to the model

:param model: PyTorch model to apply structure to
:param recipe_path: path to recipe to apply to the model
"""
Expand All @@ -49,43 +49,22 @@ def initialize_recipe(model: Module, recipe_path: str):
logger.info(f"Applied {msg} to the model")


def save_model_and_recipe(
model: Module,
def save_checkpoint(
save_path: str,
processor: Optional[Processor] = None,
save_safetensors: bool = False,
save_compressed: bool = False,
model: PreTrainedModel,
processor: Processor,
save_safetensors: bool = True,
save_compressed: bool = True,
):
"""
Save a model, processor and the currently loaded recipe to file

:param model: pytorch model to save
:param save_path: path to save output to
:param processor: model processor or tokenizer to save
:param save_safetensors: whether to save as safetensors or pickle (bin)
:param save_compressed: whether to compress sparse weights on disk
"""
# avoid circular import
from llmcompressor.transformers.utils.helpers import RECIPE_FILE_NAME

# saving the model also saves the recipe
model.save_pretrained(
save_path, save_compressed=save_compressed, safe_serialization=save_safetensors
save_path,
save_safetensors=save_safetensors,
save_compressed=save_compressed,
)

if processor is not None:
processor.save_pretrained(save_path)

logger.info("Saving output to {}".format(os.path.abspath(save_path)))

recipe_path = os.path.join(save_path, RECIPE_FILE_NAME)
session = active_session()
recipe_yaml_str = session.get_serialized_recipe()
with open(recipe_path, "w") as fp:
fp.write(recipe_yaml_str)

# copy python files from cache dir to save_path if any
copy_python_files_from_model_cache(model, save_path)


def fallback_to_cpu(device: str) -> str:
"""
Expand Down
10 changes: 7 additions & 3 deletions src/llmcompressor/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from llmcompressor.pytorch.model_load.helpers import (
get_completed_stages,
get_session_model,
save_checkpoint,
save_completed_stages,
)
from llmcompressor.pytorch.utils import tensors_to_device
Expand All @@ -27,7 +28,7 @@
make_dataset_splits,
)
from llmcompressor.typing import Processor
from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_model_and_recipe
from llmcompressor.utils.fsdp.helpers import is_fsdp_model


class StageRunner:
Expand Down Expand Up @@ -261,17 +262,20 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None):
self.train(checkpoint=checkpoint, stage=stage_name)
checkpoint = None

# save model between stages
if (
self._training_args.output_dir
!= TrainingArguments.__dataclass_fields__["output_dir"].default
and self.trainer.accelerator.is_main_process
):
save_model_and_recipe(
model=self.trainer.model,
save_checkpoint(
save_path=self._output_dir,
model=self.trainer.model,
processor=self.processor,
save_safetensors=self._training_args.save_safetensors,
save_compressed=self._model_args.save_compressed,
)
self.trainer.accelerator.wait_for_everyone()

# save stage to checkpoint dir
if self.trainer.accelerator.is_main_process:
Expand Down
51 changes: 11 additions & 40 deletions src/llmcompressor/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,13 @@
from llmcompressor.modifiers.distillation.utils.pytorch.model_wrapper import (
KDModelWrapper,
)
from llmcompressor.pytorch.model_load.helpers import get_session_model
from llmcompressor.pytorch.model_load.helpers import get_session_model, save_checkpoint
from llmcompressor.pytorch.utils import ModuleSparsificationInfo
from llmcompressor.transformers import RECIPE_FILE_NAME
from llmcompressor.transformers.finetune.callbacks import (
DisableHalfPrecisionCallback,
TrainingLoopCallbacks,
)
from llmcompressor.utils.fsdp.context import summon_full_params_context
from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_pretrained_fsdp
from llmcompressor.utils.pytorch import qat_active

if TYPE_CHECKING:
Expand Down Expand Up @@ -66,8 +64,8 @@ class SessionManagerMixIn:
def __init__(
self,
recipe: str,
data_args: "DatasetArguments",
model_args: "ModelArguments",
data_args: Optional["DatasetArguments"] = None,
teacher: Optional[Union[Module, str]] = None,
recipe_args: Optional[Union[Dict[str, Any], str]] = None,
**kwargs,
Expand Down Expand Up @@ -185,7 +183,6 @@ def initialize_structure(self, stage: Optional[str] = None):
"""
Initialize any recipe structural changes such as quantization on the model,
return immediately if session has already been initialized

:param stage: Optional stage of recipe to run, or None to run all stages
"""
session = active_session()
Expand Down Expand Up @@ -415,7 +412,6 @@ def evaluate(self, *args, **kwargs):
Run a sparsification evaluation cycle.
Runs initialize_structure for the sparse session before calling
super().evaluate() and finalization of the session after.

:param args: positional args to pass to super().evaluate()
:param kwargs: keyword args to pass to super().evaluate()
:return: the output from super.evaluate()
Expand All @@ -432,12 +428,12 @@ def predict(self, *args, **kwargs):
Run a sparsification prediction cycle.
Runs initialize_structure for the sparse session before calling
super().predict() and finalization of the session after.

:param args: positional args to pass to super().predict()
:param kwargs: keyword args to pass to super().predict()
:return: the output from super.predict()
"""
self.initialize_structure()

output = super().predict(*args, **kwargs)
self.finalize_session()

Expand Down Expand Up @@ -483,44 +479,19 @@ def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False):

# knowledge distillation requires making wrappers transparent during
if isinstance(self.model, KDModelWrapper):
self.model.prepare_for_save()
self.model.prepare_for_save() # TODO: move to finalize

if not is_fsdp_model(self.model):
self.model.save_pretrained(
# save checkpoint
self.save_state()
if self.accelerator.is_main_process:
processor = getattr(self, "processing_class", self.tokenizer)
save_checkpoint(
output_dir,
save_compressed=self.model_args.save_compressed,
safe_serialization=self.args.save_safetensors,
)
else: # FSDP model
save_pretrained_fsdp(
model=self.model,
accelerator=self.accelerator,
output_dir=output_dir,
processor=processor,
save_safetensors=self.args.save_safetensors,
save_compressed=self.model_args.save_compressed,
save_safetensors=self.metadata.get("save_safetensors", False),
)

self.save_state()
processor = getattr(self, "processing_class", self.tokenizer)
if processor is not None:
processor.save_pretrained(output_dir)

if not self.recipe:
return

if self.accelerator.is_main_process:
# save recipe, will contain modifiers from the model's original recipe as
# well as those added from self.recipe
recipe_path = os.path.join(output_dir, RECIPE_FILE_NAME)
session = active_session()
recipe_yaml_str = session.get_serialized_recipe()
with open(recipe_path, "w") as fp:
fp.write(recipe_yaml_str)

logger.info(
f"Saved LLM Compressor recipe with model state to {recipe_path}"
)

self.accelerator.wait_for_everyone()

if isinstance(self.model, KDModelWrapper):
Expand Down
20 changes: 13 additions & 7 deletions src/llmcompressor/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@
get_session_model,
initialize_recipe,
parse_dtype,
save_checkpoint,
)
from llmcompressor.recipe import Recipe, StageRunType
from llmcompressor.transformers.finetune.runner import StageRunner
from llmcompressor.transformers.finetune.trainer import Trainer
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_fsdp_model_save_pretrained,
modify_save_pretrained,
patch_tied_tensors_bug,
)
Expand Down Expand Up @@ -418,7 +418,10 @@ def main(

# wrap model.save_pretrained
if is_fsdp_model(model):
modify_fsdp_model_save_pretrained(trainer, processor)
raise NotImplementedError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should put this in after the oneshot / stage runner refac. Currently this is an ok pipeline.

"FSDP models are not supported in the current release but will be "
"suported in future releases of LLM Compressor"
)
else:
modify_save_pretrained(model)

Expand Down Expand Up @@ -455,16 +458,19 @@ def main(
stage_runner.predict()

# save if model was provided as a string or custom output_dir was set

if isinstance(model_args.model, str) or (
training_args.output_dir
!= TrainingArguments.__dataclass_fields__["output_dir"].default
and trainer.accelerator.is_main_process
):
model.save_pretrained(
training_args.output_dir, save_compressed=model_args.save_compressed
save_checkpoint(
save_path=training_args.output_dir,
model=model,
processor=processor,
save_safetensors=True,
save_compressed=model_args.save_compressed,
)
if processor is not None:
processor.save_pretrained(training_args.output_dir)
trainer.accelerator.wait_for_everyone()

# Clean up the CompressionSession before exit if requested
if recipe_args.clear_sparse_session:
Expand Down
Loading