-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGNN.py
48 lines (37 loc) · 1.41 KB
/
GNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np
import math
np.math = math
from torch_geometric.nn import radius_graph
from torch_geometric.nn.models.dimenet import (
BesselBasisLayer,
EmbeddingBlock,
ResidualLayer,
SphericalBasisLayer,
OutputBlock,
InteractionPPBlock,
InteractionBlock,
triplets
)
import torch.nn as nn
import torch
class GNN():
def __init__(self,cutoff=8.0,max_num_neighbors=32,num_radial=16,envelope_exponent=5,num_spherical=6):
self.cutoff = cutoff
self.max_num_neighbors = max_num_neighbors
self.rbf = BesselBasisLayer(num_radial, cutoff, envelope_exponent)
self.sbf = SphericalBasisLayer(num_spherical, num_radial, cutoff,
envelope_exponent)
def DimnetPlusLocalEnvironment(self,z,pos,batch=None):
edge_index = radius_graph(pos, r=self.cutoff, batch=batch,
max_num_neighbors=self.max_num_neighbors)
i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets(
edge_index, num_nodes=z)
# Calculate distances.
dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt()
pos_jk, pos_ij = pos[idx_j] - pos[idx_k], pos[idx_i] - pos[idx_j]
a = (pos_ij * pos_jk).sum(dim=-1)
b = torch.cross(pos_ij, pos_jk).norm(dim=-1)
angle = torch.atan2(b, a)
rbf = self.rbf(dist)
sbf = self.sbf(dist, angle, idx_kj)
return rbf,sbf