-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathgconv_lstm.py
226 lines (190 loc) · 9.37 KB
/
gconv_lstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import torch
from torch.nn import Parameter
from torch_geometric.nn import ChebConv
from torch_geometric.nn.inits import glorot, zeros
from graphgym.register import register_layer
class GConvLSTM(torch.nn.Module):
r"""An implementation of the Chebyshev Graph Convolutional Long Short Term Memory
Cell. For details see this paper: `"Structured Sequence Modeling with Graph
Convolutional Recurrent Networks." <https://arxiv.org/abs/1612.07659>`_
Args:
in_channels (int): Number of input features.
out_channels (int): Number of output features.
K (int): Chebyshev filter size :math:`K`.
normalization (str, optional): The normalization scheme for the graph
Laplacian (default: :obj:`"sym"`):
1. :obj:`None`: No normalization
:math:`\mathbf{L} = \mathbf{D} - \mathbf{A}`
2. :obj:`"sym"`: Symmetric normalization
:math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A}
\mathbf{D}^{-1/2}`
3. :obj:`"rw"`: Random-walk normalization
:math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}`
You need to pass :obj:`lambda_max` to the :meth:`forward` method of
this operator in case the normalization is non-symmetric.
:obj:`\lambda_max` should be a :class:`torch.Tensor` of size
:obj:`[num_graphs]` in a mini-batch scenario and a
scalar/zero-dimensional tensor when operating on single graphs.
You can pre-compute :obj:`lambda_max` via the
:class:`torch_geometric.transforms.LaplacianLambdaMax` transform.
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
"""
def __init__(self, in_channels: int, out_channels: int, K: int = 7,
normalization: str = "sym", id: int = -1, bias: bool = True):
super(GConvLSTM, self).__init__()
assert id >= 0, 'kwarg id is required.'
self.in_channels = in_channels
self.out_channels = out_channels
self.K = K
self.normalization = normalization
self.bias = bias
self._create_parameters_and_layers()
self._set_parameters()
self.id = id
def _create_input_gate_parameters_and_layers(self):
self.conv_x_i = ChebConv(in_channels=self.in_channels,
out_channels=self.out_channels,
K=self.K,
normalization=self.normalization,
bias=self.bias)
self.conv_h_i = ChebConv(in_channels=self.out_channels,
out_channels=self.out_channels,
K=self.K,
normalization=self.normalization,
bias=self.bias)
self.w_c_i = Parameter(torch.Tensor(1, self.out_channels))
self.b_i = Parameter(torch.Tensor(1, self.out_channels))
def _create_forget_gate_parameters_and_layers(self):
self.conv_x_f = ChebConv(in_channels=self.in_channels,
out_channels=self.out_channels,
K=self.K,
normalization=self.normalization,
bias=self.bias)
self.conv_h_f = ChebConv(in_channels=self.out_channels,
out_channels=self.out_channels,
K=self.K,
normalization=self.normalization,
bias=self.bias)
self.w_c_f = Parameter(torch.Tensor(1, self.out_channels))
self.b_f = Parameter(torch.Tensor(1, self.out_channels))
def _create_cell_state_parameters_and_layers(self):
self.conv_x_c = ChebConv(in_channels=self.in_channels,
out_channels=self.out_channels,
K=self.K,
normalization=self.normalization,
bias=self.bias)
self.conv_h_c = ChebConv(in_channels=self.out_channels,
out_channels=self.out_channels,
K=self.K,
normalization=self.normalization,
bias=self.bias)
self.b_c = Parameter(torch.Tensor(1, self.out_channels))
def _create_output_gate_parameters_and_layers(self):
self.conv_x_o = ChebConv(in_channels=self.in_channels,
out_channels=self.out_channels,
K=self.K,
normalization=self.normalization,
bias=self.bias)
self.conv_h_o = ChebConv(in_channels=self.out_channels,
out_channels=self.out_channels,
K=self.K,
normalization=self.normalization,
bias=self.bias)
self.w_c_o = Parameter(torch.Tensor(1, self.out_channels))
self.b_o = Parameter(torch.Tensor(1, self.out_channels))
def _create_parameters_and_layers(self):
self._create_input_gate_parameters_and_layers()
self._create_forget_gate_parameters_and_layers()
self._create_cell_state_parameters_and_layers()
self._create_output_gate_parameters_and_layers()
def _set_parameters(self):
glorot(self.w_c_i)
glorot(self.w_c_f)
glorot(self.w_c_o)
zeros(self.b_i)
zeros(self.b_f)
zeros(self.b_c)
zeros(self.b_o)
def _set_hidden_state(self, X, H):
if not isinstance(H, torch.Tensor):
H = torch.zeros(X.shape[0], self.out_channels).to(X.device)
return H
def _set_cell_state(self, X, C):
if not isinstance(C, torch.Tensor):
C = torch.zeros(X.shape[0], self.out_channels).to(X.device)
return C
def _calculate_input_gate(self, X, edge_index, edge_weight, H, C):
I = self.conv_x_i(X, edge_index, edge_weight)
I = I + self.conv_h_i(H, edge_index, edge_weight)
I = I + (self.w_c_i * C)
I = I + self.b_i
I = torch.sigmoid(I)
return I
def _calculate_forget_gate(self, X, edge_index, edge_weight, H, C):
F = self.conv_x_f(X, edge_index, edge_weight)
F = F + self.conv_h_f(H, edge_index, edge_weight)
F = F + (self.w_c_f * C)
F = F + self.b_f
F = torch.sigmoid(F)
return F
def _calculate_cell_state(self, X, edge_index, edge_weight, H, C, I, F):
T = self.conv_x_c(X, edge_index, edge_weight)
T = T + self.conv_h_c(H, edge_index, edge_weight)
T = T + self.b_c
T = torch.tanh(T)
C = F * C + I * T
return C
def _calculate_output_gate(self, X, edge_index, edge_weight, H, C):
O = self.conv_x_o(X, edge_index, edge_weight)
O = O + self.conv_h_o(H, edge_index, edge_weight)
O = O + (self.w_c_o * C)
O = O + self.b_o
O = torch.sigmoid(O)
return O
def _calculate_hidden_state(self, O, C):
H = O * torch.tanh(C)
return H
def _forward(self, X: torch.FloatTensor, edge_index: torch.LongTensor,
edge_weight: torch.FloatTensor = None,
H: torch.FloatTensor = None,
C: torch.FloatTensor = None
) -> (torch.FloatTensor, torch.FloatTensor):
"""
Making a forward pass. If edge weights are not present the forward pass
defaults to an unweighted graph. If the hidden state and cell state
matrices are not present when the forward pass is called these are
initialized with zeros.
Arg types:
* **X** *(PyTorch Float Tensor)* - Node features.
* **edge_index** *(PyTorch Long Tensor)* - Graph edge indices.
* **edge_weight** *(PyTorch Long Tensor, optional)* - Edge weight vector.
* **H** *(PyTorch Float Tensor, optional)* - Hidden state matrix for all nodes.
* **C** *(PyTorch Float Tensor, optional)* - Cell state matrix for all nodes.
Return types:
* **H** *(PyTorch Float Tensor)* - Hidden state matrix for all nodes.
* **C** *(PyTorch Float Tensor)* - Cell state matrix for all nodes.
"""
H = self._set_hidden_state(X, H)
C = self._set_cell_state(X, C)
I = self._calculate_input_gate(X, edge_index, edge_weight, H, C)
F = self._calculate_forget_gate(X, edge_index, edge_weight, H, C)
C = self._calculate_cell_state(X, edge_index, edge_weight, H, C, I, F)
O = self._calculate_output_gate(X, edge_index, edge_weight, H, C)
H = self._calculate_hidden_state(O, C)
return H, C
def forward(self, batch):
if hasattr(batch, 'edge_weight'):
edge_weight = batch.edge_weight
else:
edge_weight = None
H, C = self._forward(X=batch.node_feature,
edge_index=batch.edge_index,
edge_weight=edge_weight,
H=batch.node_states[self.id],
C=batch.node_cells[self.id])
batch.node_states[self.id] = H
batch.node_cells[self.id] = C
batch.node_feature = H
return batch
register_layer('gconv_lstm', GConvLSTM)