Skip to content

Commit

Permalink
Fix RGB channel ordering for yolov8 (#836)
Browse files Browse the repository at this point in the history
  • Loading branch information
mayrajeo authored Mar 7, 2023
1 parent 1c2b10f commit 9b5aeb2
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 22 deletions.
6 changes: 1 addition & 5 deletions sahi/models/yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ def set_model(self, model: Any):
A YOLOv8 model
"""

# if model.__class__.__module__ not in ["yolov5.models.common", "models.common"]:
# raise Exception(f"Not a yolov5 model: {type(model)}")

# model.conf = self.confidence_threshold
self.model = model

# set category_mapping
Expand All @@ -62,7 +58,7 @@ def perform_inference(self, image: np.ndarray):
# Confirm model is loaded
if self.model is None:
raise ValueError("Model is not loaded, load it by calling .load_model()")
prediction_result = self.model(image, verbose=False)
prediction_result = self.model(image[:, :, ::-1], verbose=False) # YOLOv8 expects numpy arrays to have BGR
prediction_result = [
result.boxes.boxes[result.boxes.boxes[:, 4] >= self.confidence_threshold] for result in prediction_result
]
Expand Down
38 changes: 21 additions & 17 deletions tests/test_yolov8model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

MODEL_DEVICE = "cpu"
CONFIDENCE_THRESHOLD = 0.3
IMAGE_SIZE = 320
IMAGE_SIZE = 640


class TestYolov8DetectionModel(unittest.TestCase):
Expand All @@ -30,7 +30,6 @@ def test_load_model(self):
self.assertNotEqual(yolov8_detection_model.model, None)

def test_set_model(self):

from ultralytics import YOLO

from sahi.models.yolov8 import Yolov8DetectionModel
Expand Down Expand Up @@ -109,6 +108,10 @@ def test_convert_original_predictions(self):
image_path = "tests/data/small-vehicles1.jpeg"
image = read_image(image_path)

# get raw predictions for reference
original_results = yolov8_detection_model.model.predict(image_path, conf=CONFIDENCE_THRESHOLD)[0].boxes
num_results = len(original_results)

# perform inference
yolov8_detection_model.perform_inference(image)

Expand All @@ -117,21 +120,22 @@ def test_convert_original_predictions(self):
object_prediction_list = yolov8_detection_model.object_prediction_list

# compare
self.assertEqual(len(object_prediction_list), 11)
self.assertEqual(object_prediction_list[0].category.id, 2)
self.assertEqual(object_prediction_list[0].category.name, "car")
desired_bbox = [448, 309, 49, 33]
predicted_bbox = object_prediction_list[0].bbox.to_xywh()
margin = 2
for ind, point in enumerate(predicted_bbox):
assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin
self.assertEqual(object_prediction_list[2].category.id, 2)
self.assertEqual(object_prediction_list[2].category.name, "car")
desired_bbox = [835, 307, 37, 37]
predicted_bbox = object_prediction_list[2].bbox.to_xywh()
for ind, point in enumerate(predicted_bbox):
assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin

self.assertEqual(len(object_prediction_list), num_results)

# loop through predictions and check that they are equal
for i in range(num_results):
desired_bbox = [
original_results[i].xyxy[0][0],
original_results[i].xyxy[0][1],
original_results[i].xywh[0][2],
original_results[i].xywh[0][3],
]
desired_cat_id = int(original_results[i].cls[0])
self.assertEqual(object_prediction_list[i].category.id, desired_cat_id)
predicted_bbox = object_prediction_list[i].bbox.to_xywh()
margin = 2
for ind, point in enumerate(predicted_bbox):
assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin
for object_prediction in object_prediction_list:
self.assertGreaterEqual(object_prediction.score.value, CONFIDENCE_THRESHOLD)

Expand Down

0 comments on commit 9b5aeb2

Please sign in to comment.