Skip to content

Commit

Permalink
Factor out callable check type checks
Browse files Browse the repository at this point in the history
Signed-off-by: Nijat Khanbabayev <[email protected]>
  • Loading branch information
NeejWeej committed Jan 21, 2025
1 parent 7cb34c1 commit e48c2a7
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 3 deletions.
4 changes: 1 addition & 3 deletions csp/impl/types/instantiation_type_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,7 @@ def _add_scalar_value(self, arg, in_out_def):
def _is_scalar_value_matching_spec(self, inp_def_type, arg):
if inp_def_type is typing.Any:
return True
if inp_def_type is typing.Callable or (
hasattr(inp_def_type, "__origin__") and CspTypingUtils.get_origin(inp_def_type) is collections.abc.Callable
):
if CspTypingUtils.is_callable(inp_def_type):
return callable(arg)
resolved_type = UpcastRegistry.instance().resolve_type(inp_def_type, type(arg), raise_on_error=False)
if resolved_type is inp_def_type:
Expand Down
10 changes: 10 additions & 0 deletions csp/impl/types/typing_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# utils for dealing with typing types
import collections
import numpy
import sys
import types
Expand Down Expand Up @@ -29,6 +30,15 @@ def get_origin(cls, typ):
raw_origin = typ.__origin__
return cls._ORIGIN_COMPAT_MAP.get(raw_origin, raw_origin)

@classmethod
def is_callable(cls, typ):
# Checks if a type annotation refers to a callable
if typ is typing.Callable:
return True
if not hasattr(typ, "__origin__"):
return False
return CspTypingUtils.get_origin(typ) is collections.abc.Callable

@classmethod
def is_numpy_array_type(cls, typ):
return CspTypingUtils.is_generic_container(typ) and CspTypingUtils.get_orig_base(typ) is numpy.ndarray
Expand Down
21 changes: 21 additions & 0 deletions csp/tests/test_type_checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import csp
import csp.impl.types.instantiation_type_resolver as type_resolver
from csp import ts
from csp.impl.types.typing_utils import CspTypingUtils
from csp.impl.wiring.runtime import build_graph

USE_PYDANTIC = os.environ.get("CSP_PYDANTIC")
Expand Down Expand Up @@ -916,6 +917,26 @@ def graph():

csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1))

def test_is_callable(self):
"""Test CspTypingUtils.is_callable with various input types"""
# Test cases as (input, expected_result) pairs
test_cases = [
# Direct Callable types
(Callable, True),
(Callable[[int, str], bool], True),
(Callable[..., None], True),
(Callable[[int], str], True),
# optional Callable is not Callable
(Optional[Callable], False),
# Typing module types
(List[int], False),
(Dict[str, int], False),
(typing.Set[str], False),
]
for input_type, expected in test_cases:
result = CspTypingUtils.is_callable(input_type)
self.assertEqual(result, expected)


if __name__ == "__main__":
unittest.main()

0 comments on commit e48c2a7

Please sign in to comment.