Source code for chainer_chemistry.dataset.preprocessors.weavenet_preprocessor

import os

import numpy
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import ChemicalFeatures
from rdkit import RDConfig

from chainer_chemistry.config import WEAVE_DEFAULT_NUM_MAX_ATOMS
from chainer_chemistry.dataset.preprocessors.common \
    import construct_atomic_number_array
from chainer_chemistry.dataset.preprocessors.common \
    import MolFeatureExtractionError
from chainer_chemistry.dataset.preprocessors.common import type_check_num_atoms
from chainer_chemistry.dataset.preprocessors.mol_preprocessor \
    import MolPreprocessor


ATOM = ['H', 'C', 'N', 'O', 'S', 'Cl', 'Br', 'F', 'P', 'I']
MAX_DISTANCE = 2  # 7


# --- Atom feature extraction ---
def construct_atom_type_vec(mol, num_max_atoms=WEAVE_DEFAULT_NUM_MAX_ATOMS,
                            atom_list=None, include_unknown_atom=False):
    atom_list = atom_list or ATOM
    if include_unknown_atom:
        # all atom not in `atom_list` as considered as "unknown atom"
        # and its index is `len(atom_list)`
        n_atom_type = len(atom_list) + 1
    else:
        n_atom_type = len(atom_list)
    n_atom = mol.GetNumAtoms()
    atom_type_vec = numpy.zeros((num_max_atoms, n_atom_type),
                                dtype=numpy.float32)
    for i in range(n_atom):
        a = mol.GetAtomWithIdx(i)
        try:
            atom_idx = atom_list.index(a.GetSymbol())
        except ValueError as e:
            if include_unknown_atom:
                atom_idx = len(atom_list)
            else:
                raise MolFeatureExtractionError(e)
        atom_type_vec[i, atom_idx] = 1.0
    return atom_type_vec


def construct_formal_charge_vec(mol,
                                num_max_atoms=WEAVE_DEFAULT_NUM_MAX_ATOMS):
    n_atom = mol.GetNumAtoms()
    formal_charge_vec = numpy.zeros((num_max_atoms, 1), dtype=numpy.float32)
    for i in range(n_atom):
        a = mol.GetAtomWithIdx(i)
        formal_charge_vec[i, 0] = a.GetFormalCharge()
    return formal_charge_vec


def construct_hybridization_vec(mol,
                                num_max_atoms=WEAVE_DEFAULT_NUM_MAX_ATOMS):
    # TODO(Oono)
    # Can we enhance preprocessing speed by making factory once
    # prior to calling this function many times?
    n_atom = mol.GetNumAtoms()
    hybridization_vec = numpy.zeros((num_max_atoms, 3), dtype=numpy.float32)
    for i in range(n_atom):
        a = mol.GetAtomWithIdx(i)
        if a.GetHybridization() is None:
            continue
        hybridization_type = str(a.GetHybridization())
        if hybridization_type == 'SP1':
            hybridization_vec[i, 0] = 1.0
        elif hybridization_type == 'SP2':
            hybridization_vec[i, 1] = 1.0
        elif hybridization_type == 'SP3':
            hybridization_vec[i, 2] = 1.0
    return hybridization_vec


def construct_partial_charge_vec(
        mol, num_max_atoms=WEAVE_DEFAULT_NUM_MAX_ATOMS):
    AllChem.ComputeGasteigerCharges(mol)
    n = mol.GetNumAtoms()
    partial_charge_vec = numpy.zeros((num_max_atoms, 1), dtype=numpy.float32)
    for i in range(n):
        a = mol.GetAtomWithIdx(i)
        partial_charge_vec[i, 0] = a.GetProp("_GasteigerCharge")
    return partial_charge_vec


def construct_atom_ring_vec(mol, num_max_atoms=WEAVE_DEFAULT_NUM_MAX_ATOMS):
    nAtom = mol.GetNumAtoms()
    sssr = Chem.GetSymmSSSR(mol)
    ring_feature = numpy.zeros((num_max_atoms, 6,), dtype=numpy.float32)
    for ring in sssr:
        ring = list(ring)
        for i in range(nAtom):
            if i in ring:
                ring_size = len(ring)
                if ring_size >= 3 and ring_size <= 8:
                    ring_feature[i, ring_size - 3] = 1.0
    return ring_feature


