Source code for chainer_chemistry.links.readout.nfp_readout

import chainer
from chainer import functions

from chainer_chemistry.links.connection.graph_linear import GraphLinear


[docs]class NFPReadout(chainer.Chain): """NFP submodule for readout part. Args: in_channels (int): dimension of feature vector associated to each atom (node) out_size (int): output dimension of feature vector associated to each molecule (graph) """
[docs] def __init__(self, in_channels, out_size): super(NFPReadout, self).__init__() with self.init_scope(): self.output_weight = GraphLinear(in_channels, out_size) self.in_channels = in_channels self.out_size = out_size
def __call__(self, h, is_real_node=None): # h: (minibatch, node, ch) # is_real_node: (minibatch, node) # ---Readout part --- i = self.output_weight(h) i = functions.softmax(i, axis=2) # softmax along channel axis if is_real_node is not None: # mask virtual node feature to be 0 mask = self.xp.broadcast_to( is_real_node[:, :, None], i.shape) i = i * mask i = functions.sum(i, axis=1) # sum along atom's axis return i