Source code for chainer_chemistry.dataset.splitters.scaffold_splitter

from collections import defaultdict

import numpy
from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold

from chainer_chemistry.dataset.splitters.base_splitter import BaseSplitter


def generate_scaffold(smiles, include_chirality=False):
    """return scaffold string of target molecule"""
    mol = Chem.MolFromSmiles(smiles)
    scaffold = MurckoScaffold\
        .MurckoScaffoldSmiles(mol=mol, includeChirality=include_chirality)
    return scaffold


[docs]class ScaffoldSplitter(BaseSplitter): """Class for doing data splits by chemical scaffold. Referred Deepchem for the implementation, https://git.io/fXzF4 """ def _split(self, dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1, **kwargs): numpy.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.) seed = kwargs.get('seed', None) smiles_list = kwargs.get('smiles_list') include_chirality = kwargs.get('include_chirality') if len(dataset) != len(smiles_list): raise ValueError("The lengths of dataset and smiles_list are " "different") rng = numpy.random.RandomState(seed) scaffolds = defaultdict(list) for ind, smiles in enumerate(smiles_list): scaffold = generate_scaffold(smiles, include_chirality) scaffolds[scaffold].append(ind) scaffold_sets = rng.permutation(list(scaffolds.values())) n_total_valid = int(numpy.floor(frac_valid * len(dataset))) n_total_test = int(numpy.floor(frac_test * len(dataset))) train_index = [] valid_index = [] test_index = [] for scaffold_set in scaffold_sets: if len(valid_index) + len(scaffold_set) <= n_total_valid: valid_index.extend(scaffold_set) elif len(test_index) + len(scaffold_set) <= n_total_test: test_index.extend(scaffold_set) else: train_index.extend(scaffold_set) return numpy.array(train_index), numpy.array(valid_index),\ numpy.array(test_index),\ def train_valid_test_split(self, dataset, smiles_list, frac_train=0.8, frac_valid=0.1, frac_test=0.1, converter=None, return_index=True, seed=None, include_chirality=False, **kwargs): """Split dataset into train, valid and test set. Split indices are generated by splitting based on the scaffold of small molecules. Args: dataset(NumpyTupleDataset, numpy.ndarray): Dataset. smiles_list(list): SMILES list corresponding to datset. seed (int): Random seed. frac_train(float): Fraction of dataset put into training data. frac_valid(float): Fraction of dataset put into validation data. converter(callable): return_index(bool): If `True`, this function returns only indices. If `False`, this function returns splitted dataset. Returns: SplittedDataset(tuple): splitted dataset or indices """ return super(ScaffoldSplitter, self)\ .train_valid_test_split(dataset, frac_train, frac_valid, frac_test, converter, return_index, seed=seed, smiles_list=smiles_list, include_chirality=include_chirality, **kwargs) def train_valid_split(self, dataset, smiles_list, frac_train=0.9, frac_valid=0.1, converter=None, return_index=True, seed=None, include_chirality=False, **kwargs): """Split dataset into train and valid set. Split indices are generated by splitting based on the scaffold of small molecules. Args: dataset(NumpyTupleDataset, numpy.ndarray): Dataset. smiles_list(list): SMILES list corresponding to datset. seed (int): Random seed. frac_train(float): Fraction of dataset put into training data. frac_valid(float): Fraction of dataset put into validation data. converter(callable): return_index(bool): If `True`, this function returns only indices. If `False`, this function returns splitted dataset. Returns: SplittedDataset(tuple): splitted dataset or indices """ return super(ScaffoldSplitter, self)\ .train_valid_split(dataset, frac_train, frac_valid, converter, return_index, seed=seed, smiles_list=smiles_list, include_chirality=include_chirality, **kwargs)