From 6b2a6cac52ac69f7fc8736fddf25e97d90eb7ea9 Mon Sep 17 00:00:00 2001 From: Jiaxuan Date: Mon, 25 Jul 2022 00:27:47 -0700 Subject: [PATCH] MetaLink code release --- graphgym/config.py | 3 + graphgym/contrib/__init__.py | 4 +- graphgym/contrib/config/metalink.py | 101 +++ graphgym/contrib/layer/kgconv.py | 342 ++++++++++ graphgym/contrib/loader/molecule.py | 645 +++++++++++++++++++ graphgym/contrib/network/metalink.py | 484 ++++++++++++++ graphgym/contrib/train/metalink.py | 282 ++++++++ graphgym/loader.py | 79 ++- graphgym/logger.py | 63 +- graphgym/loss.py | 41 +- graphgym/model_builder.py | 5 +- run/configs/MetaLink/mol_classification.yaml | 61 ++ run/grids/MetaLink/basic.txt | 12 + run/main.py | 2 +- run/main_pyg.py | 2 +- run/scripts/MetaLink/run_metalink.sh | 25 + 16 files changed, 2112 insertions(+), 39 deletions(-) create mode 100644 graphgym/contrib/config/metalink.py create mode 100644 graphgym/contrib/layer/kgconv.py create mode 100644 graphgym/contrib/loader/molecule.py create mode 100644 graphgym/contrib/network/metalink.py create mode 100644 graphgym/contrib/train/metalink.py create mode 100644 run/configs/MetaLink/mol_classification.yaml create mode 100644 run/grids/MetaLink/basic.txt create mode 100644 run/scripts/MetaLink/run_metalink.sh diff --git a/graphgym/config.py b/graphgym/config.py index cd11eed9..652d7bed 100644 --- a/graphgym/config.py +++ b/graphgym/config.py @@ -94,6 +94,9 @@ def set_cfg(cfg): # Number of dataset splits: train/val/test cfg.share.num_splits = 1 + # Number of task in targets. For multilabel tasks, this number > 1 + cfg.share.num_task = 1 + # ----------------------------------------------------------------------- # # Dataset options # ----------------------------------------------------------------------- # diff --git a/graphgym/contrib/__init__.py b/graphgym/contrib/__init__.py index 5a44f905..726665f7 100644 --- a/graphgym/contrib/__init__.py +++ b/graphgym/contrib/__init__.py @@ -6,9 +6,11 @@ from .layer import * # noqa from .loader import * # noqa from .loss import * # noqa -from .network import * # noqa from .optimizer import * # noqa from .pooling import * # noqa from .stage import * # noqa from .train import * # noqa from .transform import * # noqa + +# import in the end +from .network import * # noqa diff --git a/graphgym/contrib/config/metalink.py b/graphgym/contrib/config/metalink.py new file mode 100644 index 00000000..f4f144d2 --- /dev/null +++ b/graphgym/contrib/config/metalink.py @@ -0,0 +1,101 @@ +from yacs.config import CfgNode as CN + +from graphgym.register import register_config + + +def set_cfg_example(cfg): + r''' + This function sets the default config value for customized options + :return: customized configuration use by the experiment. + ''' + + # ----------------------------------------------------------------------- # + # MetaLink KG options + # ----------------------------------------------------------------------- # + + cfg.kg = CN() + + cfg.kg.kg_mode = True + + cfg.kg.dim_emb = 64 + + # type of prediction head: direct, mp, bilinear + cfg.kg.head = 'direct' + + # how to decode pred + cfg.kg.decode = 'dot' + + cfg.kg.layer_type = 'kgheteconv' + + # normalize embedding + cfg.kg.norm_emb = False + + # if do fine_tune after training + cfg.kg.fine_tune = False + + # kg message passing layers + cfg.kg.layers_mp = 0 + + # kg aggregation + cfg.kg.agg = 'mean' + + # kg message direction + cfg.kg.msg_direction = 'single' + + # kg gate function + cfg.kg.gate_self = True + cfg.kg.gate_msg = True + + # kg self transform + cfg.kg.self_trans = True + + # Add self msg passing + cfg.kg.self_msg = 'none' + + cfg.kg.has_act = True # not sure + + cfg.kg.has_bn = False # picked + + cfg.kg.hete = True # picked + + # last, every + cfg.kg.pred = 'every' # picked + + # no, every, last + cfg.kg.has_l2norm = 'no' # picked, every if needed + + # positioning l2 norm: pre, post + cfg.kg.pos_l2norm = 'post' + + cfg.kg.gate_bias = False + + # raw, loss, both + cfg.kg.edge_feature = 'raw' + + # pertask, standard + cfg.kg.split = 'pertask' + + # new, standard + cfg.kg.experiment = 'new' + + # standard, relation + cfg.kg.setting = 'standard' + + # Whether do meta inductive learning + cfg.kg.meta = False + + # cfg.kg.meta_num = None + + cfg.kg.meta_ratio = 0.2 + + # whether add aux target (logp, qed) + cfg.kg.aux = True + + # keep what percentage of edges + cfg.kg.setting_ratio = 0.5 + + # number of repeats for evaluation + cfg.kg.repeat = 1 + + +register_config('metalink', set_cfg_example) diff --git a/graphgym/contrib/layer/kgconv.py b/graphgym/contrib/layer/kgconv.py new file mode 100644 index 00000000..2c271ea7 --- /dev/null +++ b/graphgym/contrib/layer/kgconv.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Parameter +from torch_geometric.nn.conv import MessagePassing +from torch_geometric.nn.inits import glorot, zeros + +from graphgym.config import cfg +from graphgym.register import register_layer + +act_dict = { + 'relu': nn.ReLU(inplace=cfg.mem.inplace), + 'selu': nn.SELU(inplace=cfg.mem.inplace), + 'prelu': nn.PReLU(), + 'elu': nn.ELU(inplace=cfg.mem.inplace), + 'lrelu_01': nn.LeakyReLU(negative_slope=0.1, inplace=cfg.mem.inplace), + 'lrelu_025': nn.LeakyReLU(negative_slope=0.25, inplace=cfg.mem.inplace), + 'lrelu_05': nn.LeakyReLU(negative_slope=0.5, inplace=cfg.mem.inplace), +} + + +class HeteConvLayer(MessagePassing): + r""" + """ + def __init__(self, + in_channels, + out_channels, + edge_channels, + improved=False, + cached=False, + bias=True, + loss='bce', + **kwargs): + super(HeteConvLayer, self).__init__(aggr=cfg.kg.agg, **kwargs) + assert in_channels == out_channels # todo: see if necessary to relax this constraint + self.in_channels = in_channels + self.out_channels = out_channels + if cfg.kg.edge_feature == 'both': + edge_channels = 2 + else: + edge_channels = 1 + self.edge_channels = edge_channels + self.improved = improved + self.cached = cached + self.normalize = False + self.msg_direction = cfg.kg.msg_direction + + # TODO: Take meta learning explaination. Consider grad in message passing + + if self.msg_direction == 'single': + self.msg_node = nn.Linear(in_channels, out_channels) + self.msg_edge = nn.Linear(edge_channels, out_channels) + self.gate_self = nn.Linear(in_channels, out_channels) # optional + self.gate_msg_node = nn.Linear(in_channels, + out_channels) # optional + self.gate_msg_edge = nn.Linear(edge_channels, + out_channels) # optional + else: + self.msg_node = nn.Linear(in_channels * 2 + edge_channels, + out_channels) + self.gate_self = nn.Linear(in_channels, out_channels) + self.gate_msg = nn.Linear(in_channels * 2 + edge_channels, + out_channels) + + if bias: + self.bias = Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + + if loss == 'bce': + self.loss = nn.BCEWithLogitsLoss( + size_average=cfg.model.size_average, reduce=False) + + self.reset_parameters() + + def reset_parameters(self): + zeros(self.bias) + self.cached_result = None + self.cached_num_edges = None + + def forward(self, x, edge_index, edge_weight=None, edge_feature=None): + pred = self.propagate(edge_index, x=x, edge_feature=edge_feature) + + return pred + + def message(self, x_i, x_j, edge_feature): + if self.msg_direction == 'both': + x_j = torch.cat((x_i, x_j, edge_feature), dim=-1) + else: + + if cfg.kg.edge_feature != 'raw': + pred = torch.sum(x_i * x_j, dim=-1, keepdim=True) + loss = self.loss(pred, edge_feature) + if cfg.kg.edge_feature == 'loss': + edge_feature = loss + elif cfg.kg.edge_feature == 'both': + edge_feature = torch.cat((edge_feature, loss), dim=-1) + + msg = self.msg_node(x_j) + self.msg_edge(edge_feature) + if cfg.kg.gate_msg: + gate = self.gate_msg_node(x_j) + self.gate_msg_edge( + edge_feature) + gate = torch.sigmoid(gate) + if cfg.kg.gate_bias: + gate = gate - 0.5 + msg = msg * gate + return msg + + def update(self, aggr_out): + if self.bias is not None: + aggr_out = aggr_out + self.bias + return aggr_out + + def __repr__(self): + return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, + self.out_channels) + + +class SelfLayer(nn.Module): + r""" + """ + def __init__(self, in_channels, out_channels, **kwargs): + super(SelfLayer, self).__init__() + assert in_channels == out_channels # todo: see if necessary to relax this constraint + self.in_channels = in_channels + self.out_channels = out_channels + + self.gate_self = nn.Linear(in_channels, out_channels) + + def forward(self, x): + gate = torch.sigmoid(self.gate_self(x)) + if cfg.kg.gate_bias: + gate = gate + 0.5 + return gate * x + + def __repr__(self): + return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, + self.out_channels) + + +class Postlayer(nn.Module): + r""" + """ + def __init__(self, + dim_out, + bn=False, + dropout=0, + act=True, + l2norm=False, + **kwargs): + super(Postlayer, self).__init__() + self.l2norm = l2norm + self.post_layer = [] + if bn: + self.post_layer.append( + nn.BatchNorm1d(dim_out, eps=cfg.bn.eps, momentum=cfg.bn.mom)) + if dropout > 0: + self.post_layer.append( + nn.Dropout(p=dropout, inplace=cfg.mem.inplace)) + if act: + self.post_layer.append(act_dict[cfg.gnn.act]) + self.post_layer = nn.Sequential(*self.post_layer) + + def forward(self, x): + if cfg.kg.pos_l2norm == 'post': + x = self.post_layer(x) + if self.l2norm: + x = F.normalize(x, p=2, dim=1) + else: + if self.l2norm: + x = F.normalize(x, p=2, dim=1) + x = self.post_layer(x) + + return x + + +class KGHeteroConv(torch.nn.Module): + r"""A "wrapper" layer designed for heterogeneous graph layers. It takes a + heterogeneous graph layer, such as :class:`deepsnap.hetero_gnn.HeteroSAGEConv`, at the initializing stage. + """ + def __init__(self, convs, selfs=None, posts=None, aggr="add"): + super(KGHeteroConv, self).__init__() + + self.convs = convs + self.selfs = selfs # self computations + self.posts = posts + self.modules_convs = torch.nn.ModuleList(convs.values()) + if self.selfs is not None: + self.modules_selfs = torch.nn.ModuleList(selfs.values()) + if self.posts is not None: + self.modules_posts = torch.nn.ModuleList(posts.values()) + + assert aggr in ["add", "mean", "max", "mul", "concat", None] + self.aggr = aggr + + def forward(self, node_features, edge_indices, edge_features=None): + r"""The forward function for `HeteroConv`. + + Args: + node_features (dict): A dictionary each key is node type and the corresponding + value is a node feature tensor. + edge_indices (dict): A dictionary each key is message type and the corresponding + value is an edge index tensor. + edge_features (dict): A dictionary each key is edge type and the corresponding + value is an edge feature tensor. Default is `None`. + """ + # node embedding computed from each message type + message_type_emb = {} + for message_key, message_type in edge_indices.items(): + if message_key not in self.convs: + continue + neigh_type, edge_type, self_type = message_key + node_feature_neigh = node_features[neigh_type] + node_feature_self = node_features[self_type] + if edge_features is not None: + edge_feature = edge_features[edge_type] + edge_index = edge_indices[message_key] + + message_type_emb[message_key] = (self.convs[message_key]( + x=(node_feature_neigh, node_feature_self), + edge_index=edge_index, + edge_feature=edge_feature)) + + # aggregate node embeddings from different message types into 1 node + # embedding for each node + node_emb = {tail: [] for _, _, tail in message_type_emb.keys()} + + for (_, _, tail), item in message_type_emb.items(): + node_emb[tail].append(item) + + # add self messages + if self.selfs is not None: + for node_key, node_type in node_features.items(): + # print(self.selfs[node_key](node_features[node_key])) + node_emb[node_key].append(self.selfs[node_key]( + node_features[node_key])) + + # Aggregate multiple embeddings with the same tail. + for node_type, embs in node_emb.items(): + if len(embs) == 1: + node_emb[node_type] = embs[0] + else: + node_emb[node_type] = self.aggregate(embs) + if self.posts is not None: + node_emb[node_type] = self.posts[node_type]( + node_emb[node_type]) + + return node_emb + + def aggregate(self, xs): + r"""The aggregation for each node type. Currently support `concat`, `add`, + `mean`, `max` and `mul`. + """ + if self.aggr == "concat": + return torch.cat(xs, dim=-1) + + x = torch.stack(xs, dim=-1) + if self.aggr == "add": + x = x.sum(dim=-1) + elif self.aggr == "mean": + x = x.mean(dim=-1) + elif self.aggr == "max": + x = x.max(dim=-1)[0] + elif self.aggr == "mul": + x = x.prod(dim=-1)[0] + return x + + +class KGHeteConv(nn.Module): + def __init__(self, + dim_in, + dim_out, + bias=False, + dim_edge=None, + bn=False, + act=True, + l2norm=False, + **kwargs): + super(KGHeteConv, self).__init__() + if cfg.kg.hete: + convs = { + ('data', 'property', 'task'): + HeteConvLayer(dim_in, dim_out, edge_channels=dim_edge), + ('task', 'property', 'data'): + HeteConvLayer(dim_in, dim_out, edge_channels=dim_edge) + } + selfs = { + 'data': SelfLayer(dim_in, dim_out), + 'task': SelfLayer(dim_in, dim_out) + } + posts = { + 'data': Postlayer(dim_out, bn=bn, act=act, l2norm=l2norm), + 'task': Postlayer(dim_out, bn=bn, act=act, l2norm=l2norm) + } + else: + conv = HeteConvLayer(dim_in, dim_out, edge_channels=dim_edge) + convs = { + ('data', 'property', 'task'): conv, + ('task', 'property', 'data'): conv + } + self_fun = SelfLayer(dim_in, dim_out) + selfs = {'data': self_fun, 'task': self_fun} + post = Postlayer(dim_out, bn=bn, act=act, l2norm=l2norm) + posts = {'data': post, 'task': post} + if not cfg.kg.self_trans: + selfs = None + self.model = KGHeteroConv(convs, selfs, posts) + + def pred(self, batch): + + # todo: compare with index_select then sparse dot product + pred = torch.matmul(batch.node_feature['data'], + batch.node_feature['task'].transpose(0, 1)) + id_data = batch.pred_index[0, :] + id_task = batch.pred_index[1, :] + pred = pred[id_data].gather(1, id_task.view(-1, 1)) + # pdb.set_trace() + if not hasattr(batch, 'pred'): + batch.pred = pred + else: + batch.pred = batch.pred + pred + return batch + + def forward(self, batch): + if cfg.kg.pred == 'every': + batch = self.pred(batch) + if cfg.kg.self_msg == 'add': + node_feature = batch.node_feature + batch.node_feature = self.model(batch.node_feature, + batch.edge_index, + batch.edge_feature) + for key in batch.node_feature.keys(): + batch.node_feature[ + key] = batch.node_feature[key] + node_feature[key] + else: + batch.node_feature = self.model(batch.node_feature, + batch.edge_index, + batch.edge_feature) + return batch + + +register_layer('kgheteconv', KGHeteConv) diff --git a/graphgym/contrib/loader/molecule.py b/graphgym/contrib/loader/molecule.py new file mode 100644 index 00000000..5deaa5d3 --- /dev/null +++ b/graphgym/contrib/loader/molecule.py @@ -0,0 +1,645 @@ +# MIT License +# Copyright (c) 2020 Jiaxuan You, Wengong Jin, Octavian Ganea, +# Regina Barzilay, Tommi Jaakkola +# Copyright (c) 2020 Wengong Jin, Kyle Swanson, Kevin Yang, +# Regina Barzilay, Tommi Jaakkola +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import csv +import os +import random +from collections import Counter, defaultdict +from typing import Dict, List, Set, Union + +import matplotlib.pyplot as plt +import networkx as nx +import numpy as np +import torch +from deepsnap.graph import Graph +from rdkit import Chem +from rdkit.Chem.Crippen import MolLogP +from rdkit.Chem.Descriptors import qed +from rdkit.Chem.Scaffolds import MurckoScaffold +from tqdm import tqdm + +dirname = os.path.dirname(os.path.abspath(__file__)) + +# Atom feature sizes +MAX_ATOMIC_NUM = 100 +ATOM_FEATURES = { + 'atomic_num': + list(range(MAX_ATOMIC_NUM)), + 'degree': [0, 1, 2, 3, 4, 5], + 'formal_charge': [-1, -2, 1, 2, 0], + 'chiral_tag': [0, 1, 2, 3], + 'num_Hs': [0, 1, 2, 3, 4], + 'hybridization': [ + Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, + Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, + Chem.rdchem.HybridizationType.SP3D2 + ], +} + +# Distance feature sizes +PATH_DISTANCE_BINS = list(range(10)) +THREE_D_DISTANCE_MAX = 20 +THREE_D_DISTANCE_STEP = 1 +THREE_D_DISTANCE_BINS = list( + range(0, THREE_D_DISTANCE_MAX + 1, THREE_D_DISTANCE_STEP)) + +# len(choices) + 1 to include room for uncommon values; +# + 2 at end for IsAromatic and mass +ATOM_FDIM = sum(len(choices) + 1 for choices in ATOM_FEATURES.values()) + 2 +BOND_FDIM = 14 + + +def get_atom_fdim() -> int: + """Gets the dimensionality of atom features.""" + return ATOM_FDIM + + +def get_bond_fdim(atom_messages: bool = True) -> int: + """ + Gets the dimensionality of bond features. + + :param atom_messages whether atom messages are being used. If atom + messages, only contains bond features. + Otherwise contains both atom and bond features. + :return: The dimensionality of bond features. + """ + return BOND_FDIM + (not atom_messages) * get_atom_fdim() + + +def onek_encoding_unk(value: int, choices: List[int]) -> List[int]: + """ + Creates a one-hot encoding. + + :param value: The value for which the encoding should be one. + :param choices: A list of possible values. + :return: A one-hot encoding of the value in a list of + length len(choices) + 1. + If value is not in the list of choices, + then the final element in the encoding is 1. + """ + encoding = [0] * (len(choices) + 1) + index = choices.index(value) if value in choices else -1 + encoding[index] = 1 + + return encoding + + +def atom_features( + atom: Chem.rdchem.Atom, + functional_groups: List[int] = None) -> List[Union[bool, int, float]]: + """ + Builds a feature vector for an atom. + + :param atom: An RDKit atom. + :param functional_groups: A k-hot vector indicating the functional groups + the atom belongs to. + :return: A list containing the atom features. + """ + # todo: consider add isinring feature + features = onek_encoding_unk( + atom.GetAtomicNum() - 1, + ATOM_FEATURES['atomic_num']) + onek_encoding_unk( + atom.GetTotalDegree(), + ATOM_FEATURES['degree']) + onek_encoding_unk( + atom.GetFormalCharge(), + ATOM_FEATURES['formal_charge']) + onek_encoding_unk( + int(atom.GetChiralTag()), + ATOM_FEATURES['chiral_tag']) + onek_encoding_unk( + int(atom.GetTotalNumHs()), + ATOM_FEATURES['num_Hs']) + onek_encoding_unk( + int(atom.GetHybridization()), + ATOM_FEATURES['hybridization']) + [ + 1 if atom.GetIsAromatic() else 0 + ] + [atom.GetMass() * 0.01] # scale to same range + if functional_groups is not None: + features += functional_groups + return features + + +def bond_features(bond: Chem.rdchem.Bond) -> List[Union[bool, int, float]]: + """ + Builds a feature vector for a bond. + + :param bond: A RDKit bond. + :return: A list containing the bond features. + """ + if bond is None: + fbond = [1] + [0] * (BOND_FDIM - 1) + else: + bt = bond.GetBondType() + fbond = [ + 0, # bond is not None + bt == Chem.rdchem.BondType.SINGLE, + bt == Chem.rdchem.BondType.DOUBLE, + bt == Chem.rdchem.BondType.TRIPLE, + bt == Chem.rdchem.BondType.AROMATIC, + (bond.GetIsConjugated() if bt is not None else 0), + (bond.IsInRing() if bt is not None else 0) + ] + fbond += onek_encoding_unk(int(bond.GetStereo()), list(range(6))) + return fbond + + +def get_fname(path): + return os.path.basename(path).split('.')[0] + + +def get_graph_stats(g): + print(g.number_of_nodes(), g.number_of_edges(), nx.average_clustering(g), + nx.diameter(g)) + + +def mol2nx(mol): + G = nx.Graph() + for atom in mol.GetAtoms(): + index = atom.GetIdx() + feature = torch.tensor(atom_features(atom)) + G.add_node(index, node_feature=feature) + for bond in mol.GetBonds(): + index_begin = bond.GetBeginAtomIdx() + index_end = bond.GetEndAtomIdx() + feature = torch.tensor(bond_features(bond)) + G.add_edge(index_begin, index_end, edge_feature=feature) + return G + + +def mol2data(mol): + n = mol.GetNumAtoms() + nd = get_atom_fdim() + e = mol.GetNumBonds() + ed = get_bond_fdim() + node_feature = torch.zeros((n, nd)) + edge_index = torch.zeros((2, e), dtype=torch.long) + edge_feature = torch.zeros((e, ed)) + for atom in mol.GetAtoms(): + index = atom.GetIdx() + node_feature[index, :] = torch.tensor(atom_features(atom)) + for i, bond in enumerate(mol.GetBonds()): + index_begin = bond.GetBeginAtomIdx() + index_end = bond.GetEndAtomIdx() + edge_index[:, i] = torch.tensor([index_begin, index_end]) + edge_feature[i, :] = torch.tensor(bond_features(bond)) + edge_index = torch.cat((edge_index, torch.flip(edge_index, [0])), dim=1) + edge_feature = torch.cat((edge_feature, edge_feature), dim=0) + + data = Graph(node_feature=node_feature, + edge_index=edge_index, + edge_feature=edge_feature) + + return data + + +def smiles2nx(smiles): + mol = Chem.MolFromSmiles(smiles) + return mol2nx(mol) + + +def smiles2data(smiles, return_mol=False): + mol = Chem.MolFromSmiles(smiles) + if return_mol: + return mol2data(mol), mol + else: + return mol2data(mol) + + +def mols2graphs(all_mols): + graphs = [] + for smiles, val in all_mols.items(): + G = smiles2nx(smiles) + if G.number_of_edges() == 0: + continue + for key in val.keys(): + G.graph[key] = val[key] + graphs.append(G) + return graphs + + +def mols2data(all_mols, return_scaffold_split=False): + all_data = [] + all_data_mol = [] + for smiles, val in all_mols.items(): + data, mol = smiles2data(smiles, return_mol=True) + if data.num_edges == 0: + continue + if isinstance(val, dict): + for key in val.keys(): + data['graph_{}'.format(key)] = torch.tensor(val[key]) + all_data.append(data) + all_data_mol.append(mol) + if return_scaffold_split: + splits = mol_scaffold_split(all_data_mol) + return all_data, splits + else: + return all_data + + +# Load data +def load_mol_analysis(path_list): + all_smiles = [] + all_smiles_unique = set() + for path in path_list: + smiles_unique = set() + with open(path) as f: + reader = csv.DictReader(f) + columns = reader.fieldnames + smiles_column = columns[0] + for row in tqdm(reader): + smiles = row[smiles_column] + mol = Chem.MolFromSmiles(smiles) + smiles = Chem.MolToSmiles(mol) # canonical smiles + smiles_unique.add(smiles) + all_smiles += list(smiles_unique) + all_smiles_unique = all_smiles_unique.union(smiles_unique) + + # Plot + smile_count = dict(Counter(all_smiles)) + + counts = [] + for smile, count in smile_count.items(): + if count > 1: + counts.append(count) + # print(smile, count) + print(len(all_smiles_unique), len(counts)) + + count, bins = np.histogram(np.array(counts), bins=np.arange(2, 16)) + plt.figure() + plt.plot(bins[1:], count) + plt.show() + plt.figure() + plt.plot(bins[1:], np.log10(count)) + plt.show() + + +class Molecule(object): + def __init__(self, smiles): + self.mol = Chem.MolFromSmiles(smiles) + self.smiles = Chem.MolToSmiles(self.mol) # canonical smiles + + def build_graph(self): + self.graph = smiles2nx(self.smiles) + + +def min_max_scaler(x, x_min, x_max): + return (x - x_min) / (x_max - x_min) + + +# Load data +def load_mol_datasets(name_list, use_cache=False): + # todo: using pandas may be more efficient + + path_list = ['{}/data/{}.csv'.format(dirname, name) for name in name_list] + + cache_mols_name = '{}/cache/mols_{}.pt'.format(dirname, + '_'.join(name_list)) + cache_splits_name = '{}/cache/splits_{}.pt'.format(dirname, + '_'.join(name_list)) + cache_motifs_name = '{}/cache/motifs_{}.pt'.format(dirname, + '_'.join(name_list)) + cache_targets_name = '{}/cache/targets_{}.pt'.format( + dirname, '_'.join(name_list)) + + if not use_cache or not os.path.isfile( + cache_mols_name + ) or not os.path.isfile(cache_splits_name) or not os.path.isfile( + cache_motifs_name) or not os.path.isfile(cache_targets_name): + + # get all target names + all_targets = ['logp', 'qed'] + for path in path_list: + with open(path) as f: + fname = get_fname(path) + reader = csv.DictReader(f) + columns = reader.fieldnames + target_columns = columns[1:] + target_columns = [ + '{}_{}'.format(fname, target_column) + for target_column in target_columns + ] + all_targets += target_columns + + target_id_bias = 2 + all_mols = {} + + motifs = {} + logp_all = [] + qed_all = [] + for path in path_list: + with open(path) as f: + reader = csv.DictReader(f) + columns = reader.fieldnames + smiles_column = columns[0] + target_columns = columns[1:] + + smiles_unique = set() + for row in tqdm(reader): + smiles = row[smiles_column] + mol = Chem.MolFromSmiles(smiles) + smiles = Chem.MolToSmiles(mol) # canonical smiles + if smiles in smiles_unique or mol is None \ + or mol.GetNumAtoms() <= 1: + continue + smiles_unique.add(smiles) + + motif_mols, motif_smiles = find_fragments(mol) + # update overall motifs dict + for smiles_temp in motif_smiles: + if smiles_temp not in motifs: + motifs[smiles_temp] = len(motifs) + motifs_id = [ + motifs[smiles_temp] for smiles_temp in motif_smiles + ] + + targets_id = [] + targets_value = [] + # add synthetic molecule feature + try: + mol = Chem.MolFromSmiles(smiles) + value_logp = MolLogP(mol) + targets_id.append(0) + targets_value.append(value_logp) + logp_all.append(value_logp) + value_qed = qed(mol) + targets_id.append(1) + targets_value.append(value_qed) + qed_all.append(value_qed) + except Exception: + print('--cannot compute logp or qed') + + for id, column in enumerate(target_columns): + if row[column] != '': + id_all = id + target_id_bias + value = float(row[column]) + targets_id.append(id_all) + targets_value.append(value) + if mol is None: + print(smiles) + if all_mols.get(smiles) is None: + all_mols[smiles] = { + 'targets_id': targets_id, + 'targets_value': targets_value, + 'motifs_id': motifs_id + } + else: + all_mols[smiles]['targets_id'] += targets_id + all_mols[smiles]['targets_value'] += targets_value + all_mols[smiles]['motifs_id'] += motifs_id + target_id_bias += len(target_columns) + # normalize logp and qed + logp_all_min, logp_all_max = min(logp_all), max(logp_all) + qed_all_min, qed_all_max = min(qed_all), max(qed_all) + for smiles in all_mols.keys(): + for i, targets_id in enumerate(all_mols[smiles]['targets_id']): + + if targets_id == 0: + all_mols[smiles]['targets_value'][i] = min_max_scaler( + all_mols[smiles]['targets_value'][i], logp_all_min, + logp_all_max) + elif targets_id == 1: + all_mols[smiles]['targets_value'][i] = min_max_scaler( + all_mols[smiles]['targets_value'][i], qed_all_min, + qed_all_max) + + all_mols, splits = mols2data(all_mols, return_scaffold_split=True) + all_motifs = mols2data(motifs) + + torch.save(all_mols, cache_mols_name) + torch.save(splits, cache_splits_name) + torch.save(all_motifs, cache_motifs_name) + torch.save(all_targets, cache_targets_name) + else: + all_mols = torch.load(cache_mols_name) + splits = torch.load(cache_splits_name) + all_motifs = torch.load(cache_motifs_name) + all_targets = torch.load(cache_targets_name) + + return all_mols, splits, all_motifs, all_targets + + +def copy_atom(atom, atommap=True): + new_atom = Chem.Atom(atom.GetSymbol()) + new_atom.SetFormalCharge(atom.GetFormalCharge()) + if atommap: + new_atom.SetAtomMapNum(atom.GetAtomMapNum()) + return new_atom + + +def copy_edit_mol(mol): + new_mol = Chem.RWMol(Chem.MolFromSmiles('')) + for atom in mol.GetAtoms(): + new_atom = copy_atom(atom) + new_mol.AddAtom(new_atom) + + for bond in mol.GetBonds(): + a1 = bond.GetBeginAtom().GetIdx() + a2 = bond.GetEndAtom().GetIdx() + bt = bond.GetBondType() + new_mol.AddBond(a1, a2, bt) + return new_mol + + +def get_mol(smiles): + mol = Chem.MolFromSmiles(smiles) + if mol is not None: + Chem.Kekulize(mol) + return mol + + +def get_smiles(mol): + return Chem.MolToSmiles(mol, kekuleSmiles=True) + + +def sanitize(mol, kekulize=True): + try: + smiles = get_smiles(mol) if kekulize else Chem.MolToSmiles(mol) + mol = get_mol(smiles) if kekulize else Chem.MolFromSmiles(smiles) + except Exception: + mol = None + return mol + + +def get_clique_mol(mol, atoms): + smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True) + new_mol = Chem.MolFromSmiles(smiles, sanitize=False) + new_mol = copy_edit_mol(new_mol).GetMol() + new_mol = sanitize(new_mol) + return new_mol + + +def find_fragments(mol): + new_mol = Chem.RWMol(mol) + for atom in new_mol.GetAtoms(): + atom.SetAtomMapNum(atom.GetIdx()) + + for bond in mol.GetBonds(): + if bond.IsInRing(): + continue + a1 = bond.GetBeginAtom() + a2 = bond.GetEndAtom() + + if a1.IsInRing() and a2.IsInRing(): + new_mol.RemoveBond(a1.GetIdx(), a2.GetIdx()) + + elif a1.IsInRing() and a2.GetDegree() > 1: + new_idx = new_mol.AddAtom(copy_atom(a1)) + new_mol.GetAtomWithIdx(new_idx).SetAtomMapNum(a1.GetIdx()) + new_mol.AddBond(new_idx, a2.GetIdx(), bond.GetBondType()) + new_mol.RemoveBond(a1.GetIdx(), a2.GetIdx()) + + elif a2.IsInRing() and a1.GetDegree() > 1: + new_idx = new_mol.AddAtom(copy_atom(a2)) + new_mol.GetAtomWithIdx(new_idx).SetAtomMapNum(a2.GetIdx()) + new_mol.AddBond(new_idx, a1.GetIdx(), bond.GetBondType()) + new_mol.RemoveBond(a1.GetIdx(), a2.GetIdx()) + + new_mol = new_mol.GetMol() + new_smiles = Chem.MolToSmiles(new_mol) + + motif_mols = [] + motif_smiles = [] + for fragment in new_smiles.split('.'): + fmol = Chem.MolFromSmiles(fragment) + if fmol is None: + continue + indices = set([atom.GetAtomMapNum() for atom in fmol.GetAtoms()]) + fmol = get_clique_mol(mol, indices) + if fmol is None: + continue + fsmiles = Chem.MolToSmiles(fmol) + fmol = Chem.MolFromSmiles(fsmiles) + if fmol is None or fmol.GetNumBonds() < 1 or fmol.GetNumAtoms() <= 1: + continue + if fsmiles in motif_smiles: + continue + motif_mols.append(fmol) + motif_smiles.append(fsmiles) + + # at least return 1 motif + if len(motif_mols) == 0: + motif_mols = [mol] + motif_smiles = [Chem.MolToSmiles(mol)] + + return motif_mols, motif_smiles + + +def generate_scaffold(mol: Union[str, Chem.Mol], + include_chirality: bool = False) -> str: + """ + Compute the Bemis-Murcko scaffold for a SMILES string. + + :param mol: A smiles string or an RDKit molecule. + :param include_chirality: Whether to include chirality. + :return: + """ + mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol + scaffold = MurckoScaffold.MurckoScaffoldSmiles( + mol=mol, includeChirality=include_chirality) + + return scaffold + + +def scaffold_to_smiles( + mols: Union[List[str], List[Chem.Mol]], + use_indices: bool = False) -> Dict[str, Union[Set[str], Set[int]]]: + """ + Computes scaffold for each smiles string and returns a mapping from + scaffolds to sets of smiles. + + :param mols: A list of smiles strings or RDKit molecules. + :param use_indices: Whether to map to the smiles' index in all_smiles + rather than mapping to the smiles string itself. This is necessary if there + are duplicate smiles. + :return: A dictionary mapping each unique scaffold to all smiles + (or smiles indices) which have that scaffold. + """ + scaffolds = defaultdict(set) + for i, mol in tqdm(enumerate(mols), total=len(mols)): + scaffold = generate_scaffold(mol) + if use_indices: + scaffolds[scaffold].add(i) + else: + scaffolds[scaffold].add(mol) + + return scaffolds + + +def mol_scaffold_split(data, sizes=[0.8, 0.1, 0.1], balanced=True, repeat=10): + ''' + + :param data: all_mols + :param sizes: A length-3 tuple with the proportions of data in the + train, validation, and test sets. + :param balanced: Try to balance sizes of scaffolds in each set, rather than + just putting smallest in test set. + :return: + ''' + assert sum(sizes) == 1 + + # Map from scaffold to index in the data + scaffold_to_indices = scaffold_to_smiles(data, use_indices=True) + train_all = [] + val_all = [] + test_all = [] + + for i in range(repeat): + random.seed(i) + # Split + train_size, val_size, test_size = sizes[0] * len(data), sizes[1] * len( + data), sizes[2] * len(data) + train, val, test = [], [], [] + train_scaffold_count, val_scaffold_count, test_scaffold_count = 0, 0, 0 + # Put stuff that's bigger than half the val/test size into train, + # rest just order randomly + if balanced: + index_sets = list(scaffold_to_indices.values()) + big_index_sets = [] + small_index_sets = [] + for index_set in index_sets: + if len(index_set) > val_size / 2 or len( + index_set) > test_size / 2: + big_index_sets.append(index_set) + else: + small_index_sets.append(index_set) + random.shuffle(big_index_sets) + random.shuffle(small_index_sets) + index_sets = big_index_sets + small_index_sets + else: # Sort from largest to smallest scaffold sets + index_sets = sorted(list(scaffold_to_indices.values()), + key=lambda index_set: len(index_set), + reverse=True) + + for index_set in index_sets: + if len(train) + len(index_set) <= train_size: + train += index_set + train_scaffold_count += 1 + elif len(val) + len(index_set) <= val_size: + val += index_set + val_scaffold_count += 1 + else: + test += index_set + test_scaffold_count += 1 + train_all.append(train) + val_all.append(val) + test_all.append(test) + + return {'train': train_all, 'valid': val_all, 'test': test_all} diff --git a/graphgym/contrib/network/metalink.py b/graphgym/contrib/network/metalink.py new file mode 100644 index 00000000..a8121509 --- /dev/null +++ b/graphgym/contrib/network/metalink.py @@ -0,0 +1,484 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from deepsnap.batch import Batch + +from graphgym.config import cfg +from graphgym.init import init_weights +from graphgym.models.feature_augment import Preprocess +from graphgym.models.gnn import GNNPreMP, stage_dict +from graphgym.models.layer import MLP, GeneralLayer, layer_dict +from graphgym.models.pooling import pooling_dict +from graphgym.register import register_network + + +class StackHeteConv(nn.Module): + '''Simple Stage that stack GNN layers''' + def __init__(self, dim_in, dim_out, num_layers, layer_name): + super(StackHeteConv, self).__init__() + self.num_layers = num_layers + for i in range(num_layers): + d_in = dim_in if i == 0 else dim_out + has_act = cfg.kg.has_act if i < num_layers - 1 else False + if cfg.kg.has_l2norm == 'no': + has_l2norm = False + elif cfg.kg.has_l2norm == 'every': + has_l2norm = True + elif cfg.kg.has_l2norm == 'last': + has_l2norm = False if i < num_layers - 1 else True + layer = layer_dict[layer_name](d_in, + dim_out, + act=has_act, + bn=cfg.kg.has_bn, + l2norm=has_l2norm, + dim_edge=1) + self.add_module('layer{}'.format(i), layer) + self.dim_out = dim_out + + def forward(self, batch): + if self.num_layers > 0: + for layer in self.children(): + batch = layer(batch) + return batch + + +class GNNStackStageUser(nn.Module): + '''Simple Stage that stack GNN layers''' + def __init__(self, dim_in, dim_out, num_layers, layer_name): + super(GNNStackStageUser, self).__init__() + self.num_layers = num_layers + for i in range(num_layers): + d_in = dim_in if i == 0 else dim_out + has_act = cfg.kg.has_act if i < num_layers - 1 else False + layer = GeneralLayer(layer_name, + d_in, + dim_out, + has_act=has_act, + has_bn=cfg.kg.has_bn, + dim_edge=1) + self.add_module('layer{}'.format(i), layer) + self.dim_out = dim_out + + def forward(self, batch): + if self.num_layers > 0: + for layer in self.children(): + batch = layer(batch) + return batch + + +class MetaLinkPool(nn.Module): + '''Head of MetaLink, graph prediction + + The optional post_mp layer (specified by cfg.gnn.post_mp) is used + to transform the pooled embedding using an MLP. + ''' + def __init__(self, dim_in, dim_out): + super(MetaLinkPool, self).__init__() + self.layer_post_mp = MLP(dim_in, + dim_out, + num_layers=cfg.gnn.layers_post_mp, + bias=True) + self.pooling_fun = pooling_dict[cfg.model.graph_pooling] + + def forward(self, batch): + if cfg.dataset.transform == 'ego': + graph_emb = self.pooling_fun(batch.node_feature, batch.batch, + batch.node_id_index) + else: + graph_emb = self.pooling_fun(batch.node_feature, batch.batch) + graph_emb = self.layer_post_mp(graph_emb) + batch.graph_feature = graph_emb + if cfg.kg.norm_emb: + batch.graph_feature = F.normalize(batch.graph_feature, p=2, dim=-1) + return batch + + +class DirectHead(nn.Module): + '''Directly predict label + + The optional post_mp layer (specified by cfg.gnn.post_mp) is used + to transform the pooled embedding using an MLP. + ''' + def __init__(self, dim_in, dim_out): + super(DirectHead, self).__init__() + self.layer = MLP(dim_in, + dim_in, + num_layers=cfg.kg.layers_mp, + bias=True) + + self.task_emb = nn.Parameter(torch.Tensor(dim_in, dim_out)) + self.task_emb.data = nn.init.xavier_uniform_( + self.task_emb.data, gain=nn.init.calculate_gain('relu')) + + def pred(self, batch, pred): + id_data = batch.pred_index[0, :] + id_task = batch.pred_index[1, :] + pred = pred[id_data].gather(1, id_task.view(-1, 1)) + return pred + + def forward(self, batch, pred_index=None): + batch.pred_index = pred_index + graph_feature = self.layer(batch.graph_feature) + if cfg.kg.norm_emb: + weight = F.normalize(self.task_emb, p=2, dim=0) + pred = torch.matmul(graph_feature, weight) + else: + pred = torch.matmul(graph_feature, self.task_emb) + pred = self.pred(batch, pred) + return pred, graph_feature, self.task_emb + + +class ConcatHead(nn.Module): + '''Directly predict label + + The optional post_mp layer (specified by cfg.gnn.post_mp) is used + to transform the pooled embedding using an MLP. + ''' + def __init__(self, dim_in, dim_out, dim_aux): + super(ConcatHead, self).__init__() + self.dim_aux = dim_aux + self.layer = MLP(dim_in + dim_aux, + dim_in + dim_aux, + num_layers=cfg.kg.layers_mp, + bias=True) + + self.task_emb = nn.Parameter(torch.Tensor(dim_in + dim_aux, dim_out)) + self.task_emb.data = nn.init.xavier_uniform_( + self.task_emb.data, gain=nn.init.calculate_gain('relu')) + + def get_aux(self, batch): + i = torch.stack((batch.graph_targets_id_batch, batch.graph_targets_id), + dim=0) + v = batch.graph_targets_value + size = torch.Size([batch.graph_feature.shape[0], self.dim_aux]) + aux = torch.sparse.FloatTensor(i, v, size).to_dense() + return aux + + def pred(self, batch, pred): + id_data = batch.pred_index[0, :] + id_task = batch.pred_index[1, :] + pred = pred[id_data].gather(1, id_task.view(-1, 1)) + return pred + + def forward(self, batch, pred_index=None): + batch.pred_index = pred_index + aux = self.get_aux(batch) + graph_feature = torch.cat((batch.graph_feature, aux), dim=1) + graph_feature = self.layer(graph_feature) + + if cfg.kg.norm_emb: + weight = F.normalize(self.task_emb, p=2, dim=0) + pred = torch.matmul(graph_feature, weight) + else: + pred = torch.matmul(graph_feature, self.task_emb) + pred = self.pred(batch, pred) + return pred, graph_feature, self.task_emb + + +class MPHead(nn.Module): + '''Message pass then predict label + TODO: Heterogenous message passing + + The optional post_mp layer (specified by cfg.gnn.post_mp) is used + to transform the pooled embedding using an MLP. + ''' + def __init__(self, dim_in, dim_out, num_layers=1): + super(MPHead, self).__init__() + self.task_emb = nn.Parameter(torch.Tensor(dim_out, dim_in)) + self.task_emb.data = nn.init.xavier_uniform_( + self.task_emb.data, gain=nn.init.calculate_gain('relu')) + self.mp = GNNStackStageUser(dim_in, + dim_in, + cfg.kg.layers_mp, + layer_name='kgconvv1') + + def forward(self, batch): + batch_kg = Batch() + # create nodes + if cfg.kg.norm_emb: + task_emb = F.normalize(self.task_emb, p=2, dim=1) + else: + task_emb = self.task_emb + data_emb = batch.graph_feature + + batch_kg.node_feature = torch.cat((data_emb, task_emb), dim=0) + + # create edges + n_data = data_emb.shape[0] + # shift task node id + graph_targets_id = batch.graph_targets_id + n_data + edge_index = torch.stack( + (batch.graph_targets_id_batch, graph_targets_id), dim=0) + edge_feature = batch.graph_targets_value.unsqueeze(-1) + batch_kg.edge_index = torch.cat( + (edge_index, torch.flip(edge_index, [0])), dim=1) + batch_kg.edge_feature = torch.cat((edge_feature, edge_feature), dim=0) + batch_kg = self.mp(batch_kg) + + # pred + data_emb = batch_kg.node_feature[:n_data] + task_emb = batch_kg.node_feature[n_data:] + pred = torch.matmul(data_emb, task_emb.transpose(0, 1)) + + return pred, data_emb, task_emb + + +class MPHeteHead(nn.Module): + '''Heterogenous Message pass then predict label + + The optional post_mp layer (specified by cfg.gnn.post_mp) is used + to transform the pooled embedding using an MLP. + ''' + def __init__(self, dim_in, dim_out, num_layers=1): + super(MPHeteHead, self).__init__() + self.task_emb = nn.Parameter(torch.Tensor(dim_out, dim_in)) + self.task_emb.data = nn.init.xavier_uniform_( + self.task_emb.data, gain=nn.init.calculate_gain('relu')) + self.mp = StackHeteConv(dim_in, + dim_in, + cfg.kg.layers_mp, + layer_name=cfg.kg.layer_type) + + def pred(self, batch): + pred = torch.matmul(batch.node_feature['data'], + batch.node_feature['task'].transpose(0, 1)) + + if not hasattr(batch, 'pred'): + batch.pred = pred + else: + batch.pred = batch.pred + pred + return batch + + def forward(self, batch): + batch_kg = Batch() + # create nodes + if cfg.kg.has_l2norm == 'every': + task_emb = F.normalize(self.task_emb, p=2, dim=1) + data_emb = F.normalize(batch.graph_feature, p=2, dim=1) + else: + task_emb = self.task_emb + data_emb = batch.graph_feature + # create node + batch_kg.node_feature = {'data': data_emb, 'task': task_emb} + + # create edge + batch_kg.edge_index, batch_kg.edge_feature = {}, {} + batch_kg.edge_index[('data', 'property', 'task')] = torch.stack( + (batch.graph_targets_id_batch, batch.graph_targets_id), dim=0) + batch_kg.edge_index[('task', 'property', 'data')] = torch.flip( + batch_kg.edge_index[('data', 'property', 'task')], [0]) + batch_kg.edge_feature[ + 'property'] = batch.graph_targets_value.unsqueeze( + -1) # todo: diff discrete and continous feature + + batch_kg = self.mp(batch_kg) + batch_kg = self.pred(batch_kg) + + return batch_kg.pred, data_emb, task_emb + + +class MPHeteNewHead(nn.Module): + '''Heterogenous Message pass then predict label + + The optional post_mp layer (specified by cfg.gnn.post_mp) is used + to transform the pooled embedding using an MLP. + ''' + def __init__(self, dim_in, dim_out, num_layers=1): + super(MPHeteNewHead, self).__init__() + if cfg.kg.meta: # inductive + self.task_emb = nn.Parameter(F.normalize(torch.ones( + dim_out, dim_in), + p=2, + dim=1), + requires_grad=False) + else: # transductive + self.task_emb = nn.Parameter(torch.Tensor(dim_out, dim_in)) + self.task_emb.data = nn.init.xavier_uniform_( + self.task_emb.data, gain=nn.init.calculate_gain('relu')) + self.mp = StackHeteConv(dim_in, + dim_in, + cfg.kg.layers_mp, + layer_name=cfg.kg.layer_type) + if cfg.kg.decode == 'concat': + self.decode = nn.Linear(dim_in * 2, 1) + + def pred(self, batch): + # todo: compare with index_select then sparse dot product + if cfg.kg.decode == 'dot': + pred = torch.matmul(batch.node_feature['data'], + batch.node_feature['task'].transpose(0, 1)) + elif cfg.kg.decode == 'concat': + pred = batch.node_feature['data'].view + id_data = batch.pred_index[0, :] + id_task = batch.pred_index[1, :] + pred = pred[id_data].gather(1, id_task.view(-1, 1)) + if not hasattr(batch, 'pred'): + batch.pred = pred + else: + batch.pred = batch.pred + pred + return batch + + def forward(self, batch, pred_index=None): + batch_kg = Batch() + # create nodes + if cfg.kg.has_l2norm == 'every': + task_emb = F.normalize(self.task_emb, p=2, dim=1) + data_emb = F.normalize(batch.graph_feature, p=2, dim=1) + else: + task_emb = self.task_emb + data_emb = batch.graph_feature + # create node + batch_kg.node_feature = {'data': data_emb, 'task': task_emb} + + # create edge + batch_kg.edge_index, batch_kg.edge_feature = {}, {} + batch_kg.edge_index[('data', 'property', 'task')] = torch.stack( + (batch.graph_targets_id_batch, batch.graph_targets_id), dim=0) + batch_kg.edge_index[('task', 'property', 'data')] = torch.flip( + batch_kg.edge_index[('data', 'property', 'task')], [0]) + batch_kg.edge_feature[ + 'property'] = batch.graph_targets_value.unsqueeze( + -1) # todo: diff discrete and continous feature + + batch.pred_index = pred_index + batch_kg.pred_index = pred_index + + batch_kg = self.mp(batch_kg) + batch_kg = self.pred(batch_kg) + + return batch_kg.pred, data_emb, task_emb + + +class PairHead(nn.Module): + '''Directly predict label + + The optional post_mp layer (specified by cfg.gnn.post_mp) is used + to transform the pooled embedding using an MLP. + ''' + def __init__(self, dim_in, dim_out, num_layers=1): + super(PairHead, self).__init__() + self.layer = nn.Bilinear(dim_in, dim_in, dim_out, bias=True) + + def forward(self, batch): + x1 = batch.graph_feature.unsqueeze(0).repeat( + batch.graph_feature.shape[0], 1, 1) + x2 = batch.graph_feature.unsqueeze(1).repeat( + 1, batch.graph_feature.shape[0], 1) + pred = self.layer(x1, x2) + return pred + + +class MetaLink(nn.Module): + '''MetaLink GNN model''' + def __init__(self, dim_in, dim_out, **kwargs): + """ + Parameters: + node_encoding_classes - For integer features, gives the number + of possible integer features to map. + """ + super(MetaLink, self).__init__() + + GNNStage = stage_dict[cfg.gnn.stage_type] + + self.preprocess = Preprocess(dim_in) + d_in = self.preprocess.dim_out + if cfg.gnn.layers_pre_mp > 0: + self.pre_mp = GNNPreMP(d_in, cfg.gnn.dim_inner) + d_in = cfg.gnn.dim_inner + if cfg.gnn.layers_mp > 1: + self.mp = GNNStage(dim_in=d_in, + dim_out=cfg.gnn.dim_inner, + num_layers=cfg.gnn.layers_mp) + d_in = self.mp.dim_out + self.post_mp = MetaLinkPool(dim_in=d_in, dim_out=d_in) + + # KG mp + head_type = cfg.kg.head + + if head_type == 'direct': + self.head = DirectHead(dim_in=d_in, dim_out=dim_out) + elif head_type == 'mp': + self.head = MPHead(dim_in=d_in, dim_out=dim_out) + elif head_type == 'mphete': + self.head = MPHeteHead(dim_in=d_in, dim_out=dim_out) + elif head_type == 'mphetenew': + self.head = MPHeteNewHead(dim_in=d_in, dim_out=dim_out) + elif head_type == 'pair': + self.head = PairHead(dim_in=d_in, dim_out=dim_out) + elif head_type == 'concat': # todo: generalize dim_aux + self.head = ConcatHead(dim_in=d_in, dim_out=dim_out, dim_aux=14) + + self.apply(init_weights) + + def kg_edges(self, n, degree=20): + source_all = [] + target_all = [] + for i in range(n): + source_all += [i] * degree + target = range(i + 1, i + degree + 1) + target_all += [node % n for node in target] + source_all = torch.tensor(source_all) + target_all = torch.tensor(target_all) + edge_index = torch.stack((source_all, target_all), dim=0) + edge_index.to(torch.device(cfg.device)) + + def process_kg(self, batch, motif_emb): + batch_new = Batch() + batch_new.mol_emb = batch.graph_feature + batch_new.task_emb = self.task_emb + batch_new.motif_emb = motif_emb + + batch_new.edge_indices = {} + batch_new.edge_features = {} + + batch_new.edge_indices['mol_task'] = torch.stack( + (batch.graph_targets_id_batch, batch.graph_targets_id), dim=0) + # todo: normalize feature? + batch_new.edge_features['mol_task'] = batch.graph_targets_value + + batch_new.edge_indices['mol_motif'] = torch.stack( + (batch.graph_motifs_id_batch, batch.graph_motifs_id), dim=0) + + def emb_motif(self, batch, motifs): + motifs_id_all = torch.cat( + (batch.graph_motifs_id, batch.graph_motifs_neg_id), dim=0) + motifs_batch_all = torch.cat( + (batch.graph_motifs_id_batch, batch.graph_motifs_neg_id_batch), + dim=0) + unique_id = torch.unique(motifs_id_all) + motifs_batch = Batch.from_data_list(motifs[unique_id]) + motifs_batch.to(torch.device(cfg.device)) + motifs_emb = self.forward_compute(motifs_batch).graph_feature + motifs_emb_scatter = torch.zeros((len(motifs), motifs_emb.shape[1]), + device=torch.device(cfg.device)) + motifs_emb_scatter.index_add_(0, unique_id, motifs_emb) + batch.motif_emb = motifs_emb_scatter + batch.motif_id_all = motifs_id_all + batch.motif_batch_all = motifs_batch_all + + label_pos = torch.ones((batch.graph_motifs_id.shape[0]), + device=torch.device(cfg.device)) + label_neg = torch.zeros((batch.graph_motifs_neg_id.shape[0]), + device=torch.device(cfg.device)) + label = torch.cat((label_pos, label_neg), dim=0) + batch.motif_label = label + return batch + + def forward_compute(self, batch): + batch = self.preprocess(batch) + batch = self.pre_mp(batch) + batch = self.mp(batch) + batch = self.post_mp(batch) + + return batch + + def forward_emb(self, batch, motifs=None): + batch = self.forward_compute(batch) + return batch + + def forward_pred(self, batch, **kwargs): + batch = self.head(batch, **kwargs) + return batch + + +register_network('metalink', MetaLink) diff --git a/graphgym/contrib/train/metalink.py b/graphgym/contrib/train/metalink.py new file mode 100644 index 00000000..721d1bdb --- /dev/null +++ b/graphgym/contrib/train/metalink.py @@ -0,0 +1,282 @@ +import logging +import time + +import torch + +from graphgym.checkpoint import clean_ckpt, load_ckpt, save_ckpt +from graphgym.config import cfg +from graphgym.loss import compute_loss +from graphgym.register import register_train +from graphgym.utils.epoch import is_ckpt_epoch, is_eval_epoch + + +def train_epoch(logger, loader, model, optimizer, scheduler, motifs=None): + model.train() + # Meta-learning preparation + if cfg.kg.meta: + targets = list(range(cfg.share.num_task)) + meta_num = int(len(targets) * cfg.kg.meta_ratio) + keep_target = targets[:-meta_num] + time_start = time.time() + for i, batch in enumerate(loader): + if i % 2 == 0: + batch_support = batch + continue + batch_query = batch + + if cfg.kg.meta: + batch_support = mask_label(batch_support, keep_target, cfg.kg.aux) + batch_query = mask_label(batch_query, keep_target, cfg.kg.aux) + + batch_support.to(torch.device(cfg.device)) + batch_query.to(torch.device(cfg.device)) + + batch_query, pred_index, true = split_label( + batch_query, + setting=cfg.kg.setting, + keep_ratio=cfg.kg.setting_ratio, + aux=cfg.kg.aux) + + optimizer.zero_grad() + + batch_query = model.forward_emb(batch_query, motifs) + batch_support = model.forward_emb(batch_support, motifs=None) + + # concat batch + batch = concat_batch(batch_query, batch_support) + + pred, data_emb, task_emb = model.forward_pred(batch, + pred_index=pred_index) + # todo: see if batch_query need forward_pred + + task_type = 'classification_binary' + loss, pred_score = compute_loss(pred, true, task_type) + + pred_dict, true_dict = get_multi_task( + pred=pred_score.detach().cpu(), + true=true.detach().cpu(), + pred_index=batch.pred_index.detach().cpu(), + num_targets=len(targets)) + + logger.update_stats(true=true_dict, + pred=pred_dict, + loss=loss.item(), + lr=scheduler.get_last_lr()[0], + time_used=time.time() - time_start, + params=cfg.params) + + time_start = time.time() + + # 3 update parameters + loss.backward() # accumulate gradient from each target + optimizer.step() + + scheduler.step() + + +def eval_epoch(logger, loader_train, loader_val, model, motifs=None): + model.eval() + if cfg.kg.meta: + targets = list(range(cfg.share.num_task)) + meta_num = int(len(targets) * cfg.kg.meta_ratio) + keep_target = [0, 1] + targets[-meta_num:] + time_start = time.time() + + for i, batch_query in enumerate(loader_val): + if cfg.kg.meta: + batch_query = mask_label(batch_query, keep_target, aux=cfg.kg.aux) + batch_query.to(torch.device(cfg.device)) + batch_query, pred_index, true = split_label( + batch_query, + setting=cfg.kg.setting, + keep_ratio=cfg.kg.setting_ratio, + aux=cfg.kg.aux) + batch_query = model.forward_emb(batch_query, motifs) + + pred_all = 0 + for j in range(cfg.kg.repeat): + batch_support = next(iter(loader_train)) + + if cfg.kg.meta: + batch_support = mask_label(batch_support, + keep_target, + aux=cfg.kg.aux) + + batch_support.to(torch.device(cfg.device)) + + batch_support = model.forward_emb(batch_support, motifs=None) + + # concat batch + batch = concat_batch(batch_query, batch_support) + + pred, data_emb, task_emb = model.forward_pred( + batch, pred_index=pred_index) + # todo: see if batch_query need forward_pred + pred_all += pred + pred = pred_all / cfg.kg.repeat + + task_type = 'classification_binary' + loss, pred_score = compute_loss(pred, true, task_type) + + pred_dict, true_dict = get_multi_task( + pred=pred_score.detach().cpu(), + true=true.detach().cpu(), + pred_index=batch.pred_index.detach().cpu(), + num_targets=len(targets)) + + logger.update_stats(true=true_dict, + pred=pred_dict, + loss=loss.item(), + lr=0, + time_used=time.time() - time_start, + params=cfg.params) + + time_start = time.time() + + +def train(loggers, loaders, model, optimizer, scheduler, motifs=None): + start_epoch = 0 + if cfg.train.auto_resume: + start_epoch = load_ckpt(model, optimizer, scheduler) + if start_epoch == cfg.optim.max_epoch: + logging.info('Checkpoint found, Task already done') + else: + logging.info('Start from epoch {}'.format(start_epoch)) + + num_splits = len(loggers) + for cur_epoch in range(start_epoch, cfg.optim.max_epoch): + train_epoch(loggers[0], loaders[0], model, optimizer, scheduler, + motifs) + loggers[0].write_epoch(cur_epoch) + if is_eval_epoch(cur_epoch): + for i in range(1, num_splits): + eval_epoch(loggers[i], loaders[0], loaders[i], model, motifs) + loggers[i].write_epoch(cur_epoch) + if is_ckpt_epoch(cur_epoch): + save_ckpt(model, optimizer, scheduler, cur_epoch) + + for logger in loggers: + logger.close() + if cfg.train.ckpt_clean: + clean_ckpt() + + logging.info('Task done, results saved in {}'.format(cfg.out_dir)) + + +def mask_label(batch, keep_target=[], aux=True): + if not aux: + target_aux = [0, 1] + keep_target = [ + target for target in keep_target if target not in target_aux + ] + device = batch.graph_targets_id_batch.device + keep_mask = torch.zeros_like(batch.graph_targets_id, + dtype=torch.bool, + device=device) + for target in keep_target: + keep_mask += batch.graph_targets_id == target + batch.graph_targets_id_batch = batch.graph_targets_id_batch[keep_mask] + batch.graph_targets_id = batch.graph_targets_id[keep_mask] + batch.graph_targets_value = batch.graph_targets_value[keep_mask] + return batch + + +def split_label(batch, setting='standard', keep_ratio=0.5, aux=True): + device = batch.graph_targets_id_batch.device + if setting == 'standard': + if aux: + target_aux = [0, 1] + keep_mask = torch.zeros_like(batch.graph_targets_id, + dtype=torch.bool, + device=device) + for target in target_aux: + keep_mask += batch.graph_targets_id == target + pred_index = torch.stack((batch.graph_targets_id_batch[~keep_mask], + batch.graph_targets_id[~keep_mask]), + dim=0) + true = batch.graph_targets_value[~keep_mask] + batch.graph_targets_id_batch = batch.graph_targets_id_batch[ + keep_mask] + batch.graph_targets_id = batch.graph_targets_id[keep_mask] + batch.graph_targets_value = batch.graph_targets_value[keep_mask] + else: + pred_index = torch.stack( + (batch.graph_targets_id_batch, batch.graph_targets_id), dim=0) + true = batch.graph_targets_value + batch.graph_targets_id_batch = torch.tensor([], device=device) + batch.graph_targets_id = torch.tensor([], device=device) + batch.graph_targets_value = torch.tensor([], device=device) + elif setting == 'relation': + if aux: + target_aux = [0, 1] + keep_mask = torch.zeros_like(batch.graph_targets_id, + dtype=torch.float, + device=device).uniform_() < keep_ratio + for target in target_aux: + keep_mask += batch.graph_targets_id == target + pred_index = torch.stack((batch.graph_targets_id_batch[~keep_mask], + batch.graph_targets_id[~keep_mask]), + dim=0) + true = batch.graph_targets_value[~keep_mask] + batch.graph_targets_id_batch = batch.graph_targets_id_batch[ + keep_mask] + batch.graph_targets_id = batch.graph_targets_id[keep_mask] + batch.graph_targets_value = batch.graph_targets_value[keep_mask] + else: + keep_mask = torch.zeros_like(batch.graph_targets_id, + dtype=torch.float, + device=device).uniform_() < keep_ratio + pred_index = torch.stack((batch.graph_targets_id_batch[~keep_mask], + batch.graph_targets_id[~keep_mask]), + dim=0) + true = batch.graph_targets_value[~keep_mask] + batch.graph_targets_id_batch = batch.graph_targets_id_batch[ + keep_mask] + batch.graph_targets_id = batch.graph_targets_id[keep_mask] + batch.graph_targets_value = batch.graph_targets_value[keep_mask] + return batch, pred_index, true + + +def concat_batch(batch_query, batch_support): + id_bias = batch_query.graph_feature.shape[0] + batch_query.graph_feature = torch.cat( + (batch_query.graph_feature, batch_support.graph_feature), dim=0) + if batch_query.graph_targets_id_batch.shape[0] > 0: + batch_query.graph_targets_id = torch.cat( + (batch_query.graph_targets_id, batch_support.graph_targets_id), + dim=0) + batch_query.graph_targets_value = torch.cat( + (batch_query.graph_targets_value, + batch_support.graph_targets_value), + dim=0) + batch_support.graph_targets_id_batch += id_bias + batch_query.graph_targets_id_batch = torch.cat( + (batch_query.graph_targets_id_batch, + batch_support.graph_targets_id_batch), + dim=0) + else: + batch_query.graph_targets_id = batch_support.graph_targets_id + batch_query.graph_targets_value = batch_support.graph_targets_value + batch_query.graph_targets_id_batch = \ + batch_support.graph_targets_id_batch + + return batch_query + + +def get_multi_task(pred, true, pred_index, num_targets): + pred_index = pred_index[1, :] + pred_dict = {} + true_dict = {} + for i in range(num_targets): + mask = pred_index == i + pred_tmp = pred[mask] + true_tmp = true[mask] + if pred_tmp.shape[0] > 0: + pred_dict[i] = pred_tmp + if true_tmp.shape[0] > 0: + true_dict[i] = true_tmp + + return pred_dict, true_dict + + +register_train('metalink', train) diff --git a/graphgym/loader.py b/graphgym/loader.py index 15275702..c0efc4e5 100644 --- a/graphgym/loader.py +++ b/graphgym/loader.py @@ -16,6 +16,7 @@ import graphgym.models.feature_augment as preprocess import graphgym.register as register from graphgym.config import cfg +from graphgym.contrib.loader.molecule import load_mol_datasets from graphgym.models.transform import (edge_nets, ego_nets, path_len, remove_node_feature) @@ -119,6 +120,17 @@ def load_dataset(): # Note this is only used for custom splits from OGB split_idx = dataset.get_idx_split() return graphs, split_idx + elif cfg.dataset.format == 'mol': + # names = ['bace', 'bbbp'] + names = cfg.dataset.name.split('_') + graphs, splits, motifs, targets = load_mol_datasets(names, + use_cache=True) + if 'aux' not in cfg.kg.setting: + for i in range(len(graphs)): + graphs[i].graph_targets_id = graphs[i].graph_targets_id[2:] + graphs[i].graph_targets_value = graphs[i].graph_targets_value[ + 2:] + return graphs, splits, motifs, targets else: raise ValueError('Unknown data format: {}'.format(cfg.dataset.format)) return graphs @@ -207,6 +219,9 @@ def set_dataset_info(datasets): if 'classification' in cfg.dataset.task_type and \ cfg.share.dim_out == 2: cfg.share.dim_out = 1 + elif 'binary' in cfg.dataset.task_type: + cfg.share.dim_out = 1 + except Exception: cfg.share.dim_out = 1 @@ -219,6 +234,10 @@ def create_dataset(): time1 = time.time() if cfg.dataset.format == 'OGB': graphs, splits = load_dataset() + elif cfg.dataset.format == 'mol': + graphs, splits, motifs, targets = load_dataset() + cfg.dataset.edge_dim = graphs[0].num_edge_features + cfg.share.num_task = len(targets) else: graphs = load_dataset() @@ -247,6 +266,12 @@ def create_dataset(): datasets.append(dataset[splits['train']]) datasets.append(dataset[splits['valid']]) datasets.append(dataset[splits['test']]) + elif cfg.dataset.format == 'mol': + datasets = [] + # seed starts from 1 + datasets.append(dataset[splits['train'][cfg.seed - 1]]) + datasets.append(dataset[splits['valid'][cfg.seed - 1]]) + datasets.append(dataset[splits['test'][cfg.seed - 1]]) # Use random split, supported by DeepSNAP else: datasets = dataset.split(transductive=cfg.dataset.transductive, @@ -271,21 +296,43 @@ def create_dataset(): def create_loader(datasets): - loader_train = DataLoader(datasets[0], - collate_fn=Batch.collate(), - batch_size=cfg.train.batch_size, - shuffle=True, - num_workers=cfg.num_workers, - pin_memory=False) - - loaders = [loader_train] - for i in range(1, len(datasets)): - loaders.append( - DataLoader(datasets[i], - collate_fn=Batch.collate(), - batch_size=cfg.train.batch_size, - shuffle=False, - num_workers=cfg.num_workers, - pin_memory=False)) + if cfg.dataset.format == 'mol': + follow_batch = ['edge_index', 'graph_targets_id'] + loader_train = DataLoader( + datasets[0], + collate_fn=Batch.collate(follow_batch=follow_batch), + batch_size=cfg.train.batch_size, + shuffle=True, + num_workers=cfg.num_workers, + pin_memory=False, + drop_last=False) + + loaders = [loader_train] + for i in range(1, len(datasets)): + loaders.append( + DataLoader(datasets[i], + collate_fn=Batch.collate(follow_batch=follow_batch), + batch_size=cfg.train.batch_size, + shuffle=True, + num_workers=cfg.num_workers, + pin_memory=False, + drop_last=False)) + else: + loader_train = DataLoader(datasets[0], + collate_fn=Batch.collate(), + batch_size=cfg.train.batch_size, + shuffle=True, + num_workers=cfg.num_workers, + pin_memory=False) + + loaders = [loader_train] + for i in range(1, len(datasets)): + loaders.append( + DataLoader(datasets[i], + collate_fn=Batch.collate(), + batch_size=cfg.train.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + pin_memory=False)) return loaders diff --git a/graphgym/logger.py b/graphgym/logger.py index 945e2ebb..db67a9cc 100644 --- a/graphgym/logger.py +++ b/graphgym/logger.py @@ -118,6 +118,50 @@ def classification_multi(self): pred_int = self._get_pred_int(pred_score) return {'accuracy': round(accuracy_score(true, pred_int), cfg.round)} + def collate_dict(self, pred): + ''' Collate list of dict''' + pred_all = pred[0] + for i in range(1, len(pred)): + pred_tmp = pred[i] + for key, val in pred_tmp.items(): + if key in pred_all: + pred_all[key] = torch.cat((pred_all[key], val), dim=-1) + else: + pred_all[key] = val + return pred_all + + def classification_binary_multi(self): + from sklearn.metrics import (accuracy_score, f1_score, precision_score, + recall_score, roc_auc_score) + true_dict = self.collate_dict(self._true) + pred_score_dict = self.collate_dict(self._pred) + count = 0 + stats = None + for key in true_dict.keys(): + true = true_dict[key] + if len(torch.unique(true)) == 1: + continue + pred_score = pred_score_dict[key] + pred_int = self._get_pred_int(pred_score_dict[key]) + if stats is None: + stats = { + 'accuracy': accuracy_score(true, pred_int), + 'precision': precision_score(true, pred_int), + 'recall': recall_score(true, pred_int), + 'f1': f1_score(true, pred_int), + 'auc': roc_auc_score(true, pred_score), + } + else: + stats['accuracy'] += accuracy_score(true, pred_int) + stats['precision'] += precision_score(true, pred_int) + stats['recall'] += recall_score(true, pred_int) + stats['f1'] += f1_score(true, pred_int) + stats['auc'] += roc_auc_score(true, pred_score) + count += 1 + for key in stats.keys(): + stats[key] = round(stats[key] / count, cfg.round) + return stats + def regression(self): from sklearn.metrics import mean_absolute_error, mean_squared_error @@ -140,11 +184,15 @@ def eta(self, epoch_current): return time_per_epoch * (self._epoch_total - epoch_current) def update_stats(self, true, pred, loss, lr, time_used, params, **kwargs): - assert true.shape[0] == pred.shape[0] self._iter += 1 - self._true.append(true) - self._pred.append(pred) - batch_size = true.shape[0] + if type(true) is dict and type(pred) is dict: + self._true.append(true) + self._pred.append(pred) + batch_size = sum([val.shape[0] for val in true.values()]) + else: + self._true.append(true) + self._pred.append(pred) + batch_size = true.shape[0] self._size_current += batch_size self._loss += loss * batch_size self._lr = lr @@ -180,6 +228,8 @@ def write_epoch(self, cur_epoch): task_stats = self.classification_binary() elif self.task_type == 'classification_multi': task_stats = self.classification_multi() + elif self.task_type == 'classification_binary_multi': + task_stats = self.classification_binary_multi() else: raise ValueError('Task has to be regression or classification') @@ -221,7 +271,10 @@ def infer_task(): num_label = cfg.share.dim_out if cfg.dataset.task_type == 'classification': if num_label <= 2: - task_type = 'classification_binary' + if cfg.share.num_task == 1: + task_type = 'classification_binary' + else: + task_type = 'classification_binary_multi' else: task_type = 'classification_multi' else: diff --git a/graphgym/loss.py b/graphgym/loss.py index 505207bd..656dd2c6 100644 --- a/graphgym/loss.py +++ b/graphgym/loss.py @@ -6,13 +6,14 @@ from graphgym.config import cfg -def compute_loss(pred, true): +def compute_loss(pred, true, task_type=None): """ Compute loss and prediction score Args: pred (torch.tensor): Unnormalized prediction - true (torch.tensor): Grou + true (torch.tensor): Ground truth labels + task_type (str): User specified task type Returns: Loss, normalized prediction score @@ -30,19 +31,31 @@ def compute_loss(pred, true): value = func(pred, true) if value is not None: return value - - if cfg.model.loss_fun == 'cross_entropy': - # multiclass - if pred.ndim > 1 and true.ndim == 1: + if task_type is None: + if cfg.model.loss_fun == 'cross_entropy': + # multiclass + if pred.ndim > 1 and true.ndim == 1: + pred = F.log_softmax(pred, dim=-1) + return F.nll_loss(pred, true), pred + # binary or multilabel + else: + true = true.float() + return bce_loss(pred, true), torch.sigmoid(pred) + elif cfg.model.loss_fun == 'mse': + true = true.float() + return mse_loss(pred, true), pred + else: + raise ValueError('Loss func {} not supported'.format( + cfg.model.loss_fun)) + else: + if task_type == 'classification_multi': pred = F.log_softmax(pred, dim=-1) return F.nll_loss(pred, true), pred - # binary or multilabel - else: + elif 'classification' in task_type and 'binary' in task_type: true = true.float() return bce_loss(pred, true), torch.sigmoid(pred) - elif cfg.model.loss_fun == 'mse': - true = true.float() - return mse_loss(pred, true), pred - else: - raise ValueError('Loss func {} not supported'.format( - cfg.model.loss_fun)) + elif task_type == 'regression': + true = true.float() + return mse_loss(pred, true), pred + else: + raise ValueError('Task type {} not supported'.format(task_type)) diff --git a/graphgym/model_builder.py b/graphgym/model_builder.py index aa300c48..f9c6592e 100644 --- a/graphgym/model_builder.py +++ b/graphgym/model_builder.py @@ -20,7 +20,10 @@ def create_model(to_device=True, dim_in=None, dim_out=None): dim_out (int, optional): Output dimension to the model """ dim_in = cfg.share.dim_in if dim_in is None else dim_in - dim_out = cfg.share.dim_out if dim_out is None else dim_out + if cfg.share.num_task == 1: + dim_out = cfg.share.dim_out if dim_out is None else dim_out + else: + dim_out = cfg.share.num_task # binary classification, output dim = 1 if 'classification' in cfg.dataset.task_type and dim_out == 2: dim_out = 1 diff --git a/run/configs/MetaLink/mol_classification.yaml b/run/configs/MetaLink/mol_classification.yaml new file mode 100644 index 00000000..86493a7c --- /dev/null +++ b/run/configs/MetaLink/mol_classification.yaml @@ -0,0 +1,61 @@ +out_dir: results +dataset: + format: mol + name: tox21 + task: graph + task_type: classification_binary_multi + transductive: False + split: [0.8, 0.1, 0.1] + transform: none + task_main: [] + subgraph: False +train: + batch_size: 128 + eval_period: 1 + ckpt_period: 100 +model: + type: metalink + loss_fun: cross_entropy + size_average: True + edge_decoding: dot + graph_pooling: add +gnn: + layers_pre_mp: 1 + layers_mp: 2 + layers_post_mp: 2 + self_msg: add + dim_inner: 256 + layer_type: generaledgeconv + stage_type: stack + batchnorm: True + act: prelu + dropout: 0.0 + agg: add + normalize_adj: False + msg_direction: both + att_heads: 4 + att_final_linear: False + att_final_linear_bn: False + l2norm: False +optim: + optimizer: adam + base_lr: 0.001 + max_epoch: 30 + scheduler: cos +kg: + meta: False + setting: relation + setting_ratio: 0.2 + aux: True + head: mphetenew + self_msg: add + layer_type: kgheteconv + decode: dot + layers_mp: 1 + norm_emb: True + has_bn: False + gate_self: True + gate_msg: True + gate_bias: False + edge_feature: raw +view_emb: False \ No newline at end of file diff --git a/run/grids/MetaLink/basic.txt b/run/grids/MetaLink/basic.txt new file mode 100644 index 00000000..dbb30b30 --- /dev/null +++ b/run/grids/MetaLink/basic.txt @@ -0,0 +1,12 @@ +# name in config.py; short name; range to search + +dataset.name dataset ['sider','tox21','toxcast'] +gnn.layers_mp l_mp [2,3,4,5] +kg.head head ['mphetenew'] +kg.layers_mp kg_layers [1,2,3,4,5] +kg.meta meta [False,True] +kg.meta_ratio meta_ratio [0.2] +kg.setting setting ['relation'] +kg.setting_ratio keep_ratio [0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9] +optim.max_epoch ep [30] + diff --git a/run/main.py b/run/main.py index 0afa1a52..0e37d95d 100644 --- a/run/main.py +++ b/run/main.py @@ -5,7 +5,7 @@ from torch_geometric import seed_everything from graphgym.cmd_args import parse_args -from graphgym.config import cfg, dump_cfg, load_cfg, set_run_dir, set_out_dir +from graphgym.config import cfg, dump_cfg, load_cfg, set_out_dir, set_run_dir from graphgym.loader import create_dataset, create_loader from graphgym.logger import create_logger, setup_printing from graphgym.model_builder import create_model diff --git a/run/main_pyg.py b/run/main_pyg.py index dc15386d..a16faf46 100644 --- a/run/main_pyg.py +++ b/run/main_pyg.py @@ -5,7 +5,7 @@ from torch_geometric import seed_everything from graphgym.cmd_args import parse_args -from graphgym.config import cfg, dump_cfg, load_cfg, set_run_dir, set_out_dir +from graphgym.config import cfg, dump_cfg, load_cfg, set_out_dir, set_run_dir from graphgym.loader_pyg import create_loader from graphgym.logger import create_logger, setup_printing from graphgym.model_builder_pyg import create_model diff --git a/run/scripts/MetaLink/run_metalink.sh b/run/scripts/MetaLink/run_metalink.sh new file mode 100644 index 00000000..60125c60 --- /dev/null +++ b/run/scripts/MetaLink/run_metalink.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +cd ../.. + +DIR=MetaLink +CONFIG=mol_classification +GRID=basic +REPEAT=10 +MAX_JOBS=8 + +# generate configs (after controlling computational budget) +# please remove --config_budget, if don't control computational budget +python configs_gen.py --config configs/${DIR}/${CONFIG}.yaml \ + --grid grids/${DIR}/${GRID}.txt \ + --out_dir configs +# run batch of configs +# Args: config_dir, num of repeats, max jobs running +bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS +# rerun missed / stopped experiments +bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS +# rerun missed / stopped experiments +bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS + +# aggregate results for the batch +python agg_batch.py --dir results/${CONFIG}_grid_${GRID}