Source code for chainer_chemistry.models.rsgcn

import chainer
from chainer import functions
from chainer import Variable

import chainer_chemistry
from chainer_chemistry.config import MAX_ATOMIC_NUM
from chainer_chemistry.links.readout.general_readout import GeneralReadout
from chainer_chemistry.links.update.rsgcn_update import RSGCNUpdate


[docs]class RSGCN(chainer.Chain): """Renormalized Spectral Graph Convolutional Network (RSGCN) See: Thomas N. Kipf and Max Welling, \ Semi-Supervised Classification with Graph Convolutional Networks. \ September 2016. \ `arXiv:1609.02907 <https://arxiv.org/abs/1609.02907>`_ The name of this model "Renormalized Spectral Graph Convolutional Network (RSGCN)" is named by us rather than the authors of the paper above. The authors call this model just "Graph Convolution Network (GCN)", but we think that "GCN" is bit too general and may cause namespace issue. That is why we did not name this model as GCN. Args: out_dim (int): dimension of output feature vector hidden_dim (int): dimension of feature vector associated to each atom n_atom_types (int): number of types of atoms n_layers (int): number of layers use_batch_norm (bool): If True, batch normalization is applied after graph convolution. readout (Callable): readout function. If None, `GeneralReadout(mode='sum)` is used. To the best of our knowledge, the paper of RSGCN model does not give any suggestion on readout. dropout_ratio (float): ratio used in dropout function. If 0 or negative value is set, dropout function is skipped. """
[docs] def __init__(self, out_dim, hidden_dim=32, n_layers=4, n_atom_types=MAX_ATOMIC_NUM, use_batch_norm=False, readout=None, dropout_ratio=0.5): super(RSGCN, self).__init__() in_dims = [hidden_dim for _ in range(n_layers)] out_dims = [hidden_dim for _ in range(n_layers)] out_dims[n_layers - 1] = out_dim if readout is None: readout = GeneralReadout() with self.init_scope(): self.embed = chainer_chemistry.links.EmbedAtomID( in_size=n_atom_types, out_size=hidden_dim) self.gconvs = chainer.ChainList( *[RSGCNUpdate(in_dims[i], out_dims[i]) for i in range(n_layers)]) if use_batch_norm: self.bnorms = chainer.ChainList( *[chainer_chemistry.links.GraphBatchNormalization( out_dims[i]) for i in range(n_layers)]) else: self.bnorms = [None for _ in range(n_layers)] if isinstance(readout, chainer.Link): self.readout = readout if not isinstance(readout, chainer.Link): self.readout = readout self.out_dim = out_dim self.hidden_dim = hidden_dim self.n_layers = n_layers self.dropout_ratio = dropout_ratio
def __call__(self, graph, adj): """Forward propagation Args: graph (numpy.ndarray): minibatch of molecular which is represented with atom IDs (representing C, O, S, ...) `atom_array[mol_index, atom_index]` represents `mol_index`-th molecule's `atom_index`-th atomic number adj (numpy.ndarray): minibatch of adjancency matrix `adj[mol_index]` represents `mol_index`-th molecule's adjacency matrix Returns: ~chainer.Variable: minibatch of fingerprint """ if graph.dtype == self.xp.int32: # atom_array: (minibatch, nodes) h = self.embed(graph) else: h = graph # h: (minibatch, nodes, ch) if isinstance(adj, Variable): w_adj = adj.data else: w_adj = adj w_adj = Variable(w_adj, requires_grad=False) # --- RSGCN update --- for i, (gconv, bnorm) in enumerate(zip(self.gconvs, self.bnorms)): h = gconv(h, w_adj) if bnorm is not None: h = bnorm(h) if self.dropout_ratio > 0.: h = functions.dropout(h, ratio=self.dropout_ratio) if i < self.n_layers - 1: h = functions.relu(h) # --- readout --- y = self.readout(h) return y