Source code for chainer_chemistry.models.prediction.classifier

import warnings

import numpy

import chainer
from chainer.dataset.convert import concat_examples
from chainer.functions.evaluation import accuracy
from chainer.functions.loss import softmax_cross_entropy
from chainer import cuda, Variable  # NOQA
from chainer import reporter
from chainer_chemistry.models.prediction.base import BaseForwardModel

def _argmax(*args):
    x = args[0]
    return chainer.functions.argmax(x, axis=1)

[docs]class Classifier(BaseForwardModel): """A simple classifier model. This is an example of chain that wraps another chain. It computes the loss and accuracy based on a given input/label pair. Args: predictor (~chainer.Link): Predictor network. lossfun (function): Loss function. accfun (function): DEPRECATED. Please use `metrics_fun` instead. metrics_fun (function or dict or None): Function that computes metrics. label_key (int or str): Key to specify label variable from arguments. When it is ``int``, a variable in positional arguments is used. And when it is ``str``, a variable in keyword arguments is used. device (int): GPU device id of this Classifier to be used. -1 indicates to use in CPU. Attributes: predictor (~chainer.Link): Predictor network. lossfun (function): Loss function. accfun (function): DEPRECATED. Please use `metrics_fun` instead. y (~chainer.Variable): Prediction for the last minibatch. loss (~chainer.Variable): Loss value for the last minibatch. metrics (dict): Metrics computed in last minibatch compute_metrics (bool): If ``True``, compute metrics on the forward computation. The default value is ``True``. .. note:: The differences between original `Classifier` class in chainer and chainer chemistry are as follows. 1. `predict` and `predict_proba` methods are supported. 2. `device` can be managed internally by the `Classifier` 3. `accfun` is deprecated, `metrics_fun` is used instead. 4. `metrics_fun` can be `dict` which specifies the metrics name as key and function as value. .. note:: This link uses :func:`chainer.softmax_cross_entropy` with default arguments as a loss function (specified by ``lossfun``), if users do not explicitly change it. In particular, the loss function does not support double backpropagation. If you need second or higher order differentiation, you need to turn it on with ``enable_double_backprop=True``: >>> import chainer.functions as F >>> import chainer.links as L >>> >>> def lossfun(x, t): ... return F.softmax_cross_entropy( ... x, t, enable_double_backprop=True) >>> >>> predictor = L.Linear(10) >>> model = L.Classifier(predictor, lossfun=lossfun) """ compute_metrics = True
[docs] def __init__(self, predictor, lossfun=softmax_cross_entropy.softmax_cross_entropy, accfun=None, metrics_fun=accuracy.accuracy, label_key=-1, device=-1): if not (isinstance(label_key, (int, str))): raise TypeError('label_key must be int or str, but is %s' % type(label_key)) if accfun is not None: warnings.warn( 'accfun is deprecated, please use metrics_fun instead') warnings.warn('overriding metrics by accfun...') # override metrics by accfun metrics_fun = accfun super(Classifier, self).__init__() self.lossfun = lossfun if metrics_fun is None: self.compute_metrics = False self.metrics_fun = {} elif callable(metrics_fun): self.metrics_fun = {'accuracy': metrics_fun} elif isinstance(metrics_fun, dict): self.metrics_fun = metrics_fun else: raise TypeError('Unexpected type metrics_fun must be None or ' 'Callable or dict. actual {}'.format(type(accfun))) self.y = None self.loss = None self.metrics = None self.label_key = label_key with self.init_scope(): self.predictor = predictor # `initialize` must be called after `init_scope`. self.initialize(device)
def _convert_to_scalar(self, value): """Converts an input value to a scalar if its type is a Variable, numpy or cupy array, otherwise it returns the value as it is. """ if isinstance(value, Variable): value = value.array if numpy.isscalar(value): return value if type(value) is not numpy.array: value = cuda.to_cpu(value) return numpy.asscalar(value) def __call__(self, *args, **kwargs): """Computes the loss value for an input and label pair. It also computes accuracy and stores it to the attribute. Args: args (list of ~chainer.Variable): Input minibatch. kwargs (dict of ~chainer.Variable): Input minibatch. When ``label_key`` is ``int``, the correpoding element in ``args`` is treated as ground truth labels. And when it is ``str``, the element in ``kwargs`` is used. The all elements of ``args`` and ``kwargs`` except the ground trush labels are features. It feeds features to the predictor and compare the result with ground truth labels. Returns: ~chainer.Variable: Loss value. """ # --- Separate `args` and `t` --- if isinstance(self.label_key, int): if not (-len(args) <= self.label_key < len(args)): msg = 'Label key %d is out of bounds' % self.label_key raise ValueError(msg) t = args[self.label_key] if self.label_key == -1: args = args[:-1] else: args = args[:self.label_key] + args[self.label_key + 1:] elif isinstance(self.label_key, str): if self.label_key not in kwargs: msg = 'Label key "%s" is not found' % self.label_key raise ValueError(msg) t = kwargs[self.label_key] del kwargs[self.label_key] else: raise TypeError('Label key type {} not supported' .format(type(self.label_key))) self.y = None self.loss = None self.metrics = None self.y = self.predictor(*args, **kwargs) self.loss = self.lossfun(self.y, t) {'loss': self._convert_to_scalar(self.loss)}, self) if self.compute_metrics: # Note: self.accuracy is `dict`, which is different from original # chainer implementation self.metrics = {key: self._convert_to_scalar(value(self.y, t)) for key, value in self.metrics_fun.items()}, self) return self.loss def predict_proba( self, data, batchsize=16, converter=concat_examples, retain_inputs=False, preprocess_fn=None, postprocess_fn=chainer.functions.softmax): """Calculate probability of each category. Args: data: "train_x array" or "chainer dataset" fn (Callable): Main function to forward. Its input argument is either Variable, cupy.ndarray or numpy.ndarray, and returns Variable. batchsize (int): batch size converter (Callable): convert from `data` to `inputs` preprocess_fn (Callable): Its input is numpy.ndarray or cupy.ndarray, it can return either Variable, cupy.ndarray or numpy.ndarray postprocess_fn (Callable): Its input argument is Variable, but this method may return either Variable, cupy.ndarray or numpy.ndarray. retain_inputs (bool): If True, this instance keeps inputs in `self.inputs` or not. Returns (tuple or numpy.ndarray): Typically, it is 2-dimensional float array with shape (batchsize, number of category) which represents each examples probability to be each category. """ with chainer.no_backprop_mode(), chainer.using_config('train', False): proba = self._forward( data, fn=self.predictor, batchsize=batchsize, converter=converter, retain_inputs=retain_inputs, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn) return proba def predict( self, data, batchsize=16, converter=concat_examples, retain_inputs=False, preprocess_fn=None, postprocess_fn=_argmax): """Predict label of each category by taking . Args: data: input data batchsize (int): batch size converter (Callable): convert from `data` to `inputs` preprocess_fn (Callable): Its input is numpy.ndarray or cupy.ndarray, it can return either Variable, cupy.ndarray or numpy.ndarray postprocess_fn (Callable): Its input argument is Variable, but this method may return either Variable, cupy.ndarray or numpy.ndarray. retain_inputs (bool): If True, this instance keeps inputs in `self.inputs` or not. Returns (tuple or numpy.ndarray): Typically, it is 1-dimensional int array with shape (batchsize, ) which represents each examples category prediction. """ with chainer.no_backprop_mode(), chainer.using_config('train', False): predict_labels = self._forward( data, fn=self.predictor, batchsize=batchsize, converter=converter, retain_inputs=retain_inputs, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn) return predict_labels # --- For backward compatibility --- @property def compute_accuracy(self): warnings.warn('compute_accuracy is deprecated,' 'please use compute_metrics instead') return self.compute_metrics @compute_accuracy.setter def compute_accuracy(self, value): warnings.warn('compute_accuracy is deprecated,' 'please use compute_metrics instead') self.compute_metrics = value @property def accuracy(self): warnings.warn('accuracy is deprecated,' 'please use metrics instead') return self.metrics @accuracy.setter def accuracy(self, value): warnings.warn('accuracy is deprecated,' 'please use metrics instead') self.metrics = value @property def accfun(self): warnings.warn('accfun is deprecated,' 'please use metrics_fun instead') return self.metrics_fun @accfun.setter def accfun(self, value): warnings.warn('accfun is deprecated,' 'please use metrics_fun instead') self.metrics_fun = value