Skip to content

Commit

Permalink
fix(scheduler): Allow ScheduledJob to run with args/kwargs of type Py…
Browse files Browse the repository at this point in the history
…dantic's BaseModel along with global encoder PydanticEncoder
  • Loading branch information
nkphan committed Dec 2, 2024
1 parent adf1424 commit a29b3bb
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 7 deletions.
3 changes: 2 additions & 1 deletion remoulade/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .broker import Broker, Consumer, MessageProxy, change_broker, declare_actors, get_broker, set_broker
from .collection_results import CollectionResults
from .composition import group, pipeline
from .encoder import Encoder, JSONEncoder, PickleEncoder
from .encoder import Encoder, JSONEncoder, PickleEncoder, PydanticEncoder
from .errors import (
ActorNotFound,
BrokerError,
Expand Down Expand Up @@ -62,6 +62,7 @@
"Encoder",
"JSONEncoder",
"PickleEncoder",
"PydanticEncoder",
# Errors
"RemouladeError",
"BrokerError",
Expand Down
34 changes: 29 additions & 5 deletions remoulade/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import hashlib
import json
import time
from typing import Dict, List, Union
from typing import Dict, List, Union, Any

import pytz
import redis
from pydantic import BaseModel

from remoulade import Broker, get_encoder, get_logger

Expand Down Expand Up @@ -62,9 +63,32 @@ def __init__(
self.args = args if args is not None else []
self.kwargs = kwargs if kwargs is not None else {}

@property
def args_json_serializable(self) -> List:
results = []
for arg in self.args:
if isinstance(arg, BaseModel):
results.append(arg.model_dump(mode="json"))
results.append(arg)
return results

@property
def kwargs_json_serializable(self) -> Dict:
def _value_as_dict(value: Dict | BaseModel | Any) -> Dict:
if isinstance(value, dict):
return {
_key: _value_as_dict(_value) for _key, _value in value.items()
}
elif isinstance(value, BaseModel):
return value.model_dump(mode="json")
else:
return value
return {k: _value_as_dict(v) for k, v in self.kwargs.items()}

def get_hash(self) -> str:
args = json.dumps(self.args)
kwargs = json.dumps(sorted(list(self.kwargs.items()), key=lambda x: x[0]))
args = json.dumps(self.args_json_serializable)
kwargs = json.dumps(sorted(list(self.kwargs_json_serializable.items()), key=lambda x: x[0]))


path = [
str(getattr(self, k)) for k in ("actor_name", "interval", "daily_time", "iso_weekday", "enabled", "tz")
Expand All @@ -85,8 +109,8 @@ def as_dict(self, encode: bool = False) -> Dict:
"enabled": self.enabled,
"last_queued": self.last_queued,
"tz": self.tz,
"args": self.args,
"kwargs": self.kwargs,
"args": self.args_json_serializable,
"kwargs": self.kwargs_json_serializable,
}
if encode:
job_dict["daily_time"] = (
Expand Down
26 changes: 26 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest
import redis
from freezegun import freeze_time
from pydantic import BaseModel
from sqlalchemy.engine import create_engine
from sqlalchemy.inspection import inspect
from sqlalchemy.orm.session import sessionmaker
Expand Down Expand Up @@ -377,6 +378,31 @@ def pickle_encoder():
remoulade.set_encoder(old_encoder)


@pytest.fixture
def pydantic_encoder():
old_encoder = remoulade.get_encoder()
new_encoder = remoulade.PydanticEncoder()
remoulade.set_encoder(new_encoder)
yield new_encoder
remoulade.set_encoder(old_encoder)


class SimpleStructure(BaseModel):
data: str


@pytest.fixture
def actor_with_pydantic_args_kwargs():
broker = remoulade.get_broker()

@remoulade.actor()
def actor_with_pydantic_args_kwargs(input_data: SimpleStructure) -> SimpleStructure:
return SimpleStructure(data=input_data.data + " processed")

broker.declare_actor(actor_with_pydantic_args_kwargs)
return actor_with_pydantic_args_kwargs


def mock_func(func):
event = threading.Event()

Expand Down
30 changes: 29 additions & 1 deletion tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import json
import threading
import time
from unittest.mock import ANY

import pytest
import pytz

import remoulade
from remoulade.scheduler import ScheduledJob
from tests.conftest import check_redis, mock_func, new_scheduler
from tests.conftest import check_redis, mock_func, new_scheduler, SimpleStructure


def test_simple_interval_scheduler(stub_broker, stub_worker, scheduler, scheduler_thread, mul, add):
Expand Down Expand Up @@ -396,3 +397,30 @@ def test_tz_aware_last_queued(scheduler, api_client, do_work):
)

assert res.status_code == 400


def test_scheduled_job_with_pydantic_encoder_correct_hashing(
pydantic_encoder, scheduler, actor_with_pydantic_args_kwargs
):
scheduler.schedule = [
job := ScheduledJob(
actor_name="actor_with_pydantic_args_kwargs",
kwargs={"input_data": SimpleStructure(data="data")},
iso_weekday=1,
daily_time=datetime.time(3, 0, 0),
)
]
scheduler.sync_config()
scheduled_job_by_job_hash = scheduler.get_redis_schedule()
assert scheduled_job_by_job_hash[job.get_hash()].as_dict() == {
"actor_name": "actor_with_pydantic_args_kwargs",
"args": [],
"daily_time": datetime.time(3, 0),
"enabled": True,
"hash": job.get_hash(),
"interval": 86400,
"iso_weekday": 1,
"kwargs": {"input_data": {"data": "data"}},
"last_queued": ANY,
"tz": "UTC",
}

0 comments on commit a29b3bb

Please sign in to comment.