Skip to content

Commit

Permalink
typing overrides
Browse files Browse the repository at this point in the history
  • Loading branch information
Christian-B committed Dec 4, 2023
1 parent 47c3f7e commit 099be16
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 30 deletions.
6 changes: 3 additions & 3 deletions spinn_utilities/ranged/abstract_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# Can't be Iterable[str] or Sequence[str] because that includes str itself
_StrSeq: TypeAlias = Union[
MutableSequence[str], Tuple[str, ...], FrozenSet[str], Set[str]]
_Keys: TypeAlias = Union[None, str, _StrSeq]
_Keys: TypeAlias = Optional[Union[str, _StrSeq]]


class AbstractDict(Generic[T], metaclass=AbstractBase):
Expand Down Expand Up @@ -117,7 +117,7 @@ def iter_all_values(self, key: Optional[_StrSeq],
...

@abstractmethod
def iter_all_values(self, key, update_safe=False):
def iter_all_values(self, key: _Keys, update_safe: bool = False):
"""
Iterates over the value(s) for all IDs covered by this view.
There will be one yield for each ID even if values are repeated.
Expand Down Expand Up @@ -181,7 +181,7 @@ def iter_ranges(self, key: Optional[_StrSeq]) -> Iterator[Tuple[
...

@abstractmethod
def iter_ranges(self, key=None):
def iter_ranges(self, key: _Keys = None):
"""
Iterates over the ranges(s) for all IDs covered by this view.
There will be one yield for each range which may cover one or
Expand Down
26 changes: 16 additions & 10 deletions spinn_utilities/ranged/abstract_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,20 +645,21 @@ def range_based(self):
return self._a_list.range_based()

@overrides(AbstractList.get_value_by_id)
def get_value_by_id(self, the_id):
def get_value_by_id(self, the_id: int) -> T:
return self._operation(self._a_list.get_value_by_id(the_id))

@overrides(AbstractList.get_single_value_by_slice)
def get_single_value_by_slice(self, slice_start, slice_stop):
def get_single_value_by_slice(
self, slice_start: int, slice_stop: int) -> T:
return self._operation(self._a_list.get_single_value_by_slice(
slice_start, slice_stop))

@overrides(AbstractList.get_single_value_by_ids)
def get_single_value_by_ids(self, ids):
def get_single_value_by_ids(self, ids: IdsType) -> T:
return self._operation(self._a_list.get_single_value_by_ids(ids))

@overrides(AbstractList.iter_ranges)
def iter_ranges(self):
def iter_ranges(self) -> Iterator[Tuple[int, int, T]]:
for (start, stop, value) in self._a_list.iter_ranges():
yield (start, stop, self._operation(value))

Expand All @@ -667,7 +668,9 @@ def get_default(self):
return self._operation(self._a_list.get_default())

@overrides(AbstractList.iter_ranges_by_slice)
def iter_ranges_by_slice(self, slice_start, slice_stop):
def iter_ranges_by_slice(
self, slice_start: int, slice_stop: int) -> Iterator[
Tuple[int, int, T]]:
for (start, stop, value) in \
self._a_list.iter_ranges_by_slice(slice_start, slice_stop):
yield (start, stop, self._operation(value))
Expand Down Expand Up @@ -704,25 +707,26 @@ def range_based(self):
return self._left.range_based() and self._right.range_based()

@overrides(AbstractList.get_value_by_id)
def get_value_by_id(self, the_id):
def get_value_by_id(self, the_id: int) -> T:
return self._operation(
self._left.get_value_by_id(the_id),
self._right.get_value_by_id(the_id))

@overrides(AbstractList.get_single_value_by_slice)
def get_single_value_by_slice(self, slice_start, slice_stop):
def get_single_value_by_slice(
self, slice_start: int, slice_stop: int) -> T:
return self._operation(
self._left.get_single_value_by_slice(slice_start, slice_stop),
self._right.get_single_value_by_slice(slice_start, slice_stop))

@overrides(AbstractList.get_single_value_by_ids)
def get_single_value_by_ids(self, ids):
def get_single_value_by_ids(self, ids: IdsType) -> T:
return self._operation(
self._left.get_single_value_by_ids(ids),
self._right.get_single_value_by_ids(ids))

@overrides(AbstractList.iter_by_slice)
def iter_by_slice(self, slice_start, slice_stop):
def iter_by_slice(self, slice_start: int, slice_stop: int) -> Iterator[T]:
slice_start, slice_stop = self._check_slice_in_range(
slice_start, slice_stop)
if self._left.range_based():
Expand Down Expand Up @@ -772,7 +776,9 @@ def iter_ranges(self):
return self._merge_ranges(left_iter, right_iter)

@overrides(AbstractList.iter_ranges_by_slice)
def iter_ranges_by_slice(self, slice_start, slice_stop):
def iter_ranges_by_slice(
self, slice_start: int, slice_stop: int) -> Iterator[
Tuple[int, int, T]]:
left_iter = self._left.iter_ranges_by_slice(slice_start, slice_stop)
right_iter = self._right.iter_ranges_by_slice(slice_start, slice_stop)
return self._merge_ranges(left_iter, right_iter)
Expand Down
7 changes: 4 additions & 3 deletions spinn_utilities/ranged/ids_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def get_value(self, key: _Keys):
for k in key}

