From 873f3f831e5f5f458ae0ce3d153cd447bd0424d5 Mon Sep 17 00:00:00 2001 From: zhong-al <74470739+zhong-al@users.noreply.github.com> Date: Fri, 13 Dec 2024 21:36:36 -0500 Subject: [PATCH 1/5] Add yolo model option --- src/kabr_tools/detector2cvat.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/kabr_tools/detector2cvat.py b/src/kabr_tools/detector2cvat.py index 206da44..4973cd5 100644 --- a/src/kabr_tools/detector2cvat.py +++ b/src/kabr_tools/detector2cvat.py @@ -8,13 +8,14 @@ from kabr_tools.utils.draw import Draw -def detector2cvat(path_to_videos: str, path_to_save: str, show: bool) -> None: +def detector2cvat(path_to_videos: str, path_to_save: str, model: str, show: bool) -> None: """ Detect objects with Ultralytics YOLO detections, apply SORT tracking and convert tracks to CVAT format. Parameters: path_to_videos - str. Path to the folder containing videos. path_to_save - str. Path to the folder to save output xml & mp4 files. + model - str. YOLO model to use with detections. show - bool. Flag to display detector's visualization. """ videos = [] @@ -29,7 +30,7 @@ def detector2cvat(path_to_videos: str, path_to_save: str, show: bool) -> None: videos.append(f"{root}/{file}") - yolo = YOLOv8(weights="yolov8x.pt", imgsz=3840, conf=0.5) + yolo = YOLOv8(weights=model, imgsz=3840, conf=0.5) for i, video in enumerate(videos): try: @@ -120,6 +121,12 @@ def parse_args() -> argparse.Namespace: help="path to save output xml & mp4 files", required=True ) + local_parser.add_argument( + "--yolo", + type=str, + default="yolov8x.pt", + help="yolo model to use with detections" + ) local_parser.add_argument( "--imshow", action="store_true", @@ -130,7 +137,7 @@ def parse_args() -> argparse.Namespace: def main() -> None: args = parse_args() - detector2cvat(args.video, args.save, args.imshow) + detector2cvat(args.video, args.save, args.yolo, args.imshow) if __name__ == "__main__": From 7d4361c41dfd23d04d29650624f8e7610f67d98e Mon Sep 17 00:00:00 2001 From: zhong-al <74470739+zhong-al@users.noreply.github.com> Date: Fri, 13 Dec 2024 21:37:17 -0500 Subject: [PATCH 2/5] Update readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d26f71f..2a55e7b 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ You may use [YOLO](https://docs.ultralytics.com/) to automatically perform detec Detect objects with Ultralytics YOLO detections, apply SORT tracking and convert tracks to CVAT format. ``` -detector2cvat --video path_to_videos --save path_to_save [--imshow] +detector2cvat --video path_to_videos --save path_to_save [--yolo yolo_model] [--imshow] ``` From c666e5366eaba87de6a05ac7c1c0371a0b47effa Mon Sep 17 00:00:00 2001 From: zhong-al <74470739+zhong-al@users.noreply.github.com> Date: Fri, 13 Dec 2024 21:45:48 -0500 Subject: [PATCH 3/5] Add test --- tests/test_detector2cvat.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_detector2cvat.py b/tests/test_detector2cvat.py index 675e7bc..79f0e1c 100644 --- a/tests/test_detector2cvat.py +++ b/tests/test_detector2cvat.py @@ -33,6 +33,7 @@ def setUp(self): self.tool = "detector2cvat.py" self.video = TestDetector2Cvat.dir self.save = "tests/detector2cvat" + self.yolo = "yolov5s.pt" def tearDown(self): # delete outputs @@ -55,6 +56,9 @@ def test_parse_arg_min(self): # check parsed argument values self.assertEqual(args.video, self.video) self.assertEqual(args.save, self.save) + + # check default argument values + self.assertEqual(args.yolo, "yolov8x.pt") self.assertEqual(args.imshow, False) def test_parse_arg_full(self): @@ -62,10 +66,15 @@ def test_parse_arg_full(self): sys.argv = [self.tool, "--video", self.video, "--save", self.save, + "--yolo", self.yolo, "--imshow"] args = detector2cvat.parse_args() # check parsed argument values self.assertEqual(args.video, self.video) self.assertEqual(args.save, self.save) + self.assertEqual(args.yolo, self.yolo) self.assertEqual(args.imshow, True) + + # run + run() From 40eb0908f0cb71e9e78e4e07bf048f9a212d0b73 Mon Sep 17 00:00:00 2001 From: zhong-al <74470739+zhong-al@users.noreply.github.com> Date: Fri, 13 Dec 2024 21:56:50 -0500 Subject: [PATCH 4/5] Patch imshow --- tests/test_detector2cvat.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_detector2cvat.py b/tests/test_detector2cvat.py index 79f0e1c..543a290 100644 --- a/tests/test_detector2cvat.py +++ b/tests/test_detector2cvat.py @@ -1,6 +1,7 @@ import unittest import sys import os +from unittest.mock import patch from kabr_tools import detector2cvat from tests.utils import ( del_dir, @@ -60,8 +61,9 @@ def test_parse_arg_min(self): # check default argument values self.assertEqual(args.yolo, "yolov8x.pt") self.assertEqual(args.imshow, False) - - def test_parse_arg_full(self): + + @patch('kabr_tools.detector2cvat.cv2.imshow') + def test_parse_arg_full(self, imshow): # parse arguments sys.argv = [self.tool, "--video", self.video, From f5d6eea2e68428fe8fd555646fd12102ba0e45b0 Mon Sep 17 00:00:00 2001 From: zhong-al <74470739+zhong-al@users.noreply.github.com> Date: Fri, 13 Dec 2024 21:57:40 -0500 Subject: [PATCH 5/5] Ignore all pt files --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 77b27f5..bc2a60f 100644 --- a/.gitignore +++ b/.gitignore @@ -7,7 +7,7 @@ statistics_all.png statistics_g.png statistics_zg.png statistics_zp.png -yolov8x.pt +*.pt *.mp4