diff --git a/src/cattrs/v/__init__.py b/src/cattrs/v/__init__.py index 024d4c5c..d39def5c 100644 --- a/src/cattrs/v/__init__.py +++ b/src/cattrs/v/__init__.py @@ -1,5 +1,5 @@ """Cattrs validation.""" -from typing import Any, Callable, List, Tuple, Type, TypeVar, Union, overload +from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union, overload from attrs import NOTHING, frozen @@ -152,6 +152,7 @@ def transform_error( T = TypeVar("T") E = TypeVar("E") +TV = TypeVar("TV") @overload @@ -161,15 +162,30 @@ def ensure( ... +@overload +def ensure( + type: Type[Dict], + *validators: Callable[[Dict], Any], + keys: Type[E], + values: Type[TV], +) -> Type[Dict[E, TV]]: + ... + + @overload def ensure(type: Type[T], *validators: Callable[[T], Any]) -> Type[T]: ... -def ensure(type: Any, *validators: Any, elems: Any = NOTHING) -> Any: +def ensure(type, *validators, elems=NOTHING, keys=NOTHING, values=NOTHING): + """Ensure validators run when structuring the given type.""" if elems is not NOTHING: # These are lists. if not validators: return type[elems] - return Annotated[type, VAnnotation(*validators)] + return Annotated[type[elems], VAnnotation(*validators)] + if keys is not NOTHING or values is not NOTHING: + if not validators: + return type[keys, values] + return Annotated[type[keys, values], VAnnotation(*validators)] return Annotated[type, VAnnotation(*validators)] diff --git a/tests/v/test_ensure.py b/tests/v/test_ensure.py index 1c53b0cf..eff81674 100644 --- a/tests/v/test_ensure.py +++ b/tests/v/test_ensure.py @@ -1,18 +1,19 @@ """Tests for `cattrs.v.ensure`.""" import sys -from typing import List, MutableSequence, Sequence +from typing import Dict, List, MutableSequence, Sequence from pytest import fixture, mark, raises from cattrs import BaseConverter from cattrs._compat import ExceptionGroup +from cattrs.errors import IterableValidationError from cattrs.v import ensure from cattrs.v._hooks import is_validated, validator_factory @fixture def valconv(converter) -> BaseConverter: - converter.register_structure_hook_factory(is_validated)(validator_factory) + converter.register_structure_hook_factory(is_validated, validator_factory) return converter @@ -38,7 +39,7 @@ def test_ensured_lists(valconv: BaseConverter): valconv.structure([], ensure(List[int], lambda lst: len(lst) > 0)) if valconv.detailed_validation: - assert isinstance(exc.value, ExceptionGroup) + assert isinstance(exc.value, IterableValidationError) assert isinstance(exc.value.exceptions[0], ValueError) else: assert isinstance(exc.value, ValueError) @@ -53,7 +54,53 @@ def test_ensured_list_elements(valconv: BaseConverter, type): valconv.structure([1, -2], ensure(type, elems=ensure(int, lambda i: i > 0))) if valconv.detailed_validation: - assert isinstance(exc.value, ExceptionGroup) + assert isinstance(exc.value, IterableValidationError) + assert isinstance(exc.value.exceptions[0], ExceptionGroup) + assert isinstance(exc.value.exceptions[0].exceptions[0], ValueError) + else: + assert isinstance(exc.value, ValueError) + + # Now both elements and the list itself. + assert valconv.structure( + [1, 2], + ensure(type, lambda lst: len(lst) < 3, elems=ensure(int, lambda i: i > 0)), + ) + + with raises(Exception) as exc: + valconv.structure( + [1, 2, 3], + ensure(type, lambda lst: len(lst) < 3, elems=ensure(int, lambda i: i > 0)), + ) + + if valconv.detailed_validation: + assert isinstance(exc.value, IterableValidationError) + assert isinstance(exc.value.exceptions[0], ValueError) + else: + assert isinstance(exc.value, ValueError) + + with raises(Exception) as exc: + valconv.structure( + [1, -2], + ensure(type, lambda lst: len(lst) < 3, elems=ensure(int, lambda i: i > 0)), + ) + + if valconv.detailed_validation: + assert isinstance(exc.value, IterableValidationError) + assert isinstance(exc.value.exceptions[0], ExceptionGroup) + assert isinstance(exc.value.exceptions[0].exceptions[0], ValueError) + else: + assert isinstance(exc.value, ValueError) + + +def test_ensured_typing_list(valconv: BaseConverter): + """Ensure works for typing lists.""" + assert valconv.structure([1, 2], ensure(List, elems=ensure(int, lambda i: i > 0))) + + with raises(Exception) as exc: + valconv.structure([1, -2], ensure(List, elems=ensure(int, lambda i: i > 0))) + + if valconv.detailed_validation: + assert isinstance(exc.value, IterableValidationError) assert isinstance(exc.value.exceptions[0], ExceptionGroup) assert isinstance(exc.value.exceptions[0].exceptions[0], ValueError) else: @@ -63,13 +110,42 @@ def test_ensured_list_elements(valconv: BaseConverter, type): @mark.skipif(sys.version_info[:2] < (3, 10), reason="Not supported on older Pythons") def test_ensured_list(valconv: BaseConverter): """Ensure works for builtin lists.""" - assert valconv.structure([1, 2], ensure(List, elems=ensure(int, lambda i: i > 0))) + assert valconv.structure([1, 2], ensure(list, elems=ensure(int, lambda i: i > 0))) with raises(Exception) as exc: - valconv.structure([1, -2], ensure(List, elems=ensure(int, lambda i: i > 0))) + valconv.structure([1, -2], ensure(list, elems=ensure(int, lambda i: i > 0))) if valconv.detailed_validation: - assert isinstance(exc.value, ExceptionGroup) + assert isinstance(exc.value, IterableValidationError) + assert isinstance(exc.value.exceptions[0], ExceptionGroup) + assert isinstance(exc.value.exceptions[0].exceptions[0], ValueError) + else: + assert isinstance(exc.value, ValueError) + + +def test_ensured_typing_dict(valconv: BaseConverter): + """Ensure works for typing.Dicts.""" + assert valconv.structure( + {"a": 1}, ensure(Dict, lambda d: len(d) > 0, keys=str, values=int) + ) + + with raises(Exception) as exc: + valconv.structure({}, ensure(Dict, lambda d: len(d) > 0, keys=str, values=int)) + + if valconv.detailed_validation: + assert isinstance(exc.value, IterableValidationError) + assert isinstance(exc.value.exceptions[0], ValueError) + else: + assert isinstance(exc.value, ValueError) + + with raises(Exception) as exc: + valconv.structure( + {"b": 1, "c": "a"}, + ensure(Dict, keys=ensure(str, lambda s: s.startswith("a")), values=int), + ) + + if valconv.detailed_validation: + assert isinstance(exc.value, IterableValidationError) assert isinstance(exc.value.exceptions[0], ExceptionGroup) assert isinstance(exc.value.exceptions[0].exceptions[0], ValueError) else: