Source code for models.Graph2Label
from openchem.models.openchem_model import OpenChemModel
import torch
[docs]class Graph2Label(OpenChemModel):
r"""
Creates a model that predicts one or multiple labels given object of
class graph as input. Consists of 'graph convolution neural network
encoder'__, followed by 'graph max pooling layer'__ and
multilayer perceptron.
__https://arxiv.org/abs/1609.02907
__https://pubs.acs.org/doi/full/10.1021/acscentsci.6b00367
Args:
params (dict): dictionary of parameters describing the model
architecture.
"""
def __init__(self, params):
super(Graph2Label, self).__init__(params)
self.encoder = self.params['encoder']
self.encoder_params = self.params['encoder_params']
self.Encoder = self.encoder(self.encoder_params, self.use_cuda)
self.mlp = self.params['mlp']
self.mlp_params = self.params['mlp_params']
self.MLP = self.mlp(self.mlp_params)
[docs] def forward(self, inp, eval=False):
if eval:
self.eval()
else:
self.train()
output = self.Encoder(inp)
output = self.MLP(output)
return output
[docs] @staticmethod
def cast_inputs(sample, task, use_cuda, for_prediction=False):
batch_adj = sample['adj_matrix'].to(torch.float)
batch_x = sample['node_feature_matrix'].to(torch.float)
if for_prediction and "object" in sample.keys():
batch_object = sample["object"]
else:
batch_object = None
if not for_prediction and 'labels' in sample.keys():
batch_labels = sample['labels']
if task == 'classification':
batch_labels = batch_labels.to(torch.long)
else:
batch_labels = batch_labels.to(torch.float)
else:
batch_labels = None
if use_cuda:
batch_x = batch_x.to(device='cuda')
batch_adj = batch_adj.to(device='cuda')
if batch_labels is not None:
batch_labels = batch_labels.to(device='cuda')
batch_inp = (batch_x, batch_adj)
if batch_object is not None:
return batch_inp, batch_object
else:
return batch_inp, batch_labels