import torch
from torch import nn
from openchem.utils.utils import check_params
[docs]class OpenChemEncoder(nn.Module):
"""Base class for embedding module"""
def __init__(self, params, use_cuda=None):
super(OpenChemEncoder, self).__init__()
check_params(params, self.get_required_params(), self.get_required_params())
self.params = params
if use_cuda is None:
use_cuda = torch.cuda.is_available()
self.use_cuda = use_cuda
self.input_size = self.params['input_size']
self.encoder_dim = self.params['encoder_dim']
[docs] @staticmethod
def get_required_params():
return {'input_size': int, 'encoder_dim': int}
[docs] @staticmethod
def get_optional_params():
return {}
[docs] def forward(self, inp):
raise NotImplementedError