Skip to content

Commit

Permalink
updated to be compatible with latest, unit tests passing
Browse files Browse the repository at this point in the history
Signed-off-by: Brian Dellabetta <[email protected]>
  • Loading branch information
brian-dellabetta committed Feb 18, 2025
1 parent 98a5b73 commit 2611966
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 56 deletions.
26 changes: 7 additions & 19 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward
from llmcompressor.pytorch.utils import (
clear_memory,
pseudo_quantize_tensor,
tensor_forward_with_input_args,
)
Expand Down Expand Up @@ -83,12 +82,12 @@ class AWQModifier(Modifier):
example recipe:
```yaml
AWQModifier:
bits: 4
mappings: [
bits: 4
mappings: [
[["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*self_attn_layer_norm"],
[["re:.*fc1"], "re:.*final_layer_norm"]
]
ignore: ["model.decoder.final_layer_norm"]
]
ignore: ["model.decoder.final_layer_norm"]
```
:param mappings: list activation layers to smooth, and which layers to
Expand Down Expand Up @@ -166,18 +165,6 @@ def on_initialize(self, state: State, **kwargs) -> bool:

return True

def on_start(self, state: State, event: Event, **kwargs):
pass

def on_update(self, state: State, event: Event, **kwargs):
pass

def on_end(self, state: State, event: Event, **kwargs):
pass

def on_event(self, state: State, event: Event, **kwargs):
pass

def on_finalize(self, state: State, **kwargs) -> bool:
"""
Clean up by clearing the scale and mapping data
Expand Down Expand Up @@ -694,8 +681,9 @@ def _compute_best_clip(

best_max_val = torch.cat(best_max_val_all, dim=0)

clear_memory(input_feat)
clear_memory(org_out)
#TODO this appears unneeded, clear_memory removed
# clear_memory(input_feat)
# clear_memory(org_out)

return best_max_val.squeeze(1)

Expand Down
7 changes: 0 additions & 7 deletions src/llmcompressor/pytorch/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@
"pseudo_dequantize_linear",
"tensor_forward_with_input_args",
"sanitize_kwargs_for_module",
"clear_memory",
]


Expand Down Expand Up @@ -1299,9 +1298,3 @@ def pseudo_dequantize_linear(

return w


def clear_memory(value: Optional[Any] = None):
if value is not None:
del value
gc.collect()
torch.cuda.empty_cache()
42 changes: 14 additions & 28 deletions src/llmcompressor/transformers/finetune/data/pile.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from copy import deepcopy
from typing import Optional
from typing import TYPE_CHECKING

from llmcompressor.transformers.finetune.data import TextGenerationDataset
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.args import DatasetArguments


@TextGenerationDataset.register(name="pile_eval")
Expand All @@ -13,33 +17,15 @@ class PileEvalDataset(TextGenerationDataset):
:param tokenizer: tokenizer to use on dataset
"""

def __init__(self, data_args, split, tokenizer):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.text_column = "text"
data_args.dataset = "mit-han-lab/pile-val-backup"
super().__init__(
text_column="text", data_args=data_args, split=split, tokenizer=tokenizer
)

def get_raw_dataset(self, cache_dir: Optional[str] = None):
"""
Load the raw dataset from Hugging Face, using cached copy if available.
Additionally reformats the entries to fit the template.
:param cache_dir: disk location to search for cached dataset
:return: the requested dataset
"""
raw_dataset = super().get_raw_dataset(cache_dir=cache_dir)

def restructure_fn(sample):
sample["text"] = sample["text"].strip()
return sample
super().__init__(data_args=data_args, split=split, processor=processor)

raw_dataset = self.map(
raw_dataset,
function=restructure_fn,
batched=False,
remove_columns=["meta"],
num_proc=self.data_args.preprocessing_num_workers,
load_from_cache_file=not self.data_args.overwrite_cache,
desc="Restructuring Pile Dataset",
)
return raw_dataset
def dataset_template(self, sample):
return {
"text": self.processor.apply_chat_template(
sample["text"].strip(),
),
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def test_pile_eval_initializes(tiny_llama_tokenizer):
data_args.dataset,
data_args=data_args,
split=None,
tokenizer=tiny_llama_tokenizer,
processor=tiny_llama_tokenizer,
)
assert isinstance(pile_eval_manager, TextGenerationDataset)
assert isinstance(pile_eval_manager, PileEvalDataset)
assert pile_eval_manager.text_column == "text"
assert pile_eval_manager.data_args.text_column == "text"
assert not pile_eval_manager.padding
assert pile_eval_manager.max_seq_length == data_args.max_seq_length

0 comments on commit 2611966

Please sign in to comment.