diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 0999bf7ba6b..c539319df19 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -7,7 +7,7 @@ from functools import partial from inspect import signature from types import ModuleType -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union +from typing import Any, Callable, Dict, get_args, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union from torch import nn @@ -168,14 +168,13 @@ def _get_enum_from_fn(fn: Callable) -> Type[WeightsEnum]: if "weights" not in sig.parameters: raise ValueError("The method is missing the 'weights' argument.") - ann = signature(fn).parameters["weights"].annotation + ann = sig.parameters["weights"].annotation weights_enum = None if isinstance(ann, type) and issubclass(ann, WeightsEnum): weights_enum = ann else: # handle cases like Union[Optional, T] - # TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8 - for t in ann.__args__: # type: ignore[union-attr] + for t in get_args(ann): # type: ignore[union-attr] if isinstance(t, type) and issubclass(t, WeightsEnum): weights_enum = t break