Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Florence-2 evaluation command #38

Merged
merged 3 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

<div align="center">

<h1>maestro</h1>
Expand All @@ -23,13 +22,40 @@ Pip install the supervision package in a
pip install maestro
```

## 🚀 example
## 🔥 quickstart

### CLI

VLMs can be fine-tuned on downstream tasks directly from the command line with
`maestro` command:

```bash
maestro florence2 train --dataset='<DATASET_PATH>' --epochs=10 --batch-size=8
```

Documentation and Florence-2 fine-tuning examples for object detection and VQA coming
soon.
### SDK

Alternatively, you can fine-tune VLMs using the Python SDK, which accepts the same
arguments as the CLI example above:

```python
from maestro.trainer.common import MeanAveragePrecisionMetric
from maestro.trainer.models.florence_2 import train, TrainingConfiguration

config = TrainingConfiguration(
dataset='<DATASET_PATH>',
epochs=10,
batch_size=8,
metrics=[MeanAveragePrecisionMetric()]
)

train(config)
```

## 🚧 roadmap
## 🦸 contribution

- [ ] Release a CLI for predefined fine-tuning recipes.
- [ ] Multi-GPU fine-tuning support.
- [ ] Allow multi-dataset fine-tuning and support multiple tasks at the same time.
We would love your help in making this repository even better! We are especially
looking for contributors with experience in fine-tuning vision-language models (VLMs).
If you notice any bugs or have suggestions for improvement, feel free to open an
[issue](https://github.com/roboflow/multimodal-maestro/issues) or submit a
[pull request](https://github.com/roboflow/multimodal-maestro/pulls).
1 change: 1 addition & 0 deletions maestro/trainer/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from maestro.trainer.common.utils.metrics import MeanAveragePrecisionMetric
41 changes: 41 additions & 0 deletions maestro/trainer/common/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from typing import Any, Dict, List, Tuple

import matplotlib.pyplot as plt
import supervision as sv
from PIL import Image
from supervision.metrics.mean_average_precision import MeanAveragePrecision


class BaseMetric(ABC):
Expand Down Expand Up @@ -45,6 +47,45 @@ def compute(self, targets: List[Any], predictions: List[Any]) -> Dict[str, float
pass


class MeanAveragePrecisionMetric(BaseMetric):
"""
A class used to compute the Mean Average Precision (mAP) metric.
"""

def describe(self) -> List[str]:
"""
Returns a list of metric names that this class will compute.

Returns:
List[str]: A list of metric names.
"""
return ["map50:95", "map50", "map75"]

def compute(
self,
targets: List[sv.Detections],
predictions: List[sv.Detections]
) -> Dict[str, float]:
"""
Computes the mAP metrics based on the targets and predictions.

Args:
targets (List[sv.Detections]): The ground truth detections.
predictions (List[sv.Detections]): The predicted detections.

Returns:
Dict[str, float]: A dictionary of computed mAP metrics with metric names as
keys and their values.
"""
result = MeanAveragePrecision().update(
targets=targets, predictions=predictions).compute()
return {
"map50:95": result.map50_95,
"map50": result.map50,
"map75": result.map75
}


class MetricsTracker:

@classmethod
Expand Down
1 change: 0 additions & 1 deletion maestro/trainer/models/florence_2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from maestro.trainer.models.florence_2.core import TrainingConfiguration, train
from maestro.trainer.models.florence_2.metrics import MeanAveragePrecisionMetric
14 changes: 14 additions & 0 deletions maestro/trainer/models/florence_2/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,20 @@ def load_model(
device: torch.device = DEVICE,
cache_dir: Optional[str] = None,
) -> Tuple[AutoProcessor, AutoModelForCausalLM]:
"""Loads a Florence-2 model and its associated processor.

Args:
model_id_or_path: The identifier or path of the model to load.
revision: The specific model revision to use.
device: The device to load the model onto.
cache_dir: Directory to cache the downloaded model files.

Returns:
A tuple containing the loaded processor and model.

Raises:
ValueError: If the model or processor cannot be loaded.
"""
processor = AutoProcessor.from_pretrained(
model_id_or_path,
trust_remote_code=True,
Expand Down
77 changes: 72 additions & 5 deletions maestro/trainer/models/florence_2/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,23 @@
from transformers import AutoModelForCausalLM, AutoProcessor, get_scheduler

from maestro.trainer.common.utils.file_system import create_new_run_directory
from maestro.trainer.common.utils.metrics import BaseMetric, MetricsTracker, \
display_results, save_metric_plots
from maestro.trainer.common.utils.metrics import (
BaseMetric,
MetricsTracker,
display_results,
save_metric_plots,
MeanAveragePrecisionMetric
)
from maestro.trainer.common.utils.reproducibility import make_it_reproducible
from maestro.trainer.models.florence_2.checkpoints import CheckpointManager, load_model, \
DEFAULT_FLORENCE2_MODEL_ID, DEFAULT_FLORENCE2_MODEL_REVISION, DEVICE
from maestro.trainer.models.florence_2.checkpoints import (
CheckpointManager,
load_model,
DEFAULT_FLORENCE2_MODEL_ID,
DEFAULT_FLORENCE2_MODEL_REVISION,
DEVICE
)
from maestro.trainer.models.florence_2.data_loading import prepare_data_loaders
from maestro.trainer.models.florence_2.metrics import (
MeanAveragePrecisionMetric,
extract_unique_detection_dataset_classes,
postprocess_florence2_output_for_mean_average_precision,
run_predictions,
Expand Down Expand Up @@ -144,6 +153,10 @@ def train(config: TrainingConfiguration) -> None:
validation_metrics_tracker.as_json(
output_dir=os.path.join(config.output_dir, "metrics"),
filename="validation.json")

# Log out paths for latest and best checkpoints
print(f"Latest checkpoint saved at: {checkpoint_manager.latest_checkpoint_dir}")
print(f"Best checkpoint saved at: {checkpoint_manager.best_checkpoint_dir}")


def prepare_peft_model(
Expand Down Expand Up @@ -354,3 +367,57 @@ def get_optimizer(model: PeftModel, config: TrainingConfiguration) -> Optimizer:
if optimizer_type == "sgd":
return SGD(model.parameters(), lr=config.lr)
raise ValueError(f"Unsupported optimizer: {config.optimizer}")


def evaluate(config: TrainingConfiguration) -> None:
processor, model = load_model(
model_id_or_path=config.model_id,
revision=config.revision,
device=config.device,
cache_dir=config.cache_dir,
)
train_loader, val_loader, test_loader = prepare_data_loaders(
dataset_location=config.dataset,
train_batch_size=config.batch_size,
processor=processor,
device=config.device,
num_workers=config.num_workers,
test_loaders_workers=config.val_num_workers,
)
evaluation_loader = test_loader if test_loader is not None else val_loader

metrics = []
for metric in config.metrics:
metrics += metric.describe()
evaluation_metrics_tracker = MetricsTracker.init(metrics=metrics)

# Run inference once for all metrics
_, expected_responses, generated_texts, images = run_predictions(
dataset=evaluation_loader.dataset,
processor=processor,
model=model,
device=config.device,
)

for metric in config.metrics:
if isinstance(metric, MeanAveragePrecisionMetric):
classes = extract_unique_detection_dataset_classes(train_loader.dataset)
targets, predictions = postprocess_florence2_output_for_mean_average_precision(
expected_responses=expected_responses,
generated_texts=generated_texts,
images=images,
classes=classes,
processor=processor
)
result = metric.compute(targets=targets, predictions=predictions)
for key, value in result.items():
evaluation_metrics_tracker.register(
metric=key,
epoch=1,
step=1,
value=value,
)

evaluation_metrics_tracker.as_json(
output_dir=os.path.join(config.output_dir, "metrics"),
filename="evaluation.json")
94 changes: 87 additions & 7 deletions maestro/trainer/models/florence_2/entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import Optional, Annotated
from typing import Optional, Annotated, List, Dict, Type

import rich
import torch
Expand All @@ -8,11 +8,29 @@
from maestro.trainer.models.florence_2.checkpoints import DEFAULT_FLORENCE2_MODEL_ID, \
DEFAULT_FLORENCE2_MODEL_REVISION, DEVICE
from maestro.trainer.models.florence_2.core import TrainingConfiguration
from maestro.trainer.models.florence_2.core import train as train_fun
from maestro.trainer.models.florence_2.core import train as florence2_train
from maestro.trainer.models.florence_2.core import evaluate as florence2_evaluate
from maestro.trainer.common.utils.metrics import BaseMetric, MeanAveragePrecisionMetric

florence_2_app = typer.Typer(help="Fine-tune and evaluate Florence 2 model")


METRIC_CLASSES: Dict[str, Type[BaseMetric]] = {
"mean_average_precision": MeanAveragePrecisionMetric,
}


def parse_metrics(metrics: List[str]) -> List[BaseMetric]:
metric_objects = []
for metric_name in metrics:
metric_class = METRIC_CLASSES.get(metric_name.lower())
if metric_class:
metric_objects.append(metric_class())
else:
raise ValueError(f"Unsupported metric: {metric_name}")
return metric_objects


@florence_2_app.command(
help="Train Florence 2 model",
context_settings={"allow_extra_args": True, "ignore_unknown_options": True}
Expand Down Expand Up @@ -98,7 +116,12 @@ def train(
str,
typer.Option("--output_dir", help="Directory to save output files"),
] = "./training/florence-2",
metrics: Annotated[
List[str],
typer.Option("--metrics", help="List of metrics to track during training"),
] = [],
) -> None:
metric_objects = parse_metrics(metrics)
config = TrainingConfiguration(
dataset=dataset,
model_id=model_id,
Expand All @@ -119,21 +142,78 @@ def train(
bias=bias,
use_rslora=use_rslora,
init_lora_weights=init_lora_weights,
output_dir=output_dir
output_dir=output_dir,
metrics=metric_objects
)
typer.echo(typer.style(
text="Training configuration",
fg=typer.colors.BRIGHT_GREEN,
bold=True
))
rich.print(dataclasses.asdict(config))
train_fun(config=config)
florence2_train(config=config)


@florence_2_app.command(help="Evaluate Florence 2 model")
def evaluate() -> None:
def evaluate(
dataset: Annotated[
str,
typer.Option("--dataset", help="Path to the dataset used for evaluation"),
],
model_id: Annotated[
str,
typer.Option("--model_id", help="Identifier for the Florence-2 model"),
] = DEFAULT_FLORENCE2_MODEL_ID,
revision: Annotated[
str,
typer.Option("--revision", help="Revision of the model to use"),
] = DEFAULT_FLORENCE2_MODEL_REVISION,
device: Annotated[
str,
typer.Option("--device", help="Device to use for evaluation"),
] = DEVICE,
cache_dir: Annotated[
Optional[str],
typer.Option("--cache_dir", help="Directory to cache the model"),
] = None,
batch_size: Annotated[
int,
typer.Option("--batch_size", help="Batch size for evaluation"),
] = 4,
num_workers: Annotated[
int,
typer.Option("--num_workers", help="Number of workers for data loading"),
] = 0,
val_num_workers: Annotated[
Optional[int],
typer.Option("--val_num_workers", help="Number of workers for validation data loading"),
] = None,
output_dir: Annotated[
str,
typer.Option("--output_dir", help="Directory to save output files"),
] = "./evaluation/florence-2",
metrics: Annotated[
List[str],
typer.Option("--metrics", help="List of metrics to track during evaluation"),
] = [],
) -> None:
metric_objects = parse_metrics(metrics)
config = TrainingConfiguration(
dataset=dataset,
model_id=model_id,
revision=revision,
device=torch.device(device),
cache_dir=cache_dir,
batch_size=batch_size,
num_workers=num_workers,
val_num_workers=val_num_workers,
output_dir=output_dir,
metrics=metric_objects
)
typer.echo(typer.style(
"Evaluation command for Florence 2 is not yet implemented.",
fg=typer.colors.YELLOW,
text="Evaluation configuration",
fg=typer.colors.BRIGHT_GREEN,
bold=True
))
rich.print(dataclasses.asdict(config))
florence2_evaluate(config=config)
Loading