Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Callbacks] Remove pre_initialize_structure #1160

Open
wants to merge 16 commits into
base: kylesayrs/consolidate-saving
Choose a base branch
from
1 change: 0 additions & 1 deletion src/llmcompressor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,5 @@
create_session,
finalize,
initialize,
pre_initialize_structure,
reset_session,
)
2 changes: 0 additions & 2 deletions src/llmcompressor/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
create_session,
finalize,
initialize,
pre_initialize_structure,
reset_session,
)
from llmcompressor.core.state import Data, Hardware, ModifiedState, State
Expand All @@ -37,7 +36,6 @@
"create_session",
"active_session",
"reset_session",
"pre_initialize_structure",
"initialize",
"finalize",
"apply",
Expand Down
31 changes: 0 additions & 31 deletions src/llmcompressor/core/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class EventType(Enum):
The purpose of each EventType is to trigger the corresponding
modifier callback during training or post training pipelines.

:param PRE_INIT: Event type for pre-initialization.
:param INITIALIZE: Event type for initialization.
:param FINALIZE: Event type for finalization.
:param BATCH_START: Event type for the start of a batch.
Expand All @@ -38,7 +37,6 @@ class EventType(Enum):
"""

# training lifecycle
PRE_INIT = "pre_init"
INITIALIZE = "initialize"
FINALIZE = "finalize"

Expand All @@ -51,35 +49,6 @@ class EventType(Enum):
OPTIM_PRE_STEP = "optim_pre_step"
OPTIM_POST_STEP = "optim_post_step"

def order(self) -> int:
"""
Returns the priority order of the current EventType.
Lower values have higher priority.

:raises ValueError: if the event type is invalid.
:return: The order of the event type, lower has higher priority.
:rtype: int
"""
if self == EventType.PRE_INIT:
return 0
elif self == EventType.INITIALIZE:
return 10
elif self == EventType.FINALIZE:
return 20
elif self == EventType.BATCH_START:
return 100
elif self == EventType.LOSS_CALCULATED:
return 110
elif self == EventType.OPTIM_PRE_STEP:
return 120
elif self == EventType.OPTIM_POST_STEP:
return 130
elif self == EventType.BATCH_END:
return 140
else:
logger.error("Invalid event type: {}", self)
raise ValueError(f"Invalid event type {self}")


@dataclass
class Event:
Expand Down
99 changes: 24 additions & 75 deletions src/llmcompressor/core/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
)
from llmcompressor.core.state import State
from llmcompressor.modifiers import StageModifiers
from llmcompressor.recipe import RecipeContainer
from llmcompressor.recipe import (
RecipeArgsInput,
RecipeContainer,
RecipeInput,
RecipeStageInput,
)

__all__ = ["CompressionLifecycle"]

Expand All @@ -38,12 +43,11 @@ class CompressionLifecycle:
:type event_lifecycle: Optional[EventLifecycle]
"""

state: Optional[State] = None
state: State = field(default_factory=State)
recipe_container: RecipeContainer = field(default_factory=RecipeContainer)
modifiers: List[StageModifiers] = field(default_factory=list)
event_lifecycle: Optional[EventLifecycle] = None

initialized_structure: bool = False
initialized_: bool = False
finalized: bool = False
event_called: bool = False
Expand All @@ -64,66 +68,35 @@ def reset(self):
except Exception as e:
logger.warning(f"Exception during finalizing modifier: {e}")

self.state = None
self.recipe_container = RecipeContainer()
self.modifiers = []
self.event_lifecycle = None

self.initialized_structure = False
self.initialized_ = False
self.finalized = False
self.event_called = False
self.__init__()
logger.info("Compression lifecycle reset")

def pre_initialize_structure(self, **kwargs) -> List[Any]:
"""
Pre-initialize the structure of the compression lifecycle.

:param kwargs: Additional arguments to update the state with
:return: List of data returned from pre-initialization of modifiers
:rtype: List[Any]
"""
logger.debug("Pre-initializing structure")
self._check_create_state()
extras = self.state.update(**kwargs)
extras = self.recipe_container.update(**extras)

