Source code for modules.embeddings.openchem_embedding

import torch.nn as nn

from openchem.utils.utils import check_params


[docs]class OpenChemEmbedding(nn.Module): def __init__(self, params): super(OpenChemEmbedding, self).__init__() check_params(params, self.get_required_params(), self.get_optional_params()) self.params = params self.num_embeddings = self.params['num_embeddings'] if 'padding_idx' in params.keys(): self.padding_idx = self.params['padding_idx'] else: self.padding_idx = None
[docs] def forward(self, inp): raise NotImplementedError
[docs] @staticmethod def get_required_params(): return { 'num_embeddings': int, }
[docs] @staticmethod def get_optional_params(): return {'padding_idx': int}