Skip to content

Commit

Permalink
switch to using HooksMixin api
Browse files Browse the repository at this point in the history
Signed-off-by: Brian Dellabetta <[email protected]>
  • Loading branch information
brian-dellabetta committed Feb 18, 2025
1 parent 2611966 commit 88aeab8
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 23 deletions.
44 changes: 25 additions & 19 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from compressed_tensors.utils.offload import is_module_offloaded
from loguru import logger
from torch.nn import Module
from tqdm import tqdm

from llmcompressor.core import Event, State
from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward
from llmcompressor.pytorch.utils import (
pseudo_quantize_tensor,
tensor_forward_with_input_args,
)
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
from llmcompressor.utils.helpers import calibration_forward_context
from llmcompressor.utils.pytorch.module import (
get_layer,
get_layers,
Expand Down Expand Up @@ -124,14 +126,10 @@ class AWQModifier(Modifier):
duo_scaling: bool = True
apply_clip: bool = True

hooks_: Optional[List] = None
resolved_mappings_: Optional[List] = None
resolved_mappings_: Optional[List[AWQMapping]] = None
scales_: Optional[Dict] = None
module_kwargs_: Optional[Dict] = None

def on_initialize_structure(self, state: State, **kwargs):
pass # nothing needed for this modifier

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Initialize and run AWQ on the given state
Expand All @@ -155,7 +153,6 @@ def on_initialize(self, state: State, **kwargs) -> bool:
self.scales_ = {}

calibration_dataloader = state.data.calib
self.hooks_ = []

self._get_module_kwargs(state.model, calibration_dataloader)
self._setup_scale_hooks()
Expand All @@ -179,7 +176,7 @@ def on_finalize(self, state: State, **kwargs) -> bool:

return True

def _resolve_mappings(self, model: Module) -> List:
def _resolve_mappings(self, model: Module) -> List[AWQMapping]:
"""
Transforms the list of activations to smooth and their corresponding weights
into AWQMapping objects, resolving regular expressions.
Expand Down Expand Up @@ -252,7 +249,7 @@ def hook_fn(module, inp, out):
# is enough, as other balance layers
# get the same input
layer = mapping.balance_layers[0]
self.hooks_.append(layer.register_forward_hook(create_hook_fn(name)))
self.register_hook(layer, create_hook_fn(name), "forward")

@torch.no_grad()
def _calibrate(self, model: Module, calibration_dataloader: List):
Expand All @@ -271,17 +268,16 @@ def _calibrate(self, model: Module, calibration_dataloader: List):
" CompressionSession to run the AWQ modifier"
)

run_calibration_forward(
model,
calibration_dataloader,
self.num_calibration_steps,
self.calibration_function,
)
with calibration_forward_context(model):
run_calibration_forward(
model,
calibration_dataloader,
self.num_calibration_steps,
self.calibration_function,
)

# remove the hooks now that we are done calibrating
for hook in self.hooks_:
hook.remove()
del self.hooks_
self.remove_hooks()

def _concat_collected_activations(self):
"""
Expand Down Expand Up @@ -370,6 +366,13 @@ def _apply_smoothing(self, model: Module):

@torch.no_grad()
def smooth(module):
# TODO calls to module._hf_hook.pre_forward(module) and
# module._hf_hook.post_forward(module, None) appear a couple places
# in SmoothQuantModifier, do we need them anywhere else?
offloaded = is_module_offloaded(module)
if offloaded:
module._hf_hook.pre_forward(module)

if module in balance_layers:
module.weight.mul_(scales.view(1, -1).to(module.weight.device))
elif module == smooth_layer:
Expand All @@ -380,6 +383,9 @@ def smooth(module):
if hasattr(module, "bias") and module.bias is not None:
module.bias.div_(scales.to(module.bias.device))

if offloaded:
module._hf_hook.post_forward(module, None)

parent = get_fsdp_parent(mapping.smooth_name, model)
if parent is not None:
parent.apply(smooth)
Expand Down Expand Up @@ -681,7 +687,7 @@ def _compute_best_clip(

best_max_val = torch.cat(best_max_val_all, dim=0)

#TODO this appears unneeded, clear_memory removed
# TODO this appears unneeded, clear_memory removed
# clear_memory(input_feat)
# clear_memory(org_out)

Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class SmoothQuantModifier(Modifier):
num_calibration_steps: Optional[int] = None
calibration_function: Optional[Callable] = None

resolved_mappings_: Optional[List] = None
resolved_mappings_: Optional[List[SmoothQuantMapping]] = None
scales_: Optional[Dict] = None

def on_initialize(self, state: State, **kwargs) -> bool:
Expand Down Expand Up @@ -166,7 +166,7 @@ def _infer_mappings_from_model(
)

@handle_mapping_resolution_errors
def _resolve_mappings(self, model: Module) -> List:
def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
"""
Transforms the list of activations to smooth and their corresponding weights
into SmoothQuantMapping objects, resolving regular expressions.
Expand Down
2 changes: 0 additions & 2 deletions src/llmcompressor/pytorch/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

import functools
import gc
import inspect
import os
import random
Expand Down Expand Up @@ -1297,4 +1296,3 @@ def pseudo_dequantize_linear(
w = w.weight.data * scales

return w

0 comments on commit 88aeab8

Please sign in to comment.