# TODO: packed variable length sequence
import time
import math
import numpy as np
import csv
import warnings
from rdkit import Chem
from rdkit import DataStructs
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from openchem.data.smiles_enumerator import SmilesEnumerator
from rdkit import rdBase
from rdkit import Chem
from openchem.utils.graph import Graph
rdBase.DisableLog('rdApp.error')
[docs]class DummyDataset(Dataset):
def __init__(self, size=1):
super(DummyDataset, self).__init__()
self.size = size
def __len__(self):
return self.size
def __getitem__(self, index):
return {
'x': torch.zeros(1),
'y': torch.zeros(1),
'num_nodes': torch.zeros(1),
'c_in': torch.zeros(1),
'c_out': torch.zeros(1),
'max_prev_nodes_local': torch.zeros(1)
}
[docs]class DummyDataLoader(object):
def __init__(self, batch_size):
self.batch_size = 32 #batch_size
self.current = 0
def __iter__(self):
return self
def __len__(self):
return self.batch_size
def __next__(self):
if self.current < self.batch_size:
self.current += 1
return None
else:
self.current = 0
raise StopIteration
[docs]def cut_padding(samples, lengths, padding='left'):
max_len = lengths.max(dim=0)[0].cpu().numpy()
if padding == 'right':
cut_samples = samples[:, :max_len]
elif padding == 'left':
total_len = samples.size()[1]
cut_samples = samples[:, total_len - max_len:]
else:
raise ValueError('Invalid value for padding argument. Must be right' 'or left')
return cut_samples
[docs]def seq2tensor(seqs, tokens, flip=True):
tensor = np.zeros((len(seqs), len(seqs[0])))
for i in range(len(seqs)):
for j in range(len(seqs[i])):
if seqs[i][j] in tokens:
tensor[i, j] = tokens.index(seqs[i][j])
else:
tokens = tokens + seqs[i][j]
tensor[i, j] = tokens.index(seqs[i][j])
if flip:
tensor = np.flip(tensor, axis=1).copy()
return tensor, tokens
[docs]def pad_sequences(seqs, max_length=None, pad_symbol=' '):
if max_length is None:
max_length = -1
for seq in seqs:
max_length = max(max_length, len(seq))
lengths = []
for i in range(len(seqs)):
cur_len = len(seqs[i])
lengths.append(cur_len)
seqs[i] = seqs[i] + pad_symbol * (max_length - cur_len)
return seqs, lengths
[docs]def create_loader(dataset, batch_size, shuffle=True, num_workers=1, pin_memory=False, sampler=None):
data_loader = DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory,
sampler=sampler)
return data_loader
[docs]def sanitize_smiles(smiles,
canonize=True,
min_atoms=-1,
max_atoms=-1,
return_num_atoms=False,
allowed_tokens=None,
allow_charges=False,
return_max_len=False,
logging="warn"):
"""
Takes list of SMILES strings and returns list of their sanitized versions.
For definition of sanitized SMILES check
http://www.rdkit.org/docs/api/rdkit.Chem.rdmolops-module.html#SanitizeMol
Args:
smiles (list): list of SMILES strings
canonize (bool): parameter specifying whether to return
canonical SMILES or not.
min_atoms (int): minimum allowed number of atoms
max_atoms (int): maxumum allowed number of atoms
return_num_atoms (bool): return additional array of atom numbers
allowed_tokens (iterable, optional): allowed tokens set
allow_charges (bool): allow nonzero charges of atoms
logging ("warn", "info", "none"): logging level
Output:
new_smiles (list): list of SMILES and NaNs if SMILES string is
invalid or unsanitized.
If 'canonize = True', return list of canonical SMILES.
When 'canonize = True' the function is analogous to:
canonize_smiles(smiles, sanitize=True).
"""
assert logging in ["warn", "info", "none"]
new_smiles = []
idx = []
num_atoms = []
smiles_lens = []
for i in range(len(smiles)):
sm = smiles[i]
mol = Chem.MolFromSmiles(sm, sanitize=False)
sm_new = Chem.MolToSmiles(mol) if canonize and mol is not None else sm
good = mol is not None
if good and allowed_tokens is not None:
good &= all([t in allowed_tokens for t in sm_new])
if good and not allow_charges:
good &= all([a.GetFormalCharge() == 0 for a in mol.GetAtoms()])
if good:
n = mol.GetNumAtoms()
if (n < min_atoms and min_atoms > -1) or (n > max_atoms > -1):
good = False
else:
n = 0
if good:
new_smiles.append(sm_new)
idx.append(i)
num_atoms.append(n)
smiles_lens.append(len(sm_new))
else:
new_smiles.append(' ')
num_atoms.append(0)
smiles_set = set(new_smiles)
num_unique = len(smiles_set) - ('' in smiles_set)
if len(idx) > 0:
valid_unique_rate = float(num_unique) / len(idx)
invalid_rate = 1.0 - float(len(idx)) / len(smiles)
else:
valid_unique_rate = 0.0
invalid_rate = 1.0
num_bad = len(smiles) - len(idx)
if len(idx) != len(smiles) and logging == "warn":
warnings.warn('{:d}/{:d} unsanitized smiles ({:.1f}%)'.format(num_bad, len(smiles), 100 * invalid_rate))
elif logging == "info":
print("Valid: {}/{} ({:.2f}%)".format(len(idx), len(smiles), 100 * (1 - invalid_rate)))
print("Unique valid: {:.2f}%".format(100 * valid_unique_rate))
if return_num_atoms and return_max_len:
return new_smiles, idx, num_atoms, max(smiles_lens)
elif return_num_atoms and not return_max_len:
return new_smiles, idx, num_atoms
elif not return_num_atoms and return_max_len:
return new_smiles, idx, max(smiles_lens)
else:
return new_smiles, idx
[docs]def canonize_smiles(smiles, sanitize=True):
"""
Takes list of SMILES strings and returns list of their canonical SMILES.
Args:
smiles (list): list of SMILES strings
sanitize (bool): parameter specifying whether to sanitize
SMILES or not.
For definition of sanitized SMILES check
www.rdkit.org/docs/api/rdkit.Chem.rdmolops-module.html#SanitizeMol
Output:
new_smiles (list): list of canonical SMILES and NaNs
if SMILES string is invalid or unsanitized
(when 'sanitize = True')
When 'sanitize = True' the function is analogous to:
sanitize_smiles(smiles, canonize=True).
"""
new_smiles = []
idx = []
for i in range(len(smiles)):
sm = smiles[i]
try:
new_smiles.append(Chem.MolToSmiles(Chem.MolFromSmiles(sm, sanitize=sanitize)))
idx.append(i)
except:
new_smiles.append('')
if len(idx) != len(smiles):
invalid_rate = 1.0 - len(idx) / len(smiles)
warnings.warn('Proportion of unsanitized smiles is %.3f ' % (invalid_rate))
return new_smiles
[docs]def save_smi_to_file(filename, smiles, unique=True):
"""
Takes path to file and list of SMILES strings and writes SMILES
to the specified file.
Args:
filename (str): path to the file
smiles (list): list of SMILES strings
unique (bool): parameter specifying whether to write
only unique copies or not.
Output:
success (bool): defines whether operation
was successfully completed or not.
"""
if unique:
smiles = list(set(smiles))
else:
smiles = list(smiles)
f = open(filename, 'w')
for mol in smiles:
f.writelines([mol, '\n'])
f.close()
return f.closed
[docs]def read_smi_file(filename, unique=True):
"""
Reads SMILES from file. File must contain one SMILES string per line
with \n token in the end of the line.
Args:
filename (str): path to the file
unique (bool): return only unique SMILES
Returns:
smiles (list): list of SMILES strings from specified file.
success (bool): defines whether operation
was successfully completed or not.
If 'unique=True' this list contains only unique copies.
"""
f = open(filename, 'r')
molecules = []
for line in f:
molecules.append(line[:-1])
if unique:
molecules = list(set(molecules))
else:
molecules = list(molecules)
f.close()
return molecules, f.closed
[docs]def get_tokens(smiles, tokens=None):
"""
Returns list of unique tokens, token-2-index dictionary and
number of unique tokens from the list of SMILES
Args:
smiles (list): list of SMILES strings to tokenize.
tokens (string): string of tokens or None.
If none will be extracted from dataset.
Returns:
tokens (list): list of unique tokens/SMILES alphabet.
token2idx (dict): dictionary mapping token to its index.
num_tokens (int): number of unique tokens.
"""
if tokens is None:
tokens = list(set(''.join(smiles)))
tokens = sorted(tokens)
tokens = ''.join(tokens)
token2idx = dict((token, i) for i, token in enumerate(tokens))
num_tokens = len(tokens)
return tokens, token2idx, num_tokens
[docs]def augment_smiles(smiles, labels, n_augment=5):
smiles_augmentation = SmilesEnumerator()
augmented_smiles = []
augmented_labels = []
for i in range(len(smiles)):
sm = smiles[i]
for _ in range(n_augment):
augmented_smiles.append(smiles_augmentation.randomize_smiles(sm))
augmented_labels.append(labels[i])
augmented_smiles.append(sm)
augmented_labels.append(labels[i])
return augmented_smiles, augmented_labels
[docs]def time_since(since):
s = time.time() - since
m = math.floor(s / 60)
s -= m * 60
return '%dm %ds' % (m, s)
[docs]def read_smiles_property_file(path, cols_to_read, delimiter=',', keep_header=False):
reader = csv.reader(open(path, 'r'), delimiter=delimiter)
data = list(reader)
if keep_header:
start_position = 0
else:
start_position = 1
assert len(data) > start_position
data = map(list, zip(*data))
data = [d for c, d in enumerate(data)]
data_ = [data[c][start_position:] for c in cols_to_read]
return data_
[docs]def save_smiles_property_file(path, smiles, labels, delimiter=','):
f = open(path, 'w')
n_targets = labels.shape[1]
for i in range(len(smiles)):
f.writelines(smiles[i])
for j in range(n_targets):
f.writelines(delimiter + str(labels[i, j]))
f.writelines('\n')
f.close()
[docs]def process_smiles(smiles,
sanitized=False,
target=None,
augment=False,
pad=True,
tokenize=True,
tokens=None,
flip=False,
allowed_tokens=None):
if not sanitized:
clean_smiles, clean_idx = sanitize_smiles(smiles, allowed_tokens=allowed_tokens)
clean_smiles = [clean_smiles[i] for i in clean_idx]
if target is not None:
target = target[clean_idx]
else:
clean_smiles = smiles
length = None
if augment and target is not None:
clean_smiles, target = augment_smiles(clean_smiles, target)
if pad:
clean_smiles, length = pad_sequences(clean_smiles)
tokens, token2idx, num_tokens = get_tokens(clean_smiles, tokens)
if tokenize:
clean_smiles, tokens = seq2tensor(clean_smiles, tokens, flip)
return clean_smiles, target, length, tokens, token2idx, num_tokens
[docs]def process_graphs(smiles,
node_attributes,
get_atomic_attributes,
edge_attributes,
get_bond_attributes=None,
kekulize=True):
clean_smiles, clean_idx, num_atoms = sanitize_smiles(smiles, return_num_atoms=True)
clean_smiles = [clean_smiles[i] for i in clean_idx]
max_size = np.max(num_atoms)
adj = []
node_feat = []
for sm in clean_smiles:
graph = Graph(sm, max_size, get_atomic_attributes, get_bond_attributes, kekulize=kekulize)
node_feature_matrix = graph.get_node_feature_matrix(node_attributes, max_size)
# TODO: remove diagonal elements from adjacency matrix
if get_bond_attributes is None:
adj_matrix = graph.adj_matrix
else:
adj_matrix = graph.get_edge_attr_adj_matrix(edge_attributes, max_size)
adj.append(adj_matrix.astype('float32'))
node_feat.append(node_feature_matrix.astype('float32'))
return adj, node_feat
[docs]def get_fp(smiles, n_bits=2048):
fp = []
processed_indices = []
invalid_indices = []
for i in range(len(smiles)):
mol = smiles[i]
tmp = np.array(mol2image(mol, n=n_bits))
if np.isnan(tmp[0]):
invalid_indices.append(i)
else:
fp.append(tmp)
processed_indices.append(i)
return np.array(fp, dtype="float32"), processed_indices, invalid_indices
[docs]def mol2image(x, n=2048):
try:
m = Chem.MolFromSmiles(x)
fp = Chem.RDKFingerprint(m, maxPath=4, fpSize=n)
res = np.zeros(len(fp))
DataStructs.ConvertToNumpyArray(fp, res)
return res
except Exception as e:
print(e)
return [np.nan]