Source code for chainer_chemistry.links.update.schnet_update

"""
Chainer implementation of CFConv.

SchNet: A continuous-filter convolutional neural network for modeling
    quantum interactions
Kristof et al.
See: https://arxiv.org/abs/1706.08566
"""

import chainer
from chainer import functions
from chainer import links

from chainer_chemistry.links.connection.graph_linear import GraphLinear


class CFConv(chainer.Chain):
    def __init__(self, num_rbf=300, radius_resolution=0.1, gamma=10.0,
                 hidden_dim=64):
        super(CFConv, self).__init__()
        with self.init_scope():
            self.dense1 = links.Linear(num_rbf, hidden_dim)
            self.dense2 = links.Linear(hidden_dim)
        self.hidden_dim = hidden_dim
        self.num_rbf = num_rbf
        self.radius_resolution = radius_resolution
        self.gamma = gamma

    def __call__(self, h, dist):
        """
        Args:
            h (numpy.ndarray): axis 0 represents minibatch index,
                axis 1 represents atom_index and axis2 represents
                feature dimension.
            dist (numpy.ndarray): axis 0 represents minibatch index,
                axis 1 and 2 represent distance between atoms.

        """
        mb, atom, ch = h.shape
        if ch != self.hidden_dim:
            raise ValueError('h.shape[2] {} and hidden_dim {} must be same!'
                             .format(ch, self.hidden_dim))
        embedlist = self.xp.arange(
            self.num_rbf).astype('f') * self.radius_resolution
        dist = functions.reshape(dist, (mb, atom, atom, 1))
        dist = functions.broadcast_to(dist, (mb, atom, atom, self.num_rbf))
        dist = functions.exp(- self.gamma * (dist - embedlist) ** 2)
        dist = functions.reshape(dist, (-1, self.num_rbf))
        dist = self.dense1(dist)
        dist = functions.softplus(dist)
        dist = self.dense2(dist)
        dist = functions.softplus(dist)
        dist = functions.reshape(dist, (mb, atom, atom, self.hidden_dim))
        h = functions.reshape(h, (mb, atom, 1, self.hidden_dim))
        h = functions.broadcast_to(h, (mb, atom, atom, self.hidden_dim))
        h = functions.sum(h * dist, axis=1)
        return h


[docs]class SchNetUpdate(chainer.Chain):
[docs] def __init__(self, hidden_dim=64): super(SchNetUpdate, self).__init__() with self.init_scope(): self.linear = chainer.ChainList( *[GraphLinear(hidden_dim) for _ in range(3)]) self.cfconv = CFConv(hidden_dim=hidden_dim) self.hidden_dim = hidden_dim
def __call__(self, x, dist): v = self.linear[0](x) v = self.cfconv(v, dist) v = self.linear[1](v) v = functions.softplus(v) v = self.linear[2](v) return x + v