Skip to content

Commit

Permalink
refactor: move store serializer from test framework to code Store c…
Browse files Browse the repository at this point in the history
…lass

feat: add ability to set custom serializer for store snapshots
  • Loading branch information
sassanh committed Mar 16, 2024
1 parent a17c652 commit 8cd981b
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 47 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## Version 0.12.1

- refactor: move store serializer from test framework to code `Store` class
- feat: add ability to set custom serializer for store snapshots

## Version 0.12.0

- refactor: improve creating new state classes in `combine_reducers` upon registering/unregistering
Expand Down
12 changes: 12 additions & 0 deletions redux/basic_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# ruff: noqa: A003, D100, D101, D102, D103, D104, D105, D107
from __future__ import annotations

from types import NoneType
from typing import Any, Callable, Coroutine, Generic, Protocol, TypeAlias, TypeGuard

from immutable import Immutable
Expand Down Expand Up @@ -191,3 +192,14 @@ class CombineReducerRegisterAction(CombineReducerAction):

class CombineReducerUnregisterAction(CombineReducerAction):
key: str


SnapshotAtom = (
int
| float
| str
| bool
| NoneType
| dict[str, 'SnapshotAtom']
| list['SnapshotAtom']
)
47 changes: 47 additions & 0 deletions redux/main.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
# ruff: noqa: D100, D101, D102, D103, D104, D105, D107
from __future__ import annotations

import dataclasses
import inspect
import queue
import threading
import weakref
from asyncio import create_task, iscoroutine
from collections import defaultdict
from enum import IntEnum, StrEnum
from inspect import signature
from threading import Lock
from types import NoneType
from typing import Any, Callable, Coroutine, Generic, cast

from immutable import Immutable, is_immutable

from redux.autorun import Autorun
from redux.basic_types import (
Action,
Expand All @@ -32,6 +37,7 @@
InitAction,
ReducerType,
SelectorOutput,
SnapshotAtom,
State,
is_complete_reducer_result,
is_state_reducer_result,
Expand Down Expand Up @@ -68,6 +74,8 @@ def run(self: _SideEffectRunnerThread[Event]) -> None:


class Store(Generic[State, Action, Event]):
custom_serializer = None

def __init__(
self: Store[State, Action, Event],
reducer: ReducerType[State, Action, Event],
Expand Down Expand Up @@ -276,3 +284,42 @@ def decorator(
)

return decorator

def set_custom_serializer(
self: Store,
serializer: Callable[[object | type], SnapshotAtom],
) -> None:
"""Set a custom serializer for the store snapshot."""
self.custom_serializer = serializer

@property
def snapshot(self: Store[State, Action, Event]) -> SnapshotAtom:
return self._serialize_value(self._state)

def _serialize_value(self: Store, obj: object | type) -> SnapshotAtom:
if self.custom_serializer:
return self.custom_serializer(obj)
if is_immutable(obj):
return self._serialize_dataclass_to_dict(obj)
if isinstance(obj, (list, tuple)):
return [self._serialize_value(i) for i in obj]
if callable(obj):
return self._serialize_value(obj())
if isinstance(obj, StrEnum):
return str(obj)
if isinstance(obj, IntEnum):
return int(obj)
if isinstance(obj, (int, float, str, bool, NoneType)):
return obj
msg = f'Unable to serialize object with type {type(obj)}.'
raise ValueError(msg)

def _serialize_dataclass_to_dict(
self: Store,
obj: Immutable,
) -> dict[str, Any]:
result = {}
for field in dataclasses.fields(obj):
value = self._serialize_value(getattr(obj, field.name))
result[field.name] = value
return result
49 changes: 2 additions & 47 deletions redux/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,11 @@
"""Let the test check snapshots of the window during execution."""
from __future__ import annotations

import dataclasses
import json
import os
from enum import IntEnum, StrEnum
from types import NoneType
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

import pytest
from immutable import Immutable, is_immutable

if TYPE_CHECKING:
from logging import Logger
Expand All @@ -24,9 +20,6 @@
override_store_snapshots = os.environ.get('REDUX_TEST_OVERRIDE_SNAPSHOTS', '0') == '1'


Atom = int | float | str | bool | NoneType | dict[str, 'Atom'] | list['Atom']


class StoreSnapshotContext:
"""Context object for tests taking snapshots of the store."""

Expand All @@ -43,44 +36,6 @@ def __init__(
self.logger = logger
self.results_dir.mkdir(exist_ok=True)

def _convert_value(self: StoreSnapshotContext, obj: object | type) -> Atom:
import sys
from pathlib import Path

if is_immutable(obj):
return self._convert_dataclass_to_dict(obj)
if isinstance(obj, (list, tuple)):
return [self._convert_value(i) for i in obj]
if isinstance(obj, type):
file_path = sys.modules[obj.__module__].__file__
if file_path:
return f"""{Path(file_path).relative_to(Path().absolute()).as_posix()}:{
obj.__name__}"""
return f'{obj.__module__}:{obj.__name__}'
if callable(obj):
return self._convert_value(obj())
if isinstance(obj, StrEnum):
return str(obj)
if isinstance(obj, IntEnum):
return int(obj)
if isinstance(obj, (int, float, str, bool, NoneType)):
return obj
self.logger.warning(
'Unable to serialize',
extra={'type': type(obj), 'value': obj},
)
return None

def _convert_dataclass_to_dict(
self: StoreSnapshotContext,
obj: Immutable,
) -> dict[str, Any]:
result = {}
for field in dataclasses.fields(obj):
value = self._convert_value(getattr(obj, field.name))
result[field.name] = value
return result

def set_store(self: StoreSnapshotContext, store: Store) -> None:
"""Set the store to take snapshots of."""
self.store = store
Expand All @@ -89,7 +44,7 @@ def set_store(self: StoreSnapshotContext, store: Store) -> None:
def snapshot(self: StoreSnapshotContext) -> str:
"""Return the snapshot of the current state of the store."""
return (
json.dumps(self._convert_value(self.store._state), indent=2) # noqa: SLF001
json.dumps(self.store.snapshot, indent=2)
if self.store._state # noqa: SLF001
else ''
)
Expand Down

0 comments on commit 8cd981b

Please sign in to comment.