forked from rasmusbergpalm/hebbian-evolution
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhebbian_layer.py
42 lines (32 loc) · 1.26 KB
/
hebbian_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch as t
import shapeguard
class HebbianLayer:
def __init__(self, hebb_coeff: t.Tensor, activation_fn, weights: t.Tensor = None, normalize=False):
self.n_in, self.n_out, _ = hebb_coeff.shape
if weights is not None:
weights.sg((self.n_in, self.n_out))
else:
weights = 0.2 * t.rand((self.n_in, self.n_out), requires_grad=False) - 0.1
self.W = weights
self.h = hebb_coeff
self.normalize = normalize
self.activation_fn = activation_fn
def get_weights(self):
return self.W + 0
def forward(self, pre):
pre.sg((self.n_in,))
post = self.activation_fn(pre @ self.W).sg((self.n_out,))
self.update(pre, post)
return post
def update(self, pre, post):
pre.sg((self.n_in,))
post.sg((self.n_out,))
eta, A, B, C, D = [v.squeeze().sg((self.n_in, self.n_out)) for v in self.h.split(1, -1)]
self.W += eta * (
A * (pre[:, None] @ post[None, :]).sg((self.n_in, self.n_out)) +
(B * pre[:, None]).sg((self.n_in, self.n_out)) +
(C * post[None, :]).sg((self.n_in, self.n_out)) +
D
)
if self.normalize:
self.W = self.W / self.W.abs().max()