Skip to content

Commit

Permalink
remove double init, replace with update_state
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <[email protected]>
  • Loading branch information
kylesayrs committed Feb 19, 2025
1 parent bf9a8cd commit 438eae5
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/llmcompressor/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
finalize,
initialize,
reset_session,
update_state,
)
from llmcompressor.core.state import Data, Hardware, ModifiedState, State

Expand All @@ -37,6 +38,7 @@
"active_session",
"reset_session",
"initialize",
"update_state",
"finalize",
"apply",
"callbacks",
Expand Down
7 changes: 7 additions & 0 deletions src/llmcompressor/core/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ def initialize(

return mod_data

def update_state(self, **kwargs):
"""
TODO
"""
logger.info(f"Updated state with {kwargs}")
self.state.update(**kwargs)

def finalize(self, **kwargs) -> List[Any]:
"""
Finalize the compression lifecycle.
Expand Down
15 changes: 15 additions & 0 deletions src/llmcompressor/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,21 @@ def initialize(
modifier_data=mod_data,
)

def update_state(self, **kwargs) -> ModifiedState:
"""
TODO
"""

self._lifecycle.update_state(
**kwargs,
)

return ModifiedState(
model=self.state.model,
optimizer=self.state.optimizer,
loss=self.state.loss,
)

def finalize(self, **kwargs) -> ModifiedState:
"""
Finalize the session for compression. This will run the finalize method
Expand Down
8 changes: 8 additions & 0 deletions src/llmcompressor/core/session_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"active_session",
"reset_session",
"initialize",
"update_state",
"finalize",
"apply",
"callbacks",
Expand Down Expand Up @@ -122,6 +123,13 @@ def initialize(
)


def update_state(**kwargs) -> ModifiedState:
"""
TODO
"""
return active_session().update_state(**kwargs)


def finalize(**kwargs) -> ModifiedState:
"""
Method to finalize the active session for sparsification
Expand Down
5 changes: 4 additions & 1 deletion src/llmcompressor/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
create_session,
finalize,
initialize,
update_state,
)
from llmcompressor.metrics import LoggerManager
from llmcompressor.modifiers.distillation.utils.pytorch.model_wrapper import (
Expand Down Expand Up @@ -224,7 +225,9 @@ def create_optimizer(self):
len(self.train_dataset) / total_batch_size
)

initialize(optimizer=self.optimizer, steps_per_epoch=self.total_steps_per_epoch)
update_state(
optimizer=self.optimizer, steps_per_epoch=self.total_steps_per_epoch
)

return self.optimizer

Expand Down

0 comments on commit 438eae5

Please sign in to comment.