-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathuf_dict.py
130 lines (114 loc) · 4.55 KB
/
uf_dict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from collections import defaultdict
from stop_watch import StopWatch
"""
UnionFindDict is a dictionary-like structure for the lookup table of the logical core.
It is a dictionary of the form (label, tuple of objects) -> (tuple of objects)
which can be set using
add(label, input, output)
and retrieved using
get(label, input)
The "add" function cannot overwrite, that is, it raises an exception
if the (label, input) is already in the database. The "get" function
returns a tuple, or None if it is not in the database.
Moreover, the structure allows "gluing" of objects. The function
glue(obj1, obj2)
makes the two objects equal from the perspective of the UnionFindDict.
The "glue" function returns the list of all pairs (a,b) that were glued,
they include the initial (obj1, obj2) and other pairs glued
due to extensionality.
"""
class UnionFindDict:
def __init__(self):
self.data = dict() # the main dictionary
self.obj_to_root_d = dict() # obj -> (representative) obj
self.obj_to_children = defaultdict(set) # inverse of obj_to_root
self.obj_to_keys = defaultdict(set) # obj -> (label, input) such that obj in input or output
def obj_to_root(self, obj):
return self.obj_to_root_d.get(obj, obj)
def tup_to_root(self, tup):
return tuple(map(self.obj_to_root, tup))
def _data_add(self, label, args, vals):
#print("_data_add", label, args, vals)
key = label, args
if key in self.data:
if self.data[key] == vals: return
raise KeyError("key {} is already in the uf_dictionary".format(key))
self.data[key] = vals
for obj in args + vals:
#print(" obj_to_keys[{}] :".format(obj))
#print(" {}".format(self.obj_to_keys[obj]))
self.obj_to_keys[obj].add(key)
#print(" {}".format(self.obj_to_keys[obj]))
def _data_remove(self, label, args):
#print("_data_remove", label, args)
key = label, args
vals = self.data[key]
del self.data[key]
for obj in args + vals:
#print(" obj_to_keys[{}] :".format(obj))
#print(" {}".format(self.obj_to_keys[obj]))
self.obj_to_keys[obj].discard(key)
#print(" {}".format(self.obj_to_keys[obj]))
return vals
def add(self, label, args, vals):
#print('add', label, args, vals)
args, vals = map(self.tup_to_root, (args, vals))
self._data_add(label, args, vals)
return args, vals
def is_equal(self, n1, n2):
n1, n2 = map(self.obj_to_root, (n1, n2))
return n1 == n2
def glue(self, n1, n2):
#print('glue', n1, n2)
result = self.multi_glue((n1, n2))
return result
def multi_glue(self, *pairs):
changed = []
to_glue = list(pairs)
while to_glue:
n1, n2 = to_glue.pop()
n1, n2 = map(self.obj_to_root, (n1, n2))
if n1 == n2: continue
c1, c2 = [
len(self.obj_to_children[n]) + len(self.obj_to_keys[n])
for n in (n1, n2)
]
if c1 < c2: n1, n2 = n2, n1
changed.append((n1, n2))
self.obj_to_root_d[n2] = n1
children1 = self.obj_to_children[n1]
children2 = self.obj_to_children[n2]
for child in children2:
self.obj_to_root_d[child] = n1
children1.update(children2)
children1.add(n2)
children2.clear()
#print("{} : {}".format(n2, self.obj_to_keys[n2]))
for key in tuple(self.obj_to_keys[n2]):
label, args = key
vals = self._data_remove(label, args)
args, vals = map(self.tup_to_root, (args, vals))
ori_val = self.get(label, args)
if ori_val is not None:
assert(len(vals) == len(ori_val))
to_glue.extend(zip(vals, ori_val))
else: self._data_add(label, args, vals)
return changed
def __contains__(self, key):
return self.tup_to_root(key) in self.data
def get(self, label, args): # default = None, otherwise tuple
args = self.tup_to_root(args)
return self.data.get((label, args), None)
if __name__ == "__main__":
d = UnionFindDict()
d.add("A", (1, 0), ())
d.add("B", (1, 0), (2,))
d.add("C", (1, 0), (2,))
d.add("D", (1, 2), (3,))
d.add("E", (3,), (4,))
d.add("F", (3,), (4,))
d.add("G", (3,), (5,))
d.add("H", (3,), (5,))
d.glue(1, 4)
d.glue(2, 5)
print(d.data)