Skip to content

Commit

Permalink
fix detectron2 device (#763)
Browse files Browse the repository at this point in the history
  • Loading branch information
fcakyon authored Dec 1, 2022
1 parent 091b38a commit aacf865
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 9 deletions.
2 changes: 1 addition & 1 deletion sahi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.11.5"
__version__ = "0.11.6"

from sahi.annotation import BoundingBox, Category, Mask
from sahi.auto_model import AutoDetectionModel
Expand Down
2 changes: 1 addition & 1 deletion sahi/models/detectron2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def load_model(self):
cfg.MODEL.WEIGHTS = self.model_path

# set model device
cfg.MODEL.DEVICE = self.device
cfg.MODEL.DEVICE = self.device.type
# set input image size
if self.image_size is not None:
cfg.INPUT.MIN_SIZE_TEST = self.image_size
Expand Down
4 changes: 2 additions & 2 deletions sahi/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ def check_package_minimum_version(package_name: str, minimum_version: str):
return True


def ensure_package_minimum_version(package_name: str, minimum_version: str):
def ensure_package_minimum_version(package_name: str, minimum_version: str, verbose=False):
"""
Raise error if module version is not compatible.
"""
from packaging import version

_is_available, _version = get_package_info(package_name)
_is_available, _version = get_package_info(package_name, verbose=verbose)
if _is_available:
if _version == "unknown":
logger.warning(
Expand Down
11 changes: 6 additions & 5 deletions tests/test_detectron2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

# note that detectron2 binaries are available only for linux

if get_package_info("torch", verbose=False)[1] == "1.10.2":
torch_version = get_package_info("torch", verbose=False)[1]
if "1.10." in torch_version:

class TestDetectron2DetectionModel(unittest.TestCase):
def test_load_model(self):
Expand Down Expand Up @@ -145,7 +146,7 @@ def test_convert_original_predictions_with_mask_output(self):
raise AssertionError(f"desired_bbox: {desired_bbox}, predicted_bbox: {predicted_bbox}")

def test_get_prediction_detectron2(self):
from sahi.model import Detectron2DetectionModel
from sahi.models.detectron2 import Detectron2DetectionModel
from sahi.predict import get_prediction
from sahi.utils.detectron2 import Detectron2TestConstants

Expand Down Expand Up @@ -194,7 +195,7 @@ def test_get_prediction_detectron2(self):
self.assertEqual(num_car, 16)

def test_get_sliced_prediction_detectron2(self):
from sahi.model import Detectron2DetectionModel
from sahi.models.detectron2 import Detectron2DetectionModel
from sahi.predict import get_sliced_prediction
from sahi.utils.detectron2 import Detectron2TestConstants

Expand Down Expand Up @@ -239,7 +240,7 @@ def test_get_sliced_prediction_detectron2(self):
object_prediction_list = prediction_result.object_prediction_list

# compare
self.assertEqual(len(object_prediction_list), 18)
self.assertEqual(len(object_prediction_list), 19)
num_person = 0
for object_prediction in object_prediction_list:
if object_prediction.category.name == "person":
Expand All @@ -254,7 +255,7 @@ def test_get_sliced_prediction_detectron2(self):
for object_prediction in object_prediction_list:
if object_prediction.category.name == "car":
num_car += 1
self.assertEqual(num_car, 18)
self.assertEqual(num_car, 19)


if __name__ == "__main__":
Expand Down

0 comments on commit aacf865

Please sign in to comment.