self._check_compile_recipe()
mod_data = []
for mod in self.modifiers:
data = mod.pre_initialize_structure(state=self.state, **extras)
logger.debug("Pre-initialized modifier: {}", mod)
if data is not None:
mod_data.append(data)

self.initialized_structure = True
applied_stage_names = [mod.unique_id for mod in self.modifiers if mod.applied]
self.recipe_container.update_applied_stages(applied_stage_names)
logger.info(
"Compression lifecycle structure pre-initialized for {} modifiers",
len(self.modifiers),
)

return mod_data

def initialize(self, **kwargs) -> List[Any]:
def initialize(
self,
recipe: Optional[RecipeInput] = None,
recipe_stage: Optional[RecipeStageInput] = None,
recipe_args: Optional[RecipeArgsInput] = None,
**kwargs,
) -> List[Any]:
"""
Initialize the compression lifecycle.

:param kwargs: Additional arguments to update the state with
:return: List of data returned from initialization of modifiers
:rtype: List[Any]
"""
logger.debug("Initializing compression lifecycle")
self._check_create_state()
extras = self.state.update(**kwargs)
extras = self.recipe_container.update(**extras)
self.state.update(**kwargs)
if self.initialized_: # TODO: do not initialize twice
return

self._check_compile_recipe()
logger.debug("Initializing compression lifecycle")
self.recipe_container.append(recipe, recipe_stage, recipe_args)
self.modifiers = self.recipe_container.get_modifiers()
self._set_model_layer_prefix()

mod_data = []
for mod in self.modifiers:
data = mod.initialize(state=self.state, **extras)
data = mod.initialize(state=self.state, **kwargs)
logger.debug("Initialized modifier: {}", mod)
if data is not None:
mod_data.append(data)
Expand Down Expand Up @@ -190,7 +163,7 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:
logger.error("Cannot invoke event after finalizing")
raise ValueError("Cannot invoke event after finalizing")

