├── images ├── gentrl.pdf └── gentrl.png ├── gentrl ├── __init__.py ├── encoder.py ├── dataloader.py ├── tokenizer.py ├── decoder.py ├── gentrl.py └── lp.py ├── setup.py ├── README.md └── examples ├── train_rl.ipynb ├── sampling.ipynb └── pretrain.ipynb /images/gentrl.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/GENTRL/HEAD/images/gentrl.pdf -------------------------------------------------------------------------------- /images/gentrl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/GENTRL/HEAD/images/gentrl.png -------------------------------------------------------------------------------- /gentrl/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder import RNNEncoder 2 | from .decoder import DilConvDecoder 3 | from .gentrl import GENTRL 4 | from .dataloader import MolecularDataset 5 | 6 | 7 | __all__ = ['RNNEncoder', 'DilConvDecoder', 'GENTRL', 'MolecularDataset'] 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | setup( 5 | name='gentrl', 6 | version='0.1', 7 | python_requires='>=3.5.0', 8 | packages=find_packages(), 9 | install_requires=[ 10 | 'numpy>=1.15', 11 | 'pandas>=0.23', 12 | 'scipy>=1.1.0', 13 | 'torch==1.0.1', 14 | 'molsets==0.1.3' 15 | ], 16 | description='Generative Tensorial Reinforcement Learning (GENTRL)', 17 | ) 18 | -------------------------------------------------------------------------------- /gentrl/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from gentrl.tokenizer import encode, get_vocab_size 4 | 5 | 6 | class RNNEncoder(nn.Module): 7 | def __init__(self, hidden_size=256, num_layers=2, latent_size=50, 8 | bidirectional=False): 9 | super(RNNEncoder, self).__init__() 10 | 11 | self.embs = nn.Embedding(get_vocab_size(), hidden_size) 12 | self.rnn = nn.GRU(input_size=hidden_size, 13 | hidden_size=hidden_size, 14 | num_layers=num_layers, 15 | bidirectional=bidirectional) 16 | 17 | self.final_mlp = nn.Sequential( 18 | nn.Linear(hidden_size, hidden_size), nn.LeakyReLU(), 19 | nn.Linear(hidden_size, 2 * latent_size)) 20 | 21 | def encode(self, sm_list): 22 | """ 23 | Maps smiles onto a latent space 24 | """ 25 | 26 | tokens, lens = encode(sm_list) 27 | to_feed = tokens.transpose(1, 0).to(self.embs.weight.device) 28 | 29 | outputs = self.rnn(self.embs(to_feed))[0] 30 | outputs = outputs[lens, torch.arange(len(lens))] 31 | 32 | return self.final_mlp(outputs) 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generative Tensorial Reinforcement Learning (GENTRL) 2 | Supporting Information for the paper _"[Deep learning enables rapid identification of potent DDR1 kinase inhibitors](https://www.nature.com/articles/s41587-019-0224-x)"_. 3 | 4 | The GENTRL model is a variational autoencoder with a rich prior distribution of the latent space. We used tensor decompositions to encode the relations between molecular structures and their properties and to learn on data with missing values. We train the model in two steps. First, we learn a mapping of a chemical space on the latent manifold by maximizing the evidence lower bound. We then freeze all the parameters except for the learnable prior and explore the chemical space to find molecules with a high reward. 5 | 6 | ![GENTRL](images/gentrl.png) 7 | 8 | 9 | ## Repository 10 | In this repository, we provide an implementation of a GENTRL model with an example trained on a [MOSES](https://github.com/molecularsets/moses) dataset. 11 | 12 | To run the training procedure, 13 | 1. [Install RDKit](https://www.rdkit.org/docs/Install.html) to process molecules 14 | 2. Install GENTRL model: `python setup.py install` 15 | 3. Install MOSES from the [repository](https://github.com/molecularsets/moses) 16 | 4. Run the [pretrain.ipynb](./examples/pretrain.ipynb) to train an autoencoder 17 | 5. Run the [train_rl.ipynb](./examples/train_rl.ipynb) to optimize a reward function 18 | -------------------------------------------------------------------------------- /gentrl/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch.utils.data import Dataset 4 | 5 | import pandas as pd 6 | import numpy as np 7 | 8 | 9 | class MolecularDataset(Dataset): 10 | def __init__(self, sources=[], props=['logIC50', 'BFL', 'pipeline'], 11 | with_missings=False): 12 | self.num_sources = len(sources) 13 | 14 | self.source_smiles = [] 15 | self.source_props = [] 16 | self.source_missings = [] 17 | self.source_probs = [] 18 | 19 | self.with_missings = with_missings 20 | 21 | self.len = 0 22 | for source_descr in sources: 23 | cur_df = pd.read_csv(source_descr['path']) 24 | cur_smiles = list(cur_df[source_descr['smiles']].values) 25 | 26 | cur_props = torch.zeros(len(cur_smiles), len(props)).float() 27 | cur_missings = torch.zeros(len(cur_smiles), len(props)).long() 28 | 29 | for i, prop in enumerate(props): 30 | if prop in source_descr: 31 | if isinstance(source_descr[prop], str): 32 | cur_props[:, i] = torch.from_numpy( 33 | cur_df[source_descr[prop]].values) 34 | else: 35 | cur_props[:, i] = torch.from_numpy( 36 | cur_df[source_descr['smiles']].map( 37 | source_descr[prop]).values) 38 | else: 39 | cur_missings[:, i] = 1 40 | 41 | self.source_smiles.append(cur_smiles) 42 | self.source_props.append(cur_props) 43 | self.source_missings.append(cur_missings) 44 | self.source_probs.append(source_descr['prob']) 45 | 46 | self.len = max(self.len, 47 | int(len(cur_smiles) / source_descr['prob'])) 48 | 49 | self.source_probs = np.array(self.source_probs).astype(np.float) 50 | 51 | self.source_probs /= self.source_probs.sum() 52 | 53 | def __len__(self): 54 | return self.len 55 | 56 | def __getitem__(self, idx): 57 | trial = np.random.random() 58 | 59 | s = 0 60 | for i in range(self.num_sources): 61 | if (trial >= s) and (trial <= s + self.source_probs[i]): 62 | bin_len = len(self.source_smiles[i]) 63 | sm = self.source_smiles[i][idx % bin_len] 64 | 65 | props = self.source_props[i][idx % bin_len] 66 | miss = self.source_missings[i][idx % bin_len] 67 | 68 | if self.with_missings: 69 | return sm, torch.concat([props, miss]) 70 | else: 71 | return sm, props 72 | 73 | s += self.source_probs[i] 74 | -------------------------------------------------------------------------------- /examples/train_rl.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import gentrl\n", 10 | "import torch\n", 11 | "torch.cuda.set_device(0)" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "enc = gentrl.RNNEncoder(latent_size=50)\n", 21 | "dec = gentrl.DilConvDecoder(latent_input_size=50)\n", 22 | "model = gentrl.GENTRL(enc, dec, 50 * [('c', 20)], [('c', 20)], beta=0.001)\n", 23 | "model.cuda();" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "model.load('saved_gentrl/')\n", 33 | "model.cuda();" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "from moses.metrics import mol_passes_filters, QED, SA, logP\n", 43 | "from moses.metrics.utils import get_n_rings, get_mol\n", 44 | "\n", 45 | "from moses.utils import disable_rdkit_log\n", 46 | "disable_rdkit_log()\n", 47 | "\n", 48 | "def get_num_rings_6(mol):\n", 49 | " r = mol.GetRingInfo()\n", 50 | " return len([x for x in r.AtomRings() if len(x) > 6])\n", 51 | "\n", 52 | "\n", 53 | "def penalized_logP(mol_or_smiles, masked=False, default=-5):\n", 54 | " mol = get_mol(mol_or_smiles)\n", 55 | " if mol is None:\n", 56 | " return default\n", 57 | " reward = logP(mol) - SA(mol) - get_num_rings_6(mol)\n", 58 | " if masked and not mol_passes_filters(mol):\n", 59 | " return default\n", 60 | " return reward" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "model.train_as_rl(penalized_logP)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "! mkdir -p saved_gentrl_after_rl" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "model.save('./saved_gentrl_after_rl/')" 88 | ] 89 | } 90 | ], 91 | "metadata": { 92 | "kernelspec": { 93 | "display_name": "Python 3", 94 | "language": "python", 95 | "name": "python3" 96 | }, 97 | "language_info": { 98 | "codemirror_mode": { 99 | "name": "ipython", 100 | "version": 3 101 | }, 102 | "file_extension": ".py", 103 | "mimetype": "text/x-python", 104 | "name": "python", 105 | "nbconvert_exporter": "python", 106 | "pygments_lexer": "ipython3", 107 | "version": "3.7.3" 108 | } 109 | }, 110 | "nbformat": 4, 111 | "nbformat_minor": 2 112 | } 113 | -------------------------------------------------------------------------------- /gentrl/tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import re 3 | 4 | 5 | _atoms = ['He', 'Li', 'Be', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'Cl', 'Ar', 6 | 'Ca', 'Ti', 'Cr', 'Fe', 'Ni', 'Cu', 'Ga', 'Ge', 'As', 'Se', 7 | 'Br', 'Kr', 'Rb', 'Sr', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 8 | 'Pd', 'Ag', 'Cd', 'Sb', 'Te', 'Xe', 'Ba', 'La', 'Ce', 'Pr', 9 | 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Er', 'Tm', 'Yb', 10 | 'Lu', 'Hf', 'Ta', 'Re', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 11 | 'Bi', 'At', 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'Pu', 'Am', 'Cm', 12 | 'Bk', 'Cf', 'Es', 'Fm', 'Md', 'Lr', 'Rf', 'Db', 'Sg', 'Mt', 13 | 'Ds', 'Rg', 'Fl', 'Mc', 'Lv', 'Ts', 'Og'] 14 | 15 | 16 | def get_tokenizer_re(atoms): 17 | return re.compile('('+'|'.join(atoms)+r'|\%\d\d|.)') 18 | 19 | 20 | _atoms_re = get_tokenizer_re(_atoms) 21 | 22 | 23 | __i2t = { 24 | 0: 'unused', 1: '>', 2: '<', 3: '2', 4: 'F', 5: 'Cl', 6: 'N', 25 | 7: '[', 8: '6', 9: 'O', 10: 'c', 11: ']', 12: '#', 26 | 13: '=', 14: '3', 15: ')', 16: '4', 17: '-', 18: 'n', 27 | 19: 'o', 20: '5', 21: 'H', 22: '(', 23: 'C', 28 | 24: '1', 25: 'S', 26: 's', 27: 'Br' 29 | } 30 | 31 | 32 | __t2i = { 33 | '>': 1, '<': 2, '2': 3, 'F': 4, 'Cl': 5, 'N': 6, '[': 7, '6': 8, 34 | 'O': 9, 'c': 10, ']': 11, '#': 12, '=': 13, '3': 14, ')': 15, 35 | '4': 16, '-': 17, 'n': 18, 'o': 19, '5': 20, 'H': 21, '(': 22, 36 | 'C': 23, '1': 24, 'S': 25, 's': 26, 'Br': 27 37 | } 38 | 39 | 40 | def smiles_tokenizer(line, atoms=None): 41 | """ 42 | Tokenizes SMILES string atom-wise using regular expressions. While this 43 | method is fast, it may lead to some mistakes: Sn may be considered as Tin 44 | or as Sulfur with Nitrogen in aromatic cycle. Because of this, you should 45 | specify a set of two-letter atoms explicitly. 46 | 47 | Parameters: 48 | atoms: set of two-letter atoms for tokenization 49 | """ 50 | if atoms is not None: 51 | reg = get_tokenizer_re(atoms) 52 | else: 53 | reg = _atoms_re 54 | return reg.split(line)[1::2] 55 | 56 | 57 | def encode(sm_list, pad_size=50): 58 | """ 59 | Encoder list of smiles to tensor of tokens 60 | """ 61 | res = [] 62 | lens = [] 63 | for s in sm_list: 64 | tokens = ([1] + [__t2i[tok] 65 | for tok in smiles_tokenizer(s)])[:pad_size - 1] 66 | lens.append(len(tokens)) 67 | tokens += (pad_size - len(tokens)) * [2] 68 | res.append(tokens) 69 | 70 | return torch.tensor(res).long(), lens 71 | 72 | 73 | def decode(tokens_tensor): 74 | """ 75 | Decodes from tensor of tokens to list of smiles 76 | """ 77 | 78 | smiles_res = [] 79 | 80 | for i in range(tokens_tensor.shape[0]): 81 | cur_sm = '' 82 | for t in tokens_tensor[i].detach().cpu().numpy(): 83 | if t == 2: 84 | break 85 | elif t > 2: 86 | cur_sm += __i2t[t] 87 | 88 | smiles_res.append(cur_sm) 89 | 90 | return smiles_res 91 | 92 | 93 | def get_vocab_size(): 94 | return len(__i2t) 95 | -------------------------------------------------------------------------------- /examples/sampling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import gentrl\n", 10 | "import torch\n", 11 | "from rdkit.Chem import Draw\n", 12 | "from moses.metrics import mol_passes_filters, QED, SA, logP\n", 13 | "from moses.metrics.utils import get_n_rings, get_mol\n", 14 | "\n", 15 | "torch.cuda.set_device(0)" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "enc = gentrl.RNNEncoder(latent_size=50)\n", 25 | "dec = gentrl.DilConvDecoder(latent_input_size=50)\n", 26 | "model = gentrl.GENTRL(enc, dec, 50 * [('c', 20)], [('c', 20)], beta=0.001)\n", 27 | "model.cuda();" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "model.load('saved_gentrl_after_rl/')\n", 37 | "model.cuda();" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "def get_num_rings_6(mol):\n", 47 | " r = mol.GetRingInfo()\n", 48 | " return len([x for x in r.AtomRings() if len(x) > 6])\n", 49 | "\n", 50 | "\n", 51 | "def penalized_logP(mol_or_smiles, masked=True, default=-5):\n", 52 | " mol = get_mol(mol_or_smiles)\n", 53 | " if mol is None:\n", 54 | " return default\n", 55 | " reward = logP(mol) - SA(mol) - get_num_rings_6(mol)\n", 56 | " if masked and not mol_passes_filters(mol):\n", 57 | " return default\n", 58 | " return reward" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "generated = []\n", 68 | "\n", 69 | "while len(generated) < 1000:\n", 70 | " sampled = model.sample(100)\n", 71 | " sampled_valid = [s for s in sampled if get_mol(s)]\n", 72 | " \n", 73 | " generated += sampled_valid" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "Draw.MolsToGridImage([get_mol(s) for s in sampled_valid], \n", 83 | " legends=[str(penalized_logP(s)) for s in sampled_valid])" 84 | ] 85 | } 86 | ], 87 | "metadata": { 88 | "kernelspec": { 89 | "display_name": "Python 3", 90 | "language": "python", 91 | "name": "python3" 92 | }, 93 | "language_info": { 94 | "codemirror_mode": { 95 | "name": "ipython", 96 | "version": 3 97 | }, 98 | "file_extension": ".py", 99 | "mimetype": "text/x-python", 100 | "name": "python", 101 | "nbconvert_exporter": "python", 102 | "pygments_lexer": "ipython3", 103 | "version": "3.6.7" 104 | } 105 | }, 106 | "nbformat": 4, 107 | "nbformat_minor": 2 108 | } 109 | -------------------------------------------------------------------------------- /examples/pretrain.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import gentrl\n", 10 | "import torch\n", 11 | "import pandas as pd\n", 12 | "torch.cuda.set_device(0)" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "from moses.metrics import mol_passes_filters, QED, SA, logP\n", 22 | "from moses.metrics.utils import get_n_rings, get_mol\n", 23 | "\n", 24 | "\n", 25 | "def get_num_rings_6(mol):\n", 26 | " r = mol.GetRingInfo()\n", 27 | " return len([x for x in r.AtomRings() if len(x) > 6])\n", 28 | "\n", 29 | "\n", 30 | "def penalized_logP(mol_or_smiles, masked=False, default=-5):\n", 31 | " mol = get_mol(mol_or_smiles)\n", 32 | " if mol is None:\n", 33 | " return default\n", 34 | " reward = logP(mol) - SA(mol) - get_num_rings_6(mol)\n", 35 | " if masked and not mol_passes_filters(mol):\n", 36 | " return default\n", 37 | " return reward" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "! wget https://media.githubusercontent.com/media/molecularsets/moses/master/data/dataset_v1.csv" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "df = pd.read_csv('dataset_v1.csv')\n", 56 | "df = df[df['SPLIT'] == 'train']\n", 57 | "df['plogP'] = df['SMILES'].apply(penalized_logP)\n", 58 | "df.to_csv('train_plogp_plogpm.csv', index=None)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "enc = gentrl.RNNEncoder(latent_size=50)\n", 68 | "dec = gentrl.DilConvDecoder(latent_input_size=50)\n", 69 | "model = gentrl.GENTRL(enc, dec, 50 * [('c', 20)], [('c', 20)], beta=0.001)\n", 70 | "model.cuda();" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "md = gentrl.MolecularDataset(sources=[\n", 80 | " {'path':'train_plogp_plogpm.csv',\n", 81 | " 'smiles': 'SMILES',\n", 82 | " 'prob': 1,\n", 83 | " 'plogP' : 'plogP',\n", 84 | " }], \n", 85 | " props=['plogP'])\n", 86 | "\n", 87 | "from torch.utils.data import DataLoader\n", 88 | "train_loader = DataLoader(md, batch_size=50, shuffle=True, num_workers=1, drop_last=True)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "model.train_as_vaelp(train_loader, lr=1e-4)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "! mkdir -p saved_gentrl" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "model.save('./saved_gentrl/')" 116 | ] 117 | } 118 | ], 119 | "metadata": { 120 | "kernelspec": { 121 | "display_name": "Python 3", 122 | "language": "python", 123 | "name": "python3" 124 | }, 125 | "language_info": { 126 | "codemirror_mode": { 127 | "name": "ipython", 128 | "version": 3 129 | }, 130 | "file_extension": ".py", 131 | "mimetype": "text/x-python", 132 | "name": "python", 133 | "nbconvert_exporter": "python", 134 | "pygments_lexer": "ipython3", 135 | "version": "3.6.7" 136 | } 137 | }, 138 | "nbformat": 4, 139 | "nbformat_minor": 2 140 | } 141 | -------------------------------------------------------------------------------- /gentrl/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from gentrl.tokenizer import get_vocab_size, encode, decode 5 | 6 | 7 | class DilConvDecoder(nn.Module): 8 | ''' 9 | Class for autoregressive model that works in WaveNet manner. 10 | It make conditinioning on previosly sampled tokens by running 11 | stack of dilation convolution on them. 12 | ''' 13 | def __init__(self, latent_input_size, token_weights=None, 14 | split_len=50, num_dilated_layers=7, num_channels=128): 15 | r''' 16 | Args: 17 | latent_input_size: int, size of latent code used in VAE-like models 18 | token_weights: Tensor of shape [num_tokens], where i-th element 19 | contains the weight of i-th token. If None, then all 20 | tokens has the same weight. 21 | split_len: int, maximum length of token sequence 22 | num_dilated_layers: int, how much dilated layers is in stack 23 | num_channels: int, num channels in convolutional layers 24 | ''' 25 | super(DilConvDecoder, self).__init__() 26 | self.vocab_size = get_vocab_size() 27 | self.latent_input_size = latent_input_size 28 | self.split_len = split_len 29 | self.num_dilated_layers = num_dilated_layers 30 | self.num_channels = num_channels 31 | self.token_weights = token_weights 32 | self.eos = 2 33 | 34 | cur_dil = 1 35 | self.dil_conv_layers = [] 36 | for i in range(num_dilated_layers): 37 | self.dil_conv_layers.append( 38 | DilConv1dWithGLU(num_channels, cur_dil)) 39 | cur_dil *= 2 40 | 41 | self.latent_fc = nn.Linear(latent_input_size, num_channels) 42 | self.input_embeddings = nn.Embedding(self.vocab_size, 43 | num_channels) 44 | self.logits_1x1_layer = nn.Conv1d(num_channels, 45 | self.vocab_size, 46 | kernel_size=1) 47 | 48 | cur_parameters = [] 49 | for layer in [self.input_embeddings, self.logits_1x1_layer, 50 | self.latent_fc] + self.dil_conv_layers: 51 | cur_parameters += list(layer.parameters()) 52 | 53 | self.parameters = nn.ParameterList(cur_parameters) 54 | 55 | def get_logits(self, input_tensor, z, sampling=False): 56 | ''' 57 | Computing logits for each token input_tensor by given latent code 58 | 59 | [WORKS ONLY IN TEACHER-FORCING MODE] 60 | 61 | Args: 62 | input_tensor: Tensor of shape [batch_size, max_seq_len] 63 | z: Tensor of shape [batch_size, lat_code_size] 64 | ''' 65 | 66 | input_embedded = self.input_embeddings(input_tensor).transpose(1, 2) 67 | latent_embedded = self.latent_fc(z) 68 | 69 | x = input_embedded + latent_embedded.unsqueeze(-1) 70 | 71 | for dil_conv_layer in self.dil_conv_layers: 72 | x = dil_conv_layer(x, sampling=sampling) 73 | 74 | x = self.logits_1x1_layer(x).transpose(1, 2) 75 | 76 | return F.log_softmax(x, dim=-1) 77 | 78 | def get_log_prob(self, x, z): 79 | ''' 80 | Getting logits of SMILES sequences 81 | Args: 82 | x: tensor of shape [batch_size, seq_size] with tokens 83 | z: tensor of shape [batch_size, lat_size] with latents 84 | Returns: 85 | logits: tensor of shape [batch_size, seq_size] 86 | ''' 87 | seq_logits = torch.gather(self.get_logits(x, z)[:, :-1, :], 88 | 2, x[:, 1:].long().unsqueeze(-1)) 89 | 90 | return seq_logits[:, :, 0] 91 | 92 | def forward(self, x, z): 93 | ''' 94 | Getting logits of SMILES sequences 95 | Args: 96 | x: tensor of shape [batch_size, seq_size] with tokens 97 | z: tensor of shape [batch_size, lat_size] with latents 98 | Returns: 99 | logits: tensor of shape [batch_size, seq_size] 100 | None: since dilconv decoder doesn't have hidden state unlike RNN 101 | ''' 102 | return self.get_log_prob(x, z), None 103 | 104 | def weighted_forward(self, sm_list, z): 105 | ''' 106 | ''' 107 | x = encode(sm_list)[0].to( 108 | self.input_embeddings.weight.data.device 109 | ) 110 | 111 | seq_logits = self.get_log_prob(x, z) 112 | 113 | if self.token_weights is not None: 114 | w = self.token_weights[x[:, 1:].long().contiguous().view(-1)] 115 | w = w.view_as(seq_logits) 116 | seq_logits = seq_logits * w 117 | 118 | non_eof = (x != self.eos)[:, :-1].float() 119 | ans_logits = (seq_logits * non_eof).sum(dim=-1) 120 | ans_logits /= non_eof.sum(dim=-1) 121 | 122 | return ans_logits 123 | 124 | def sample(self, max_len, latents, argmax=True): 125 | ''' Sample SMILES for given latents 126 | 127 | Args: 128 | latents: tensor of shape [n_batch, n_features] 129 | 130 | Returns: 131 | logits: tensor of shape [batch_size, seq_size], logits of tokens 132 | tokens: tensor of shape [batch_size, seq_size], sampled token 133 | None: since dilconv decoder doesn't have hidden state unlike RNN 134 | 135 | ''' 136 | 137 | # clearing buffers 138 | for dil_conv_layer in self.dil_conv_layers: 139 | dil_conv_layer.clear_buffer() 140 | 141 | num_objects = latents.shape[0] 142 | 143 | ans_seqs = [[1] for _ in range(num_objects)] 144 | ans_logits = [] 145 | 146 | cur_tokens = torch.tensor(ans_seqs, device=latents.device).long() 147 | 148 | for s in range(max_len): 149 | logits = self.get_logits(cur_tokens, latents, sampling=True) 150 | logits = logits.detach() 151 | logits = torch.log_softmax(logits[:, 0, :], dim=-1) 152 | ans_logits.append(logits.unsqueeze(0)) 153 | 154 | if argmax: 155 | cur_tokens = torch.max(logits, dim=-1)[1].unsqueeze(-1) 156 | else: 157 | cur_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1) 158 | 159 | det_tokens = cur_tokens.cpu().detach().tolist() 160 | ans_seqs = [a + b for a, b in zip(ans_seqs, det_tokens)] 161 | 162 | # clearing buffers 163 | for dil_conv_layer in self.dil_conv_layers: 164 | dil_conv_layer.clear_buffer() 165 | 166 | ans_logits = torch.cat(ans_logits, dim=0) 167 | ans_seqs = torch.tensor(ans_seqs)[:, 1:] 168 | return decode(ans_seqs) 169 | 170 | 171 | class DilConv1dWithGLU(nn.Module): 172 | def __init__(self, num_channels, dilation, lenght=100, 173 | kernel_size=2, activation=F.leaky_relu, 174 | residual_connection=True, dropout=0.2): 175 | 176 | super(DilConv1dWithGLU, self).__init__() 177 | 178 | self.dilation = dilation 179 | 180 | self.start_ln = nn.LayerNorm(num_channels) 181 | self.start_conv1x1 = nn.Conv1d(num_channels, num_channels, 182 | kernel_size=1) 183 | 184 | self.dilconv_ln = nn.LayerNorm(num_channels) 185 | self.dilated_conv = nn.Conv1d(num_channels, num_channels, 186 | dilation=dilation, 187 | kernel_size=kernel_size, 188 | padding=dilation) 189 | 190 | self.gate_ln = nn.LayerNorm(num_channels) 191 | self.end_conv1x1 = nn.Conv1d(num_channels, num_channels, 192 | kernel_size=1) 193 | self.gated_conv1x1 = nn.Conv1d(num_channels, num_channels, 194 | kernel_size=1) 195 | 196 | self.activation = activation 197 | 198 | self.buffer = None 199 | 200 | self.residual_connection = residual_connection 201 | 202 | def clear_buffer(self): 203 | self.buffer = None 204 | 205 | def forward(self, x_inp, sampling=False): 206 | # applying 1x1 convolution 207 | x = self.start_ln(x_inp.transpose(1, 2)).transpose(1, 2) 208 | x = self.activation(x) 209 | x = self.start_conv1x1(x) 210 | 211 | # applying dilated convolution 212 | # if in sampling mode 213 | x = self.dilconv_ln(x.transpose(1, 2)).transpose(1, 2) 214 | x = self.activation(x) 215 | if sampling: 216 | if self.buffer is None: 217 | self.buffer = x 218 | else: 219 | pre_buffer = torch.cat([self.buffer, x], dim=2) 220 | self.buffer = pre_buffer[:, :, -(self.dilation + 1):] 221 | 222 | if self.buffer.shape[2] == self.dilation + 1: 223 | x = self.buffer 224 | else: 225 | x = torch.cat([torch.zeros(self.buffer.shape[0], 226 | self.buffer.shape[1], 227 | self.dilation + 1 228 | - self.buffer.shape[2], 229 | device=x_inp.device), self.buffer], 230 | dim=2) 231 | 232 | x = self.dilated_conv(x)[:, :, self.dilation:] 233 | x = x[:, :, :x_inp.shape[-1]] 234 | else: 235 | x = self.dilated_conv(x)[:, :, :x_inp.shape[-1]] 236 | 237 | # applying gated linear unit 238 | x = self.gate_ln(x.transpose(1, 2)).transpose(1, 2) 239 | x = self.activation(x) 240 | x = self.end_conv1x1(x) * torch.sigmoid(self.gated_conv1x1(x)) 241 | 242 | # if residual connection 243 | if self.residual_connection: 244 | x = x + x_inp 245 | 246 | return x 247 | -------------------------------------------------------------------------------- /gentrl/gentrl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from math import pi, log 5 | from gentrl.lp import LP 6 | import pickle 7 | 8 | from moses.metrics.utils import get_mol 9 | 10 | 11 | class TrainStats(): 12 | def __init__(self): 13 | self.stats = dict() 14 | 15 | def update(self, delta): 16 | for key in delta.keys(): 17 | if key in self.stats.keys(): 18 | self.stats[key].append(delta[key]) 19 | else: 20 | self.stats[key] = [delta[key]] 21 | 22 | def reset(self): 23 | for key in self.stats.keys(): 24 | self.stats[key] = [] 25 | 26 | def print(self): 27 | for key in self.stats.keys(): 28 | print(str(key) + ": {:4.4};".format( 29 | sum(self.stats[key]) / len(self.stats[key]) 30 | ), end='') 31 | 32 | print() 33 | 34 | 35 | class GENTRL(nn.Module): 36 | ''' 37 | GENTRL model 38 | ''' 39 | def __init__(self, enc, dec, latent_descr, feature_descr, tt_int=40, 40 | tt_type='usual', beta=0.01, gamma=0.1): 41 | super(GENTRL, self).__init__() 42 | 43 | self.enc = enc 44 | self.dec = dec 45 | 46 | self.num_latent = len(latent_descr) 47 | self.num_features = len(feature_descr) 48 | 49 | self.latent_descr = latent_descr 50 | self.feature_descr = feature_descr 51 | 52 | self.tt_int = tt_int 53 | self.tt_type = tt_type 54 | 55 | self.lp = LP(distr_descr=self.latent_descr + self.feature_descr, 56 | tt_int=self.tt_int, tt_type=self.tt_type) 57 | 58 | self.beta = beta 59 | self.gamma = gamma 60 | 61 | def get_elbo(self, x, y): 62 | means, log_stds = torch.split(self.enc.encode(x), 63 | len(self.latent_descr), dim=1) 64 | latvar_samples = (means + torch.randn_like(log_stds) * 65 | torch.exp(0.5 * log_stds)) 66 | 67 | rec_part = self.dec.weighted_forward(x, latvar_samples).mean() 68 | 69 | normal_distr_hentropies = (log(2 * pi) + 1 + log_stds).sum(dim=1) 70 | 71 | latent_dim = len(self.latent_descr) 72 | condition_dim = len(self.feature_descr) 73 | 74 | zy = torch.cat([latvar_samples, y], dim=1) 75 | log_p_zy = self.lp.log_prob(zy) 76 | 77 | y_to_marg = latent_dim * [True] + condition_dim * [False] 78 | log_p_y = self.lp.log_prob(zy, marg=y_to_marg) 79 | 80 | z_to_marg = latent_dim * [False] + condition_dim * [True] 81 | log_p_z = self.lp.log_prob(zy, marg=z_to_marg) 82 | log_p_z_by_y = log_p_zy - log_p_y 83 | log_p_y_by_z = log_p_zy - log_p_z 84 | 85 | kldiv_part = (-normal_distr_hentropies - log_p_zy).mean() 86 | 87 | elbo = rec_part - self.beta * kldiv_part 88 | elbo = elbo + self.gamma * log_p_y_by_z.mean() 89 | 90 | return elbo, { 91 | 'loss': -elbo.detach().cpu().numpy(), 92 | 'rec': rec_part.detach().cpu().numpy(), 93 | 'kl': kldiv_part.detach().cpu().numpy(), 94 | 'log_p_y_by_z': log_p_y_by_z.mean().detach().cpu().numpy(), 95 | 'log_p_z_by_y': log_p_z_by_y.mean().detach().cpu().numpy() 96 | } 97 | 98 | def save(self, folder_to_save='./'): 99 | if folder_to_save[-1] != '/': 100 | folder_to_save = folder_to_save + '/' 101 | torch.save(self.enc.state_dict(), folder_to_save + 'enc.model') 102 | torch.save(self.dec.state_dict(), folder_to_save + 'dec.model') 103 | torch.save(self.lp.state_dict(), folder_to_save + 'lp.model') 104 | 105 | pickle.dump(self.lp.order, open(folder_to_save + 'order.pkl', 'wb')) 106 | 107 | def load(self, folder_to_load='./'): 108 | if folder_to_load[-1] != '/': 109 | folder_to_load = folder_to_load + '/' 110 | 111 | order = pickle.load(open(folder_to_load + 'order.pkl', 'rb')) 112 | self.lp = LP(distr_descr=self.latent_descr + self.feature_descr, 113 | tt_int=self.tt_int, tt_type=self.tt_type, 114 | order=order) 115 | 116 | self.enc.load_state_dict(torch.load(folder_to_load + 'enc.model')) 117 | self.dec.load_state_dict(torch.load(folder_to_load + 'dec.model')) 118 | self.lp.load_state_dict(torch.load(folder_to_load + 'lp.model')) 119 | 120 | def train_as_vaelp(self, train_loader, num_epochs=10, 121 | verbose_step=50, lr=1e-3): 122 | optimizer = optim.Adam(self.parameters(), lr=lr) 123 | 124 | global_stats = TrainStats() 125 | local_stats = TrainStats() 126 | 127 | epoch_i = 0 128 | to_reinit = False 129 | buf = None 130 | while epoch_i < num_epochs: 131 | i = 0 132 | if verbose_step: 133 | print("Epoch", epoch_i, ":") 134 | 135 | if epoch_i in [0, 1, 5]: 136 | to_reinit = True 137 | 138 | epoch_i += 1 139 | 140 | for x_batch, y_batch in train_loader: 141 | if verbose_step: 142 | print("!", end='') 143 | 144 | i += 1 145 | 146 | y_batch = y_batch.float().to(self.lp.tt_cores[0].device) 147 | if len(y_batch.shape) == 1: 148 | y_batch = y_batch.view(-1, 1).contiguous() 149 | 150 | if to_reinit: 151 | if (buf is None) or (buf.shape[0] < 5000): 152 | enc_out = self.enc.encode(x_batch) 153 | means, log_stds = torch.split(enc_out, 154 | len(self.latent_descr), 155 | dim=1) 156 | z_batch = (means + torch.randn_like(log_stds) * 157 | torch.exp(0.5 * log_stds)) 158 | cur_batch = torch.cat([z_batch, y_batch], dim=1) 159 | if buf is None: 160 | buf = cur_batch 161 | else: 162 | buf = torch.cat([buf, cur_batch]) 163 | else: 164 | descr = len(self.latent_descr) * [0] 165 | descr += len(self.feature_descr) * [1] 166 | self.lp.reinit_from_data(buf, descr) 167 | self.lp.cuda() 168 | buf = None 169 | to_reinit = False 170 | 171 | continue 172 | 173 | elbo, cur_stats = self.get_elbo(x_batch, y_batch) 174 | local_stats.update(cur_stats) 175 | global_stats.update(cur_stats) 176 | 177 | optimizer.zero_grad() 178 | loss = -elbo 179 | loss.backward() 180 | optimizer.step() 181 | 182 | if verbose_step and i % verbose_step == 0: 183 | local_stats.print() 184 | local_stats.reset() 185 | i = 0 186 | 187 | epoch_i += 1 188 | if i > 0: 189 | local_stats.print() 190 | local_stats.reset() 191 | 192 | return global_stats 193 | 194 | def train_as_rl(self, 195 | reward_fn, 196 | num_iterations=100000, verbose_step=50, 197 | batch_size=200, 198 | cond_lb=-2, cond_rb=0, 199 | lr_lp=1e-5, lr_dec=1e-6): 200 | optimizer_lp = optim.Adam(self.lp.parameters(), lr=lr_lp) 201 | optimizer_dec = optim.Adam(self.dec.latent_fc.parameters(), lr=lr_dec) 202 | 203 | global_stats = TrainStats() 204 | local_stats = TrainStats() 205 | 206 | cur_iteration = 0 207 | while cur_iteration < num_iterations: 208 | print("!", end='') 209 | 210 | exploit_size = int(batch_size * (1 - 0.3)) 211 | exploit_z = self.lp.sample(exploit_size, 50 * ['s'] + ['m']) 212 | 213 | z_means = exploit_z.mean(dim=0) 214 | z_stds = exploit_z.std(dim=0) 215 | 216 | expl_size = int(batch_size * 0.3) 217 | expl_z = torch.randn(expl_size, exploit_z.shape[1]) 218 | expl_z = 2 * expl_z.to(exploit_z.device) * z_stds[None, :] 219 | expl_z += z_means[None, :] 220 | 221 | z = torch.cat([exploit_z, expl_z]) 222 | smiles = self.dec.sample(50, z, argmax=False) 223 | zc = torch.zeros(z.shape[0], 1).to(z.device) 224 | conc_zy = torch.cat([z, zc], dim=1) 225 | log_probs = self.lp.log_prob(conc_zy, marg=50 * [False] + [True]) 226 | log_probs += self.dec.weighted_forward(smiles, z) 227 | r_list = [reward_fn(s) for s in smiles] 228 | 229 | rewards = torch.tensor(r_list).float().to(exploit_z.device) 230 | rewards_bl = rewards - rewards.mean() 231 | 232 | optimizer_dec.zero_grad() 233 | optimizer_lp.zero_grad() 234 | loss = -(log_probs * rewards_bl).mean() 235 | loss.backward() 236 | optimizer_dec.step() 237 | optimizer_lp.step() 238 | 239 | valid_sm = [s for s in smiles if get_mol(s) is not None] 240 | cur_stats = {'mean_reward': sum(r_list) / len(smiles), 241 | 'valid_perc': len(valid_sm) / len(smiles)} 242 | 243 | local_stats.update(cur_stats) 244 | global_stats.update(cur_stats) 245 | 246 | cur_iteration += 1 247 | 248 | if verbose_step and (cur_iteration + 1) % verbose_step == 0: 249 | local_stats.print() 250 | local_stats.reset() 251 | 252 | return global_stats 253 | 254 | def sample(self, num_samples): 255 | z = self.lp.sample(num_samples, 50 * ['s'] + ['m']) 256 | smiles = self.dec.sample(50, z, argmax=False) 257 | 258 | return smiles 259 | -------------------------------------------------------------------------------- /gentrl/lp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from math import sqrt, pi 4 | from sklearn.mixture import GaussianMixture 5 | import numpy as np 6 | 7 | 8 | class LP(nn.Module): 9 | """ 10 | Class for a Learnable Prior. 11 | """ 12 | 13 | def __init__(self, distr_descr, tt_int=30, distr_init='rand', 14 | tt_type='usual', eps=1e-10, order=None, **kwargs): 15 | """ 16 | Args: 17 | distr_descr: list of n tuples, where n is a number 18 | of variables in lp model distribution, i-th tuple describes 19 | the distribution of the i-th variable; if i-th variable is 20 | continuous the tuple should contain ('c', d, lr, rb), where d 21 | is a number of gaussians to model this variable, lr and rb are 22 | optional elements that descibe lower and upper bounds for 23 | means of the gaussians; if i-th variable is discrete, then it 24 | should be described as ('d', d) where d is number of values 25 | this variable can take 26 | 27 | example: [('c', 10), ('c', 10, -2, 5), ('d', 2), ('c', 20)] 28 | tt_int: int; internal dimension of Tensor-Train decomposition 29 | distr_init: 'rand' or 'uniform'; method to initialize 30 | the distribution 31 | tt_type: 'usual' or 'ring'; type of Tensor Train decomposition 32 | eps: float; small number to avoid devision by zero 33 | order: None or list of int; if None then order of cores corresponds 34 | to distr_descr, otherwise it should be a permutation of 35 | [0, 1, ..., len(distr_descr) - 1] 36 | """ 37 | super(LP, self).__init__() 38 | 39 | self.tt_int = tt_int 40 | self.tt_type = tt_type 41 | 42 | self.distr_descr = distr_descr 43 | self.distr_init = distr_init 44 | 45 | self.tt_cores = [] 46 | self.means = [] 47 | self.log_stds = [] 48 | 49 | self.eps = eps 50 | 51 | if order is None: 52 | self.order = list(range(len(distr_descr))) 53 | else: 54 | self.order = order 55 | 56 | # initialize cores, means, and stds for the distribution 57 | 58 | if self.tt_type not in ['ring', 'usual']: 59 | raise ValueError("Use 'ring' or 'usual' in tt_type, " 60 | "found {}".format(self.tt_type)) 61 | 62 | for var_descr in self.distr_descr: 63 | if distr_init == 'rand': 64 | cur_core = torch.randn(var_descr[1], self.tt_int, self.tt_int) 65 | elif distr_init == 'uniform': 66 | cur_core = torch.ones(var_descr[1], self.tt_int, self.tt_int) 67 | else: 68 | raise ValueError("Use 'rand' or 'uniform' in distr_init, " 69 | "found {}".format(distr_init)) 70 | 71 | cur_core = cur_core / (self.tt_int ** 2 * var_descr[1]) 72 | 73 | self.tt_cores.append(nn.Parameter(cur_core)) 74 | 75 | if var_descr[0] == 'd': # discrete variable 76 | self.means.append(None) 77 | self.log_stds.append(None) 78 | elif var_descr[0] == 'c': # continous variable 79 | if len(var_descr) == 4: 80 | lb = var_descr[2] 81 | rb = var_descr[3] 82 | else: 83 | lb = -1 84 | rb = 1 85 | 86 | if distr_init == 'rand': 87 | cur_means = torch.rand(var_descr[1]) * (rb - lb) + lb 88 | elif distr_init == 'uniform': 89 | cur_means = (torch.arange(var_descr[1]).float() / 90 | (var_descr[1] - 1)) * (rb - lb) + lb 91 | 92 | cur_log_stds = 2 * torch.log( 93 | torch.ones(var_descr[1]) * (rb - lb) / var_descr[1] 94 | ) 95 | 96 | self.means.append(nn.Parameter(cur_means)) 97 | self.log_stds.append(nn.Parameter(cur_log_stds)) 98 | else: 99 | raise ValueError("Use 'c' or 'd' in distribution desciption, " 100 | "found {}".format(var_descr[1])) 101 | 102 | self._make_model_parameters() 103 | 104 | @staticmethod 105 | def __make_contr_vec(x, var, missed, means, log_stds): 106 | if missed is None: 107 | missed = torch.isnan(x).byte() 108 | x[missed] = 0 109 | if var[0] == 'd': 110 | contr_vect = torch.zeros(x.shape[0], var[1]) 111 | contr_vect[missed] = 1 112 | contr_vect[torch.arange(x.shape[0]), x.long().cpu()] = 1 113 | elif var[0] == 'c': 114 | cur_vals = x[:, None] 115 | cur_stds = torch.exp(log_stds)[None, :] 116 | cur_means = means[None, :] 117 | 118 | contr_vect = (cur_vals - cur_means) / cur_stds 119 | contr_vect = torch.exp(-0.5 * (contr_vect ** 2)) 120 | contr_vect = contr_vect / (sqrt(2 * pi) * cur_stds) 121 | contr_vect = contr_vect + 1e-10 122 | 123 | m = missed.float()[:, None].to(x.device) 124 | contr_vect = contr_vect * (1 - m) + m 125 | 126 | contr_vect = contr_vect.to(x.device) 127 | 128 | return contr_vect 129 | 130 | def log_prob(self, x, marg=None): 131 | ''' 132 | Computes logits for each token input_tensor by given latent code 133 | 134 | Args: 135 | x: tensor of shape [num_objects, num_components]; 136 | missing values encoded as nan 137 | marg: None or list of bools; if None, no variables will 138 | be marginalized, else if i-th value of list is True, 139 | then i-th variable will be marginalized 140 | Returns: 141 | log_probs: tensor of shape [num_objects] 142 | ''' 143 | num_objects = x.shape[0] 144 | 145 | if marg is None: 146 | marg = x.shape[1] * [False] 147 | 148 | perm_marg = [marg[i] for i in self.order] 149 | perm_dist_descr = [self.distr_descr[i] for i in self.order] 150 | perm_x = x[:, self.order] 151 | perm_cores = [self.tt_cores[i] for i in self.order] 152 | perm_means = [self.means[i] for i in self.order] 153 | perm_log_stds = [self.log_stds[i] for i in self.order] 154 | 155 | # compute log probabilities 156 | log_probs = torch.zeros(num_objects).to(x.device) 157 | 158 | if self.tt_type == 'usual': 159 | pref = torch.ones(num_objects, 1, perm_cores[0].shape[1]) 160 | norm_pref = torch.ones(num_objects, 1, perm_cores[0].shape[1]) 161 | elif self.tt_type == 'ring': 162 | pref = torch.eye(perm_cores[0].shape[1]) 163 | pref = pref[None, :, :].repeat(num_objects, 1, 1) 164 | 165 | norm_pref = torch.eye(perm_cores[0].shape[1]) 166 | pref = pref.to(x.device) 167 | norm_pref = norm_pref.to(x.device) 168 | 169 | for i, (core, var) in enumerate(zip(perm_cores, perm_dist_descr)): 170 | core = self._pos_func(core) 171 | 172 | if perm_marg[i]: 173 | cond_core = core.sum(dim=0)[None, :, :] 174 | cond_core = cond_core.repeat(num_objects, 1, 1) 175 | else: 176 | cur_contr_vect = self.__make_contr_vec(perm_x[:, i], 177 | var, None, 178 | perm_means[i], 179 | perm_log_stds[i]) 180 | cond_core = core[None, :, :, :] 181 | cond_core = cond_core * cur_contr_vect[:, :, None, None] 182 | cond_core = cond_core.sum(dim=1) 183 | 184 | norm_core = core.sum(dim=0) 185 | 186 | pref = torch.bmm(pref, cond_core) 187 | norm_pref = norm_pref @ norm_core 188 | 189 | cur_norm_const = torch.sum(norm_pref) + self.eps 190 | pref = pref / cur_norm_const 191 | norm_pref = norm_pref / cur_norm_const 192 | 193 | cur_prob_addition = pref.sum(dim=-1).sum(dim=-1) + self.eps 194 | log_probs = log_probs + torch.log(cur_prob_addition) 195 | 196 | pref = pref / cur_prob_addition[:, None, None] 197 | 198 | if self.tt_type == 'ring': 199 | eye = torch.eye(perm_cores[-1].shape[-1])[None, :, :] 200 | eye = eye.to(x.device) 201 | cur_prob_addition = (pref * eye).sum(dim=-1).sum(dim=-1) 202 | cur_div = (norm_pref * eye).sum(dim=-1).sum(dim=-1) + self.eps 203 | cur_prob_addition = cur_prob_addition / cur_div 204 | log_probs = log_probs + torch.log(cur_prob_addition) 205 | elif self.tt_type == 'usual': 206 | cur_prob_addition = pref.sum(dim=-1).sum(dim=-1) 207 | cur_div = norm_pref.sum(dim=-1).sum(dim=-1) + self.eps 208 | cur_prob_addition = cur_prob_addition / cur_div 209 | log_probs = log_probs + torch.log(cur_prob_addition) 210 | 211 | return log_probs 212 | 213 | def sample(self, num_samples, sample_descr, conds=None): 214 | ''' 215 | Sample from the distribution 216 | 217 | Args: 218 | num_samples: int, number objects to sample 219 | sample_descr: list of chars, containining 220 | 's' if we should sample this variable 221 | 'm' if we should marginalise this variable 222 | 'c' if we should condition on this variable 223 | 224 | example: ['s', 's', 'c', 's', 'm', 's'] 225 | conditions: tensor of shape [num_sampled, total_num_of_variables], 226 | if sample_descr has variables for conditioning, then 227 | condition values should be set by this parameter 228 | Returns: 229 | samples: tensor of shape [num_objects, num_vars_to_sample] 230 | ''' 231 | 232 | perm_dist_descr = [self.distr_descr[i] for i in self.order] 233 | 234 | perm_sample_descr = [sample_descr[i] for i in self.order] 235 | perm_cores = [self.tt_cores[i] for i in self.order] 236 | perm_means = [self.means[i] for i in self.order] 237 | perm_log_stds = [self.log_stds[i] for i in self.order] 238 | 239 | if conds is not None: 240 | perm_conds = conds[:, self.order] 241 | 242 | # computing contraction vectors 243 | contr_vect_list = [] 244 | for i, (action, var) in enumerate( 245 | zip(perm_sample_descr, perm_dist_descr)): 246 | if action == 'c': 247 | contr_vect_list.append(self.__make_contr_vec(perm_conds[:, i], 248 | var, 249 | None, 250 | perm_means[i], 251 | perm_log_stds[i])) 252 | elif action in ['m', 's']: 253 | contr_vect_list.append( 254 | torch.ones(num_samples, var[1]).to(self.tt_cores[0].device) 255 | ) 256 | 257 | # computing suffixes to sample via chainrule 258 | sufxs = [] 259 | if self.tt_type == 'usual': 260 | cur_suf = torch.ones(num_samples, perm_cores[-1].shape[-1], 1) 261 | else: 262 | cur_suf = torch.eye(perm_cores[-1].shape[-1]) 263 | cur_suf = cur_suf[None, :, :].repeat(num_samples, 1, 1) 264 | cur_suf = cur_suf.to(self.tt_cores[0]) 265 | sufxs.append(cur_suf) 266 | 267 | for var_descr, core, contr_vect in zip(perm_dist_descr[::-1], 268 | perm_cores[::-1], 269 | contr_vect_list[::-1]): 270 | core = self._pos_func(core) 271 | 272 | cond_core = (core[None, :, :, :] * contr_vect[:, :, None, None]) 273 | cond_core = cond_core.sum(dim=1) 274 | 275 | cur_suf = torch.bmm(cond_core, cur_suf) 276 | 277 | norm_const = torch.sum(cur_suf + self.eps, dim=-1, keepdim=True) 278 | norm_const = torch.sum(norm_const, dim=-2, keepdim=True) 279 | 280 | cur_suf /= norm_const 281 | 282 | sufxs.append(cur_suf) 283 | sufxs = sufxs[-2::-1] 284 | 285 | # sampling 286 | if self.tt_type == 'usual': 287 | pref = torch.ones(num_samples, 1, perm_cores[0].shape[1]) 288 | else: 289 | pref = torch.eye(perm_cores[0].shape[1])[None, :, :] 290 | pref = pref.repeat(num_samples, 1, 1) 291 | 292 | pref = pref.to(self.tt_cores[0]) 293 | 294 | samples_list = [] 295 | for i, (action, var, core, suf, prev_contr_vect) in enumerate( 296 | zip(perm_sample_descr, 297 | perm_dist_descr, 298 | perm_cores, 299 | sufxs, 300 | contr_vect_list)): 301 | core = self._pos_func(core) 302 | if action == 's': 303 | # compute current mixture/discr dist weights 304 | part_to_contract = torch.bmm(suf, pref).permute(0, 2, 1) 305 | part_to_contract = part_to_contract[:, None, :, :] 306 | weights = part_to_contract * core[None, :, :, :] 307 | weights = weights.sum(dim=-1).sum(dim=-1) + self.eps 308 | weights /= torch.sum(weights, dim=-1, keepdim=True) 309 | 310 | # sample 311 | discr_comp_sample = torch.multinomial(weights, num_samples=1) 312 | discr_comp_sample = discr_comp_sample.view(-1) 313 | 314 | # construct cur_contr_vect 315 | if var[0] == 'd': 316 | cur_samples = discr_comp_sample 317 | elif var[0] == 'c': 318 | cur_means = perm_means[i][discr_comp_sample] 319 | cur_log_stds = perm_log_stds[i][discr_comp_sample] 320 | 321 | cur_samples = cur_means + torch.exp( 322 | cur_log_stds) * torch.randn_like(cur_log_stds) 323 | 324 | samples_list.append(cur_samples) 325 | contr_vect = self.__make_contr_vec(cur_samples, 326 | var, 327 | None, 328 | perm_means[i], 329 | perm_log_stds[i]) 330 | elif action in ['m', 'c']: 331 | samples_list.append(None) 332 | contr_vect = prev_contr_vect 333 | 334 | cond_core = (core[None, :, :, :] * contr_vect[:, :, None, None]) 335 | cond_core = cond_core.sum(dim=1) 336 | 337 | pref = torch.bmm(pref, cond_core) + self.eps 338 | 339 | norm_const = torch.sum(pref + self.eps, dim=-1, keepdim=True) 340 | norm_const = torch.sum(norm_const, dim=-2, keepdim=True) 341 | 342 | pref /= norm_const 343 | 344 | inv_perm_samples_list = len(samples_list) * [None] 345 | for i in range(len(samples_list)): 346 | inv_perm_samples_list[self.order[i]] = samples_list[i] 347 | 348 | return torch.cat( 349 | [s.float()[:, None] for s in inv_perm_samples_list if 350 | s is not None], dim=-1).detach() 351 | 352 | def reinit_from_data(self, data, var_types=None): 353 | """ 354 | Reinitializing Gaussians' parameters to better 355 | cover the latent space 356 | Also resets TT cores 357 | 358 | Args: 359 | data: tensor of shape [num_objects, num_vars], data to 360 | reinitialize the Gaussians 361 | var_types: 362 | """ 363 | new_tt_cores = [] 364 | new_means = [] 365 | new_log_stds = [] 366 | 367 | components = [] 368 | for i, var_descr in enumerate(self.distr_descr): 369 | cur_core = torch.randn(var_descr[1], self.tt_int, self.tt_int) 370 | cur_core = cur_core / (self.tt_int ** 2 * var_descr[1]) 371 | new_tt_cores.append(nn.Parameter(cur_core)) 372 | 373 | if torch.sum(torch.isnan(data[:, i])) == data.shape[0]: 374 | new_means.append(self.means[i]) 375 | new_log_stds.append(self.log_stds[i]) 376 | 377 | components.append(-1 * np.ones(data.shape[0])) 378 | continue 379 | 380 | if var_descr[0] == 'd': 381 | new_means.append(None) 382 | new_log_stds.append(None) 383 | cur_components = data[:, i].cpu().detach().numpy() 384 | cur_components[np.isnan(cur_components)] = -1 385 | 386 | elif var_descr[0] == 'c': 387 | gmm = GaussianMixture(n_components=var_descr[1]) 388 | cur_data = data[:, i].cpu().detach().numpy() 389 | non_missings = np.logical_not(np.isnan(cur_data)) 390 | 391 | cur_components = -1 * np.ones_like(cur_data) 392 | 393 | non_missed_data = cur_data[non_missings].reshape(-1, 1) 394 | gmm.fit(non_missed_data) 395 | cur_gmm_comp = gmm.predict(non_missed_data) 396 | cur_components[non_missings] = cur_gmm_comp 397 | 398 | cur_means = torch.from_numpy(gmm.means_[:, 0]) 399 | cur_means = cur_means.float() 400 | cur_log_stds = torch.from_numpy(gmm.covariances_[:, 0, 0]) 401 | cur_log_stds = torch.log(cur_log_stds.float() + self.eps) / 2 402 | 403 | new_means.append(nn.Parameter(cur_means)) 404 | new_log_stds.append(nn.Parameter(cur_log_stds)) 405 | 406 | components.append(cur_components.astype(np.int)) 407 | 408 | if var_types is not None: 409 | usual_vars_idxs = [i for i in range(len(var_types)) if 410 | var_types[i] == 0] 411 | target_vars_idxs = [i for i in range(len(var_types)) if 412 | var_types[i] == 1] 413 | 414 | scores = np.zeros((len(target_vars_idxs), len(usual_vars_idxs))) 415 | 416 | for i in range(len(target_vars_idxs)): 417 | for j in range(len(usual_vars_idxs)): 418 | tg_n = self.distr_descr[target_vars_idxs[i]][1] 419 | us_n = self.distr_descr[usual_vars_idxs[j]][1] 420 | mx = np.zeros((tg_n, us_n)) 421 | 422 | tg_comp = components[target_vars_idxs[i]] 423 | us_comp = components[usual_vars_idxs[j]] 424 | for x, y in zip(tg_comp, us_comp): 425 | if x != -1 and y != -1: 426 | mx[x, y] += 1 427 | 428 | if mx.sum() == 0: 429 | continue 430 | 431 | mx += 1e-10 432 | 433 | s = mx.sum(axis=0) 434 | mx = mx / s[None, :] 435 | 436 | scores[i, j] = -((np.log(mx) * mx).sum( 437 | axis=0) * s).sum() / s.sum() 438 | 439 | groups = np.argmin(scores, axis=0) 440 | 441 | new_order = [] 442 | 443 | for group_i in range(len(target_vars_idxs)): 444 | g_members = np.where(groups == group_i)[0] 445 | g_members = sorted(g_members, key=lambda s: scores[group_i, s]) 446 | new_group = [target_vars_idxs[group_i]] 447 | 448 | for i, member in enumerate(g_members): 449 | if i % 2 == 0: 450 | new_group = new_group + [usual_vars_idxs[member]] 451 | else: 452 | new_group = [usual_vars_idxs[member]] + new_group 453 | 454 | new_order += new_group 455 | 456 | self.order = new_order 457 | 458 | for i in range(len(self.tt_cores)): 459 | self.tt_cores[i].data = new_tt_cores[i].data 460 | if new_means[i] is not None: 461 | self.means[i].data = new_means[i].data 462 | self.log_stds[i].data = new_log_stds[i].data 463 | 464 | @staticmethod 465 | def _pos_func(x): 466 | return x * x 467 | 468 | def _make_model_parameters(self): 469 | parameters = [] 470 | for mean, log_std in zip(self.means, self.log_stds): 471 | if mean is None: 472 | continue 473 | parameters += [mean, log_std] 474 | 475 | for core in self.tt_cores: 476 | parameters.append(core) 477 | 478 | self.parameters = nn.ParameterList(parameters) 479 | --------------------------------------------------------------------------------