Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for vectors longer than dim 3 #11

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 133 additions & 33 deletions README.md

Large diffs are not rendered by default.

130 changes: 75 additions & 55 deletions gvp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import torch, functools
from torch import nn
import functools
import torch
import torch.nn.functional as F
from torch import nn
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter_add


def tuple_sum(*args):
'''
Sums any number of tuples (s, V) elementwise.
'''
return tuple(map(sum, zip(*args)))


def tuple_cat(*args, dim=-1):
'''
Concatenates any number of tuples (s, V) elementwise.
Expand All @@ -23,6 +26,7 @@ def tuple_cat(*args, dim=-1):
s_args, v_args = list(zip(*args))
return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim)


def tuple_index(x, idx):
'''
Indexes into a tuple (s, V) along the first dimension.
Expand All @@ -31,7 +35,8 @@ def tuple_index(x, idx):
'''
return x[0][idx], x[1][idx]

def randn(n, dims, device="cpu"):

def randn(n, dims, d=3, device="cpu"):
'''
Returns random tuples (s, V) drawn elementwise from a normal distribution.

Expand All @@ -42,7 +47,8 @@ def randn(n, dims, device="cpu"):
V.shape = (n, n_vector, 3)
'''
return torch.randn(n, dims[0], device=device), \
torch.randn(n, dims[1], 3, device=device)
torch.randn(n, dims[1], d, device=device)


def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True):
'''
Expand All @@ -53,7 +59,8 @@ def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True):
out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps)
return torch.sqrt(out) if sqrt else out

def _split(x, nv):

def _split(x, nv, vector_dim=3):
'''
Splits a merged representation of (s, V) back into a tuple.
Should be used only with `_merge(s, V)` and only if the tuple
Expand All @@ -62,20 +69,22 @@ def _split(x, nv):
:param x: the `torch.Tensor` returned from `_merge`
:param nv: the number of vector channels in the input to `_merge`
'''
v = torch.reshape(x[..., -3*nv:], x.shape[:-1] + (nv, 3))
s = x[..., :-3*nv]
v = torch.reshape(x[..., -vector_dim * nv:], x.shape[:-1] + (nv, vector_dim))
s = x[..., :-vector_dim * nv]
return s, v

def _merge(s, v):

def _merge(s, v, vector_dim=3):
'''
Merges a tuple (s, V) into a single `torch.Tensor`, where the
vector channels are flattened and appended to the scalar channels.
Should be used only if the tuple representation cannot be used.
Use `_split(x, nv)` to reverse.
'''
v = torch.reshape(v, v.shape[:-2] + (3*v.shape[-2],))
v = torch.reshape(v, v.shape[:-2] + (vector_dim * v.shape[-2],))
return torch.cat([s, v], -1)


class GVP(nn.Module):
'''
Geometric Vector Perceptron. See manuscript and README.md
Expand All @@ -88,25 +97,26 @@ class GVP(nn.Module):
:param vector_gate: whether to use vector gating.
(vector_act will be used as sigma^+ in vector gating if `True`)
'''

def __init__(self, in_dims, out_dims, h_dim=None,
activations=(F.relu, torch.sigmoid), vector_gate=False):
super(GVP, self).__init__()
self.si, self.vi = in_dims
self.so, self.vo = out_dims
self.vector_gate = vector_gate
if self.vi:
self.h_dim = h_dim or max(self.vi, self.vo)
if self.vi:
self.h_dim = h_dim or max(self.vi, self.vo)
self.wh = nn.Linear(self.vi, self.h_dim, bias=False)
self.ws = nn.Linear(self.h_dim + self.si, self.so)
if self.vo:
self.wv = nn.Linear(self.h_dim, self.vo, bias=False)
if self.vector_gate: self.wsv = nn.Linear(self.so, self.vo)
else:
self.ws = nn.Linear(self.si, self.so)

self.scalar_act, self.vector_act = activations
self.dummy_param = nn.Parameter(torch.empty(0))

