Source code for data.smiles_protein_data_layer

# TODO: packed variable length sequence

import numpy as np
import pickle

from torch.utils.data import Dataset

from openchem.data.utils import read_smiles_property_file
from openchem.data.utils import sanitize_smiles, pad_sequences, seq2tensor
from openchem.data.utils import get_tokens


[docs]class SmilesProteinDataset(Dataset): def __init__(self, filename, tokenized=False, cols_to_read=None, delimiter=',', mol_tokens=None, prot_tokens=None, pad=True): super(SmilesProteinDataset, self).__init__() if not tokenized: data = read_smiles_property_file(filename, cols_to_read, delimiter) smiles = data[0] proteins = np.array(data[1]) target = np.array(data[2], dtype='float') clean_smiles, clean_idx = sanitize_smiles(smiles) self.target = target[clean_idx] proteins = list(proteins[clean_idx]) if pad: clean_smiles, self.mol_lengths = pad_sequences(clean_smiles) proteins, self.prot_lengths = pad_sequences(proteins) self.mol_tokens, self.mol_token2idx, self.mol_num_tokens = \ get_tokens(clean_smiles, mol_tokens) self.prot_tokens, self.prot_token2idx, self.prot_num_tokens = \ get_tokens(proteins, prot_tokens) clean_smiles = seq2tensor(clean_smiles, self.mol_tokens) proteins = seq2tensor(proteins, self.prot_tokens) self.molecules = clean_smiles self.proteins = proteins else: f = open(filename, 'rb') data = pickle.load(f) self.mol_tokens = data['smiles_tokens'] self.prot_tokens = data['proteins_tokens'] self.mol_num_tokens = len(data['smiles_tokens']) self.prot_num_tokens = len(data['proteins_tokens']) self.molecules = data['smiles'] self.proteins = data['proteins'] self.target = data['labels'] assert len(self.molecules) == len(self.proteins) assert len(self.molecules) == len(self.target) def __len__(self): return len(self.target) def __getitem__(self, index): sample = {'tokenized_smiles': self.molecules[index], 'tokenized_protein': self.proteins[index], 'labels': self.target[index], 'mol_length': self.mol_lengths[index], 'prot_length': self.prot_lengths[index]} return sample