Source code for chainer_chemistry.links.connection.graph_linear

import chainer


[docs]class GraphLinear(chainer.links.Linear): """Graph Linear layer. This function assumes its input is 3-dimensional. Differently from :class:`chainer.functions.linear`, it applies an affine transformation to the third axis of input `x`. .. seealso:: :class:`chainer.links.Linear` """ def __call__(self, x): """Forward propagation. Args: x (:class:`chainer.Variable`, or :class:`numpy.ndarray`\ or :class:`cupy.ndarray`): Input array that should be a float array whose ``ndim`` is 3. It represents a minibatch of atoms, each of which consists of a sequence of molecules. Each molecule is represented by integer IDs. The first axis is an index of atoms (i.e. minibatch dimension) and the second one an index of molecules. Returns: :class:`chainer.Variable`: A 3-dimeisional array. """ h = x # (minibatch, atom, ch) s0, s1, s2 = h.shape h = chainer.functions.reshape(h, (s0 * s1, s2)) h = super(GraphLinear, self).__call__(h) h = chainer.functions.reshape(h, (s0, s1, self.out_size)) return h