Skip to content

Commit

Permalink
more typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Christian-B committed Dec 5, 2023
1 parent 4a16be2 commit a4cbe7f
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 28 deletions.
4 changes: 3 additions & 1 deletion spinn_utilities/ranged/abstract_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ def iter_ranges(self, key: Optional[_StrSeq]) -> Iterator[Tuple[
...

@abstractmethod
def iter_ranges(self, key: _Keys = None):
def iter_ranges(self, key: _Keys = None
) -> Union[Iterator[Tuple[int, int, T]],
Iterator[Tuple[int, int, Dict[str, T]]]]:
"""
Iterates over the ranges(s) for all IDs covered by this view.
There will be one yield for each range which may cover one or
Expand Down
32 changes: 22 additions & 10 deletions spinn_utilities/ranged/abstract_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,8 @@ class SingleList(AbstractList[T], Generic[T], metaclass=AbstractBase):
__slots__ = [
"_a_list", "_operation"]

def __init__(self, a_list, operation, key=None):
def __init__(self, a_list: AbstractList[T], operation: Callable[[T], T],
key: str = None):
"""
:param AbstractList a_list: The list to perform the operation on
:param callable operation:
Expand All @@ -641,7 +642,7 @@ def __init__(self, a_list, operation, key=None):
self._operation = operation

@overrides(AbstractList.range_based)
def range_based(self):
def range_based(self) -> bool:
return self._a_list.range_based()

@overrides(AbstractList.get_value_by_id)
Expand All @@ -664,7 +665,10 @@ def iter_ranges(self) -> Iterator[Tuple[int, int, T]]:
yield (start, stop, self._operation(value))

@overrides(AbstractList.get_default)
def get_default(self):
def get_default(self) -> Optional[T]:
default = self._a_list.get_default()
if default is None:
return None
return self._operation(self._a_list.get_default())

@overrides(AbstractList.iter_ranges_by_slice)
Expand All @@ -683,7 +687,8 @@ class DualList(AbstractList[T], Generic[T], metaclass=AbstractBase):
__slots__ = [
"_left", "_operation", "_right"]

def __init__(self, left, right, operation, key=None):
def __init__(self, left: AbstractList[T], right: AbstractList[T],
operation: Callable[[T, T], T], key: str = None):
"""
:param AbstractList left: The first list to combine
:param AbstractList right: The second list to combine
Expand All @@ -703,7 +708,7 @@ def __init__(self, left, right, operation, key=None):
self._operation = operation

@overrides(AbstractList.range_based)
def range_based(self):
def range_based(self) -> bool:
return self._left.range_based() and self._right.range_based()

@overrides(AbstractList.get_value_by_id)
Expand Down Expand Up @@ -770,7 +775,7 @@ def iter_by_slice(self, slice_start: int, slice_stop: int) -> Iterator[T]:
return

@overrides(AbstractList.iter_ranges)
def iter_ranges(self):
def iter_ranges(self) -> Iterator[Tuple[int, int, T]]:
left_iter = self._left.iter_ranges()
right_iter = self._right.iter_ranges()
return self._merge_ranges(left_iter, right_iter)
Expand All @@ -783,7 +788,9 @@ def iter_ranges_by_slice(
right_iter = self._right.iter_ranges_by_slice(slice_start, slice_stop)
return self._merge_ranges(left_iter, right_iter)

def _merge_ranges(self, left_iter, right_iter):
def _merge_ranges(self, left_iter: Iterator[Tuple[int, int, T]],
right_iter: Iterator[Tuple[int, int, T]]
) -> Iterator[Tuple[int, int, T]]:
(left_start, left_stop, left_value) = next(left_iter)
(right_start, right_stop, right_value) = next(right_iter)
try:
Expand All @@ -802,6 +809,11 @@ def _merge_ranges(self, left_iter, right_iter):
return

@overrides(AbstractList.get_default)
def get_default(self):
return self._operation(
self._left.get_default(), self._right.get_default())
def get_default(self) -> Optional[T]:
l_default = self._left.get_default()
if l_default is None:
return None
r_default = self._right.get_default()
if r_default is None:
return None
return self._operation(l_default, r_default)
8 changes: 5 additions & 3 deletions spinn_utilities/ranged/ids_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations
from typing import (
Dict, Generic, Iterable, Iterator, Optional, Sequence, Tuple,
overload, TYPE_CHECKING)
overload, TYPE_CHECKING, Union)
from spinn_utilities.overrides import overrides
from .abstract_dict import AbstractDict, _StrSeq, _Keys
from .abstract_list import IdsType
Expand Down Expand Up @@ -49,7 +49,7 @@ def get_value(self, key: Optional[_StrSeq]) -> Dict[str, T]:
...

@overrides(AbstractDict.get_value)
def get_value(self, key: _Keys):
def get_value(self, key: _Keys) -> Union[T, Dict[str, T]]:
if isinstance(key, str):
return self._range_dict.get_list(key).get_single_value_by_ids(
self._ids)
Expand Down Expand Up @@ -104,5 +104,7 @@ def iter_ranges(self, key: Optional[_StrSeq] = None) -> Iterator[Tuple[
...

@overrides(AbstractDict.iter_ranges)
def iter_ranges(self, key: _Keys = None):
def iter_ranges(self, key: _Keys = None
) -> Union[Iterator[Tuple[int, int, T]],
Iterator[Tuple[int, int, Dict[str, T]]]]:
return self._range_dict.iter_ranges_by_ids(key=key, ids=self._ids)
17 changes: 10 additions & 7 deletions spinn_utilities/ranged/range_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def get_value(self, key: str) -> T: ...
def get_value(self, key: Optional[_StrSeq]) -> Dict[str, T]: ...

@overrides(AbstractDict.get_value, extend_defaults=True)
def get_value(self, key: Union[str, None, _StrSeq] = None):
def get_value(self, key: Union[str, None, _StrSeq] = None
) -> Union[T, Dict[str, T]]:
if isinstance(key, str):
return self._value_lists[key].get_single_value_all()
if key is None:
Expand Down Expand Up @@ -227,7 +228,8 @@ def update_safe_iter_all_values(
ids: IdsType) -> Iterator[Dict[str, T]]: ...

def update_safe_iter_all_values(
self, key: Union[str, Optional[_StrSeq]], ids: IdsType):
self, key: Union[str, Optional[_StrSeq]],
ids: IdsType) -> Iterator[T]:
"""
Same as
:py:meth:`iter_all_values`
Expand All @@ -248,7 +250,7 @@ def iter_all_values(self, key: Optional[_StrSeq],
...

@overrides(AbstractDict.iter_all_values, extend_defaults=True)
def iter_all_values(self, key: _Keys, update_safe=False):
def iter_all_values(self, key: _Keys, update_safe: bool = False):
if isinstance(key, str):
if update_safe:
return self._value_lists[key].iter()
Expand Down Expand Up @@ -373,7 +375,7 @@ def keys(self) -> Iterable[str]:

def _merge_ranges(
self, range_iters: Dict[str, Iterator[Tuple[int, int, T]]]
) -> _CompoundRangeIter:
) -> Iterator[Tuple[int, int, Dict[str, T]]]:
current: Dict[str, T] = dict()
ranges: Dict[str, Tuple[int, int, T]] = dict()
start = 0
Expand Down Expand Up @@ -402,12 +404,13 @@ def _merge_ranges(
yield (start, stop, current)

@overrides(AbstractDict.iter_ranges)
def iter_ranges(self, key=None):
def iter_ranges(self, key: _Keys = None) -> \
Iterator[Tuple[int, int, Dict[str, T]]]:
if isinstance(key, str):
return self._value_lists[key].iter_ranges()
return self._value_lists[key].iter_ranges() # Iterator[Tuple[int, int, T]]
if key is None:
key = self.keys()
return self._merge_ranges({
return self._merge_ranges({ # Iterator[Tuple[int, int, Dict[str, T]]]
a_key: self._value_lists[a_key].iter_ranges()
for a_key in key})

Expand Down
3 changes: 2 additions & 1 deletion spinn_utilities/ranged/ranged_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def get_value_by_id(self, the_id: int) -> T:
return self.__the_values[the_id]

@overrides(AbstractList.get_single_value_by_slice)
def get_single_value_by_slice(self, slice_start: int, slice_stop: int):
def get_single_value_by_slice(
self, slice_start: int, slice_stop: int) -> T:
slice_start, slice_stop = self._check_slice_in_range(
slice_start, slice_stop)

Expand Down
8 changes: 5 additions & 3 deletions spinn_utilities/ranged/single_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations
from typing import (
Dict, Generic, Iterator, Optional, Sequence, Tuple, overload,
TYPE_CHECKING)
TYPE_CHECKING, Union)
from spinn_utilities.overrides import overrides
from .abstract_dict import AbstractDict, T, _StrSeq, _Keys
from .abstract_view import AbstractView
Expand Down Expand Up @@ -48,7 +48,7 @@ def get_value(self, key: Optional[_StrSeq]) -> Dict[str, T]:
...

@overrides(AbstractDict.get_value)
def get_value(self, key: _Keys):
def get_value(self, key: _Keys) -> Union[T, Dict[str, T]]:
if isinstance(key, str):
return self._range_dict.get_list(key).get_value_by_id(
the_id=self._id)
Expand Down Expand Up @@ -96,5 +96,7 @@ def iter_ranges(self, key: Optional[_StrSeq] = None) -> Iterator[
...

@overrides(AbstractDict.iter_ranges)
def iter_ranges(self, key: _Keys = None):
def iter_ranges(self, key: _Keys = None
) -> Union[Iterator[Tuple[int, int, T]],
Iterator[Tuple[int, int, Dict[str, T]]]]:
return self._range_dict.iter_ranges_by_id(key=key, the_id=self._id)
8 changes: 5 additions & 3 deletions spinn_utilities/ranged/slice_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations
from typing import (
Dict, Generic, Iterable, Iterator, Optional, Sequence, Tuple, overload,
TYPE_CHECKING)
TYPE_CHECKING, Union)
from spinn_utilities.overrides import overrides
from .abstract_dict import AbstractDict, T, _StrSeq, _Keys
from .abstract_view import AbstractView
Expand Down Expand Up @@ -50,7 +50,7 @@ def get_value(self, key: Optional[_StrSeq]) -> Dict[str, T]:
...

@overrides(AbstractDict.get_value)
def get_value(self, key: _Keys):
def get_value(self, key: _Keys) -> Union[T, Dict[str, T]]:
if isinstance(key, str):
return self._range_dict.get_list(key).get_single_value_by_slice(
slice_start=self._start, slice_stop=self._stop)
Expand Down Expand Up @@ -109,6 +109,8 @@ def iter_ranges(self, key: Optional[_StrSeq] = None) -> Iterator[
...

@overrides(AbstractDict.iter_ranges)
def iter_ranges(self, key: _Keys = None):
def iter_ranges(self, key: _Keys = None
) -> Union[Iterator[Tuple[int, int, T]],
Iterator[Tuple[int, int, Dict[str, T]]]]:
return self._range_dict.iter_ranges_by_slice(
key=key, slice_start=self._start, slice_stop=self._stop)

0 comments on commit a4cbe7f

Please sign in to comment.