def construct_hydrogen_bonding(mol, num_max_atoms=WEAVE_DEFAULT_NUM_MAX_ATOMS):
    fdefName = os.path.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
    factory = ChemicalFeatures.BuildFeatureFactory(fdefName)
    feats = factory.GetFeaturesForMol(mol)
    hydrogen_bonding_vec = numpy.zeros((num_max_atoms, 2), dtype=numpy.float32)
    for f in feats:
        if f.GetFamily() == 'Donor':
            idx = f.GetAtomIds()[0]
            hydrogen_bonding_vec[idx, 0] = 1.0
        if f.GetFamily() == 'Acceptor':
            idx = f.GetAtomIds()[0]
            hydrogen_bonding_vec[idx, 1] = 1.0
    return hydrogen_bonding_vec


def construct_num_hydrogens_vec(mol,
                                num_max_atoms=WEAVE_DEFAULT_NUM_MAX_ATOMS):
    n_hydrogen_vec = numpy.zeros((num_max_atoms, 1), dtype=numpy.float32)
    n_atom = mol.GetNumAtoms()
    for i in range(n_atom):
        n = 0
        for j in range(n_atom):
            if i == j:
                continue
            a = mol.GetAtomWithIdx(j)
            if a.GetSymbol() != 'H':
                continue
            k = mol.GetBondBetweenAtoms(i, j)
            if k is not None:
                n += 1
        n_hydrogen_vec[i, 0] = n
    return n_hydrogen_vec


def construct_aromaticity_vec(mol, num_max_atoms=WEAVE_DEFAULT_NUM_MAX_ATOMS):
    aromaticity_vec = numpy.zeros((num_max_atoms, 1), dtype=numpy.float32)
    aromatix_atoms = mol.GetAromaticAtoms()
    for a in aromatix_atoms:
        aromaticity_vec[a.GetIdx()] = 1.0

    return aromaticity_vec


def construct_atom_feature(mol, add_Hs,
                           num_max_atoms=WEAVE_DEFAULT_NUM_MAX_ATOMS,
                           atom_list=None, include_unknown_atom=False):
    """construct atom feature

    Args:
        mol (Mol): mol instance
        add_Hs (bool): if the `mol` instance was added Hs, set True.
        num_max_atoms (int): number of max atoms
        atom_list (list): list of atoms to extract feature. If None, default
            `ATOM` is used as `atom_list`
        include_unknown_atom (bool): If False, when the `mol` includes atom
            which is not in `atom_list`, it will raise
            `MolFeatureExtractionError`.
            If True, even the atom is not in `atom_list`, `atom_type` is set
            as "unknown" atom.

    Returns (numpy.ndarray): 2 dimensional array. First axis size is
        `num_max_atoms`, representing each atom index.
        Second axis for feature.

    """
    atom_type_vec = construct_atom_type_vec(
        mol, num_max_atoms, atom_list=atom_list,
        include_unknown_atom=include_unknown_atom)
    # TODO(nakago): Chilarity
    formal_charge_vec = construct_formal_charge_vec(mol)
    partial_charge_vec = construct_partial_charge_vec(mol)
    atom_ring_vec = construct_atom_ring_vec(mol)
    hybridization_vec = construct_hybridization_vec(mol)
    hydrogen_bonding = construct_hydrogen_bonding(mol)
    aromaticity_vec = construct_aromaticity_vec(mol)
    if add_Hs:
        num_hydrogens_vec = construct_num_hydrogens_vec(mol)
        feature = numpy.hstack((atom_type_vec, formal_charge_vec,
                                partial_charge_vec, atom_ring_vec,
                                hybridization_vec, hydrogen_bonding,
                                aromaticity_vec, num_hydrogens_vec))
    else:
        feature = numpy.hstack((atom_type_vec, formal_charge_vec,
                                partial_charge_vec, atom_ring_vec,
                                hybridization_vec, hydrogen_bonding,
                                aromaticity_vec))
    return feature


# --- Pair feature extraction ---
def construct_bond_vec(mol, i, j):
    bond_feature_vec = numpy.zeros((4, ), dtype=numpy.float32)
    k = mol.GetBondBetweenAtoms(i, j)
    if k is not None:
        bond_type = str(k.GetBondType())
        if bond_type == 'SINGLE':
            bond_feature_vec[0] = 1.0
        elif bond_type == 'DOUBLE':
            bond_feature_vec[1] = 1.0
        elif bond_type == 'TRIPLE':
            bond_feature_vec[2] = 1.0
        elif bond_type == 'AROMATIC':
            bond_feature_vec[3] = 1.0
        else:
            raise ValueError("Unknown bond type {}".format(bond_type))
    return bond_feature_vec


