Skip to content

Commit

Permalink
Add typings
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanforbes committed May 18, 2020
1 parent 7609144 commit c6cdeac
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 59 deletions.
8 changes: 8 additions & 0 deletions pgproto.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import codecs

class CodecContext:
def get_text_codec(self) -> codecs.CodecInfo: ...
def is_encoding_utf8(self) -> bool: ...

class ReadBuffer: ...
class WriteBuffer: ...
176 changes: 117 additions & 59 deletions types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,29 @@
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0


import builtins
import sys
import typing
import typing_extensions


__all__ = (
'BitString', 'Point', 'Path', 'Polygon',
'Box', 'Line', 'LineSegment', 'Circle',
)

_BS = typing.TypeVar('_BS', bound='BitString')
_P = typing.TypeVar('_P', bound='Point')
_BitOrder = typing_extensions.Literal['big', 'little']


class BitString:
"""Immutable representation of PostgreSQL `bit` and `varbit` types."""

__slots__ = '_bytes', '_bitlength'

def __init__(self, bitstring=None):
def __init__(self,
bitstring: typing.Optional[builtins.bytes] = None) -> None:
if not bitstring:
self._bytes = bytes()
self._bitlength = 0
Expand All @@ -28,7 +39,7 @@ def __init__(self, bitstring=None):
bit_pos = 0

for i, bit in enumerate(bitstring):
if bit == ' ':
if bit == ' ': # type: ignore
continue
bit = int(bit)
if bit != 0 and bit != 1:
Expand All @@ -53,14 +64,15 @@ def __init__(self, bitstring=None):
self._bitlength = bitlen

@classmethod
def frombytes(cls, bytes_=None, bitlength=None):
if bitlength is None and bytes_ is None:
bytes_ = bytes()
bitlength = 0

elif bitlength is None:
bitlength = len(bytes_) * 8

