diff --git a/CHANGELOG.md b/CHANGELOG.md index 51f0db8..76481e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## Version 0.17.2 + +- feat(autorun): add `auto_await` to `AutorunOptions` so that one can define an autorun/view as a decorator of a function without automatically awaiting its result, when `auto_await` is set to `False`, which activates the new behavior, the decorated function passes `asyncio.iscoroutinefunction` test, useful for certain libraries like quart + ## Version 0.17.1 - refactor(core): allow `None` type for state, action and event types in `ReducerResult` and `CompleteReducerResult` diff --git a/pyproject.toml b/pyproject.toml index 2b6cc44..9f6e591 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "python-redux" -version = "0.17.1" +version = "0.17.2" description = "Redux implementation for Python" authors = ["Sassan Haradji "] license = "Apache-2.0" diff --git a/redux/autorun.py b/redux/autorun.py index 2fcfbc6..d007589 100644 --- a/redux/autorun.py +++ b/redux/autorun.py @@ -1,11 +1,22 @@ # ruff: noqa: D100, D101, D102, D103, D104, D105, D107 from __future__ import annotations +import asyncio import functools import inspect import weakref from asyncio import Future, Task, iscoroutine, iscoroutinefunction -from typing import TYPE_CHECKING, Any, Callable, Concatenate, Generic, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Concatenate, + Coroutine, + Generator, + Generic, + TypeVar, + cast, +) from redux.basic_types import ( Action, @@ -22,6 +33,25 @@ from redux.main import Store +T = TypeVar('T') + + +class AwaitableWrapper(Generic[T]): + def __init__(self, coro: Coroutine[None, None, T]) -> None: + self.coro = coro + self.awaited = False + + def __await__(self) -> Generator[None, None, T]: + self.awaited = True + return self.coro.__await__() + + def close(self) -> None: + self.coro.close() + + def __repr__(self) -> str: + return f'AwaitableWrapper({self.coro}, awaited={self.awaited})' + + class Autorun( Generic[ State, @@ -45,6 +75,7 @@ def __init__( ], options: AutorunOptions[AutorunOriginalReturnType], ) -> None: + self.__name__ = func.__name__ self._store = store self._selector = selector self._comparator = comparator @@ -55,6 +86,11 @@ def __init__( self._func = weakref.WeakMethod(func, self.unsubscribe) else: self._func = weakref.ref(func, self.unsubscribe) + self._is_coroutine = ( + asyncio.coroutines._is_coroutine # pyright: ignore [reportAttributeAccessIssue] # noqa: SLF001 + if asyncio.iscoroutinefunction(func) and not options.auto_await + else None + ) self._options = options self._last_selector_result: SelectorOutput | None = None @@ -120,11 +156,11 @@ def _task_callback( ], task: Task, *, - future: Future | None, + future: Future, ) -> None: task.add_done_callback( lambda result: ( - future.set_result(result.result()) if future else None, + future.set_result(result.result()), self.inform_subscribers(), ), ) @@ -184,15 +220,27 @@ def _call( ) create_task = self._store._create_task # noqa: SLF001 if iscoroutine(value) and create_task: - future = Future() - self._latest_value = cast(AutorunOriginalReturnType, future) - create_task( - value, - callback=functools.partial( - self._task_callback, - future=future, - ), - ) + if self._options.auto_await: + future = Future() + self._latest_value = cast(AutorunOriginalReturnType, future) + create_task( + value, + callback=functools.partial( + self._task_callback, + future=future, + ), + ) + else: + if ( + self._latest_value is not None + and isinstance(self._latest_value, AwaitableWrapper) + and not self._latest_value.awaited + ): + self._latest_value.close() + self._latest_value = cast( + AutorunOriginalReturnType, + AwaitableWrapper(value), + ) else: self._latest_value = value self.inform_subscribers() diff --git a/redux/basic_types.py b/redux/basic_types.py index e227d19..a0498c3 100644 --- a/redux/basic_types.py +++ b/redux/basic_types.py @@ -133,6 +133,7 @@ class CreateStoreOptions(Immutable, Generic[Action, Event]): class AutorunOptions(Immutable, Generic[AutorunOriginalReturnType]): default_value: AutorunOriginalReturnType | None = None + auto_await: bool = True initial_call: bool = True reactive: bool = True keep_ref: bool = True @@ -167,6 +168,8 @@ def subscribe( def unsubscribe(self: AutorunReturnType) -> None: ... + __name__: str + class AutorunDecorator( Protocol, diff --git a/redux/main.py b/redux/main.py index ecb3c6c..cf2fd0b 100644 --- a/redux/main.py +++ b/redux/main.py @@ -388,6 +388,7 @@ def decorator( func=cast(Callable, func), options=AutorunOptions( default_value=_options.default_value, + auto_await=False, initial_call=False, reactive=False, keep_ref=_options.keep_ref, diff --git a/redux_pytest/fixtures/snapshot.py b/redux_pytest/fixtures/snapshot.py index 0fa98e9..6ac1d91 100644 --- a/redux_pytest/fixtures/snapshot.py +++ b/redux_pytest/fixtures/snapshot.py @@ -128,9 +128,7 @@ def monitor(self: StoreSnapshot[State], selector: Callable[[State], Any]) -> Non """Monitor the state of the store and take snapshots.""" @self.store.autorun(selector=selector) - def _(state: object | None) -> None: - if state is None: - return + def _(state: object) -> None: self.take(selector=lambda _: state) def close(self: StoreSnapshot[State]) -> None: diff --git a/tests/test_async.py b/tests/test_async.py index 626db38..c1ab5f5 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -24,7 +24,7 @@ if TYPE_CHECKING: from redux_pytest.fixtures.event_loop import LoopThread -INCREMENTS = 2 +INCREMENTS = 20 class StateType(Immutable): @@ -91,16 +91,18 @@ def test_autorun( event_loop: LoopThread, ) -> None: @store.autorun(lambda state: state.value) - async def _(value: int) -> int: + async def sync_mirror(value: int) -> int: await asyncio.sleep(value / 10) store.dispatch(SetMirroredValueAction(value=value)) return value + assert not asyncio.iscoroutinefunction(sync_mirror) + @store.autorun( lambda state: state.mirrored_value, lambda state: state.mirrored_value >= INCREMENTS, ) - async def _(mirrored_value: int) -> None: + def _(mirrored_value: int) -> None: if mirrored_value < INCREMENTS: return event_loop.stop() @@ -109,6 +111,36 @@ async def _(mirrored_value: int) -> None: dispatch_actions(store) +def test_autorun_autoawait( + store: StoreType, + event_loop: LoopThread, +) -> None: + @store.autorun(lambda state: state.value, options=AutorunOptions(auto_await=False)) + async def sync_mirror(value: int) -> int: + store.dispatch(SetMirroredValueAction(value=value)) + return value * 2 + + assert asyncio.iscoroutinefunction(sync_mirror) + + @store.autorun(lambda state: (state.value, state.mirrored_value)) + async def _(values: tuple[int, int]) -> None: + value, mirrored_value = values + if mirrored_value != value: + assert 'awaited=False' in str(sync_mirror()) + await sync_mirror() + assert 'awaited=True' in str(sync_mirror()) + with pytest.raises( + RuntimeError, + match=r'^cannot reuse already awaited coroutine$', + ): + await sync_mirror() + elif value < INCREMENTS: + store.dispatch(IncrementAction()) + else: + event_loop.stop() + store.dispatch(FinishAction()) + + def test_autorun_default_value( store: StoreType, event_loop: LoopThread, @@ -122,7 +154,7 @@ async def _(value: int) -> int: lambda state: state.mirrored_value, lambda state: state.mirrored_value >= INCREMENTS, ) - async def _(mirrored_value: int) -> None: + def _(mirrored_value: int) -> None: if mirrored_value < INCREMENTS: return event_loop.stop() @@ -145,7 +177,10 @@ async def doubled(value: int) -> int: @store.autorun(lambda state: state.value) async def _(value: int) -> None: assert await doubled() == value * 2 - for _ in range(10): + with pytest.raises( + RuntimeError, + match=r'^cannot reuse already awaited coroutine$', + ): await doubled() if value < INCREMENTS: store.dispatch(IncrementAction()) @@ -155,6 +190,30 @@ async def _(value: int) -> None: assert calls == list(range(INCREMENTS + 1)) +def test_view_await(store: StoreType, event_loop: LoopThread) -> None: + calls = [] + + @store.view(lambda state: state.value) + async def doubled(value: int) -> int: + calls.append(value) + return value * 2 + + assert asyncio.iscoroutinefunction(doubled) + + @store.autorun(lambda state: state.value) + async def _(value: int) -> None: + calls_length = len(calls) + assert await doubled() == value * 2 + assert len(calls) == calls_length + 1 + + if value < INCREMENTS: + store.dispatch(IncrementAction()) + else: + event_loop.stop() + store.dispatch(FinishAction()) + assert calls == list(range(INCREMENTS + 1)) + + def test_view_with_args( store: StoreType, event_loop: LoopThread, diff --git a/tests/test_autorun.py b/tests/test_autorun.py index 07da8c5..bfc5277 100644 --- a/tests/test_autorun.py +++ b/tests/test_autorun.py @@ -82,10 +82,12 @@ def store() -> Generator[StoreType, None, None]: def test_general(store_snapshot: StoreSnapshot, store: StoreType) -> None: @store.autorun(lambda state: state.value) - def _(value: int) -> int: + def decorated(value: int) -> int: store_snapshot.take() return value + assert decorated.__name__ == 'decorated' + def test_ignore_attribute_error_in_selector(store: StoreType) -> None: @store.autorun(lambda state: cast(Any, state).non_existing)