diff --git a/.gitignore b/.gitignore index 77b27f5..bc2a60f 100644 --- a/.gitignore +++ b/.gitignore @@ -7,7 +7,7 @@ statistics_all.png statistics_g.png statistics_zg.png statistics_zp.png -yolov8x.pt +*.pt *.mp4 diff --git a/README.md b/README.md index d26f71f..2a55e7b 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ You may use [YOLO](https://docs.ultralytics.com/) to automatically perform detec Detect objects with Ultralytics YOLO detections, apply SORT tracking and convert tracks to CVAT format. ``` -detector2cvat --video path_to_videos --save path_to_save [--imshow] +detector2cvat --video path_to_videos --save path_to_save [--yolo yolo_model] [--imshow] ``` diff --git a/src/kabr_tools/detector2cvat.py b/src/kabr_tools/detector2cvat.py index 206da44..4973cd5 100644 --- a/src/kabr_tools/detector2cvat.py +++ b/src/kabr_tools/detector2cvat.py @@ -8,13 +8,14 @@ from kabr_tools.utils.draw import Draw -def detector2cvat(path_to_videos: str, path_to_save: str, show: bool) -> None: +def detector2cvat(path_to_videos: str, path_to_save: str, model: str, show: bool) -> None: """ Detect objects with Ultralytics YOLO detections, apply SORT tracking and convert tracks to CVAT format. Parameters: path_to_videos - str. Path to the folder containing videos. path_to_save - str. Path to the folder to save output xml & mp4 files. + model - str. YOLO model to use with detections. show - bool. Flag to display detector's visualization. """ videos = [] @@ -29,7 +30,7 @@ def detector2cvat(path_to_videos: str, path_to_save: str, show: bool) -> None: videos.append(f"{root}/{file}") - yolo = YOLOv8(weights="yolov8x.pt", imgsz=3840, conf=0.5) + yolo = YOLOv8(weights=model, imgsz=3840, conf=0.5) for i, video in enumerate(videos): try: @@ -120,6 +121,12 @@ def parse_args() -> argparse.Namespace: help="path to save output xml & mp4 files", required=True ) + local_parser.add_argument( + "--yolo", + type=str, + default="yolov8x.pt", + help="yolo model to use with detections" + ) local_parser.add_argument( "--imshow", action="store_true", @@ -130,7 +137,7 @@ def parse_args() -> argparse.Namespace: def main() -> None: args = parse_args() - detector2cvat(args.video, args.save, args.imshow) + detector2cvat(args.video, args.save, args.yolo, args.imshow) if __name__ == "__main__": diff --git a/tests/test_detector2cvat.py b/tests/test_detector2cvat.py index 675e7bc..543a290 100644 --- a/tests/test_detector2cvat.py +++ b/tests/test_detector2cvat.py @@ -1,6 +1,7 @@ import unittest import sys import os +from unittest.mock import patch from kabr_tools import detector2cvat from tests.utils import ( del_dir, @@ -33,6 +34,7 @@ def setUp(self): self.tool = "detector2cvat.py" self.video = TestDetector2Cvat.dir self.save = "tests/detector2cvat" + self.yolo = "yolov5s.pt" def tearDown(self): # delete outputs @@ -55,17 +57,26 @@ def test_parse_arg_min(self): # check parsed argument values self.assertEqual(args.video, self.video) self.assertEqual(args.save, self.save) - self.assertEqual(args.imshow, False) - def test_parse_arg_full(self): + # check default argument values + self.assertEqual(args.yolo, "yolov8x.pt") + self.assertEqual(args.imshow, False) + + @patch('kabr_tools.detector2cvat.cv2.imshow') + def test_parse_arg_full(self, imshow): # parse arguments sys.argv = [self.tool, "--video", self.video, "--save", self.save, + "--yolo", self.yolo, "--imshow"] args = detector2cvat.parse_args() # check parsed argument values self.assertEqual(args.video, self.video) self.assertEqual(args.save, self.save) + self.assertEqual(args.yolo, self.yolo) self.assertEqual(args.imshow, True) + + # run + run()