Source code for chainer_chemistry.links.readout.schnet_readout

import chainer
from chainer import functions

from chainer_chemistry.links.connection.graph_linear import GraphLinear


[docs]class SchNetReadout(chainer.Chain): """SchNet submodule for readout part. Args: out_dim (int): dimension of output feature vector hidden_dim (int): dimension of feature vector associated to each molecule """
[docs] def __init__(self, out_dim=1, hidden_dim=32): super(SchNetReadout, self).__init__() with self.init_scope(): self.linear1 = GraphLinear(hidden_dim) self.linear2 = GraphLinear(out_dim) self.out_dim = out_dim self.hidden_dim = hidden_dim
def __call__(self, h): h = self.linear1(h) h = functions.softplus(h) h = self.linear2(h) h = functions.sum(h, axis=1) return h