Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TypeIs support #112

Merged
merged 3 commits into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions py_cachify/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import annotations

from .backend.cached import cached
from .backend.exceptions import CachifyInitError, CachifyLockError
from .backend.helpers import Decoder, Encoder
Expand Down
100 changes: 49 additions & 51 deletions py_cachify/backend/cached.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,80 @@
from __future__ import annotations

import inspect
from functools import partial, wraps
from typing import Awaitable, Callable, Optional, Tuple, TypeVar, Union, cast
from functools import wraps
from typing import Awaitable, Callable, Tuple, TypeVar, Union, cast, overload

from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, deprecated

from py_cachify.backend.lib import get_cachify

from .helpers import Decoder, Encoder, encode_decode_value, get_full_key_from_signature, is_coroutine
from .helpers import Decoder, Encoder, SyncOrAsync, encode_decode_value, get_full_key_from_signature, is_coroutine


R = TypeVar('R')
P = ParamSpec('P')


def _decorator(
_func: Union[Callable[P, R], Callable[P, Awaitable[R]]],
key: str,
ttl: Union[int, None] = None,
enc_dec: Optional[Tuple[Encoder, Decoder]] = None,
) -> Union[Callable[P, R], Callable[P, Awaitable[R]]]:
signature = inspect.signature(_func)
def cached(key: str, ttl: Union[int, None] = None, enc_dec: Union[Tuple[Encoder, Decoder], None] = None) -> SyncOrAsync:
@overload
def _cached_inner(
_func: Callable[P, Awaitable[R]],
) -> Callable[P, Awaitable[R]]: ...

enc, dec = None, None
if enc_dec is not None:
enc, dec = enc_dec
@overload
def _cached_inner(
_func: Callable[P, R],
) -> Callable[P, R]: ...

if is_coroutine(_func):
_awaitable_func = _func
def _cached_inner( # type: ignore[misc]
_func: Union[Callable[P, R], Callable[P, Awaitable[R]]],
) -> Union[Callable[P, R], Callable[P, Awaitable[R]]]:
signature = inspect.signature(_func)

@wraps(_awaitable_func)
async def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cachify = get_cachify()
_key = get_full_key_from_signature(bound_args=signature.bind(*args, **kwargs), key=key)
if val := await cachify.a_get(key=_key):
return encode_decode_value(encoder_decoder=dec, val=val)
enc, dec = None, None
if enc_dec is not None:
enc, dec = enc_dec

res = await _awaitable_func(*args, **kwargs)
await cachify.a_set(key=_key, val=encode_decode_value(encoder_decoder=enc, val=res), ttl=ttl)
return res
if is_coroutine(_func):
_awaitable_func = _func

return cast(Callable[P, Awaitable[R]], _async_wrapper)
else:
@wraps(_awaitable_func)
async def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cachify = get_cachify()
_key = get_full_key_from_signature(bound_args=signature.bind(*args, **kwargs), key=key)
if val := await cachify.a_get(key=_key):
return encode_decode_value(encoder_decoder=dec, val=val)

@wraps(_func)
def _sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cachify = get_cachify()
_key = get_full_key_from_signature(bound_args=signature.bind(*args, **kwargs), key=key)
if val := cachify.get(key=_key):
return encode_decode_value(encoder_decoder=dec, val=val)
res = await _awaitable_func(*args, **kwargs)
await cachify.a_set(key=_key, val=encode_decode_value(encoder_decoder=enc, val=res), ttl=ttl)
return res

res = _func(*args, **kwargs)
cachify.set(key=_key, val=encode_decode_value(encoder_decoder=enc, val=res), ttl=ttl)
return cast(R, res)
return _async_wrapper
else:

return cast(Callable[P, R], _sync_wrapper)
@wraps(_func) # type: ignore[unreachable]
def _sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cachify = get_cachify()
_key = get_full_key_from_signature(bound_args=signature.bind(*args, **kwargs), key=key)
if val := cachify.get(key=_key):
return encode_decode_value(encoder_decoder=dec, val=val)

