Source code for chainer_chemistry.links.update.ggnn_update

import chainer
from chainer import functions
from chainer import links

import chainer_chemistry
from chainer_chemistry.links.connection.graph_linear import GraphLinear


[docs]class GGNNUpdate(chainer.Chain): """GGNN submodule for update part. Args: hidden_dim (int): dimension of feature vector associated to each atom num_edge_type (int): number of types of edge """
[docs] def __init__(self, hidden_dim=16, num_edge_type=4): super(GGNNUpdate, self).__init__() with self.init_scope(): self.graph_linear = GraphLinear( hidden_dim, num_edge_type * hidden_dim) self.update_layer = links.GRU(2 * hidden_dim, hidden_dim) self.num_edge_type = num_edge_type
def __call__(self, h, adj): # --- Message part --- mb, atom, ch = h.shape out_ch = ch m = functions.reshape(self.graph_linear(h), (mb, atom, out_ch, self.num_edge_type)) # m: (minibatch, atom, ch, edge_type) # Transpose m = functions.transpose(m, (0, 3, 1, 2)) # m: (minibatch, edge_type, atom, ch) adj = functions.reshape(adj, (mb * self.num_edge_type, atom, atom)) # (minibatch * edge_type, atom, out_ch) m = functions.reshape(m, (mb * self.num_edge_type, atom, out_ch)) m = chainer_chemistry.functions.matmul(adj, m) # (minibatch * edge_type, atom, out_ch) m = functions.reshape(m, (mb, self.num_edge_type, atom, out_ch)) m = functions.sum(m, axis=1) # (minibatch, atom, out_ch) # --- Update part --- # Contraction h = functions.reshape(h, (mb * atom, ch)) # Contraction m = functions.reshape(m, (mb * atom, ch)) out_h = self.update_layer(functions.concat((h, m), axis=1)) # Expansion out_h = functions.reshape(out_h, (mb, atom, ch)) return out_h def reset_state(self): self.update_layer.reset_state()