Source code for chainer_chemistry.iterators.index_iterator

import numpy

from chainer.dataset import iterator


[docs]class IndexIterator(iterator.Iterator): """Index iterator IndexIterator is used internally in `BalancedSerialIterator`, as each label's index iterator Args: index_list (list): list of int which represents indices. shuffle (bool): shuffle flag. If True, indices specified by ``index_list`` will be randomly shuffled. num (int): number of indices to be extracted when ``___next___`` is called. """
[docs] def __init__(self, index_list, shuffle=True, num=0): self.index_list = numpy.asarray(index_list) assert self.index_list.ndim == 1 self.index_length = len(index_list) self.current_index_list = None self.current_pos = 0 self.shuffle = shuffle self.num = num self.update_current_index_list()
def update_current_index_list(self): if self.shuffle: self.current_index_list = numpy.random.permutation(self.index_list) else: self.current_index_list = self.index_list def __next__(self): return self.get_next_indices(self.num) def get_next_indices(self, num): """get next indices Args: num (int): number for indices to extract. Returns (numpy.ndarray): 1d array of indices .. admonition:: Example >>> ii = IndexIterator([1, 3, 5, 10], shuffle=True) >>> print(ii.get_next_indices(5)) [ 5 1 10 3 10] >>> print(ii.get_next_indices(5)) [ 3 1 5 10 1] """ indices = [] if self.current_pos + num < self.index_length: indices.append(self.current_index_list[ self.current_pos: self.current_pos + num]) self.current_pos += num else: indices.append(self.current_index_list[self.current_pos:]) num -= (self.index_length - self.current_pos) # When `num` is twice bigger than `self.index_length`, `index_list` # is repeated `q` times to get desired length of `indices`. q, r = divmod(num, self.index_length) if self.shuffle: for _ in range(q): indices.append(numpy.random.permutation(self.index_list)) else: indices.append(numpy.tile(self.index_list, q)) self.update_current_index_list() indices.append(self.current_index_list[:r]) self.current_pos = r return numpy.concatenate(indices).ravel() def serialize(self, serializer): self.current_index_list = serializer('current_index_list', self.current_index_list) self.current_pos = serializer('current_pos', self.current_pos)