From 17c1a23b3096ad8e6b3704a5d02464cc8546a42c Mon Sep 17 00:00:00 2001 From: FangJiangyi <2290755361@qq.com> Date: Mon, 15 Jan 2024 20:31:01 +0800 Subject: [PATCH] Add MTGNN --- Experiments/MTGNN/MTGNN.py | 209 ++++++++++++++++ Experiments/MTGNN/Runner.py | 0 UCTB/model/MTGNN.py | 462 ++++++++++++++++++++++++++++++++++++ 3 files changed, 671 insertions(+) create mode 100644 Experiments/MTGNN/MTGNN.py create mode 100644 Experiments/MTGNN/Runner.py create mode 100644 UCTB/model/MTGNN.py diff --git a/Experiments/MTGNN/MTGNN.py b/Experiments/MTGNN/MTGNN.py new file mode 100644 index 00000000..7938f858 --- /dev/null +++ b/Experiments/MTGNN/MTGNN.py @@ -0,0 +1,209 @@ +import argparse +from UCTB.dataset import NodeTrafficLoader +import os +from UCTB.utils.utils_GraphWaveNet import * +from UCTB.preprocess.GraphGenerator import GraphGenerator +from UCTB.dataset import NodeTrafficLoader, data_loader +from UCTB.evaluation import metric +from UCTB.utils.utils_MTGNN import * +from UCTB.model.MTGNN import gtnet + + +parser = argparse.ArgumentParser(description='PyTorch Time series forecasting') +parser.add_argument('--data', type=str, default='./data/solar_AL.txt', + help='location of the data file') +parser.add_argument('--log_interval', type=int, default=2000, metavar='N', + help='report interval') +parser.add_argument('--save', type=str, default='model/model.pt', + help='path to save the final model') +parser.add_argument('--optim', type=str, default='adam') +parser.add_argument('--L1Loss', type=bool, default=True) +parser.add_argument('--normalize', type=int, default=2) +parser.add_argument('--device',type=str,default='cuda:1',help='') +parser.add_argument('--gcn_true', type=bool, default=True, help='whether to add graph convolution layer') +parser.add_argument('--buildA_true', type=bool, default=True, help='whether to construct adaptive adjacency matrix') +parser.add_argument('--gcn_depth',type=int,default=2,help='graph convolution depth') +parser.add_argument('--num_nodes',type=int,default=137,help='number of nodes/variables') +parser.add_argument('--dropout',type=float,default=0.3,help='dropout rate') +parser.add_argument('--subgraph_size',type=int,default=20,help='k') +parser.add_argument('--node_dim',type=int,default=40,help='dim of nodes') +parser.add_argument('--dilation_exponential',type=int,default=2,help='dilation exponential') +parser.add_argument('--conv_channels',type=int,default=16,help='convolution channels') +parser.add_argument('--residual_channels',type=int,default=16,help='residual channels') +parser.add_argument('--skip_channels',type=int,default=32,help='skip channels') +parser.add_argument('--end_channels',type=int,default=64,help='end channels') +parser.add_argument('--in_dim',type=int,default=1,help='inputs dimension') +parser.add_argument('--seq_in_len',type=int,default=24*7,help='input sequence length') +parser.add_argument('--seq_out_len',type=int,default=1,help='output sequence length') +parser.add_argument('--horizon', type=int, default=3) +parser.add_argument('--layers',type=int,default=5,help='number of layers') + +parser.add_argument('--batch_size',type=int,default=32,help='batch size') +parser.add_argument('--lr',type=float,default=0.0001,help='learning rate') +parser.add_argument('--weight_decay',type=float,default=0.00001,help='weight decay rate') + +parser.add_argument('--clip',type=int,default=5,help='clip') + +parser.add_argument('--propalpha',type=float,default=0.05,help='prop alpha') +parser.add_argument('--tanhalpha',type=float,default=3,help='tanh alpha') + +parser.add_argument('--epochs',type=int,default=1,help='') +parser.add_argument('--num_split',type=int,default=1,help='number of splits for graphs') +parser.add_argument('--step_size',type=int,default=100,help='step_size') + +# data parameters +parser.add_argument("--dataset", default='DiDi', type=str, help="configuration file path") +parser.add_argument("--city", default='Xian', type=str) +parser.add_argument("--closeness_len", default=6, type=int) +parser.add_argument("--period_len", default=7, type=int) +parser.add_argument("--trend_len", default=4, type=int) +parser.add_argument("--data_range", default="all", type=str) +parser.add_argument("--train_data_length", default="all", type=str) +parser.add_argument("--test_ratio", default=0.1, type=float) +parser.add_argument("--MergeIndex", default=1, type=int) +parser.add_argument("--MergeWay", default="sum", type=str) + +args = parser.parse_args() + +# loading data +uctb_data_loader = NodeTrafficLoader(dataset=args.dataset, city=args.city, + data_range=args.data_range, train_data_length=args.train_data_length, + test_ratio=float(args.test_ratio), + closeness_len=args.closeness_len, + period_len=args.period_len, + trend_len=args.trend_len, + normalize=False, + MergeIndex=args.MergeIndex, + MergeWay=args.MergeWay) + +args.num_nodes = uctb_data_loader.station_number +args.in_dim = uctb_data_loader.closeness_len + uctb_data_loader.period_len + uctb_data_loader.trend_len +args.seq_length = 1 +args.save = os.path.abspath('./experiment/{}_{}_{}'.format(args.dataset, args.city, args.MergeIndex)) +if not os.path.exists(args.save): + os.makedirs(args.save) + +# Build Graph +graph_obj = GraphGenerator(graph='distance', data_loader=uctb_data_loader) + + +device = torch.device(args.device) +data_dict = load_dataset(uctb_data_loader, args.batch_size, args.batch_size, args.batch_size) +# 需要改下 +predefined_A = load_adj(args.adj_data) +predefined_A = torch.tensor(predefined_A)-torch.eye(args.num_nodes) +predefined_A = predefined_A.to(device) + +model = gtnet(args.gcn_true, args.buildA_true, args.gcn_depth, args.num_nodes, + device, dropout=args.dropout, subgraph_size=args.subgraph_size, + node_dim=args.node_dim, dilation_exponential=args.dilation_exponential, + conv_channels=args.conv_channels, residual_channels=args.residual_channels, + skip_channels=args.skip_channels, end_channels= args.end_channels, + seq_length=args.seq_in_len, in_dim=args.in_dim, out_dim=args.seq_out_len, + layers=args.layers, propalpha=args.propalpha, tanhalpha=args.tanhalpha, layer_norm_affline=False) +model = model.to(device) +print(args) +print('The recpetive field size is', model.receptive_field) +nParams = sum([p.nelement() for p in model.parameters()]) +print('Number of model parameters is', nParams) +engine = Trainer(model, args.learning_rate, args.weight_decay, args.clip, args.step_size1, args.seq_out_len, device, args.cl) +print("start training...",flush=True) +his_loss =[] +val_time = [] +train_time = [] +minl = 1e5 +for i in range(1,args.epochs+1): + train_loss = [] + train_mape = [] + train_rmse = [] + t1 = time.time() + data_dict['train_loader'].shuffle() + for iter, (x, y) in enumerate(data_dict['train_loader'].get_iterator()): + trainx = torch.Tensor(x).to(device) + trainx= trainx.transpose(1, 3) + trainy = torch.Tensor(y).to(device) + trainy = trainy.transpose(1, 3) + if iter%args.step_size2==0: + perm = np.random.permutation(range(args.num_nodes)) + num_sub = int(args.num_nodes/args.num_split) + for j in range(args.num_split): + if j != args.num_split-1: + id = perm[j * num_sub:(j + 1) * num_sub] + else: + id = perm[j * num_sub:] + id = torch.tensor(id).to(device) + tx = trainx[:, :, id, :] + ty = trainy[:, :, id, :] + metrics = engine.train(tx, ty[:,0,:,:],id) + train_loss.append(metrics[0]) + train_mape.append(metrics[1]) + train_rmse.append(metrics[2]) + if iter % args.print_every == 0 : + log = 'Iter: {:03d}, Train Loss: {:.4f}, Train MAPE: {:.4f}, Train RMSE: {:.4f}' + print(log.format(iter, train_loss[-1], train_mape[-1], train_rmse[-1]),flush=True) + t2 = time.time() + train_time.append(t2-t1) + #validation + valid_loss = [] + valid_mape = [] + valid_rmse = [] + s1 = time.time() + for iter, (x, y) in enumerate(data_dict['val_loader'].get_iterator()): + testx = torch.Tensor(x).to(device) + testx = testx.transpose(1, 3) + testy = torch.Tensor(y).to(device) + testy = testy.transpose(1, 3) + metrics = engine.eval(testx, testy[:,0,:,:]) + valid_loss.append(metrics[0]) + valid_mape.append(metrics[1]) + valid_rmse.append(metrics[2]) + s2 = time.time() + log = 'Epoch: {:03d}, Inference Time: {:.4f} secs' + print(log.format(i,(s2-s1))) + val_time.append(s2-s1) + mtrain_loss = np.mean(train_loss) + mtrain_mape = np.mean(train_mape) + mtrain_rmse = np.mean(train_rmse) + mvalid_loss = np.mean(valid_loss) + mvalid_mape = np.mean(valid_mape) + mvalid_rmse = np.mean(valid_rmse) + his_loss.append(mvalid_loss) + log = 'Epoch: {:03d}, Train Loss: {:.4f}, Train MAPE: {:.4f}, Train RMSE: {:.4f}, Valid Loss: {:.4f}, Valid MAPE: {:.4f}, Valid RMSE: {:.4f}, Training Time: {:.4f}/epoch' + print(log.format(i, mtrain_loss, mtrain_mape, mtrain_rmse, mvalid_loss, mvalid_mape, mvalid_rmse, (t2 - t1)),flush=True) + if mvalid_lossncvl',(x,A)) + return x.contiguous() + +class dy_nconv(nn.Module): + def __init__(self): + super(dy_nconv,self).__init__() + + def forward(self,x, A): + x = torch.einsum('ncvl,nvwl->ncwl',(x,A)) + return x.contiguous() + +class linear(nn.Module): + def __init__(self,c_in,c_out,bias=True): + super(linear,self).__init__() + self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0,0), stride=(1,1), bias=bias) + + def forward(self,x): + return self.mlp(x) + + +class prop(nn.Module): + def __init__(self,c_in,c_out,gdep,dropout,alpha): + super(prop, self).__init__() + self.nconv = nconv() + self.mlp = linear(c_in,c_out) + self.gdep = gdep + self.dropout = dropout + self.alpha = alpha + + def forward(self,x,adj): + adj = adj + torch.eye(adj.size(0)).to(x.device) + d = adj.sum(1) + h = x + dv = d + a = adj / dv.view(-1, 1) + for i in range(self.gdep): + h = self.alpha*x + (1-self.alpha)*self.nconv(h,a) + ho = self.mlp(h) + return ho + + +class mixprop(nn.Module): + def __init__(self,c_in,c_out,gdep,dropout,alpha): + super(mixprop, self).__init__() + self.nconv = nconv() + self.mlp = linear((gdep+1)*c_in,c_out) + self.gdep = gdep + self.dropout = dropout + self.alpha = alpha + + + def forward(self,x,adj): + adj = adj + torch.eye(adj.size(0)).to(x.device) + d = adj.sum(1) + h = x + out = [h] + a = adj / d.view(-1, 1) + for i in range(self.gdep): + h = self.alpha*x + (1-self.alpha)*self.nconv(h,a) + out.append(h) + ho = torch.cat(out,dim=1) + ho = self.mlp(ho) + return ho + +class dy_mixprop(nn.Module): + def __init__(self,c_in,c_out,gdep,dropout,alpha): + super(dy_mixprop, self).__init__() + self.nconv = dy_nconv() + self.mlp1 = linear((gdep+1)*c_in,c_out) + self.mlp2 = linear((gdep+1)*c_in,c_out) + + self.gdep = gdep + self.dropout = dropout + self.alpha = alpha + self.lin1 = linear(c_in,c_in) + self.lin2 = linear(c_in,c_in) + + + def forward(self,x): + #adj = adj + torch.eye(adj.size(0)).to(x.device) + #d = adj.sum(1) + x1 = torch.tanh(self.lin1(x)) + x2 = torch.tanh(self.lin2(x)) + adj = self.nconv(x1.transpose(2,1),x2) + adj0 = torch.softmax(adj, dim=2) + adj1 = torch.softmax(adj.transpose(2,1), dim=2) + + h = x + out = [h] + for i in range(self.gdep): + h = self.alpha*x + (1-self.alpha)*self.nconv(h,adj0) + out.append(h) + ho = torch.cat(out,dim=1) + ho1 = self.mlp1(ho) + + + h = x + out = [h] + for i in range(self.gdep): + h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj1) + out.append(h) + ho = torch.cat(out, dim=1) + ho2 = self.mlp2(ho) + + return ho1+ho2 + + + +class dilated_1D(nn.Module): + def __init__(self, cin, cout, dilation_factor=2): + super(dilated_1D, self).__init__() + self.tconv = nn.ModuleList() + self.kernel_set = [2,3,6,7] + self.tconv = nn.Conv2d(cin,cout,(1,7),dilation=(1,dilation_factor)) + + def forward(self,input): + x = self.tconv(input) + return x + +class dilated_inception(nn.Module): + def __init__(self, cin, cout, dilation_factor=2): + super(dilated_inception, self).__init__() + self.tconv = nn.ModuleList() + self.kernel_set = [2,3,6,7] + cout = int(cout/len(self.kernel_set)) + for kern in self.kernel_set: + self.tconv.append(nn.Conv2d(cin,cout,(1,kern),dilation=(1,dilation_factor))) + + def forward(self,input): + x = [] + for i in range(len(self.kernel_set)): + x.append(self.tconv[i](input)) + for i in range(len(self.kernel_set)): + x[i] = x[i][...,-x[-1].size(3):] + x = torch.cat(x,dim=1) + return x + + +class graph_constructor(nn.Module): + def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): + super(graph_constructor, self).__init__() + self.nnodes = nnodes + if static_feat is not None: + xd = static_feat.shape[1] + self.lin1 = nn.Linear(xd, dim) + self.lin2 = nn.Linear(xd, dim) + else: + self.emb1 = nn.Embedding(nnodes, dim) + self.emb2 = nn.Embedding(nnodes, dim) + self.lin1 = nn.Linear(dim,dim) + self.lin2 = nn.Linear(dim,dim) + + self.device = device + self.k = k + self.dim = dim + self.alpha = alpha + self.static_feat = static_feat + + def forward(self, idx): + if self.static_feat is None: + nodevec1 = self.emb1(idx) + nodevec2 = self.emb2(idx) + else: + nodevec1 = self.static_feat[idx,:] + nodevec2 = nodevec1 + + nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1)) + nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2)) + + a = torch.mm(nodevec1, nodevec2.transpose(1,0))-torch.mm(nodevec2, nodevec1.transpose(1,0)) + adj = F.relu(torch.tanh(self.alpha*a)) + mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device) + mask.fill_(float('0')) + s1,t1 = (adj + torch.rand_like(adj)*0.01).topk(self.k,1) + mask.scatter_(1,t1,s1.fill_(1)) + adj = adj*mask + return adj + + def fullA(self, idx): + if self.static_feat is None: + nodevec1 = self.emb1(idx) + nodevec2 = self.emb2(idx) + else: + nodevec1 = self.static_feat[idx,:] + nodevec2 = nodevec1 + + nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1)) + nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2)) + + a = torch.mm(nodevec1, nodevec2.transpose(1,0))-torch.mm(nodevec2, nodevec1.transpose(1,0)) + adj = F.relu(torch.tanh(self.alpha*a)) + return adj + +class graph_global(nn.Module): + def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): + super(graph_global, self).__init__() + self.nnodes = nnodes + self.A = nn.Parameter(torch.randn(nnodes, nnodes).to(device), requires_grad=True).to(device) + + def forward(self, idx): + return F.relu(self.A) + + +class graph_undirected(nn.Module): + def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): + super(graph_undirected, self).__init__() + self.nnodes = nnodes + if static_feat is not None: + xd = static_feat.shape[1] + self.lin1 = nn.Linear(xd, dim) + else: + self.emb1 = nn.Embedding(nnodes, dim) + self.lin1 = nn.Linear(dim,dim) + + self.device = device + self.k = k + self.dim = dim + self.alpha = alpha + self.static_feat = static_feat + + def forward(self, idx): + if self.static_feat is None: + nodevec1 = self.emb1(idx) + nodevec2 = self.emb1(idx) + else: + nodevec1 = self.static_feat[idx,:] + nodevec2 = nodevec1 + + nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1)) + nodevec2 = torch.tanh(self.alpha*self.lin1(nodevec2)) + + a = torch.mm(nodevec1, nodevec2.transpose(1,0)) + adj = F.relu(torch.tanh(self.alpha*a)) + mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device) + mask.fill_(float('0')) + s1,t1 = adj.topk(self.k,1) + mask.scatter_(1,t1,s1.fill_(1)) + adj = adj*mask + return adj + + + +class graph_directed(nn.Module): + def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): + super(graph_directed, self).__init__() + self.nnodes = nnodes + if static_feat is not None: + xd = static_feat.shape[1] + self.lin1 = nn.Linear(xd, dim) + self.lin2 = nn.Linear(xd, dim) + else: + self.emb1 = nn.Embedding(nnodes, dim) + self.emb2 = nn.Embedding(nnodes, dim) + self.lin1 = nn.Linear(dim,dim) + self.lin2 = nn.Linear(dim,dim) + + self.device = device + self.k = k + self.dim = dim + self.alpha = alpha + self.static_feat = static_feat + + def forward(self, idx): + if self.static_feat is None: + nodevec1 = self.emb1(idx) + nodevec2 = self.emb2(idx) + else: + nodevec1 = self.static_feat[idx,:] + nodevec2 = nodevec1 + + nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1)) + nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2)) + + a = torch.mm(nodevec1, nodevec2.transpose(1,0)) + adj = F.relu(torch.tanh(self.alpha*a)) + mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device) + mask.fill_(float('0')) + s1,t1 = adj.topk(self.k,1) + mask.scatter_(1,t1,s1.fill_(1)) + adj = adj*mask + return adj + + +class LayerNorm(nn.Module): + __constants__ = ['normalized_shape', 'weight', 'bias', 'eps', 'elementwise_affine'] + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super(LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = tuple(normalized_shape) + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter(torch.Tensor(*normalized_shape)) + self.bias = nn.Parameter(torch.Tensor(*normalized_shape)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + self.reset_parameters() + + + def reset_parameters(self): + if self.elementwise_affine: + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, input, idx): + if self.elementwise_affine: + return F.layer_norm(input, tuple(input.shape[1:]), self.weight[:,idx,:], self.bias[:,idx,:], self.eps) + else: + return F.layer_norm(input, tuple(input.shape[1:]), self.weight, self.bias, self.eps) + + def extra_repr(self): + return '{normalized_shape}, eps={eps}, ' \ + 'elementwise_affine={elementwise_affine}'.format(**self.__dict__) + +class gtnet(nn.Module): + def __init__(self, gcn_true, buildA_true, gcn_depth, num_nodes, device, predefined_A=None, static_feat=None, dropout=0.3, subgraph_size=20, node_dim=40, dilation_exponential=1, conv_channels=32, residual_channels=32, skip_channels=64, end_channels=128, seq_length=12, in_dim=2, out_dim=12, layers=3, propalpha=0.05, tanhalpha=3, layer_norm_affline=True): + super(gtnet, self).__init__() + self.gcn_true = gcn_true + self.buildA_true = buildA_true + self.num_nodes = num_nodes + self.dropout = dropout + self.predefined_A = predefined_A + self.filter_convs = nn.ModuleList() + self.gate_convs = nn.ModuleList() + self.residual_convs = nn.ModuleList() + self.skip_convs = nn.ModuleList() + self.gconv1 = nn.ModuleList() + self.gconv2 = nn.ModuleList() + self.norm = nn.ModuleList() + self.start_conv = nn.Conv2d(in_channels=in_dim, + out_channels=residual_channels, + kernel_size=(1, 1)) + self.gc = graph_constructor(num_nodes, subgraph_size, node_dim, device, alpha=tanhalpha, static_feat=static_feat) + + self.seq_length = seq_length + kernel_size = 7 + if dilation_exponential>1: + self.receptive_field = int(1+(kernel_size-1)*(dilation_exponential**layers-1)/(dilation_exponential-1)) + else: + self.receptive_field = layers*(kernel_size-1) + 1 + + for i in range(1): + if dilation_exponential>1: + rf_size_i = int(1 + i*(kernel_size-1)*(dilation_exponential**layers-1)/(dilation_exponential-1)) + else: + rf_size_i = i*layers*(kernel_size-1)+1 + new_dilation = 1 + for j in range(1,layers+1): + if dilation_exponential > 1: + rf_size_j = int(rf_size_i + (kernel_size-1)*(dilation_exponential**j-1)/(dilation_exponential-1)) + else: + rf_size_j = rf_size_i+j*(kernel_size-1) + + self.filter_convs.append(dilated_inception(residual_channels, conv_channels, dilation_factor=new_dilation)) + self.gate_convs.append(dilated_inception(residual_channels, conv_channels, dilation_factor=new_dilation)) + self.residual_convs.append(nn.Conv2d(in_channels=conv_channels, + out_channels=residual_channels, + kernel_size=(1, 1))) + if self.seq_length>self.receptive_field: + self.skip_convs.append(nn.Conv2d(in_channels=conv_channels, + out_channels=skip_channels, + kernel_size=(1, self.seq_length-rf_size_j+1))) + else: + self.skip_convs.append(nn.Conv2d(in_channels=conv_channels, + out_channels=skip_channels, + kernel_size=(1, self.receptive_field-rf_size_j+1))) + + if self.gcn_true: + self.gconv1.append(mixprop(conv_channels, residual_channels, gcn_depth, dropout, propalpha)) + self.gconv2.append(mixprop(conv_channels, residual_channels, gcn_depth, dropout, propalpha)) + + if self.seq_length>self.receptive_field: + self.norm.append(LayerNorm((residual_channels, num_nodes, self.seq_length - rf_size_j + 1),elementwise_affine=layer_norm_affline)) + else: + self.norm.append(LayerNorm((residual_channels, num_nodes, self.receptive_field - rf_size_j + 1),elementwise_affine=layer_norm_affline)) + + new_dilation *= dilation_exponential + + self.layers = layers + self.end_conv_1 = nn.Conv2d(in_channels=skip_channels, + out_channels=end_channels, + kernel_size=(1,1), + bias=True) + self.end_conv_2 = nn.Conv2d(in_channels=end_channels, + out_channels=out_dim, + kernel_size=(1,1), + bias=True) + if self.seq_length > self.receptive_field: + self.skip0 = nn.Conv2d(in_channels=in_dim, out_channels=skip_channels, kernel_size=(1, self.seq_length), bias=True) + self.skipE = nn.Conv2d(in_channels=residual_channels, out_channels=skip_channels, kernel_size=(1, self.seq_length-self.receptive_field+1), bias=True) + + else: + self.skip0 = nn.Conv2d(in_channels=in_dim, out_channels=skip_channels, kernel_size=(1, self.receptive_field), bias=True) + self.skipE = nn.Conv2d(in_channels=residual_channels, out_channels=skip_channels, kernel_size=(1, 1), bias=True) + + + self.idx = torch.arange(self.num_nodes).to(device) + + + def forward(self, input, idx=None): + seq_len = input.size(3) + # pdb.set_trace() + assert seq_len==self.seq_length, 'input sequence length not equal to preset sequence length' + + if self.seq_length