res = _func(*args, **kwargs)
cachify.set(key=_key, val=encode_decode_value(encoder_decoder=enc, val=res), ttl=ttl)
return cast(R, res)

def cached(
key: str, ttl: Union[int, None] = None, enc_dec: Union[Tuple[Encoder, Decoder], None] = None
) -> Callable[[Union[Callable[P, Awaitable[R]], Callable[P, R]]], Union[Callable[P, Awaitable[R]], Callable[P, R]]]:
return cast(
Callable[[Union[Callable[P, Awaitable[R]], Callable[P, R]]], Union[Callable[P, Awaitable[R]], Callable[P, R]]],
partial(_decorator, key=key, ttl=ttl, enc_dec=enc_dec),
)
return _sync_wrapper

return _cached_inner


@deprecated('sync_cached is deprecated, use cached instead. Scheduled for removal in 1.3.0')
def sync_cached(
key: str, ttl: Union[int, None] = None, enc_dec: Union[Tuple[Encoder, Decoder], None] = None
) -> Callable[[Callable[P, R]], Callable[P, R]]:
return cast(Callable[[Callable[P, R]], Callable[P, R]], partial(_decorator, key=key, ttl=ttl, enc_dec=enc_dec))
return cached(key=key, ttl=ttl, enc_dec=enc_dec)


@deprecated('async_cached is deprecated, use cached instead. Scheduled for removal in 1.3.0')
def async_cached(
key: str, ttl: Union[int, None] = None, enc_dec: Union[Tuple[Encoder, Decoder], None] = None
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
return cast(
Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]],
partial(_decorator, key=key, ttl=ttl, enc_dec=enc_dec),
)
return cached(key=key, ttl=ttl, enc_dec=enc_dec)
20 changes: 17 additions & 3 deletions py_cachify/backend/helpers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import inspect
from typing import Any, Awaitable, Callable, TypeVar, Union
from typing import Any, Awaitable, Callable, TypeVar, Union, overload

from typing_extensions import ParamSpec, TypeAlias, TypeGuard
from typing_extensions import ParamSpec, Protocol, TypeAlias, TypeIs


R = TypeVar('R')
Expand All @@ -23,7 +23,9 @@ def get_full_key_from_signature(bound_args: inspect.BoundArguments, key: str) ->
raise ValueError('Arguments in a key do not match function signature') from None


def is_coroutine(func: Union[Callable[P, R], Callable[P, Awaitable[R]]]) -> TypeGuard[Callable[P, Awaitable[R]]]:
def is_coroutine(
func: Callable[P, Union[R, Awaitable[R]]],
) -> TypeIs[Callable[P, Awaitable[R]]]:
return asyncio.iscoroutinefunction(func)


