Skip to content

Commit

Permalink
Add support for type-checking TypedDict objects.
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewwardrop committed Oct 21, 2024
1 parent 756bf53 commit 080ae20
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 2 deletions.
29 changes: 28 additions & 1 deletion spec_classes/utils/type_checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,18 @@
TypeVar,
Union,
_GenericAlias,
get_type_hints,
)

# pylint: disable=protected-access
from typing_extensions import Literal as LiteralExtension
from typing_extensions import (
Literal as LiteralExtension,
)
from typing_extensions import (
NotRequired,
Required,
is_typeddict,
)

try:
from typing import Literal
Expand Down Expand Up @@ -48,6 +56,25 @@ def check_type(value: Any, attr_type: Type) -> bool:
if sys.version_info >= (3, 10) and isinstance(attr_type, types.UnionType):
return any(check_type(value, type_) for type_ in attr_type.__args__)

if is_typeddict(attr_type): # we are dealinq with a TypedDict
if not isinstance(value, dict):
return False

Check warning on line 61 in spec_classes/utils/type_checking.py

View check run for this annotation

Codecov / codecov/patch

spec_classes/utils/type_checking.py#L61

Added line #L61 was not covered by tests
keys = set(value.keys())
if attr_type.__required_keys__.difference(keys):
return False
if keys.difference(attr_type.__required_keys__).difference(
attr_type.__optional_keys__
):
return False

Check warning on line 68 in spec_classes/utils/type_checking.py

View check run for this annotation

Codecov / codecov/patch

spec_classes/utils/type_checking.py#L68

Added line #L68 was not covered by tests
annotations = get_type_hints(attr_type)
for key, value in value.items():
subattr_type = annotations[key]
if getattr(subattr_type, "__origin__", None) in (Required, NotRequired):
subattr_type = subattr_type.__args__[0]
if not check_type(value, subattr_type):
return False
return True

if hasattr(attr_type, "__origin__"): # we are dealing with a `typing` object.
if attr_type.__origin__ is Union:
return any(check_type(value, type_) for type_ in attr_type.__args__)
Expand Down
39 changes: 38 additions & 1 deletion tests/utils/test_type_checking.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
from typing import Any, Callable, Dict, List, Set, Tuple, Type, TypeVar, Union

from typing_extensions import Literal
from typing_extensions import Literal, NotRequired, Required, TypedDict

from spec_classes import spec_class
from spec_classes.types import KeyedList, KeyedSet
Expand Down Expand Up @@ -130,3 +130,40 @@ def test_type_instantiate(self):
assert type_instantiate(List[str]) == []
assert type_instantiate(dict, a=1) == {"a": 1}
assert type_instantiate(Dict[str, int], a=1) == {"a": 1}

def test_typed_dict(self):
class Movie(TypedDict):
name: str
year: int

assert check_type({"name": "The Matrix", "year": 1999}, Movie)
assert not check_type({"name": "The Matrix"}, Movie)
assert not check_type({"name": "The Matrix", "year": "1999"}, Movie)

class NonTotalMovie(TypedDict, total=False):
name: str
year: int

assert check_type({"name": "The Matrix", "year": 1999}, NonTotalMovie)
assert check_type({"name": "The Matrix"}, NonTotalMovie)

class AnnotatedMovie(TypedDict):
name: Required[str]
year: NotRequired[int]

assert check_type({"name": "The Matrix"}, AnnotatedMovie)
assert not check_type({"year": 1999}, AnnotatedMovie)

class PartiallyAnnotatedMovie(TypedDict):
name: str
year: NotRequired[int]

assert check_type({"name": "The Matrix"}, PartiallyAnnotatedMovie)
assert not check_type({"year": 1999}, PartiallyAnnotatedMovie)

class PartiallyAnnotatedMovie2(TypedDict, total=False):
name: Required[str]
year: int

assert check_type({"name": "The Matrix"}, PartiallyAnnotatedMovie2)
assert not check_type({"year": 1999}, PartiallyAnnotatedMovie2)

0 comments on commit 080ae20

Please sign in to comment.