Skip to content

Commit

Permalink
squashed for easier rebase
Browse files Browse the repository at this point in the history
Signed-off-by: Brian Dellabetta <[email protected]>
  • Loading branch information
brian-dellabetta committed Feb 17, 2025
1 parent 9258eb3 commit 20c6155
Show file tree
Hide file tree
Showing 17 changed files with 183 additions and 96 deletions.
6 changes: 4 additions & 2 deletions src/llmcompressor/transformers/tracing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
Qwen2VLForConditionalGeneration as TraceableQwen2VLForConditionalGeneration,
)
from .idefics3 import (
Idefics3ForConditionalGeneration as TraceableIdefics3ForConditionalGeneration
Idefics3ForConditionalGeneration as TraceableIdefics3ForConditionalGeneration,
)
from .whisper import (
WhisperForConditionalGeneration as TraceableWhisperForConditionalGeneration
WhisperForConditionalGeneration as TraceableWhisperForConditionalGeneration,
)
from .debug import get_model_class

__all__ = [
"get_model_class",
"TraceableLlavaForConditionalGeneration",
"TraceableMllamaForConditionalGeneration",
"TraceableQwen2VLForConditionalGeneration",
Expand Down
38 changes: 28 additions & 10 deletions tests/e2e/e2e_utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
import torch
from datasets import load_dataset
from loguru import logger
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoProcessor

from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
from llmcompressor.transformers import oneshot
from tests.test_timer.timer_utils import log_time
from tests.testing_utils import preprocess_tokenize_dataset
from llmcompressor.transformers.tracing import get_model_class
from tests.testing_utils import process_dataset


@log_time
def _load_model_and_tokenizer(
def _load_model_and_processor(
model: str,
model_class: str,
device: str,
):
loaded_model = AutoModelForCausalLM.from_pretrained(
model, device_map=device, torch_dtype="auto"
pretrained_model_class = get_model_class(model_class)
loaded_model = pretrained_model_class.from_pretrained(
model,
device_map=device,
torch_dtype="auto",
trust_remote_code=True,
_attn_implementation="eager",
)
tokenizer = AutoTokenizer.from_pretrained(model)
return loaded_model, tokenizer
processor = AutoProcessor.from_pretrained(model, trust_remote_code=True)
return loaded_model, processor


@log_time
Expand All @@ -30,6 +38,7 @@ def _run_oneshot(device: str, **oneshot_kwargs):

def run_oneshot_for_e2e_testing(
model: str,
model_class: str,
device: str,
num_calibration_samples: int,
max_seq_length: int,
Expand All @@ -43,16 +52,25 @@ def run_oneshot_for_e2e_testing(
# Load model.
oneshot_kwargs = {}

loaded_model, tokenizer = _load_model_and_tokenizer(model=model, device=device)
loaded_model, processor = _load_model_and_processor(model=model, model_class=model_class, device=device)

if dataset_id:
ds = load_dataset(dataset_id, name=dataset_config, split=dataset_split)
ds = ds.shuffle(seed=42).select(range(num_calibration_samples))
ds = preprocess_tokenize_dataset(ds, tokenizer, max_seq_length)
ds = process_dataset(ds, processor, max_seq_length)
oneshot_kwargs["dataset"] = ds
oneshot_kwargs["max_seq_length"] = max_seq_length
oneshot_kwargs["num_calibration_samples"] = num_calibration_samples

# TODO better conditional on when multimodal data-collator should be added
if "flickr30k" in dataset_id:
# Define a oneshot data collator for multimodal inputs.
def data_collator(batch):
assert len(batch) == 1
return {key: torch.tensor(value) for key, value in batch[0].items()}

oneshot_kwargs["data_collator"] = data_collator

oneshot_kwargs["model"] = loaded_model
if recipe:
oneshot_kwargs["recipe"] = recipe
Expand All @@ -72,4 +90,4 @@ def run_oneshot_for_e2e_testing(
logger.info("ONESHOT KWARGS", oneshot_kwargs)
_run_oneshot(device=device, **oneshot_kwargs)

return oneshot_kwargs["model"], tokenizer
return oneshot_kwargs["model"], processor
8 changes: 0 additions & 8 deletions tests/e2e/vLLM/lm_eval_configs/fp8_dynamic_per_token.yaml

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ quant_stage:
SmoothQuantModifier:
smoothing_strength: 0.8
GPTQModifier:
ignore: [lm_head]
ignore: ["lm_head", "re:vision_tower.*", "re:multi_modal_projector.*", "re:visual.*", "re:vision_model.*"]
config_groups:
group_0:
weights: {num_bits: 8, type: int, symmetric: true, strategy: channel}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
quant_stage:
quant_modifiers:
GPTQModifier:
ignore: ["lm_head"]
ignore: ["lm_head", "re:vision_tower.*", "re:multi_modal_projector.*", "re:visual.*", "re:vision_model.*"]
config_groups:
group_0:
weights:
Expand Down
2 changes: 2 additions & 0 deletions tests/e2e/vLLM/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def set_up(self):
pytest.skip("Skipping test; cadence mismatch")

self.model = eval_config["model"]
self.model_class = eval_config.get("model_class", "AutoModelForCausalLM")
self.scheme = eval_config.get("scheme")
self.dataset_id = eval_config.get("dataset_id")
self.dataset_config = eval_config.get("dataset_config")
Expand Down Expand Up @@ -104,6 +105,7 @@ def test_vllm(self):
self.save_dir = self.model.split("/")[1] + f"-{self.scheme}"
oneshot_model, tokenizer = run_oneshot_for_e2e_testing(
model=self.model,
model_class=self.model_class,
device=self.device,
num_calibration_samples=self.num_calibration_samples,
max_seq_length=self.max_seq_length,
Expand Down
File renamed without changes.
13 changes: 13 additions & 0 deletions tests/lmeval/configs/fp8_dynamic_per_token.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
cadence: "weekly"
model: meta-llama/Meta-Llama-3-8B-Instruct
scheme: FP8_DYNAMIC
lmeval:
model_args:
dtype: bfloat16
add_bos_token: True
task: "gsm8k"
num_fewshot: 5
limit: 1000
metrics:
exact_match,flexible-extract: 0.75
exact_match,strict-match: 0.75
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ task: "gsm8k"
dataset_id: HuggingFaceH4/ultrachat_200k
dataset_split: train_sft
exact_match,flexible-extract: 0.77
exact_match,strict-match: 0.76
exact_match,strict-match: 0.76
16 changes: 16 additions & 0 deletions tests/lmeval/configs/vl_fp8_dynamic_per_token.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
cadence: weekly
model: Qwen/Qwen2-VL-2B-Instruct
model_class: TraceableQwen2VLForConditionalGeneration
scheme: FP8_DYNAMIC
lmeval:
model: "hf-multimodal"
model_args:
dtype: bfloat16
add_bos_token: True
convert_img_format: True
task: mmmu_val_economics
num_fewshot: 0
limit: 1000
batch_size: 8
metrics:
acc,none: 0.333
19 changes: 19 additions & 0 deletions tests/lmeval/configs/vl_int8_w8a8_dynamic_per_token.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
cadence: "weekly"
model: llava-hf/llava-1.5-7b-hf
model_class: TraceableLlavaForConditionalGeneration
scheme: INT8_dyn_per_token
recipe: tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_dynamic_per_token.yaml
dataset_id: lmms-lab/flickr30k
dataset_split: "test[:512]"
lmeval:
model: "hf-multimodal"
model_args:
dtype: bfloat16
add_bos_token: True
convert_img_format: True
task: mmmu_val_economics
num_fewshot: 0
limit: 1000
metrics:
acc,none: 0.233
batch_size: 8
19 changes: 19 additions & 0 deletions tests/lmeval/configs/vl_w4a16_actorder_weight.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
cadence: "weekly"
model: Qwen/Qwen2-VL-2B-Instruct
model_class: TraceableQwen2VLForConditionalGeneration
recipe: tests/e2e/vLLM/recipes/actorder/recipe_w4a16_actorder_weight.yaml
dataset_id: lmms-lab/flickr30k
dataset_split: "test[:512]"
scheme: W4A16_actorder_group
lmeval:
model: "hf-multimodal"
model_args:
dtype: bfloat16
add_bos_token: True
convert_img_format: True
task: mmmu_val_economics
num_fewshot: 0
limit: 1000
metrics:
acc,none: 0.333
batch_size: 4
54 changes: 32 additions & 22 deletions tests/e2e/vLLM/test_lmeval.py → tests/lmeval/test_lmeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,23 @@
import pytest
import yaml
from loguru import logger
from pydantic import BaseModel

from llmcompressor.core import active_session
from tests.e2e.e2e_utils import run_oneshot_for_e2e_testing
from tests.examples.utils import requires_gpu_count


class LmEvalConfig(BaseModel):
model: str = "hf"
model_args: dict = {"add_bos_token": True, "dtype": "bfloat16"}
task: str
num_fewshot: int
limit: int
metrics: dict
batch_size: int = 100


try:
import lm_eval

Expand Down Expand Up @@ -51,18 +63,15 @@ def set_up(self):
pytest.skip("Skipping test; cadence mismatch")

self.model = eval_config["model"]
self.model_class = eval_config.get("model_class", "AutoModelForCausalLM")
self.lmeval = LmEvalConfig(**eval_config.get("lmeval"))
self.scheme = eval_config.get("scheme")
self.dataset_id = eval_config.get("dataset_id")
self.dataset_config = eval_config.get("dataset_config")
self.dataset_split = eval_config.get("dataset_split")
self.recipe = eval_config.get("recipe")
self.quant_type = eval_config.get("quant_type")
self.save_dir = eval_config.get("save_dir")
self.task = eval_config.get("task")
self.num_fewshot = eval_config.get("num_fewshot")
self.limit = eval_config.get("limit")
self.exact_flex = eval_config.get("exact_match,flexible-extract")
self.exact_strict = eval_config.get("exact_match,strict-match")

logger.info("========== RUNNING ==============")
logger.info(self.scheme)
Expand All @@ -76,8 +85,9 @@ def test_lm_eval(self):
self.set_up()
if not self.save_dir:
self.save_dir = self.model.split("/")[1] + f"-{self.scheme}"
oneshot_model, tokenizer = run_oneshot_for_e2e_testing(
oneshot_model, processor = run_oneshot_for_e2e_testing(
model=self.model,
model_class=self.model_class,
device=self.device,
num_calibration_samples=self.num_calibration_samples,
max_seq_length=self.max_seq_length,
Expand All @@ -91,7 +101,7 @@ def test_lm_eval(self):

logger.info("================= SAVING TO DISK ======================")
oneshot_model.save_pretrained(self.save_dir)
tokenizer.save_pretrained(self.save_dir)
processor.save_pretrained(self.save_dir)
recipe_path = os.path.join(self.save_dir, "recipe.yaml")

# Use the session to fetch the recipe;
Expand All @@ -104,26 +114,26 @@ def test_lm_eval(self):

logger.info("================= Running LM Eval ======================")

model_args = f"pretrained={self.save_dir},add_bos_token=True"
model_args = {"pretrained": self.save_dir}
model_args.update(self.lmeval.model_args)
results = lm_eval.simple_evaluate(
model="hf",
model=self.lmeval.model,
model_args=model_args,
tasks=[self.task],
num_fewshot=self.num_fewshot,
limit=self.limit,
tasks=[self.lmeval.task],
num_fewshot=self.lmeval.num_fewshot,
limit=self.lmeval.limit,
device="cuda:0",
batch_size=100,
batch_size=self.lmeval.batch_size,
)

metrics = results["results"][self.task]
exact_match_strict = metrics.get("exact_match,strict-match")
exact_match_flex = metrics.get("exact_match,flexible-extract")
logger.info("Exact Match, Strict")
logger.info(exact_match_strict)
logger.info("Exact Match, Flex")
logger.info(exact_match_flex)
assert numpy.isclose(exact_match_strict, self.exact_strict, rtol=0.05)
assert numpy.isclose(exact_match_flex, self.exact_flex, rtol=0.05)
metrics = results["results"][self.lmeval.task]
for metric, expected_val in self.lmeval.metrics.items():
actual_val = metrics.get(metric)
logger.info(
f"Comparing {metric}: Expected {expected_val}, Got {actual_val}"
)
assert numpy.isclose(expected_val, actual_val, rtol=0.05)

self.tear_down()

def tear_down(self):
Expand Down
Loading

0 comments on commit 20c6155

Please sign in to comment.