Expand All @@ -32,3 +34,15 @@ def encode_decode_value(encoder_decoder: Union[Encoder, Decoder, None], val: Any
return val

return encoder_decoder(val)


class SyncOrAsync(Protocol):
@overload
def __call__(self, _func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...

@overload
def __call__(self, _func: Callable[P, R]) -> Callable[P, R]: ...

def __call__( # type: ignore[misc]
self, _func: Union[Callable[P, Awaitable[R]], Callable[P, R]]
) -> Union[Callable[P, Awaitable[R]], Callable[P, R]]: ...
105 changes: 50 additions & 55 deletions py_cachify/backend/lock.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from __future__ import annotations

import inspect
import logging
from contextlib import asynccontextmanager, contextmanager
from functools import partial, wraps
from typing import Any, AsyncGenerator, Awaitable, Callable, Generator, TypeVar, Union, cast
from functools import wraps
from typing import Any, AsyncGenerator, Awaitable, Callable, Generator, TypeVar, Union

from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, deprecated, overload

from .exceptions import CachifyLockError
from .helpers import get_full_key_from_signature, is_coroutine
from .helpers import SyncOrAsync, get_full_key_from_signature, is_coroutine
from .lib import get_cachify


Expand Down Expand Up @@ -52,74 +50,71 @@ def lock(key: str) -> Generator[None, None, None]:
_cachify.delete(key=key)


def _decorator(
_func: Union[Callable[P, R], Callable[P, Awaitable[R]]],
key: str,
raise_on_locked: bool = False,
return_on_locked: Any = None,
) -> Union[Callable[P, R], Callable[P, Awaitable[R]]]:
signature = inspect.signature(_func)
def once(key: str, raise_on_locked: bool = False, return_on_locked: Any = None) -> SyncOrAsync:
@overload
def _once_inner(
_func: Callable[P, Awaitable[R]],
) -> Callable[P, Awaitable[R]]: ...

if is_coroutine(_func):
_awaitable_func = _func
@overload
def _once_inner(
_func: Callable[P, R],
) -> Callable[P, R]: ...

@wraps(_awaitable_func)
async def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
bound_args = signature.bind(*args, **kwargs)
_key = get_full_key_from_signature(bound_args=bound_args, key=key)
def _once_inner( # type: ignore[misc]
_func: Union[Callable[P, R], Callable[P, Awaitable[R]]],
) -> Union[Callable[P, R], Callable[P, Awaitable[R]]]:
signature = inspect.signature(_func)

try:
async with async_lock(key=_key):
return await _awaitable_func(*args, **kwargs)
except CachifyLockError:
if raise_on_locked:
raise
if is_coroutine(_func):
_awaitable_func = _func

return return_on_locked
@wraps(_awaitable_func)
async def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
bound_args = signature.bind(*args, **kwargs)
_key = get_full_key_from_signature(bound_args=bound_args, key=key)

return cast(Callable[P, Awaitable[R]], _async_wrapper)
try:
async with async_lock(key=_key):
return await _awaitable_func(*args, **kwargs)
except CachifyLockError:
if raise_on_locked:
raise

else:
return return_on_locked

@wraps(_func)
def _sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
bound_args = signature.bind(*args, **kwargs)
_key = get_full_key_from_signature(bound_args=bound_args, key=key)
return _async_wrapper

try:
with lock(key=_key):
return _func(*args, **kwargs)
except CachifyLockError:
if raise_on_locked:
raise
else:

return return_on_locked
@wraps(_func) # type: ignore[unreachable]
def _sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
bound_args = signature.bind(*args, **kwargs)
_key = get_full_key_from_signature(bound_args=bound_args, key=key)

return cast(Callable[P, R], _sync_wrapper)
try:
with lock(key=_key):
return _func(*args, **kwargs)
except CachifyLockError:
if raise_on_locked:
raise

return return_on_locked

def once(
key: str, raise_on_locked: bool = False, return_on_locked: Any = None
) -> Callable[[Union[Callable[P, Awaitable[R]], Callable[P, R]]], Union[Callable[P, Awaitable[R]], Callable[P, R]]]:
return cast(
Callable[[Union[Callable[P, Awaitable[R]], Callable[P, R]]], Union[Callable[P, Awaitable[R]], Callable[P, R]]],
partial(_decorator, key=key, raise_on_locked=raise_on_locked, return_on_locked=return_on_locked),
)
return _sync_wrapper

return _once_inner


@deprecated('sync_once is deprecated, use once instead. Scheduled for removal in 1.3.0')
def sync_once(
key: str, raise_on_locked: bool = False, return_on_locked: Any = None
) -> Callable[[Callable[P, R]], Callable[P, R]]:
return cast(
Callable[[Callable[P, R]], Callable[P, R]],
partial(_decorator, key=key, raise_on_locked=raise_on_locked, return_on_locked=return_on_locked),
)
return once(key=key, raise_on_locked=raise_on_locked, return_on_locked=return_on_locked)


@deprecated('async_once is deprecated, use once instead. Scheduled for removal in 1.3.0')
def async_once(
key: str, raise_on_locked: bool = False, return_on_locked: Any = None
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
return cast(
Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]],
partial(_decorator, key=key, raise_on_locked=raise_on_locked, return_on_locked=return_on_locked),
)
return once(key=key, raise_on_locked=raise_on_locked, return_on_locked=return_on_locked)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ omit = [

exclude_lines = [
'pragma: no cover',
'@overload',
'SyncOrAsync',
'@abstract',
'def __repr__',
'raise AssertionError',
Expand Down
Loading