-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstep16.py
99 lines (81 loc) · 2.75 KB
/
step16.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# 由于存在一种情况就是向funcs里面添加的creator存在两个或两个以上的情况,单纯的只用一个判断是不够的,不然会导致后面的梯度计算出现问题,
# 这种情况下,引入辈分的概念。
import numpy as np
def as_array(x):
if np.isscalar(x):
return np.array(x)
return x
class Variable:
def __init__(self, data):
if not data:
if not np.isscalar(data):
raise TypeError("{} is not supported.".format(type(data)))
self.data = data
self.grad = None
self.creator = None
self.generations = 0
def set_creator(self, func):
self.creator = func
self.generations = func.generations + 1
def cleargrad(self):
self.grad = None
def backward(self):
if self.grad is None:
self.grad = np.ones_like(self.data)
funcs = []
seen_set = set()
def add_func(f):
if f not in seen_set:
funcs.append(f)
seen_set.add(f)
funcs.sort(key=lambda x: x.generations)
add_func(self.creator)
while funcs:
f = funcs.pop()
gys = [output.grad for output in f.outputs]
gxs = f.backward(*gys)
if not isinstance(gxs, tuple):
gxs = (gxs,)
for x, gx in zip(f.inputs, gxs):
if x.grad is None:
x.grad = gx
else:
x.grad = x.grad + gx
if x.creator is not None:
add_func(x.creator)
# 注意重复计算时会有叠加的问题,所以添加clear函数
class Function:
def __call__(self, *inputs):
xs = [x.data for x in inputs]
ys = self.forward(*xs)
if not isinstance(ys, tuple):
ys = (ys,)
outputs = [Variable(as_array(y)) for y in ys]
for output in outputs:
output.set_creator(self)
self.inputs = inputs
self.outputs = outputs
# 多个generations时的处理需要选取最大的
self.generations = max([x.generations for x in inputs])
return outputs if len(outputs) > 1 else outputs[0]
def forward(self, xs):
raise NotImplementedError
def backward(self, gy):
raise NotImplementedError
class Add(Function):
def forward(self, x0, x1):
y = x0 + x1
return y
def backward(self, gy):
return gy, gy
class Square(Function):
def forward(self, x):
y = x ** 2
return y
def backward(self, gy):
x = self.inputs[0].data
gx = 2 * x * gy
return gx
def square(x):
return Square()(x)
# 使用generation构造优先级,让generation越大优先级越高,最先调用