diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0b61fc3..19b1008 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,7 +21,7 @@ jobs: fail-fast: false matrix: python-version: ['3.9', '3.10', '3.11', '3.12'] - redis-image: ['redis:7.0-alpine'] + redis-image: ['redis:7.0-alpine', 'eqalpha/keydb:alpine'] env-type: ['redis'] include: @@ -31,6 +31,13 @@ jobs: - python-version: '3.9' env-type: 'dev' redis-image: 'redis:7.0-alpine' + - python-version: '3.10' + env-type: 'aioredis' + redis-image: 'eqalpha/keydb:alpine' + - python-version: '3.9' + env-type: 'dev' + redis-image: 'eqalpha/keydb:alpine' + steps: - uses: actions/checkout@v4 diff --git a/asyncio_redis_rate_limit/__init__.py b/asyncio_redis_rate_limit/__init__.py index fd0157d..4110a2f 100644 --- a/asyncio_redis_rate_limit/__init__.py +++ b/asyncio_redis_rate_limit/__init__.py @@ -6,11 +6,7 @@ from typing_extensions import ParamSpec, TypeAlias, final -from asyncio_redis_rate_limit.compat import ( - AnyPipeline, - AnyRedis, - pipeline_expire, -) +from asyncio_redis_rate_limit.compat import AnyPipeline, AnyRedis #: These aliases makes our code more readable. _Seconds: TypeAlias = int @@ -106,11 +102,9 @@ async def _run_pipeline( pipeline: AnyPipeline, ) -> int: # https://redis.io/commands/incr/#pattern-rate-limiter-1 - current_rate, _ = await pipeline_expire( - pipeline.incr(cache_key), - cache_key, - self._rate_spec.seconds, - ).execute() + _, current_rate = await pipeline.set( # type: ignore[union-attr] + cache_key, 0, nx=True, ex=self._rate_spec.seconds, + ).incr(cache_key).execute() return current_rate # type: ignore[no-any-return] def _make_cache_key( diff --git a/asyncio_redis_rate_limit/compat.py b/asyncio_redis_rate_limit/compat.py index fdedb2d..eacfc31 100644 --- a/asyncio_redis_rate_limit/compat.py +++ b/asyncio_redis_rate_limit/compat.py @@ -38,21 +38,3 @@ class _AIORedis: # type: ignore # noqa: WPS306, WPS440 AnyPipeline: TypeAlias = Union['_AsyncPipeline[Any]', _AIOPipeline] AnyRedis: TypeAlias = Union['_AsyncRedis[Any]', _AIORedis] - - -def pipeline_expire( - pipeline: Any, - cache_key: str, - seconds: int, -) -> AnyPipeline: - """Compatibility mode for `.expire(..., nx=True)` command.""" - if isinstance(pipeline, _AsyncPipeline): - return pipeline.expire(cache_key, seconds, nx=True) # type: ignore - # `aioredis` somehow does not have this boolean argument in `.expire`, - # so, we use `EXPIRE` directly with `NX` flag. - return pipeline.execute_command( # type: ignore - 'EXPIRE', - cache_key, - seconds, - 'NX', - )