Source code for chainer_chemistry.datasets.numpy_tuple_dataset

import os
import six

import numpy

from chainer_chemistry.dataset.indexers.numpy_tuple_dataset_feature_indexer import NumpyTupleDatasetFeatureIndexer  # NOQA


[docs]class NumpyTupleDataset(object): """Dataset of a tuple of datasets. It combines multiple datasets into one dataset. Each example is represented by a tuple whose ``i``-th item corresponds to the i-th dataset. And each ``i``-th dataset is expected to be an instance of numpy.ndarray. Args: datasets: Underlying datasets. The ``i``-th one is used for the ``i``-th item of each example. All datasets must have the same length. """
[docs] def __init__(self, *datasets): if not datasets: raise ValueError('no datasets are given') length = len(datasets[0]) for i, dataset in enumerate(datasets): if len(dataset) != length: raise ValueError( 'dataset of the index {} has a wrong length'.format(i)) self._datasets = datasets self._length = length self._features_indexer = NumpyTupleDatasetFeatureIndexer(self)
def __getitem__(self, index): batches = [dataset[index] for dataset in self._datasets] if isinstance(index, (slice, list, numpy.ndarray)): length = len(batches[0]) return [tuple([batch[i] for batch in batches]) for i in six.moves.range(length)] else: return tuple(batches) def __len__(self): return self._length def get_datasets(self): return self._datasets @property def features(self): """Extract features according to the specified index. - axis 0 is used to specify dataset id (`i`-th dataset) - axis 1 is used to specify feature index .. admonition:: Example >>> import numpy >>> from chainer_chemistry.datasets import NumpyTupleDataset >>> x = numpy.array([0, 1, 2], dtype=numpy.float32) >>> t = x * x >>> numpy_tuple_dataset = NumpyTupleDataset(x, t) >>> targets = numpy_tuple_dataset.features[:, 1] >>> print('targets', targets) # We can extract only target value targets [0, 1, 4] """ return self._features_indexer @classmethod def save(cls, filepath, numpy_tuple_dataset): """save the dataset to filepath in npz format Args: filepath (str): filepath to save dataset. It is recommended to end with '.npz' extension. numpy_tuple_dataset (NumpyTupleDataset): dataset instance """ if not isinstance(numpy_tuple_dataset, NumpyTupleDataset): raise TypeError('numpy_tuple_dataset is not instance of ' 'NumpyTupleDataset, got {}' .format(type(numpy_tuple_dataset))) numpy.savez(filepath, *numpy_tuple_dataset._datasets) @classmethod def load(cls, filepath): if not os.path.exists(filepath): return None load_data = numpy.load(filepath) result = [] i = 0 while True: key = 'arr_{}'.format(i) if key in load_data.keys(): result.append(load_data[key]) i += 1 else: break return NumpyTupleDataset(*result)