from openchem.models.openchem_model import OpenChemModel
from openchem.optimizer.openchem_optimizer import OpenChemOptimizer
from openchem.optimizer.openchem_lr_scheduler import OpenChemLRScheduler
from openchem.data.utils import cut_padding
import torch
[docs]class MoleculeProtein2Label(OpenChemModel):
r"""
Creates a model that predicts one or multiple labels given two sequences as
input. Embeddings for each input are extracted separately with Embedding
layer, followed by encoder (could be RNN or CNN encoder) and then merged
together. Last layer of the model is multi-layer perceptron.
Args:
params (dict): dictionary describing model architecture.
"""
def __init__(self, params):
super(MoleculeProtein2Label, self).__init__(params)
self.mol_embedding = self.params['mol_embedding']
self.mol_embed_params = self.params['mol_embedding_params']
self.prot_embedding = self.params['prot_embedding']
self.prot_embed_params = self.params['prot_embedding_params']
self.MolEmbedding = self.mol_embedding(self.mol_embed_params)
self.ProtEmbedding = self.prot_embedding(self.prot_embed_params)
self.mol_encoder = self.params['mol_encoder']
self.mol_encoder_params = self.params['mol_encoder_params']
self.prot_encoder = self.params['prot_encoder']
self.prot_encoder_params = self.params['prot_encoder_params']
self.MolEncoder = self.mol_encoder(self.mol_encoder_params,
self.use_cuda)
self.ProtEncoder = self.prot_encoder(self.prot_encoder_params,
self.use_cuda)
self.merge = self.params['merge']
self.mlp = self.params['mlp']
self.mlp_params = self.params['mlp_params']
self.MLP = self.mlp(self.mlp_params)
[docs] def forward(self, inp, eval=False):
if eval:
self.eval()
else:
self.train()
mol = inp[0]
prot = inp[1]
mol_embedded = self.MolEmbedding(mol)
mol_output, _ = self.MolEncoder(mol_embedded)
prot_embedded = self.ProtEmbedding(prot)
prot_output, _ = self.ProtEncoder(prot_embedded)
if self.merge == 'mul':
output = mol_output*prot_output
elif self.merge == 'concat':
output = torch.cat((mol_output, prot_output), 1)
else:
raise ValueError('Invalid value for merge')
output = self.MLP(output)
return output