import chainer
from chainer import functions
from chainer_chemistry.links.connection.graph_linear import GraphLinear
[docs]class GGNNReadout(chainer.Chain):
"""GGNN submodule for readout part.
Args:
out_dim (int): dimension of output feature vector
hidden_dim (int): dimension of feature vector associated to
each atom
nobias (bool): If ``True``, then this function does not use
the bias
activation (~chainer.Function or ~chainer.FunctionNode):
activate function for node representation
`functions.tanh` was suggested in original paper.
activation_agg (~chainer.Function or ~chainer.FunctionNode):
activate function for aggregation
`functions.tanh` was suggested in original paper.
"""
[docs] def __init__(self, out_dim, hidden_dim=16, nobias=False,
activation=functions.identity,
activation_agg=functions.identity):
super(GGNNReadout, self).__init__()
with self.init_scope():
self.i_layer = GraphLinear(None, out_dim, nobias=nobias)
self.j_layer = GraphLinear(None, out_dim, nobias=nobias)
self.out_dim = out_dim
self.hidden_dim = hidden_dim
self.nobias = nobias
self.activation = activation
self.activation_agg = activation_agg
def __call__(self, h, h0=None, is_real_node=None):
# --- Readout part ---
# h, h0: (minibatch, node, ch)
# is_real_node: (minibatch, node)
h1 = functions.concat((h, h0), axis=2) if h0 is not None else h
g1 = functions.sigmoid(self.i_layer(h1))
g2 = self.activation(self.j_layer(h1))
g = g1 * g2
if is_real_node is not None:
# mask virtual node feature to be 0
mask = self.xp.broadcast_to(
is_real_node[:, :, None], g.shape)
g = g * mask
# sum along node axis
g = self.activation_agg(functions.sum(g, axis=1))
return g