From 4e4f966c583ee7ee6b929e2d55461ddbbd670370 Mon Sep 17 00:00:00 2001 From: Ashhar Hasan Date: Sat, 3 Feb 2024 17:05:31 +0530 Subject: [PATCH] Add type hints to types module and enable mypy --- setup.cfg | 2 +- trino/types.py | 23 +++++++++++++---------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/setup.cfg b/setup.cfg index 7d391545..372d84b0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,5 +19,5 @@ ignore_missing_imports = true no_implicit_optional = true warn_unused_ignores = true -[mypy-tests.*,trino.client,trino.dbapi,trino.sqlalchemy.*,trino.types] +[mypy-tests.*,trino.client,trino.dbapi,trino.sqlalchemy.*] ignore_errors = true diff --git a/trino/types.py b/trino/types.py index 8a745f52..f5f2e32a 100644 --- a/trino/types.py +++ b/trino/types.py @@ -3,7 +3,7 @@ import abc from datetime import datetime, time, timedelta from decimal import Decimal -from typing import Any, Dict, Generic, List, TypeVar, Union +from typing import Any, Dict, Generic, List, Tuple, TypeVar, Union, cast from dateutil import tz @@ -26,7 +26,7 @@ def new_instance(self, value: PythonTemporalType, fraction: Decimal) -> Temporal def to_python_type(self) -> PythonTemporalType: pass - def round_to(self, precision: int) -> TemporalType: + def round_to(self, precision: int) -> TemporalType[PythonTemporalType]: """ Python datetime and time only support up to microsecond precision In case the supplied value exceeds the specified precision, @@ -34,7 +34,8 @@ def round_to(self, precision: int) -> TemporalType: """ precision = min(precision, MAX_PYTHON_TEMPORAL_PRECISION_POWER) remaining_fractional_seconds = self._remaining_fractional_seconds - digits = abs(remaining_fractional_seconds.as_tuple().exponent) + # exponent can return `n`, `N`, `F` too if the value is a NaN for example + digits = abs(remaining_fractional_seconds.as_tuple().exponent) # type: ignore if digits > precision: rounding_factor = POWERS_OF_TEN[precision] rounded = remaining_fractional_seconds.quantize(Decimal(1 / rounding_factor)) @@ -101,16 +102,18 @@ def new_instance(self, value: datetime, fraction: Decimal) -> TimestampWithTimeZ def normalize(self, value: datetime) -> datetime: if tz.datetime_ambiguous(value): - return self._whole_python_temporal_value.tzinfo.normalize(value) + # This appears to be dead code since tzinfo doesn't actually have a `normalize` method. + # TODO: Fix this or remove. + return self._whole_python_temporal_value.tzinfo.normalize(value) # type: ignore return value -class NamedRowTuple(tuple): +class NamedRowTuple(Tuple[Any, ...]): """Custom tuple class as namedtuple doesn't support missing or duplicate names""" - def __new__(cls, values, names: List[str], types: List[str]): - return super().__new__(cls, values) + def __new__(cls, values: List[Any], names: List[str], types: List[str]) -> NamedRowTuple: + return cast(NamedRowTuple, super().__new__(cls, values)) - def __init__(self, values, names: List[str], types: List[str]): + def __init__(self, values: List[Any], names: List[str], types: List[str]): self._names = names # With names and types users can retrieve the name and Trino data type of a row self.__annotations__ = dict() @@ -125,9 +128,9 @@ def __init__(self, values, names: List[str], types: List[str]): elements.append(repr(value)) self._repr = "(" + ", ".join(elements) + ")" - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if self._names.count(name): raise ValueError("Ambiguous row field reference: " + name) - def __repr__(self): + def __repr__(self) -> str: return self._repr