def forward(self, x):
'''
:param x: tuple (s, V) of `torch.Tensor`,
Expand All @@ -117,13 +127,13 @@ def forward(self, x):
if self.vi:
s, v = x
v = torch.transpose(v, -1, -2)
vh = self.wh(v)
vh = self.wh(v)
vn = _norm_no_nan(vh, axis=-2)
s = self.ws(torch.cat([s, vn], -1))
if self.vo:
v = self.wv(vh)
if self.vo:
v = self.wv(vh)
v = torch.transpose(v, -1, -2)
if self.vector_gate:
if self.vector_gate:
if self.vector_act:
gate = self.wsv(self.vector_act(s))
else:
Expand All @@ -139,14 +149,16 @@ def forward(self, x):
device=self.dummy_param.device)
if self.scalar_act:
s = self.scalar_act(s)

return (s, v) if self.vo else s


class _VDropout(nn.Module):
'''
Vector channel dropout where the elements of each
vector channel are dropped together.
'''

def __init__(self, drop_rate):
super(_VDropout, self).__init__()
self.drop_rate = drop_rate
Expand All @@ -165,11 +177,13 @@ def forward(self, x):
x = mask * x / (1 - self.drop_rate)
return x


class Dropout(nn.Module):
'''
Combined dropout for tuples (s, V).
Takes tuples (s, V) as input and as output.
'''

def __init__(self, drop_rate):
super(Dropout, self).__init__()
self.sdropout = nn.Dropout(drop_rate)
Expand All @@ -186,16 +200,18 @@ def forward(self, x):
s, v = x
return self.sdropout(s), self.vdropout(v)


class LayerNorm(nn.Module):
'''
Combined LayerNorm for tuples (s, V).
Takes tuples (s, V) as input and as output.
'''

def __init__(self, dims):
super(LayerNorm, self).__init__()
self.s, self.v = dims
self.scalar_norm = nn.LayerNorm(self.s)

def forward(self, x):
'''
:param x: tuple (s, V) of `torch.Tensor`,
Expand All @@ -209,6 +225,7 @@ def forward(self, x):
vn = torch.sqrt(torch.mean(vn, dim=-2, keepdim=True))
return self.scalar_norm(s), v / vn


class GVPConv(MessagePassing):
'''
Graph convolution / message passing with Geometric Vector Perceptrons.
Expand All @@ -229,31 +246,33 @@ class GVPConv(MessagePassing):
:param vector_gate: whether to use vector gating.
(vector_act will be used as sigma^+ in vector gating if `True`)
'''
def __init__(self, in_dims, out_dims, edge_dims,
n_layers=3, module_list=None, aggr="mean",

def __init__(self, in_dims, out_dims, edge_dims, vector_dim=3,
n_layers=3, module_list=None, aggr="mean",
activations=(F.relu, torch.sigmoid), vector_gate=False):
super(GVPConv, self).__init__(aggr=aggr)
self.si, self.vi = in_dims
self.so, self.vo = out_dims
self.se, self.ve = edge_dims

GVP_ = functools.partial(GVP,
activations=activations, vector_gate=vector_gate)

self.vector_dim = vector_dim

GVP_ = functools.partial(GVP,
activations=activations, vector_gate=vector_gate)

module_list = module_list or []
if not module_list:
if n_layers == 1:
module_list.append(
GVP_((2*self.si + self.se, 2*self.vi + self.ve),
(self.so, self.vo), activations=(None, None)))
GVP_((2 * self.si + self.se, 2 * self.vi + self.ve),
(self.so, self.vo), activations=(None, None)))
else:
module_list.append(
GVP_((2*self.si + self.se, 2*self.vi + self.ve), out_dims)
GVP_((2 * self.si + self.se, 2 * self.vi + self.ve), out_dims)
)
for i in range(n_layers - 2):
module_list.append(GVP_(out_dims, out_dims))
module_list.append(GVP_(out_dims, out_dims,
activations=(None, None)))
activations=(None, None)))
self.message_func = nn.Sequential(*module_list)

