├── .dockerignore ├── .gitignore ├── Dockerfile ├── README.md ├── bo.ipynb ├── bo ├── gen_latent.py ├── run_bo.py └── run_experiment.sh ├── configs ├── mnist │ ├── N_N.ini │ └── U_U.ini ├── moses │ ├── DD-VAE_gaussian_seed1.ini │ ├── DD-VAE_gaussian_seed2.ini │ ├── DD-VAE_gaussian_seed3.ini │ ├── DD-VAE_triweight_seed1.ini │ ├── DD-VAE_triweight_seed2.ini │ ├── DD-VAE_triweight_seed3.ini │ ├── VAE_gaussian_seed1.ini │ ├── VAE_gaussian_seed2.ini │ ├── VAE_gaussian_seed3.ini │ ├── VAE_triweight_seed1.ini │ ├── VAE_triweight_seed2.ini │ └── VAE_triweight_seed3.ini ├── synthetic │ ├── N_N.ini │ ├── U_T.ini │ └── U_U.ini └── zinc │ ├── dd_vae_gaussian.ini │ ├── dd_vae_tricube.ini │ ├── vae_gaussian.ini │ └── vae_tricube.ini ├── data ├── moses │ ├── test.csv.gz │ ├── test_scaffolds.csv.gz │ ├── test_scaffolds_stats.npz │ ├── test_stats.npz │ └── train.csv.gz ├── synthetic │ └── 2d_map_0.2.csv.gz └── zinc │ ├── 250k_rndm_zinc_drugs_clean.smi.gz │ ├── test.csv.gz │ ├── train.csv.gz │ └── valid.csv.gz ├── dd_vae ├── __init__.py ├── bo │ ├── __init__.py │ ├── gauss.py │ ├── psd_theano.py │ ├── sparse_gp.py │ ├── sparse_gp_theano_internal.py │ └── utils.py ├── proposals.py ├── utils.py ├── vae_base.py ├── vae_mnist.py └── vae_rnn.py ├── illustrations.ipynb ├── images ├── .DS_Store ├── kernels.pdf ├── mnist │ ├── latent_N_N.png │ └── latent_U_U.png ├── moses_FCD.pdf ├── moses_SNN.pdf ├── smoothed_indicator.pdf ├── synthetic │ ├── N_N.png │ ├── U_T.png │ └── U_U.png └── zinc │ ├── DD_VAE_GAUSSIAN_molecule_0.pdf │ ├── DD_VAE_GAUSSIAN_molecule_1.pdf │ ├── DD_VAE_GAUSSIAN_molecule_2.pdf │ ├── DD_VAE_GAUSSIAN_top50_molecules.pdf │ ├── DD_VAE_TRICUBE_molecule_0.pdf │ ├── DD_VAE_TRICUBE_molecule_1.pdf │ ├── DD_VAE_TRICUBE_molecule_2.pdf │ ├── DD_VAE_TRICUBE_top50_molecules.pdf │ ├── VAE_GAUSSIAN_molecule_0.pdf │ ├── VAE_GAUSSIAN_molecule_1.pdf │ ├── VAE_GAUSSIAN_molecule_2.pdf │ ├── VAE_GAUSSIAN_top50_molecules.pdf │ ├── VAE_TRICUBE_molecule_0.pdf │ ├── VAE_TRICUBE_molecule_1.pdf │ ├── VAE_TRICUBE_molecule_2.pdf │ └── VAE_TRICUBE_top50_molecules.pdf ├── mnist.ipynb ├── moses_plots.ipynb ├── moses_prepare_metrics.ipynb ├── setup.py ├── synthetic.ipynb ├── train.py └── unit_test.py /.dockerignore: -------------------------------------------------------------------------------- 1 | models/ 2 | metrics/ 3 | logs/ 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | data/mnist 107 | models/ 108 | logs/ 109 | metrics/ 110 | bo/results/ 111 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04 2 | 3 | RUN mkdir -p /code 4 | 5 | RUN set -ex \ 6 | && apt-get update \ 7 | && apt-get install -y git vim less wget \ 8 | tmux libxrender1 libxext6 9 | 10 | RUN set -ex \ 11 | && wget https://repo.continuum.io/miniconda/Miniconda3-4.7.10-Linux-x86_64.sh \ 12 | && /bin/bash Miniconda3-4.7.10-Linux-x86_64.sh -f -b -p /opt/miniconda 13 | 14 | ENV PATH /opt/miniconda/bin:$PATH 15 | 16 | RUN conda install -y numpy=1.17.2 \ 17 | scipy=1.3.1 \ 18 | scikit-learn=0.20.3 \ 19 | matplotlib=3.1.1 \ 20 | pandas=0.25.1 \ 21 | notebook=6.0.0 \ 22 | networkx=2.3 \ 23 | ipywidgets=7.5.1 24 | 25 | RUN conda install -y -c pytorch cudatoolkit=9.0 pytorch=1.1.0 torchvision=0.2.1 26 | 27 | RUN conda install -y -c rdkit rdkit=2019.03.4 28 | 29 | RUN pip install Theano==1.0.4 molsets==0.2 tensorboardX==1.9 cairosvg==2.4.2 tqdm==4.42.0 30 | 31 | ADD . /code/dd_vae 32 | 33 | RUN cd /code/dd_vae && python setup.py install 34 | 35 | CMD [ "/bin/bash" ] 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deterministic Decoding for Discrete Data in Variational Autoencoders 2 | 3 | Variational autoencoders are prominent generative models for modeling discrete data. However, with flexible decoders, they tend to ignore the latent codes. In this paper, we study a VAE model with a deterministic decoder (DD-VAE) for sequential data that selects the highest-scoring tokens instead of sampling. Deterministic decoding solely relies on latent codes as the only way to produce diverse objects, which improves the structure of the learned manifold. To implement DD-VAE, we propose a new class of bounded support proposal distributions and derive Kullback-Leibler divergence for Gaussian and uniform priors. We also study a continuous relaxation of deterministic decoding objective function and analyze the relation of reconstruction accuracy and relaxation parameters. We demonstrate the performance of DD-VAE on multiple datasets, including molecular generation and optimization problems. 4 | 5 | For more details, please refer to the [full paper](https://arxiv.org/abs/2003.02174). 6 | 7 | ### Repository 8 | In this repository, we provide all code and data that is necessary to reproduce all the results from the paper. To reproduce the experiments, we recommend using Docker image built using a provided `Dockerfile`: 9 | ```{bash} 10 | nvidia-docker build -t dd_vae . 11 | nvidia-docker run -it --shm-size 10G --network="host" --name dd_vae -w=/code/dd_vae dd_vae 12 | ``` 13 | All the code will be available inside `/code/dd_vae` folder. For more details on using Docker, please refer to [Docker manual](https://docs.docker.com/) 14 | 15 | You can also install `dd_vae` locally by running `python setup.py install` command. 16 | 17 | ### Reproducing the experiments 18 | You can train any model using `train.py` script. This scripts takes only two arguments: `--config` (path to .ini file that sets up the experiment) and `--device` (PyTorch-style device naiming such as `cuda:0`). We provide all configuration files in `configs/` folder. For each experiment we provide a separate Jupyter Notebook, where you will find further instructions to reproduce the experiments: 19 | * [Synthetic](./synthetic.ipynb) 20 | * [MNIST](./mnist.ipynb) 21 | * [MOSES (metrics)](./moses_prepare_metrics.ipynb), [MOSES (plots)](./moses_plots.ipynb) 22 | * [ZINC](./bo.ipynb) 23 | 24 | ### How to cite 25 | ``` 26 | @InProceedings{pmlr-v108-polykovskiy20a, 27 | title = {Deterministic Decoding for Discrete Data in Variational Autoencoders}, 28 | author = {Polykovskiy, Daniil and Vetrov, Dmitry}, 29 | booktitle = {Proceedings of the Twenty Third International Conference on Artificial Intelligence and Statistics}, 30 | pages = {3046--3056}, 31 | year = {2020}, 32 | editor = {Silvia Chiappa and Roberto Calandra}, 33 | volume = {108}, 34 | series = {Proceedings of Machine Learning Research}, address = {Online}, 35 | month = {26--28 Aug}, 36 | publisher = {PMLR} 37 | } 38 | ``` 39 | -------------------------------------------------------------------------------- /bo/gen_latent.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from functools import partial 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import rdkit 9 | from dd_vae.bo.utils import max_ring_penalty 10 | from dd_vae.utils import collate, StringDataset, batch_to_device 11 | from dd_vae.vae_rnn import VAE_RNN 12 | from moses.metrics import SA 13 | from rdkit import Chem 14 | from rdkit.Chem import Descriptors 15 | from torch.utils.data import DataLoader 16 | from tqdm.auto import tqdm 17 | rdkit.rdBase.DisableLog('rdApp.*') 18 | 19 | 20 | def load_csv(path): 21 | if path.endswith('.gz'): 22 | df = pd.read_csv(path, compression='gzip', 23 | dtype='str', header=None) 24 | return list(df[0].values) 25 | return [x.strip() for x in open(path)] 26 | 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("--data", type=str, required=True) 30 | parser.add_argument("--model", type=str, required=True) 31 | parser.add_argument("--device", type=str, default="cpu") 32 | parser.add_argument("--save_dir", type=str, required=True) 33 | args = parser.parse_args(sys.argv[1:]) 34 | 35 | model = VAE_RNN.load(args.model) 36 | model = model.to(args.device) 37 | 38 | smiles = load_csv(args.data) 39 | 40 | logP_values = [] 41 | latent_points = [] 42 | cycle_scores = [] 43 | SA_scores = [] 44 | 45 | print("Preparing dataset...") 46 | collate_pad = partial(collate, pad=model.vocab.pad, return_data=True) 47 | dataset = StringDataset(model.vocab, smiles) 48 | data_loader = DataLoader(dataset, collate_fn=collate_pad, 49 | batch_size=512, shuffle=False) 50 | print("Getting latent codes...") 51 | for batch in tqdm(data_loader): 52 | z = model.encode(batch_to_device(batch[:-1], args.device)) 53 | mu, _ = model.get_mu_std(z) 54 | latent_points.append(mu.detach().cpu().numpy()) 55 | romol = [Chem.MolFromSmiles(x.strip()) for x in batch[-1]] 56 | logP_values.extend([Descriptors.MolLogP(m) for m in romol]) 57 | SA_scores.extend([-SA(m) for m in romol]) 58 | cycle_scores.extend([max_ring_penalty(m) for m in romol]) 59 | 60 | SA_scores = np.array(SA_scores) 61 | logP_values = np.array(logP_values) 62 | cycle_scores = np.array(cycle_scores) 63 | 64 | SA_scores_normalized = (SA_scores - SA_scores.mean()) / SA_scores.std() 65 | logP_values_normalized = (logP_values - logP_values.mean()) / logP_values.std() 66 | cycle_scores_normalized = ( 67 | cycle_scores - cycle_scores.mean()) / cycle_scores.std() 68 | 69 | latent_points = np.vstack(latent_points) 70 | 71 | targets = (SA_scores_normalized + 72 | logP_values_normalized + 73 | cycle_scores_normalized) 74 | os.makedirs(args.save_dir, exist_ok=True) 75 | np.savez_compressed(os.path.join(args.save_dir, 'features.npz'), 76 | latent_points=latent_points, 77 | targets=targets, logP_values=logP_values, 78 | SA_scores=SA_scores, cycle_scores=cycle_scores) 79 | -------------------------------------------------------------------------------- /bo/run_bo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gzip 3 | import os 4 | import pickle 5 | import sys 6 | 7 | import numpy as np 8 | import rdkit 9 | import scipy.stats as sps 10 | import torch 11 | from dd_vae.bo.sparse_gp import SparseGP 12 | from dd_vae.bo.utils import max_ring_penalty 13 | from dd_vae.utils import prepare_seed 14 | from dd_vae.vae_rnn import VAE_RNN 15 | from moses.metrics import SA 16 | from rdkit import Chem 17 | from rdkit.Chem import Descriptors 18 | from rdkit.Chem import MolFromSmiles 19 | 20 | rdkit.rdBase.DisableLog('rdApp.*') 21 | 22 | 23 | # We define the functions used to load and save objects 24 | def save_object(obj, filename): 25 | result = pickle.dumps(obj) 26 | with gzip.GzipFile(filename, 'wb') as dest: 27 | dest.write(result) 28 | dest.close() 29 | 30 | 31 | def load_object(filename): 32 | with gzip.GzipFile(filename, 'rb') as source: 33 | result = source.read() 34 | ret = pickle.loads(result) 35 | source.close() 36 | return ret 37 | 38 | 39 | parser = argparse.ArgumentParser() 40 | 41 | parser.add_argument("--model", type=str, required=True) 42 | parser.add_argument("--save_dir", type=str, required=True) 43 | parser.add_argument("--device", type=str, default="cpu") 44 | parser.add_argument("--seed", type=int, default=777) 45 | parser.add_argument("--load_dir", type=str, required=True) 46 | args = parser.parse_args(sys.argv[1:]) 47 | 48 | prepare_seed(args.seed) 49 | 50 | model = VAE_RNN.load(args.model).to(args.device) 51 | 52 | # We load the data (y is minued!) 53 | data = np.load(os.path.join(args.load_dir, 'features.npz')) 54 | X = data['latent_points'] 55 | y = -data['targets'] 56 | y = y.reshape((-1, 1)) 57 | 58 | n = X.shape[0] 59 | 60 | permutation = np.random.choice(n, n, replace=False) 61 | 62 | X_train = X[permutation, :][0: np.int(np.round(0.9 * n)), :] 63 | X_test = X[permutation, :][np.int(np.round(0.9 * n)):, :] 64 | 65 | y_train = y[permutation][0: np.int(np.round(0.9 * n))] 66 | y_test = y[permutation][np.int(np.round(0.9 * n)):] 67 | 68 | np.random.seed(args.seed) 69 | 70 | logP_values = data['logP_values'] 71 | SA_scores = data['SA_scores'] 72 | cycle_scores = data['cycle_scores'] 73 | SA_scores_normalized = (np.array(SA_scores) - np.mean(SA_scores)) / np.std( 74 | SA_scores) 75 | logP_values_normalized = (np.array(logP_values) - np.mean( 76 | logP_values)) / np.std(logP_values) 77 | cycle_scores_normalized = (np.array(cycle_scores) - np.mean( 78 | cycle_scores)) / np.std(cycle_scores) 79 | 80 | iteration = 0 81 | while iteration < 5: 82 | # We fit the GP 83 | np.random.seed(iteration * args.seed) 84 | M = 500 85 | sgp = SparseGP(X_train, 0 * X_train, y_train, M) 86 | sgp.train_via_ADAM(X_train, 0 * X_train, y_train, X_test, X_test * 0, 87 | y_test, minibatch_size=10 * M, max_iterations=100, 88 | learning_rate=0.001) 89 | 90 | pred, uncert = sgp.predict(X_test, 0 * X_test) 91 | error = np.sqrt(np.mean((pred - y_test) ** 2)) 92 | testll = np.mean(sps.norm.logpdf(pred - y_test, scale=np.sqrt(uncert))) 93 | print('Test RMSE:', error) 94 | print('Test ll:', testll) 95 | 96 | pred, uncert = sgp.predict(X_train, 0 * X_train) 97 | error = np.sqrt(np.mean((pred - y_train) ** 2)) 98 | trainll = np.mean(sps.norm.logpdf(pred - y_train, scale=np.sqrt(uncert))) 99 | print('Train RMSE:', error) 100 | print('Train ll:', trainll) 101 | 102 | # We pick the next 60 inputs 103 | iters = 60 104 | next_inputs = sgp.batched_greedy_ei(iters, np.min(X_train, 0), 105 | np.max(X_train, 0)) 106 | valid_smiles = [] 107 | new_features = [] 108 | for i in range(iters): 109 | all_vec = next_inputs[i].reshape((1, -1)) 110 | smiles = model.sample(1, z=torch.tensor(all_vec).float())[0] 111 | mol = Chem.MolFromSmiles(smiles) 112 | if mol is None: 113 | continue 114 | err = Chem.SanitizeMol(mol, catchErrors=True) 115 | if err != 0: 116 | continue 117 | valid_smiles.append(smiles) 118 | new_features.append(all_vec) 119 | 120 | valid_smiles = valid_smiles[:50] 121 | if len(new_features) != 0: 122 | new_features = np.vstack(new_features)[:50] 123 | else: 124 | new_features = np.zeros((0, X_train.shape[1])) 125 | os.makedirs(args.save_dir, exist_ok=True) 126 | save_object(valid_smiles, 127 | os.path.join(args.save_dir, 128 | "valid_smiles{}.dat".format(iteration))) 129 | 130 | scores = [] 131 | for i in range(len(valid_smiles)): 132 | mol = MolFromSmiles(valid_smiles[i]) 133 | current_log_P_value = Descriptors.MolLogP(mol) 134 | current_SA_score = -SA(mol) 135 | current_cycle_score = max_ring_penalty(mol) 136 | 137 | current_SA_score_normalized = (current_SA_score - np.mean( 138 | SA_scores)) / np.std(SA_scores) 139 | current_log_P_value_normalized = (current_log_P_value - np.mean( 140 | logP_values)) / np.std(logP_values) 141 | current_cycle_score_normalized = (current_cycle_score - np.mean( 142 | cycle_scores)) / np.std(cycle_scores) 143 | 144 | score = (current_SA_score_normalized + 145 | current_log_P_value_normalized + 146 | current_cycle_score_normalized) 147 | scores.append(-score) # target is always minused 148 | 149 | print(f"{len(valid_smiles)} molecules found. Scores: {scores}") 150 | save_object(scores, 151 | os.path.join(args.save_dir, "scores{}.dat".format(iteration))) 152 | 153 | if len(new_features) > 0: 154 | X_train = np.concatenate([X_train, new_features], 0) 155 | y_train = np.concatenate([y_train, np.array(scores)[:, None]], 0) 156 | 157 | iteration += 1 158 | -------------------------------------------------------------------------------- /bo/run_experiment.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | mkdir -p bo/results/$2/ 3 | python bo/gen_latent.py \ 4 | --data data/zinc/250k_rndm_zinc_drugs_clean.smi.gz \ 5 | --model $1 \ 6 | --device $3 --save_dir bo/results/$2/ 7 | 8 | for SEED in $(seq 1 10) 9 | do 10 | mkdir -p bo/results/$2/experiment\_$SEED/ 11 | python bo/run_bo.py \ 12 | --model $1 \ 13 | --save_dir bo/results/$2/experiment\_$SEED/ \ 14 | --device $3 \ 15 | --seed $SEED \ 16 | --load_dir bo/results/$2/ > bo/results/$2/experiment\_$SEED/log.txt & 17 | sleep 20 18 | done 19 | -------------------------------------------------------------------------------- /configs/mnist/N_N.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | layer_sizes = [256, 128, 32] 3 | latent_size = 2 4 | proposal = 'gaussian' 5 | prior = 'gaussian' 6 | 7 | [data] 8 | title = 'MNIST' 9 | 10 | [train] 11 | epochs = 150 12 | lr_reduce_epochs = 20 13 | lr_reduce_gamma = 0.5 14 | lr = 5e-3 15 | batch_size = 512 16 | verbose = 'epoch' 17 | clamp = 10 18 | checkpoint = 'single' 19 | mode = 'sample' 20 | 21 | [kl] 22 | start_epoch = 4 23 | end_epoch = 100 24 | start = 1e-5 25 | end = 0.005 26 | 27 | [temperature] 28 | start_epoch = 0 29 | end_epoch = 100 30 | start = 0.01 31 | end = 0.001 32 | log = True 33 | 34 | [save] 35 | log_dir = 'logs/mnist/N_N/' 36 | model_dir = 'models/mnist/N_N/' 37 | -------------------------------------------------------------------------------- /configs/mnist/U_U.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | layer_sizes = [256, 128, 32] 3 | latent_size = 2 4 | proposal = 'uniform' 5 | prior = 'uniform' 6 | 7 | [data] 8 | title = 'MNIST' 9 | 10 | [train] 11 | epochs = 150 12 | lr_reduce_epochs = 20 13 | lr_reduce_gamma = 0.5 14 | lr = 5e-3 15 | batch_size = 512 16 | verbose = 'epoch' 17 | checkpoint = 'single' 18 | mode = 'argmax' 19 | 20 | [kl] 21 | start_epoch = 4 22 | end_epoch = 100 23 | start = 1e-5 24 | end = 0.05 25 | 26 | [temperature] 27 | start_epoch = 0 28 | end_epoch = 100 29 | start = 0.01 30 | end = 0.001 31 | log = True 32 | 33 | [save] 34 | log_dir = 'logs/mnist/U_U/' 35 | model_dir = 'models/mnist/U_U/' 36 | -------------------------------------------------------------------------------- /configs/moses/DD-VAE_gaussian_seed1.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | embedding_size = 64 3 | hidden_size = 512 4 | latent_size = 64 5 | num_layers = 2 6 | proposal = 'gaussian' 7 | prior = 'gaussian' 8 | use_embedding_input = True 9 | 10 | [train] 11 | epochs = 200 12 | lr_reduce_epochs = [20] 13 | lr_reduce_gamma = 0.5 14 | lr = 5e-4 15 | batch_size = 512 16 | verbose = 'epoch' 17 | clamp = 10 18 | checkpoint = 'epoch' 19 | mode = 'argmax' 20 | seed = 1 21 | 22 | [data] 23 | title = 'moses' 24 | train_path = 'data/moses/train.csv.gz' 25 | test_path = 'data/moses/test.csv.gz' 26 | 27 | [kl] 28 | start_epoch = 20 29 | end_epoch = 200 30 | start = 0.0015 31 | end = 0.02 32 | 33 | [temperature] 34 | start_epoch = 0 35 | end_epoch = 10 36 | start = 0.2 37 | end = 0.1 38 | log = True 39 | 40 | [save] 41 | log_dir = 'logs/moses/DD-VAE_gaussian_seed1/' 42 | model_dir = 'models/moses/DD-VAE_gaussian_seed1/' 43 | -------------------------------------------------------------------------------- /configs/moses/DD-VAE_gaussian_seed2.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | embedding_size = 64 3 | hidden_size = 512 4 | latent_size = 64 5 | num_layers = 2 6 | proposal = 'gaussian' 7 | prior = 'gaussian' 8 | use_embedding_input = True 9 | 10 | [train] 11 | epochs = 200 12 | lr_reduce_epochs = [20] 13 | lr_reduce_gamma = 0.5 14 | lr = 5e-4 15 | batch_size = 512 16 | verbose = 'epoch' 17 | clamp = 10 18 | checkpoint = 'epoch' 19 | mode = 'argmax' 20 | seed = 2 21 | 22 | [data] 23 | title = 'moses' 24 | train_path = 'data/moses/train.csv.gz' 25 | test_path = 'data/moses/test.csv.gz' 26 | 27 | [kl] 28 | start_epoch = 20 29 | end_epoch = 200 30 | start = 0.0015 31 | end = 0.02 32 | 33 | [temperature] 34 | start_epoch = 0 35 | end_epoch = 10 36 | start = 0.2 37 | end = 0.1 38 | log = True 39 | 40 | [save] 41 | log_dir = 'logs/moses/DD-VAE_gaussian_seed2/' 42 | model_dir = 'models/moses/DD-VAE_gaussian_seed2/' 43 | -------------------------------------------------------------------------------- /configs/moses/DD-VAE_gaussian_seed3.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | embedding_size = 64 3 | hidden_size = 512 4 | latent_size = 64 5 | num_layers = 2 6 | proposal = 'gaussian' 7 | prior = 'gaussian' 8 | use_embedding_input = True 9 | 10 | [train] 11 | epochs = 200 12 | lr_reduce_epochs = [20] 13 | lr_reduce_gamma = 0.5 14 | lr = 5e-4 15 | batch_size = 512 16 | verbose = 'epoch' 17 | clamp = 10 18 | checkpoint = 'epoch' 19 | mode = 'argmax' 20 | seed = 3 21 | 22 | [data] 23 | title = 'moses' 24 | train_path = 'data/moses/train.csv.gz' 25 | test_path = 'data/moses/test.csv.gz' 26 | 27 | [kl] 28 | start_epoch = 20 29 | end_epoch = 200 30 | start = 0.0015 31 | end = 0.02 32 | 33 | [temperature] 34 | start_epoch = 0 35 | end_epoch = 10 36 | start = 0.2 37 | end = 0.1 38 | log = True 39 | 40 | [save] 41 | log_dir = 'logs/moses/DD-VAE_gaussian_seed3/' 42 | model_dir = 'models/moses/DD-VAE_gaussian_seed3/' 43 | -------------------------------------------------------------------------------- /configs/moses/DD-VAE_triweight_seed1.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | embedding_size = 64 3 | hidden_size = 512 4 | latent_size = 64 5 | num_layers = 2 6 | proposal = 'triweight' 7 | prior = 'gaussian' 8 | use_embedding_input = True 9 | 10 | [train] 11 | epochs = 200 12 | lr_reduce_epochs = [20] 13 | lr_reduce_gamma = 0.5 14 | lr = 5e-4 15 | batch_size = 512 16 | verbose = 'epoch' 17 | clamp = 10 18 | checkpoint = 'epoch' 19 | mode = 'argmax' 20 | seed = 1 21 | 22 | [data] 23 | title = 'moses' 24 | train_path = 'data/moses/train.csv.gz' 25 | test_path = 'data/moses/test.csv.gz' 26 | 27 | [kl] 28 | start_epoch = 20 29 | end_epoch = 200 30 | start = 0.0015 31 | end = 0.02 32 | 33 | [temperature] 34 | start_epoch = 0 35 | end_epoch = 10 36 | start = 0.2 37 | end = 0.1 38 | log = True 39 | 40 | [save] 41 | log_dir = 'logs/moses/DD-VAE_triweight_seed1/' 42 | model_dir = 'models/moses/DD-VAE_triweight_seed1/' 43 | -------------------------------------------------------------------------------- /configs/moses/DD-VAE_triweight_seed2.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | embedding_size = 64 3 | hidden_size = 512 4 | latent_size = 64 5 | num_layers = 2 6 | proposal = 'triweight' 7 | prior = 'gaussian' 8 | use_embedding_input = True 9 | 10 | [train] 11 | epochs = 200 12 | lr_reduce_epochs = [20] 13 | lr_reduce_gamma = 0.5 14 | lr = 5e-4 15 | batch_size = 512 16 | verbose = 'epoch' 17 | clamp = 10 18 | checkpoint = 'epoch' 19 | mode = 'argmax' 20 | seed = 2 21 | 22 | [data] 23 | title = 'moses' 24 | train_path = 'data/moses/train.csv.gz' 25 | test_path = 'data/moses/test.csv.gz' 26 | 27 | [kl] 28 | start_epoch = 20 29 | end_epoch = 200 30 | start = 0.0015 31 | end = 0.02 32 | 33 | [temperature] 34 | start_epoch = 0 35 | end_epoch = 10 36 | start = 0.2 37 | end = 0.1 38 | log = True 39 | 40 | [save] 41 | log_dir = 'logs/moses/DD-VAE_triweight_seed2/' 42 | model_dir = 'models/moses/DD-VAE_triweight_seed2/' 43 | -------------------------------------------------------------------------------- /configs/moses/DD-VAE_triweight_seed3.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | embedding_size = 64 3 | hidden_size = 512 4 | latent_size = 64 5 | num_layers = 2 6 | proposal = 'triweight' 7 | prior = 'gaussian' 8 | use_embedding_input = True 9 | 10 | [train] 11 | epochs = 200 12 | lr_reduce_epochs = [20] 13 | lr_reduce_gamma = 0.5 14 | lr = 5e-4 15 | batch_size = 512 16 | verbose = 'epoch' 17 | clamp = 10 18 | checkpoint = 'epoch' 19 | mode = 'argmax' 20 | seed = 3 21 | 22 | [data] 23 | title = 'moses' 24 | train_path = 'data/moses/train.csv.gz' 25 | test_path = 'data/moses/test.csv.gz' 26 | 27 | [kl] 28 | start_epoch = 20 29 | end_epoch = 200 30 | start = 0.0015 31 | end = 0.02 32 | 33 | [temperature] 34 | start_epoch = 0 35 | end_epoch = 10 36 | start = 0.2 37 | end = 0.1 38 | log = True 39 | 40 | [save] 41 | log_dir = 'logs/moses/DD-VAE_triweight_seed3/' 42 | model_dir = 'models/moses/DD-VAE_triweight_seed3/' 43 | -------------------------------------------------------------------------------- /configs/moses/VAE_gaussian_seed1.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | embedding_size = 64 3 | hidden_size = 512 4 | latent_size = 64 5 | num_layers = 2 6 | proposal = 'gaussian' 7 | prior = 'gaussian' 8 | use_embedding_input = True 9 | 10 | [train] 11 | epochs = 200 12 | lr_reduce_epochs = [20] 13 | lr_reduce_gamma = 0.5 14 | lr = 5e-4 15 | batch_size = 512 16 | verbose = 'epoch' 17 | clamp = 10 18 | checkpoint = 'epoch' 19 | mode = 'sample' 20 | seed = 1 21 | 22 | [data] 23 | title = 'moses' 24 | train_path = 'data/moses/train.csv.gz' 25 | test_path = 'data/moses/test.csv.gz' 26 | 27 | [kl] 28 | start_epoch = 20 29 | end_epoch = 200 30 | start = 0.0005 31 | end = 0.01 32 | 33 | [temperature] 34 | start_epoch = 0 35 | end_epoch = 10 36 | start = 0.2 37 | end = 0.1 38 | log = True 39 | 40 | [save] 41 | log_dir = 'logs/moses/VAE_gaussian_seed1/' 42 | model_dir = 'models/moses/VAE_gaussian_seed1/' 43 | -------------------------------------------------------------------------------- /configs/moses/VAE_gaussian_seed2.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | embedding_size = 64 3 | hidden_size = 512 4 | latent_size = 64 5 | num_layers = 2 6 | proposal = 'gaussian' 7 | prior = 'gaussian' 8 | use_embedding_input = True 9 | 10 | [train] 11 | epochs = 200 12 | lr_reduce_epochs = [20] 13 | lr_reduce_gamma = 0.5 14 | lr = 5e-4 15 | batch_size = 512 16 | verbose = 'epoch' 17 | clamp = 10 18 | checkpoint = 'epoch' 19 | mode = 'sample' 20 | seed = 2 21 | 22 | [data] 23 | title = 'moses' 24 | train_path = 'data/moses/train.csv.gz' 25 | test_path = 'data/moses/test.csv.gz' 26 | 27 | [kl] 28 | start_epoch = 20 29 | end_epoch = 200 30 | start = 0.0005 31 | end = 0.01 32 | 33 | [temperature] 34 | start_epoch = 0 35 | end_epoch = 10 36 | start = 0.2 37 | end = 0.1 38 | log = True 39 | 40 | [save] 41 | log_dir = 'logs/moses/VAE_gaussian_seed2/' 42 | model_dir = 'models/moses/VAE_gaussian_seed2/' 43 | -------------------------------------------------------------------------------- /configs/moses/VAE_gaussian_seed3.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | embedding_size = 64 3 | hidden_size = 512 4 | latent_size = 64 5 | num_layers = 2 6 | proposal = 'gaussian' 7 | prior = 'gaussian' 8 | use_embedding_input = True 9 | 10 | [train] 11 | epochs = 200 12 | lr_reduce_epochs = [20] 13 | lr_reduce_gamma = 0.5 14 | lr = 5e-4 15 | batch_size = 512 16 | verbose = 'epoch' 17 | clamp = 10 18 | checkpoint = 'epoch' 19 | mode = 'sample' 20 | seed = 3 21 | 22 | [data] 23 | title = 'moses' 24 | train_path = 'data/moses/train.csv.gz' 25 | test_path = 'data/moses/test.csv.gz' 26 | 27 | [kl] 28 | start_epoch = 20 29 | end_epoch = 200 30 | start = 0.0005 31 | end = 0.01 32 | 33 | [temperature] 34 | start_epoch = 0 35 | end_epoch = 10 36 | start = 0.2 37 | end = 0.1 38 | log = True 39 | 40 | [save] 41 | log_dir = 'logs/moses/VAE_gaussian_seed3/' 42 | model_dir = 'models/moses/VAE_gaussian_seed3/' 43 | -------------------------------------------------------------------------------- /configs/moses/VAE_triweight_seed1.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | embedding_size = 64 3 | hidden_size = 512 4 | latent_size = 64 5 | num_layers = 2 6 | proposal = 'triweight' 7 | prior = 'gaussian' 8 | use_embedding_input = True 9 | 10 | [train] 11 | epochs = 200 12 | lr_reduce_epochs = [20] 13 | lr_reduce_gamma = 0.5 14 | lr = 5e-4 15 | batch_size = 512 16 | verbose = 'epoch' 17 | clamp = 10 18 | checkpoint = 'epoch' 19 | mode = 'sample' 20 | seed = 1 21 | 22 | [data] 23 | title = 'moses' 24 | train_path = 'data/moses/train.csv.gz' 25 | test_path = 'data/moses/test.csv.gz' 26 | 27 | [kl] 28 | start_epoch = 20 29 | end_epoch = 200 30 | start = 0.0005 31 | end = 0.01 32 | 33 | [temperature] 34 | start_epoch = 0 35 | end_epoch = 10 36 | start = 0.2 37 | end = 0.1 38 | log = True 39 | 40 | [save] 41 | log_dir = 'logs/moses/VAE_triweight_seed1/' 42 | model_dir = 'models/moses/VAE_triweight_seed1/' 43 | -------------------------------------------------------------------------------- /configs/moses/VAE_triweight_seed2.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | embedding_size = 64 3 | hidden_size = 512 4 | latent_size = 64 5 | num_layers = 2 6 | proposal = 'triweight' 7 | prior = 'gaussian' 8 | use_embedding_input = True 9 | 10 | [train] 11 | epochs = 200 12 | lr_reduce_epochs = [20] 13 | lr_reduce_gamma = 0.5 14 | lr = 5e-4 15 | batch_size = 512 16 | verbose = 'epoch' 17 | clamp = 10 18 | checkpoint = 'epoch' 19 | mode = 'sample' 20 | seed = 2 21 | 22 | [data] 23 | title = 'moses' 24 | train_path = 'data/moses/train.csv.gz' 25 | test_path = 'data/moses/test.csv.gz' 26 | 27 | [kl] 28 | start_epoch = 20 29 | end_epoch = 200 30 | start = 0.0005 31 | end = 0.01 32 | 33 | [temperature] 34 | start_epoch = 0 35 | end_epoch = 10 36 | start = 0.2 37 | end = 0.1 38 | log = True 39 | 40 | [save] 41 | log_dir = 'logs/moses/VAE_triweight_seed2/' 42 | model_dir = 'models/moses/VAE_triweight_seed2/' 43 | -------------------------------------------------------------------------------- /configs/moses/VAE_triweight_seed3.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | embedding_size = 64 3 | hidden_size = 512 4 | latent_size = 64 5 | num_layers = 2 6 | proposal = 'triweight' 7 | prior = 'gaussian' 8 | use_embedding_input = True 9 | 10 | [train] 11 | epochs = 200 12 | lr_reduce_epochs = [20] 13 | lr_reduce_gamma = 0.5 14 | lr = 5e-4 15 | batch_size = 512 16 | verbose = 'epoch' 17 | clamp = 10 18 | checkpoint = 'epoch' 19 | mode = 'sample' 20 | seed = 3 21 | 22 | [data] 23 | title = 'moses' 24 | train_path = 'data/moses/train.csv.gz' 25 | test_path = 'data/moses/test.csv.gz' 26 | 27 | [kl] 28 | start_epoch = 20 29 | end_epoch = 200 30 | start = 0.0005 31 | end = 0.01 32 | 33 | [temperature] 34 | start_epoch = 0 35 | end_epoch = 10 36 | start = 0.2 37 | end = 0.1 38 | log = True 39 | 40 | [save] 41 | log_dir = 'logs/moses/VAE_triweight_seed3/' 42 | model_dir = 'models/moses/VAE_triweight_seed3/' 43 | -------------------------------------------------------------------------------- /configs/synthetic/N_N.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | embedding_size = 8 3 | hidden_size = 128 4 | latent_size = 2 5 | num_layers = 2 6 | proposal = 'gaussian' 7 | prior = 'gaussian' 8 | 9 | [train] 10 | epochs = 100 11 | lr_reduce_epochs = 20 12 | lr_reduce_gamma = 0.5 13 | lr = 5e-3 14 | batch_size = 512 15 | verbose = 'epoch' 16 | clamp = 2 17 | checkpoint = 'single' 18 | fine_tune = 10 19 | mode = 'sample' 20 | 21 | [data] 22 | title = 'p0.2' 23 | train_path = 'data/synthetic/2d_map_0.2.csv.gz' 24 | 25 | [kl] 26 | start_epoch = 2 27 | end_epoch = 100 28 | start = 0 29 | end = 0.1 30 | 31 | [temperature] 32 | start_epoch = 0 33 | end_epoch = 100 34 | start = 1e-1 35 | end = 1e-3 36 | log = True 37 | 38 | [save] 39 | log_dir = 'logs/synthetic/normal_normal/' 40 | model_dir = 'models/synthetic/normal_normal/' -------------------------------------------------------------------------------- /configs/synthetic/U_T.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | embedding_size = 8 3 | hidden_size = 128 4 | latent_size = 2 5 | num_layers = 2 6 | proposal = 'tricube' 7 | prior = 'uniform' 8 | 9 | [train] 10 | epochs = 100 11 | lr_reduce_epochs = 20 12 | lr_reduce_gamma = 0.5 13 | lr = 5e-3 14 | batch_size = 512 15 | verbose = 'epoch' 16 | clamp = 2 17 | checkpoint = 'single' 18 | fine_tune = 10 19 | mode = 'argmax' 20 | 21 | [data] 22 | title = 'p0.2' 23 | train_path = 'data/synthetic/2d_map_0.2.csv.gz' 24 | 25 | [kl] 26 | start_epoch = 2 27 | end_epoch = 100 28 | start = 0 29 | end = 0.1 30 | 31 | [temperature] 32 | start_epoch = 0 33 | end_epoch = 100 34 | start = 1e-1 35 | end = 1e-2 36 | log = True 37 | 38 | [save] 39 | log_dir = 'logs/synthetic/uniform_tricube/' 40 | model_dir = 'models/synthetic/uniform_tricube/' -------------------------------------------------------------------------------- /configs/synthetic/U_U.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | embedding_size = 8 3 | hidden_size = 128 4 | latent_size = 2 5 | num_layers = 2 6 | proposal = 'uniform' 7 | prior = 'uniform' 8 | 9 | [train] 10 | epochs = 100 11 | lr_reduce_epochs = 20 12 | lr_reduce_gamma = 0.5 13 | lr = 5e-3 14 | batch_size = 512 15 | verbose = 'epoch' 16 | clamp = 2 17 | checkpoint = 'single' 18 | fine_tune = 10 19 | mode = 'argmax' 20 | 21 | [data] 22 | title = 'p0.2' 23 | train_path = 'data/synthetic/2d_map_0.2.csv.gz' 24 | 25 | [kl] 26 | start_epoch = 2 27 | end_epoch = 100 28 | start = 0 29 | end = 1 30 | 31 | [temperature] 32 | start_epoch = 0 33 | end_epoch = 100 34 | start = 1e-1 35 | end = 1e-3 36 | log = True 37 | 38 | [save] 39 | log_dir = 'logs/synthetic/uniform_uniform/' 40 | model_dir = 'models/synthetic/uniform_uniform/' -------------------------------------------------------------------------------- /configs/zinc/dd_vae_gaussian.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | embedding_size = 64 3 | hidden_size = 1024 4 | latent_size = 64 5 | num_layers = 1 6 | proposal = 'gaussian' 7 | prior = 'gaussian' 8 | use_embedding_input = True 9 | 10 | [train] 11 | epochs = 200 12 | lr_reduce_epochs = 50 13 | lr_reduce_gamma = 0.5 14 | lr = 5e-4 15 | batch_size = 512 16 | verbose = 'epoch' 17 | clamp = 2 18 | checkpoint = 'single' 19 | mode = 'argmax' 20 | seed = 1 21 | 22 | [data] 23 | title = 'zinc' 24 | train_path = 'data/zinc/train.csv.gz' 25 | test_path = 'data/zinc/valid.csv.gz' 26 | 27 | [kl] 28 | start_epoch = 0 29 | end_epoch = 50 30 | start = 0.001 31 | end = 0.02 32 | 33 | [temperature] 34 | start_epoch = 0 35 | end_epoch = 100 36 | start = 0.001 37 | end = 1e-4 38 | log = True 39 | 40 | [save] 41 | log_dir = 'logs/zinc/DD_VAE_GAUSSIAN/' 42 | model_dir = 'models/zinc/DD_VAE_GAUSSIAN/' 43 | -------------------------------------------------------------------------------- /configs/zinc/dd_vae_tricube.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | embedding_size = 64 3 | hidden_size = 1024 4 | latent_size = 64 5 | num_layers = 1 6 | proposal = 'tricube' 7 | prior = 'gaussian' 8 | use_embedding_input = True 9 | 10 | [train] 11 | epochs = 200 12 | lr_reduce_epochs = 50 13 | lr_reduce_gamma = 0.5 14 | lr = 5e-4 15 | batch_size = 512 16 | verbose = 'epoch' 17 | clamp = 2 18 | checkpoint = 'single' 19 | mode = 'argmax' 20 | seed = 1 21 | 22 | [data] 23 | title = 'zinc' 24 | train_path = 'data/zinc/train.csv.gz' 25 | test_path = 'data/zinc/valid.csv.gz' 26 | 27 | [kl] 28 | start_epoch = 0 29 | end_epoch = 50 30 | start = 0.001 31 | end = 0.02 32 | 33 | [temperature] 34 | start_epoch = 0 35 | end_epoch = 100 36 | start = 0.001 37 | end = 1e-4 38 | log = True 39 | 40 | [save] 41 | log_dir = 'logs/zinc/DD_VAE_TRICUBE/' 42 | model_dir = 'models/zinc/DD_VAE_TRICUBE/' 43 | -------------------------------------------------------------------------------- /configs/zinc/vae_gaussian.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | embedding_size = 64 3 | hidden_size = 1024 4 | latent_size = 64 5 | num_layers = 1 6 | proposal = 'gaussian' 7 | prior = 'gaussian' 8 | use_embedding_input = True 9 | 10 | [train] 11 | epochs = 200 12 | lr_reduce_epochs = 50 13 | lr_reduce_gamma = 0.5 14 | lr = 5e-4 15 | batch_size = 512 16 | verbose = 'epoch' 17 | clamp = 2 18 | checkpoint = 'single' 19 | mode = 'sample' 20 | seed = 1 21 | 22 | [data] 23 | title = 'zinc' 24 | train_path = 'data/zinc/train.csv.gz' 25 | test_path = 'data/zinc/valid.csv.gz' 26 | 27 | [kl] 28 | start_epoch = 0 29 | end_epoch = 50 30 | start = 0.0001 31 | end = 0.0008 32 | 33 | [temperature] 34 | start_epoch = 0 35 | end_epoch = 100 36 | start = 0.001 37 | end = 1e-4 38 | log = True 39 | 40 | [save] 41 | log_dir = 'logs/zinc/VAE_GAUSSIAN/' 42 | model_dir = 'models/zinc/VAE_GAUSSIAN/' 43 | -------------------------------------------------------------------------------- /configs/zinc/vae_tricube.ini: -------------------------------------------------------------------------------- 1 | [model] 2 | embedding_size = 64 3 | hidden_size = 1024 4 | latent_size = 64 5 | num_layers = 1 6 | proposal = 'tricube' 7 | prior = 'gaussian' 8 | use_embedding_input = True 9 | 10 | [train] 11 | epochs = 200 12 | lr_reduce_epochs = 50 13 | lr_reduce_gamma = 0.5 14 | lr = 5e-4 15 | batch_size = 512 16 | verbose = 'epoch' 17 | clamp = 2 18 | checkpoint = 'single' 19 | mode = 'sample' 20 | seed = 1 21 | 22 | [data] 23 | title = 'zinc' 24 | train_path = 'data/zinc/train.csv.gz' 25 | test_path = 'data/zinc/valid.csv.gz' 26 | 27 | [kl] 28 | start_epoch = 0 29 | end_epoch = 50 30 | start = 0.0001 31 | end = 0.0008 32 | 33 | [temperature] 34 | start_epoch = 0 35 | end_epoch = 100 36 | start = 0.001 37 | end = 1e-4 38 | log = True 39 | 40 | [save] 41 | log_dir = 'logs/zinc/VAE_TRICUBE/' 42 | model_dir = 'models/zinc/VAE_TRICUBE/' 43 | -------------------------------------------------------------------------------- /data/moses/test.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/data/moses/test.csv.gz -------------------------------------------------------------------------------- /data/moses/test_scaffolds.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/data/moses/test_scaffolds.csv.gz -------------------------------------------------------------------------------- /data/moses/test_scaffolds_stats.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/data/moses/test_scaffolds_stats.npz -------------------------------------------------------------------------------- /data/moses/test_stats.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/data/moses/test_stats.npz -------------------------------------------------------------------------------- /data/moses/train.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/data/moses/train.csv.gz -------------------------------------------------------------------------------- /data/synthetic/2d_map_0.2.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/data/synthetic/2d_map_0.2.csv.gz -------------------------------------------------------------------------------- /data/zinc/250k_rndm_zinc_drugs_clean.smi.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/data/zinc/250k_rndm_zinc_drugs_clean.smi.gz -------------------------------------------------------------------------------- /data/zinc/test.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/data/zinc/test.csv.gz -------------------------------------------------------------------------------- /data/zinc/train.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/data/zinc/train.csv.gz -------------------------------------------------------------------------------- /data/zinc/valid.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/data/zinc/valid.csv.gz -------------------------------------------------------------------------------- /dd_vae/__init__.py: -------------------------------------------------------------------------------- 1 | from .vae_rnn import VAE_RNN 2 | from .vae_mnist import VAE_MNIST 3 | 4 | __all__ = ['VAE_RNN', 'VAE_MNIST'] 5 | -------------------------------------------------------------------------------- /dd_vae/bo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/dd_vae/bo/__init__.py -------------------------------------------------------------------------------- /dd_vae/bo/gauss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import theano 3 | import theano.tensor as T 4 | 5 | 6 | def casting(x): 7 | return np.array(x).astype(theano.config.floatX) 8 | 9 | 10 | def compute_kernel(lls, lsf, x, z): 11 | ls = T.exp(lls) 12 | sf = T.exp(lsf) 13 | 14 | if x.ndim == 1: 15 | x = x[None, :] 16 | 17 | if z.ndim == 1: 18 | z = z[None, :] 19 | 20 | lsre = T.outer(T.ones_like(x[:, 0]), ls) 21 | 22 | r2 = T.outer(T.sum(x * x / lsre, 1), T.ones_like(z[:, 0: 1])) - \ 23 | np.float32(2) * T.dot(x / lsre, T.transpose(z)) + \ 24 | T.dot(np.float32(1.0) / lsre, T.transpose(z) ** 2) 25 | 26 | k = sf * T.exp(-np.float32(0.5) * r2) 27 | 28 | return k 29 | 30 | 31 | def compute_kernel_numpy(lls, lsf, x, z): 32 | ls = np.exp(lls) 33 | sf = np.exp(lsf) 34 | 35 | if x.ndim == 1: 36 | x = x[None, :] 37 | 38 | if z.ndim == 1: 39 | z = z[None, :] 40 | 41 | lsre = np.outer(np.ones(x.shape[0]), ls) 42 | 43 | r2 = np.outer(np.sum(x * x / lsre, 1), np.ones(z.shape[0])) - \ 44 | 2 * np.dot(x / lsre, z.T) + np.dot(1.0 / lsre, z.T ** 2) 45 | 46 | k = sf * np.exp(-0.5 * r2) 47 | 48 | return k 49 | 50 | 51 | ## 52 | # xmean and xvar can be vectors of input points 53 | # 54 | # This is the expected value of the kernel 55 | # 56 | 57 | def compute_psi1(lls, lsf, xmean, xvar, z): 58 | if xmean.ndim == 1: 59 | xmean = xmean[None, :] 60 | 61 | ls = T.exp(lls) 62 | sf = T.exp(lsf) 63 | lspxvar = ls + xvar 64 | constterm1 = ls / lspxvar 65 | constterm2 = T.prod(T.sqrt(constterm1), 1) 66 | r2_psi1 = T.outer(T.sum(xmean * xmean / lspxvar, 1), 67 | T.ones_like(z[:, 0: 1])) - \ 68 | np.float32(2) * T.dot(xmean / lspxvar, T.transpose(z)) + \ 69 | T.dot(np.float32(1.0) / lspxvar, T.transpose(z) ** 2) 70 | psi1 = sf * T.outer(constterm2, T.ones_like(z[:, 0: 1])) * T.exp( 71 | -np.float32(0.5) * r2_psi1) 72 | 73 | return psi1 74 | 75 | 76 | def compute_psi1_numpy(lls, lsf, xmean, xvar, z): 77 | if xmean.ndim == 1: 78 | xmean = xmean[None, :] 79 | 80 | ls = np.exp(lls) 81 | sf = np.exp(lsf) 82 | lspxvar = ls + xvar 83 | constterm1 = ls / lspxvar 84 | constterm2 = np.prod(np.sqrt(constterm1), 1) 85 | r2_psi1 = np.outer(np.sum(xmean * xmean / lspxvar, 1), 86 | np.ones(z.shape[0])) - \ 87 | 2 * np.dot(xmean / lspxvar, z.T) + \ 88 | np.dot(1.0 / lspxvar, z.T ** 2) 89 | psi1 = sf * np.outer(constterm2, np.ones(z.shape[0])) * np.exp( 90 | -0.5 * r2_psi1) 91 | return psi1 92 | 93 | 94 | def compute_psi2(lls, lsf, z, input_means, input_vars): 95 | ls = T.exp(lls) 96 | sf = T.exp(lsf) 97 | b = ls / casting(2.0) 98 | term_1 = T.prod(T.sqrt(b / (b + input_vars)), 1) 99 | 100 | scale = T.sqrt(4 * (2 * b[None, :] + 0 * input_vars)) 101 | scaled_z = z[None, :, :] / scale[:, None, :] 102 | scaled_z_minus_m = scaled_z 103 | r2b = T.sum(scaled_z_minus_m ** 2, 2)[:, None, :] + \ 104 | T.sum(scaled_z_minus_m ** 2, 2)[:, :, None] - \ 105 | 2 * T.batched_dot(scaled_z_minus_m, 106 | np.transpose(scaled_z_minus_m, [0, 2, 1])) 107 | term_2 = T.exp(-r2b) 108 | 109 | scale = T.sqrt(4 * (2 * b[None, :] + 2 * input_vars)) 110 | scaled_z = z[None, :, :] / scale[:, None, :] 111 | scaled_m = input_means / scale 112 | scaled_m = T.tile(scaled_m[:, None, :], [1, z.shape[0], 1]) 113 | scaled_z_minus_m = scaled_z - scaled_m 114 | r2b = T.sum(scaled_z_minus_m ** 2, 2)[:, None, :] + \ 115 | T.sum(scaled_z_minus_m ** 2, 2)[:, :, None] + \ 116 | 2 * T.batched_dot(scaled_z_minus_m, 117 | np.transpose(scaled_z_minus_m, [0, 2, 1])) 118 | term_3 = T.exp(-r2b) 119 | 120 | psi2_computed = sf ** casting(2.0) * \ 121 | term_1[:, None, None] * term_2 * term_3 122 | 123 | return T.transpose(psi2_computed, [1, 2, 0]) 124 | 125 | 126 | def compute_psi2_numpy(lls, lsf, z, input_means, input_vars): 127 | ls = np.exp(lls) 128 | sf = np.exp(lsf) 129 | b = ls / casting(2.0) 130 | term_1 = np.prod(np.sqrt(b / (b + input_vars)), 1) 131 | 132 | scale = np.sqrt(4 * (2 * b[None, :] + 0 * input_vars)) 133 | scaled_z = z[None, :, :] / scale[:, None, :] 134 | scaled_z_minus_m = scaled_z 135 | r2b = np.sum(scaled_z_minus_m ** 2, 2)[:, None, :] + \ 136 | np.sum(scaled_z_minus_m ** 2, 2)[:, :, None] - \ 137 | 2 * np.einsum('ijk,ikl->ijl', scaled_z_minus_m, 138 | np.transpose(scaled_z_minus_m, [0, 2, 1])) 139 | term_2 = np.exp(-r2b) 140 | 141 | scale = np.sqrt(4 * (2 * b[None, :] + 2 * input_vars)) 142 | scaled_z = z[None, :, :] / scale[:, None, :] 143 | scaled_m = input_means / scale 144 | scaled_m = np.tile(scaled_m[:, None, :], [1, z.shape[0], 1]) 145 | scaled_z_minus_m = scaled_z - scaled_m 146 | r2b = np.sum(scaled_z_minus_m ** 2, 2)[:, None, :] + \ 147 | np.sum(scaled_z_minus_m ** 2, 2)[:, :, None] + \ 148 | 2 * np.einsum('ijk,ikl->ijl', scaled_z_minus_m, 149 | np.transpose(scaled_z_minus_m, [0, 2, 1])) 150 | term_3 = np.exp(-r2b) 151 | 152 | psi2_computed = sf ** casting(2.0) * \ 153 | term_1[:, None, None] * term_2 * term_3 154 | psi2_computed = np.transpose(psi2_computed, [1, 2, 0]) 155 | 156 | return psi2_computed 157 | -------------------------------------------------------------------------------- /dd_vae/bo/psd_theano.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy 4 | import scipy.linalg as spla 5 | import theano 6 | from theano.gof import Op, Apply 7 | from theano.tensor import as_tensor_variable 8 | from theano.tensor.nlinalg import matrix_dot 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | __all__ = ['MatrixInversePSD', 'LogDetPSD'] 13 | 14 | 15 | def chol2inv(chol): 16 | return spla.cho_solve((chol, False), numpy.eye(chol.shape[0])) 17 | 18 | 19 | class MatrixInversePSD(Op): 20 | r"""Computes the inverse of a matrix :math:`A`. 21 | Given a square matrix :math:`A`, ``matrix_inverse`` returns a square 22 | matrix :math:`A_{inv}` such that the dot product :math:`A \cdot A_{inv}` 23 | and :math:`A_{inv} \cdot A` equals the identity matrix :math:`I`. 24 | Notes 25 | ----- 26 | When possible, the call to this op will be optimized to the call 27 | of ``solve``. 28 | """ 29 | 30 | __props__ = () 31 | 32 | def __init__(self): 33 | pass 34 | 35 | def make_node(self, x): 36 | x = as_tensor_variable(x) 37 | assert x.ndim == 2 38 | return Apply(self, [x], [x.type()]) 39 | 40 | def perform(self, node, inputs, outputs): 41 | (x,) = inputs 42 | (z,) = outputs 43 | z[0] = chol2inv(spla.cholesky(x, lower=False)).astype(x.dtype) 44 | 45 | def grad(self, inputs, g_outputs): 46 | r"""The gradient function should return 47 | .. math:: V\frac{\partial X^{-1}}{\partial X}, 48 | where :math:`V` corresponds to ``g_outputs`` and :math:`X` to 49 | ``inputs``. Using the `matrix cookbook 50 | `_, 51 | one can deduce that the relation corresponds to 52 | .. math:: (X^{-1} \cdot V^{T} \cdot X^{-1})^T. 53 | """ 54 | x, = inputs 55 | xi = self(x) 56 | gz, = g_outputs 57 | # TT.dot(gz.T,xi) 58 | return [-matrix_dot(xi, gz.T, xi).T] 59 | 60 | def R_op(self, inputs, eval_points): 61 | r"""The gradient function should return 62 | .. math:: \frac{\partial X^{-1}}{\partial X}V, 63 | where :math:`V` corresponds to ``g_outputs`` and :math:`X` to 64 | ``inputs``. Using the `matrix cookbook 65 | `_, 66 | one can deduce that the relation corresponds to 67 | .. math:: X^{-1} \cdot V \cdot X^{-1}. 68 | """ 69 | x, = inputs 70 | xi = self(x) 71 | ev, = eval_points 72 | if ev is None: 73 | return [None] 74 | return [-matrix_dot(xi, ev, xi)] 75 | 76 | def infer_shape(self, node, shapes): 77 | return shapes 78 | 79 | 80 | matrix_inverse_psd = MatrixInversePSD() 81 | 82 | 83 | class LogDetPSD(Op): 84 | """ 85 | Matrix log determinant. Input should be a square matrix. 86 | """ 87 | 88 | __props__ = () 89 | 90 | def make_node(self, x): 91 | x = as_tensor_variable(x) 92 | assert x.ndim == 2 93 | o = theano.tensor.scalar(dtype=x.dtype) 94 | return Apply(self, [x], [o]) 95 | 96 | def perform(self, node, inputs, outputs): 97 | (x,) = inputs 98 | (z,) = outputs 99 | try: 100 | z[0] = numpy.asarray(2 * numpy.sum( 101 | numpy.log(numpy.diag(spla.cholesky(x, lower=False)))), 102 | dtype=x.dtype) 103 | except Exception: 104 | print('Failed to compute log determinant', x) 105 | raise 106 | 107 | def grad(self, inputs, g_outputs): 108 | gz, = g_outputs 109 | x, = inputs 110 | return [gz * matrix_inverse_psd(x).T] 111 | 112 | def infer_shape(self, node, shapes): 113 | return [()] 114 | 115 | def __str__(self): 116 | return "LogDetPSD" 117 | 118 | 119 | log_det_psd = LogDetPSD() 120 | -------------------------------------------------------------------------------- /dd_vae/bo/sparse_gp.py: -------------------------------------------------------------------------------- 1 | ## 2 | # This class represents a node within the network 3 | # 4 | 5 | import sys 6 | 7 | import numpy as np 8 | import scipy.optimize as spo 9 | import scipy.stats as sps 10 | import theano 11 | import theano.tensor as T 12 | from .sparse_gp_theano_internal import Sparse_GP 13 | 14 | 15 | def casting(x): 16 | return np.array(x).astype(theano.config.floatX) 17 | 18 | 19 | def global_optimization(grid, lower, upper, function_grid, function_scalar, 20 | function_scalar_gradient): 21 | grid_values = function_grid(grid) 22 | best = grid_values.argmin() 23 | 24 | # We solve the optimization problem 25 | 26 | X_initial = grid[best: (best + 1), :] 27 | 28 | def objective(X): 29 | X = casting(X) 30 | X = X.reshape((1, grid.shape[1])) 31 | value = function_scalar(X) 32 | gradient_value = function_scalar_gradient(X).flatten() 33 | return np.float(value), gradient_value.astype(np.float) 34 | 35 | lbfgs_bounds = list(zip(lower.tolist(), upper.tolist())) 36 | x_optimal, y_opt, opt_info = spo.fmin_l_bfgs_b(objective, X_initial, 37 | bounds=lbfgs_bounds, 38 | iprint=0, maxiter=150) 39 | x_optimal = x_optimal.reshape((1, grid.shape[1])) 40 | 41 | return x_optimal, y_opt 42 | 43 | 44 | def adam_theano(loss, all_params, learning_rate=0.001): 45 | b1 = 0.9 46 | b2 = 0.999 47 | e = 1e-8 48 | updates = [] 49 | all_grads = theano.grad(loss, all_params) 50 | alpha = learning_rate 51 | t = theano.shared(casting(1.0)) 52 | for theta_previous, g in zip(all_params, all_grads): 53 | m_previous = theano.shared(np.zeros(theta_previous.get_value().shape, 54 | dtype=theano.config.floatX)) 55 | v_previous = theano.shared(np.zeros(theta_previous.get_value().shape, 56 | dtype=theano.config.floatX)) 57 | m = b1 * m_previous + (1 - b1) * g 58 | v = b2 * v_previous + (1 - b2) * g ** 2 59 | m_hat = m / (1 - b1 ** t) 60 | v_hat = v / (1 - b2 ** t) 61 | theta = theta_previous - (alpha * m_hat) / ( 62 | T.sqrt(v_hat) + e) # (Update parameters) 63 | updates.append((m_previous, m)) 64 | updates.append((v_previous, v)) 65 | updates.append((theta_previous, theta)) 66 | updates.append((t, t + 1.)) 67 | return updates 68 | 69 | 70 | class SparseGP: 71 | """ 72 | The training_targets are the Y's which in the case of regression are 73 | real numbers in the case of binary classification are 1 or -1 and in 74 | the case of multiclass classification are 0, 1, 2,.. n_class - 1 75 | """ 76 | def __init__(self, input_means, input_vars, training_targets, 77 | n_inducing_points): 78 | 79 | self.input_means = theano.shared( 80 | value=input_means.astype(theano.config.floatX), borrow=True, 81 | name='X') 82 | self.input_vars = theano.shared( 83 | value=input_vars.astype(theano.config.floatX), borrow=True, 84 | name='X') 85 | self.original_training_targets = theano.shared( 86 | value=training_targets.astype(theano.config.floatX), borrow=True, 87 | name='y') 88 | self.training_targets = self.original_training_targets 89 | 90 | self.n_points = input_means.shape[0] 91 | self.d_input = input_means.shape[1] 92 | 93 | self.sparse_gp = Sparse_GP(n_inducing_points, self.n_points, 94 | self.d_input, self.input_means, 95 | self.input_vars, self.training_targets) 96 | 97 | self.set_for_prediction = False 98 | self.predict_function = None 99 | 100 | def initialize(self): 101 | self.sparse_gp.initialize() 102 | 103 | def setForTraining(self): 104 | self.sparse_gp.setForTraining() 105 | 106 | def setForPrediction(self): 107 | self.sparse_gp.setForPrediction() 108 | 109 | def get_params(self): 110 | return self.sparse_gp.get_params() 111 | 112 | def set_params(self, params): 113 | self.sparse_gp.set_params(params) 114 | 115 | def getEnergy(self): 116 | self.sparse_gp.compute_output() 117 | return self.sparse_gp.getContributionToEnergy()[0, 0] 118 | 119 | def predict(self, means_test, vars_test): 120 | 121 | self.setForPrediction() 122 | 123 | means_test = means_test.astype(theano.config.floatX) 124 | vars_test = vars_test.astype(theano.config.floatX) 125 | 126 | if self.predict_function is None: 127 | self.sparse_gp.compute_output() 128 | predictions = self.sparse_gp.getPredictedValues() 129 | 130 | X = T.matrix('X', dtype=theano.config.floatX) 131 | Z = T.matrix('Z', dtype=theano.config.floatX) 132 | 133 | self.predict_function = theano.function([X, Z], predictions, 134 | givens={ 135 | self.input_means: X, 136 | self.input_vars: Z}) 137 | 138 | predicted_values = self.predict_function(means_test, vars_test) 139 | 140 | self.setForTraining() 141 | 142 | return predicted_values 143 | 144 | # This trains the network via LBFGS as implemented in scipy 145 | # (slow but good for small datasets) 146 | 147 | def train_via_LBFGS(self, input_means, input_vars, training_targets, 148 | max_iterations=500): 149 | 150 | # We initialize the network and get the initial parameters 151 | 152 | input_means = input_means.astype(theano.config.floatX) 153 | input_vars = input_vars.astype(theano.config.floatX) 154 | training_targets = training_targets.astype(theano.config.floatX) 155 | self.input_means.set_value(input_means) 156 | self.input_vars.set_value(input_vars) 157 | self.original_training_targets.set_value(training_targets) 158 | 159 | self.initialize() 160 | self.setForTraining() 161 | 162 | X = T.matrix('X', dtype=theano.config.floatX) 163 | Z = T.matrix('Z', dtype=theano.config.floatX) 164 | y = T.matrix('y', dtype=theano.config.floatX) 165 | e = self.getEnergy() 166 | energy = theano.function([X, Z, y], e, givens={ 167 | self.input_means: X, 168 | self.input_vars: Z, 169 | self.training_targets: y}) 170 | all_params = self.get_params() 171 | energy_grad = theano.function([X, Z, y], T.grad(e, all_params), 172 | givens={self.input_means: X, 173 | self.input_vars: Z, 174 | self.training_targets: y}) 175 | 176 | initial_params = theano.function([], all_params)() 177 | 178 | params_shapes = [s.shape for s in initial_params] 179 | 180 | def de_vectorize_params(params): 181 | ret = [] 182 | for shape in params_shapes: 183 | if len(shape) == 2: 184 | ret.append(params[: np.prod(shape)].reshape(shape)) 185 | params = params[np.prod(shape):] 186 | elif len(shape) == 1: 187 | ret.append(params[: np.prod(shape)]) 188 | params = params[np.prod(shape):] 189 | else: 190 | ret.append(params[0]) 191 | params = params[1:] 192 | return ret 193 | 194 | def vectorize_params(params): 195 | return np.concatenate([s.flatten() for s in params]) 196 | 197 | def objective(params): 198 | 199 | params = de_vectorize_params(params) 200 | self.set_params(params) 201 | energy_value = energy(input_means, input_vars, training_targets) 202 | gradient_value = energy_grad(input_means, input_vars, 203 | training_targets) 204 | 205 | return -energy_value, -vectorize_params(gradient_value) 206 | 207 | # We create a theano function that evaluates the energy 208 | 209 | initial_params = vectorize_params(initial_params) 210 | x_opt, y_opt, opt_info = spo.fmin_l_bfgs_b(objective, initial_params, 211 | bounds=None, iprint=1, 212 | maxiter=max_iterations) 213 | 214 | self.set_params(de_vectorize_params(x_opt)) 215 | 216 | return y_opt 217 | 218 | def train_via_ADAM(self, input_means, input_vars, training_targets, 219 | input_means_test, input_vars_test, test_targets, 220 | max_iterations=500, minibatch_size=4000, 221 | learning_rate=1e-3, ignoroe_variances=True): 222 | 223 | input_means = input_means.astype(theano.config.floatX) 224 | input_vars = input_vars.astype(theano.config.floatX) 225 | training_targets = training_targets.astype(theano.config.floatX) 226 | n_data_points = input_means.shape[0] 227 | selected_points = np.random.choice(n_data_points, n_data_points, 228 | replace=False)[ 229 | 0: min(n_data_points, minibatch_size)] 230 | self.input_means.set_value(input_means[selected_points, :]) 231 | self.input_vars.set_value(input_vars[selected_points, :]) 232 | self.original_training_targets.set_value( 233 | training_targets[selected_points, :]) 234 | 235 | print('Initializing network') 236 | sys.stdout.flush() 237 | self.setForTraining() 238 | self.initialize() 239 | 240 | X = T.matrix('X', dtype=theano.config.floatX) 241 | Z = T.matrix('Z', dtype=theano.config.floatX) 242 | y = T.matrix('y', dtype=theano.config.floatX) 243 | 244 | e = self.getEnergy() 245 | 246 | all_params = self.get_params() 247 | 248 | print('Compiling adam updates') 249 | sys.stdout.flush() 250 | 251 | process_minibatch_adam = theano.function( 252 | [X, Z, y], -e, 253 | updates=adam_theano(-e, 254 | all_params, 255 | learning_rate), 256 | givens={self.input_means: X, 257 | self.input_vars: Z, 258 | self.original_training_targets: y} 259 | ) 260 | 261 | # Main loop of the optimization 262 | 263 | print('Training') 264 | sys.stdout.flush() 265 | n_batches = int(np.ceil(1.0 * n_data_points / minibatch_size)) 266 | for j in range(max_iterations): 267 | suffle = np.random.choice(n_data_points, n_data_points, 268 | replace=False) 269 | input_means = input_means[suffle, :] 270 | input_vars = input_vars[suffle, :] 271 | training_targets = training_targets[suffle, :] 272 | 273 | for i in range(n_batches): 274 | minibatch_data_means = input_means[i * minibatch_size: min( 275 | (i + 1) * minibatch_size, n_data_points), :] 276 | minibatch_data_vars = input_vars[i * minibatch_size: min( 277 | (i + 1) * minibatch_size, n_data_points), :] 278 | minibatch_targets = training_targets[i * minibatch_size: min( 279 | (i + 1) * minibatch_size, n_data_points), :] 280 | 281 | process_minibatch_adam(minibatch_data_means, 282 | minibatch_data_vars, 283 | minibatch_targets) 284 | 285 | pred, uncert = self.predict(input_means_test, input_vars_test) 286 | test_error = np.sqrt(np.mean((pred - test_targets) ** 2)) 287 | test_ll = np.mean( 288 | sps.norm.logpdf(pred - test_targets, scale=np.sqrt(uncert))) 289 | 290 | print(f'Epoch: {j}') 291 | print('Test error: {} Test ll: {}'.format(test_error, test_ll)) 292 | sys.stdout.flush() 293 | 294 | pred = np.zeros((0, 1)) 295 | uncert = np.zeros((0, uncert.shape[1])) 296 | for i in range(n_batches): 297 | minibatch_data_means = input_means[i * minibatch_size: min( 298 | (i + 1) * minibatch_size, n_data_points), :] 299 | minibatch_data_vars = input_vars[i * minibatch_size: min( 300 | (i + 1) * minibatch_size, n_data_points), :] 301 | pred_new, uncert_new = self.predict(minibatch_data_means, 302 | minibatch_data_vars) 303 | pred = np.concatenate((pred, pred_new), 0) 304 | uncert = np.concatenate((uncert, uncert_new), 0) 305 | 306 | training_error = np.sqrt(np.mean((pred - training_targets) ** 2)) 307 | training_ll = np.mean(sps.norm.logpdf(pred - training_targets, 308 | scale=np.sqrt(uncert))) 309 | 310 | print('Train error: {} Train ll: {}'.format(training_error, 311 | training_ll)) 312 | sys.stdout.flush() 313 | 314 | def get_incumbent(self, grid, lower, upper): 315 | 316 | self.sparse_gp.compute_output() 317 | m, v = self.sparse_gp.getPredictedValues() 318 | 319 | X = T.matrix('X', dtype=theano.config.floatX) 320 | function_grid = theano.function([X], m, 321 | givens={self.input_means: X, 322 | self.input_vars: 0 * X}) 323 | function_scalar = theano.function([X], m[0, 0], 324 | givens={self.input_means: X, 325 | self.input_vars: 0 * X}) 326 | function_scalar_gradient = theano.function( 327 | [X], T.grad(m[0, 0], self.input_means), 328 | givens={self.input_means: X, 329 | self.input_vars: 0 * X}) 330 | 331 | return global_optimization(grid, lower, upper, function_grid, 332 | function_scalar, 333 | function_scalar_gradient)[1] 334 | 335 | def optimize_ei(self, grid, lower, upper, incumbent): 336 | 337 | X = T.matrix('X', dtype=theano.config.floatX) 338 | log_ei = self.sparse_gp.compute_log_ei(X, incumbent) 339 | 340 | function_grid = theano.function([X], -log_ei) 341 | function_scalar = theano.function([X], -log_ei[0, 0]) 342 | function_scalar_gradient = theano.function([X], 343 | -T.grad(log_ei[0, 0], X)) 344 | 345 | return \ 346 | global_optimization(grid, lower, upper, function_grid, 347 | function_scalar, 348 | function_scalar_gradient)[0] 349 | 350 | def batched_greedy_ei(self, q, lower, upper, n_samples=1): 351 | 352 | self.setForPrediction() 353 | 354 | grid_size = 10000 355 | grid = casting( 356 | lower + np.random.rand(grid_size, self.d_input) * (upper - lower)) 357 | 358 | incumbent = self.get_incumbent(grid, lower, upper) 359 | X_numpy = self.optimize_ei(grid, lower, upper, incumbent) 360 | randomness_numpy = casting( 361 | 0 * np.random.randn(X_numpy.shape[0], n_samples).astype( 362 | theano.config.floatX)) 363 | 364 | randomness = theano.shared( 365 | value=randomness_numpy.astype(theano.config.floatX), 366 | name='randomness', borrow=True) 367 | X = theano.shared(value=X_numpy.astype(theano.config.floatX), name='X', 368 | borrow=True) 369 | x = T.matrix('x', dtype=theano.config.floatX) 370 | log_ei = self.sparse_gp.compute_log_averaged_ei(x, X, randomness, 371 | incumbent) 372 | 373 | function_grid = theano.function([x], -log_ei) 374 | function_scalar = theano.function([x], -log_ei[0]) 375 | function_scalar_gradient = theano.function([x], -T.grad(log_ei[0], x)) 376 | 377 | # We optimize the ei in a greedy manner 378 | 379 | for i in range(1, q): 380 | new_point = global_optimization(grid, lower, upper, function_grid, 381 | function_scalar, 382 | function_scalar_gradient)[0] 383 | X_numpy = casting(np.concatenate([X_numpy, new_point], 0)) 384 | randomness_numpy = casting( 385 | 0 * np.random.randn(X_numpy.shape[0], n_samples).astype( 386 | theano.config.floatX)) 387 | X.set_value(X_numpy) 388 | randomness.set_value(randomness_numpy) 389 | 390 | m, v = self.predict(X_numpy, 0 * X_numpy) 391 | 392 | return X_numpy 393 | -------------------------------------------------------------------------------- /dd_vae/bo/sparse_gp_theano_internal.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import theano 5 | import theano.tensor as T 6 | from theano.tensor.slinalg import Cholesky as MatrixChol 7 | 8 | from .gauss import casting, compute_kernel, compute_psi1, compute_psi2 9 | from .psd_theano import MatrixInversePSD, LogDetPSD 10 | 11 | 12 | def n_pdf(x): 13 | return 1.0 / T.sqrt(2 * math.pi) * T.exp(-0.5 * x ** 2) 14 | 15 | 16 | def log_n_pdf(x): 17 | return -0.5 * T.log(2 * math.pi) - 0.5 * x ** 2 18 | 19 | 20 | def n_cdf(x): 21 | return 0.5 * (1.0 + T.erf(x / T.sqrt(2.0))) 22 | 23 | 24 | def log_n_cdf_approx(x): 25 | return log_n_pdf(x) - T.log(-x - 1 / x + 2 / x ** 3) 26 | 27 | 28 | def log_n_cdf(x): 29 | x = T.switch(T.lt(x, casting(-10)), log_n_cdf_approx(x), T.log(n_cdf(x))) 30 | return x 31 | 32 | 33 | def ratio(x): 34 | x = T.switch( 35 | T.lt(x, casting(-10)), 36 | -(casting(1.0) / x - casting(1.0) / x ** 3 + 37 | casting(3.0) / x ** 5 - casting(15.0) / x ** 7), 38 | n_cdf(x) / n_pdf(x)) 39 | return x 40 | 41 | 42 | def LogSumExp(x, axis=None): 43 | x_max = T.max(x, axis=axis, keepdims=True) 44 | return T.log(T.sum(T.exp(x - x_max), axis=axis, keepdims=True)) + x_max 45 | 46 | 47 | ## 48 | # This class represents a GP node in the network 49 | # 50 | 51 | class Sparse_GP: 52 | """ 53 | n_points are the total number of training points 54 | (that is used for cavity computation) 55 | """ 56 | 57 | def __init__(self, n_inducing_points, n_points, input_d, input_means, 58 | input_vars, training_targets): 59 | 60 | self.ignore_variances = True 61 | self.n_inducing_points = n_inducing_points 62 | self.n_points = n_points 63 | self.input_d = input_d 64 | self.training_targets = training_targets 65 | self.input_means = input_means 66 | self.input_vars = input_vars 67 | 68 | # These are the actual parameters of the posterior distribution 69 | # being optimzied 70 | # covCavity = (Kzz^-1 + LParamPost LParamPost^T * (n - 1) / n) 71 | # and meanCavity = covCavity mParamPost * (n - 1) / n 72 | 73 | initial_value = np.zeros((n_inducing_points, n_inducing_points)) 74 | self.LParamPost = theano.shared( 75 | value=initial_value.astype(theano.config.floatX), 76 | name='LParamPost', borrow=True) 77 | self.mParamPost = theano.shared( 78 | value=initial_value[:, 0: 1].astype(theano.config.floatX), 79 | name='mParamPost', borrow=True) 80 | self.lls = theano.shared( 81 | value=np.zeros(input_d).astype(theano.config.floatX), name='lls', 82 | borrow=True) 83 | self.lsf = theano.shared( 84 | value=np.zeros(1).astype(theano.config.floatX)[0], name='lsf', 85 | borrow=True) 86 | self.z = theano.shared( 87 | value=np.zeros((n_inducing_points, input_d)).astype( 88 | theano.config.floatX), name='z', borrow=True) 89 | self.lvar_noise = theano.shared( 90 | value=casting(0) * np.ones(1).astype(theano.config.floatX)[0], 91 | name='lvar_noise', borrow=True) 92 | 93 | self.set_for_training = casting(1.0) 94 | 95 | # We set the level of jitter to use (added to the diagonal of Kzz) 96 | 97 | self.jitter = casting(1e-3) 98 | 99 | def compute_output(self): 100 | 101 | # We compute the output mean 102 | 103 | self.Kzz = compute_kernel(self.lls, self.lsf, self.z, self.z) + T.eye( 104 | self.z.shape[0]) * self.jitter * T.exp(self.lsf) 105 | self.KzzInv = MatrixInversePSD()(self.Kzz) 106 | LLt = T.dot(self.LParamPost, T.transpose(self.LParamPost)) 107 | self.covCavityInv = self.KzzInv + LLt * casting( 108 | self.n_points - self.set_for_training) / casting(self.n_points) 109 | self.covCavity = MatrixInversePSD()(self.covCavityInv) 110 | self.meanCavity = T.dot(self.covCavity, casting( 111 | self.n_points - self.set_for_training) / casting( 112 | self.n_points) * self.mParamPost) 113 | self.KzzInvcovCavity = T.dot(self.KzzInv, self.covCavity) 114 | self.KzzInvmeanCavity = T.dot(self.KzzInv, self.meanCavity) 115 | self.covPosteriorInv = self.KzzInv + LLt 116 | self.covPosterior = MatrixInversePSD()(self.covPosteriorInv) 117 | self.meanPosterior = T.dot(self.covPosterior, self.mParamPost) 118 | self.Kxz = compute_kernel(self.lls, self.lsf, self.input_means, self.z) 119 | self.B = T.dot(self.KzzInvcovCavity, self.KzzInv) - self.KzzInv 120 | v_out = T.exp(self.lsf) + T.dot(self.Kxz * T.dot(self.Kxz, self.B), 121 | T.ones_like(self.z[:, 0: 1])) 122 | 123 | if self.ignore_variances: 124 | 125 | self.output_means = T.dot(self.Kxz, self.KzzInvmeanCavity) 126 | self.output_vars = abs(v_out) + casting(0) * T.sum(self.input_vars) 127 | 128 | else: 129 | 130 | self.EKxz = compute_psi1(self.lls, self.lsf, self.input_means, 131 | self.input_vars, self.z) 132 | self.output_means = T.dot(self.EKxz, self.KzzInvmeanCavity) 133 | 134 | # In other layers we have to compute the expected variance 135 | 136 | self.B2 = T.outer(T.dot(self.KzzInv, self.meanCavity), 137 | T.dot(self.KzzInv, self.meanCavity)) 138 | 139 | exact_output_vars = True 140 | 141 | if exact_output_vars: 142 | 143 | # We compute the exact output variance 144 | 145 | self.psi2 = compute_psi2(self.lls, self.lsf, self.z, 146 | self.input_means, self.input_vars) 147 | ll = T.transpose(self.EKxz[:, None, :] * self.EKxz[:, :, None], 148 | [1, 2, 0]) 149 | kk = T.transpose(self.Kxz[:, None, :] * self.Kxz[:, :, None], 150 | [1, 2, 0]) 151 | v1 = T.transpose(T.sum( 152 | T.sum(T.shape_padaxis(self.B2, 2) * (self.psi2 - ll), 0), 153 | 0, keepdims=True)) 154 | v2 = T.transpose(T.sum( 155 | T.sum(T.shape_padaxis(self.B, 2) * (self.psi2 - kk), 0), 0, 156 | keepdims=True)) 157 | 158 | else: 159 | 160 | # We compute the approximate output variance using 161 | # the unscented kalman filter 162 | 163 | v1 = 0 164 | v2 = 0 165 | 166 | n = self.input_d 167 | for j in range(1, n + 1): 168 | mask = T.zeros_like(self.input_vars) 169 | mask = T.set_subtensor(mask[:, j - 1], 1) 170 | inc = mask * T.sqrt(casting(n) * self.input_vars) 171 | self.kplus = T.sqrt( 172 | casting(1.0) / casting(2 * n)) * compute_kernel( 173 | self.lls, self.lsf, self.input_means + inc, self.z) 174 | self.kminus = T.sqrt( 175 | casting(1.0) / casting(2 * n)) * compute_kernel( 176 | self.lls, self.lsf, self.input_means - inc, self.z) 177 | 178 | v1 += T.dot(self.kplus * T.dot(self.kplus, self.B2), 179 | T.ones_like(self.z[:, 0: 1])) 180 | v1 += T.dot(self.kminus * T.dot(self.kminus, self.B2), 181 | T.ones_like(self.z[:, 0: 1])) 182 | v2 += T.dot(self.kplus * T.dot(self.kplus, self.B), 183 | T.ones_like(self.z[:, 0: 1])) 184 | v2 += T.dot(self.kminus * T.dot(self.kminus, self.B), 185 | T.ones_like(self.z[:, 0: 1])) 186 | 187 | v1 -= T.dot(self.EKxz * T.dot(self.EKxz, self.B2), 188 | T.ones_like(self.z[:, 0: 1])) 189 | v2 -= T.dot(self.Kxz * T.dot(self.Kxz, self.B), 190 | T.ones_like(self.z[:, 0: 1])) 191 | 192 | self.output_vars = abs(v_out) + abs(v2) + abs(v1) 193 | 194 | self.output_vars = self.output_vars + T.exp(self.lvar_noise) 195 | 196 | return 197 | 198 | def get_params(self): 199 | 200 | return [self.lls, self.lsf, self.z, self.mParamPost, self.LParamPost, 201 | self.lvar_noise] 202 | 203 | def set_params(self, params): 204 | 205 | self.lls.set_value(params[0]) 206 | self.lsf.set_value(params[1]) 207 | self.z.set_value(params[2]) 208 | self.mParamPost.set_value(params[3]) 209 | self.LParamPost.set_value(params[4]) 210 | self.lvar_noise.set_value(params[5]) 211 | 212 | ## 213 | # The next functions compute the log normalizer of each distribution 214 | # (needed for energy computation) 215 | # 216 | 217 | def getLogNormalizerCavity(self): 218 | 219 | assert self.covCavity is not None and \ 220 | self.meanCavity is not None and \ 221 | self.covCavityInv is not None 222 | 223 | return (casting(0.5 * self.n_inducing_points * np.log(2 * np.pi)) + 224 | casting(0.5) * LogDetPSD()(self.covCavity) + 225 | casting(0.5) * T.dot(T.dot( 226 | T.transpose(self.meanCavity), 227 | self.covCavityInv), 228 | self.meanCavity)) 229 | 230 | def getLogNormalizerPrior(self): 231 | 232 | assert self.KzzInv is not None 233 | 234 | return casting( 235 | 0.5 * self.n_inducing_points * np.log(2 * np.pi)) - casting( 236 | 0.5) * LogDetPSD()(self.KzzInv) 237 | 238 | def getLogNormalizerPosterior(self): 239 | 240 | assert self.covPosterior is not None and \ 241 | self.meanPosterior is not None and \ 242 | self.covPosteriorInv is not None 243 | 244 | return (casting(0.5 * self.n_inducing_points * np.log(2 * np.pi)) + 245 | casting(0.5) * LogDetPSD()(self.covPosterior) + 246 | casting(0.5) * T.dot(T.dot(T.transpose(self.meanPosterior), 247 | self.covPosteriorInv), 248 | self.meanPosterior)) 249 | 250 | ## 251 | # We return the contribution to the energy of the node (See last Eq. of 252 | # Sec. 4 in http://arxiv.org/pdf/1602.04133.pdf v1) 253 | # 254 | 255 | def getContributionToEnergy(self): 256 | 257 | assert self.n_points is not None and \ 258 | self.covCavity is not None and \ 259 | self.covPosterior is not None and \ 260 | self.input_means is not None 261 | 262 | logZpost = self.getLogNormalizerPosterior() 263 | logZprior = self.getLogNormalizerPrior() 264 | logZcav = self.getLogNormalizerCavity() 265 | 266 | # We multiply by the minibatch size and normalize terms according to 267 | # the total number of points (n_points) 268 | 269 | return (((logZcav - logZpost) + 270 | logZpost / casting(self.n_points) - 271 | logZprior / casting(self.n_points)) * 272 | T.cast(self.input_means.shape[0], 'float32') + 273 | T.sum(self.getLogZ())) 274 | 275 | # These methods sets the inducing points to be a random subset of the 276 | # inputs (we should receive more 277 | # inputs than inducing points), the length scales are set to the mean 278 | # of the euclidean distance 279 | 280 | def initialize(self): 281 | 282 | input_means = np.array(theano.function([], self.input_means)()) 283 | 284 | assert input_means.shape[0] >= self.n_inducing_points 285 | 286 | selected_points = np.random.choice(input_means.shape[0], 287 | self.n_inducing_points, 288 | replace=False) 289 | z = input_means[selected_points, :] 290 | 291 | # If we are not in the first layer, we initialize the length 292 | # scales to one 293 | 294 | lls = np.zeros(input_means.shape[1]) 295 | 296 | M = np.outer(np.sum(input_means ** 2, 1), 297 | np.ones(input_means.shape[0])) 298 | dist = M - 2 * np.dot(input_means, input_means.T) + M.T 299 | lls = np.log(0.5 * (np.median( 300 | dist[np.triu_indices(input_means.shape[0], 1)]) + 1e-3)) * np.ones( 301 | input_means.shape[1]) 302 | 303 | self.lls.set_value(lls.astype(theano.config.floatX)) 304 | self.z.set_value(z.astype(theano.config.floatX)) 305 | self.lsf.set_value(np.zeros(1).astype(theano.config.floatX)[0]) 306 | 307 | # We initialize the cavity and the posterior approximation to the prior 308 | # but with a small random mean so that the outputs are not equal to 309 | # zero (otherwise the output of the gp will be zero and 310 | # the next layer will be initialized improperly). 311 | 312 | # If we are not in the first layer, we reduce the variance of 313 | # the L and m 314 | 315 | L = np.random.normal( 316 | size=(self.n_inducing_points, self.n_inducing_points)) * 1.0 317 | m = self.training_targets.get_value()[selected_points, :] 318 | 319 | self.LParamPost.set_value(L.astype(theano.config.floatX)) 320 | self.mParamPost.set_value(m.astype(theano.config.floatX)) 321 | 322 | # This sets the node for prediction. It basically switches the cavity 323 | # distribution to be the posterior approximation 324 | # Once set in this state the network cannot be trained any more. 325 | 326 | def setForPrediction(self): 327 | 328 | if self.set_for_training == casting(1.0): 329 | self.set_for_training = casting(0.0) 330 | 331 | # This function undoes the work done by the previous method 332 | 333 | def setForTraining(self): 334 | 335 | # We only do something if the node was set for prediction 336 | # instead of training 337 | 338 | if self.set_for_training == casting(0.0): 339 | self.set_for_training == casting(1.0) 340 | 341 | def getLogZ(self): 342 | 343 | return -0.5 * T.log(2 * np.pi * self.output_vars) - \ 344 | 0.5 * (self.training_targets - 345 | self.output_means) ** 2 / self.output_vars 346 | 347 | def getPredictedValues(self): 348 | 349 | return self.output_means, self.output_vars 350 | 351 | def get_training_targets(self): 352 | return self.training_targets 353 | 354 | def set_training_targets(self, training_targets): 355 | self.training_targets = training_targets 356 | 357 | def compute_log_ei(self, x, incumbent): 358 | 359 | Kzz = compute_kernel(self.lls, self.lsf, self.z, self.z) + T.eye( 360 | self.z.shape[0]) * self.jitter * T.exp(self.lsf) 361 | KzzInv = MatrixInversePSD()(Kzz) 362 | LLt = T.dot(self.LParamPost, T.transpose(self.LParamPost)) 363 | covCavityInv = KzzInv + LLt * casting( 364 | self.n_points - self.set_for_training) / casting(self.n_points) 365 | covCavity = MatrixInversePSD()(covCavityInv) 366 | meanCavity = T.dot(covCavity, casting( 367 | self.n_points - self.set_for_training) / casting( 368 | self.n_points) * self.mParamPost) 369 | KzzInvcovCavity = T.dot(KzzInv, covCavity) 370 | KzzInvmeanCavity = T.dot(KzzInv, meanCavity) 371 | Kxz = compute_kernel(self.lls, self.lsf, x, self.z) 372 | B = T.dot(KzzInvcovCavity, KzzInv) - KzzInv 373 | v_out = T.exp(self.lsf) + T.dot(Kxz * T.dot(Kxz, B), T.ones_like( 374 | self.z[:, 0: 1])) # + T.exp(self.lvar_noise) 375 | m_out = T.dot(Kxz, KzzInvmeanCavity) 376 | s = (incumbent - m_out) / T.sqrt(v_out) 377 | 378 | log_ei = T.log( 379 | (incumbent - m_out) * ratio(s) + T.sqrt(v_out)) + log_n_pdf(s) 380 | 381 | return log_ei 382 | 383 | def compute_log_averaged_ei(self, x, X, randomness, incumbent): 384 | 385 | # We compute the old predictive mean at x 386 | 387 | Kzz = compute_kernel(self.lls, self.lsf, self.z, self.z) + T.eye( 388 | self.z.shape[0]) * self.jitter * T.exp(self.lsf) 389 | KzzInv = MatrixInversePSD()(Kzz) 390 | LLt = T.dot(self.LParamPost, T.transpose(self.LParamPost)) 391 | covCavityInv = KzzInv + LLt * casting( 392 | self.n_points - self.set_for_training) / casting(self.n_points) 393 | covCavity = MatrixInversePSD()(covCavityInv) 394 | meanCavity = T.dot(covCavity, casting( 395 | self.n_points - self.set_for_training) / casting( 396 | self.n_points) * self.mParamPost) 397 | KzzInvmeanCavity = T.dot(KzzInv, meanCavity) 398 | Kxz = compute_kernel(self.lls, self.lsf, x, self.z) 399 | m_old_x = T.dot(Kxz, KzzInvmeanCavity) 400 | 401 | # We compute the old predictive mean at X 402 | 403 | KXz = compute_kernel(self.lls, self.lsf, X, self.z) 404 | 405 | # We compute the required cross covariance matrices 406 | 407 | KXX = compute_kernel(self.lls, self.lsf, X, X) - T.dot( 408 | T.dot(KXz, KzzInv), KXz.T) + T.eye( 409 | X.shape[0]) * self.jitter * T.exp(self.lsf) 410 | KXXInv = MatrixInversePSD()(KXX) 411 | 412 | KxX = compute_kernel(self.lls, self.lsf, x, X) 413 | xX = T.concatenate([x, X], 0) 414 | KxXz = compute_kernel(self.lls, self.lsf, xX, self.z) 415 | KxX = KxX - T.dot(T.dot(KxXz[0: x.shape[0], :], KzzInv), 416 | KxXz[x.shape[0]: xX.shape[0], :].T) 417 | 418 | # We compute the new posterior mean 419 | 420 | samples_internal = T.dot(MatrixChol()(KXX), randomness) 421 | 422 | new_predictive_mean = T.tile(m_old_x, [1, randomness.shape[1]]) + \ 423 | T.dot(KxX, T.dot(KXXInv, samples_internal)) 424 | 425 | # We compute the new posterior variance 426 | 427 | z_expanded = T.concatenate([self.z, X], 0) 428 | Kxz_expanded = compute_kernel(self.lls, self.lsf, x, z_expanded) 429 | Kzz_expanded = compute_kernel(self.lls, self.lsf, z_expanded, 430 | z_expanded) + T.eye( 431 | z_expanded.shape[0]) * self.jitter * T.exp(self.lsf) 432 | Kzz_expandedInv = MatrixInversePSD()(Kzz_expanded) 433 | v_out = T.exp(self.lsf) - T.dot( 434 | Kxz_expanded * T.dot(Kxz_expanded, Kzz_expandedInv), 435 | T.ones_like(z_expanded[:, 0: 1])) 436 | new_predictive_var = T.tile(v_out, [1, randomness.shape[1]]) 437 | 438 | s = (incumbent - new_predictive_mean) / T.sqrt(new_predictive_var) 439 | 440 | log_ei = T.log((incumbent - new_predictive_mean) * ratio(s) + T.sqrt( 441 | new_predictive_var)) + log_n_pdf(s) 442 | 443 | return T.mean(LogSumExp(log_ei, 1), 1) 444 | -------------------------------------------------------------------------------- /dd_vae/bo/utils.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | from rdkit.Chem import rdmolops 3 | 4 | 5 | def max_ring_penalty(mol): 6 | cycle_list = nx.cycle_basis( 7 | nx.Graph(rdmolops.GetAdjacencyMatrix(mol))) 8 | if len(cycle_list) == 0: 9 | cycle_length = 0 10 | else: 11 | cycle_length = max([len(j) for j in cycle_list]) 12 | return -max(0, cycle_length - 6) 13 | -------------------------------------------------------------------------------- /dd_vae/proposals.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def get_proposals(): 6 | ''' 7 | Returns a dict with implemented finite support proposals 8 | ''' 9 | return { 10 | 'uniform': UniformProposal, 11 | 'triangular': TriangularProposal, 12 | 'epanechnikov': EpanechnikovProposal, 13 | 'quartic': QuarticProposal, 14 | 'triweight': TriweightProposal, 15 | 'tricube': TricubeProposal, 16 | 'cosine': CosineProposal, 17 | 'gaussian': GaussianProposal 18 | } 19 | 20 | 21 | class Proposal: 22 | def __init__(self, eps=1e-9): 23 | self.buffer = np.zeros(0) 24 | self.eps = eps 25 | 26 | def density(self, z): 27 | raise NotImplementedError 28 | 29 | def kl(self, m=0, s=1): 30 | raise NotImplementedError 31 | 32 | def kl_uniform(self, m=0, s=1): 33 | raise NotImplementedError 34 | 35 | def sample(self, m, s): 36 | batch_size = np.prod(m.shape) 37 | uniform_height = self.density(0) 38 | acceptance_rate = 0.5 / uniform_height 39 | up_batch_size = int(batch_size / acceptance_rate) + 1 40 | while self.buffer.shape[0] < batch_size: 41 | sample, rejection_sample = np.split( 42 | np.random.rand(2*up_batch_size), 2 43 | ) 44 | sample = sample * 2 - 1 45 | rejection_sample = rejection_sample * uniform_height 46 | density = self.density(sample) 47 | sample = sample[rejection_sample < density] 48 | self.buffer = np.concatenate((self.buffer, sample), 0) 49 | 50 | sample = self.buffer[:batch_size] 51 | self.buffer = self.buffer[batch_size:] 52 | sample = sample.reshape(*m.shape) 53 | sample = torch.tensor(sample, dtype=m.dtype, device=m.device) 54 | return sample * s + m 55 | 56 | 57 | class UniformProposal(Proposal): 58 | def density(self, z): 59 | return z * 0 + 0.5 60 | 61 | def sample(self, m, s): 62 | sample = np.random.rand(*m.shape) * 2 - 1 63 | sample = torch.tensor(sample, dtype=m.dtype, device=m.device) 64 | return sample * s + m 65 | 66 | def kl(self, m, s): 67 | return (0.5 * m**2 + s**2/6 - torch.log(s+self.eps) + 68 | 0.5*np.log(2*np.pi) - np.log(2)).sum(1) 69 | 70 | def kl_uniform(self, m, s): 71 | return (-torch.log(s+self.eps)).sum(1) 72 | 73 | 74 | class TriangularProposal(Proposal): 75 | def density(self, z): 76 | return 1 - np.abs(z) 77 | 78 | def kl(self, m, s): 79 | return (0.5 * m**2 + s**2/12 - torch.log(s+self.eps) + 80 | 0.5*np.log(2*np.pi) - 0.5).sum(1) 81 | 82 | def kl_uniform(self, m, s): 83 | return (-0.5 + np.log(2) - torch.log(s+self.eps)).sum(1) 84 | 85 | 86 | class EpanechnikovProposal(Proposal): 87 | def density(self, z): 88 | return 0.75 * (1 - z**2) 89 | 90 | def kl(self, m, s): 91 | return (0.5 * m**2 + s**2/10 - torch.log(s+self.eps) + 92 | 0.5*np.log(2*np.pi) - 5/3 + np.log(3)).sum(1) 93 | 94 | def kl_uniform(self, m, s): 95 | return (-5/3 + np.log(6)-torch.log(s+self.eps)).sum(1) 96 | 97 | 98 | class QuarticProposal(Proposal): 99 | def density(self, z): 100 | return 15/16 * (1 - z**2)**2 101 | 102 | def kl(self, m, s): 103 | return (0.5 * m**2 + s**2/14 - torch.log(s+self.eps) + 104 | 0.5*np.log(2*np.pi) - 47/15 + np.log(15)).sum(1) 105 | 106 | def kl_uniform(self, m, s): 107 | return (-47/15 + np.log(30) - torch.log(s+self.eps)).sum(1) 108 | 109 | 110 | class TriweightProposal(Proposal): 111 | def density(self, z): 112 | return 35/32 * (1 - z**2)**3 113 | 114 | def kl(self, m, s): 115 | return (0.5 * m**2 + s**2/18 - torch.log(s+self.eps) + 116 | 0.5*np.log(2*np.pi) - 319/70 + np.log(70)).sum(1) 117 | 118 | def kl_uniform(self, m, s): 119 | return (-319/70 + np.log(140) - torch.log(s+self.eps)).sum(1) 120 | 121 | 122 | class TricubeProposal(Proposal): 123 | def density(self, z): 124 | return 70/81 * (1 - np.abs(z)**3)**3 125 | 126 | def kl(self, m, s): 127 | return (0.5 * m**2 + 35*s**2/486 - torch.log(s+self.eps) + 128 | 0.5*np.log(2*np.pi) + np.pi * np.sqrt(3) / 2 - 129 | 1111/140 + np.log(70*np.sqrt(3))).sum(1) 130 | 131 | def kl_uniform(self, m, s): 132 | return (np.pi * np.sqrt(3) / 2 - 1111/140 + 133 | np.log(140*np.sqrt(3)) - torch.log(s+self.eps)).sum(1) 134 | 135 | 136 | class CosineProposal(Proposal): 137 | def density(self, z): 138 | return np.pi/4 * np.cos(np.pi * z / 2) 139 | 140 | def kl(self, m, s): 141 | return (0.5 * m**2 + (0.5 - 4 / np.pi**2)*s**2 - 142 | torch.log(s+self.eps) + 0.5*np.log(2*np.pi) - 143 | 1 + np.log(np.pi/2)).sum(1) 144 | 145 | def kl_uniform(self, m, s): 146 | return (-1 + np.log(np.pi) - torch.log(s+self.eps)).sum(1) 147 | 148 | 149 | class GaussianProposal(Proposal): 150 | def density(self, z): 151 | return np.exp(-(z**2)/2) / np.sqrt(2 * np.pi) 152 | 153 | def sample(self, m, s): 154 | sample = torch.randn(*m.shape, dtype=m.dtype, device=m.device) 155 | return sample * s + m 156 | 157 | def kl(self, m, s): 158 | return 0.5 * (m**2 + s**2 - 2 * torch.log(s+self.eps) - 1).sum(1) 159 | 160 | def kl_uniform(self, m, s): 161 | raise ValueError("KL(N || U) = -inf") 162 | -------------------------------------------------------------------------------- /dd_vae/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn.functional import softplus 6 | 7 | 8 | def prepare_seed(seed=777, n_jobs=8): 9 | random.seed(seed) 10 | np.random.seed(seed) 11 | torch.manual_seed(seed) 12 | torch.backends.cudnn.deterministic = True 13 | torch.backends.cudnn.benchmark = False 14 | torch.set_num_threads(n_jobs) 15 | 16 | 17 | def smoothed_log_indicator(x, temperature): 18 | return softplus(-x/temperature + np.log(1/temperature - 1)) 19 | 20 | 21 | def combine_loss(loss_components, weights): 22 | if len(weights) == 0: 23 | raise ValueError("Specify at least one weight") 24 | loss = 0 25 | for component in weights: 26 | loss += loss_components[component] * weights[component] 27 | return loss 28 | 29 | 30 | class Reshape(nn.Module): 31 | def __init__(self, *shape): 32 | super().__init__() 33 | shape = shape or [-1] 34 | self.shape = shape 35 | 36 | def forward(self, x): 37 | return x.view(x.shape[0], *self.shape) 38 | 39 | 40 | class BaseModel(nn.Module): 41 | def __init__(self): 42 | super().__init__() 43 | self.device = 'cpu' 44 | 45 | def save(self, path): 46 | device = self.device 47 | self.to('cpu') 48 | weights = self.state_dict() 49 | data = { 50 | 'weights': weights, 51 | 'config': self.config 52 | } 53 | torch.save(data, path) 54 | self.to(device) 55 | 56 | @classmethod 57 | def load(cls, path, restore_weights=True): 58 | data = torch.load(path) 59 | model = cls(**data['config']) 60 | if restore_weights: 61 | model.load_state_dict(data['weights']) 62 | return model 63 | 64 | def to(self, device): 65 | self.device = device 66 | super().to(device) 67 | return self 68 | 69 | 70 | class LinearGrowth: 71 | def __init__(self, start, end, start_epoch, end_epoch, log=False): 72 | self.start = start 73 | self.end = end 74 | self.start_epoch = start_epoch 75 | self.end_epoch = end_epoch 76 | self.log = log 77 | if log: 78 | self.start = np.log10(start) 79 | self.end = np.log10(end) 80 | 81 | def __call__(self, epoch): 82 | if epoch <= self.start_epoch: 83 | value = self.start 84 | elif epoch >= self.end_epoch: 85 | value = self.end 86 | else: 87 | delta = (self.end - self.start) / ( 88 | self.end_epoch - self.start_epoch) 89 | value = delta * (epoch - self.start_epoch) + self.start 90 | if self.log: 91 | value = 10**value 92 | return value 93 | 94 | 95 | def to_onehot(x, n): 96 | one_hot = torch.zeros(x.shape[0], n) 97 | one_hot.scatter_(1, x[:, None].cpu(), 1) 98 | one_hot = one_hot.to(x.device) 99 | return one_hot 100 | 101 | 102 | class SpecialTokens: 103 | bos = '' 104 | eos = '' 105 | pad = '' 106 | unk = '' 107 | 108 | 109 | class CharVocab: 110 | @classmethod 111 | def from_data(cls, data, *args, **kwargs): 112 | chars = set() 113 | for string in data: 114 | chars.update(string) 115 | 116 | return cls(chars, *args, **kwargs) 117 | 118 | def __init__(self, chars, st=SpecialTokens): 119 | if (st.bos in chars) or (st.eos in chars) or \ 120 | (st.pad in chars) or (st.unk in chars): 121 | raise ValueError('SpecialTokens in chars') 122 | 123 | all_syms = sorted(list(chars)) + [st.bos, st.eos, st.pad, st.unk] 124 | 125 | self.st = st 126 | self.c2i = {c: i for i, c in enumerate(all_syms)} 127 | self.i2c = {i: c for i, c in enumerate(all_syms)} 128 | 129 | def __len__(self): 130 | return len(self.c2i) 131 | 132 | @property 133 | def bos(self): 134 | return self.c2i[self.st.bos] 135 | 136 | @property 137 | def eos(self): 138 | return self.c2i[self.st.eos] 139 | 140 | @property 141 | def pad(self): 142 | return self.c2i[self.st.pad] 143 | 144 | @property 145 | def unk(self): 146 | return self.c2i[self.st.unk] 147 | 148 | def char2id(self, char): 149 | if char not in self.c2i: 150 | return self.unk 151 | 152 | return self.c2i[char] 153 | 154 | def id2char(self, id): 155 | if id not in self.i2c: 156 | return self.st.unk 157 | 158 | return self.i2c[id] 159 | 160 | def string2ids(self, string, add_bos=False, add_eos=False): 161 | ids = [self.char2id(c) for c in string] 162 | 163 | if add_bos: 164 | ids = [self.bos] + ids 165 | if add_eos: 166 | ids = ids + [self.eos] 167 | 168 | return ids 169 | 170 | def ids2string(self, ids, rem_bos=True, rem_eos=True): 171 | if len(ids) == 0: 172 | return '' 173 | if rem_bos and ids[0] == self.bos: 174 | ids = ids[1:] 175 | if rem_eos and ids[-1] == self.eos: 176 | ids = ids[:-1] 177 | 178 | string = ''.join([self.id2char(id) for id in ids]) 179 | 180 | return string 181 | 182 | 183 | class StringDataset: 184 | def __init__(self, vocab, data): 185 | self.tokens = [vocab.string2ids(s) for s in data] 186 | self.data = data 187 | self.bos = vocab.bos 188 | self.eos = vocab.eos 189 | 190 | def __len__(self): 191 | return len(self.tokens) 192 | 193 | def __getitem__(self, index): 194 | tokens = self.tokens[index] 195 | with_bos = torch.tensor([self.bos] + tokens, dtype=torch.long) 196 | with_eos = torch.tensor(tokens + [self.eos], dtype=torch.long) 197 | return with_bos, with_eos, self.data[index] 198 | 199 | 200 | def collate(batch, pad, return_data=False): 201 | with_bos, with_eos, data = list(zip(*batch)) 202 | lengths = [len(x) for x in with_bos] 203 | order = np.argsort(lengths)[::-1] 204 | with_bos = [with_bos[i] for i in order] 205 | with_eos = [with_eos[i] for i in order] 206 | lengths = [lengths[i] for i in order] 207 | with_bos = torch.nn.utils.rnn.pad_sequence( 208 | with_bos, padding_value=pad 209 | ) 210 | with_eos = torch.nn.utils.rnn.pad_sequence( 211 | with_eos, padding_value=pad 212 | ) 213 | if return_data: 214 | data = np.array(data)[order] 215 | return with_bos, with_eos, lengths, data 216 | return with_bos, with_eos, lengths 217 | 218 | 219 | def batch_to_device(batch, device): 220 | return [ 221 | x.to(device) if isinstance(x, torch.Tensor) else x 222 | for x in batch 223 | ] 224 | -------------------------------------------------------------------------------- /dd_vae/vae_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .utils import BaseModel, batch_to_device 4 | from .proposals import get_proposals 5 | 6 | 7 | class VAE(BaseModel): 8 | """ 9 | Generative Recurrent Autoregressive Decoder 10 | """ 11 | def __init__(self, prior, proposal): 12 | super().__init__() 13 | if prior not in ['gaussian', 'uniform']: 14 | raise ValueError( 15 | "Supported priors are 'gaussian' and 'uniform'") 16 | if proposal not in get_proposals(): 17 | proposals = list(get_proposals().keys()) 18 | raise ValueError( 19 | f"Supported proposals are {proposals}") 20 | 21 | self.config = { 22 | 'proposal': proposal, 23 | 'prior': prior 24 | } 25 | self.proposal = get_proposals()[proposal]() 26 | self.prior = prior 27 | 28 | def encode(self, batch): 29 | """ 30 | Encodes batch and returns latent codes 31 | """ 32 | raise NotImplementedError 33 | 34 | def decode(self, batch, z=None, state=None): 35 | """ 36 | Decodes batch and returns logits and intermediate states 37 | """ 38 | raise NotImplementedError 39 | 40 | def compute_metrics(self, batch, logits): 41 | return {} 42 | 43 | def language_model_nll(self, with_eos, logits): 44 | loss = nn.NLLLoss(ignore_index=self.vocab.pad)( 45 | logits.transpose(1, 2), with_eos) 46 | return loss 47 | 48 | def encoder_parameters(self): 49 | raise NotImplementedError 50 | 51 | def decoder_parameters(self): 52 | raise NotImplementedError 53 | 54 | def get_mu_std(self, z): 55 | dim = z.shape[1] // 2 56 | mu, logstd = z.split(dim, 1) 57 | std = logstd.exp() 58 | 59 | if self.prior == 'uniform': 60 | left = torch.sigmoid(mu - std) 61 | right = torch.sigmoid(mu + std) 62 | mu = (right + left) - 1 63 | std = (right - left) 64 | 65 | return mu, std 66 | 67 | def sample_kl(self, z, mu_only=False): 68 | mu, std = self.get_mu_std(z) 69 | if self.prior == 'gaussian': 70 | kl_loss = self.proposal.kl(mu, std).mean() 71 | elif self.prior == 'uniform': 72 | kl_loss = self.proposal.kl_uniform(mu, std).mean() 73 | else: 74 | raise ValueError 75 | 76 | if mu_only: 77 | sample = mu 78 | else: 79 | sample = self.proposal.sample(mu, std) 80 | return sample, kl_loss 81 | 82 | def argmax_nll(self, batch, logits, temperature): 83 | raise NotImplementedError 84 | 85 | def sample_nll(self, batch, logits): 86 | raise NotImplementedError 87 | 88 | def get_loss_components(self, batch, temperature): 89 | batch = batch_to_device(batch, self.device) 90 | z = self.encode(batch) 91 | sample, kl_loss = self.sample_kl(z, not self.variational) 92 | logits, _ = self.decode(batch, z=sample) 93 | metrics = self.compute_metrics(batch, logits) 94 | language_model_nll = self.sample_nll(batch, logits) 95 | argmax_nll = self.argmax_nll(batch, logits, temperature) 96 | loss_components = { 97 | 'sample_nll': language_model_nll, 98 | 'kl_loss': kl_loss, 99 | 'argmax_nll': argmax_nll, 100 | **metrics 101 | } 102 | return loss_components 103 | 104 | def sample(self, batch_size=1, mode='argmax', z=None): 105 | raise NotImplementedError 106 | -------------------------------------------------------------------------------- /dd_vae/vae_mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .utils import smoothed_log_indicator, Reshape 4 | from .vae_base import VAE 5 | 6 | 7 | class VAE_MNIST(VAE): 8 | def __init__(self, layer_sizes, latent_size, 9 | proposal='tricube', 10 | prior='gaussian', 11 | variational=True, 12 | image_size=28, 13 | channels=1): 14 | super().__init__(prior=prior, proposal=proposal) 15 | self.config.update({ 16 | 'layer_sizes': layer_sizes, 17 | 'latent_size': latent_size, 18 | 'variational': variational, 19 | 'image_size': image_size, 20 | 'channels': channels 21 | }) 22 | 23 | self.encoder = nn.Sequential( 24 | Reshape(784), 25 | *self.DNN(784, *layer_sizes, 2*latent_size) 26 | ) 27 | 28 | self.decoder = nn.Sequential( 29 | *self.DNN(latent_size, *layer_sizes[::-1], 784), 30 | Reshape(1, 28, 28), 31 | nn.Sigmoid() 32 | ) 33 | self.latent_size = latent_size 34 | self.variational = variational 35 | 36 | @staticmethod 37 | def DNN(*layers): 38 | net = [] 39 | for i in range(len(layers) - 1): 40 | net.append(nn.Linear(layers[i], layers[i+1])) 41 | if i != len(layers) - 2: 42 | net.append(nn.LeakyReLU()) 43 | return net 44 | 45 | def compute_metrics(self, batch, logits): 46 | images, _ = batch 47 | match = (images.long() == (logits > 0.5).long()) 48 | match = match.view(match.shape[0], -1).float() 49 | return { 50 | 'pixel_accuracy': match.mean(), 51 | 'image_accuracy': match.min(1)[0].mean(), 52 | 'image_accuracy@10': ((1-match).sum(1) < 10).float().mean() 53 | } 54 | 55 | def encoder_parameters(self): 56 | return self.encoder.parameters() 57 | 58 | def decoder_parameters(self): 59 | return self.decoder.parameters() 60 | 61 | def sample_nll(self, batch, logits): 62 | images, _ = batch 63 | return torch.nn.BCELoss()(logits, images) 64 | 65 | def encode(self, batch): 66 | image, _ = batch 67 | return self.encoder(image.float()) 68 | 69 | def decode(self, batch, z=None, state=None): 70 | return self.decoder(z), None 71 | 72 | def argmax_nll(self, batch, logits, temperature): 73 | images, _ = batch 74 | p_correct = logits*images + (1 - logits)*(1 - images) 75 | delta = p_correct - (1 - p_correct) 76 | reconstruction_loss = smoothed_log_indicator(delta, temperature).mean() 77 | return reconstruction_loss 78 | 79 | def sample(self, batch_size=1, z=None): 80 | if z is None: 81 | if self.prior == 'gaussian': 82 | z = torch.randn(batch_size, self.latent_size) 83 | elif self.prior == 'uniform': 84 | z = torch.rand(batch_size, self.latent_size)*2 - 1 85 | z = z.to(self.device) 86 | return self.decoder(z) 87 | -------------------------------------------------------------------------------- /dd_vae/vae_rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | from .utils import to_onehot, smoothed_log_indicator 5 | from .proposals import get_proposals 6 | from .vae_base import VAE 7 | 8 | 9 | class VAE_RNN(VAE): 10 | def __init__(self, embedding_size, 11 | hidden_size, latent_size, 12 | num_layers, vocab, 13 | proposal='tricube', 14 | prior='gaussian', 15 | variational=True, 16 | use_embedding_input=True, 17 | fc=None, 18 | fc_norm=False): 19 | super().__init__(prior=prior, proposal=proposal) 20 | self.vocab = vocab 21 | self.config.update({ 22 | 'embedding_size': embedding_size, 23 | 'hidden_size': hidden_size, 24 | 'latent_size': latent_size, 25 | 'num_layers': num_layers, 26 | 'proposal': proposal, 27 | 'prior': prior, 28 | 'vocab': self.vocab, 29 | 'variational': variational, 30 | 'use_embedding_input': use_embedding_input, 31 | 'fc': fc, 32 | 'fc_norm': fc_norm 33 | }) 34 | 35 | self.vocab_size = len(self.vocab) 36 | self.num_layers = num_layers 37 | self.latent_size = latent_size 38 | self.encoder_embedding = nn.Embedding(self.vocab_size, 39 | embedding_size, 40 | self.vocab.pad) 41 | rnn_input_size = embedding_size 42 | self.encoder = nn.GRU(rnn_input_size, hidden_size, 43 | num_layers=num_layers) 44 | self.encoder_to_latent = self.get_fc( 45 | fc, hidden_size * num_layers, 2 * latent_size, fc_norm) 46 | self.latent_to_decoder = self.get_fc( 47 | fc, latent_size, hidden_size * num_layers, fc_norm) 48 | if use_embedding_input: 49 | self.decoder_embedding = nn.Embedding(self.vocab_size, 50 | embedding_size, 51 | self.vocab.pad) 52 | decoder_input_size = rnn_input_size 53 | else: 54 | self.decoder_embedding = None 55 | decoder_input_size = 1 56 | 57 | self.decoder = nn.GRU(decoder_input_size, hidden_size, 58 | num_layers=num_layers) 59 | 60 | self.proposal = get_proposals()[proposal]() 61 | self.prior = prior 62 | 63 | self.decoder_to_logits = nn.Linear( 64 | hidden_size, self.vocab_size 65 | ) 66 | 67 | self.variational = variational 68 | self.use_embedding_input = use_embedding_input 69 | 70 | @staticmethod 71 | def get_fc(layers, input_dim, output_dim, fc_norm): 72 | if layers is None: 73 | return nn.Linear(input_dim, output_dim) 74 | layers = [input_dim] + layers + [output_dim] 75 | network = [] 76 | for i in range(len(layers) - 2): 77 | network.append(nn.Linear(layers[i], layers[i+1])) 78 | if fc_norm: 79 | network.append(nn.LayerNorm(layers[i+1])) 80 | network.append(nn.ELU()) 81 | network.append(nn.Linear(layers[-2], layers[-1])) 82 | return nn.Sequential(*network) 83 | 84 | def encode(self, batch): 85 | with_bos, with_eos, lengths = batch 86 | emb = self.encoder_embedding(with_eos) 87 | packed_sequence = nn.utils.rnn.pack_padded_sequence(emb, lengths) 88 | _, h = self.encoder(packed_sequence, None) 89 | h = h.transpose(0, 1).contiguous().view(h.shape[1], -1) 90 | z = self.encoder_to_latent(h) 91 | return z 92 | 93 | def decode(self, batch, z=None, state=None): 94 | with_bos, with_eos, lengths = batch 95 | if state is None: 96 | state = self.latent_to_decoder(z) 97 | state = state.view( 98 | state.shape[0], self.num_layers, -1 99 | ).transpose(0, 1).contiguous() 100 | 101 | if self.use_embedding_input: 102 | emb = self.decoder_embedding(with_bos) 103 | else: 104 | emb = torch.zeros( 105 | (with_bos.shape[0], with_bos.shape[1], 1), 106 | device=with_bos.device 107 | ) 108 | 109 | packed_sequence = nn.utils.rnn.pack_padded_sequence(emb, lengths) 110 | states, state = self.decoder(packed_sequence, state) 111 | states, _ = nn.utils.rnn.pad_packed_sequence(states) 112 | logits = self.decoder_to_logits(states) 113 | logits = torch.log_softmax(logits, 2) 114 | return logits, state 115 | 116 | def compute_metrics(self, batch, logits): 117 | with_bos, with_eos, lengths = batch 118 | predictions = torch.argmax(logits, 2) 119 | pad_mask = (with_eos == self.vocab.pad) 120 | non_pad_mask = (~pad_mask).float() 121 | correct_prediction = (predictions == with_eos) 122 | string_accuracy = ( 123 | correct_prediction | pad_mask 124 | ).float().min(0)[0].mean() 125 | character_accuracy = ( 126 | correct_prediction.float() * non_pad_mask 127 | ).sum() / non_pad_mask.sum() 128 | return { 129 | 'string_accuracy': string_accuracy, 130 | 'character_accuracy': character_accuracy 131 | } 132 | 133 | def sample_nll(self, batch, logits): 134 | with_bos, with_eos, lengths = batch 135 | loss = nn.NLLLoss( 136 | ignore_index=self.vocab.pad, 137 | reduction='mean')(logits.transpose(1, 2), with_eos) 138 | return loss 139 | 140 | def encoder_parameters(self): 141 | return nn.ModuleList([self.encoder_embedding, 142 | self.encoder, 143 | self.encoder_to_latent]).parameters() 144 | 145 | def decoder_parameters(self): 146 | modules = [self.decoder, 147 | self.latent_to_decoder, 148 | self.decoder_to_logits] 149 | if self.use_embedding_input: 150 | modules.append(self.decoder_embedding) 151 | return nn.ModuleList(modules).parameters() 152 | 153 | def argmax_nll(self, batch, logits, temperature): 154 | with_bos, with_eos, lengths = batch 155 | with_eos = with_eos.view(-1) 156 | 157 | logits = logits.view(-1, logits.shape[2]) 158 | oh = to_onehot(with_eos, logits.shape[1]) 159 | delta = (logits * oh).sum(1, keepdim=True) - logits 160 | error = smoothed_log_indicator(delta, temperature) * (1 - oh) 161 | error = error.sum(1) 162 | pad_mask = (with_eos != self.vocab.pad).float() 163 | error = error * pad_mask 164 | error = error.mean() / pad_mask.mean() 165 | return error 166 | 167 | def sample(self, batch_size=1, max_len=100, mode='argmax', 168 | z=None, keep_stats=False, 169 | temperature=1): 170 | if mode not in ['sample', 'argmax']: 171 | raise ValueError("Can either sample or argmax") 172 | generated_sequence = [] 173 | if z is None: 174 | if self.prior == 'gaussian': 175 | z = torch.randn(batch_size, self.latent_size) 176 | elif self.prior == 'uniform': 177 | z = torch.rand(batch_size, self.latent_size)*2 - 1 178 | batch_size = z.shape[0] 179 | character = [[self.vocab.bos for _ in range(batch_size)]] 180 | character = torch.tensor(character, dtype=torch.long, 181 | device=self.device) 182 | h = self.latent_to_decoder(z.to(self.device)) 183 | h = h.view( 184 | h.shape[0], self.num_layers, -1 185 | ).transpose(0, 1).contiguous() 186 | if keep_stats: 187 | stats = [] 188 | lengths = [1]*batch_size 189 | for i in range(max_len): 190 | batch = (character, None, lengths) 191 | logits, h = self.decode(batch, state=h) 192 | if keep_stats: 193 | stats.append([logits.detach().cpu().numpy()]) 194 | if mode == 'argmax': 195 | character = torch.argmax(logits[0], 1) 196 | else: 197 | character = torch.distributions.Categorical( 198 | torch.exp(logits[0])).sample() 199 | character = character.detach()[None, :] 200 | generated_sequence.append(character.cpu().numpy()) 201 | generated_sequence = np.concatenate(generated_sequence, 0).T 202 | samples = [self.vocab.ids2string(s) for s in generated_sequence] 203 | eos = self.vocab.i2c[self.vocab.eos] 204 | samples = [x.split(eos)[0] for x in samples] 205 | if keep_stats: 206 | return samples, stats 207 | else: 208 | return samples 209 | -------------------------------------------------------------------------------- /illustrations.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from matplotlib import pyplot as plt\n", 10 | "import numpy as np\n", 11 | "import torch\n", 12 | "%matplotlib inline\n", 13 | "\n", 14 | "from dd_vae.proposals import get_proposals\n", 15 | "from dd_vae.utils import smoothed_log_indicator\n", 16 | "\n", 17 | "import matplotlib\n", 18 | "matplotlib.rcParams['text.usetex'] = True\n", 19 | "matplotlib.rcParams['text.latex.preamble'] = [r'\\usepackage{sansmath}', r'\\sansmath']\n", 20 | "matplotlib.rcParams['font.family'] = 'sans-serif'\n", 21 | "matplotlib.rcParams['font.sans-serif'] = 'Helvetica, Avant Garde, Computer Modern Sans serif'" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": { 28 | "scrolled": false 29 | }, 30 | "outputs": [ 31 | { 32 | "name": "stderr", 33 | "output_type": "stream", 34 | "text": [ 35 | "findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.\n" 36 | ] 37 | }, 38 | { 39 | "data": { 40 | "image/png": "\n", 41 | "text/plain": [ 42 | "
" 43 | ] 44 | }, 45 | "metadata": { 46 | "needs_background": "light" 47 | }, 48 | "output_type": "display_data" 49 | } 50 | ], 51 | "source": [ 52 | "fig = plt.figure()\n", 53 | "ax = plt.subplot(111)\n", 54 | "\n", 55 | "x = np.linspace(-1, 1, 1000)\n", 56 | "i = 0\n", 57 | "for name in get_proposals():\n", 58 | " if name == 'gaussian':\n", 59 | " continue\n", 60 | " y = get_proposals()[name]().density(x)\n", 61 | " ax.plot(x, y, label=name.title(), linewidth=2)\n", 62 | " i += 1\n", 63 | "\n", 64 | "box = ax.get_position()\n", 65 | "ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])\n", 66 | "\n", 67 | "lgd = ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=16)\n", 68 | "plt.savefig('images/kernels.pdf', bbox_extra_artists=(lgd,), bbox_inches='tight')\n", 69 | "plt.show()" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 4, 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "data": { 79 | "image/png": "\n", 80 | "text/plain": [ 81 | "
" 82 | ] 83 | }, 84 | "metadata": { 85 | "needs_background": "light" 86 | }, 87 | "output_type": "display_data" 88 | } 89 | ], 90 | "source": [ 91 | "x = np.linspace(-1, 1, 1000)\n", 92 | "plt.plot([0, 0], [0, 1], ':')\n", 93 | "for temp in [0.5, 0.1, 0.01]:\n", 94 | " plt.plot(x, np.exp(-smoothed_log_indicator(torch.tensor(x), temp)).numpy(), label=f'$\\\\tau$={temp}')\n", 95 | "plt.legend(fontsize=16)\n", 96 | "plt.tight_layout()\n", 97 | "plt.savefig('images/smoothed_indicator.pdf')" 98 | ] 99 | } 100 | ], 101 | "metadata": { 102 | "kernelspec": { 103 | "display_name": "Python 3", 104 | "language": "python", 105 | "name": "python3" 106 | }, 107 | "language_info": { 108 | "codemirror_mode": { 109 | "name": "ipython", 110 | "version": 3 111 | }, 112 | "file_extension": ".py", 113 | "mimetype": "text/x-python", 114 | "name": "python", 115 | "nbconvert_exporter": "python", 116 | "pygments_lexer": "ipython3", 117 | "version": "3.6.9" 118 | } 119 | }, 120 | "nbformat": 4, 121 | "nbformat_minor": 2 122 | } 123 | -------------------------------------------------------------------------------- /images/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/.DS_Store -------------------------------------------------------------------------------- /images/kernels.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/kernels.pdf -------------------------------------------------------------------------------- /images/mnist/latent_N_N.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/mnist/latent_N_N.png -------------------------------------------------------------------------------- /images/mnist/latent_U_U.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/mnist/latent_U_U.png -------------------------------------------------------------------------------- /images/moses_FCD.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/moses_FCD.pdf -------------------------------------------------------------------------------- /images/moses_SNN.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/moses_SNN.pdf -------------------------------------------------------------------------------- /images/smoothed_indicator.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/smoothed_indicator.pdf -------------------------------------------------------------------------------- /images/synthetic/N_N.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/synthetic/N_N.png -------------------------------------------------------------------------------- /images/synthetic/U_T.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/synthetic/U_T.png -------------------------------------------------------------------------------- /images/synthetic/U_U.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/synthetic/U_U.png -------------------------------------------------------------------------------- /images/zinc/DD_VAE_GAUSSIAN_molecule_0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/zinc/DD_VAE_GAUSSIAN_molecule_0.pdf -------------------------------------------------------------------------------- /images/zinc/DD_VAE_GAUSSIAN_molecule_1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/zinc/DD_VAE_GAUSSIAN_molecule_1.pdf -------------------------------------------------------------------------------- /images/zinc/DD_VAE_GAUSSIAN_molecule_2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/zinc/DD_VAE_GAUSSIAN_molecule_2.pdf -------------------------------------------------------------------------------- /images/zinc/DD_VAE_GAUSSIAN_top50_molecules.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/zinc/DD_VAE_GAUSSIAN_top50_molecules.pdf -------------------------------------------------------------------------------- /images/zinc/DD_VAE_TRICUBE_molecule_0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/zinc/DD_VAE_TRICUBE_molecule_0.pdf -------------------------------------------------------------------------------- /images/zinc/DD_VAE_TRICUBE_molecule_1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/zinc/DD_VAE_TRICUBE_molecule_1.pdf -------------------------------------------------------------------------------- /images/zinc/DD_VAE_TRICUBE_molecule_2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/zinc/DD_VAE_TRICUBE_molecule_2.pdf -------------------------------------------------------------------------------- /images/zinc/DD_VAE_TRICUBE_top50_molecules.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/zinc/DD_VAE_TRICUBE_top50_molecules.pdf -------------------------------------------------------------------------------- /images/zinc/VAE_GAUSSIAN_molecule_0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/zinc/VAE_GAUSSIAN_molecule_0.pdf -------------------------------------------------------------------------------- /images/zinc/VAE_GAUSSIAN_molecule_1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/zinc/VAE_GAUSSIAN_molecule_1.pdf -------------------------------------------------------------------------------- /images/zinc/VAE_GAUSSIAN_molecule_2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/zinc/VAE_GAUSSIAN_molecule_2.pdf -------------------------------------------------------------------------------- /images/zinc/VAE_GAUSSIAN_top50_molecules.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/zinc/VAE_GAUSSIAN_top50_molecules.pdf -------------------------------------------------------------------------------- /images/zinc/VAE_TRICUBE_molecule_0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/zinc/VAE_TRICUBE_molecule_0.pdf -------------------------------------------------------------------------------- /images/zinc/VAE_TRICUBE_molecule_1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/zinc/VAE_TRICUBE_molecule_1.pdf -------------------------------------------------------------------------------- /images/zinc/VAE_TRICUBE_molecule_2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/zinc/VAE_TRICUBE_molecule_2.pdf -------------------------------------------------------------------------------- /images/zinc/VAE_TRICUBE_top50_molecules.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insilicomedicine/DD-VAE/13498d098bae2c8177abec61ab80d8b618d274f3/images/zinc/VAE_TRICUBE_top50_molecules.pdf -------------------------------------------------------------------------------- /moses_prepare_metrics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Distribution learning on MOSES dataset: calculating metrics\n", 8 | "\n", 9 | "You can calculate metrics from checkpoints using this notebook. Note that training the models takes ~30h per model on Titan X (Pascal); computing MOSES metrics for checkpoints takes ~40h per model.\n", 10 | "\n", 11 | "To reproduce the models and statistics, run the following bash script:\n", 12 | "```{bash}\n", 13 | "for SEED in 1 2 3\n", 14 | "do\n", 15 | " for PROPOSAL in gaussian triweight\n", 16 | " do\n", 17 | " python train.py --config configs/moses/VAE_$PROPOSAL\\_seed$SEED.ini --device cuda:0\n", 18 | " python train.py --config configs/moses/DD-VAE_$PROPOSAL\\_seed$SEED.ini --device cuda:0\n", 19 | " done\n", 20 | "done\n", 21 | "```\n", 22 | "\n", 23 | "This script will save models into `models/moses` folder and tensorboard logs into `logs/moses` folder.\n", 24 | "\n", 25 | "The notebook below will create files with all MOSES metrics for each checkpoint. `moses_plots.ipynb` will use logs and MOSES metrics to build final plots." 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 1, 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "name": "stderr", 35 | "output_type": "stream", 36 | "text": [ 37 | "RDKit WARNING: [11:32:06] Enabling RDKit 2019.09.3 jupyter extensions\n" 38 | ] 39 | } 40 | ], 41 | "source": [ 42 | "import os\n", 43 | "import glob\n", 44 | "import pickle\n", 45 | "import gc\n", 46 | "from time import sleep\n", 47 | "\n", 48 | "import rdkit\n", 49 | "import pandas as pd\n", 50 | "from tqdm.auto import tqdm\n", 51 | "import numpy as np\n", 52 | "import torch\n", 53 | "from moses.metrics import get_all_metrics\n", 54 | "\n", 55 | "from dd_vae.vae_rnn import VAE_RNN\n", 56 | "from dd_vae.utils import prepare_seed\n", 57 | "\n", 58 | "rdkit.rdBase.DisableLog('rdApp.*')" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 2, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "DEVICE = 'cuda:0'\n", 68 | "N_JOBS = 32\n", 69 | "\n", 70 | "def load_csv(path):\n", 71 | " df = pd.read_csv(path, compression='gzip', dtype='str', header=None)\n", 72 | " return list(df[0].values)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 3, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "test = load_csv('data/moses/test.csv.gz')\n", 82 | "test_scaffolds = load_csv('data/moses/test_scaffolds.csv.gz')\n", 83 | "train = load_csv('data/moses/train.csv.gz')\n", 84 | "\n", 85 | "test_stats = np.load('data/moses/test_stats.npz', allow_pickle=True)['stats'].item()\n", 86 | "test_scaffold_stats = np.load('data/moses/test_scaffolds_stats.npz', allow_pickle=True)['stats'].item()" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 4, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "def prepare_metrics(name, checkpoint_id, overwrite=False, device='cpu', n_jobs=1):\n", 96 | " path = f'models/moses/{name}/checkpoint_{checkpoint_id}.pt'\n", 97 | " output_path = f'metrics/{name}/{checkpoint_id}.pkl'\n", 98 | " os.makedirs(f'metrics/{name}/', exist_ok=True)\n", 99 | " if os.path.exists(output_path) and not overwrite:\n", 100 | " raise ValueError(f\"Metrics file {output_path} already exists\")\n", 101 | " model = VAE_RNN.load(path).to(device)\n", 102 | " prepare_seed(1)\n", 103 | " with torch.no_grad():\n", 104 | " smiles = sum([model.sample(100) for _ in tqdm(range(300))], [])\n", 105 | " model.to(device)\n", 106 | " del model\n", 107 | " torch.cuda.empty_cache()\n", 108 | " gc.collect()\n", 109 | " torch.cuda.empty_cache()\n", 110 | " if device == 'cpu':\n", 111 | " gpu = -1\n", 112 | " else:\n", 113 | " gpu = int(device.split(':')[1])\n", 114 | "\n", 115 | " metrics = get_all_metrics(\n", 116 | " test=test, gen=smiles,\n", 117 | " test_scaffolds=test_scaffolds, gpu=gpu, n_jobs=n_jobs,\n", 118 | " ptest=test_stats,\n", 119 | " ptest_scaffolds=test_scaffold_stats,\n", 120 | " train=train)\n", 121 | "\n", 122 | " with open(output_path, 'wb') as f:\n", 123 | " pickle.dump(metrics, f)\n", 124 | " torch.cuda.empty_cache()\n", 125 | " gc.collect()\n", 126 | " torch.cuda.empty_cache()" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 5, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "def get_epoch_id(path):\n", 136 | " try:\n", 137 | " return int(path.split('_')[-1][:-3])\n", 138 | " except ValueError:\n", 139 | " return None" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 6, 145 | "metadata": { 146 | "scrolled": false 147 | }, 148 | "outputs": [], 149 | "source": [ 150 | "checkpoints = glob.glob('models/moses/*/*.pt')\n", 151 | "checkpoints = [x for x in checkpoints if get_epoch_id(x) is not None]\n", 152 | "for checkpoint in checkpoints:\n", 153 | " try:\n", 154 | " epoch_id = int(checkpoint.split('_')[-1][:-3])\n", 155 | " except ValueError:\n", 156 | " continue\n", 157 | " config_id = checkpoint.split('/')[-2]\n", 158 | " print(f\"Processing {checkpoint}\")\n", 159 | " try:\n", 160 | " prepare_metrics(config_id, epoch_id, device=DEVICE, n_jobs=N_JOBS)\n", 161 | " except ValueError:\n", 162 | " pass" 163 | ] 164 | } 165 | ], 166 | "metadata": { 167 | "kernelspec": { 168 | "display_name": "Python 3", 169 | "language": "python", 170 | "name": "python3" 171 | }, 172 | "language_info": { 173 | "codemirror_mode": { 174 | "name": "ipython", 175 | "version": 3 176 | }, 177 | "file_extension": ".py", 178 | "mimetype": "text/x-python", 179 | "name": "python", 180 | "nbconvert_exporter": "python", 181 | "pygments_lexer": "ipython3", 182 | "version": "3.6.9" 183 | } 184 | }, 185 | "nbformat": 4, 186 | "nbformat_minor": 2 187 | } 188 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | setup( 5 | name='dd_vae', 6 | packages=find_packages(), 7 | python_requires='>=3.5.0', 8 | version='0.1', 9 | install_requires=[ 10 | 'tqdm', 'numpy', 11 | 'pandas', 'scipy', 12 | 'torch', 'networkx', 13 | 'Theano' 14 | ], 15 | description=('DD-VAE'), 16 | ) 17 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ast 3 | import sys 4 | from functools import partial 5 | from configparser import ConfigParser 6 | from tqdm import tqdm 7 | import os 8 | import pandas as pd 9 | 10 | import torch 11 | from torch.utils.data import DataLoader 12 | from tensorboardX import SummaryWriter 13 | from torchvision.datasets import MNIST 14 | 15 | from dd_vae.vae_mnist import VAE_MNIST 16 | from dd_vae.vae_rnn import VAE_RNN 17 | from dd_vae.utils import CharVocab, collate, StringDataset, \ 18 | LinearGrowth, combine_loss, prepare_seed 19 | from torchvision import transforms 20 | from torch.optim.lr_scheduler import MultiStepLR, StepLR 21 | 22 | 23 | def parse_args(args): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--config', 26 | type=str, required=True, 27 | help='Path to configuration file') 28 | parser.add_argument('--device', 29 | type=str, required=True, default='cpu', 30 | help='Training device') 31 | parsed_args, errors = parser.parse_known_args(args[1:]) 32 | if len(errors) != 0: 33 | raise ValueError(f"Unknown arguments {errors}") 34 | return parsed_args 35 | 36 | 37 | def infer_config_types(parameters): 38 | return dict({ 39 | k: ast.literal_eval(v) 40 | for k, v in parameters.items() 41 | }) 42 | 43 | 44 | def parse_config(config_path): 45 | config = ConfigParser() 46 | paths = config.read(config_path) 47 | if len(paths) == 0: 48 | raise ValueError(f"Config file {config_path} does not exist") 49 | 50 | infered_config = { 51 | k: infer_config_types(v) 52 | for k, v in config.items() 53 | } 54 | return infered_config 55 | 56 | 57 | def add_dict(left, right): 58 | for key, value in right.items(): 59 | left[key] = left.get(key, 0) + value.item() 60 | 61 | 62 | def train_epoch(model, loss_weights, epoch, data_loader, 63 | backward, temperature, logger, 64 | optimizer, verbose=True, clamp=None, fine_tune=False): 65 | if backward: 66 | label = '/train' 67 | else: 68 | label = '/test' 69 | total_loss = {} 70 | iterations = 0 71 | for batch in tqdm( 72 | data_loader, postfix=f'Epoch {epoch} {label}', 73 | disable=not verbose): 74 | iterations += 1 75 | loss_components = model.get_loss_components(batch, temperature) 76 | loss = combine_loss(loss_components, loss_weights) 77 | loss_components['loss'] = loss 78 | add_dict(total_loss, loss_components) 79 | if backward: 80 | optimizer['encoder'].zero_grad() 81 | optimizer['decoder'].zero_grad() 82 | loss.backward() 83 | if clamp is not None: 84 | for param in model.parameters(): 85 | param.grad.clamp_(-clamp, clamp) 86 | if not fine_tune: 87 | optimizer['encoder'].step() 88 | optimizer['decoder'].step() 89 | 90 | for key, value in total_loss.items(): 91 | logger.add_scalar(key + label, 92 | value / iterations, 93 | global_step=epoch) 94 | 95 | 96 | def prepare_mnist(config): 97 | transform = transforms.Compose([ 98 | transforms.ToTensor(), 99 | transforms.Lambda(lambda x: (x > 0.3).float()) 100 | ]) 101 | train_dataset = MNIST('data/mnist/', train=True, 102 | download=True, transform=transform) 103 | test_dataset = MNIST('data/mnist/', train=False, 104 | transform=transform) 105 | batch_size = config['train']['batch_size'] 106 | train_loader = DataLoader( 107 | train_dataset, batch_size=batch_size, shuffle=True) 108 | test_loader = DataLoader( 109 | test_dataset, batch_size=batch_size, shuffle=False) 110 | if 'load' in config['model']: 111 | model = VAE_MNIST.load(config['model']['load']) 112 | else: 113 | model = VAE_MNIST(**config['model']) 114 | return train_loader, test_loader, model 115 | 116 | 117 | def load_csv(path): 118 | if path.endswith('.csv'): 119 | return [x.strip() for x in open(path)] 120 | if path.endswith('.csv.gz'): 121 | df = pd.read_csv(path, compression='gzip', 122 | dtype='str', header=None) 123 | return list(df[0].values) 124 | raise ValueError("Unknown format") 125 | 126 | 127 | def prepare_rnn(config): 128 | data_config = config['data'] 129 | train_config = config['train'] 130 | train_data = load_csv(data_config['train_path']) 131 | vocab = CharVocab.from_data(train_data) 132 | if 'load' in config['model']: 133 | print("LOADING") 134 | model = VAE_RNN.load(config['model']['load']) 135 | vocab = model.vocab 136 | else: 137 | model = VAE_RNN(vocab=vocab, **config['model']) 138 | collate_pad = partial(collate, pad=vocab.pad) 139 | train_dataset = StringDataset(vocab, train_data) 140 | train_loader = DataLoader( 141 | train_dataset, collate_fn=collate_pad, 142 | batch_size=train_config['batch_size'], shuffle=True) 143 | if 'test_path' in data_config: 144 | test_data = load_csv(data_config['test_path']) 145 | test_dataset = StringDataset(vocab, test_data) 146 | test_loader = DataLoader( 147 | test_dataset, collate_fn=collate_pad, 148 | batch_size=train_config['batch_size']) 149 | else: 150 | test_loader = None 151 | return train_loader, test_loader, model 152 | 153 | 154 | def train(config_path, device): 155 | """ 156 | Trains a deterministic VAE model. 157 | 158 | Parameters: 159 | config_path: path to .ini file with model configuration 160 | device: device for training ('cpu' for CPU, 'cuda:n' for GPU #n) 161 | train_data: list of train dataset strings 162 | test_data: list of test dataset strings 163 | """ 164 | config = parse_config(config_path) 165 | prepare_seed(seed=config['train'].get('seed', 777)) 166 | 167 | data_config = config['data'] 168 | if data_config['title'].lower() == 'mnist': 169 | train_loader, test_loader, model = prepare_mnist(config) 170 | else: 171 | train_loader, test_loader, model = prepare_rnn(config) 172 | 173 | model = model.to(device) 174 | 175 | train_config = config['train'] 176 | save_config = config['save'] 177 | kl_config = config['kl'] 178 | temperature_config = config['temperature'] 179 | model_dir = save_config['model_dir'] 180 | os.makedirs(model_dir, exist_ok=True) 181 | os.makedirs(save_config['log_dir'], exist_ok=True) 182 | 183 | optimizer = { 184 | 'encoder': torch.optim.Adam(model.encoder_parameters(), 185 | lr=train_config['lr']), 186 | 'decoder': torch.optim.Adam(model.decoder_parameters(), 187 | lr=train_config['lr']) 188 | } 189 | scheduler_class = ( 190 | MultiStepLR 191 | if isinstance(train_config['lr_reduce_epochs'], (list, tuple)) 192 | else StepLR 193 | ) 194 | scheduler = { 195 | 'encoder': scheduler_class( 196 | optimizer['encoder'], 197 | train_config['lr_reduce_epochs'], 198 | train_config['lr_reduce_gamma']), 199 | 'decoder': scheduler_class( 200 | optimizer['decoder'], 201 | train_config['lr_reduce_epochs'], 202 | train_config['lr_reduce_gamma']) 203 | } 204 | 205 | logger = SummaryWriter(save_config['log_dir']) 206 | 207 | kl_weight = LinearGrowth(**kl_config) 208 | temperature = LinearGrowth(**temperature_config) 209 | epoch_verbose = train_config.get('verbose', None) == 'epoch' 210 | batch_verbose = not epoch_verbose 211 | 212 | pretrain = train_config.get('pretrain', 0) 213 | if pretrain != 0: 214 | pretrain_weight = LinearGrowth(0, 1, 0, pretrain) 215 | fine_tune = train_config.get('fune_tune', 0) 216 | for epoch in tqdm(range(train_config['epochs'] + pretrain + fine_tune), 217 | disable=not epoch_verbose): 218 | fine_tune = epoch >= train_config['epochs'] + pretrain 219 | current_temperature = temperature(epoch) 220 | if epoch < pretrain: 221 | w = pretrain_weight(epoch) 222 | loss_weights = {'argmax_nll': w, 223 | 'sample_nll': 1 - w} 224 | elif train_config['mode'] == 'argmax': 225 | loss_weights = {'argmax_nll': 1} 226 | logger.add_scalar('temperature', current_temperature, epoch) 227 | else: 228 | loss_weights = {'sample_nll': 1} 229 | loss_weights['kl_loss'] = kl_weight(epoch) 230 | logger.add_scalar('kl_weight', loss_weights['kl_loss'], epoch) 231 | 232 | scheduler['encoder'].step() 233 | scheduler['decoder'].step() 234 | 235 | train_epoch( 236 | model, loss_weights, epoch, train_loader, True, 237 | current_temperature, logger, 238 | optimizer, batch_verbose, 239 | clamp=train_config.get('clamp'), 240 | fine_tune=fine_tune 241 | ) 242 | 243 | if test_loader is not None: 244 | with torch.no_grad(): 245 | train_epoch( 246 | model, loss_weights, epoch, test_loader, False, 247 | current_temperature, logger, 248 | optimizer, batch_verbose, 249 | clamp=train_config.get('clamp'), 250 | fine_tune=fine_tune 251 | ) 252 | if train_config.get("checkpoint", "epoch") == "epoch": 253 | path = f"{model_dir}/checkpoint_{epoch+1}.pt" 254 | else: 255 | path = f"{model_dir}/checkpoint.pt" 256 | model.save(path) 257 | 258 | model.save(f"{model_dir}/checkpoint.pt") 259 | logger.close() 260 | 261 | 262 | if __name__ == "__main__": 263 | parsed_args = parse_args(sys.argv) 264 | train(parsed_args.config, parsed_args.device) 265 | -------------------------------------------------------------------------------- /unit_test.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | import numpy as np 3 | import torch 4 | 5 | from dd_vae.proposals import get_proposals 6 | 7 | 8 | class TestUtils(TestCase): 9 | def test_kl(self): 10 | np.random.seed(0) 11 | batch_size = 100000 12 | m = np.random.randn(32, 1)/3 13 | s = np.exp(np.random.randn(32, 1))/3 14 | m_rep = torch.tensor(np.tile(m, (1, batch_size))) 15 | s_rep = torch.tensor(np.tile(s, (1, batch_size))) 16 | m_t = torch.tensor(m) 17 | s_t = torch.tensor(s) 18 | for name, proposal_class in get_proposals().items(): 19 | with self.subTest(name=name): 20 | proposal = proposal_class() 21 | samples = proposal.sample(m_rep, s_rep) 22 | density = proposal.density((samples-m_rep) / s_rep) / s_rep 23 | kl = proposal.kl(m_t, s_t) 24 | kl_mc = ( 25 | np.log(density * np.sqrt(2 * np.pi)) + 26 | samples**2 / 2 27 | ).mean(1) 28 | self.assertLess( 29 | torch.abs(kl - kl_mc).max().item(), 0.05, 30 | f"Failed proposal {name} for Gaussian prior" 31 | ) 32 | 33 | def test_kl_uniform(self): 34 | np.random.seed(0) 35 | batch_size = 100000 36 | m = np.random.randn(32, 1)/3 37 | s = np.exp(np.random.randn(32, 1))/3 38 | m_rep = torch.tensor(np.tile(m, (1, batch_size))) 39 | s_rep = torch.tensor(np.tile(s, (1, batch_size))) 40 | m_t = torch.tensor(m) 41 | s_t = torch.tensor(s) 42 | for name, proposal_class in get_proposals().items(): 43 | if name == 'gaussian': 44 | continue 45 | with self.subTest(name=name): 46 | proposal = proposal_class() 47 | samples = proposal.sample(m_rep, s_rep) 48 | density = proposal.density((samples-m_rep) / s_rep) / s_rep 49 | kl = proposal.kl_uniform(m_t, s_t) 50 | kl_mc = np.log(density * 2).mean(1) 51 | self.assertLess( 52 | torch.abs(kl - kl_mc).max().item(), 0.05, 53 | f"Failed proposal {name} for uniform prior" 54 | ) 55 | --------------------------------------------------------------------------------