diff --git a/src/llmcompressor/transformers/tracing/__init__.py b/src/llmcompressor/transformers/tracing/__init__.py index d5643efb8..29a976663 100644 --- a/src/llmcompressor/transformers/tracing/__init__.py +++ b/src/llmcompressor/transformers/tracing/__init__.py @@ -8,16 +8,18 @@ Qwen2VLForConditionalGeneration as TraceableQwen2VLForConditionalGeneration, ) from .idefics3 import ( - Idefics3ForConditionalGeneration as TraceableIdefics3ForConditionalGeneration + Idefics3ForConditionalGeneration as TraceableIdefics3ForConditionalGeneration, ) from .whisper import ( - WhisperForConditionalGeneration as TraceableWhisperForConditionalGeneration + WhisperForConditionalGeneration as TraceableWhisperForConditionalGeneration, ) from .qwen2_5_vl import ( Qwen2_5_VLForConditionalGeneration as TraceableQwen2_5_VLForConditionalGeneration ) +from .debug import get_model_class __all__ = [ + "get_model_class", "TraceableLlavaForConditionalGeneration", "TraceableMllamaForConditionalGeneration", "TraceableQwen2VLForConditionalGeneration", diff --git a/src/llmcompressor/transformers/tracing/debug.py b/src/llmcompressor/transformers/tracing/debug.py index 3b31366b1..ccce917a7 100644 --- a/src/llmcompressor/transformers/tracing/debug.py +++ b/src/llmcompressor/transformers/tracing/debug.py @@ -12,6 +12,10 @@ from llmcompressor.transformers import TextGenerationDataset from llmcompressor.args import DatasetArguments +__all__ = [ + "get_model_class" +] + def parse_args(): parser = argparse.ArgumentParser(description="Trace a model into subgraphs") diff --git a/tests/e2e/e2e_utils.py b/tests/e2e/e2e_utils.py index 41e2434ab..c77921bab 100644 --- a/tests/e2e/e2e_utils.py +++ b/tests/e2e/e2e_utils.py @@ -1,23 +1,27 @@ +import torch from datasets import load_dataset from loguru import logger -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoProcessor from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier from llmcompressor.transformers import oneshot +from llmcompressor.transformers.tracing import get_model_class from tests.test_timer.timer_utils import log_time -from tests.testing_utils import preprocess_tokenize_dataset +from tests.testing_utils import process_dataset @log_time -def _load_model_and_tokenizer( +def _load_model_and_processor( model: str, + model_class: str, device: str, ): - loaded_model = AutoModelForCausalLM.from_pretrained( + pretrained_model_class = get_model_class(model_class) + loaded_model = pretrained_model_class.from_pretrained( model, device_map=device, torch_dtype="auto" ) - tokenizer = AutoTokenizer.from_pretrained(model) - return loaded_model, tokenizer + processor = AutoProcessor.from_pretrained(model) + return loaded_model, processor @log_time @@ -30,6 +34,7 @@ def _run_oneshot(device: str, **oneshot_kwargs): def run_oneshot_for_e2e_testing( model: str, + model_class: str, device: str, num_calibration_samples: int, max_seq_length: int, @@ -43,16 +48,27 @@ def run_oneshot_for_e2e_testing( # Load model. oneshot_kwargs = {} - loaded_model, tokenizer = _load_model_and_tokenizer(model=model, device=device) + loaded_model, processor = _load_model_and_processor( + model=model, model_class=model_class, device=device + ) if dataset_id: ds = load_dataset(dataset_id, name=dataset_config, split=dataset_split) ds = ds.shuffle(seed=42).select(range(num_calibration_samples)) - ds = preprocess_tokenize_dataset(ds, tokenizer, max_seq_length) + ds = process_dataset(ds, processor, max_seq_length) oneshot_kwargs["dataset"] = ds oneshot_kwargs["max_seq_length"] = max_seq_length oneshot_kwargs["num_calibration_samples"] = num_calibration_samples + # Define a data collator for multimodal inputs. + if "flickr30k" in dataset_id: + + def data_collator(batch): + assert len(batch) == 1 + return {key: torch.tensor(value) for key, value in batch[0].items()} + + oneshot_kwargs["data_collator"] = data_collator + oneshot_kwargs["model"] = loaded_model if recipe: oneshot_kwargs["recipe"] = recipe @@ -72,4 +88,4 @@ def run_oneshot_for_e2e_testing( logger.info("ONESHOT KWARGS", oneshot_kwargs) _run_oneshot(device=device, **oneshot_kwargs) - return oneshot_kwargs["model"], tokenizer + return oneshot_kwargs["model"], processor diff --git a/tests/e2e/vLLM/lm_eval_configs/fp8_dynamic_per_token.yaml b/tests/e2e/vLLM/lm_eval_configs/fp8_dynamic_per_token.yaml deleted file mode 100644 index fc610bae9..000000000 --- a/tests/e2e/vLLM/lm_eval_configs/fp8_dynamic_per_token.yaml +++ /dev/null @@ -1,8 +0,0 @@ -cadence: "weekly" -model: meta-llama/Meta-Llama-3-8B-Instruct -scheme: FP8_DYNAMIC -num_fewshot: 5 -limit: 1000 -task: "gsm8k" -exact_match,flexible-extract: 0.75 -exact_match,strict-match: 0.75 diff --git a/tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_dynamic_per_token.yaml b/tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_dynamic_per_token.yaml index 367437e5a..c3ecdea86 100644 --- a/tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_dynamic_per_token.yaml +++ b/tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_dynamic_per_token.yaml @@ -3,7 +3,7 @@ quant_stage: SmoothQuantModifier: smoothing_strength: 0.8 GPTQModifier: - ignore: [lm_head] + ignore: ["lm_head", "re:vision_tower.*", "re:multi_modal_projector.*", "re:visual.*", "re:vision_model.*"] config_groups: group_0: weights: {num_bits: 8, type: int, symmetric: true, strategy: channel} diff --git a/tests/e2e/vLLM/recipes/actorder/recipe_w4a16_actorder_weight.yaml b/tests/e2e/vLLM/recipes/actorder/recipe_w4a16_actorder_weight.yaml index 0c8476883..4efa211a2 100644 --- a/tests/e2e/vLLM/recipes/actorder/recipe_w4a16_actorder_weight.yaml +++ b/tests/e2e/vLLM/recipes/actorder/recipe_w4a16_actorder_weight.yaml @@ -1,7 +1,7 @@ quant_stage: quant_modifiers: GPTQModifier: - ignore: ["lm_head"] + ignore: ["lm_head", "re:vision_tower.*", "re:multi_modal_projector.*", "re:visual.*", "re:vision_model.*"] config_groups: group_0: weights: diff --git a/tests/e2e/vLLM/test_vllm.py b/tests/e2e/vLLM/test_vllm.py index 5ea0be43c..5b419cdc2 100644 --- a/tests/e2e/vLLM/test_vllm.py +++ b/tests/e2e/vLLM/test_vllm.py @@ -71,6 +71,7 @@ def set_up(self): pytest.skip("Skipping test; cadence mismatch") self.model = eval_config["model"] + self.model_class = eval_config.get("model_class", "AutoModelForCausalLM") self.scheme = eval_config.get("scheme") self.dataset_id = eval_config.get("dataset_id") self.dataset_config = eval_config.get("dataset_config") @@ -104,6 +105,7 @@ def test_vllm(self): self.save_dir = self.model.split("/")[1] + f"-{self.scheme}" oneshot_model, tokenizer = run_oneshot_for_e2e_testing( model=self.model, + model_class=self.model_class, device=self.device, num_calibration_samples=self.num_calibration_samples, max_seq_length=self.max_seq_length, diff --git a/tests/integration/__init__.py b/tests/lmeval/__init__.py similarity index 100% rename from tests/integration/__init__.py rename to tests/lmeval/__init__.py diff --git a/tests/lmeval/configs/fp8_dynamic_per_token.yaml b/tests/lmeval/configs/fp8_dynamic_per_token.yaml new file mode 100644 index 000000000..b89bb4552 --- /dev/null +++ b/tests/lmeval/configs/fp8_dynamic_per_token.yaml @@ -0,0 +1,7 @@ +cadence: "weekly" +model: meta-llama/Meta-Llama-3-8B-Instruct +scheme: FP8_DYNAMIC +lmeval: + metrics: + exact_match,flexible-extract: 0.75 + exact_match,strict-match: 0.75 diff --git a/tests/e2e/vLLM/lm_eval_configs/fp8_static_per_tensor.yaml b/tests/lmeval/configs/fp8_static_per_tensor.yaml similarity index 56% rename from tests/e2e/vLLM/lm_eval_configs/fp8_static_per_tensor.yaml rename to tests/lmeval/configs/fp8_static_per_tensor.yaml index 0b6d42a46..e4d31cef2 100644 --- a/tests/e2e/vLLM/lm_eval_configs/fp8_static_per_tensor.yaml +++ b/tests/lmeval/configs/fp8_static_per_tensor.yaml @@ -1,10 +1,9 @@ cadence: "weekly" model: meta-llama/Meta-Llama-3-8B-Instruct scheme: FP8 -num_fewshot: 5 -limit: 1000 -task: "gsm8k" dataset_id: HuggingFaceH4/ultrachat_200k dataset_split: train_sft -exact_match,flexible-extract: 0.75 -exact_match,strict-match: 0.75 +lmeval: + metrics: + exact_match,flexible-extract: 0.75 + exact_match,strict-match: 0.75 diff --git a/tests/e2e/vLLM/lm_eval_configs/int8_w8a8_dynamic_per_token.yaml b/tests/lmeval/configs/int8_w8a8_dynamic_per_token.yaml similarity index 69% rename from tests/e2e/vLLM/lm_eval_configs/int8_w8a8_dynamic_per_token.yaml rename to tests/lmeval/configs/int8_w8a8_dynamic_per_token.yaml index 446ca1e7f..3e6c364e0 100644 --- a/tests/e2e/vLLM/lm_eval_configs/int8_w8a8_dynamic_per_token.yaml +++ b/tests/lmeval/configs/int8_w8a8_dynamic_per_token.yaml @@ -2,10 +2,9 @@ cadence: "weekly" model: meta-llama/Meta-Llama-3-8B-Instruct scheme: INT8_dyn_per_token recipe: tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_dynamic_per_token.yaml -num_fewshot: 5 -limit: 1000 -task: "gsm8k" dataset_id: HuggingFaceH4/ultrachat_200k dataset_split: train_sft -exact_match,flexible-extract: 0.77 -exact_match,strict-match: 0.76 +lmeval: + metrics: + exact_match,flexible-extract: 0.77 + exact_match,strict-match: 0.76 \ No newline at end of file diff --git a/tests/lmeval/configs/vl_fp8_dynamic_per_token.yaml b/tests/lmeval/configs/vl_fp8_dynamic_per_token.yaml new file mode 100644 index 000000000..3ae64f093 --- /dev/null +++ b/tests/lmeval/configs/vl_fp8_dynamic_per_token.yaml @@ -0,0 +1,16 @@ +cadence: weekly +model: Qwen/Qwen2-VL-2B-Instruct +model_class: TraceableQwen2VLForConditionalGeneration +scheme: FP8_DYNAMIC +lmeval: + model: "hf-multimodal" + model_args: + dtype: bfloat16 + add_bos_token: True + convert_img_format: True + task: mmmu_val_economics + num_fewshot: 0 + limit: 1000 + batch_size: 8 + metrics: + acc,none: 0.333 diff --git a/tests/lmeval/configs/vl_int8_w8a8_dynamic_per_token.yaml b/tests/lmeval/configs/vl_int8_w8a8_dynamic_per_token.yaml new file mode 100644 index 000000000..22b5d8419 --- /dev/null +++ b/tests/lmeval/configs/vl_int8_w8a8_dynamic_per_token.yaml @@ -0,0 +1,19 @@ +cadence: "weekly" +model: llava-hf/llava-1.5-7b-hf +model_class: TraceableLlavaForConditionalGeneration +scheme: INT8_dyn_per_token +recipe: tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_dynamic_per_token.yaml +dataset_id: lmms-lab/flickr30k +dataset_split: "test[:512]" +lmeval: + model: "hf-multimodal" + model_args: + dtype: bfloat16 + add_bos_token: True + convert_img_format: True + task: mmmu_val_economics + num_fewshot: 0 + limit: 1000 + metrics: + acc,none: 0.233 + batch_size: 8 \ No newline at end of file diff --git a/tests/lmeval/configs/vl_w4a16_actorder_weight.yaml b/tests/lmeval/configs/vl_w4a16_actorder_weight.yaml new file mode 100644 index 000000000..b7fa161c8 --- /dev/null +++ b/tests/lmeval/configs/vl_w4a16_actorder_weight.yaml @@ -0,0 +1,19 @@ +cadence: "weekly" +model: Qwen/Qwen2-VL-2B-Instruct +model_class: TraceableQwen2VLForConditionalGeneration +recipe: tests/e2e/vLLM/recipes/actorder/recipe_w4a16_actorder_weight.yaml +dataset_id: lmms-lab/flickr30k +dataset_split: "test[:512]" +scheme: W4A16_actorder_group +lmeval: + model: "hf-multimodal" + model_args: + dtype: bfloat16 + add_bos_token: True + convert_img_format: True + task: mmmu_val_economics + num_fewshot: 0 + limit: 1000 + metrics: + acc,none: 0.4 + batch_size: 4 \ No newline at end of file diff --git a/tests/e2e/vLLM/lm_eval_configs/w4a16_actorder_weight.yaml b/tests/lmeval/configs/w4a16_actorder_weight.yaml similarity index 59% rename from tests/e2e/vLLM/lm_eval_configs/w4a16_actorder_weight.yaml rename to tests/lmeval/configs/w4a16_actorder_weight.yaml index ca82bb44f..612274218 100644 --- a/tests/e2e/vLLM/lm_eval_configs/w4a16_actorder_weight.yaml +++ b/tests/lmeval/configs/w4a16_actorder_weight.yaml @@ -1,11 +1,10 @@ cadence: "weekly" model: meta-llama/Meta-Llama-3-8B-Instruct +scheme: W4A16_actorder_group recipe: tests/e2e/vLLM/recipes/actorder/recipe_w4a16_actorder_weight.yaml -num_fewshot: 5 -limit: 1000 -task: "gsm8k" dataset_id: HuggingFaceH4/ultrachat_200k dataset_split: train_sft -exact_match,flexible-extract: 0.72 -exact_match,strict-match: 0.72 -scheme: W4A16_actorder_group \ No newline at end of file +lmeval: + metrics: + exact_match,flexible-extract: 0.72 + exact_match,strict-match: 0.72 diff --git a/tests/e2e/vLLM/lm_eval_configs/w4a16_grouped_quant.yaml b/tests/lmeval/configs/w4a16_grouped_quant.yaml similarity index 53% rename from tests/e2e/vLLM/lm_eval_configs/w4a16_grouped_quant.yaml rename to tests/lmeval/configs/w4a16_grouped_quant.yaml index a4c7b6244..45728a5b6 100644 --- a/tests/e2e/vLLM/lm_eval_configs/w4a16_grouped_quant.yaml +++ b/tests/lmeval/configs/w4a16_grouped_quant.yaml @@ -1,11 +1,10 @@ cadence: "weekly" model: meta-llama/Meta-Llama-3-8B-Instruct -num_fewshot: 5 -limit: 1000 -task: "gsm8k" -exact_match,flexible-extract: 0.72 -exact_match,strict-match: 0.72 scheme: W4A16 dataset_id: HuggingFaceH4/ultrachat_200k dataset_split: train_sft -quant_type: "GPTQ" \ No newline at end of file +quant_type: "GPTQ" +lmeval: + metrics: + exact_match,flexible-extract: 0.72 + exact_match,strict-match: 0.72 diff --git a/tests/e2e/vLLM/test_lmeval.py b/tests/lmeval/test_lmeval.py similarity index 76% rename from tests/e2e/vLLM/test_lmeval.py rename to tests/lmeval/test_lmeval.py index 4e11123a5..e5b9efcef 100644 --- a/tests/e2e/vLLM/test_lmeval.py +++ b/tests/lmeval/test_lmeval.py @@ -6,11 +6,23 @@ import pytest import yaml from loguru import logger +from pydantic import BaseModel from llmcompressor.core import active_session from tests.e2e.e2e_utils import run_oneshot_for_e2e_testing from tests.examples.utils import requires_gpu_count + +class LmEvalConfig(BaseModel): + model: str = "hf" + model_args: dict = {"add_bos_token": True, "dtype": "bfloat16"} + task: str = "gsm8k" + num_fewshot: int = 5 + limit: int = 1000 + metrics: dict + batch_size: int = 100 + + try: import lm_eval @@ -51,6 +63,8 @@ def set_up(self): pytest.skip("Skipping test; cadence mismatch") self.model = eval_config["model"] + self.model_class = eval_config.get("model_class", "AutoModelForCausalLM") + self.lmeval = LmEvalConfig(**eval_config.get("lmeval", {})) self.scheme = eval_config.get("scheme") self.dataset_id = eval_config.get("dataset_id") self.dataset_config = eval_config.get("dataset_config") @@ -58,11 +72,6 @@ def set_up(self): self.recipe = eval_config.get("recipe") self.quant_type = eval_config.get("quant_type") self.save_dir = eval_config.get("save_dir") - self.task = eval_config.get("task") - self.num_fewshot = eval_config.get("num_fewshot") - self.limit = eval_config.get("limit") - self.exact_flex = eval_config.get("exact_match,flexible-extract") - self.exact_strict = eval_config.get("exact_match,strict-match") logger.info("========== RUNNING ==============") logger.info(self.scheme) @@ -76,8 +85,9 @@ def test_lm_eval(self): self.set_up() if not self.save_dir: self.save_dir = self.model.split("/")[1] + f"-{self.scheme}" - oneshot_model, tokenizer = run_oneshot_for_e2e_testing( + oneshot_model, processor = run_oneshot_for_e2e_testing( model=self.model, + model_class=self.model_class, device=self.device, num_calibration_samples=self.num_calibration_samples, max_seq_length=self.max_seq_length, @@ -91,7 +101,7 @@ def test_lm_eval(self): logger.info("================= SAVING TO DISK ======================") oneshot_model.save_pretrained(self.save_dir) - tokenizer.save_pretrained(self.save_dir) + processor.save_pretrained(self.save_dir) recipe_path = os.path.join(self.save_dir, "recipe.yaml") # Use the session to fetch the recipe; @@ -104,26 +114,26 @@ def test_lm_eval(self): logger.info("================= Running LM Eval ======================") - model_args = f"pretrained={self.save_dir},add_bos_token=True" + model_args = {"pretrained": self.save_dir} + model_args.update(self.lmeval.model_args) results = lm_eval.simple_evaluate( - model="hf", + model=self.lmeval.model, model_args=model_args, - tasks=[self.task], - num_fewshot=self.num_fewshot, - limit=self.limit, + tasks=[self.lmeval.task], + num_fewshot=self.lmeval.num_fewshot, + limit=self.lmeval.limit, device="cuda:0", - batch_size=100, + batch_size=self.lmeval.batch_size, ) - metrics = results["results"][self.task] - exact_match_strict = metrics.get("exact_match,strict-match") - exact_match_flex = metrics.get("exact_match,flexible-extract") - logger.info("Exact Match, Strict") - logger.info(exact_match_strict) - logger.info("Exact Match, Flex") - logger.info(exact_match_flex) - assert numpy.isclose(exact_match_strict, self.exact_strict, rtol=0.05) - assert numpy.isclose(exact_match_flex, self.exact_flex, rtol=0.05) + metrics = results["results"][self.lmeval.task] + for metric, expected_val in self.lmeval.metrics.items(): + actual_val = metrics.get(metric) + logger.info( + f"Comparing {metric}: Expected {expected_val}, Got {actual_val}" + ) + assert numpy.isclose(expected_val, actual_val, rtol=0.05) + self.tear_down() def tear_down(self): diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 257506784..07bb58b99 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -9,7 +9,7 @@ import yaml from datasets import Dataset -from transformers import PreTrainedTokenizer +from transformers import ProcessorMixin from tests.data import CustomTestConfig, TestConfig @@ -125,8 +125,8 @@ def run_cli_command(cmd: List[str], cwd: Optional[Union[str, Path]] = None): return run(cmd, stdout=PIPE, stderr=STDOUT, check=False, encoding="utf-8", cwd=cwd) -def preprocess_tokenize_dataset( - ds: Dataset, tokenizer: PreTrainedTokenizer, max_seq_length: int +def process_dataset( + ds: Dataset, processor: ProcessorMixin, max_seq_length: int ) -> Dataset: """ Helper function to preprocess and tokenize a dataset according to presets @@ -138,11 +138,8 @@ def preprocess_tokenize_dataset( ds_name = ds.info.dataset_name.lower() if ds_name == "gsm8k": - def preprocess(example): - return example - - def tokenize(sample): - return tokenizer( + def process(sample): + return processor( sample["question"], padding=False, max_length=max_seq_length, @@ -152,17 +149,12 @@ def tokenize(sample): elif ds_name == "ultrachat_200k": - def preprocess(example): - return { - "text": tokenizer.apply_chat_template( - example["messages"], + def process(sample): + return processor( + processor.apply_chat_template( + sample["messages"], tokenize=False, - ) - } - - def tokenize(sample): - return tokenizer( - sample["text"], + ), padding=False, max_length=max_seq_length, truncation=True, @@ -171,17 +163,12 @@ def tokenize(sample): elif ds_name == "llm_compression_calibration": - def preprocess(example): - return { - "text": tokenizer.apply_chat_template( - example["text"], + def process(sample): + return processor( + processor.apply_chat_template( + sample["text"], tokenize=False, - ) - } - - def tokenize(sample): - return tokenizer( - sample["text"], + ), padding=False, max_length=max_seq_length, truncation=True, @@ -190,17 +177,12 @@ def tokenize(sample): elif ds_name == "open-platypus": # use the output rather than the instruction - def preprocess(example): - return { - "text": tokenizer.apply_chat_template( - example["output"], + def process(sample): + return processor( + processor.apply_chat_template( + sample["output"], tokenize=False, - ) - } - - def tokenize(sample): - return tokenizer( - sample["text"], + ), padding=False, max_length=max_seq_length, truncation=True, @@ -209,32 +191,46 @@ def tokenize(sample): elif ds_name == "slimorca-deduped-cleaned-corrected": # find the first element corresponding to a message from a human - def preprocess(example): + def process(sample): conversation_idx = 0 - for idx, conversation in enumerate(example["conversations"]): + for idx, conversation in enumerate(sample["conversations"]): if conversation["from"] == "human": conversation_idx = idx break - return { - "text": tokenizer.apply_chat_template( - example["conversations"][conversation_idx]["value"], + return processor( + processor.apply_chat_template( + sample["conversations"][conversation_idx]["value"], tokenize=False, - ) - } - - def tokenize(sample): - return tokenizer( - sample["text"], + ), padding=False, max_length=max_seq_length, truncation=True, add_special_tokens=False, ) + elif ds_name == "flickr30k": + + def process(sample): + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What does the image show?"}, + ], + } + ] + return { + "text": processor.apply_chat_template( + messages, + add_generation_prompt=True, + ), + "images": sample["image"], + } + else: raise NotImplementedError(f"Cannot preprocess dataset {ds.info.dataset_name}") - ds = ds.map(preprocess) - ds = ds.map(tokenize, remove_columns=ds.column_names) + ds = ds.map(process, remove_columns=ds.column_names) return ds