-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathgeneralconv_hete_complete.py
173 lines (149 loc) · 6.66 KB
/
generalconv_hete_complete.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import torch
import torch.nn as nn
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from graphgym.config import cfg
from graphgym.register import register_layer
class EdgeConvLayer(MessagePassing):
r"""
Args:
in_channels_neigh (int): The input dimension of the end node type.
out_channels (int): The dimension of the output.
in_channels_self (int): The input dimension of the start node type.
Default is `None` where the `in_channels_self` is equal to
`in_channels_neigh`.
"""
def __init__(self, in_channels_neigh, out_channels, in_channels_self=None):
super(EdgeConvLayer, self).__init__(aggr=cfg.gnn.agg)
self.in_channels_neigh = in_channels_neigh
if in_channels_self is None:
self.in_channels_self = in_channels_neigh
else:
self.in_channels_self = in_channels_self
self.out_channels = out_channels
self.edge_channels = cfg.dataset.edge_dim
self.msg_direction = cfg.gnn.msg_direction
self.lin_neigh = nn.Linear(self.in_channels_neigh, self.out_channels)
self.lin_self = nn.Linear(self.in_channels_self, self.out_channels)
if self.msg_direction == 'single':
self.lin_update = nn.Linear(
self.out_channels + cfg.dataset.edge_dim,
self.out_channels)
elif self.msg_direction == 'both':
self.lin_update = nn.Linear(
self.out_channels * 2 + cfg.dataset.edge_dim,
self.out_channels)
else:
raise ValueError
def forward(self, node_feature_neigh, node_feature_self, edge_index,
edge_feature, edge_weight=None, size=None):
return self.propagate(
edge_index, size=size,
node_feature_neigh=node_feature_neigh,
node_feature_self=node_feature_self,
edge_feature=edge_feature,
edge_weight=edge_weight
)
def message(self, node_feature_neigh_j, node_feature_self_i,
edge_feature, edge_weight):
if self.msg_direction == 'single':
node_feature_neigh_j = self.lin_neigh(node_feature_neigh_j)
return torch.cat([node_feature_neigh_j, edge_feature], dim=-1)
else:
node_feature_neigh_j = self.lin_neigh(node_feature_neigh_j)
node_feature_self_i = self.lin_self(node_feature_self_i)
return torch.cat(
[node_feature_neigh_j, edge_feature, node_feature_self_i],
dim=-1)
def update(self, aggr_out):
aggr_out = self.lin_update(aggr_out)
return aggr_out
def __repr__(self):
return (
f"{self.__class__.__name__}"
f"(neigh: {self.in_channels_neigh}, self: {self.in_channels_self}, "
f"edge: {self.edge_channels},"
f"out: {self.out_channels})"
)
class HeteroGNNWrapperConv(torch.nn.Module):
def __init__(self, convs, dim_in, dim_out, aggr='add'):
super(HeteroGNNWrapperConv, self).__init__()
self.convs = convs
# self.modules = torch.nn.ModuleList(convs.values())
self.dim_in = dim_in
self.dim_out = dim_out
# NOTE: this aggregation is different from cfg.gnn.agg
assert aggr in ['add', 'mean', 'max', None]
self.aggr = aggr
self.reset_parameters() # TODO: something like this?
def reset_parameters(self):
for conv in self.convs.values():
reset(conv)
def forward(self, node_features, edge_indices, edge_features):
r"""The forward function for `HeteroConv`.
Args:
node_features (dict): A dictionary each key is node type and the
corresponding value is a node feature tensor.
edge_indices (dict): A dictionary each key is message type and the
corresponding value is an edge index tensor.
edge_features (dict): A dictionary each key is edge type and the
corresponding value is an edge feature tensor.
"""
# node embedding computed from each message type
message_type_emb = {}
for message_key, message_type in edge_indices.items():
# neigh_type --(edge_type)--> self_type
neigh_type, edge_type, self_type = message_key
node_feature_neigh = node_features[neigh_type]
node_feature_self = node_features[self_type]
# edge_feature = edge_features[edge_type]
edge_feature = edge_features[message_key]
edge_index = edge_indices[message_key]
message_type_emb[message_key] = (
self.convs[str(message_key)](
node_feature_neigh,
node_feature_self,
edge_index,
edge_feature
)
)
# TODO: What if a type does not receive anything within the period?
node_emb = {typ: [] for typ in node_features.keys()}
for (_, _, tail), item in message_type_emb.items():
node_emb[tail].append(item)
# Aggregate multiple embeddings with the same tail.
for node_type, embs in node_emb.items():
if len(embs) == 0:
# This type of nodes did not receive any incoming edge,
# put all zeros, keep_ratio will be 1, this does not matter.
node_emb[node_type] = torch.zeros((
node_features[node_type].shape[0],
self.dim_out
)).to(cfg.device)
elif len(embs) == 1:
node_emb[node_type] = embs[0]
else:
node_emb[node_type] = self.aggregate(embs)
return node_emb # Dict[NodeType, NodeEmb]
def aggregate(self, xs):
x = torch.stack(xs, dim=-1)
if self.aggr == "add":
return x.sum(dim=-1)
elif self.aggr == "mean":
return x.mean(dim=-1)
elif self.aggr == "max":
return x.max(dim=-1)[0]
class HeteroGeneralEdgeConv(nn.Module):
def __init__(self, dim_in, dim_out, **kwargs):
super(HeteroGeneralEdgeConv, self).__init__()
convs = nn.ModuleDict()
for s, r, d in cfg.dataset.message_types:
module = EdgeConvLayer(dim_in, dim_out)
convs[str((s, r, d))] = module
self.model = HeteroGNNWrapperConv(convs, dim_in, dim_out, 'mean')
def forward(self, batch):
batch.node_feature = self.model(batch.node_feature,
batch.edge_index,
batch.edge_feature)
return batch
register_layer('generaledgeheteconv_complete', HeteroGeneralEdgeConv)