Source code for chainer.links.normalization.batch_normalization

import chainer
from chainer.backends import cuda
from chainer import configuration
from chainer import functions
from chainer import initializers
from chainer import link
from chainer.utils import argument
from chainer import variable


class BatchNormalization(link.Link):

    """Batch normalization layer on outputs of linear or convolution functions.

    This link wraps the :func:`~chainer.functions.batch_normalization` and
    :func:`~chainer.functions.fixed_batch_normalization` functions.

    It runs in three modes: training mode, fine-tuning mode, and testing mode.

    In training mode, it normalizes the input by *batch statistics*. It also
    maintains approximated population statistics by moving averages, which can
    be used for instant evaluation in testing mode. Training mode is enabled
    when ``chainer.config.train`` is set to ``True`` and :meth:`__call__`
    is invoked with ``finetune=False`` (the default is False).

    In fine-tuning mode, it accumulates the input to compute *population
    statistics*. In order to correctly compute the population statistics, a
    user must use this mode to feed mini-batches running through whole training
    dataset. Finetuning mode is enabled when ``chainer.config.train`` is set to
    ``True`` and :meth:`__call__` is invoked with ``finetune=True``.

    In testing mode, it uses pre-computed population statistics to normalize
    the input variable. The population statistics is approximated if it is
    computed by training mode, or accurate if it is correctly computed by
    fine-tuning mode. Testing mode is enabled when ``chainer.config.train``
    is set to ``False``.

    Args:
        size (int, tuple of ints, or None): Size (or shape) of channel
            dimensions.  If ``None``, the size will be determined from
            dimension(s) of the input batch during the first forward pass.
        decay (float): Decay rate of moving average. It is used on training.
        eps (float): Epsilon value for numerical stability.
        dtype (numpy.dtype): Type to use in computing.
        use_gamma (bool): If ``True``, use scaling parameter. Otherwise, use
            unit(1) which makes no effect.
        use_beta (bool): If ``True``, use shifting parameter. Otherwise, use
            unit(0) which makes no effect.
        axis (int or tuple of int): Axis over which normalization is
            performed. When axis is ``None``, it is determined from input
            dimensions. For example, if ``x.ndim`` is 4, axis becomes (0, 2, 3)
            and normalization is performed over 0th, 2nd and 3rd axis of input.
            If it is 2, axis becomes (0) and normalization is performed
            over 0th axis of input. When a tuple of int is given to this
            option, numbers in the tuple must be being sorted in ascending
            order. For example, (0, 2) is OK, but (2, 0) is not.

        initial_gamma: Initializer of the scaling parameter. The default value
            is ``1``.
        initial_beta: Initializer of the shifting parameter. The default value
            is ``0``.
        initial_avg_mean: Initializer of the moving average of population mean.
            The default value is ``0``.
        initial_avg_var: Initializer of the moving average of population
            variance. The default value is ``1``.

    .. note::

        From v5.0.0, the initial value of the population variance is changed to
        1. It does not change the behavior of training, but the resulting model
        may have a slightly different behavior on inference. To emulate the
        old behavior, pass ``initial_avg_var=0`` for training.

    See: `Batch Normalization: Accelerating Deep Network Training by Reducing\
          Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_

    .. seealso::
       :func:`~chainer.functions.batch_normalization`,
       :func:`~chainer.functions.fixed_batch_normalization`

    Attributes:
        gamma (~chainer.Variable): Scaling parameter.
        beta (~chainer.Variable): Shifting parameter.
        avg_mean (:ref:`ndarray`): Population mean.
        avg_var (:ref:`ndarray`): Population variance.
        N (int): Count of batches given for fine-tuning.
        decay (float): Decay rate of moving average. It is used on training.
        eps (float): Epsilon value for numerical stability. This value is added
            to the batch variances.

    .. admonition:: Example

        >>> x = np.arange(12).reshape(4, 3).astype(np.float32) ** 2
        >>> x
        array([[  0.,   1.,   4.],
               [  9.,  16.,  25.],
               [ 36.,  49.,  64.],
               [ 81., 100., 121.]], dtype=float32)
        >>> bn = chainer.links.BatchNormalization(3)
        >>> bn(x)
        variable([[-1.        , -1.0664359 , -1.1117983 ],
                  [-0.71428573, -0.6714596 , -0.6401263 ],
                  [ 0.14285715,  0.19748813,  0.23583598],
                  [ 1.5714287 ,  1.5404074 ,  1.5160885 ]])
        >>> (x - x.mean(axis=0)) / np.sqrt(x.var(axis=0) + 2e-5)
        array([[-1.        , -1.0664359 , -1.1117983 ],
               [-0.71428573, -0.6714596 , -0.6401263 ],
               [ 0.14285715,  0.19748813,  0.235836  ],
               [ 1.5714285 ,  1.5404074 ,  1.5160886 ]], dtype=float32)

        There are several ways to make a BatchNormalization link.
        Consider an input of batched 10 images of 32x32 with 3 channels.

        >>> x = np.random.randn(10, 3, 32, 32).astype(np.float32)

        1. Give the parameter size:

            To normalize for each channel, give the number of channels
            to ``size``.

            >>> bn = chainer.links.BatchNormalization(3)
            >>> bn.avg_mean.shape
            (3,)
            >>> bn.beta += 2.0
            >>> bn.gamma *= 5.0
            >>> list(sorted(bn.namedparams()))  # doctest: +ELLIPSIS
            [('/beta', variable([2., ...])), ('/gamma', variable([5., ...]))]
            >>> y = bn(x)
            >>> y.shape
            (10, 3, 32, 32)
            >>> np.testing.assert_allclose(
            ...     y.array.mean(axis=(0, 2, 3)), bn.beta.array, atol=1e-6)
            >>> np.testing.assert_allclose(
            ...     y.array.std(axis=(0, 2, 3)),
            ...     bn.gamma.array, atol=1e-3)

            To normalize for each channel for each pixel, ``size`` should
            be the tuple of the dimensions.

            >>> bn = chainer.links.BatchNormalization((3, 32, 32))
            >>> bn.avg_mean.shape
            (3, 32, 32)
            >>> y = bn(x)
            >>> y.shape
            (10, 3, 32, 32)
            >>> np.testing.assert_allclose(
            ...     y.array.mean(axis=0), bn.beta.array, atol=1e-6)
            >>> np.testing.assert_allclose(
            ...     y.array.std(axis=0),
            ...     bn.gamma.array, atol=1e-3)

            By default, channel axis is (or starts from) the 1st axis of the
            input shape.

        2. Give the aggregate axes:

            from Chainer v5

            With ``axis`` option, similarly to NumPy, you may specify the
            aggregate axes, which are treated as the "batch" axes for the
            batch statistics.

            You can omit ``size`` if ``axis`` is given. In this case, creation
            of persistent values ``avg_mean``, ``avg_var`` and parameters
            ``beta``, ``gamma`` is deferred until first forward propagation.

            The examples in 1. corresponds to the following, respectively.

            >>> bn = chainer.links.BatchNormalization(axis=(0, 2, 3))
            >>> print(bn.avg_mean)
            None
            >>> y = bn(x)
            >>> bn.avg_mean.shape
            (3,)

            >>> bn = chainer.links.BatchNormalization(axis=0)
            >>> print(bn.avg_mean)
            None
            >>> y = bn(x)
            >>> bn.avg_mean.shape
            (3, 32, 32)

    """

    gamma = None
    beta = None
    avg_mean = None
    avg_var = None

    def __init__(self, size=None, decay=0.9, eps=2e-5, dtype=None,
                 use_gamma=True, use_beta=True,
                 initial_gamma=None, initial_beta=None, axis=None,
                 initial_avg_mean=None, initial_avg_var=None):
        super(BatchNormalization, self).__init__()

        if size is None and axis is None:
            raise RuntimeError('size or axis is required')
        self._initial_avg_mean = initial_avg_mean
        self._initial_avg_var = initial_avg_var
        self.N = 0
        self.register_persistent('N')
        self.decay = decay
        self.eps = eps
        if isinstance(axis, int):
            axis = (axis,)
        self.axis = axis
        self._dtype = chainer.get_dtype(dtype)

        with self.init_scope():
            if use_gamma:
                if initial_gamma is None:
                    initial_gamma = 1
                gamma_initializer = \
                    initializers._get_initializer(initial_gamma)
                gamma_initializer.dtype = self._dtype
                self.gamma = variable.Parameter(gamma_initializer)
            if use_beta:
                if initial_beta is None:
                    initial_beta = 0
                beta_initializer = initializers._get_initializer(initial_beta)
                beta_initializer.dtype = self._dtype
                self.beta = variable.Parameter(beta_initializer)

        if size is not None:
            self._initialize_params(size)

    def _initialize_params(self, shape):
        self.avg_mean = self._init_array(self._initial_avg_mean, 0, shape)
        self._initial_avg_mean = None
        self.register_persistent('avg_mean')
        self.avg_var = self._init_array(self._initial_avg_var, 1, shape)
        self._initial_avg_var = None
        self.register_persistent('avg_var')
        if self.gamma is not None:
            self.gamma.initialize(shape)
        if self.beta is not None:
            self.beta.initialize(shape)

    def _init_array(self, initializer, default_value, size):
        if initializer is None:
            initializer = default_value
        initializer = initializers._get_initializer(initializer)
        return initializers.generate_array(
            initializer, size, self.xp, dtype=self._dtype)

    def forward(self, x, **kwargs):
        """forward(self, x, finetune=False)

        Invokes the forward propagation of BatchNormalization.

        In training mode, the BatchNormalization computes moving averages of
        mean and variance for evaluation during training, and normalizes the
        input using batch statistics.

        .. warning::

           ``test`` argument is not supported anymore since v2.
           Instead, use ``chainer.using_config('train', False)``.
           See :func:`chainer.using_config`.

        Args:
            x (Variable): Input variable.
            finetune (bool): If it is in the training mode and ``finetune`` is
                ``True``, BatchNormalization runs in fine-tuning mode; it
                accumulates the input array to compute population statistics
                for normalization, and normalizes the input using batch
                statistics.

        """
        finetune, = argument.parse_kwargs(
            kwargs, ('finetune', False),
            test='test argument is not supported anymore. '
                 'Use chainer.using_config')

        if self.avg_mean is None:
            param_shape = tuple([
                d
                for i, d in enumerate(x.shape)
                if i not in self.axis])
            self._initialize_params(param_shape)

        gamma = self.gamma
        if gamma is None:
            with cuda.get_device_from_id(self._device_id):
                gamma = self.xp.ones(
                    self.avg_mean.shape, dtype=x.dtype)

        beta = self.beta
        if beta is None:
            with cuda.get_device_from_id(self._device_id):
                beta = self.xp.zeros(
                    self.avg_mean.shape, dtype=x.dtype)

        if configuration.config.train:
            if finetune:
                self.N += 1
                decay = 1. - 1. / self.N
            else:
                decay = self.decay

            avg_mean = self.avg_mean
            avg_var = self.avg_var

            if chainer.config.in_recomputing:
                # Do not update statistics when extra forward computation is
                # called.
                if finetune:
                    self.N -= 1  # Revert the count
                avg_mean = None
                avg_var = None

            ret = functions.batch_normalization(
                x, gamma, beta, eps=self.eps, running_mean=avg_mean,
                running_var=avg_var, decay=decay, axis=self.axis)
        else:
            # Use running average statistics or fine-tuned statistics.
            mean = self.avg_mean
            var = self.avg_var
            ret = functions.fixed_batch_normalization(
                x, gamma, beta, mean, var, self.eps, axis=self.axis)
        return ret

    def start_finetuning(self):
        """Resets the population count for collecting population statistics.

        This method can be skipped if it is the first time to use the
        fine-tuning mode. Otherwise, this method should be called before
        starting the fine-tuning mode again.

        """
        self.N = 0