Source code for modules.embeddings.basic_embedding

from openchem.modules.embeddings.openchem_embedding import OpenChemEmbedding

from torch import nn


[docs]class Embedding(OpenChemEmbedding): def __init__(self, params): super(Embedding, self).__init__(params) self.embedding_dim = self.params['embedding_dim'] self.embedding = nn.Embedding(num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dim, padding_idx=self.padding_idx)
[docs] def forward(self, inp): embedded = self.embedding(inp) return embedded
[docs] @staticmethod def get_required_params(): return { 'embedding_dim': int, }