Source code for chainer_chemistry.iterators.balanced_serial_iterator

from __future__ import division

from logging import getLogger

from chainer.dataset import iterator
import numpy

from chainer_chemistry.iterators.index_iterator import IndexIterator


[docs]class BalancedSerialIterator(iterator.Iterator): """Dataset iterator that serially reads the examples with balancing label. Args: dataset: Dataset to iterate. batch_size (int): Number of examples within each minibatch. labels (list or numpy.ndarray): 1d array which specifies label feature of `dataset`. Its size must be same as the length of `dataset`. repeat (bool): If ``True``, it infinitely loops over the dataset. Otherwise, it stops iteration at the end of the first epoch. shuffle (bool): If ``True``, the order of examples is shuffled at the beginning of each epoch. Otherwise, the order is permanently same as that of `dataset`. batch_balancing (bool): If ``True``, examples are sampled in the way that each label examples are roughly evenly sampled in each minibatch. Otherwise, the iterator only guarantees that total numbers of examples are same among label features. ignore_labels (int or list or None): Labels to be ignored. If not ``None``, the example whose label is in `ignore_labels` are not sampled by this iterator. """
[docs] def __init__(self, dataset, batch_size, labels, repeat=True, shuffle=True, batch_balancing=False, ignore_labels=None, logger=getLogger(__name__)): assert len(dataset) == len(labels) labels = numpy.asarray(labels) if len(dataset) != labels.size: raise ValueError('dataset length {} and labels size {} must be ' 'same!'.format(len(dataset), labels.size)) labels = numpy.ravel(labels) self.dataset = dataset self.batch_size = batch_size self.labels = labels self.logger = logger if ignore_labels is None: ignore_labels = [] elif isinstance(ignore_labels, int): ignore_labels = [ignore_labels, ] self.ignore_labels = list(ignore_labels) self._repeat = repeat self._shuffle = shuffle self._batch_balancing = batch_balancing self.labels_iterator_dict = {} max_label_count = -1 include_label_count = 0 for label in numpy.unique(labels): label_index = numpy.argwhere(labels == label).ravel() label_count = len(label_index) ii = IndexIterator(label_index, shuffle=shuffle) self.labels_iterator_dict[label] = ii if label in self.ignore_labels: continue if max_label_count < label_count: max_label_count = label_count include_label_count += 1 self.max_label_count = max_label_count self.N_augmented = max_label_count * include_label_count self.reset()
def __next__(self): if not self._repeat and self.epoch > 0: raise StopIteration self._previous_epoch_detail = self.epoch_detail i = self.current_position i_end = i + self.batch_size N = self.N_augmented batch = [self.dataset[index] for index in self._order[i:i_end]] if i_end >= N: if self._repeat: rest = i_end - N self._update_order() if rest > 0: batch.extend([self.dataset[index] for index in self._order[:rest]]) self.current_position = rest else: self.current_position = 0 self.epoch += 1 self.is_new_epoch = True else: self.is_new_epoch = False self.current_position = i_end return batch next = __next__ @property def epoch_detail(self): return self.epoch + self.current_position / self.N_augmented @property def previous_epoch_detail(self): # This iterator saves ``-1`` as _previous_epoch_detail instead of # ``None`` because some serializers do not support ``None``. if self._previous_epoch_detail < 0: return None return self._previous_epoch_detail def serialize(self, serializer): self.current_position = serializer('current_position', self.current_position) self.epoch = serializer('epoch', self.epoch) self.is_new_epoch = serializer('is_new_epoch', self.is_new_epoch) if self._order is not None: serializer('order', self._order) self._previous_epoch_detail = serializer( 'previous_epoch_detail', self._previous_epoch_detail) for label, index_iterator in self.labels_iterator_dict.items(): self.labels_iterator_dict[label].serialize( serializer['index_iterator_{}'.format(label)]) def _update_order(self): indices_list = [] for label, index_iterator in self.labels_iterator_dict.items(): if label in self.ignore_labels: # Not include index of ignore_labels continue indices_list.append(index_iterator.get_next_indices( self.max_label_count)) if self._batch_balancing: # `indices_list` contains same number of indices of each label. # we can `transpose` and `ravel` it to get each label's index in # sequence, which guarantees that label in each batch is balanced. indices = numpy.array(indices_list).transpose().ravel() self._order = indices else: indices = numpy.array(indices_list).ravel() self._order = numpy.random.permutation(indices) def reset(self): self._update_order() self.current_position = 0 self.epoch = 0 self.is_new_epoch = False # use -1 instead of None internally. self._previous_epoch_detail = -1. def show_label_stats(self): self.logger.warning(' label count rate status') total = 0 for label, index_iterator in self.labels_iterator_dict.items(): count = len(index_iterator.index_list) total += count for label, index_iterator in self.labels_iterator_dict.items(): count = len(index_iterator.index_list) rate = count / len(self.dataset) status = 'ignored' if label in self.ignore_labels else 'included' self.logger.warning('{:>8} {:>8} {:>8.4f} {:>10}' .format(label, count, rate, status))