Skip to content

Commit

Permalink
Work on dicts
Browse files Browse the repository at this point in the history
  • Loading branch information
Tinche committed Feb 17, 2024
1 parent b13dae8 commit 7bcd631
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 10 deletions.
22 changes: 19 additions & 3 deletions src/cattrs/v/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Cattrs validation."""
from typing import Any, Callable, List, Tuple, Type, TypeVar, Union, overload
from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union, overload

from attrs import NOTHING, frozen

Expand Down Expand Up @@ -152,6 +152,7 @@ def transform_error(

T = TypeVar("T")
E = TypeVar("E")
TV = TypeVar("TV")


@overload
Expand All @@ -161,15 +162,30 @@ def ensure(
...


@overload
def ensure(
type: Type[Dict],
*validators: Callable[[Dict], Any],
keys: Type[E],
values: Type[TV],
) -> Type[Dict[E, TV]]:
...


@overload
def ensure(type: Type[T], *validators: Callable[[T], Any]) -> Type[T]:
...


def ensure(type: Any, *validators: Any, elems: Any = NOTHING) -> Any:
def ensure(type, *validators, elems=NOTHING, keys=NOTHING, values=NOTHING):
"""Ensure validators run when structuring the given type."""
if elems is not NOTHING:
# These are lists.
if not validators:
return type[elems]
return Annotated[type, VAnnotation(*validators)]
return Annotated[type[elems], VAnnotation(*validators)]
if keys is not NOTHING or values is not NOTHING:
if not validators:
return type[keys, values]
return Annotated[type[keys, values], VAnnotation(*validators)]
return Annotated[type, VAnnotation(*validators)]
90 changes: 83 additions & 7 deletions tests/v/test_ensure.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
"""Tests for `cattrs.v.ensure`."""
import sys
from typing import List, MutableSequence, Sequence
from typing import Dict, List, MutableSequence, Sequence

from pytest import fixture, mark, raises

from cattrs import BaseConverter
from cattrs._compat import ExceptionGroup
from cattrs.errors import IterableValidationError
from cattrs.v import ensure
from cattrs.v._hooks import is_validated, validator_factory


@fixture
def valconv(converter) -> BaseConverter:
converter.register_structure_hook_factory(is_validated)(validator_factory)
converter.register_structure_hook_factory(is_validated, validator_factory)
return converter


Expand All @@ -38,7 +39,7 @@ def test_ensured_lists(valconv: BaseConverter):
valconv.structure([], ensure(List[int], lambda lst: len(lst) > 0))

if valconv.detailed_validation:
assert isinstance(exc.value, ExceptionGroup)
assert isinstance(exc.value, IterableValidationError)
assert isinstance(exc.value.exceptions[0], ValueError)
else:
assert isinstance(exc.value, ValueError)
Expand All @@ -53,7 +54,53 @@ def test_ensured_list_elements(valconv: BaseConverter, type):
valconv.structure([1, -2], ensure(type, elems=ensure(int, lambda i: i > 0)))

if valconv.detailed_validation:
assert isinstance(exc.value, ExceptionGroup)
assert isinstance(exc.value, IterableValidationError)
assert isinstance(exc.value.exceptions[0], ExceptionGroup)
assert isinstance(exc.value.exceptions[0].exceptions[0], ValueError)
else:
assert isinstance(exc.value, ValueError)

# Now both elements and the list itself.
assert valconv.structure(
[1, 2],
ensure(type, lambda lst: len(lst) < 3, elems=ensure(int, lambda i: i > 0)),
)

with raises(Exception) as exc:
valconv.structure(
[1, 2, 3],
ensure(type, lambda lst: len(lst) < 3, elems=ensure(int, lambda i: i > 0)),
)

if valconv.detailed_validation:
assert isinstance(exc.value, IterableValidationError)
assert isinstance(exc.value.exceptions[0], ValueError)
else:
assert isinstance(exc.value, ValueError)

with raises(Exception) as exc:
valconv.structure(
[1, -2],
ensure(type, lambda lst: len(lst) < 3, elems=ensure(int, lambda i: i > 0)),
)

if valconv.detailed_validation:
assert isinstance(exc.value, IterableValidationError)
assert isinstance(exc.value.exceptions[0], ExceptionGroup)
assert isinstance(exc.value.exceptions[0].exceptions[0], ValueError)
else:
assert isinstance(exc.value, ValueError)


def test_ensured_typing_list(valconv: BaseConverter):
"""Ensure works for typing lists."""
assert valconv.structure([1, 2], ensure(List, elems=ensure(int, lambda i: i > 0)))

with raises(Exception) as exc:
valconv.structure([1, -2], ensure(List, elems=ensure(int, lambda i: i > 0)))

if valconv.detailed_validation:
assert isinstance(exc.value, IterableValidationError)
assert isinstance(exc.value.exceptions[0], ExceptionGroup)
assert isinstance(exc.value.exceptions[0].exceptions[0], ValueError)
else:
Expand All @@ -63,13 +110,42 @@ def test_ensured_list_elements(valconv: BaseConverter, type):
@mark.skipif(sys.version_info[:2] < (3, 10), reason="Not supported on older Pythons")
def test_ensured_list(valconv: BaseConverter):
"""Ensure works for builtin lists."""
assert valconv.structure([1, 2], ensure(List, elems=ensure(int, lambda i: i > 0)))
assert valconv.structure([1, 2], ensure(list, elems=ensure(int, lambda i: i > 0)))

with raises(Exception) as exc:
valconv.structure([1, -2], ensure(List, elems=ensure(int, lambda i: i > 0)))
valconv.structure([1, -2], ensure(list, elems=ensure(int, lambda i: i > 0)))

if valconv.detailed_validation:
assert isinstance(exc.value, ExceptionGroup)
assert isinstance(exc.value, IterableValidationError)
assert isinstance(exc.value.exceptions[0], ExceptionGroup)
assert isinstance(exc.value.exceptions[0].exceptions[0], ValueError)
else:
assert isinstance(exc.value, ValueError)


def test_ensured_typing_dict(valconv: BaseConverter):
"""Ensure works for typing.Dicts."""
assert valconv.structure(
{"a": 1}, ensure(Dict, lambda d: len(d) > 0, keys=str, values=int)
)

with raises(Exception) as exc:
valconv.structure({}, ensure(Dict, lambda d: len(d) > 0, keys=str, values=int))

if valconv.detailed_validation:
assert isinstance(exc.value, IterableValidationError)
assert isinstance(exc.value.exceptions[0], ValueError)
else:
assert isinstance(exc.value, ValueError)

with raises(Exception) as exc:
valconv.structure(
{"b": 1, "c": "a"},
ensure(Dict, keys=ensure(str, lambda s: s.startswith("a")), values=int),
)

if valconv.detailed_validation:
assert isinstance(exc.value, IterableValidationError)
assert isinstance(exc.value.exceptions[0], ExceptionGroup)
assert isinstance(exc.value.exceptions[0].exceptions[0], ValueError)
else:
Expand Down

0 comments on commit 7bcd631

Please sign in to comment.