Source code for chainer_chemistry.links.update.nfp_update

import chainer
from chainer import functions
import numpy

import chainer_chemistry
from chainer_chemistry.links.connection.graph_linear import GraphLinear


[docs]class NFPUpdate(chainer.Chain): """NFP submodule for update part. Args: in_channels (int): input channel dimension out_channels (int): output channel dimension max_degree (int): max degree of edge """
[docs] def __init__(self, in_channels, out_channels, max_degree=6): super(NFPUpdate, self).__init__() num_degree_type = max_degree + 1 with self.init_scope(): self.graph_linears = chainer.ChainList( *[GraphLinear(in_channels, out_channels) for _ in range(num_degree_type)]) self.max_degree = max_degree self.in_channels = in_channels self.out_channels = out_channels
def __call__(self, h, adj, deg_conds): # h: (minibatch, atom, ch) # h encodes each atom's info in ch axis of size hidden_dim # adjs: (minibatch, atom, atom) # --- Message part --- # Take sum along adjacent atoms # fv: (minibatch, atom, ch) fv = chainer_chemistry.functions.matmul(adj, h) # --- Update part --- if self.xp is numpy: zero_array = numpy.zeros(fv.shape, dtype=numpy.float32) else: zero_array = self.xp.zeros_like(fv) fvds = [functions.where(cond, fv, zero_array) for cond in deg_conds] out_h = 0 for graph_linear, fvd in zip(self.graph_linears, fvds): out_h = out_h + graph_linear(fvd) # out_h shape (minibatch, max_num_atoms, hidden_dim) out_h = functions.sigmoid(out_h) return out_h