import chainer
import chainer_chemistry
from chainer_chemistry.links.connection.graph_linear import GraphLinear
[docs]class RSGCNUpdate(chainer.Chain):
"""RSGCN submodule for message and update part.
Args:
in_channels (int): input channel dimension
out_channels (int): output channel dimension
"""
[docs] def __init__(self, in_channels, out_channels):
super(RSGCNUpdate, self).__init__()
with self.init_scope():
self.graph_linear = GraphLinear(
in_channels, out_channels, nobias=True)
self.in_channels = in_channels
self.out_channels = out_channels
def __call__(self, h, adj):
# --- Message part ---
h = chainer_chemistry.functions.matmul(adj, h)
# --- Update part ---
h = self.graph_linear(h)
return h