Source code for optimizer.openchem_optimizer

# Modified from
# github.com/pytorch/fairseq/blob/master/fairseq/optim/fairseq_optimizer.py

# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import torch.optim


[docs]class OpenChemOptimizer(object): def __init__(self, params, model_params): self.params = params[1] self._optimizer = params[0](model_params, **self.params) @property def optimizer(self): if not isinstance(self._optimizer, torch.optim.Optimizer): raise ValueError('_optimizer must be an instance of ' 'torch.optim.Optimizer') return self._optimizer @property def param_groups(self): return self.optimizer.param_groups
[docs] def get_lr(self): """Return the current learning rate.""" return self.optimizer.param_groups[0]['lr']
[docs] def set_lr(self, lr): """Set the learning rate.""" for param_group in self.optimizer.param_groups: param_group['lr'] = lr
[docs] def state_dict(self): """Return the optimizer's state dict.""" return self.optimizer.state_dict()
[docs] def load_state_dict(self, state_dict): """Load an optimizer state dict. In general we should prefer the configuration of the existing optimizer instance (e.g., learning rate) over that found in the state_dict. This allows us to resume training from a checkpoint using a new set of optimizer args. """ self.optimizer.load_state_dict(state_dict) # override learning rate, momentum, etc. with latest values for group in self.optimizer.param_groups: group.update(self.params)
[docs] def step(self, closure=None): """Performs a single optimization step.""" return self.optimizer.step(closure)
[docs] def zero_grad(self): """Clears the gradients of all optimized parameters.""" return self.optimizer.zero_grad()