Skip to content

Commit

Permalink
test mAP metic
Browse files Browse the repository at this point in the history
  • Loading branch information
SkalskiP committed Sep 9, 2024
1 parent 9e52912 commit 4c3fbd0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
3 changes: 1 addition & 2 deletions maestro/trainer/models/florence_2/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,13 @@ def get_ground_truths_and_predictions(
dataset: DetectionDataset,
processor: AutoProcessor,
model: AutoModelForCausalLM,
split_name: str,
device: torch.device,
) -> Tuple[List[sv.Detections], List[sv.Detections], List[str]]:
classes = extract_classes(dataset=dataset)
targets = []
predictions = []
post_processed_text_outputs = []
for idx in tqdm(list(range(len(dataset))), desc=f"Generating {split_name} predictions..."):
for idx in tqdm(list(range(len(dataset))), desc="Generating predictions..."):
image, data = dataset.dataset[idx]
prefix = data["prefix"]
suffix = data["suffix"]
Expand Down
24 changes: 24 additions & 0 deletions maestro/trainer/models/florence_2/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
save_metric_plots, BaseMetric
from maestro.trainer.common.utils.reproducibility import make_it_reproducible
from maestro.trainer.models.florence_2.data_loading import prepare_data_loaders
from maestro.trainer.models.florence_2.metrics import MeanAveragePrecisionMetric, get_ground_truths_and_predictions
from maestro.trainer.models.paligemma.training import LoraInitLiteral

DEFAULT_FLORENCE2_MODEL_ID = "microsoft/Florence-2-base-ft"
Expand Down Expand Up @@ -340,6 +341,29 @@ def run_validation_epoch(
)
print(f"Average Validation Loss: {avg_val_loss}")

# TODO: standardize the calculation of metrics input to run inference only once

for metric in configuration.metrics:
if isinstance(metric, MeanAveragePrecisionMetric):
targets, predictions, _ = get_ground_truths_and_predictions(
dataset=loader.dataset,
processor=processor,
model=model,
device=configuration.device,
)
map_result = metric.compute(targets=targets, predictions=predictions)
for map_key, map_value in map_result.items():
metrics_tracker.register(
metric=map_key,
epoch=epoch_number,
step=1,
value=map_value,
)
print(f"Validation {map_key}: {map_value:.4f}")
else:
# Handle other metric types
pass


def save_model(
target_dir: str,
Expand Down

0 comments on commit 4c3fbd0

Please sign in to comment.