if event_type in [EventType.PRE_INIT, EventType.INITIALIZE, EventType.FINALIZE]:
if event_type in [EventType.INITIALIZE, EventType.FINALIZE]:
logger.error(
"Cannot invoke {} event. Use the corresponding method instead.",
event_type,
Expand Down Expand Up @@ -229,30 +202,6 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:

return mod_data

def _check_create_state(self):
if self.state is not None:
return

logger.debug("Creating new State instance for compression lifecycle")
self.state = State()
logger.info("State created for compression lifecycle")

def _check_compile_recipe(self):
if not self.recipe_container.check_compile_recipe():
return

logger.debug(
"Compiling recipe and creating modifiers for compression lifecycle"
)
self.modifiers = self.recipe_container.compiled_recipe.create_modifier()
for mod in self.modifiers:
if mod.unique_id in self.recipe_container.applied_stages:
mod.applied = True
logger.info(
"Recipe compiled and {} modifiers created",
len(self.modifiers),
)

def _check_setup_event_lifecycle(self, event_type: EventType):
if self.event_lifecycle is not None:
return
Expand Down
39 changes: 0 additions & 39 deletions src/llmcompressor/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,45 +65,6 @@ def state(self) -> State:
"""
return self._lifecycle.state

def pre_initialize_structure(
self,
model: Any,
recipe: Union[str, List[str], Recipe, List[Recipe], None] = None,
recipe_stage: Union[str, List[str], None] = None,
recipe_args: Union[Dict[str, Any], List[Dict[str, Any]], None] = None,
**kwargs,
) -> ModifiedState:
"""
A method to pre-initialize the structure of the model for compression.
This will run the pre-initialize structure method for each modifier in the
session's lifecycle. This will also set the session's state to the
pre-initialized state. Takes care of cases when the model(s) structure
has been previously modified by a modifier.

:param model: the model to pre-initialize the structure for
:param recipe: the recipe to use for the compression, can be a path to a
recipe file, a raw recipe string, a recipe object, or a list
of recipe objects.
:param recipe_stage: the stage to use for the compression
:param recipe_args: the args to use for overriding the recipe defaults
:return: A ModifiedState instance holding the modified model and modifier_data
after pre-initializing the structure
"""
mod_data = self._lifecycle.pre_initialize_structure(
model=model,
recipe=recipe,
recipe_stage=recipe_stage,
recipe_args=recipe_args,
**kwargs,
)

return ModifiedState(
model=self.state.model,
optimizer=None,
loss=None,
modifier_data=mod_data,
)

def initialize(
self,
recipe: Union[str, List[str], "Recipe", List["Recipe"], None] = None,
Expand Down
13 changes: 1 addition & 12 deletions src/llmcompressor/core/session_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
"create_session",
"active_session",
"reset_session",
"pre_initialize_structure",
"initialize",
"finalize",
"apply",
Expand Down Expand Up @@ -60,16 +59,6 @@ def reset_session():
session._lifecycle.reset()


def pre_initialize_structure(**kwargs):
"""
A method to pre-initialize the structure of the model for the active session

:param kwargs: the kwargs to pass to the active session's pre-initialize-structure
method
"""
active_session().pre_initialize_structure(**kwargs)


def initialize(
recipe: Union[str, List[str], "Recipe", List["Recipe"], None] = None,
recipe_stage: Union[str, List[str], None] = None,
Expand Down Expand Up @@ -213,7 +202,7 @@ def event(cls, event_type: EventType, **kwargs) -> ModifiedState:
:param kwargs: additional kwargs to pass to the current session's event method
:return: the modified state of the active session after invoking the event
"""
if event_type in [EventType.PRE_INIT, EventType.INITIALIZE, EventType.FINALIZE]:
if event_type in [EventType.INITIALIZE, EventType.FINALIZE]:
raise ValueError(
f"Cannot invoke {event_type} event. "
f"Use the corresponding method instead."
Expand Down
19 changes: 0 additions & 19 deletions src/llmcompressor/modifiers/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,6 @@ class ModifierInterface(ABC):
Defines the contract that all modifiers must implement
"""

@property
@abstractmethod
def initialized_structure(self) -> bool:
"""
:return: True if the modifier structure has been
applied to the model
"""
raise NotImplementedError()

@property
@abstractmethod
def initialized(self) -> bool:
Expand Down Expand Up @@ -58,16 +49,6 @@ def calculate_end(self) -> float:
"""
raise NotImplementedError()

@abstractmethod
def pre_initialize_structure(self, state: State, **kwargs):
"""
Apply the modifier structure to the model

:param state: The current state of the model
:param kwargs: Additional arguments for the modifier
"""
raise NotImplementedError()

@abstractmethod
def initialize(self, state: State, **kwargs):
"""
Expand Down
31 changes: 0 additions & 31 deletions src/llmcompressor/modifiers/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,11 @@ class Modifier(ModifierInterface, HooksMixin):
end: Optional[float] = None
update: Optional[float] = None

initialized_structure_: bool = False
initialized_: bool = False
finalized_: bool = False
started_: bool = False
ended_: bool = False

@property
def initialized_structure(self) -> bool:
"""
:return: True if the modifier structure has been
applied to the model
"""
return self.initialized_structure_

@property
def initialized(self) -> bool:
"""
Expand Down Expand Up @@ -78,15 +69,6 @@ def calculate_end(self) -> float:
"""
return self.end if self.end is not None else -1

def pre_initialize_structure(self, state: State, **kwargs):
"""
:param state: The current state of the model
:param kwargs: Additional arguments for initializing the structure
of the model in question
"""
self.on_initialize_structure(state, **kwargs)
self.initialized_structure_ = True

def initialize(self, state: State, **kwargs):
"""
Initialize the modifier for the given model and state.
Expand Down Expand Up @@ -221,19 +203,6 @@ def should_end(self, event: Event):

return self.end is not None and current >= self.end

def on_initialize_structure(self, state: State, **kwargs):
"""
on_initialize_structure is called before the model is initialized
with the modifier structure.

TODO: Depreciate this function as part of the lifecycle

:param state: The current state of the model
:param kwargs: Additional arguments for initializing the structure
of the model in question
"""
pass

@abstractmethod
def on_initialize(self, state: State, **kwargs) -> bool:
"""
Expand Down
Loading