def frombytes(cls: typing.Type[_BS],
bytes_: typing.Optional[builtins.bytes] = None,
bitlength: typing.Optional[int] = None) -> _BS:
if bitlength is None:
if bytes_ is None:
bytes_ = bytes()
bitlength = 0
else:
bitlength = len(bytes_) * 8
else:
if bytes_ is None:
bytes_ = bytes(bitlength // 8 + 1)
Expand All @@ -87,10 +99,10 @@ def frombytes(cls, bytes_=None, bitlength=None):
return result

@property
def bytes(self):
def bytes(self) -> builtins.bytes:
return self._bytes

def as_string(self):
def as_string(self) -> str:
s = ''

for i in range(self._bitlength):
Expand All @@ -100,7 +112,8 @@ def as_string(self):

return s.strip()

def to_int(self, bitorder='big', *, signed=False):
def to_int(self, bitorder: _BitOrder = 'big',
*, signed: bool = False) -> int:
"""Interpret the BitString as a Python int.
Acts similarly to int.from_bytes.
Expand Down Expand Up @@ -135,7 +148,8 @@ def to_int(self, bitorder='big', *, signed=False):
return x

@classmethod
def from_int(cls, x, length, bitorder='big', *, signed=False):
def from_int(cls: typing.Type[_BS], x: int, length: int,
bitorder: _BitOrder = 'big', *, signed: bool = False) -> _BS:
"""Represent the Python int x as a BitString.
Acts similarly to int.to_bytes.
Expand Down Expand Up @@ -187,27 +201,27 @@ def from_int(cls, x, length, bitorder='big', *, signed=False):
bytes_ = x.to_bytes((length + 7) // 8, byteorder='big')
return cls.frombytes(bytes_, length)

def __repr__(self):
def __repr__(self) -> str:
return '<BitString {}>'.format(self.as_string())

__str__ = __repr__

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, BitString):
return NotImplemented

return (self._bytes == other._bytes and
self._bitlength == other._bitlength)

def __hash__(self):
def __hash__(self) -> int:
return hash((self._bytes, self._bitlength))

def _getitem(self, i):
def _getitem(self, i: int) -> int:
byte = self._bytes[i // 8]
shift = 8 - i % 8 - 1
return (byte >> shift) & 0x1

def __getitem__(self, i):
def __getitem__(self, i: int) -> int:
if isinstance(i, slice):
raise NotImplementedError('BitString does not support slices')

Expand All @@ -216,100 +230,134 @@ def __getitem__(self, i):

return self._getitem(i)

def __len__(self):
def __len__(self) -> int:
return self._bitlength


class Point(tuple):
if typing.TYPE_CHECKING or sys.version_info >= (3, 6):
_PointBase = typing.Tuple[float, float]
_BoxBase = typing.Tuple['Point', 'Point']
_LineBase = typing.Tuple[float, float, float]
_LineSegmentBase = typing.Tuple['Point', 'Point']
_CircleBase = typing.Tuple['Point', float]
else:
# In Python 3.5, subclassing from typing.Tuple does not make the
# subclass act like a tuple in certain situations (like starred
# expressions)
_PointBase = tuple
_BoxBase = tuple
_LineBase = tuple
_LineSegmentBase = tuple
_CircleBase = tuple


class Point(_PointBase):
"""Immutable representation of PostgreSQL `point` type."""

__slots__ = ()

def __new__(cls, x, y):
return super().__new__(cls, (float(x), float(y)))

def __repr__(self):
def __new__(cls,
x: typing.Union[typing.SupportsFloat,
'builtins._SupportsIndex',
typing.Text,
builtins.bytes,
builtins.bytearray],
y: typing.Union[typing.SupportsFloat,
'builtins._SupportsIndex',
typing.Text,
builtins.bytes,
builtins.bytearray]) -> 'Point':
return super().__new__(cls,
typing.cast(typing.Any, (float(x), float(y))))

def __repr__(self) -> str:
return '{}.{}({})'.format(
type(self).__module__,
type(self).__name__,
tuple.__repr__(self)
)

@property
def x(self):
def x(self) -> float:
return self[0]

@property
def y(self):
def y(self) -> float:
return self[1]


class Box(tuple):
class Box(_BoxBase):
"""Immutable representation of PostgreSQL `box` type."""

__slots__ = ()

def __new__(cls, high, low):
return super().__new__(cls, (Point(*high), Point(*low)))
def __new__(cls, high: typing.Sequence[float],
low: typing.Sequence[float]) -> 'Box':
return super().__new__(cls,
typing.cast(typing.Any, (Point(*high),
Point(*low))))

def __repr__(self):
def __repr__(self) -> str:
return '{}.{}({})'.format(
type(self).__module__,
type(self).__name__,
tuple.__repr__(self)
)

@property
def high(self):
def high(self) -> Point:
return self[0]

@property
def low(self):
def low(self) -> Point:
return self[1]


class Line(tuple):
class Line(_LineBase):
"""Immutable representation of PostgreSQL `line` type."""

__slots__ = ()

def __new__(cls, A, B, C):
return super().__new__(cls, (A, B, C))
def __new__(cls, A: float, B: float, C: float) -> 'Line':
return super().__new__(cls, typing.cast(typing.Any, (A, B, C)))

@property
def A(self):
def A(self) -> float:
return self[0]

@property
def B(self):
def B(self) -> float:
return self[1]

@property
def C(self):
def C(self) -> float:
return self[2]


class LineSegment(tuple):
class LineSegment(_LineSegmentBase):
"""Immutable representation of PostgreSQL `lseg` type."""

__slots__ = ()

def __new__(cls, p1, p2):
return super().__new__(cls, (Point(*p1), Point(*p2)))
def __new__(cls, p1: typing.Sequence[float],
p2: typing.Sequence[float]) -> 'LineSegment':
return super().__new__(cls,
typing.cast(typing.Any, (Point(*p1),
Point(*p2))))

def __repr__(self):
def __repr__(self) -> str:
return '{}.{}({})'.format(
type(self).__module__,
type(self).__name__,
tuple.__repr__(self)
)

@property
def p1(self):
def p1(self) -> Point:
return self[0]

@property
def p2(self):
def p2(self) -> Point:
return self[1]


Expand All @@ -318,34 +366,44 @@ class Path:

__slots__ = '_is_closed', 'points'

def __init__(self, *points, is_closed=False):
def __init__(self, *points: typing.Sequence[float],
is_closed: bool = False) -> None:
self.points = tuple(Point(*p) for p in points)
self._is_closed = is_closed

@property
def is_closed(self):
def is_closed(self) -> bool:
return self._is_closed

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, Path):
return NotImplemented

return (self.points == other.points and
self._is_closed == other._is_closed)

def __hash__(self):
def __hash__(self) -> int:
return hash((self.points, self.is_closed))

def __iter__(self):
def __iter__(self) -> typing.Iterator[Point]:
return iter(self.points)

def __len__(self):
def __len__(self) -> int:
return len(self.points)

def __getitem__(self, i):
@typing.overload
def __getitem__(self, i: int) -> Point:
...

@typing.overload
def __getitem__(self, i: slice) -> typing.Tuple[Point, ...]:
...

def __getitem__(self, i: typing.Union[int, slice]) \
-> typing.Union[Point, typing.Tuple[Point, ...]]:
return self.points[i]

def __contains__(self, point):
def __contains__(self, point: object) -> bool:
return point in self.points


Expand All @@ -354,23 +412,23 @@ class Polygon(Path):

__slots__ = ()

def __init__(self, *points):
def __init__(self, *points: typing.Sequence[float]) -> None:
# polygon is always closed
super().__init__(*points, is_closed=True)


class Circle(tuple):
class Circle(_CircleBase):
"""Immutable representation of PostgreSQL `circle` type."""

__slots__ = ()

def __new__(cls, center, radius):
return super().__new__(cls, (center, radius))
def __new__(cls, center: Point, radius: float) -> 'Circle':
return super().__new__(cls, typing.cast(typing.Any, (center, radius)))

@property
def center(self):
def center(self) -> Point:
return self[0]

@property
def radius(self):
def radius(self) -> float:
return self[1]

0 comments on commit c6cdeac

Please sign in to comment.