diff --git a/redux/combine_reducers.py b/redux/combine_reducers.py index 3d4286c..b6825ad 100644 --- a/redux/combine_reducers.py +++ b/redux/combine_reducers.py @@ -48,7 +48,7 @@ def combine_reducers( type[state_type], make_dataclass( state_type.__name__, - ('_id', *reducers.keys()), + {'_id', *reducers.keys(), *(field.name for field in fields(state_type))}, frozen=True, kw_only=True, ), @@ -66,28 +66,32 @@ def combined_reducer( key = action.key reducer = action.reducer reducers[key] = reducer + field_names = {field.name for field in fields(state_class)} state_class = make_dataclass( state_type.__name__, - ('_id', *reducers.keys()), + {*field_names, key}, frozen=True, kw_only=True, ) - reducer_result = reducer( - None, - CombineReducerInitAction(_id=_id, key=key), + reducer_result = ( + getattr(state, key) + if hasattr(state, key) + else reducer( + None, + CombineReducerInitAction(_id=_id, key=key), + ) ) state = state_class( - _id=state._id, # noqa: SLF001 **( { - key_: ( + field: ( reducer_result.state if is_complete_reducer_result(reducer_result) else reducer_result ) - if key == key_ - else getattr(state, key_) - for key_ in reducers + if key == field + else getattr(state, field) + for field in field_names } ), ) @@ -137,10 +141,15 @@ def combined_reducer( for key, reducer in reducers.items() } result_state = state_class( - _id=_id, - **{ - key: result.state if is_complete_reducer_result(result) else result - for key, result in reducers_results.items() + **{ # pyright: ignore [reportArgumentType] + field.name: ( + reducers_results[field.name].state + if is_complete_reducer_result(reducers_results[field.name]) + else reducers_results[field.name] + ) + if field.name in reducers_results + else getattr(state, field.name, None) + for field in fields(state_class) }, ) result_actions += functools.reduce(