def forward(self, x, edge_index, edge_attr):
Expand All @@ -263,17 +282,17 @@ def forward(self, x, edge_index, edge_attr):
:param edge_attr: tuple (s, V) of `torch.Tensor`
'''
x_s, x_v = x
message = self.propagate(edge_index,
s=x_s, v=x_v.reshape(x_v.shape[0], 3*x_v.shape[1]),
edge_attr=edge_attr)
return _split(message, self.vo)
message = self.propagate(edge_index,
s=x_s, v=x_v.reshape(x_v.shape[0], x_v.shape[1] * x_v.shape[2]),
edge_attr=edge_attr)
return _split(message, self.vo, vector_dim=self.vector_dim)

def message(self, s_i, v_i, s_j, v_j, edge_attr):
v_j = v_j.view(v_j.shape[0], v_j.shape[1]//3, 3)
v_i = v_i.view(v_i.shape[0], v_i.shape[1]//3, 3)
v_j = v_j.view(v_j.shape[0], v_j.shape[1] // self.vector_dim, self.vector_dim)
v_i = v_i.view(v_i.shape[0], v_i.shape[1] // self.vector_dim, self.vector_dim)
message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i))
message = self.message_func(message)
return _merge(*message)
return _merge(*message, vector_dim=self.vector_dim)


class GVPConvLayer(nn.Module):
Expand All @@ -297,27 +316,28 @@ class GVPConvLayer(nn.Module):
:param vector_gate: whether to use vector gating.
(vector_act will be used as sigma^+ in vector gating if `True`)
'''
def __init__(self, node_dims, edge_dims,

def __init__(self, node_dims, edge_dims, vector_dim=3,
n_message=3, n_feedforward=2, drop_rate=.1,
autoregressive=False,
autoregressive=False,
activations=(F.relu, torch.sigmoid), vector_gate=False):

super(GVPConvLayer, self).__init__()
self.conv = GVPConv(node_dims, node_dims, edge_dims, n_message,
aggr="add" if autoregressive else "mean",
activations=activations, vector_gate=vector_gate)
GVP_ = functools.partial(GVP,
activations=activations, vector_gate=vector_gate)
self.conv = GVPConv(node_dims, node_dims, edge_dims, vector_dim, n_message,
aggr="add" if autoregressive else "mean",
activations=activations, vector_gate=vector_gate)
GVP_ = functools.partial(GVP,
activations=activations, vector_gate=vector_gate)
self.norm = nn.ModuleList([LayerNorm(node_dims) for _ in range(2)])
self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)])

ff_func = []
if n_feedforward == 1:
ff_func.append(GVP_(node_dims, node_dims, activations=(None, None)))
else:
hid_dims = 4*node_dims[0], 2*node_dims[1]
hid_dims = 4 * node_dims[0], 2 * node_dims[1]
ff_func.append(GVP_(node_dims, hid_dims))
for i in range(n_feedforward-2):
for i in range(n_feedforward - 2):
ff_func.append(GVP_(hid_dims, hid_dims))
ff_func.append(GVP_(hid_dims, node_dims, activations=(None, None)))
self.ff_func = nn.Sequential(*ff_func)
Expand All @@ -337,37 +357,37 @@ def forward(self, x, edge_index, edge_attr,
dim of node embeddings (s, V). If not `None`, only
these nodes will be updated.
'''

if autoregressive_x is not None:
src, dst = edge_index
mask = src < dst
edge_index_forward = edge_index[:, mask]
edge_index_backward = edge_index[:, ~mask]
edge_attr_forward = tuple_index(edge_attr, mask)
edge_attr_backward = tuple_index(edge_attr, ~mask)

dh = tuple_sum(
self.conv(x, edge_index_forward, edge_attr_forward),
self.conv(autoregressive_x, edge_index_backward, edge_attr_backward)
)

count = scatter_add(torch.ones_like(dst), dst,
dim_size=dh[0].size(0)).clamp(min=1).unsqueeze(-1)
dim_size=dh[0].size(0)).clamp(min=1).unsqueeze(-1)

dh = dh[0] / count, dh[1] / count.unsqueeze(-1)

else:
dh = self.conv(x, edge_index, edge_attr)

if node_mask is not None:
x_ = x
x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask)

x = self.norm[0](tuple_sum(x, self.dropout[0](dh)))

dh = self.ff_func(x)
x = self.norm[1](tuple_sum(x, self.dropout[1](dh)))

if node_mask is not None:
x_[0][node_mask], x_[1][node_mask] = x[0], x[1]
x = x_
Expand Down
Loading