Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add use nx option #215

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions asyncio_redis_rate_limit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class RateLimiter:
'_backend',
'_cache_prefix',
'_lock',
'_use_nx_on_expire'
)

def __init__(
Expand All @@ -58,13 +59,15 @@ def __init__(
backend: AnyRedis,
*,
cache_prefix: str,
use_nx_on_expire: bool = True,
) -> None:
"""In the future other backends might be supported as well."""
self._unique_key = unique_key
self._rate_spec = rate_spec
self._backend = backend
self._cache_prefix = cache_prefix
self._lock = asyncio.Lock()
self._use_nx_on_expire = use_nx_on_expire

async def __aenter__(self: _RateLimiterT) -> _RateLimiterT:
"""
Expand Down Expand Up @@ -110,6 +113,7 @@ async def _run_pipeline(
pipeline.incr(cache_key),
cache_key,
self._rate_spec.seconds,
use_nx=self._use_nx_on_expire,
).execute()
return current_rate # type: ignore[no-any-return]

Expand All @@ -130,6 +134,7 @@ def rate_limit( # noqa: WPS320
backend: AnyRedis,
*,
cache_prefix: str = 'aio-rate-limit',
use_nx_on_expire: bool = True,
) -> Callable[
[_CoroutineFunction[_ParamsT, _ResultT]],
_CoroutineFunction[_ParamsT, _ResultT],
Expand Down Expand Up @@ -167,6 +172,7 @@ async def factory(
backend=backend,
rate_spec=rate_spec,
cache_prefix=cache_prefix,
use_nx_on_expire=use_nx_on_expire,
):
return await function(*args, **kwargs)
return factory
Expand Down
5 changes: 5 additions & 0 deletions asyncio_redis_rate_limit/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,13 @@ def pipeline_expire(
pipeline: Any,
cache_key: str,
seconds: int,
*,
use_nx: bool = True,
) -> AnyPipeline:
"""Compatibility mode for `.expire(..., nx=True)` command."""
if not use_nx:
return pipeline.expire(cache_key, seconds) # type: ignore

if isinstance(pipeline, _AsyncPipeline):
return pipeline.expire(cache_key, seconds, nx=True) # type: ignore
# `aioredis` somehow does not have this boolean argument in `.expire`,
Expand Down
27 changes: 27 additions & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def __call__(
self,
requests: int = ...,
seconds: int = ...,
*,
use_nx_on_expire: bool = ...,
) -> _LimitedSig:
"""We use this callback to construct `limited` test function."""

Expand Down Expand Up @@ -246,6 +248,31 @@ async def test_ten_reqs_in_two_secs2(
await asyncio.sleep(1 + 0.5)
await function()

@pytest.mark.repeat(5)
async def test_ten_reqs_in_two_secs_without_nx(
limited: _LimitedCallback,
) -> None:
"""Ensure that several gathered coroutines do respect the rate limit."""
function = limited(requests=10, seconds=2, use_nx_on_expire=False)

# Or just consume all:
for attempt in range(10):
await function(attempt)

# This one will fail:
with pytest.raises(RateLimitError):
await function()

# Now, let's move time to the next second:
await asyncio.sleep(1)

# This one will also fail:
with pytest.raises(RateLimitError):
await function()

# Next attempts will pass:
await asyncio.sleep(1 + 0.5)
await function()

class _Counter:
def __init__(self) -> None:
Expand Down
Loading