diff --git a/src/llmcompressor/core/__init__.py b/src/llmcompressor/core/__init__.py index ed4134af7..77be21fa7 100644 --- a/src/llmcompressor/core/__init__.py +++ b/src/llmcompressor/core/__init__.py @@ -17,6 +17,7 @@ finalize, initialize, reset_session, + update_state, ) from llmcompressor.core.state import Data, Hardware, ModifiedState, State @@ -37,6 +38,7 @@ "active_session", "reset_session", "initialize", + "update_state", "finalize", "apply", "callbacks", diff --git a/src/llmcompressor/core/lifecycle.py b/src/llmcompressor/core/lifecycle.py index e7274b21a..4af9d3d16 100644 --- a/src/llmcompressor/core/lifecycle.py +++ b/src/llmcompressor/core/lifecycle.py @@ -108,6 +108,13 @@ def initialize( return mod_data + def update_state(self, **kwargs): + """ + TODO + """ + logger.info(f"Updated state with {kwargs}") + self.state.update(**kwargs) + def finalize(self, **kwargs) -> List[Any]: """ Finalize the compression lifecycle. diff --git a/src/llmcompressor/core/session.py b/src/llmcompressor/core/session.py index 07eb2dc57..ae23b5736 100644 --- a/src/llmcompressor/core/session.py +++ b/src/llmcompressor/core/session.py @@ -143,6 +143,21 @@ def initialize( modifier_data=mod_data, ) + def update_state(self, **kwargs) -> ModifiedState: + """ + TODO + """ + + self._lifecycle.update_state( + **kwargs, + ) + + return ModifiedState( + model=self.state.model, + optimizer=self.state.optimizer, + loss=self.state.loss, + ) + def finalize(self, **kwargs) -> ModifiedState: """ Finalize the session for compression. This will run the finalize method diff --git a/src/llmcompressor/core/session_functions.py b/src/llmcompressor/core/session_functions.py index 5f8fd1a0c..fedea7ed7 100644 --- a/src/llmcompressor/core/session_functions.py +++ b/src/llmcompressor/core/session_functions.py @@ -12,6 +12,7 @@ "active_session", "reset_session", "initialize", + "update_state", "finalize", "apply", "callbacks", @@ -122,6 +123,13 @@ def initialize( ) +def update_state(**kwargs) -> ModifiedState: + """ + TODO + """ + return active_session().update_state(**kwargs) + + def finalize(**kwargs) -> ModifiedState: """ Method to finalize the active session for sparsification diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index fa7e138ac..6ed2c81f2 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -18,6 +18,7 @@ create_session, finalize, initialize, + update_state, ) from llmcompressor.metrics import LoggerManager from llmcompressor.modifiers.distillation.utils.pytorch.model_wrapper import ( @@ -224,7 +225,9 @@ def create_optimizer(self): len(self.train_dataset) / total_batch_size ) - initialize(optimizer=self.optimizer, steps_per_epoch=self.total_steps_per_epoch) + update_state( + optimizer=self.optimizer, steps_per_epoch=self.total_steps_per_epoch + ) return self.optimizer