def construct_distance_vec(distance_matrix, i, j):
    distance = min(MAX_DISTANCE, int(distance_matrix[i][j]))
    distance_feature = numpy.zeros((MAX_DISTANCE, ), dtype=numpy.float32)
    distance_feature[:distance] = 1.0
    return distance_feature


def construct_ring_feature_vec(mol, num_max_atoms=WEAVE_DEFAULT_NUM_MAX_ATOMS):
    n_atom = mol.GetNumAtoms()
    sssr = Chem.GetSymmSSSR(mol)
    ring_feature_vec = numpy.zeros(
        (num_max_atoms ** 2, 1,), dtype=numpy.float32)
    for ring in sssr:
        ring = list(ring)
        n_atom_in_ring = len(ring)
        for i in range(n_atom_in_ring):
            for j in range(n_atom_in_ring):
                a0 = ring[i]
                a1 = ring[j]
                ring_feature_vec[a0 * n_atom + a1] = 1
    return ring_feature_vec


def construct_pair_feature(mol, num_max_atoms=WEAVE_DEFAULT_NUM_MAX_ATOMS):
    """construct pair feature

    Args:
        mol (Mol): mol instance
        num_max_atoms (int): number of max atoms

    Returns (numpy.ndarray): 2 dimensional array. First axis size is
        `num_max_atoms` ** 2, representing index of each atom pair.
        Second axis for feature.

    """
    n_atom = mol.GetNumAtoms()
    distance_matrix = Chem.GetDistanceMatrix(mol)
    distance_feature = numpy.zeros((num_max_atoms ** 2, MAX_DISTANCE,),
                                   dtype=numpy.float32)
    for i in range(n_atom):
        for j in range(n_atom):
            distance_feature[i * n_atom + j] = construct_distance_vec(
                distance_matrix, i, j)
    bond_feature = numpy.zeros((num_max_atoms ** 2, 4,), dtype=numpy.float32)
    for i in range(n_atom):
        for j in range(n_atom):
            bond_feature[i * n_atom + j] = construct_bond_vec(mol, i, j)
    ring_feature = construct_ring_feature_vec(mol, num_max_atoms=num_max_atoms)
    feature = numpy.hstack((distance_feature, bond_feature, ring_feature))
    return feature


[docs]class WeaveNetPreprocessor(MolPreprocessor): """WeaveNetPreprocessor WeaveNet must have fixed-size atom list for now, zero_padding option is always set to True. Args: max_atoms (int): Max number of atoms for each molecule, if the number of atoms is more than this value, this data is simply ignored. Setting negative value indicates no limit for max atoms. add_Hs (bool): If True, implicit Hs are added. use_fixed_atom_feature (bool): If True, atom feature is extracted used in original paper. If it is False, atomic number is used instead. atom_list (list): list of atoms to extract feature. If None, default `ATOM` is used as `atom_list` include_unknown_atom (bool): If False, when the `mol` includes atom which is not in `atom_list`, it will raise `MolFeatureExtractionError`. If True, even the atom is not in `atom_list`, `atom_type` is set as "unknown" atom. kekulize (bool): If True, Kekulizes the molecule. """
[docs] def __init__(self, max_atoms=WEAVE_DEFAULT_NUM_MAX_ATOMS, add_Hs=True, use_fixed_atom_feature=False, atom_list=None, include_unknown_atom=False, kekulize=False): super(WeaveNetPreprocessor, self).__init__( add_Hs=add_Hs, kekulize=kekulize) zero_padding = True if zero_padding and max_atoms <= 0: raise ValueError('max_atoms must be set to positive value when ' 'zero_padding is True') self.max_atoms = max_atoms self.add_Hs = add_Hs self.zero_padding = zero_padding self.use_fixed_atom_feature = use_fixed_atom_feature self.atom_list = atom_list self.include_unknown_atom = include_unknown_atom
def get_input_features(self, mol): """get input features for WeaveNet WeaveNetPreprocessor automatically add `H` to `mol` Args: mol (Mol): """ type_check_num_atoms(mol, self.max_atoms) if self.use_fixed_atom_feature: # original paper feature extraction atom_array = construct_atom_feature(mol, self.add_Hs, self.max_atoms, self.atom_list, self.include_unknown_atom) else: # embed id of atomic numbers atom_array = construct_atomic_number_array(mol, self.max_atoms) pair_feature = construct_pair_feature(mol, num_max_atoms=self.max_atoms) return atom_array, pair_feature