Skip to content

Commit

Permalink
feat(autorun): add auto_await to AutorunOptions so that one can d…
Browse files Browse the repository at this point in the history
…efine an autorun/view as a decorator of a function without automatically awaiting its result, when `auto_await` is set to `False`, which activates the new behavior, the decorated function passes `asyncio.iscoroutinefunction` test, useful for certain libraries like quart
  • Loading branch information
sassanh committed Oct 8, 2024
1 parent 4527ef3 commit d4fe4de
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 21 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## Version 0.18.0

- feat(autorun): add `auto_await` to `AutorunOptions` so that one can define an autorun/view as a decorator of a function without automatically awaiting its result, when `auto_await` is set to `False`, which activates the new behavior, the decorated function passes `asyncio.iscoroutinefunction` test, useful for certain libraries like quart

## Version 0.17.1

- refactor(core): allow `None` type for state, action and event types in `ReducerResult` and `CompleteReducerResult`
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "python-redux"
version = "0.17.1"
version = "0.18.0"
description = "Redux implementation for Python"
authors = ["Sassan Haradji <[email protected]>"]
license = "Apache-2.0"
Expand Down
72 changes: 60 additions & 12 deletions redux/autorun.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
# ruff: noqa: D100, D101, D102, D103, D104, D105, D107
from __future__ import annotations

import asyncio
import functools
import inspect
import weakref
from asyncio import Future, Task, iscoroutine, iscoroutinefunction
from typing import TYPE_CHECKING, Any, Callable, Concatenate, Generic, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Concatenate,
Coroutine,
Generator,
Generic,
TypeVar,
cast,
)

