-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcircuit_test.py
151 lines (131 loc) · 5.59 KB
/
circuit_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from circuit import Circuit, RuntimeCircuit
import time
def test_cleartext_add64():
c = Circuit("bristol_circuits/add64.txt", ['V' for _ in range(128)])
for x,y in [(1000, 2010), (100, 200), (111111, 23456), (2**32-1, 2**32-1), (2**60, 2**60+5), (2**63, 2**63+1), (2**64-25, 2**64-100), (2**64-1, 2**64-1)]:
answer = (x+y)%(2**64)
x_bits = [int(c) for c in bin(x)[2:]]
if len(x_bits)<64:
x_bits = [0 for _ in range(64 - len(x_bits))]+ x_bits
y_bits = [int(c) for c in bin(y)[2:]]
if len(y_bits)<64:
y_bits = [0 for _ in range(64 - len(y_bits))]+ y_bits
x_bits = list(reversed(x_bits))
y_bits = list(reversed(y_bits))
out_bits = RuntimeCircuit(c, x_bits+y_bits).evaluate()
out_string = ''.join([str(i) for i in list(reversed(out_bits))])
assert eval('0b'+out_string) == answer, "computed wrong value"
def test_cleartext_sub64():
c = Circuit("bristol_circuits/sub64.txt", ['V' for _ in range(128)])
for x,y in [(1000, 2010), (2010, 1000), (2**32-1, 2**32-1), (2**60, 2**60+5), (2**63, 2**63+1), (2**64-25, 2**64-100), (2**64-1, 2**64-1)]:
answer = (x-y)%(2**64)
x_bits = [int(c) for c in bin(x)[2:]]
if len(x_bits)<64:
x_bits = [0 for _ in range(64 - len(x_bits))]+ x_bits
y_bits = [int(c) for c in bin(y)[2:]]
if len(y_bits)<64:
y_bits = [0 for _ in range(64 - len(y_bits))]+ y_bits
x_bits = list(reversed(x_bits))
y_bits = list(reversed(y_bits))
out_bits = RuntimeCircuit(c, x_bits+y_bits).evaluate()
out_string = ''.join([str(i) for i in list(reversed(out_bits))])
assert eval('0b'+out_string) == answer, "computed wrong value"
def test_cleartext_mul64mod():
c = Circuit("bristol_circuits/mul64mod.txt", ['V' for _ in range(128)])
for x,y in [(100, 200), (111111, 23456), (2**30, 2**10), (2**63, 2**63+1), (2**64-1, 2**64-1)]:
answer = (x*y)%(2**64)
x_bits = [int(c) for c in bin(x)[2:]]
if len(x_bits)<64:
x_bits = [0 for _ in range(64 - len(x_bits))]+ x_bits
y_bits = [int(c) for c in bin(y)[2:]]
if len(y_bits)<64:
y_bits = [0 for _ in range(64 - len(y_bits))]+ y_bits
x_bits = list(reversed(x_bits))
y_bits = list(reversed(y_bits))
out_bits = RuntimeCircuit(c, x_bits+y_bits).evaluate()
out_string = ''.join([str(i) for i in list(reversed(out_bits))])
assert eval('0b'+out_string) == answer, "computed wrong value"
def test_cleartext_mul64():
c = Circuit("bristol_circuits/mul64.txt", ['V' for _ in range(128)])
for x,y in [(100, 200), (111111, 23456), (2**30, 2**10), (2**63, 2**63+1), (2**64-1, 2**64-1)]:
answer = x*y
x_bits = [int(c) for c in bin(x)[2:]]
if len(x_bits)<64:
x_bits = [0 for _ in range(64 - len(x_bits))]+ x_bits
y_bits = [int(c) for c in bin(y)[2:]]
if len(y_bits)<64:
y_bits = [0 for _ in range(64 - len(y_bits))]+ y_bits
x_bits = list(reversed(x_bits))
y_bits = list(reversed(y_bits))
out_bits = RuntimeCircuit(c, x_bits+y_bits).evaluate()
out_string = ''.join([str(i) for i in list(reversed(out_bits))])
out_string = out_string[64:]+out_string[:64]
assert eval('0b'+out_string) == answer, "computed wrong value"
def test_cleartext_lessthan32():
c = Circuit("bristol_circuits/lessthan32.txt", ['V' for _ in range(64)])
for x,y in [(100, 200), (200, 100), (111111, 23456), (2**30, 2**10), (2**10, 2**30), (2**32-1, 2**32-1)]:
answer = 1 if x<y else 0
x_bits = [int(c) for c in bin(x)[2:]]
if len(x_bits)<32:
x_bits = [0 for _ in range(32 - len(x_bits))]+ x_bits
y_bits = [int(c) for c in bin(y)[2:]]
if len(y_bits)<32:
y_bits = [0 for _ in range(32 - len(y_bits))]+ y_bits
x_bits = list(reversed(x_bits))
y_bits = list(reversed(y_bits))
out_bit = RuntimeCircuit(c, x_bits+y_bits).evaluate()
assert out_bit[0] == answer, "computed wrong value"
def test_cleartext_dist32():
x = 4060000
y = 7390000
cx = 4063500
cy = 7396000
rsq = 64000000
answer = 1 if ((x-cx)**2 + (y-cy)**2) < rsq else 0
inputs = [0 for _ in range(32)]
x_bits = [int(b) for b in bin(x)[2:]]
if len(x_bits)<32:
x_bits = [0 for _ in range(32 - len(x_bits))]+ x_bits
y_bits = [int(b) for b in bin(y)[2:]]
if len(y_bits)<32:
y_bits = [0 for _ in range(32 - len(y_bits))]+ y_bits
cx_bits = [int(b) for b in bin(cx)[2:]]
if len(cx_bits)<32:
cx_bits = [0 for _ in range(32 - len(cx_bits))]+ cx_bits
cy_bits = [int(b) for b in bin(cy)[2:]]
if len(cy_bits)<32:
cy_bits = [0 for _ in range(32 - len(cy_bits))]+ cy_bits
rsq_bits = [int(b) for b in bin(rsq)[2:]]
if len(rsq_bits)<32:
rsq_bits = [0 for _ in range(32 - len(rsq_bits))]+ rsq_bits
inputs.extend(x_bits[::-1])
inputs.extend(y_bits[::-1])
inputs.extend(cx_bits[::-1])
inputs.extend(cy_bits[::-1])
inputs.extend(rsq_bits[::-1])
c = Circuit("bristol_circuits/dist32.txt", ['V' for _ in range(192)])
out_bit = RuntimeCircuit(c, inputs).evaluate()
assert out_bit[0] == answer, "computed wrong value"
def test_cleartext_unnormalized_subregion_10k():
start = time.time()
c = Circuit("bristol_circuits/unnormalized_subregion_100_1.txt", ['V' for _ in range(364)])
print(f"circuit load time: {round(time.time()-start, 4)}")
answer = 300
inputs = [0 for _ in range(64)] + [1 for _ in range(300)]
start = time.time()
out_bits = RuntimeCircuit(c, inputs).evaluate()
out_string = ''.join([str(i) for i in list(reversed(out_bits))])
assert eval('0b'+out_string) == answer, "computed wrong value"
print(f"PASS, time: {round(time.time()-start, 4)}")
if __name__ == "__main__":
print("--TEST BASIC CIRCUITS--")
test_cleartext_add64()
test_cleartext_sub64()
test_cleartext_mul64mod()
test_cleartext_mul64()
test_cleartext_lessthan32()
test_cleartext_dist32()
print("PASS")
print("--TEST SUBREGION CIRCUIT--")
test_cleartext_unnormalized_subregion_10k()
print("ALL TESTS PASSED")