forked from snap-stanford/roland
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathroland.py
117 lines (96 loc) · 4.53 KB
/
roland.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import torch.nn as nn
import deepsnap
from graphgym.config import cfg
from graphgym.register import (register_node_encoder,
register_edge_encoder)
import pdb
class TransactionEdgeEncoder(torch.nn.Module):
r"""A module that encodes edge features in the transaction graph.
Example:
TransactionEdgeEncoder(
(embedding_list): ModuleList(
(0): Embedding(50, 32) # The first integral edge feature has 50 unique values.
# convert this integral feature to 32 dimensional embedding.
(1): Embedding(8, 32)
(2): Embedding(252, 32)
(3): Embedding(252, 32)
)
(linear_amount): Linear(in_features=1, out_features=64, bias=True)
(linear_time): Linear(in_features=1, out_features=64, bias=True)
)
Initial edge feature dimension = 6
Final edge embedding dimension = 32 + 32 + 32 + 32 + 64 + 64 = 256
"""
def __init__(self, emb_dim: int):
# emb_dim is not used here.
super(TransactionEdgeEncoder, self).__init__()
self.embedding_list = torch.nn.ModuleList()
# Note: feature_edge_int_num[i] = len(torch.unique(graph.edge_feature[:, i]))
# where i-th edge features are integral.
for num in cfg.transaction.feature_edge_int_num:
emb = torch.nn.Embedding(num, cfg.transaction.feature_int_dim)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.embedding_list.append(emb)
# Embed non-integral features.
self.linear_amount = nn.Linear(1, cfg.transaction.feature_amount_dim)
self.linear_time = nn.Linear(1, cfg.transaction.feature_time_dim)
# update edge_dim
cfg.dataset.edge_dim = len(cfg.transaction.feature_edge_int_num) \
* cfg.transaction.feature_int_dim \
+ cfg.transaction.feature_amount_dim \
+ cfg.transaction.feature_time_dim
def forward(self, batch: deepsnap.batch.Batch) -> deepsnap.batch.Batch:
edge_embedding = []
for i in range(len(self.embedding_list)):
edge_embedding.append(
self.embedding_list[i](batch.edge_feature[:, i].long())
)
# By default, edge_feature[:, -2] contains edge amount,
# edge_feature[:, -1] contains edge time.
edge_embedding.append(
self.linear_amount(batch.edge_feature[:, -2].view(-1, 1))
)
edge_embedding.append(
self.linear_time(batch.edge_feature[:, -1].view(-1, 1))
)
batch.edge_feature = torch.cat(edge_embedding, dim=1)
return batch
register_edge_encoder('roland', TransactionEdgeEncoder)
class TransactionNodeEncoder(torch.nn.Module):
r"""A module that encodes node features in the transaction graph.
Parameters:
num_classes - the number of classes for the embedding mapping to learn
Example:
3 unique values for the first integral node feature.
3 unique values for the second integral node feature.
cfg.transaction.feature_node_int_num = [3, 3]
cfg.transaction.feature_int_dim = 32
TransactionNodeEncoder(
(embedding_list): ModuleList(
(0): Embedding(3, 32) # embed the first node feature to 32-dimensional space.
(1): Embedding(3, 32) # embed the second node feature to 32-dimensional space.
)
)
Initial node feature dimension = 2
Final node embedding dimension = 32 + 32 = 256
"""
def __init__(self, emb_dim: int, num_classes=None):
super(TransactionNodeEncoder, self).__init__()
self.embedding_list = torch.nn.ModuleList()
for i, num in enumerate(cfg.transaction.feature_node_int_num):
emb = torch.nn.Embedding(num, cfg.transaction.feature_int_dim)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.embedding_list.append(emb)
# update encoder_dim
cfg.dataset.encoder_dim = len(cfg.transaction.feature_node_int_num) \
* cfg.transaction.feature_int_dim
def forward(self, batch: deepsnap.batch.Batch) -> deepsnap.batch.Batch:
node_embedding = []
for i in range(len(self.embedding_list)):
node_embedding.append(
self.embedding_list[i](batch.node_feature[:, i].long())
)
batch.node_feature = torch.cat(node_embedding, dim=1)
return batch
register_node_encoder('roland', TransactionNodeEncoder)