Skip to content

Commit

Permalink
Simplify unions with common base class.
Browse files Browse the repository at this point in the history
Detects if a `Union` type is redundant. Refs #121.

Includes benchmark change.
  • Loading branch information
coady committed Jun 16, 2024
1 parent 7bb21ae commit a98c163
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 7 deletions.
16 changes: 11 additions & 5 deletions multimethod/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ def get_mro(cls) -> tuple: # `inspect.getmro` doesn't handle all cases
return type.mro(cls) if isinstance(cls, type) else cls.mro()


def common_bases(*bases):
counts = collections.Counter()
for base in bases:
counts.update(cls for cls in get_mro(base) if issubclass(abc.ABCMeta, type(cls)))
return tuple(cls for cls in counts if counts[cls] == len(bases))


class subtype(abc.ABCMeta):
"""A normalized generic type which checks subscripts.
Expand Down Expand Up @@ -58,12 +65,11 @@ def __new__(cls, tp, *args):
return origin
bases = (origin,) if type(origin) in (type, abc.ABCMeta) else ()
if origin is Literal:
bases = (subtype(Union[tuple(map(type, args))]),)
bases = (cls(Union[tuple(map(type, args))]),)
if origin is Union:
counts = collections.Counter()
for arg in args:
counts.update(cls for cls in get_mro(arg) if issubclass(abc.ABCMeta, type(cls)))
bases = tuple(cls for cls in counts if counts[cls] == len(args))[:1]
bases = common_bases(*args)[:1]
if bases[0] in args:
return bases[0]
if origin is Callable and args[:1] == (...,):
args = args[1:]
namespace = {'__origin__': origin, '__args__': args}
Expand Down
1 change: 0 additions & 1 deletion tests/test_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def _(left, right):
return 'paper covers rock'


@pytest.mark.benchmark
def test_roshambo():
assert roshambo.__name__ == 'roshambo'
r, p, s = rock(), paper(), scissors()
Expand Down
3 changes: 2 additions & 1 deletion tests/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def join(seq: tree, sep: object):
return join(seq.walk(), sep)


@pytest.mark.benchmark
def test_join():
sep = '<>'
seq = [0, tree([1]), 2]
Expand Down Expand Up @@ -94,7 +95,7 @@ def test_signature():
assert (type,) - signature([object]) == (1,)
# using EnumMeta because it is a standard, stable, metaclass
assert signature([enum.EnumMeta]) - signature([object]) == (2,)
assert signature([Union[type, enum.EnumMeta]]) - signature([object]) == (2,)
assert signature([Union[type, enum.EnumMeta]]) - signature([object]) == (1,)


class namespace:
Expand Down
1 change: 1 addition & 0 deletions tests/test_subscripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def func(arg: Literal['a', 0]):
def test_union():
assert issubclass(int, subtype(int | float))
assert issubclass(subtype(int | float), subtype(int | float | None))
assert subtype(Iterable | Mapping | Sequence) is Iterable


@pytest.mark.skipif(sys.version_info < (3, 12), reason="Type aliases added in 3.12")
Expand Down

0 comments on commit a98c163

Please sign in to comment.