Source code for openchem.utils.graph

from rdkit import Chem
import numpy as np


[docs]class Attribute: def __init__(self, attr_type, name, one_hot=True, values=None): if attr_type not in ['node', 'edge']: raise ValueError('Invalid value for attribute type: must be "node" ' 'or "edge"') self.attr_type = attr_type self.name = name if values is not None: self.n_values = len(values) self.attr_values = values self.one_hot = one_hot if self.one_hot: self.one_hot_dict = {} for i in range(self.n_values): tmp = np.zeros(self.n_values) tmp[i] = 1 self.one_hot_dict[self.attr_values[i]] = tmp
[docs]class Node: def __init__(self, idx, rdmol, get_atom_attributes, has_3D=False): rdatom = rdmol.GetAtoms()[idx] self.node_idx = idx self.atom_type = rdatom.GetAtomicNum() self.attributes_dict = get_atom_attributes(rdatom) if has_3D: pos = rdmol.GetConformer().GetAtomPosition(idx) self.pos_x = pos.x self.pos_y = pos.y self.pos_z = pos.z
[docs]class Edge: def __init__(self, rdbond, get_bond_attributes=None): self.begin_atom_idx = rdbond.GetBeginAtomIdx() self.end_atom_idx = rdbond.GetEndAtomIdx() if get_bond_attributes is not None: self.attributes_dict = get_bond_attributes(rdbond)
[docs]class Graph: """Describes an undirected graph class""" def __init__(self, mol, max_size, get_atom_attributes, get_bond_attributes=None, kekulize=True, addHs=False, has_3D=False, from_rdmol=False): self.kekulize = kekulize self.addHs = addHs self.has_3D = has_3D if from_rdmol: self.smiles = Chem.MolToSmiles(mol) else: self.smiles = mol if from_rdmol: rdmol = mol else: rdmol = Chem.MolFromSmiles(mol) if self.addHs: rdmol = Chem.AddHs(rdmol) if kekulize: Chem.Kekulize(rdmol) self.num_nodes = rdmol.GetNumAtoms() self.num_edges = rdmol.GetNumBonds() self.nodes = [] if has_3D: self.xyz = np.zeros((max_size, 3)) num_atoms = rdmol.GetNumAtoms() for k in range(num_atoms): cur_node = Node(k, rdmol, get_atom_attributes, has_3D) self.nodes.append(cur_node) if has_3D: self.xyz[k, 0] = cur_node.pos_x self.xyz[k, 1] = cur_node.pos_y self.xyz[k, 2] = cur_node.pos_z adj_matrix = np.eye(self.num_nodes) self.edges = [] for _, bond in enumerate(rdmol.GetBonds()): cur_edge = Edge(bond, get_bond_attributes) self.edges.append(cur_edge) adj_matrix[cur_edge.begin_atom_idx, cur_edge.end_atom_idx] = 1.0 adj_matrix[cur_edge.end_atom_idx, cur_edge.begin_atom_idx] = 1.0 self.adj_matrix = np.zeros((max_size, max_size)) self.adj_matrix[:self.num_nodes, :self.num_nodes] = adj_matrix if get_bond_attributes is not None and len(self.edges) > 0: tmp = self.edges[0] self.n_attr = len(tmp.attributes_dict.keys())
[docs] def get_node_attr_adj_matrix(self, attr): node_attr_adj_matrix = np.zeros((self.num_nodes, self.num_nodes, attr.n_values)) attr_one_hot = [] node_idx = [] for node in self.nodes: tmp = attr.one_hot_dict[node.attributes_dict[attr.name]] attr_one_hot.append(tmp) node_attr_adj_matrix[node.node_idx, node.node_idx] = tmp node_idx.append(node.node_idx) for edge in self.edges: begin = edge.begin_atom_idx end = edge.end_atom_idx begin_one_hot = attr_one_hot[node_idx.index(begin)] end_one_hot = attr_one_hot[node_idx.index(end)] node_attr_adj_matrix[begin, end, :] = (begin_one_hot + end_one_hot) / 2 return node_attr_adj_matrix
[docs] def get_edge_attr_adj_matrix(self, all_atr_dict, max_size): fl = True for edge in self.edges: begin = edge.begin_atom_idx end = edge.end_atom_idx cur_features = [] for attr_name in edge.attributes_dict.keys(): cur_attr = all_atr_dict[attr_name] if cur_attr.one_hot: cur_features += list(cur_attr.one_hot_dict[edge.attributes_dict[cur_attr.name]]) else: cur_features += [edge.attributes_dict[cur_attr.name]] cur_features = np.array(cur_features) attr_len = len(cur_features) if fl: edge_attr_adj_matrix = np.zeros((max_size, max_size, attr_len)) fl = False edge_attr_adj_matrix[begin, end, :] = cur_features edge_attr_adj_matrix[end, begin, :] = cur_features return edge_attr_adj_matrix
[docs] def get_node_feature_matrix(self, all_atr_dict, max_size): features = [] for node in self.nodes: cur_features = [] for attr_name in node.attributes_dict.keys(): cur_attr = all_atr_dict[attr_name] if cur_attr.one_hot: cur_features += list(cur_attr.one_hot_dict[node.attributes_dict[cur_attr.name]]) else: cur_features += [node.attributes_dict[cur_attr.name]] features.append(cur_features) features = np.array(features) padded_features = np.zeros((max_size, features.shape[1])) padded_features[:features.shape[0], :features.shape[1]] = features return padded_features
[docs] def xyz_to_zmat(self): raise NotImplementedError()
[docs] def zmat_to_xyz(self): raise NotImplementedError()