Source code for chainer_chemistry.datasets.qm9

import glob
from logging import getLogger
import os
import shutil
import tarfile
import tempfile

from chainer.dataset import download
import numpy
import pandas
from tqdm import tqdm

from chainer_chemistry.dataset.parsers.csv_file_parser import CSVFileParser
from chainer_chemistry.dataset.preprocessors.atomic_number_preprocessor import AtomicNumberPreprocessor  # NOQA

download_url = 'https://ndownloader.figshare.com/files/3195389'
file_name = 'qm9.csv'

_root = 'pfnet/chainer/qm9'

_label_names = ['A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2',
                'zpve', 'U0', 'U', 'H', 'G', 'Cv']
_smiles_column_names = ['SMILES1', 'SMILES2']


def get_qm9_label_names():
    """Returns label names of QM9 datasets."""
    return _label_names


[docs]def get_qm9(preprocessor=None, labels=None, return_smiles=False, target_index=None): """Downloads, caches and preprocesses QM9 dataset. Args: preprocessor (BasePreprocessor): Preprocessor. This should be chosen based on the network to be trained. If it is None, default `AtomicNumberPreprocessor` is used. labels (str or list): List of target labels. return_smiles (bool): If set to ``True``, smiles array is also returned. target_index (list or None): target index list to partially extract dataset. If None (default), all examples are parsed. Returns: dataset, which is composed of `features`, which depends on `preprocess_method`. """ labels = labels or get_qm9_label_names() if isinstance(labels, str): labels = [labels, ] def postprocess_label(label_list): # This is regression task, cast to float value. return numpy.asarray(label_list, dtype=numpy.float32) if preprocessor is None: preprocessor = AtomicNumberPreprocessor() parser = CSVFileParser(preprocessor, postprocess_label=postprocess_label, labels=labels, smiles_col='SMILES1') result = parser.parse(get_qm9_filepath(), return_smiles=return_smiles, target_index=target_index) if return_smiles: return result['dataset'], result['smiles'] else: return result['dataset']
def get_qm9_filepath(download_if_not_exist=True): """Construct a filepath which stores qm9 dataset for config_name This method check whether the file exist or not, and downloaded it if necessary. Args: download_if_not_exist (bool): If `True` download dataset if it is not downloaded yet. Returns (str): file path for qm9 dataset (formatted to csv) """ cache_path = _get_qm9_filepath() if not os.path.exists(cache_path): if download_if_not_exist: is_successful = download_and_extract_qm9(save_filepath=cache_path) if not is_successful: logger = getLogger(__name__) logger.warning('Download failed.') return cache_path def _get_qm9_filepath(): """Construct a filepath which stores QM9 dataset in csv This method does not check if the file is already downloaded or not. Returns (str): filepath for qm9 dataset """ cache_root = download.get_dataset_directory(_root) cache_path = os.path.join(cache_root, file_name) return cache_path def download_and_extract_qm9(save_filepath): logger = getLogger(__name__) logger.warning('Extracting QM9 dataset, it takes time...') download_file_path = download.cached_download(download_url) tf = tarfile.open(download_file_path, 'r') temp_dir = tempfile.mkdtemp() tf.extractall(temp_dir) file_re = os.path.join(temp_dir, '*.xyz') file_pathes = glob.glob(file_re) # Make sure the order is sorted file_pathes.sort() ls = [] for path in tqdm(file_pathes): with open(path, 'r') as f: data = [line.strip() for line in f] num_atom = int(data[0]) properties = list(map(float, data[1].split('\t')[1:])) smiles = data[3 + num_atom].split('\t') new_ls = smiles + properties ls.append(new_ls) df = pandas.DataFrame(ls, columns=_smiles_column_names + _label_names) df.to_csv(save_filepath) shutil.rmtree(temp_dir) return True