Source code for openchem.utils.utils

import time
import math
from six import string_types
from collections import defaultdict

import torch
import glob
import os
import six


[docs]def move_to_cuda(sample): # copy-pasted from # https://github.com/pytorch/fairseq/blob/master/fairseq/utils.py if len(sample) == 0: return {} def _move_to_cuda(maybe_tensor): if torch.is_tensor(maybe_tensor): return maybe_tensor.cuda() elif isinstance(maybe_tensor, dict): return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()} elif isinstance(maybe_tensor, list): return [_move_to_cuda(x) for x in maybe_tensor] else: return maybe_tensor return _move_to_cuda(sample)
[docs]def get_latest_checkpoint(path): if os.path.isdir(path) and os.listdir(path) != []: list_of_files = glob.glob(os.path.join(path, '*')) latest_file = max(list_of_files, key=os.path.getctime) return latest_file else: return None
[docs]def deco_print(line, offset=0, start="*** ", end='\n'): # copy-pasted from # github.com/NVIDIA/OpenSeq2Seq/blob/master/open_seq2seq/utils/utils.py if six.PY2: print((start + " " * offset + line).encode('utf-8'), end=end) else: print(start + " " * offset + line, end=end)
[docs]def flatten_dict(dct): # copy-pasted from # github.com/NVIDIA/OpenSeq2Seq/blob/master/open_seq2seq/utils/utils.py flat_dict = {} for key, value in dct.items(): if isinstance(value, int) or isinstance(value, float) or \ isinstance(value, string_types) or isinstance(value, bool): flat_dict.update({key: value}) elif isinstance(value, dict): flat_dict.update({key + '/' + k: v for k, v in flatten_dict(dct[key]).items()}) return flat_dict
[docs]def nest_dict(flat_dict): # copy-pasted from # github.com/NVIDIA/OpenSeq2Seq/blob/master/open_seq2seq/utils/utils.py nst_dict = {} for key, value in flat_dict.items(): nest_keys = key.split('/') cur_dict = nst_dict for i in range(len(nest_keys) - 1): if nest_keys[i] not in cur_dict: cur_dict[nest_keys[i]] = {} cur_dict = cur_dict[nest_keys[i]] cur_dict[nest_keys[-1]] = value return nst_dict
[docs]def nested_update(org_dict, upd_dict): # copy-pasted from # github.com/NVIDIA/OpenSeq2Seq/blob/master/open_seq2seq/utils/utils.py for key, value in upd_dict.items(): if isinstance(value, dict): if key in org_dict: if not isinstance(org_dict[key], dict): raise ValueError("Mismatch between org_dict and upd_dict " "at node {}".format(key)) nested_update(org_dict[key], value) else: org_dict[key] = value else: org_dict[key] = value
[docs]def time_since(since): s = time.time() - since m = math.floor(s / 60) s -= m * 60 return '%dm %ds' % (m, s)
[docs]def identity(input): return input
[docs]def calculate_metrics(predicted, ground_truth, metrics): return metrics(ground_truth, predicted)
[docs]def check_params(config, required_dict, optional_dict): if required_dict is None or optional_dict is None: return for pm, vals in required_dict.items(): if pm not in config: raise ValueError("{} parameter has to be specified".format(pm)) else: if vals == str: vals = string_types if vals and isinstance(vals, list) and config[pm] not in vals: raise ValueError("{} has to be one of {}".format(pm, vals)) if vals and not isinstance(vals, list) and \ not isinstance(config[pm], vals): raise ValueError("{} has to be of type {}".format(pm, vals)) for pm, vals in optional_dict.items(): if vals == str: vals = string_types if pm in config: if vals and isinstance(vals, list) and config[pm] not in vals: raise ValueError("{} has to be one of {}".format(pm, vals)) if vals and not isinstance(vals, list) and \ not isinstance(config[pm], vals): raise ValueError("{} has to be of type {}".format(pm, vals))
# for pm in config: # if pm not in required_dict and pm not in optional_dict: # raise ValueError("Unknown parameter: {}".format(pm))
[docs]def cross_validation_split(data, targets, n_folds=5, split='random', stratified=True, folds=None): if split not in ['random', 'fixed']: raise ValueError('Invalid value for argument \'split\': ' 'must be either \'random\' or \'fixed\'.') if split == 'fixed' and folds is None: raise ValueError('When \'split\' is \'fixed\' ' 'argument \'folds\' must be provided.') if split == 'fixed': assert len(targets) == len(folds) raise NotImplementedError
[docs]def make_positions(tensor, padding_idx, left_pad): # Copy-pasted from # https://github.com/pytorch/fairseq/blob/master/fairseq/utils.py """Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored, but it is necessary to specify whether padding is added on the left side (left_pad=True) or right side (left_pad=False). """ max_pos = padding_idx + 1 + tensor.size(1) if not hasattr(make_positions, 'range_buf'): make_positions.range_buf = tensor.new() make_positions.range_buf = make_positions.range_buf.type_as(tensor) if make_positions.range_buf.numel() < max_pos: torch.arange(padding_idx + 1, max_pos, out=make_positions.range_buf) mask = tensor.ne(padding_idx) positions = make_positions.range_buf[:tensor.size(1)].expand_as(tensor) if left_pad: positions = positions - mask.size(1) + \ mask.long().sum(dim=1).unsqueeze(1) return tensor.clone().masked_scatter_(mask, positions[mask])