@overrides(AbstractDict.set_value)
def set_value(self, key: str, value: T, use_list_as_value=False):
def set_value(
self, key: str, value: T, use_list_as_value: bool = False):
ranged_list = self._range_dict.get_list(key)
for _id in self._ids:
ranged_list.set_value_by_id(the_id=_id, value=value)
Expand All @@ -85,7 +86,7 @@ def iter_all_values(self, key: Optional[_StrSeq],
...

@overrides(AbstractDict.iter_all_values)
def iter_all_values(self, key: _Keys, update_safe=False):
def iter_all_values(self, key: _Keys, update_safe: bool = False):
if isinstance(key, str):
yield from self._range_dict.iter_values_by_ids(
ids=self._ids, key=key, update_safe=update_safe)
Expand All @@ -103,5 +104,5 @@ def iter_ranges(self, key: Optional[_StrSeq] = None) -> Iterator[Tuple[
...

@overrides(AbstractDict.iter_ranges)
def iter_ranges(self, key=None):
def iter_ranges(self, key: _Keys = None):
return self._range_dict.iter_ranges_by_ids(key=key, ids=self._ids)
4 changes: 3 additions & 1 deletion spinn_utilities/ranged/range_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from .abstract_view import AbstractView

_KeyType: TypeAlias = Union[int, slice, Iterable[int]]
_Keys: TypeAlias = Union[None, str, _StrSeq]

_Range: TypeAlias = Tuple[int, int, T]
_SimpleRangeIter: TypeAlias = Iterator[_Range]
_CompoundRangeIter: TypeAlias = Iterator[Tuple[int, int, Dict[str, T]]]
Expand Down Expand Up @@ -246,7 +248,7 @@ def iter_all_values(self, key: Optional[_StrSeq],
...

@overrides(AbstractDict.iter_all_values, extend_defaults=True)
def iter_all_values(self, key=None, update_safe: bool = False):
def iter_all_values(self, key: _Keys, update_safe=False):
if isinstance(key, str):
if update_safe:
return self._value_lists[key].iter()
Expand Down
6 changes: 3 additions & 3 deletions spinn_utilities/ranged/single_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,15 @@ def iter_all_values(
...

@overrides(AbstractDict.iter_all_values)
def iter_all_values(self, key, update_safe=False):
def iter_all_values(self, key: _Keys, update_safe: bool = False):
if isinstance(key, str):
yield self._range_dict.get_list(key).get_value_by_id(
the_id=self._id)
else:
yield self._range_dict.get_values_by_id(key=key, the_id=self._id)

@overrides(AbstractDict.set_value)
def set_value(self, key: str, value: T, use_list_as_value=False):
def set_value(self, key: str, value: T, use_list_as_value: bool = False):
return self._range_dict.get_list(key).set_value_by_id(
value=value, the_id=self._id)

Expand All @@ -96,5 +96,5 @@ def iter_ranges(self, key: Optional[_StrSeq] = None) -> Iterator[
...

@overrides(AbstractDict.iter_ranges)
def iter_ranges(self, key=None):
def iter_ranges(self, key: _Keys = None):
return self._range_dict.iter_ranges_by_id(key=key, the_id=self._id)
6 changes: 3 additions & 3 deletions spinn_utilities/ranged/slice_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def iter_all_values(
...

@overrides(AbstractDict.iter_all_values, extend_defaults=True)
def iter_all_values(self, key=None, update_safe=False):
def iter_all_values(self, key: _Keys = None, update_safe: bool = False):
if isinstance(key, str):
if update_safe:
return self.update_safe_iter_all_values(key)
Expand All @@ -94,7 +94,7 @@ def iter_all_values(self, key=None, update_safe=False):

@overrides(AbstractDict.set_value)
def set_value(
self, key: str, value: _ValueType, use_list_as_value=False):
self, key: str, value: _ValueType, use_list_as_value: bool = False):
self._range_dict.get_list(key).set_value_by_slice(
slice_start=self._start, slice_stop=self._stop, value=value,
use_list_as_value=use_list_as_value)
Expand All @@ -109,6 +109,6 @@ def iter_ranges(self, key: Optional[_StrSeq] = None) -> Iterator[
...

@overrides(AbstractDict.iter_ranges)
def iter_ranges(self, key=None):
def iter_ranges(self, key: _Keys = None):
return self._range_dict.iter_ranges_by_slice(
key=key, slice_start=self._start, slice_stop=self._stop)
6 changes: 4 additions & 2 deletions unittests/abstract_base/abstract_has_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any
from spinn_utilities.abstract_base import (
AbstractBase, abstractproperty, abstractmethod)

Expand All @@ -23,7 +24,7 @@ class AbstractHasConstraints(object, metaclass=AbstractBase):
__slots__ = ()

@abstractmethod
def add_constraint(self, constraint):
def add_constraint(self, constraint: Any):
""" Add a new constraint to the collection of constraints
:param constraint: constraint to add
Expand All @@ -33,7 +34,8 @@ def add_constraint(self, constraint):
If the constraint is not valid
"""

@abstractproperty
@property
@abstractmethod
def constraints(self):
""" An iterable of constraints
Expand Down
3 changes: 2 additions & 1 deletion unittests/abstract_base/grandparent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any
from spinn_utilities.overrides import overrides
from .abstract_grandparent import AbstractGrandParent
from .abstract_has_constraints import AbstractHasConstraints
Expand All @@ -26,7 +27,7 @@ def set_label(selfself, label):
pass

@overrides(AbstractHasConstraints.add_constraint)
def add_constraint(self, constraint):
def add_constraint(self, constraint: Any):
raise NotImplementedError("We set our own constrainst")

@overrides(AbstractHasConstraints.constraints)
Expand Down
3 changes: 2 additions & 1 deletion unittests/abstract_base/no_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any
from spinn_utilities.overrides import overrides
from .abstract_grandparent import AbstractGrandParent
from .abstract_has_constraints import AbstractHasConstraints
Expand All @@ -23,7 +24,7 @@ def set_label(selfself, label):
pass

@overrides(AbstractHasConstraints.add_constraint)
def add_constraint(self, constraint):
def add_constraint(self, constraint: Any):
raise NotImplementedError("We set our own constraints")

@overrides(AbstractHasConstraints.constraints)
Expand Down
3 changes: 2 additions & 1 deletion unittests/abstract_base/unchecked_bad_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any
from spinn_utilities.overrides import overrides
from .abstract_grandparent import AbstractGrandParent
from .abstract_has_constraints import AbstractHasConstraints
Expand All @@ -25,7 +26,7 @@ def set_label(selfself, not_label):
pass

@overrides(AbstractHasConstraints.add_constraint)
def add_constraint(self, constraint):
def add_constraint(self, constraint: Any):
raise NotImplementedError("We set our own constrainst")

@overrides(AbstractHasConstraints.constraints)
Expand Down
8 changes: 6 additions & 2 deletions unittests/test_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime
import logging
from typing import List, Optional, Tuple
from spinn_utilities.log import (
_BraceMessage, ConfiguredFilter, ConfiguredFormatter, FormatAdapter,
LogLevelTooHighException)
Expand Down Expand Up @@ -54,13 +56,15 @@ def __init__(self):
self.data = []

@overrides(LogStore.store_log)
def store_log(self, level, message, timestamp=None):
def store_log(self, level: int, message: str,
timestamp: Optional[datetime] = None):
if level == logging.CRITICAL:
1/0
self.data.append((level, message))

@overrides(LogStore.retreive_log_messages)
def retreive_log_messages(self, min_level=0):
def retreive_log_messages(
self, min_level: int = 0) -> List[Tuple[int, str]]:
result = []
for (level, message) in self.data:
if level >= min_level:
Expand Down

0 comments on commit 099be16

Please sign in to comment.