-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathndimgraph.py
142 lines (109 loc) · 3.99 KB
/
ndimgraph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from __future__ import annotations
from typing import Optional, Union, Tuple, List, Sequence, Deque, Literal
import numpy as np
class Point:
def __init__(self, coords: Sequence[int]):
self._p = tuple(coords)
self.length = len(self._p)
def as_tuple(self):
return tuple(self._p)
def __repr__(self):
return f"Node{str(self._p)}"
def __iter__(self):
for p in self._p:
yield p
def __hash__(self):
return hash(self.as_tuple())
def __neg__(self):
return type(self)([-x for x in self._p])
def __add__(self, other: Point):
assert self.length == other.length
return type(self)([self._p[i] + other._p[i] for i in range(self.length)])
def __sub__(self, other: Point):
return self + (-other)
def __rmul__(self, other: int):
return type(self)([other * x for x in self._p])
def __mul__(self, other: int):
return type(self)([other * x for x in self._p])
def __eq__(self, other: Point):
assert self.length == other.length
return self._p == other._p
def __gt__(self, other: Point):
assert self.length == other.length
return self._p > other._p
def __ge__(self, other: Point):
assert self.length == other.length
return self._p >= other._p
def __lt__(self, other: Point):
assert self.length == other.length
return self._p < other._p
def __le__(self, other: Point):
assert self.length == other.length
return self._p <= other._p
def __ne__(self, other: Point):
assert self.length == other.length
return not (self._p == other._p)
class Vertex(Point):
def __init__(self, bag: VertexBag, coords: Sequence[int]):
super().__init__(coords)
self._bag = bag
def step(self, direction: int):
polarity, direction = divmod(direction, self._bag.n_dims)
coords = list(self)
if polarity % 2:
coords[direction] -= 1
else:
coords[direction] += 1
return self._bag[coords]
def get_xy(self, edge_length: float = 1.0, form: Union[Literal["xy"], Literal["xy1"]] = "xy"):
x = edge_length * sum([p * np.cos(np.pi / self.length * i) for i, p in enumerate(self._p)])
y = edge_length * sum([p * np.sin(np.pi / self.length * i) for i, p in enumerate(self._p)])
if form == "xy1":
return np.asarray([x, y, 1])
return np.asarray([x, y])
class VertexBag:
def __init__(self, n_dims: int):
self._dict = dict()
self.n_dims = n_dims
def __getitem__(self, item):
assert len(item) == self.n_dims
if isinstance(item, Vertex):
item = item.as_tuple()
if tuple(item) in self.keys():
return self._dict[tuple(item)]
vertex = Vertex(self, item)
self._dict[tuple(item)] = vertex
return vertex
def keys(self):
return self._dict.keys()
def as_nparray(self):
return np.stack(self.keys())
def get_xy(self, edge_length, form: Union[Literal["xy"], Literal["xy1"]] = "xy"):
transmat = np.asarray(
[(np.cos(np.pi / self.n_dims * i), np.sin(np.pi / self.n_dims * i)) for i in range(self.n_dims)]
) * edge_length
xy = np.matmul(self.as_nparray(), transmat)
if form == "xy":
return xy
elif form == "xy1":
np.vstack([xy, np.ones([xy.shape[0], 1])])
else:
raise ValueError(f"Unknown form: '{form}'")
class EdgeBag:
def __init__(self):
self._set = set()
def exists(self, a: Vertex, b: Vertex):
if a > b:
return (b, a) in self._set
else:
return (a, b) in self._set
def add(self, a: Vertex, b: Vertex):
if a > b:
self._set.add((b, a))
else:
self._set.add((a, b))
def remove(self, a: Vertex, b: Vertex):
if a > b:
self._set.remove((b, a))
else:
self._set.remove((a, b))