Source code for chainer_chemistry.models.schnet

"""
Chainer implementation of SchNet

SchNet: A continuous-filter convolutional neural network for modeling quantum
interactions
Kristof et al.
See: https://arxiv.org/abs/1706.08566
"""
import chainer
from chainer import functions

from chainer_chemistry.config import MAX_ATOMIC_NUM
from chainer_chemistry.links.connection.embed_atom_id import EmbedAtomID
from chainer_chemistry.links.readout.schnet_readout import SchNetReadout
from chainer_chemistry.links.update.schnet_update import SchNetUpdate


[docs]class SchNet(chainer.Chain): """SchNet Args: out_dim (int): dimension of output feature vector hidden_dim (int): dimension of feature vector associated to each atom n_layers (int): number of layers readout_hidden_dim (int): dimension of feature vector associated to each molecule n_atom_types (int): number of types of atoms concat_hidden (bool): If set to True, readout is executed in each layer and the result is concatenated """
[docs] def __init__(self, out_dim=1, hidden_dim=64, n_layers=3, readout_hidden_dim=32, n_atom_types=MAX_ATOMIC_NUM, concat_hidden=False): super(SchNet, self).__init__() with self.init_scope(): self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types) self.update_layers = chainer.ChainList( *[SchNetUpdate(hidden_dim) for _ in range(n_layers)]) self.readout_layer = SchNetReadout(out_dim, readout_hidden_dim) self.out_dim = out_dim self.hidden_dim = hidden_dim self.readout_hidden_dim = readout_hidden_dim self.n_layers = n_layers self.concat_hidden = concat_hidden
def __call__(self, atom_features, dist_features): x = self.embed(atom_features) h = [] # --- update part --- for i in range(self.n_layers): x = self.update_layers[i](x, dist_features) if self.concat_hidden: h.append(x) # --- readout part --- if self.concat_hidden: x = functions.concat(h, axis=2) x = self.readout_layer(x) return x