Skip to content

Commit

Permalink
refactored, multimodal working with num_fewshot=0
Browse files Browse the repository at this point in the history
Signed-off-by: Brian Dellabetta <[email protected]>
  • Loading branch information
brian-dellabetta committed Feb 13, 2025
1 parent 5a99384 commit 24c757e
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 31 deletions.
15 changes: 10 additions & 5 deletions tests/lmeval/configs/fp8_dynamic_per_token.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
cadence: "weekly"
model: meta-llama/Meta-Llama-3-8B-Instruct
scheme: FP8_DYNAMIC
num_fewshot: 5
limit: 1000
task: "gsm8k"
exact_match,flexible-extract: 0.75
exact_match,strict-match: 0.75
lmeval:
model_args:
dtype: bfloat16
add_bos_token: True
task: "gsm8k"
num_fewshot: 5
limit: 1000
metrics:
exact_match,flexible-extract: 0.75
exact_match,strict-match: 0.75
16 changes: 11 additions & 5 deletions tests/lmeval/configs/vl_fp8_dynamic_per_token.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@ cadence: weekly
model: Qwen/Qwen2-VL-2B-Instruct
model_type: qwen2_vl
scheme: FP8_DYNAMIC
num_fewshot: 5
limit: 1000
task: mmmu_val_accounting
exact_match,flexible-extract: 0.75
exact_match,strict-match: 0.75
lmeval:
model: "hf-multimodal"
model_args:
dtype: bfloat16
add_bos_token: True
convert_img_format: True
task: mmmu_val_economics
num_fewshot: 0
limit: 10
metrics:
acc,none: 0.3
45 changes: 24 additions & 21 deletions tests/lmeval/test_lmeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,20 @@
import yaml
from loguru import logger

from pydantic import BaseModel
from llmcompressor.core import active_session
from tests.e2e.e2e_utils import run_oneshot_for_e2e_testing
from tests.examples.utils import requires_gpu_count


class LmEvalConfig(BaseModel):
model: str = "hf"
model_args: dict = {"add_bos_token":True, "dtype": "bfloat16"}
task: str
num_fewshot: int
limit: int
metrics: dict

try:
import lm_eval

Expand Down Expand Up @@ -51,19 +61,15 @@ def set_up(self):
pytest.skip("Skipping test; cadence mismatch")

self.model = eval_config["model"]
self.model_type = eval_config["model_type"]
self.model_type = eval_config.get("model_type", "")
self.lmeval = LmEvalConfig(**eval_config.get("lmeval"))
self.scheme = eval_config.get("scheme")
self.dataset_id = eval_config.get("dataset_id")
self.dataset_config = eval_config.get("dataset_config")
self.dataset_split = eval_config.get("dataset_split")
self.recipe = eval_config.get("recipe")
self.quant_type = eval_config.get("quant_type")
self.save_dir = eval_config.get("save_dir")
self.task = eval_config.get("task")
self.num_fewshot = eval_config.get("num_fewshot")
self.limit = eval_config.get("limit")
self.exact_flex = eval_config.get("exact_match,flexible-extract")
self.exact_strict = eval_config.get("exact_match,strict-match")

logger.info("========== RUNNING ==============")
logger.info(self.scheme)
Expand Down Expand Up @@ -106,27 +112,24 @@ def test_lm_eval(self):

logger.info("================= Running LM Eval ======================")

model_args = f"pretrained={self.save_dir},add_bos_token=True"
model_args = {"pretrained": self.save_dir}
model_args.update(self.lmeval.model_args)
results = lm_eval.simple_evaluate(
# TODO conditional on task type?
model="hf-multimodal", # "hf",
model=self.lmeval.model,
model_args=model_args,
tasks=[self.task],
num_fewshot=self.num_fewshot,
limit=self.limit,
tasks=[self.lmeval.task],
num_fewshot=self.lmeval.num_fewshot,
limit=self.lmeval.limit,
device="cuda:0",
batch_size=100,
)

metrics = results["results"][self.task]
exact_match_strict = metrics.get("exact_match,strict-match")
exact_match_flex = metrics.get("exact_match,flexible-extract")
logger.info("Exact Match, Strict")
logger.info(exact_match_strict)
logger.info("Exact Match, Flex")
logger.info(exact_match_flex)
assert numpy.isclose(exact_match_strict, self.exact_strict, rtol=0.05)
assert numpy.isclose(exact_match_flex, self.exact_flex, rtol=0.05)
for metric, expected_val in self.lmeval.metrics.items():
breakpoint()
actual_val = metrics.get(metric)
logger.info(f"Comparing {metric}: Expected {expected_val}, Got {actual_val}")
assert numpy.isclose(expected_val, actual_val, rtol=0.05)

self.tear_down()

def tear_down(self):
Expand Down

0 comments on commit 24c757e

Please sign in to comment.