from redux.basic_types import (
Action,
Expand All @@ -22,6 +33,25 @@
from redux.main import Store


T = TypeVar('T')


class AwaitableWrapper(Generic[T]):
def __init__(self, coro: Coroutine[None, None, T]) -> None:
self.coro = coro
self.awaited = False

def __await__(self) -> Generator[None, None, T]:
self.awaited = True
return self.coro.__await__()

def close(self) -> None:
self.coro.close()

def __repr__(self) -> str:
return f'AwaitableWrapper({self.coro}, awaited={self.awaited})'


class Autorun(
Generic[
State,
Expand All @@ -45,6 +75,7 @@ def __init__(
],
options: AutorunOptions[AutorunOriginalReturnType],
) -> None:
self.__name__ = func.__name__
self._store = store
self._selector = selector
self._comparator = comparator
Expand All @@ -55,6 +86,11 @@ def __init__(
self._func = weakref.WeakMethod(func, self.unsubscribe)
else:
self._func = weakref.ref(func, self.unsubscribe)
self._is_coroutine = (
asyncio.coroutines._is_coroutine # pyright: ignore [reportAttributeAccessIssue] # noqa: SLF001
if asyncio.iscoroutinefunction(func)
else None
)
self._options = options

self._last_selector_result: SelectorOutput | None = None
Expand Down Expand Up @@ -120,11 +156,11 @@ def _task_callback(
],
task: Task,
*,
future: Future | None,
future: Future,
) -> None:
task.add_done_callback(
lambda result: (
future.set_result(result.result()) if future else None,
future.set_result(result.result()),
self.inform_subscribers(),
),
)
Expand Down Expand Up @@ -184,15 +220,27 @@ def _call(
)
create_task = self._store._create_task # noqa: SLF001
if iscoroutine(value) and create_task:
future = Future()
self._latest_value = cast(AutorunOriginalReturnType, future)
create_task(
value,
callback=functools.partial(
self._task_callback,
future=future,
),
)
if self._options.auto_await:
future = Future()
self._latest_value = cast(AutorunOriginalReturnType, future)
create_task(
value,
callback=functools.partial(
self._task_callback,
future=future,
),
)
else:
if (
self._latest_value is not None
and isinstance(self._latest_value, AwaitableWrapper)
and not self._latest_value.awaited
):
self._latest_value.close()
self._latest_value = cast(
AutorunOriginalReturnType,
AwaitableWrapper(value),
)
else:
self._latest_value = value
self.inform_subscribers()
Expand Down
3 changes: 3 additions & 0 deletions redux/basic_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class CreateStoreOptions(Immutable, Generic[Action, Event]):

class AutorunOptions(Immutable, Generic[AutorunOriginalReturnType]):
default_value: AutorunOriginalReturnType | None = None
auto_await: bool = True
initial_call: bool = True
reactive: bool = True
keep_ref: bool = True
Expand Down Expand Up @@ -167,6 +168,8 @@ def subscribe(

def unsubscribe(self: AutorunReturnType) -> None: ...

__name__: str


class AutorunDecorator(
Protocol,
Expand Down
1 change: 1 addition & 0 deletions redux/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ def decorator(
func=cast(Callable, func),
options=AutorunOptions(
default_value=_options.default_value,
auto_await=True,
initial_call=False,
reactive=False,
keep_ref=_options.keep_ref,
Expand Down
4 changes: 1 addition & 3 deletions redux_pytest/fixtures/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,7 @@ def monitor(self: StoreSnapshot[State], selector: Callable[[State], Any]) -> Non
"""Monitor the state of the store and take snapshots."""

@self.store.autorun(selector=selector)
def _(state: object | None) -> None:
if state is None:
return
def _(state: object) -> None:
self.take(selector=lambda _: state)

def close(self: StoreSnapshot[State]) -> None:
Expand Down
64 changes: 60 additions & 4 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
if TYPE_CHECKING:
from redux_pytest.fixtures.event_loop import LoopThread

INCREMENTS = 2
INCREMENTS = 20


class StateType(Immutable):
Expand Down Expand Up @@ -91,16 +91,18 @@ def test_autorun(
event_loop: LoopThread,
) -> None:
@store.autorun(lambda state: state.value)
async def _(value: int) -> int:
async def sync_mirror(value: int) -> int:
await asyncio.sleep(value / 10)
store.dispatch(SetMirroredValueAction(value=value))
return value

assert asyncio.iscoroutinefunction(sync_mirror)

@store.autorun(
lambda state: state.mirrored_value,
lambda state: state.mirrored_value >= INCREMENTS,
)
async def _(mirrored_value: int) -> None:
def _(mirrored_value: int) -> None:
if mirrored_value < INCREMENTS:
return
event_loop.stop()
Expand All @@ -109,6 +111,36 @@ async def _(mirrored_value: int) -> None:
dispatch_actions(store)


def test_autorun_autoawait(
store: StoreType,
event_loop: LoopThread,
) -> None:
@store.autorun(lambda state: state.value, options=AutorunOptions(auto_await=False))
async def sync_mirror(value: int) -> int:
store.dispatch(SetMirroredValueAction(value=value))
return value * 2

assert asyncio.iscoroutinefunction(sync_mirror)

@store.autorun(lambda state: (state.value, state.mirrored_value))
async def _(values: tuple[int, int]) -> None:
value, mirrored_value = values
if mirrored_value != value:
assert 'awaited=False' in str(sync_mirror())
await sync_mirror()
assert 'awaited=True' in str(sync_mirror())
with pytest.raises(
RuntimeError,
match=r'^cannot reuse already awaited coroutine$',
):
await sync_mirror()
elif value < INCREMENTS:
store.dispatch(IncrementAction())
else:
event_loop.stop()
store.dispatch(FinishAction())


def test_autorun_default_value(
store: StoreType,
event_loop: LoopThread,
Expand All @@ -122,7 +154,7 @@ async def _(value: int) -> int:
lambda state: state.mirrored_value,
lambda state: state.mirrored_value >= INCREMENTS,
)
async def _(mirrored_value: int) -> None:
def _(mirrored_value: int) -> None:
if mirrored_value < INCREMENTS:
return
event_loop.stop()
Expand Down Expand Up @@ -155,6 +187,30 @@ async def _(value: int) -> None:
assert calls == list(range(INCREMENTS + 1))


def test_view_await(store: StoreType, event_loop: LoopThread) -> None:
calls = []

@store.view(lambda state: state.value)
async def doubled(value: int) -> int:
calls.append(value)
return value * 2

assert asyncio.iscoroutinefunction(doubled)

@store.autorun(lambda state: state.value)
async def _(value: int) -> None:
calls_length = len(calls)
assert await doubled() == value * 2
assert len(calls) == calls_length + 1

if value < INCREMENTS:
store.dispatch(IncrementAction())
else:
event_loop.stop()
store.dispatch(FinishAction())
assert calls == list(range(INCREMENTS + 1))


def test_view_with_args(
store: StoreType,
event_loop: LoopThread,
Expand Down
4 changes: 3 additions & 1 deletion tests/test_autorun.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,12 @@ def store() -> Generator[StoreType, None, None]:

def test_general(store_snapshot: StoreSnapshot, store: StoreType) -> None:
@store.autorun(lambda state: state.value)
def _(value: int) -> int:
def decorated(value: int) -> int:
store_snapshot.take()
return value

assert decorated.__name__ == 'decorated'


def test_ignore_attribute_error_in_selector(store: StoreType) -> None:
@store.autorun(lambda state: cast(Any, state).non_existing)
Expand Down

0 comments on commit d4fe4de

Please sign in to comment.