diff --git a/pgproto.pyi b/pgproto.pyi new file mode 100644 index 0000000..390a92f --- /dev/null +++ b/pgproto.pyi @@ -0,0 +1,8 @@ +import codecs + +class CodecContext: + def get_text_codec(self) -> codecs.CodecInfo: ... + def is_encoding_utf8(self) -> bool: ... + +class ReadBuffer: ... +class WriteBuffer: ... diff --git a/types.py b/types.py index 9232ae0..3a56451 100644 --- a/types.py +++ b/types.py @@ -5,18 +5,29 @@ # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +import builtins +import sys +import typing +import typing_extensions + + __all__ = ( 'BitString', 'Point', 'Path', 'Polygon', 'Box', 'Line', 'LineSegment', 'Circle', ) +_BS = typing.TypeVar('_BS', bound='BitString') +_P = typing.TypeVar('_P', bound='Point') +_BitOrder = typing_extensions.Literal['big', 'little'] + class BitString: """Immutable representation of PostgreSQL `bit` and `varbit` types.""" __slots__ = '_bytes', '_bitlength' - def __init__(self, bitstring=None): + def __init__(self, + bitstring: typing.Optional[builtins.bytes] = None) -> None: if not bitstring: self._bytes = bytes() self._bitlength = 0 @@ -28,7 +39,7 @@ def __init__(self, bitstring=None): bit_pos = 0 for i, bit in enumerate(bitstring): - if bit == ' ': + if bit == ' ': # type: ignore continue bit = int(bit) if bit != 0 and bit != 1: @@ -53,14 +64,15 @@ def __init__(self, bitstring=None): self._bitlength = bitlen @classmethod - def frombytes(cls, bytes_=None, bitlength=None): - if bitlength is None and bytes_ is None: - bytes_ = bytes() - bitlength = 0 - - elif bitlength is None: - bitlength = len(bytes_) * 8 - + def frombytes(cls: typing.Type[_BS], + bytes_: typing.Optional[builtins.bytes] = None, + bitlength: typing.Optional[int] = None) -> _BS: + if bitlength is None: + if bytes_ is None: + bytes_ = bytes() + bitlength = 0 + else: + bitlength = len(bytes_) * 8 else: if bytes_ is None: bytes_ = bytes(bitlength // 8 + 1) @@ -87,10 +99,10 @@ def frombytes(cls, bytes_=None, bitlength=None): return result @property - def bytes(self): + def bytes(self) -> builtins.bytes: return self._bytes - def as_string(self): + def as_string(self) -> str: s = '' for i in range(self._bitlength): @@ -100,7 +112,8 @@ def as_string(self): return s.strip() - def to_int(self, bitorder='big', *, signed=False): + def to_int(self, bitorder: _BitOrder = 'big', + *, signed: bool = False) -> int: """Interpret the BitString as a Python int. Acts similarly to int.from_bytes. @@ -135,7 +148,8 @@ def to_int(self, bitorder='big', *, signed=False): return x @classmethod - def from_int(cls, x, length, bitorder='big', *, signed=False): + def from_int(cls: typing.Type[_BS], x: int, length: int, + bitorder: _BitOrder = 'big', *, signed: bool = False) -> _BS: """Represent the Python int x as a BitString. Acts similarly to int.to_bytes. @@ -187,27 +201,27 @@ def from_int(cls, x, length, bitorder='big', *, signed=False): bytes_ = x.to_bytes((length + 7) // 8, byteorder='big') return cls.frombytes(bytes_, length) - def __repr__(self): + def __repr__(self) -> str: return ''.format(self.as_string()) __str__ = __repr__ - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, BitString): return NotImplemented return (self._bytes == other._bytes and self._bitlength == other._bitlength) - def __hash__(self): + def __hash__(self) -> int: return hash((self._bytes, self._bitlength)) - def _getitem(self, i): + def _getitem(self, i: int) -> int: byte = self._bytes[i // 8] shift = 8 - i % 8 - 1 return (byte >> shift) & 0x1 - def __getitem__(self, i): + def __getitem__(self, i: int) -> int: if isinstance(i, slice): raise NotImplementedError('BitString does not support slices') @@ -216,19 +230,47 @@ def __getitem__(self, i): return self._getitem(i) - def __len__(self): + def __len__(self) -> int: return self._bitlength -class Point(tuple): +if typing.TYPE_CHECKING or sys.version_info >= (3, 6): + _PointBase = typing.Tuple[float, float] + _BoxBase = typing.Tuple['Point', 'Point'] + _LineBase = typing.Tuple[float, float, float] + _LineSegmentBase = typing.Tuple['Point', 'Point'] + _CircleBase = typing.Tuple['Point', float] +else: + # In Python 3.5, subclassing from typing.Tuple does not make the + # subclass act like a tuple in certain situations (like starred + # expressions) + _PointBase = tuple + _BoxBase = tuple + _LineBase = tuple + _LineSegmentBase = tuple + _CircleBase = tuple + + +class Point(_PointBase): """Immutable representation of PostgreSQL `point` type.""" __slots__ = () - def __new__(cls, x, y): - return super().__new__(cls, (float(x), float(y))) - - def __repr__(self): + def __new__(cls, + x: typing.Union[typing.SupportsFloat, + 'builtins._SupportsIndex', + typing.Text, + builtins.bytes, + builtins.bytearray], + y: typing.Union[typing.SupportsFloat, + 'builtins._SupportsIndex', + typing.Text, + builtins.bytes, + builtins.bytearray]) -> 'Point': + return super().__new__(cls, + typing.cast(typing.Any, (float(x), float(y)))) + + def __repr__(self) -> str: return '{}.{}({})'.format( type(self).__module__, type(self).__name__, @@ -236,23 +278,26 @@ def __repr__(self): ) @property - def x(self): + def x(self) -> float: return self[0] @property - def y(self): + def y(self) -> float: return self[1] -class Box(tuple): +class Box(_BoxBase): """Immutable representation of PostgreSQL `box` type.""" __slots__ = () - def __new__(cls, high, low): - return super().__new__(cls, (Point(*high), Point(*low))) + def __new__(cls, high: typing.Sequence[float], + low: typing.Sequence[float]) -> 'Box': + return super().__new__(cls, + typing.cast(typing.Any, (Point(*high), + Point(*low)))) - def __repr__(self): + def __repr__(self) -> str: return '{}.{}({})'.format( type(self).__module__, type(self).__name__, @@ -260,44 +305,47 @@ def __repr__(self): ) @property - def high(self): + def high(self) -> Point: return self[0] @property - def low(self): + def low(self) -> Point: return self[1] -class Line(tuple): +class Line(_LineBase): """Immutable representation of PostgreSQL `line` type.""" __slots__ = () - def __new__(cls, A, B, C): - return super().__new__(cls, (A, B, C)) + def __new__(cls, A: float, B: float, C: float) -> 'Line': + return super().__new__(cls, typing.cast(typing.Any, (A, B, C))) @property - def A(self): + def A(self) -> float: return self[0] @property - def B(self): + def B(self) -> float: return self[1] @property - def C(self): + def C(self) -> float: return self[2] -class LineSegment(tuple): +class LineSegment(_LineSegmentBase): """Immutable representation of PostgreSQL `lseg` type.""" __slots__ = () - def __new__(cls, p1, p2): - return super().__new__(cls, (Point(*p1), Point(*p2))) + def __new__(cls, p1: typing.Sequence[float], + p2: typing.Sequence[float]) -> 'LineSegment': + return super().__new__(cls, + typing.cast(typing.Any, (Point(*p1), + Point(*p2)))) - def __repr__(self): + def __repr__(self) -> str: return '{}.{}({})'.format( type(self).__module__, type(self).__name__, @@ -305,11 +353,11 @@ def __repr__(self): ) @property - def p1(self): + def p1(self) -> Point: return self[0] @property - def p2(self): + def p2(self) -> Point: return self[1] @@ -318,34 +366,44 @@ class Path: __slots__ = '_is_closed', 'points' - def __init__(self, *points, is_closed=False): + def __init__(self, *points: typing.Sequence[float], + is_closed: bool = False) -> None: self.points = tuple(Point(*p) for p in points) self._is_closed = is_closed @property - def is_closed(self): + def is_closed(self) -> bool: return self._is_closed - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, Path): return NotImplemented return (self.points == other.points and self._is_closed == other._is_closed) - def __hash__(self): + def __hash__(self) -> int: return hash((self.points, self.is_closed)) - def __iter__(self): + def __iter__(self) -> typing.Iterator[Point]: return iter(self.points) - def __len__(self): + def __len__(self) -> int: return len(self.points) - def __getitem__(self, i): + @typing.overload + def __getitem__(self, i: int) -> Point: + ... + + @typing.overload + def __getitem__(self, i: slice) -> typing.Tuple[Point, ...]: + ... + + def __getitem__(self, i: typing.Union[int, slice]) \ + -> typing.Union[Point, typing.Tuple[Point, ...]]: return self.points[i] - def __contains__(self, point): + def __contains__(self, point: object) -> bool: return point in self.points @@ -354,23 +412,23 @@ class Polygon(Path): __slots__ = () - def __init__(self, *points): + def __init__(self, *points: typing.Sequence[float]) -> None: # polygon is always closed super().__init__(*points, is_closed=True) -class Circle(tuple): +class Circle(_CircleBase): """Immutable representation of PostgreSQL `circle` type.""" __slots__ = () - def __new__(cls, center, radius): - return super().__new__(cls, (center, radius)) + def __new__(cls, center: Point, radius: float) -> 'Circle': + return super().__new__(cls, typing.cast(typing.Any, (center, radius))) @property - def center(self): + def center(self) -> Point: return self[0] @property - def radius(self): + def radius(self) -> float: return self[1]