Skip to content

Commit

Permalink
Add task clearing methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeshardmind committed Oct 8, 2024
1 parent a1863c6 commit 1cbc865
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
2 changes: 1 addition & 1 deletion async_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.


__version__ = "5.1.1"
__version__ = "5.2.0"
35 changes: 34 additions & 1 deletion async_utils/task_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,20 @@
from collections.abc import Callable, Coroutine, Hashable
from functools import partial
from typing import Any, ParamSpec, TypeVar
from weakref import WeakKeyDictionary

from ._cpython_stuff import make_key

__all__ = ("taskcache",)
__all__ = ("taskcache", "clear_cache", "remove_cache_entry")


P = ParamSpec("P")
T = TypeVar("T")


_caches: WeakKeyDictionary[Hashable, dict[Hashable, asyncio.Task[Any]]] = WeakKeyDictionary()


def taskcache(
ttl: float | None = None,
) -> Callable[[Callable[P, Coroutine[Any, Any, T]]], Callable[P, asyncio.Task[T]]]:
Expand All @@ -40,6 +44,8 @@ def taskcache(
Note: This uses the args and kwargs of the original coroutine function as a cache key.
This includes instances (self) when wrapping methods.
Consider not wrapping instance methods, but what those methods call when feasible in cases where this may matter.
The ordering of args and kwargs matters.
"""

def wrapper(coro: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, asyncio.Task[T]]:
Expand All @@ -63,6 +69,33 @@ def wrapped(*args: P.args, **kwargs: P.kwargs) -> asyncio.Task[T]:
task.add_done_callback(call_after_ttl)
return task

_caches[wrapped] = internal_cache
return wrapped

return wrapper


def clear_cache(f: Callable[..., Any]) -> None:
"""
Clear the cache of a decorated function
"""
cache = _caches.get(f)
if cache is None:
raise RuntimeError(f"{f:!r} is not a function wrapped with taskcache")
cache.clear()


def remove_cache_entry(f: Callable[..., Any], *args: Hashable, **kwargs: Hashable):
"""
Remove the cache entry for a specific arg/kwarg combination.
The ordering of args and kwargs must match.
Will not error under a missing key under the assumption that a race condition on removal
for various reasons (such as ttl) could occur
"""

cache = _caches.get(f)
if cache is None:
raise RuntimeError(f"{f:!r} is not a function wrapped with taskcache")
cache.pop(make_key(args, kwargs))

0 comments on commit 1cbc865

Please sign in to comment.