├── doc ├── source │ ├── _static │ │ └── .placeholder │ ├── release-notes.rst │ ├── issues.rst │ ├── index.rst │ ├── desmiles.rst │ └── conf.py └── Makefile ├── lib-python └── desmiles │ ├── decoding │ ├── __init__.py │ ├── greedy.py │ └── astar.py │ ├── __init__.py │ ├── config.py │ ├── data.py │ ├── utils.py │ ├── scripts │ ├── read_saved_model.py │ ├── sample_variants_of_input.py │ └── finetune_model.py │ ├── learner.py │ └── models.py ├── patches ├── README.txt └── core.py ├── tests ├── utils │ ├── get_smiles_from_pairs.sh │ ├── apply_bpe.sh │ ├── convert_to_canon_smiles.py │ ├── convert_to_np.py │ ├── compute_drd2_probs.py │ └── get_fingerprints.py ├── download_drd2_dataset.sh └── test_desmiles.py ├── environment.yml ├── setup.py ├── Dockerfile ├── LICENSE.txt ├── README.md └── Notebooks ├── drd2.py ├── overview_of_DESMILES.ipynb └── intro_demo_of_DESMILES.ipynb /doc/source/_static/.placeholder: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib-python/desmiles/decoding/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /doc/source/release-notes.rst: -------------------------------------------------------------------------------- 1 | Release notes 2 | ============= 3 | 4 | 1.0 - Initial public release of DESMILES. 5 | -------------------------------------------------------------------------------- /lib-python/desmiles/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | from .learner import * 3 | from .data import * 4 | from .utils import * 5 | -------------------------------------------------------------------------------- /patches/README.txt: -------------------------------------------------------------------------------- 1 | If you prefer to avoid the deprecation warnings, please use the provided core.py for fastai/imports/core.py 2 | -------------------------------------------------------------------------------- /tests/utils/get_smiles_from_pairs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | filename="$1" 3 | while read -r line; do 4 | for word in $line; do 5 | echo $word 6 | done 7 | done < "$filename" 8 | -------------------------------------------------------------------------------- /tests/utils/apply_bpe.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | time spm_encode --model ${DESMILES_DATA_DIR}/pretrained/train_val1_val2/bpe_v8000.model --output $2 $1 --output_format=id --extra_options=bos:eos 4 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: desmiles 2 | channels: 3 | - fastai 4 | - pytorch 5 | - conda-forge 6 | - defaults 7 | - rdkit 8 | dependencies: 9 | - python=3.7 10 | - pip 11 | - pytorch=1.0.0 12 | - fastai=1.0.55 13 | - mkl=2019 14 | - numpy=1.18 15 | - scipy=1.4.1 16 | - pandas=1.03 17 | - rdkit=2018.09.1 18 | - pytest=5.4.1 19 | - seaborn=0.10.1 20 | - pkgs/main::sentencepiece 21 | - jupyter=1.0.0 22 | - pip: 23 | - "git+https://github.com/DEShawResearch/DESMILES@1.0" 24 | - scikit-learn==0.19.2 25 | -------------------------------------------------------------------------------- /doc/source/issues.rst: -------------------------------------------------------------------------------- 1 | Open issues 2 | *********** 3 | 4 | Documentation 5 | ~~~~~~~~~~~~~ 6 | 7 | 1. include documentation for how to compile the documentation 8 | 2. sphinx generates lots of warnings about autosummary problems 9 | 10 | Code 11 | ~~~~ 12 | 13 | DESMILES has a dependency on an old garden version of fastai. It would be nice to get rid of it by including the minimal part of the fastai library in a small number of additional modules. 14 | 15 | 16 | Tests 17 | ~~~~~ 18 | 19 | 1. Separate the regression tests from the prechekin tests. 20 | 2. include additional tests for the finetuning executables 21 | -------------------------------------------------------------------------------- /doc/source/index.rst: -------------------------------------------------------------------------------- 1 | .. Doc Template documentation master file, created by 2 | sphinx-quickstart on Wed Oct 5 12:53:58 2016. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | ``desmiles`` 7 | ============================================================================= 8 | 9 | A garden module for the desmiles project 10 | 11 | 12 | 13 | 14 | Contents: 15 | 16 | .. toctree:: 17 | :maxdepth: 2 18 | :glob: 19 | 20 | 21 | desmiles 22 | release-notes 23 | issues 24 | 25 | 26 | Indices and tables 27 | ================== 28 | 29 | * :ref:`genindex` 30 | * :ref:`modindex` 31 | * :ref:`search` 32 | 33 | -------------------------------------------------------------------------------- /tests/utils/convert_to_canon_smiles.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import argparse 5 | from rdkit import Chem 6 | from pathlib import Path 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('input', type=Path, help="input file") 12 | parser.add_argument('output', type=Path, help="output file") 13 | args = parser.parse_args() 14 | return args 15 | 16 | 17 | def main(): 18 | args = parse_args() 19 | smiles = [s.strip() for s in open(args.input)] 20 | canon_smiles = [Chem.CanonSmiles(s) for s in smiles] 21 | with open(args.output, 'w') as outfile: 22 | outfile.writelines((f'{s}\n' for s in canon_smiles)) 23 | return 24 | 25 | 26 | if __name__ == "__main__": 27 | main() 28 | 29 | -------------------------------------------------------------------------------- /tests/utils/convert_to_np.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import numpy as np 5 | from pathlib import Path 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('input', type=Path, help="input file") 11 | parser.add_argument('output', type=Path, help="output file") 12 | args = parser.parse_args() 13 | return args 14 | 15 | 16 | def main(): 17 | args = parse_args() 18 | asfloats = [list(map(float, s.strip().split())) for s in open(args.input)] 19 | maxlen = max(list(map(len, asfloats))) 20 | chembl_enc = np.zeros((len(asfloats), maxlen), dtype=np.int16) 21 | for i, x in enumerate(asfloats): 22 | chembl_enc[i,:len(x)] = x 23 | np.save(args.output, chembl_enc) 24 | return 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | 30 | -------------------------------------------------------------------------------- /lib-python/desmiles/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # The default location of the data is DESMILES/data 4 | # DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../data")) 5 | 6 | if "DESMILES_DATA_DIR" in os.environ.keys(): 7 | DATA_DIR = os.environ['DESMILES_DATA_DIR'] 8 | elif "DESMILES_DOWNLOAD_LOCATION" in os.environ.keys(): 9 | DATA_DIR = os.path.join(os.environ['DESMILES_DOWNLOAD_LOCATION'], 'data') 10 | else: 11 | DATA_DIR = "/workspace/DESMILES-test/data" 12 | 13 | # We need the data_dir to be an absolute path: 14 | DATA_DIR = os.path.abspath(DATA_DIR) 15 | 16 | # TODO: Jacob: check for existence of data/pretrained and fail with helpful error if not there. 17 | 18 | # The path for the optimal model trained on the older molecules (train+val1) 19 | MODEL_train_val1 = os.path.join(DATA_DIR, 'pretrained', 'model_2000_400_2000_5') 20 | 21 | 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="DESmiles", 5 | version="0.1", 6 | author="D. E. Shaw Research", 7 | url="https://github.com/DEShawResearch/DESMILES", 8 | packages=["desmiles", "desmiles.decoding", "desmiles.scripts"], 9 | package_dir={"": "lib-python"}, 10 | entry_points = { 11 | 'console_scripts' : 12 | ['finetune-model=desmiles.scripts.finetune_model:main', 13 | 'read-saved-model=desmiles.scripts.read_saved_model:main', 14 | 'sample-variants-of-input=desmiles.scripts.sample_variants_of_input:main' 15 | ], 16 | }, 17 | classifiers=[ 18 | "Programming Language :: Python :: 3", 19 | "Operating System :: OS Independent", 20 | ], 21 | python_requires=">=3.7, <=3.9", 22 | install_requires=[ 23 | "fastai>=1.0.55, <2", 24 | "seaborn>=0.10.1" 25 | ], # conda misreports rdkit as missing; "rdkit>=2020.03.3", 26 | ) 27 | -------------------------------------------------------------------------------- /tests/utils/compute_drd2_probs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | from rdkit import Chem 5 | import numpy as np 6 | import pickle 7 | from rdkit.Chem import AllChem 8 | 9 | 10 | def load_model(): 11 | with open('./clf_py36.pkl', 'rb') as infile: 12 | model = pickle.load(infile) 13 | return model 14 | 15 | 16 | def get_score(smile, model): 17 | mol = Chem.MolFromSmiles(smile) 18 | if mol: 19 | fp = fingerprints_from_mol(mol) 20 | score = model.predict_proba(fp)[:, 1] 21 | return float(score) 22 | return 0.0 23 | 24 | 25 | def fingerprints_from_mol(mol): 26 | fp = AllChem.GetMorganFingerprint(mol, 3, useCounts=True, useFeatures=True) 27 | size = 2048 28 | nfp = np.zeros((1, size), np.int32) 29 | for idx,v in fp.GetNonzeroElements().items(): 30 | nidx = idx % size 31 | nfp[0, nidx] += int(v) 32 | return nfp 33 | 34 | 35 | def parse_args(): 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('input_file') 38 | parser.add_argument('output_file') 39 | args = parser.parse_args() 40 | return args 41 | 42 | 43 | def main(): 44 | args = parse_args() 45 | model = load_model() 46 | smiles = [s.strip() for s in open(args.input_file)] 47 | drd2_scores = np.asarray([get_score(s, model) for s in smiles]) 48 | np.save(args.output_file, drd2_scores) 49 | return 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | 55 | -------------------------------------------------------------------------------- /tests/utils/get_fingerprints.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import argparse 5 | from rdkit import Chem 6 | from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect 7 | import pandas as pd 8 | import numpy as np 9 | from pathlib import Path 10 | import scipy.sparse 11 | import multiprocessing 12 | 13 | 14 | def smiles_to_fingerprint(smiles_str, sparse=False): 15 | rdmol = Chem.MolFromSmiles(smiles_str) 16 | fp = np.concatenate([np.asarray(GetMorganFingerprintAsBitVect(rdmol, 2, useChirality=True), dtype=np.int8), np.asarray( 17 | GetMorganFingerprintAsBitVect(rdmol, 3, useChirality=True), dtype=np.int8)]) 18 | if sparse: 19 | return scipy.sparse.csr_matrix(fp) 20 | return fp 21 | 22 | def smiles_to_fingerprints(many_smiles_strs): 23 | return scipy.sparse.vstack([smiles_to_fingerprint(s, sparse=True) for s in many_smiles_strs]) 24 | 25 | 26 | def parse_args(): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('input', type=Path, help="input file") 29 | parser.add_argument('output', type=Path, help="output file") 30 | args = parser.parse_args() 31 | return args 32 | 33 | def chunks(l, n): 34 | return (l[i:i + n] for i in range(0, len(l), n)) 35 | 36 | 37 | def main(): 38 | args = parse_args() 39 | smiles = [s.strip() for s in open(args.input)] 40 | num_cpu = multiprocessing.cpu_count() 41 | pool = multiprocessing.Pool(num_cpu) 42 | chunked = list(chunks(smiles, 10000)) 43 | fps = pool.map(smiles_to_fingerprints, chunked) 44 | fps = scipy.sparse.vstack(fps) 45 | scipy.sparse.save_npz(args.output, fps) 46 | return 47 | 48 | 49 | if __name__ == "__main__": 50 | main() 51 | 52 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:1 2 | FROM nvcr.io/nvidia/cuda:10.1-devel-centos7 3 | LABEL name="desmiles" 4 | LABEL version="1.0" 5 | 6 | RUN yum -y install git 7 | 8 | RUN git clone https://github.com/DEShawResearch/DESMILES /opt/DESMILES 9 | # alternatively, if you already have a local repo, 10 | # assuming you are starting one level above the local repo, 11 | # you can use: 12 | # COPY DESMILES /opt/DESMILES 13 | 14 | # set up conda 15 | ENV CONDA_DIR /opt/conda 16 | RUN curl -s https://repo.anaconda.com/miniconda/Miniconda3-py37_4.10.3-Linux-x86_64.sh \ 17 | -o /opt/Miniconda3-py37_5.10.3-Linux-x86_64.sh 18 | RUN /bin/bash /opt/Miniconda3-py37_5.10.3-Linux-x86_64.sh -b -p /opt/conda 19 | ENV PATH=$CONDA_DIR/bin:$PATH 20 | 21 | # setup conda env 22 | # see above about where environment.yml will be found 23 | RUN conda env create --file /opt/DESMILES/environment.yml 24 | # make useable in build 25 | RUN source activate desmiles 26 | 27 | RUN echo "source activate desmiles" > ~/.bashrc 28 | # for 0.4rc install/source/run as user 29 | 30 | # DESMILES requires this environment variable. 31 | # set to recommended mountpoint for desmiles. 32 | ENV DESMILES_DATA_DIR=/desmiles/data 33 | 34 | # If your local data directory is PATH_TO_DATA_DIR, 35 | # add the option -v PATH_TO_DATA_DIR:desmiles/data 36 | 37 | # If you chose to mount the data elsewhere, 38 | # set`-e DESMILES_DATA_DIR` accordingly. 39 | 40 | # If you know how to use docker, the above information will get you started. 41 | # Build this image with: 42 | # docker build -t desmiles:1.0 https://github.com/DEShawResearch/DESMILES.git#1.0 43 | # Assuming that this docker image builds correctly, 44 | # then you can start this docker image using: 45 | # docker run -p 8888:8888 -it --rm --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -v PATH_TO_DATA_DIR:/desmiles/data desmiles:0.5a 46 | -------------------------------------------------------------------------------- /tests/download_drd2_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script will create a preprocessed dataset for the graph-to-graph translation of DRD2. 4 | # If you want to run the DESMILES overview notebook, 5 | # please place this dataset to ${DESMILES_DATA_DIR}/notebooks/ 6 | 7 | # Keep track of desired install path; move the dataset here in the end. 8 | if [[ -d "$1" ]]; then 9 | data_path=$1 10 | echo "Downloading to ${data_path}" 11 | else 12 | echo "Downloading to local directory" 13 | data_path=$(pwd) 14 | fi 15 | 16 | # Get the path for the utility scripts 17 | utils=$(readlink -f $(dirname $0))/utils 18 | 19 | # Create a temporary directory structure for the downloads and preprocessing 20 | dir=$(mktemp -d) 21 | 22 | drd2_dir=${dir}/DRD2 23 | val_dir=${drd2_dir}/Validation 24 | test_dir=${drd2_dir}/Testing 25 | 26 | mkdir -p ${drd2_dir} 27 | mkdir -p ${val_dir} 28 | mkdir -p ${test_dir} 29 | 30 | 31 | # We are using the specific repo of the paper published by Wengong Jin to download the data 32 | DRD2_repo="https://raw.githubusercontent.com/wengong-jin/iclr19-graph2graph/691e28c12d9753c53b765932100d667885376d34" 33 | 34 | cd ${test_dir} 35 | curl -f -O ${DRD2_repo}/data/drd2/test.txt 36 | 37 | cd ${val_dir} 38 | curl -f -O ${DRD2_repo}/data/drd2/valid.txt 39 | 40 | cd ${drd2_dir} 41 | curl -f -O ${DRD2_repo}/data/drd2/train_pairs.txt 42 | 43 | # Perform a series of preprocessing steps for DESMILES 44 | ${utils}/get_smiles_from_pairs.sh train_pairs.txt > drd2.smi 45 | ${utils}/apply_bpe.sh drd2.smi drd2.enc8000 46 | ${utils}/convert_to_np.py drd2.enc8000 drd2.enc8000.npy 47 | ${utils}/get_fingerprints.py drd2.smi fps_drd2.npz 48 | 49 | # Now download the pickled scoring function and use it to evaluate the scores: 50 | curl -f -O ${DRD2_repo}/props/clf_py36.pkl 51 | ${utils}/compute_drd2_probs.py drd2.smi drd2_probs.npy 52 | ${utils}/convert_to_canon_smiles.py drd2.smi drd2_canon_smiles.smi 53 | 54 | # Move the processed DRD2 data directory to the current location. 55 | mv ${drd2_dir} ${data_path} 56 | 57 | # Clean up the tmp directory 58 | rm -rf ${dir} 59 | -------------------------------------------------------------------------------- /patches/core.py: -------------------------------------------------------------------------------- 1 | import csv, gc, gzip, os, pickle, shutil, sys, warnings, yaml, io, subprocess 2 | import math, matplotlib.pyplot as plt, numpy as np, pandas as pd, random 3 | import scipy.stats, scipy.special 4 | import abc, collections, hashlib, itertools, json, operator, pathlib 5 | import mimetypes, inspect, typing, functools, importlib, weakref 6 | import html, re, requests, tarfile, numbers, tempfile, bz2 7 | 8 | from abc import abstractmethod, abstractproperty 9 | from collections.abc import Iterable 10 | from collections import abc, Counter, defaultdict, namedtuple, OrderedDict 11 | import concurrent 12 | from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor 13 | from copy import copy, deepcopy 14 | from dataclasses import dataclass, field, InitVar 15 | from enum import Enum, IntEnum 16 | from functools import partial, reduce 17 | from pdb import set_trace 18 | from matplotlib import patches, patheffects 19 | from numpy import array, cos, exp, log, sin, tan, tanh 20 | from operator import attrgetter, itemgetter 21 | from pathlib import Path 22 | from warnings import warn 23 | from contextlib import contextmanager 24 | from fastprogress.fastprogress import MasterBar, ProgressBar 25 | from matplotlib.patches import Patch 26 | from pandas import Series, DataFrame 27 | from io import BufferedWriter, BytesIO 28 | 29 | import pkg_resources 30 | pkg_resources.require("fastprogress>=0.1.19") 31 | from fastprogress.fastprogress import master_bar, progress_bar 32 | 33 | #for type annotations 34 | from numbers import Number 35 | from typing import Any, AnyStr, Callable, Collection, Dict, Hashable, Iterator, List, Mapping, NewType, Optional 36 | from typing import Sequence, Tuple, TypeVar, Union 37 | from types import SimpleNamespace 38 | 39 | def try_import(module): 40 | "Try to import `module`. Returns module's object on success, None on failure" 41 | try: return importlib.import_module(module) 42 | except: return None 43 | 44 | def have_min_pkg_version(package, version): 45 | "Check whether we have at least `version` of `package`. Returns True on success, False otherwise." 46 | try: 47 | pkg_resources.require(f"{package}>={version}") 48 | return True 49 | except: 50 | return False 51 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | DESMILES LICENSE AGREEMENT 2 | 3 | Copyright 2018-2022, D. E. Shaw Research. All rights reserved. 4 | 5 | Redistribution and use of (1) the DESMILES software in source and binary forms 6 | and (2) the associated chemical system data released with the software, 7 | with or without modification, is permitted provided that the following 8 | conditions are met: 9 | 10 | * Redistributions of source code and the associated data must retain the 11 | above copyright notice, this list of conditions, and the following 12 | disclaimer. 13 | 14 | * Redistributions in binary form must reproduce the above copyright 15 | notice, this list of conditions, and the following disclaimer in the 16 | documentation and/or other materials provided with the distribution. 17 | 18 | Neither the name of D. E. Shaw Research nor the names of its contributors may 19 | be used to endorse or promote products derived from this software without 20 | specific prior written permission. 21 | 22 | THIS SOFTWARE AND DATA ARE PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 23 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 24 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 25 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE 26 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 27 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 28 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 29 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 30 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | OF THIS SOFTWARE AND/OR DATA, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | 33 | 34 | 35 | Binary distributions of this software made available by D.E. Shaw Research 36 | include the following set of third party software libraries, 37 | the use of which are governed by license agreements listed below. 38 | 39 | * Fastai 40 | 41 | Copyright 2017 onwards, fast.ai, Inc 42 | 43 | Licensed under the Apache License, Version 2.0 (the "License"); 44 | you may not use this file except in compliance with the License. 45 | You may obtain a copy of the License at 46 | 47 | http://www.apache.org/licenses/LICENSE-2.0 48 | 49 | Unless required by applicable law or agreed to in writing, software 50 | distributed under the License is distributed on an "AS IS" BASIS, 51 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 52 | See the License for the specific language governing permissions and 53 | limitations under the License. 54 | -------------------------------------------------------------------------------- /lib-python/desmiles/decoding/greedy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | def beam_search(rnn_desmiles, fp, beam_sz=100, max_tokens=30): 6 | assert beam_sz <= rnn_desmiles.embedding.weight.shape[0], "Beam Size must be smaller than number of tokens" 7 | assert fp.shape[0] == 1, "Currently must be tensor of size (1, fp_size)" 8 | device = "cuda" if torch.cuda.is_available() else "cpu" 9 | leaf_nodes = [] 10 | with torch.no_grad(): 11 | rnn_desmiles.embed_fingerprints(fp) # initialize rnn_desmiles by embedding the fingerprint 12 | nodes = torch.tensor([[3]]).to(device) # root node [3] 13 | xb = nodes.clone() # the last token of each node. this is what actually runs through the model 14 | scores = torch.tensor([0.0]).to(device) # scores for the nodes (negative log probabilities) 15 | for t in range(max_tokens): 16 | out = F.log_softmax(rnn_desmiles(xb.transpose(0,1))[:,-1], dim=-1) # run the last token of each node through the model (maybe .contiguous()) 17 | values, indices = out.topk(beam_sz, dim=-1) # get the top beam_sz scores for each node. this is a speed optimization since we will only keep beam_sz nodes in total 18 | scores = (-values + scores[:,None]).view(-1) # update scores (negative log probabilities) based on most recent conditional probability 19 | indices_idx = torch.arange(0,nodes.size(0))[:,None].expand(nodes.size(0), beam_sz).contiguous().view(-1).long() 20 | sort_idx = torch.sort(scores)[1][:beam_sz] # grab the indices of the top beam_sz scores, in pytorch >= 1.0 use scores.argsort() instead 21 | scores = scores[sort_idx] # sort the scores 22 | 23 | nodes = torch.cat([nodes[:,None].expand(nodes.size(0),beam_sz,nodes.size(1)), 24 | indices[:,:,None].expand(nodes.size(0),beam_sz,1),], dim=2) # get the set of beam_sz * beam_sz nodes searched 25 | nodes = nodes.view(-1, nodes.size(2))[sort_idx] # flatten them out and grab the top beam_sz 26 | rnn_desmiles.select_hidden(indices_idx[sort_idx]) # update the hidden states of DESMILES based on which nodes we are keeping 27 | xb = nodes[:,-1][:,None] # update xb to be the last token of the nodes 28 | if nodes.shape[1] > 2: 29 | are_leaf_nodes = (nodes[:,-1] == 2) | (nodes[:,-1] == 1) 30 | for n,s in zip(nodes[are_leaf_nodes], scores[are_leaf_nodes]): # grab the leaf nodes and scores 31 | leaf_nodes.append((n,s)) 32 | leaf_nodes = [(n.cpu().numpy().tolist(),s) for (n,s) in leaf_nodes] 33 | leaf_nodes, scores = zip(*leaf_nodes) 34 | leaf_nodes = np.asarray(leaf_nodes) 35 | scores = np.asarray([s.item() for s in scores]) 36 | sort_idx = np.argsort(scores)[:beam_sz] 37 | scores = scores[sort_idx] 38 | leaf_nodes = leaf_nodes[sort_idx] 39 | return leaf_nodes, scores 40 | -------------------------------------------------------------------------------- /doc/source/desmiles.rst: -------------------------------------------------------------------------------- 1 | Usage 2 | ==================== 3 | 4 | 5 | 6 | Finetuning of desmiles 7 | ****************************** 8 | 9 | Create a finetuned model 10 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 11 | 12 | finetune_model.py will consume a CSV file of matched pairs and dump a learner that is finetuned to preferentially optimize the resulting molecules. The typical CSV file has format: "SMILES_1, SMILES_2, similarity, delta_PI, P1, P2" with delta_PI the difference in the log10 of IC50 and delta_PI positive. The code will only read the entries: "SMILES_1", "SMILES_2", and assume that all postprocessing to bring the CSV file to the correct format was done beforehand. (FIXME: Share an example notebook that creates and inspects such a datasets of pairs from a single list of molecules with IC50 values.) 13 | 14 | 15 | .. argparse:: 16 | :module: finetune_model 17 | :func: get_parser 18 | :prog: finetune_model.py 19 | 20 | 21 | Using finetuned model 22 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 23 | 24 | read_saved_model.py will read the learner that was generated by finetuned_model and will apply it on a new set of small molecules. The small molecules will first be converted to fingerprints and then regenerated in optimized versions using the finetuned model. The parameter training_pairs (same as in finetune_model) is only used to separate novel molecules from those seen in the training set, if any such generated molecules exist. 25 | 26 | 27 | .. argparse:: 28 | :module: read_saved_model 29 | :func: get_parser 30 | :prog: read_saved_model.py 31 | 32 | 33 | Additional processing of results 34 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 35 | 36 | 37 | .. argparse:: 38 | :module: process_samples 39 | :func: get_parser 40 | :prog: process_samples.py 41 | 42 | 43 | 44 | Python API 45 | ********** 46 | 47 | 48 | desmiles.models 49 | ~~~~~~~~~~~~~~~ 50 | 51 | 52 | .. automodule:: desmiles.models 53 | :members: 54 | :undoc-members: 55 | :show-inheritance: 56 | 57 | 58 | desmiles.learner 59 | ~~~~~~~~~~~~~~~~ 60 | 61 | 62 | .. automodule:: desmiles.learner 63 | :members: 64 | :undoc-members: 65 | :show-inheritance: 66 | 67 | 68 | desmiles.data 69 | ~~~~~~~~~~~~~ 70 | 71 | 72 | .. automodule:: desmiles.data 73 | :members: 74 | :undoc-members: 75 | :show-inheritance: 76 | 77 | 78 | desmiles.decoding 79 | ~~~~~~~~~~~~~~~~~ 80 | 81 | 82 | .. automodule:: desmiles.decoding.astar 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | 87 | 88 | .. automodule:: desmiles.decoding.greedy 89 | :members: 90 | :undoc-members: 91 | :show-inheritance: 92 | 93 | 94 | desmiles.utils 95 | ~~~~~~~~~~~~~~ 96 | 97 | 98 | .. automodule:: desmiles.utils 99 | :members: 100 | :undoc-members: 101 | :show-inheritance: 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## **Reference** 2 | This repository contains the code associated with our paper: 3 | 4 | Paul Maragakis, Hunter Nisonoff, Brian Cole, and David E. Shaw, "A Deep-Learning View of Chemical Space Designed to Facilitate Drug Discovery," *Journal of Chemical Information and Modeling*, vol. 60, no. 10, 2020, pp. 4487–4496. [Text](https://doi.org/10.1021/acs.jcim.0c00321) 5 | 6 | The corresponding dataset can be found on our [web site](https://www.deshawresearch.com/downloads/download_desmiles.cgi/). 7 | 8 | ## **Installing and setup** 9 | 10 | To train your own data you'll need a GPU compatible with CUDA 10.1. You can run without a GPU using our pre-trained data set, but performance will be lower than with a GPU. 11 | 12 | You will need a compatible cuDNN and 13 | 14 | ``` 15 | python >=3.7, <=3.9 16 | pytorch >= 1.0 17 | fastai == 1.0.55 18 | scipy 19 | numpy 20 | rdkit 21 | ``` 22 | The tests require `pytest` and `hypothesis`. 23 | 24 | For the provided sample notebooks you will also need: 25 | ``` 26 | jupyter 27 | matplotlib 28 | seaborn 29 | sentencepiece 30 | ``` 31 | 32 | **note:** The script [download_drd2_dataset.sh](https://github.com/DEShawResearch/DESMILES/blob/master/tests/download_drd2_dataset.sh) requires`rdkit==2018.09.01` and `scikit-learn==0.19.2`. 33 | 34 | 35 | **Conda** 36 | 37 | For easy installation, we've also provided a conda [environment.yml](environment.yml). Refer to the [miniconda documentation](https://docs.conda.io/en/latest/miniconda.html) for instructions for installing conda. The conda environment is all that is required for CPU applications. For GPU applications, the environment is limited to CUDA 10. Running DESMILES on GPUs not compatible with CUDA 10 requires building pytorch 1.0.0 from source. 38 | 39 | **Containers** 40 | 41 | We're including a [Dockerfile](Dockerfile) to build a containerized, GPU-enabled version of DESMILES. You can build a docker image by running: 42 | 43 | `docker build -t desmiles:1.0 https://github.com/DEShawResearch/DESMILES.git#1.0` 44 | 45 | ## **Using** 46 | 47 | DESMILES identifies the data directory with the environment variable DESMILES_DATA_DIR. Set it with 48 | ``` 49 | export DESMILES_DATA_DIR= 50 | ``` 51 | Where `DESMILES/data` is the unpacked form of the [data set](https://www.deshawresearch.com/downloads/download_desmiles.cgi/). 52 | 53 | If you are using a container, you'll need to make the data directory visible within the container and correspond to the environment variable `DESMILES_DATA_DIR` within the container. The provided Dockerfile defaults `DESMILES_DATA_DIR` such that the following bind mount will work: 54 | `-v :/desmiles/data` 55 | 56 | 57 | ### **Jupyter notebooks** 58 | 59 | We provide two demo Jupyter notebooks: 60 | 61 | * [intro_demo_of_DESMILES.ipynb](Notebooks/intro_demo_of_DESMILES.ipynb) shows simple ways to use a pretrained model to generate molecules. 62 | * [overview_of_DESMILES.ipynb](Notebooks/overview_of_DESMILES.ipynb) shows examples of training a simple model and fine-tuning an existing model. 63 | 64 | The `overview_of_DESMILES.ipynb` notebook demonstrates how to sample potential potent binders to DRD2 using the benchmark data set published by [Wengong Jin, Kevin Yang, Regina Barzilay, Tommi Jaakkola](https://arxiv.org/abs/1812.01070). 65 | We've provided a script to download and preprocess the data 66 | [download_drd2_dataset.sh](tests/download_drd2_dataset.sh). 67 | 68 | If you are using a container, don't forget to include port forwarding in the container and Jupyter: 69 | 70 | * ```docker run -p 8888:8888 -it --rm --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -v :/desmiles/data desmiles:1.0``` 71 | 72 | * `jupyter notebook --ip "0.0.0.0" --no-browser --allow-root` 73 | 74 | * edit the URL accordingly and access from a local browser 75 | -------------------------------------------------------------------------------- /Notebooks/drd2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import scipy.sparse 4 | from pathlib import Path 5 | # import sys; sys.path.append('../lib-python/') 6 | from desmiles import * 7 | from desmiles.config import DATA_DIR 8 | 9 | def load_training_data(raise_prob=True): 10 | smiles, encoded, fps, pairs, probs = load_drd2_data() 11 | smiles, encoded, fps, pairs, probs, smile_to_enc, smile_to_fp, smile_to_prob = clean_up_data(smiles, encoded, fps, pairs, probs) 12 | original_smile, train_fp, train_enc = clean_data_to_training_data(smiles, encoded, fps, pairs, probs, smile_to_enc, smile_to_fp, smile_to_prob, raise_prob=raise_prob) 13 | return original_smile, train_fp, train_enc 14 | 15 | def load_drd2_data(): 16 | datadir = Path(os.path.join(DATA_DIR, 'notebooks', 'DRD2')) 17 | smiles = np.asarray([s.strip() for s in open(datadir / "drd2.smi")]) 18 | encoded = np.load(datadir / "drd2.enc8000.npy") 19 | fps = scipy.sparse.load_npz(datadir / "fps_drd2.npz") 20 | pairs = pd.read_csv(datadir / 'train_pairs.txt', header=None, sep=" ").values 21 | probs = np.load(datadir / 'drd2_probs.npy') 22 | return smiles, encoded, fps, pairs, probs 23 | 24 | def clean_data_to_training_data(smiles, encoded, fps, pairs, probs, smile_to_enc, smile_to_fp, smile_to_prob, raise_prob=True): 25 | train_enc = [] 26 | train_fp = [] 27 | original_smile = [] 28 | for s1, s2 in pairs: 29 | # raise prob 30 | if raise_prob: 31 | if smile_to_prob[s1] < smile_to_prob[s2]: 32 | train_enc.append(smile_to_enc[s2]) 33 | train_fp.append(smile_to_fp[s1]) 34 | original_smile.append(s1) 35 | else: 36 | train_enc.append(smile_to_enc[s1]) 37 | train_fp.append(smile_to_fp[s2]) 38 | original_smile.append(s2) 39 | else: 40 | if smile_to_prob[s1] < smile_to_prob[s2]: 41 | train_enc.append(smile_to_enc[s1]) 42 | train_fp.append(smile_to_fp[s2]) 43 | original_smile.append(s2) 44 | else: 45 | train_enc.append(smile_to_enc[s2]) 46 | train_fp.append(smile_to_fp[s1]) 47 | original_smile.append(s1) 48 | train_enc = np.asarray(train_enc) 49 | original_smile = np.asarray(original_smile) 50 | train_fp = scipy.sparse.csr_matrix(np.asarray(train_fp)) 51 | return original_smile, train_fp, train_enc 52 | 53 | def clean_up_data(smiles, encoded, fps, pairs, probs): 54 | doesnt_have_unk = ~(encoded == 3).sum(axis=1).astype(np.bool) 55 | is_less_than_25 = (encoded > 0).sum(axis=1) < 25 56 | to_keep = doesnt_have_unk & is_less_than_25 57 | smiles = smiles[to_keep] 58 | encoded = encoded[to_keep] 59 | fps = np.asarray(fps[to_keep].todense()) 60 | probs = probs[to_keep] 61 | smiles_set = set(smiles) 62 | pairs = np.asarray([p for p in pairs if p[0] in smiles_set and p[1] in smiles_set]) 63 | smile_to_enc = {s:e for s,e in zip(smiles, encoded)} 64 | smile_to_fp = {s:fp for s,fp in zip(smiles, fps)} 65 | smile_to_prob = {s:prob for s,prob in zip(smiles, probs)} 66 | return smiles, encoded, fps, pairs, probs, smile_to_enc, smile_to_fp, smile_to_prob 67 | 68 | def create_databunch(train_fp, train_enc, itos_fn, bs): 69 | itos = [s.strip() for i,s in enumerate(open(itos_fn, encoding='utf-8'))] 70 | vocab = Vocab(itos) 71 | num_tokens = len(itos) 72 | 73 | inds = np.arange(train_enc.shape[0]) 74 | inds = np.random.permutation(inds) 75 | val_inds = inds[int(0.8*train_enc.shape[0]):] 76 | # validate on random %20 of training data 77 | # true validation will be measured with inversion 78 | trn_ds = FpSmilesList(train_enc, train_fp, vocab=vocab) 79 | val_ds = FpSmilesList(train_enc[val_inds], train_fp[val_inds], vocab=vocab) 80 | 81 | trn_dl = DesmilesLoader(trn_ds, bs=bs, vocab=vocab) 82 | val_dl = DesmilesLoader(val_ds, bs=bs, vocab=vocab) 83 | db = DataBunch(trn_dl, val_dl) 84 | return db 85 | 86 | -------------------------------------------------------------------------------- /lib-python/desmiles/data.py: -------------------------------------------------------------------------------- 1 | from fastai.basics import * 2 | from fastai.text import Vocab 3 | import scipy.sparse 4 | from rdkit import Chem 5 | from rdkit.Chem.Draw import IPythonConsole 6 | from rdkit.Chem import Draw,rdMolDescriptors,AllChem 7 | 8 | 9 | class FpSmiles(ItemBase): 10 | "Base item type in the fastai library." 11 | def __init__(self, ids,text,fp): 12 | self.data=(ids, fp) 13 | self.text = text 14 | def __repr__(self): return str(self.text) 15 | def show(self): return imageOfMols([self.text], molsPerRow=1) 16 | def apply_tfms(self, tfms:Collection, **kwargs): 17 | if tfms: raise Exception('Not implemented') 18 | return self 19 | 20 | class FpSmilesList(ItemList): 21 | "Basic `ItemList` for FpSmiles data." 22 | _bunch = DataBunch 23 | _label_cls = EmptyLabel 24 | 25 | def __init__(self, ids:NPArrayList, fps:'Scipy.Sparse', vocab:Vocab=None, pad_idx:int=0, **kwargs): 26 | super().__init__(ids, **kwargs) 27 | self.ids = ids 28 | self.fps = fps 29 | self.vocab,self.pad_idx = vocab,pad_idx 30 | self.copy_new += ['vocab', 'pad_idx'] 31 | self.loss_func = CrossEntropyFlat() 32 | 33 | def get(self, i): 34 | ids = self.ids[i] 35 | fp = self.fps[i] 36 | return FpSmiles(ids, self.vocab.textify(ids, sep=''), fp) 37 | 38 | def reconstruct(self, ids:Tensor, fp:Tensor): 39 | return FpSmiles(ids, self.vocab.textify(ids), fp) 40 | 41 | class DesmilesLoader(): 42 | "Create a dataloader for desmiles." 43 | def __init__(self, dataset:FpSmilesList, bs:int=64, vocab=None, sampler=None, shuffle=False, drop_last=False): 44 | # shuffle an drop_last are required by fastai as arguments. We don't use them. 45 | self.dataset,self.bs = dataset,bs 46 | self.first,self.i,self.iter = True,0,0 47 | self.nb = dataset.ids.shape[0] // bs 48 | self.batch_size=bs 49 | self.n = len(self.dataset) 50 | self.batch_first=False 51 | self.vocab=vocab 52 | self.sampler=sampler 53 | self.init_kwargs = dict(bs=bs, vocab=vocab, sampler=sampler) 54 | 55 | def __iter__(self): 56 | self.i,self.iter = 0,0 57 | inds = np.arange(len(self.dataset)) 58 | if self.sampler is None: 59 | inds = np.random.permutation(inds) 60 | # Get index of largest idsuence 61 | largest_seq_ind = np.argmax(np.sum(self.dataset.ids > 0, axis=1)) 62 | # Find where this index got randomly permuted to 63 | current_ind = np.where(inds == largest_seq_ind)[0][0] 64 | # Get the current first index 65 | zero_ind = inds[0] 66 | # Flip the first index with the largest sequence index 67 | inds[current_ind] = zero_ind 68 | inds[0] = largest_seq_ind 69 | else: 70 | seq_lengths = (self.dataset.ids > 0).sum(axis=1) 71 | inds = np.asarray([x for x in self.sampler(seq_lengths, key=lambda x: seq_lengths[x], bs=self.bs)]) 72 | inds = inds[:self.bs*self.nb].reshape(self.nb, self.bs) 73 | while self.iter 0).sum(axis=1) 79 | perm_inds = np.argsort(-lengths) 80 | lengths = lengths[perm_inds] 81 | sequences = sequences[perm_inds,:lengths[0]] 82 | if self.batch_first: 83 | sequences = torch.tensor(sequences, dtype=torch.long) 84 | else: 85 | sequences = torch.tensor(sequences.T, dtype=torch.long) 86 | fingerprints = torch.tensor(np.asarray(self.dataset.fps[batch_inds].todense(), dtype=np.float32)[perm_inds]) 87 | lengths = torch.tensor(lengths, dtype=torch.long) 88 | self.iter += 1 89 | ## CHANGE THE x to sequences[:-1] and y to sequences[1:] 90 | yield ((sequences[:-1, :], fingerprints, lengths-1), sequences[1:,:].contiguous().view(-1)) 91 | 92 | def __len__(self): return self.nb 93 | -------------------------------------------------------------------------------- /lib-python/desmiles/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from pathlib import Path 4 | from .models import * 5 | from .learner import * 6 | from .data import DesmilesLoader, FpSmilesList 7 | from .config import DATA_DIR 8 | from rdkit import Chem 9 | from rdkit.Chem import Draw,rdMolDescriptors,AllChem 10 | from rdkit.Chem.Draw import IPythonConsole 11 | from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect 12 | import scipy.sparse 13 | from functools import partial 14 | import torch 15 | import torch.nn.functional as F 16 | from fastai.text import Vocab 17 | from fastai.basic_data import DataBunch 18 | 19 | 20 | def load_pretrained_desmiles(path, 21 | return_learner=False, 22 | return_rnn=False, 23 | fp_emb_sz=2000, 24 | emb_sz=400, 25 | nh=2000, 26 | nl=5, 27 | itos_fn=os.path.join(DATA_DIR, 'pretrained', 'id.dec8000'), 28 | bs=200, 29 | device=None, 30 | with_opt=True): 31 | if device is None: 32 | device = "cuda" if torch.cuda.is_available() else "cpu" 33 | path = Path(path) 34 | path = path.parent / path.stem 35 | itos = [s.strip() for i,s in enumerate(open(itos_fn, encoding='utf-8'))] 36 | vocab = Vocab(itos) 37 | 38 | trn_ds = FpSmilesList(np.asarray([]),np.asarray([])) 39 | val_ds = FpSmilesList(np.asarray([]),np.asarray([])) 40 | 41 | trn_dl = DesmilesLoader(trn_ds, bs=bs, vocab=vocab) 42 | val_dl = DesmilesLoader(val_ds, bs=bs, vocab=vocab) 43 | db = DataBunch(trn_dl, val_dl) 44 | learner = desmiles_model_learner(db, drop_mult=0.7, fp_emb_sz=fp_emb_sz, emb_sz=emb_sz, nh=nh, nl=nl, pad_token=0, bias=False) 45 | # purge makes loading really slow!! 46 | learner.load(path, purge=False, device=device, with_opt=with_opt) 47 | if return_learner: 48 | learner.model = learner.model.to(device) 49 | return learner 50 | if return_rnn: 51 | rnn_desmiles = RecurrentDESMILES(learner.model).eval() 52 | return rnn_desmiles.to(device) 53 | model = learner.model 54 | return model.to(device) 55 | 56 | 57 | def load_old_pretrained_desmiles(path, 58 | return_learner=False, 59 | return_rnn=False, 60 | bs=200, 61 | fp_emb_sz=2000, 62 | emb_sz=400, 63 | nh=2000, 64 | nl=5, 65 | clip=0.3, 66 | alpha=2., 67 | beta=1., 68 | itos_fn=os.path.join(DATA_DIR, 'pretrained', 'id.dec8000'), 69 | device=None): 70 | ''' 71 | Load a DESMILES model whose weights were generated in pytorch 0.4 72 | ''' 73 | path = Path(path) 74 | path = path.parent / path.stem 75 | if device is None: 76 | device = "cuda" if torch.cuda.is_available() else "cpu" 77 | with open(itos_fn, encoding='utf-8') as itos_file: 78 | itos = [s.strip() for i,s in enumerate(itos_file)] 79 | vocab = Vocab(itos) 80 | bs=1 81 | trn_ds = FpSmilesList(np.asarray([]), np.asarray([]), vocab) 82 | val_ds = FpSmilesList(np.asarray([]), np.asarray([]), vocab) 83 | trn_dl = DesmilesLoader(trn_ds, vocab=vocab, bs=bs, sampler=None) 84 | val_dl = DesmilesLoader(val_ds, vocab=vocab, bs=bs, sampler=None) 85 | db = DataBunch(trn_dl, val_dl, path=".") 86 | learn = desmiles_model_learner(db, drop_mult=0.7, fp_emb_sz=fp_emb_sz, emb_sz=emb_sz, nh=nh, nl=nl, pad_token=0, bias=False, clip=clip, alpha=alpha, beta=beta) 87 | learn.model.reset() 88 | learn.load_old(path, strict=False) 89 | learn.model.eval() 90 | learn.model.reset() 91 | if return_learner: 92 | learn.model = learn.model.to(device) 93 | return learn 94 | if return_rnn: 95 | rnn_desmiles = RecurrentDESMILES(learn.model).eval() 96 | return rnn_desmiles.to(device) 97 | model = learn.model 98 | return model.to(device) 99 | 100 | 101 | def smiles_to_fingerprint(smiles_str, sparse=False, as_tensor=False): 102 | rdmol = Chem.MolFromSmiles(smiles_str) 103 | fp = np.concatenate([np.asarray(GetMorganFingerprintAsBitVect(rdmol, 2, useChirality=True), dtype=np.int8), np.asarray( 104 | GetMorganFingerprintAsBitVect(rdmol, 3, useChirality=True), dtype=np.int8)]) 105 | if sparse: 106 | return scipy.sparse.csr_matrix(fp) 107 | if as_tensor: 108 | import torch 109 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 110 | return torch.tensor(fp.astype(np.float32)).to(device) 111 | return fp 112 | 113 | 114 | def process_smiles(sm): 115 | m = Chem.MolFromSmiles(sm) 116 | AllChem.Compute2DCoords(m) 117 | return m 118 | 119 | 120 | def image_of_mols(smiles_list, molsPerRow=5, subImgSize=(200,200), labels=None): 121 | mols = [process_smiles(sm) for sm in smiles_list] 122 | img = Draw.MolsToGridImage(mols, molsPerRow=molsPerRow, subImgSize=subImgSize, useSVG=True, legends=labels) 123 | return img 124 | 125 | def accuracy4(input, targs): 126 | "Compute accuracy with `targs` when `input` is bs * n_classes, excluding tokens 0--3." 127 | n = targs.shape[0] 128 | input = input.argmax(dim=-1).view(n,-1) 129 | targs = targs.view(n,-1) 130 | return (input[targs > 3] == targs[targs > 3]).float().mean() 131 | 132 | def decoder(idx_vec, itos): 133 | """Return a SMILES string from an index vector (deals with reversal)""" 134 | if len(idx_vec) < 2: 135 | return "" 136 | if idx_vec[1] == 1: # SMILES string is in fwd direction 137 | return ''.join(itos[x] for x in idx_vec if x > 3) 138 | if idx_vec[1] == 2: # SMILES string is in bwd direction 139 | return ''.join(itos[x] for x in idx_vec[::-1] if x > 3) 140 | else: # don't know how to deal with it---do your best 141 | return ''.join(itos[x] for x in idx_vec if x > 3) 142 | 143 | def get_default_decoder(): 144 | from functools import partial 145 | itos_fn=os.path.join(DATA_DIR, 'pretrained', 'id.dec8000') 146 | with open(itos_fn, encoding='utf-8') as itos_file: 147 | itos = [s.strip() for i,s in enumerate(itos_file)] 148 | return partial(decoder, itos=itos) 149 | 150 | default_decoder = get_default_decoder() 151 | def smiles_idx_to_string(smiles_idx, decoder=default_decoder): 152 | return decoder(smiles_idx[smiles_idx > 0].tolist()) 153 | -------------------------------------------------------------------------------- /lib-python/desmiles/scripts/read_saved_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | import os 5 | import argparse 6 | import multiprocessing 7 | from collections import Counter 8 | 9 | 10 | import numpy as np 11 | import pandas as pd 12 | import scipy 13 | from tqdm.auto import tqdm 14 | 15 | from rdkit import Chem 16 | from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect 17 | 18 | 19 | import desmiles 20 | from desmiles.data import Vocab, FpSmilesList, DesmilesLoader, DataBunch 21 | from desmiles.learner import desmiles_model_learner 22 | from desmiles.models import Desmiles, RecurrentDESMILES 23 | from desmiles.models import get_fp_to_embedding_model, get_embedded_fp_to_smiles_model 24 | from desmiles.utils import load_old_pretrained_desmiles, load_pretrained_desmiles 25 | from desmiles.utils import accuracy4 26 | from desmiles.utils import smiles_idx_to_string 27 | from desmiles.learner import OriginalFastaiOneCycleScheduler, Learner 28 | from desmiles.decoding.astar import AstarTreeParallelHybrid as AstarTree 29 | 30 | 31 | def load_pairs(csv_fname, col1="SMILES_1", col2="SMILES_2"): 32 | "Load pairs of SMILES from columns SMILES_1, SMILES_2" 33 | df = pd.read_csv(csv_fname) 34 | return df.loc[:, df.columns.isin((col1, col2))].copy() 35 | 36 | 37 | def canon_smiles(x): 38 | return Chem.CanonSmiles(x, useChiral=True) 39 | 40 | 41 | def smiles_list_to_canon(slist): 42 | "convert a list of smiles to a list of rdkit canonical chiral smiles" 43 | with multiprocessing.Pool() as p: 44 | result = p.map(canon_smiles, slist) 45 | return result 46 | 47 | 48 | ## check: this might be in desmiles.utils 49 | def smiles_to_fingerprint(smiles_str, sparse=False, as_tensor=False): 50 | "Return the desmiles fp" 51 | rdmol = Chem.MolFromSmiles(smiles_str) 52 | fp = np.concatenate([ 53 | np.asarray(GetMorganFingerprintAsBitVect(rdmol, 2, useChirality=True), dtype=np.int8), 54 | np.asarray(GetMorganFingerprintAsBitVect(rdmol, 3, useChirality=True), dtype=np.int8)]) 55 | if sparse: 56 | return scipy.sparse.csr_matrix(fp) 57 | if as_tensor: 58 | import torch 59 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 60 | return torch.tensor(fp.astype(np.float32)).to(device) 61 | return fp 62 | 63 | 64 | ####### 65 | 66 | 67 | def simple_smiles_fail(sm): 68 | # faster and safer processing of really bad SMILES 69 | return ((sm.count("(") != sm.count(")")) | 70 | (sm.count("[") != sm.count("]")) | 71 | (len(sm.strip()) == 0)) 72 | 73 | 74 | # Return num_return molecules, if possible within num_max_try iterations of the algorithm, 75 | # otherwise return as many as you got. 76 | def sample_astar(model, smiles, fp=None, num_return=20, cutoff=0, num_expand=2000, num_max_try=1000): 77 | "sample using parallel hybrid astar" 78 | if fp is None: 79 | fp = smiles_to_fingerprint(smiles, as_tensor=True) 80 | astar = AstarTree(fp, model, num_expand=num_expand) 81 | results = set() 82 | for i in range(num_max_try): 83 | nlp, generated_smiles_idx = next(astar) 84 | generated_smiles = smiles_idx_to_string(generated_smiles_idx) 85 | if simple_smiles_fail(generated_smiles): 86 | continue 87 | print(i, generated_smiles) 88 | try: 89 | mol = Chem.MolFromSmiles(generated_smiles) 90 | print(i, mol) 91 | if mol is not None: 92 | results.add(canon_smiles(generated_smiles)) # keep set of canonical smiles 93 | except: 94 | pass 95 | if len(results) >= num_return: 96 | return results 97 | print("NOTE: sample_astar didn't return enough molecules") 98 | return results 99 | 100 | 101 | ####### 102 | 103 | 104 | def get_training_smiles(fname, col1="SMILES_1", col2="SMILES_2"): 105 | "return all canonical smiles in the training set" 106 | tmp = load_pairs(fname, col1, col2) 107 | training_smiles = smiles_list_to_canon(list(set(tmp.SMILES_1) | set(tmp.SMILES_2))) 108 | return training_smiles 109 | 110 | 111 | def read_enamine_real_smiles(fname): 112 | return [x.strip().split()[0] for x in open(fname)] 113 | 114 | 115 | ######## 116 | 117 | 118 | def main(): 119 | args = get_parser().parse_args() 120 | # First setup the workdir and change into it 121 | try: 122 | os.mkdir(args.workdir, 0o755) 123 | except OSError: 124 | print(f'failed to make directory {args.workdir}') 125 | sys.exit(1) 126 | os.chdir(args.workdir) 127 | 128 | # Read the input (random) molecules 129 | smiles = read_enamine_real_smiles(args.input_smiles) 130 | 131 | # Read the set of training molecules in canonical smiles form 132 | training_smiles = get_training_smiles(args.training_pairs) 133 | 134 | # Read the pre-trained learner 135 | learner = load_pretrained_desmiles(args.learner, return_learner=True) 136 | 137 | # Create the recurrent DESMILES model from fingerprint input 138 | model = learner.model 139 | model.eval() 140 | model = RecurrentDESMILES(model) 141 | 142 | # How many molecules per molecule 143 | num_return = args.num_return 144 | num_expand = args.num_expand 145 | num_max_try = args.num_max_try 146 | 147 | total = Counter() # Keep track of the times we generated each molecule 148 | with open("samples.csv", "w") as out: 149 | out.write("SMILES_from,SMILES_to\n") 150 | 151 | for s in tqdm(smiles): 152 | results = sample_astar(model, s, num_return=num_return, num_expand=num_expand, num_max_try=num_max_try) 153 | total.update(results) 154 | for x in results: 155 | out.write(f'{s},{x}\n') 156 | 157 | # The rest is optional, since we've saved the new molecules already. 158 | 159 | with open("uniques.csv", 'w') as out: 160 | out.write("SMILES,count\n") 161 | for k, v in total.most_common(): 162 | out.write(f"{k},{v}\n") 163 | 164 | unique_training = set(training_smiles) 165 | novel_results = set(total.keys()).difference(unique_training) 166 | with open("novel.csv", 'w') as out: 167 | out.write("SMILES\n") 168 | for x in novel_results: 169 | out.write(f'{x}\n') 170 | 171 | 172 | def get_parser(): 173 | parser = argparse.ArgumentParser() 174 | # Directory where all output goes. 175 | # Will create 3 output files: samples.csv, uniques.csv, novel.csv 176 | parser.add_argument('-w', '--workdir', 177 | help="directory with output", 178 | type=os.path.abspath, required=True) 179 | parser.add_argument('-l', '--learner', 180 | help="name of saved model", 181 | type=os.path.abspath, required=True) 182 | parser.add_argument('-n', '--num_return', 183 | help="molecules to output for each input molecule", 184 | type=int, default=30) 185 | parser.add_argument('-m', '--num_max_try', 186 | help="maximal number of astar iterations", 187 | type=int, default=1000) 188 | parser.add_argument('-x', '--num_expand', 189 | help="batch expansions to try on GPU astar", 190 | type=int, default=1000) 191 | # The list of input molecules, one smiles per line, with the smiles as first element. 192 | parser.add_argument('-i', '--input_smiles', 193 | help="list of input smiles; no header", 194 | type=os.path.abspath, required=True) 195 | # In principle the next argument is optional. 196 | # We use the training molecules to eliminate them from the output file novel.csv 197 | parser.add_argument('-t', '--training_pairs', 198 | help="list of training molecules", 199 | type=os.path.abspath, required=True) 200 | return parser 201 | 202 | 203 | if __name__ == "__main__": 204 | maindoc = """ 205 | Resurrect a finetuned model and apply it to new molecules. 206 | """ 207 | main() 208 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source 21 | 22 | .PHONY: help 23 | help: 24 | @echo "Please use \`make ' where is one of" 25 | @echo " html to make standalone HTML files" 26 | @echo " dirhtml to make HTML files named index.html in directories" 27 | @echo " singlehtml to make a single large HTML file" 28 | @echo " pickle to make pickle files" 29 | @echo " json to make JSON files" 30 | @echo " htmlhelp to make HTML files and a HTML help project" 31 | @echo " qthelp to make HTML files and a qthelp project" 32 | @echo " applehelp to make an Apple Help Book" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 38 | @echo " text to make text files" 39 | @echo " man to make manual pages" 40 | @echo " texinfo to make Texinfo files" 41 | @echo " info to make Texinfo files and run them through makeinfo" 42 | @echo " gettext to make PO message catalogs" 43 | @echo " changes to make an overview of all changed/added/deprecated items" 44 | @echo " xml to make Docutils-native XML files" 45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 46 | @echo " linkcheck to check all external links for integrity" 47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 48 | @echo " coverage to run coverage check of the documentation (if enabled)" 49 | 50 | .PHONY: clean 51 | clean: 52 | rm -rf $(BUILDDIR)/* 53 | 54 | .PHONY: html 55 | html: 56 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 57 | @echo 58 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 59 | 60 | .PHONY: dirhtml 61 | dirhtml: 62 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 63 | @echo 64 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 65 | 66 | .PHONY: singlehtml 67 | singlehtml: 68 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 69 | @echo 70 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 71 | 72 | .PHONY: pickle 73 | pickle: 74 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 75 | @echo 76 | @echo "Build finished; now you can process the pickle files." 77 | 78 | .PHONY: json 79 | json: 80 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 81 | @echo 82 | @echo "Build finished; now you can process the JSON files." 83 | 84 | .PHONY: htmlhelp 85 | htmlhelp: 86 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 87 | @echo 88 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 89 | ".hhp project file in $(BUILDDIR)/htmlhelp." 90 | 91 | .PHONY: qthelp 92 | qthelp: 93 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 94 | @echo 95 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 96 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 97 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/DocTemplate.qhcp" 98 | @echo "To view the help file:" 99 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/DocTemplate.qhc" 100 | 101 | .PHONY: applehelp 102 | applehelp: 103 | $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp 104 | @echo 105 | @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." 106 | @echo "N.B. You won't be able to view it unless you put it in" \ 107 | "~/Library/Documentation/Help or install it in your application" \ 108 | "bundle." 109 | 110 | .PHONY: devhelp 111 | devhelp: 112 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 113 | @echo 114 | @echo "Build finished." 115 | @echo "To view the help file:" 116 | @echo "# mkdir -p $$HOME/.local/share/devhelp/DocTemplate" 117 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/DocTemplate" 118 | @echo "# devhelp" 119 | 120 | .PHONY: epub 121 | epub: 122 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 123 | @echo 124 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 125 | 126 | .PHONY: latex 127 | latex: 128 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 129 | @echo 130 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 131 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 132 | "(use \`make latexpdf' here to do that automatically)." 133 | 134 | .PHONY: latexpdf 135 | latexpdf: 136 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 137 | @echo "Running LaTeX files through pdflatex..." 138 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 139 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 140 | 141 | .PHONY: latexpdfja 142 | latexpdfja: 143 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 144 | @echo "Running LaTeX files through platex and dvipdfmx..." 145 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 146 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 147 | 148 | .PHONY: text 149 | text: 150 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 151 | @echo 152 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 153 | 154 | .PHONY: man 155 | man: 156 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 157 | @echo 158 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 159 | 160 | .PHONY: texinfo 161 | texinfo: 162 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 163 | @echo 164 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 165 | @echo "Run \`make' in that directory to run these through makeinfo" \ 166 | "(use \`make info' here to do that automatically)." 167 | 168 | .PHONY: info 169 | info: 170 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 171 | @echo "Running Texinfo files through makeinfo..." 172 | make -C $(BUILDDIR)/texinfo info 173 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 174 | 175 | .PHONY: gettext 176 | gettext: 177 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 178 | @echo 179 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 180 | 181 | .PHONY: changes 182 | changes: 183 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 184 | @echo 185 | @echo "The overview file is in $(BUILDDIR)/changes." 186 | 187 | .PHONY: linkcheck 188 | linkcheck: 189 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 190 | @echo 191 | @echo "Link check complete; look for any errors in the above output " \ 192 | "or in $(BUILDDIR)/linkcheck/output.txt." 193 | 194 | .PHONY: doctest 195 | doctest: 196 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 197 | @echo "Testing of doctests in the sources finished, look at the " \ 198 | "results in $(BUILDDIR)/doctest/output.txt." 199 | 200 | .PHONY: coverage 201 | coverage: 202 | $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage 203 | @echo "Testing of coverage in the sources finished, look at the " \ 204 | "results in $(BUILDDIR)/coverage/python.txt." 205 | 206 | .PHONY: xml 207 | xml: 208 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 209 | @echo 210 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 211 | 212 | .PHONY: pseudoxml 213 | pseudoxml: 214 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 215 | @echo 216 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 217 | -------------------------------------------------------------------------------- /lib-python/desmiles/learner.py: -------------------------------------------------------------------------------- 1 | 'Model training for NLP' 2 | from fastai.core import * 3 | from fastai.torch_core import * 4 | from fastai.basic_data import DataBunch 5 | from fastai.basic_train import LearnerCallback, Learner 6 | from fastai.text.learner import RNNLearner 7 | from fastai.callbacks import annealing_linear, Scheduler 8 | from .models import get_desmiles_model 9 | from functools import partial 10 | import torch.nn.functional as F 11 | 12 | __all__ = ['desmiles_split', 'desmiles_model_learner', 'DesmilesLearner', 'OriginalFastaiOneCycleScheduler'] 13 | 14 | 15 | def desmiles_split(model:nn.Module) -> List[nn.Module]: 16 | ''' 17 | Split a DESMILES `model` in groups for differential learning rates. 18 | This is currently never used but could be used in the future 19 | ''' 20 | groups = [[rnn, dp] for rnn, dp in zip(model[0].rnns, model[0].hidden_dps)] 21 | groups.append([feedforward for feedforward in model[0].linear_fp]) 22 | groups.append([model[0].encoder, model[0].encoder_dp, model[1]]) 23 | return groups 24 | 25 | 26 | def desmiles_model_learner(data:DataBunch, fp_emb_sz:int=2000, emb_sz:int=400, nh:int=2000, nl:int=5, pad_token:int=0, 27 | drop_mult:float=1., tie_weights:bool=True, bias:bool=True, qrnn:bool=False, 28 | num_bits=4096, dropouts=(0.25, 0.1, 0.2, 0.02, 0.15), **kwargs) -> 'LanguageLearner': 29 | "Create a `DesmilesLearner` with a DESMILES model from `data`." 30 | assert(len(dropouts) == 5) 31 | dps = np.asarray(dropouts) * drop_mult 32 | vocab_size = len(data.vocab.itos) 33 | model = get_desmiles_model(vocab_size, fp_emb_sz, emb_sz, nh, nl, pad_token=pad_token, input_p=dps[0], output_p=dps[1], 34 | weight_p=dps[2], embed_p=dps[3], hidden_p=dps[4], tie_weights=tie_weights, bias=bias, qrnn=qrnn, num_bits=num_bits) 35 | loss_func=partial(F.cross_entropy, ignore_index=0) 36 | learn = DesmilesLearner(data, model, split_func=desmiles_split, loss_func=loss_func, **kwargs) 37 | return learn 38 | 39 | 40 | class DesmilesLearner(RNNLearner): 41 | 42 | def load_old(self, name:PathOrStr, device:torch.device=None, strict:bool=True, with_opt:bool=None, verbose:bool=False): 43 | ''' 44 | Load the weights from an old (i.e. python3.6 and pytorch 0.4 DESMILES model) 45 | ''' 46 | if device is None: device = self.data.device 47 | state = torch.load(self.path/self.model_dir/f'{name}.h5', map_location=device) 48 | if set(state.keys()) == {'model', 'opt'}: 49 | get_model(self.model).load_state_dict(state['model'], strict=strict) 50 | 51 | if ifnone(with_opt,True): 52 | if not hasattr(self, 'opt'): opt = self.create_opt(defaults.lr, self.wd) 53 | try: self.opt.load_state_dict(state['opt']) 54 | except: pass 55 | else: 56 | if with_opt: warn("Saved filed doesn't contain an optimizer state.") 57 | sd = OrderedDict({self.replace_name(k):v for k,v in state.items()}) 58 | if verbose: 59 | for k in self.model.state_dict().keys(): 60 | if k not in sd: 61 | print(k, "not found") 62 | get_model(self.model).load_state_dict(sd, strict=strict) 63 | return self 64 | 65 | def replace_name(self, name): 66 | import re 67 | name = self.replace_sequential_name(name) 68 | layer_names = list(self.model.state_dict().keys()) 69 | if "encoder_with_dropout.embed" in name: 70 | return name.replace("encoder_with_dropout.embed", "encoder_dp.emb") 71 | if ".module.weight_hh_l0_raw" in name: 72 | return name.replace(".module","") 73 | if "linear_fp" in name and "lin.weight" in name: 74 | layer_number = int(re.search("linear_fp_(\d)", name).groups()[0]) 75 | if name.replace(f"linear_fp_{layer_number}.lin.", f"linear_fp.{layer_number - 1}.1.") in layer_names: 76 | return name.replace(f"linear_fp_{layer_number}.lin.", f"linear_fp.{layer_number - 1}.1.") 77 | else: 78 | return name.replace(f"linear_fp_{layer_number}.lin.", f"linear_fp.{layer_number - 1}.2.") 79 | if "linear_fp" in name and "lin.bias" in name: 80 | layer_number = int(re.search("linear_fp_(\d)", name).groups()[0]) 81 | if name.replace(f"linear_fp_{layer_number}.lin.", f"linear_fp.{layer_number - 1}.1.") in layer_names: 82 | return name.replace(f"linear_fp_{layer_number}.lin.", f"linear_fp.{layer_number - 1}.1.") 83 | else: 84 | return name.replace(f"linear_fp_{layer_number}.lin.", f"linear_fp.{layer_number - 1}.2.") 85 | if "linear_fp" in name and "bn.weight" in name: 86 | layer_number = int(re.search("linear_fp_(\d)", name).groups()[0]) 87 | return f"desmiles_rnn_core.linear_fp.{layer_number - 1}.0.weight" 88 | if "linear_fp" in name and "bn.bias" in name: 89 | layer_number = int(re.search("linear_fp_(\d)", name).groups()[0]) 90 | return f"desmiles_rnn_core.linear_fp.{layer_number - 1}.0.bias" 91 | if "linear_fp" in name and "bn.running_mean" in name: 92 | layer_number = int(re.search("linear_fp_(\d)", name).groups()[0]) 93 | return f"desmiles_rnn_core.linear_fp.{layer_number - 1}.0.running_mean" 94 | if "linear_fp" in name and "bn.running_var" in name: 95 | layer_number = int(re.search("linear_fp_(\d)", name).groups()[0]) 96 | return f"desmiles_rnn_core.linear_fp.{layer_number - 1}.0.running_var" 97 | return name 98 | 99 | 100 | @staticmethod 101 | def replace_sequential_name(name): 102 | if "0.encoder" in name: 103 | return name.replace("0.encoder", "desmiles_rnn_core.encoder") 104 | elif "0.rnns" in name: 105 | return name.replace("0.rnns", "desmiles_rnn_core.rnns") 106 | elif "0.linear" in name: 107 | return name.replace("0.linear", "desmiles_rnn_core.linear") 108 | elif "1.decoder" in name: 109 | return name.replace("1.decoder", "linear_decoder.decoder") 110 | else: 111 | raise ValueError("Could not find replacement for name") 112 | 113 | class OriginalFastaiOneCycleScheduler(LearnerCallback): 114 | "Scheduler that mimics the original Fastai one-cycle learner" 115 | def __init__(self, learn:Learner, lr_max:float, moms:Floats=(0.8, 0.6), div_factor:float=10., frac_inc:float=0.5, frac_dec=0.49, tot_epochs:int=None, start_epoch:int=None): 116 | super().__init__(learn) 117 | self.lr_max,self.div_factor,self.frac_inc,self.frac_dec = lr_max,div_factor,frac_inc,frac_dec 118 | self.moms=tuple(listify(moms,2)) 119 | if is_listy(self.lr_max): self.lr_max = np.array(self.lr_max) 120 | self.start_epoch, self.tot_epochs = start_epoch, tot_epochs 121 | 122 | def steps(self, *steps_cfg:StartOptEnd): 123 | "Build anneal schedule for all of the parameters." 124 | return [Scheduler(step, n_iter, func=func) 125 | for (step,(n_iter,func)) in zip(steps_cfg, self.phases)] 126 | 127 | def on_train_begin(self, n_epochs:int, epoch:int, **kwargs:Any)->None: 128 | "Initialize our optimization params based on our annealing schedule." 129 | res = {'epoch':self.start_epoch} if self.start_epoch is not None else None 130 | self.start_epoch = ifnone(self.start_epoch, epoch) 131 | self.tot_epochs = ifnone(self.tot_epochs, n_epochs) 132 | n = len(self.learn.data.train_dl) * self.tot_epochs 133 | a1 = int(n * self.frac_inc) 134 | a2 = int(n * self.frac_dec) 135 | a3 = n - (a2 + a1) 136 | self.phases = ((a1, annealing_linear), (a2, annealing_linear), (a3, annealing_linear)) 137 | low_lr = self.lr_max/self.div_factor 138 | self.lr_scheds = self.steps((low_lr, self.lr_max), (self.lr_max, low_lr), (low_lr, low_lr/(self.div_factor**2))) 139 | self.mom_scheds = self.steps(self.moms, (self.moms[1], self.moms[0]), (self.moms[0], self.moms[0])) 140 | self.opt = self.learn.opt 141 | self.opt.lr,self.opt.mom = self.lr_scheds[0].start,self.mom_scheds[0].start 142 | self.idx_s = 0 143 | return res 144 | 145 | def jump_to_epoch(self, epoch:int)->None: 146 | for _ in range(len(self.learn.data.train_dl) * epoch): 147 | self.on_batch_end(True) 148 | 149 | def on_batch_end(self, train, **kwargs:Any)->None: 150 | "Take one step forward on the annealing schedule for the optim params." 151 | if train: 152 | if self.idx_s >= len(self.lr_scheds): return {'stop_training': True, 'stop_epoch': True} 153 | self.opt.lr = self.lr_scheds[self.idx_s].step() 154 | self.opt.mom = self.mom_scheds[self.idx_s].step() 155 | # when the current schedule is complete we move onto the next 156 | # schedule. (in 1-cycle there are two schedules) 157 | if self.lr_scheds[self.idx_s].is_done: 158 | self.idx_s += 1 159 | 160 | def on_epoch_end(self, epoch, **kwargs:Any)->None: 161 | "Tell Learner to stop if the cycle is finished." 162 | if epoch > self.tot_epochs: return {'stop_training': True} 163 | -------------------------------------------------------------------------------- /lib-python/desmiles/scripts/sample_variants_of_input.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Sample a number of variants of each molecule in a CSV file 5 | using a pretrained DESMILES model. Example usage for getting 6 | 100 molecules for each input molecule: 7 | 8 | time ./sample_variants_of_input.py --max_try 1000 --num_expand 500 -n 100 -i ml_results.batch4.csv --verbose --dont-dump-model -w variants_sample_1080ti_10K_20K --min-row=10000 --max-row=20000 2>&1 >& 10K-20K_1080ti.log 9 | 10 | 11 | The output is placed in the file variants_sample_1080ti_10K_20K/samples.csv 12 | """ 13 | 14 | 15 | import sys 16 | import socket 17 | import argparse 18 | import time 19 | import os 20 | from pathlib import Path 21 | import pandas as pd 22 | import numpy as np 23 | import rdkit 24 | from rdkit import Chem 25 | from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect 26 | import pickle 27 | import fastai 28 | from desmiles.data import Vocab, FpSmilesList, DesmilesLoader, DataBunch 29 | from desmiles.config import DATA_DIR, MODEL_train_val1 30 | from desmiles.utils import load_old_pretrained_desmiles, load_pretrained_desmiles, decoder 31 | from desmiles.decoding.astar import AstarTreeParallelHybrid as AstarTree 32 | from desmiles.models import Desmiles, RecurrentDESMILES 33 | 34 | 35 | def smiles_to_fingerprint(smiles_str, sparse=False, as_tensor=False): 36 | "Return the desmiles fp" 37 | rdmol = Chem.MolFromSmiles(smiles_str) 38 | fp = np.concatenate([ 39 | np.asarray(GetMorganFingerprintAsBitVect(rdmol, 2, useChirality=True), dtype=np.int8), 40 | np.asarray(GetMorganFingerprintAsBitVect(rdmol, 3, useChirality=True), dtype=np.int8)]) 41 | if sparse: 42 | return scipy.sparse.csr_matrix(fp) 43 | if as_tensor: 44 | import torch 45 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 46 | return torch.tensor(fp.astype(np.float32)).to(device) 47 | return fp 48 | 49 | 50 | def canon_smiles(x): 51 | "Shortcut to return canonical smiles with chiral info" 52 | return Chem.CanonSmiles(x, useChiral=True) 53 | 54 | 55 | def load_good_model(model_fn, itos_fn): 56 | "Load an old desmiles model with default hyperparameters" 57 | model = load_old_pretrained_desmiles(model_fn, return_learner=True, itos_fn=itos_fn) 58 | return model 59 | 60 | 61 | def simple_smiles_fail(sm): 62 | # faster and safer processing of really bad SMILES 63 | return ((sm.count("(") != sm.count(")")) | 64 | (sm.count("[") != sm.count("]")) | 65 | (len(sm.strip()) == 0)) 66 | 67 | 68 | def get_my_decoder(itos_fn): 69 | from functools import partial 70 | itos = [s.strip() for i,s in enumerate(open(itos_fn, encoding='utf-8'))] 71 | return partial(decoder, itos=itos) 72 | 73 | 74 | def get_smiles_idx_to_string(itos_fn): 75 | my_decoder = get_my_decoder(itos_fn) 76 | def smiles_idx_to_string(smiles_idx): 77 | return my_decoder(smiles_idx[smiles_idx > 0].tolist()) 78 | return smiles_idx_to_string 79 | 80 | 81 | # Return num_return molecules, if possible within num_max_try iterations of the algorithm, 82 | # otherwise return as many as you got. 83 | def sample_astar(model, smiles, smiles_idx_to_string, fp=None, num_return=20, cutoff=0, num_expand=2000, num_max_try=1000): 84 | "sample using parallel hybrid astar" 85 | if fp is None: 86 | fp = smiles_to_fingerprint(smiles, as_tensor=True) 87 | astar = AstarTree(fp, model, num_expand=num_expand) 88 | results = set() 89 | for i in range(num_max_try): 90 | nlp, generated_smiles_idx = next(astar) 91 | generated_smiles = smiles_idx_to_string(generated_smiles_idx) 92 | if simple_smiles_fail(generated_smiles): 93 | continue 94 | print(i, generated_smiles) 95 | try: 96 | mol = Chem.MolFromSmiles(generated_smiles) 97 | print(i, mol) 98 | if mol is not None: 99 | results.add(canon_smiles(generated_smiles)) # keep set of canonical smiles 100 | except: 101 | pass 102 | if len(results) >= num_return: 103 | return results 104 | print("NOTE: sample_astar didn't return enough molecules") 105 | return results 106 | 107 | 108 | def main(): 109 | """ 110 | Get num_return molecules for each of the input smiles 111 | """ 112 | t_start = time.time() 113 | args = parse_args() 114 | df = pd.read_csv(args.input_smiles) 115 | df = df.iloc[args.min_row:args.max_row] 116 | df['canon_SMILES'] = df['SMILES'].map(canon_smiles) 117 | smiles = [canon_smiles(x) for x in df['canon_SMILES']] 118 | 119 | if args.verbose: 120 | print(f'Canonicalized the input smiles. Time so far {time.time() - t_start}') 121 | os.makedirs(args.workdir, mode=0o755, exist_ok=True) 122 | os.chdir(args.workdir) 123 | if Path(args.fpname).exists(): 124 | fps = pickle.load(open(args.fpname, 'rb')) 125 | else: 126 | fps = [smiles_to_fingerprint(x) for x in smiles] 127 | pickle.dump(fps, open(args.fpname, 'wb')) 128 | if args.verbose: 129 | print(f'Got fingerprints. Time so far {time.time() - t_start}') 130 | model_dump_fname = os.path.join("models", args.modeldump + ".pth") 131 | if Path(model_dump_fname).exists(): 132 | if args.verbose: 133 | print(f'Loading dumped model {model_dump_fname}') 134 | model = load_pretrained_desmiles(args.modeldump, return_learner=True, itos_fn=args.itos_fn) 135 | else: 136 | if args.verbose: 137 | print(f'loading default model {args.model}') 138 | model = load_good_model(args.model, itos_fn=args.itos_fn) 139 | if args.verbose: 140 | print(f'Loaded model. Time so far {time.time() - t_start}') 141 | if not Path(model_dump_fname).exists() and not args.dont_dump_model: 142 | model.save(args.modeldump) 143 | if args.verbose: 144 | print(f'Dumped local model. Time so far {time.time() - t_start}') 145 | rnn_desmiles = RecurrentDESMILES(model.model).eval() 146 | if args.verbose: 147 | print(f'Got recurrent model. Time so far {time.time() - t_start}') 148 | smiles_idx_to_string = get_smiles_idx_to_string(args.itos_fn) 149 | with open('samples.csv', 'w') as out: 150 | out.write('SMILES, SMILES_out\n') 151 | total = set() 152 | for s in smiles: 153 | results = sample_astar(rnn_desmiles, s, smiles_idx_to_string, num_return=args.num_return, 154 | num_expand=args.num_expand, num_max_try=args.max_try) 155 | total.update(results) 156 | for x in results: 157 | out.write(f'{s},{x}\n') 158 | t_end = time.time() 159 | 160 | if args.verbose: 161 | print(f'total walltime was: {t_end - t_start}') 162 | print(f'total different smiles generated: {len(total)}') 163 | 164 | def parse_args(): 165 | parser = argparse.ArgumentParser() 166 | parser.add_argument('-w', '--workdir', 167 | help="directory with output", 168 | type=os.path.abspath, required=True) 169 | parser.add_argument('-l', '--model', 170 | help="name of saved model", 171 | type=os.path.abspath, 172 | default=MODEL_train_val1) 173 | parser.add_argument('--itos_fn', 174 | help="filename of int to string mapping for BPE", 175 | type=os.path.abspath, 176 | default=os.path.join(DATA_DIR, 'pretrained', 'id.dec8000')) 177 | parser.add_argument('--dont-dump-model', 178 | help='do not keep a backup of the loaded model in workdir', 179 | action='store_true') 180 | parser.add_argument('-n', '--num_return', 181 | help="molecules to output for each input molecule", 182 | type=int, default=100) 183 | parser.add_argument('-m', '--max_try', 184 | help="maximal number of astar iterations", 185 | type=int, default=2000) 186 | parser.add_argument('-x', '--num_expand', 187 | help="batch expansions to try on GPU astar", 188 | type=int, default=1000) 189 | parser.add_argument('--fpname', 190 | help="pickle of the fingerprint file", 191 | type=str, default="temp_fingerprints.pkl") 192 | parser.add_argument('--modeldump', 193 | help="pickle of the particular model", 194 | type=str, default="model_dump") 195 | parser.add_argument('--verbose', 196 | help="print a few timining log messages", 197 | action='store_true') 198 | # The list of input molecules, one smiles per line, with the smiles as first element. 199 | parser.add_argument('-i', '--input_smiles', 200 | help="csv file with input smiles in column SMILES", 201 | type=os.path.abspath, required=True) 202 | parser.add_argument('--min-row', 203 | help='minimum row to keep in CSV input file', 204 | default=0, type=int) 205 | parser.add_argument('--max-row', 206 | help='maximal row (exclusive, or -1) to keep in CSV input file', 207 | default=-1, type=int) 208 | args = parser.parse_args() 209 | return args 210 | 211 | 212 | if __name__ == "__main__": 213 | maindoc = """ 214 | Resurrect a finetuned model and apply it to new molecules. 215 | """ 216 | main() 217 | -------------------------------------------------------------------------------- /lib-python/desmiles/scripts/finetune_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | import os 6 | import argparse 7 | import tempfile 8 | import multiprocessing 9 | import functools 10 | import subprocess 11 | from pathlib import Path 12 | 13 | 14 | import numpy as np 15 | import pandas as pd 16 | import scipy 17 | 18 | 19 | from rdkit import Chem 20 | from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect 21 | 22 | import desmiles 23 | from desmiles.data import Vocab, FpSmilesList, DesmilesLoader, DataBunch 24 | from desmiles.config import DATA_DIR, MODEL_train_val1 25 | from desmiles.learner import desmiles_model_learner 26 | from desmiles.models import Desmiles, RecurrentDESMILES 27 | from desmiles.models import get_fp_to_embedding_model, get_embedded_fp_to_smiles_model 28 | from desmiles.utils import load_old_pretrained_desmiles, load_pretrained_desmiles 29 | from desmiles.utils import accuracy4 30 | from desmiles.learner import OriginalFastaiOneCycleScheduler, Learner 31 | 32 | 33 | def load_pairs(csv_fname, col1="SMILES_1", col2="SMILES_2"): 34 | "Load pairs of SMILES from columns SMILES_1, SMILES_2" 35 | df = pd.read_csv(csv_fname) 36 | return df.loc[:, df.columns.isin((col1, col2))].copy() 37 | 38 | 39 | def canon_smiles(x): 40 | return Chem.CanonSmiles(x, useChiral=True) 41 | 42 | 43 | def smiles_list_to_canon(slist): 44 | "convert a list of smiles to a list of rdkit canonical chiral smiles" 45 | with multiprocessing.Pool() as p: 46 | result = p.map(canon_smiles, slist) 47 | return result 48 | 49 | 50 | def pairs_to_canon_dict(pairs): 51 | "Take pairs of smiles, get set of all smiles and return dict smile: canon" 52 | s1 = pairs['SMILES_1'].values.tolist() 53 | s2 = pairs['SMILES_2'].values.tolist() 54 | smiles = list(set(s1) | set(s2)) 55 | canon = smiles_list_to_canon(smiles) 56 | return {x: y for x, y in zip(smiles, canon)} 57 | 58 | 59 | def canon2bpe(canon): 60 | "Convert a list of smiles (typically canonical) to their BPE, return dict canon: BPE" 61 | # Work in a temporary directory and cleanup in case of errors 62 | with tempfile.TemporaryDirectory() as tmpdirname: 63 | # Write smiles into temporary smiles.smi file 64 | sm_fname = os.path.join(tmpdirname, 'smiles.smi') 65 | with open(sm_fname, 'w') as f: 66 | for s in canon: 67 | f.write(s + '\n') 68 | # Write corresponding BPE into smiles.bpe file 69 | bpe_fname = os.path.join(tmpdirname, 'smiles.bpe') 70 | cmd = f"spm_encode --model {DATA_DIR}/pretrained/bpe_v8000.model --output {bpe_fname} {sm_fname} --output_format=id --extra_options=bos:eos" 71 | run = subprocess.run(cmd.split()) 72 | assert run.returncode == 0 73 | # Create pbe list (of lists) before the temporary directory is deleted 74 | with open(bpe_fname) as bpe_file: 75 | bpe = [[int(x) for x in y.strip().split()] for y in bpe_file] 76 | 77 | # Return a dictionary from canonical smiles to byte pair encoding 78 | return {x: y for x, y in zip(canon, bpe)} 79 | 80 | 81 | def c2b2rectbpe(canon2bpe): 82 | "Convert a map of canon: bpe to a map that has BPE as numpy int16 arrays of equal length" 83 | smiles, bpe = map(list, zip(*canon2bpe.items())) 84 | maxlen = max(list(map(len, bpe))) 85 | desmiles_enc = np.zeros((len(smiles), maxlen), dtype=np.int16) 86 | for i, x in enumerate(bpe): 87 | desmiles_enc[i, :len(x)] = x 88 | return {x: y for x, y in zip(smiles, desmiles_enc)} 89 | 90 | 91 | def smiles_to_fingerprint(smiles_str, sparse=False, as_tensor=False): 92 | "Return the desmiles fp" 93 | rdmol = Chem.MolFromSmiles(smiles_str) 94 | fp = np.concatenate([ 95 | np.asarray(GetMorganFingerprintAsBitVect(rdmol, 2, useChirality=True), dtype=np.int8), 96 | np.asarray(GetMorganFingerprintAsBitVect(rdmol, 3, useChirality=True), dtype=np.int8)]) 97 | if sparse: 98 | return scipy.sparse.csr_matrix(fp) 99 | if as_tensor: 100 | import torch 101 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 102 | return torch.tensor(fp.astype(np.float32)).to(device) 103 | return fp 104 | 105 | 106 | def canon2fp(canon): 107 | "Convert a list of (canonical) smiles to desmiles fp" 108 | with multiprocessing.Pool() as p: 109 | f = functools.partial(smiles_to_fingerprint, sparse=True) 110 | result = p.map(f, canon) 111 | return {x: y for x, y in zip(canon, result)} 112 | 113 | 114 | ####### 115 | 116 | 117 | def load_top_pretrained_learner(db): 118 | model_fn = Path(MODEL_train_val1) 119 | learner = load_old_pretrained_desmiles(model_fn, return_learner=True) 120 | learner.metrics = [accuracy4] 121 | learner.data = db 122 | return learner 123 | 124 | 125 | ############ 126 | 127 | 128 | def create_my_db( 129 | fp, 130 | enc, 131 | bs, 132 | true_validation = False, 133 | start_validation = 0.8, 134 | itos_fn=os.path.join(DATA_DIR, 'pretrained', 'id.dec8000')): 135 | "Create a databunch for transfer learning" 136 | 137 | itos = [s.strip() for i, s in enumerate(open(itos_fn, encoding='utf-8'))] 138 | vocab = Vocab(itos) 139 | 140 | n = len(enc) 141 | enc = np.array(enc) 142 | inds = np.arange(n) 143 | inds = np.random.permutation(inds) 144 | val_inds = inds[int(start_validation*n):] 145 | if true_validation: 146 | trn_inds = inds[:int(start_validation*n)] 147 | trn_ds = FpSmilesList(enc[trn_inds], fp[trn_inds], vocab=vocab) 148 | else: 149 | trn_ds = FpSmilesList(enc, fp, vocab=vocab) 150 | val_ds = FpSmilesList(enc[val_inds], fp[val_inds], vocab=vocab) 151 | 152 | trn_dl = DesmilesLoader(trn_ds, bs=bs, vocab=vocab) 153 | val_dl = DesmilesLoader(val_ds, bs=bs, vocab=vocab) 154 | db = DataBunch(trn_dl, val_dl) 155 | return db 156 | 157 | 158 | def main(): 159 | """ 160 | Finetune the best DESMILES model. 161 | The conceptual process is the following: 162 | We start with a list of pairs of smiles (A->B with A, B neighbors, and B better than A) 163 | and we want to convert it to a list of pairs of fingerprints (of A) and BPE (of B). 164 | Schematically, this process works as follows: 165 | 166 | pairs(smiles) -> set(canon_smiles) -> canon: BPE -> canon: rectBPE 167 | | 168 | -> canon: fp 169 | 170 | pairs(smiles) -> pairs(canon_smiles) -> pairs(fp: rectBPE) 171 | """ 172 | 173 | args = get_parser().parse_args() 174 | # Load the pairs of SMILES for fine tuning. 175 | p = load_pairs(args.training_pairs) 176 | print(f"FINETUNE: Loaded pairs dataset that looks like:\n{p[:2]}") 177 | print(f'FINETUNE: Total length of dataset is: {len(p)}') 178 | # Dictionary that transforms a SMILES from those pairs to a canonical SMILES 179 | s2c = pairs_to_canon_dict(p) 180 | # The lists of unique SMILES and corresponding canonical SMILES 181 | smiles, canon = map(list, zip(*s2c.items())) 182 | print(f'FINETUNE: Total number of SMILES: {len(smiles)}') 183 | print(f'FINETUNE: Total number of canonical SMILES: {len(set(canon))}') 184 | # How many SMILES were already canonicalized in input (often all or none) 185 | print('FINETUNE: Number of canonical smiles in input:', sum([x == y for x, y in zip(smiles, canon)])) 186 | # Get the byte-pair encoding for each unique smiles 187 | c2b = canon2bpe(canon) 188 | print(f'FINETUNE: got the Byte Pair Encoding of {len(c2b)} canonical SMILES') 189 | for c, b in list(c2b.items())[:2]: 190 | print(c, b) 191 | # Get the rectangular form of the dictionary 192 | c2r = c2b2rectbpe(c2b) 193 | # Get the desmiles fingerprints for each unique smiles 194 | c2fp = canon2fp(canon) 195 | print(f'FINETUNE: got the fingeprints for {len(c2fp)} canonical SMILES') 196 | for c, fp in list(c2fp.items())[:2]: 197 | print(c, fp.shape) 198 | # Get a set of pairs of fp -> rectBPE for each of the original pairs 199 | fps = [] 200 | enc = [] 201 | for s1, s2 in p.values.tolist(): 202 | c1 = s2c[s1] 203 | c2 = s2c[s2] 204 | fps.append(c2fp[c1]) 205 | enc.append(c2r[c2]) 206 | fps = scipy.sparse.vstack(fps) 207 | print(f'FINETUNE: Got the lists of fingerprints: {fps.shape}') 208 | print(f'FINETUNE: Got the lists of encodings: {len(enc)}') 209 | print("FINETUNE: DONE with setup") 210 | print("#################################") 211 | 212 | db = create_my_db(fps, enc, 213 | bs=args.batch_size, 214 | true_validation=args.true_validation, 215 | start_validation=args.start_validation) 216 | 217 | learner = load_top_pretrained_learner(db) 218 | 219 | num_epochs = args.num_epochs # 1 220 | max_lr = args.max_lr # 0.001 221 | div_factor = args.div_factor # 7 222 | 223 | one_cycle_linear_cb = OriginalFastaiOneCycleScheduler(learner, max_lr, div_factor=div_factor) 224 | if num_epochs > 0: 225 | learner.fit(num_epochs, callbacks=[one_cycle_linear_cb]) 226 | 227 | learner.save(args.learner) 228 | 229 | 230 | def get_parser(): 231 | parser = argparse.ArgumentParser( 232 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 233 | # finetuned model name 234 | parser.add_argument('-t', '--training_pairs', 235 | help="filename of the training pairs (csv with SMILES_1, SMILES_2)", 236 | type=os.path.abspath, required=True) 237 | parser.add_argument('-l', '--learner', 238 | help="filename of the finetuned model", 239 | type=os.path.abspath, required=True) 240 | parser.add_argument('-e', '--num_epochs', 241 | help="number of epochs (0 for none)", 242 | type=int, default=1) 243 | parser.add_argument('-m', '--max_lr', 244 | help="maximal learning rate", 245 | type=float, default=0.001) 246 | parser.add_argument('-d', '--div_factor', 247 | help="div factor in one-cycle training", 248 | type=int, default=7) 249 | parser.add_argument('-b', '--batch_size', 250 | help="batch size", 251 | type=int, default=200) 252 | parser.add_argument('-v', '--true_validation', 253 | help="dont train on validation subset", 254 | type=bool, default=False) 255 | parser.add_argument('-s', '--start_validation', 256 | help="fraction of data at start of validation", 257 | type=float, default=0.8) 258 | return parser 259 | 260 | 261 | if __name__ == "__main__": 262 | maindoc = """ 263 | finetune a DESMILES model using pairs of A->B, with B improved. 264 | """ 265 | main() 266 | -------------------------------------------------------------------------------- /doc/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Doc Template documentation build configuration file, created by 4 | # sphinx-quickstart on Wed Oct 5 12:53:58 2016. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | import sys 16 | import os 17 | 18 | 19 | # # adding in support for .md files 20 | # from recommonmark.parser import CommonMarkParser 21 | # source_parsers = {'.md': CommonMarkParser} 22 | # source_suffix = ['.rst', '.md'] 23 | 24 | # If extensions (or modules to document with autodoc) are in another directory, 25 | # add these directories to sys.path here. If the directory is relative to the 26 | # documentation root, use os.path.abspath to make it absolute, like shown here. 27 | #sys.path.insert(0, os.path.abspath('.')) 28 | sys.path.insert(0, os.path.abspath('../../lib-python/desmiles/scripts')) 29 | sys.path.insert(0, os.path.abspath('../../lib-python')) 30 | HERE = os.path.dirname(__file__) 31 | TOP_BINDIR = os.path.normpath(os.path.join(HERE, "../../lib-python/desmiles/scripts")) 32 | os.environ["PATH"] = os.pathsep.join([ 33 | os.path.abspath(TOP_BINDIR), os.environ.get("PATH", "")]) 34 | 35 | 36 | # -- General configuration ------------------------------------------------ 37 | 38 | # If your documentation needs a minimal Sphinx version, state it here. 39 | #needs_sphinx = '1.0' 40 | 41 | # Add any Sphinx extension module names here, as strings. They can be 42 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 43 | # ones. 44 | extensions = [ 45 | 'sphinx.ext.autodoc', 46 | 'sphinx.ext.doctest', 47 | 'sphinx.ext.intersphinx', 48 | 'sphinx.ext.todo', 49 | 'sphinx.ext.coverage', 50 | 'sphinx.ext.mathjax', 51 | 'sphinx.ext.ifconfig', 52 | 'sphinx.ext.viewcode', 53 | ] 54 | 55 | extensions += ['numpydoc'] 56 | extensions += ['sphinxcontrib.programoutput'] 57 | extensions += ['sphinxarg.ext'] 58 | 59 | # Add any paths that contain templates here, relative to this directory. 60 | templates_path = ['_templates'] 61 | 62 | 63 | # The suffix(es) of source filenames. 64 | # You can specify multiple suffix as a list of string: 65 | # source_suffix = ['.rst', '.md'] 66 | #source_suffix = '.rst' 67 | 68 | # The encoding of source files. 69 | #source_encoding = 'utf-8-sig' 70 | 71 | # The master toctree document. 72 | master_doc = 'index' 73 | 74 | # General information about the project. 75 | project = u'desmiles' 76 | copyright = u'2018, 2019, 2020, D. E. Shaw Research' 77 | author = u'DESRES' 78 | 79 | # The version info for the project you're documenting, acts as replacement for 80 | # |version| and |release|, also used in various other places throughout the 81 | # built documents. 82 | # 83 | # The short X.Y version. 84 | version = os.environ.get('VERSION', 'BETA') 85 | # The full version, including alpha/beta/rc tags. 86 | release = u'' 87 | 88 | # The language for content autogenerated by Sphinx. Refer to documentation 89 | # for a list of supported languages. 90 | # 91 | # This is also used if you do content translation via gettext catalogs. 92 | # Usually you set "language" from the command line for these cases. 93 | language = None 94 | 95 | # There are two options for replacing |today|: either, you set today to some 96 | # non-false value, then it is used: 97 | #today = '' 98 | # Else, today_fmt is used as the format for a strftime call. 99 | #today_fmt = '%B %d, %Y' 100 | 101 | # List of patterns, relative to source directory, that match files and 102 | # directories to ignore when looking for source files. 103 | exclude_patterns = [] 104 | 105 | # The reST default role (used for this markup: `text`) to use for all 106 | # documents. 107 | #default_role = None 108 | 109 | # If true, '()' will be appended to :func: etc. cross-reference text. 110 | #add_function_parentheses = True 111 | 112 | # If true, the current module name will be prepended to all description 113 | # unit titles (such as .. function::). 114 | #add_module_names = True 115 | 116 | # If true, sectionauthor and moduleauthor directives will be shown in the 117 | # output. They are ignored by default. 118 | #show_authors = False 119 | 120 | # The name of the Pygments (syntax highlighting) style to use. 121 | pygments_style = 'sphinx' 122 | 123 | # A list of ignored prefixes for module index sorting. 124 | #modindex_common_prefix = [] 125 | 126 | # If true, keep warnings as "system message" paragraphs in the built documents. 127 | #keep_warnings = False 128 | 129 | # If true, `todo` and `todoList` produce output, else they produce nothing. 130 | todo_include_todos = True 131 | 132 | 133 | # -- Options for HTML output ---------------------------------------------- 134 | 135 | # The theme to use for HTML and HTML Help pages. See the documentation for 136 | # a list of builtin themes. 137 | html_theme = 'sphinx_rtd_theme' 138 | 139 | # Theme options are theme-specific and customize the look and feel of a theme 140 | # further. For a list of options available for each theme, see the 141 | # documentation. 142 | #html_theme_options = {} 143 | 144 | # Add any paths that contain custom themes here, relative to this directory. 145 | #html_theme_path = [] 146 | 147 | # The name for this set of Sphinx documents. If None, it defaults to 148 | # " v documentation". 149 | #html_title = None 150 | 151 | # A shorter title for the navigation bar. Default is the same as html_title. 152 | #html_short_title = None 153 | 154 | # The name of an image file (relative to this directory) to place at the top 155 | # of the sidebar. 156 | #html_logo = None 157 | 158 | # The name of an image file (within the static path) to use as favicon of the 159 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 160 | # pixels large. 161 | #html_favicon = None 162 | 163 | # Add any paths that contain custom static files (such as style sheets) here, 164 | # relative to this directory. They are copied after the builtin static files, 165 | # so a file named "default.css" will overwrite the builtin "default.css". 166 | html_static_path = ['_static'] 167 | 168 | # Add any extra paths that contain custom files (such as robots.txt or 169 | # .htaccess) here, relative to this directory. These files are copied 170 | # directly to the root of the documentation. 171 | #html_extra_path = [] 172 | 173 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 174 | # using the given strftime format. 175 | #html_last_updated_fmt = '%b %d, %Y' 176 | 177 | # If true, SmartyPants will be used to convert quotes and dashes to 178 | # typographically correct entities. 179 | #html_use_smartypants = True 180 | 181 | # Custom sidebar templates, maps document names to template names. 182 | #html_sidebars = {} 183 | 184 | # Additional templates that should be rendered to pages, maps page names to 185 | # template names. 186 | #html_additional_pages = {} 187 | 188 | # If false, no module index is generated. 189 | #html_domain_indices = True 190 | 191 | # If false, no index is generated. 192 | #html_use_index = True 193 | 194 | # If true, the index is split into individual pages for each letter. 195 | #html_split_index = False 196 | 197 | # If true, links to the reST sources are added to the pages. 198 | #html_show_sourcelink = True 199 | 200 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 201 | #html_show_sphinx = True 202 | 203 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 204 | #html_show_copyright = True 205 | 206 | # If true, an OpenSearch description file will be output, and all pages will 207 | # contain a tag referring to it. The value of this option must be the 208 | # base URL from which the finished HTML is served. 209 | #html_use_opensearch = '' 210 | 211 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 212 | #html_file_suffix = None 213 | 214 | # Language to be used for generating the HTML full-text search index. 215 | # Sphinx supports the following languages: 216 | # 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' 217 | # 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' 218 | #html_search_language = 'en' 219 | 220 | # A dictionary with options for the search language support, empty by default. 221 | # Now only 'ja' uses this config value 222 | #html_search_options = {'type': 'default'} 223 | 224 | # The name of a javascript file (relative to the configuration directory) that 225 | # implements a search results scorer. If empty, the default will be used. 226 | #html_search_scorer = 'scorer.js' 227 | 228 | # Output file base name for HTML help builder. 229 | htmlhelp_basename = 'DocTemplatedoc' 230 | 231 | # -- Options for LaTeX output --------------------------------------------- 232 | 233 | latex_elements = { 234 | # The paper size ('letterpaper' or 'a4paper'). 235 | #'papersize': 'letterpaper', 236 | 237 | # The font size ('10pt', '11pt' or '12pt'). 238 | #'pointsize': '10pt', 239 | 240 | # Additional stuff for the LaTeX preamble. 241 | #'preamble': '', 242 | 243 | # Latex figure (float) alignment 244 | #'figure_align': 'htbp', 245 | } 246 | 247 | # Grouping the document tree into LaTeX files. List of tuples 248 | # (source start file, target name, title, 249 | # author, documentclass [howto, manual, or own class]). 250 | latex_documents = [ 251 | (master_doc, 'DocTemplate.tex', u'Doc Template Documentation', 252 | u'DESRES', 'manual'), 253 | ] 254 | 255 | # The name of an image file (relative to this directory) to place at the top of 256 | # the title page. 257 | #latex_logo = None 258 | 259 | # For "manual" documents, if this is true, then toplevel headings are parts, 260 | # not chapters. 261 | #latex_use_parts = False 262 | 263 | # If true, show page references after internal links. 264 | #latex_show_pagerefs = False 265 | 266 | # If true, show URL addresses after external links. 267 | #latex_show_urls = False 268 | 269 | # Documents to append as an appendix to all manuals. 270 | #latex_appendices = [] 271 | 272 | # If false, no module index is generated. 273 | #latex_domain_indices = True 274 | 275 | 276 | # -- Options for manual page output --------------------------------------- 277 | 278 | # One entry per manual page. List of tuples 279 | # (source start file, name, description, authors, manual section). 280 | man_pages = [ 281 | (master_doc, 'doctemplate', u'Doc Template Documentation', 282 | [author], 1) 283 | ] 284 | 285 | # If true, show URL addresses after external links. 286 | #man_show_urls = False 287 | 288 | 289 | # -- Options for Texinfo output ------------------------------------------- 290 | 291 | # Grouping the document tree into Texinfo files. List of tuples 292 | # (source start file, target name, title, author, 293 | # dir menu entry, description, category) 294 | texinfo_documents = [ 295 | (master_doc, 'DocTemplate', u'Doc Template Documentation', 296 | author, 'DocTemplate', 'One line description of project.', 297 | 'Miscellaneous'), 298 | ] 299 | 300 | # Documents to append as an appendix to all manuals. 301 | #texinfo_appendices = [] 302 | 303 | # If false, no module index is generated. 304 | #texinfo_domain_indices = True 305 | 306 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 307 | #texinfo_show_urls = 'footnote' 308 | 309 | # If true, do not generate a @detailmenu in the "Top" node's menu. 310 | #texinfo_no_detailmenu = False 311 | 312 | 313 | # -- Options for Epub output ---------------------------------------------- 314 | 315 | # Bibliographic Dublin Core info. 316 | epub_title = project 317 | epub_author = author 318 | epub_publisher = author 319 | epub_copyright = copyright 320 | 321 | # The basename for the epub file. It defaults to the project name. 322 | #epub_basename = project 323 | 324 | # The HTML theme for the epub output. Since the default themes are not 325 | # optimized for small screen space, using the same theme for HTML and epub 326 | # output is usually not wise. This defaults to 'epub', a theme designed to save 327 | # visual space. 328 | #epub_theme = 'epub' 329 | 330 | # The language of the text. It defaults to the language option 331 | # or 'en' if the language is not set. 332 | #epub_language = '' 333 | 334 | # The scheme of the identifier. Typical schemes are ISBN or URL. 335 | #epub_scheme = '' 336 | 337 | # The unique identifier of the text. This can be a ISBN number 338 | # or the project homepage. 339 | #epub_identifier = '' 340 | 341 | # A unique identification for the text. 342 | #epub_uid = '' 343 | 344 | # A tuple containing the cover image and cover page html template filenames. 345 | #epub_cover = () 346 | 347 | # A sequence of (type, uri, title) tuples for the guide element of content.opf. 348 | #epub_guide = () 349 | 350 | # HTML files that should be inserted before the pages created by sphinx. 351 | # The format is a list of tuples containing the path and title. 352 | #epub_pre_files = [] 353 | 354 | # HTML files that should be inserted after the pages created by sphinx. 355 | # The format is a list of tuples containing the path and title. 356 | #epub_post_files = [] 357 | 358 | # A list of files that should not be packed into the epub file. 359 | epub_exclude_files = ['search.html'] 360 | 361 | # The depth of the table of contents in toc.ncx. 362 | #epub_tocdepth = 3 363 | 364 | # Allow duplicate toc entries. 365 | #epub_tocdup = True 366 | 367 | # Choose between 'default' and 'includehidden'. 368 | #epub_tocscope = 'default' 369 | 370 | # Fix unsupported image types using the Pillow. 371 | #epub_fix_images = False 372 | 373 | # Scale large images. 374 | #epub_max_image_width = 0 375 | 376 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 377 | #epub_show_urls = 'inline' 378 | 379 | # If false, no index is generated. 380 | #epub_use_index = True 381 | 382 | 383 | # Example configuration for intersphinx: refer to the Python standard library. 384 | intersphinx_mapping = {'https://docs.python.org/': None} 385 | -------------------------------------------------------------------------------- /tests/test_desmiles.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Documentation for desmiles 3 | 4 | ''' 5 | from __future__ import print_function 6 | import os 7 | import pytest 8 | from pytest import approx 9 | import subprocess 10 | from desmiles.utils import load_old_pretrained_desmiles 11 | from desmiles.config import DATA_DIR 12 | from pathlib import Path 13 | from functools import partial 14 | import torch 15 | from desmiles.models import * 16 | from rdkit import Chem 17 | import numpy as np 18 | 19 | 20 | REGRESSION_DIR = Path(os.path.join(DATA_DIR, 'regression')) 21 | PRETRAINED_DIR = Path(os.path.join(DATA_DIR, 'pretrained')) 22 | 23 | def get_recovered_smiles(num_smiles=10): 24 | return [s.strip() for s in open(REGRESSION_DIR / 'recovered_2000_400_2000_5.smi')][:num_smiles] 25 | 26 | def get_not_recovered_smiles(num_smiles=10): 27 | return [s.strip() for s in open(REGRESSION_DIR / 'not_recovered_2000_400_2000_5.smi')][:num_smiles] 28 | 29 | def get_model_fn(): 30 | return PRETRAINED_DIR / 'model_2000_400_2000_5' 31 | 32 | def get_old_astar_results(): 33 | "Return a dictionary of SMILES to top 100 results from old astar implementation of old model" 34 | import json 35 | with open(os.path.join(REGRESSION_DIR, 'astar_data.json'), 'r') as fp: 36 | results = json.load(fp) 37 | return results 38 | 39 | 40 | def test_desmiles_imports(): 41 | import torch 42 | import desmiles 43 | import rdkit 44 | print("It passes!") 45 | 46 | 47 | def test_load_old_pretrained_model(): 48 | print("Loading pretrained model") 49 | model_fn = get_model_fn() 50 | model = load_old_pretrained_desmiles(model_fn) 51 | assert not model.training 52 | 53 | def test_load_old_pretrained_learner(): 54 | import fastai 55 | print("Loading pretrained model") 56 | model_fn = get_model_fn() 57 | learner = load_old_pretrained_desmiles(model_fn, return_learner=True) 58 | assert any([f.func == fastai.train.GradientClipping for f in learner.callback_fns]) 59 | assert not learner.model.training 60 | assert len(learner.callbacks) == 1 61 | assert type(learner.callbacks[0]) is fastai.callbacks.rnn.RNNTrainer 62 | 63 | 64 | def get_pretrained_rnn(): 65 | model_fn = get_model_fn() 66 | model = load_old_pretrained_desmiles(model_fn, return_rnn=True) 67 | return model 68 | 69 | 70 | def get_pretrained_model(): 71 | model_fn = get_model_fn() 72 | model = load_old_pretrained_desmiles(model_fn) 73 | return model 74 | 75 | 76 | @pytest.fixture(scope="module") 77 | def pretrained_desmiles_rnn(): 78 | model = get_pretrained_rnn().to("cpu") 79 | return model 80 | 81 | 82 | @pytest.fixture(scope="module") 83 | def pretrained_desmiles_model(): 84 | model = get_pretrained_model().to("cpu") 85 | return model 86 | 87 | 88 | def decoder(idx_vec, itos): 89 | """Return a SMILES string from an index vector (deals with reversal)""" 90 | if len(idx_vec) < 2: 91 | return "" 92 | if idx_vec[1] == 1: # SMILES string is in fwd direction 93 | return ''.join(itos[x] for x in idx_vec if x > 3) 94 | if idx_vec[1] == 2: # SMILES string is in bwd direction 95 | return ''.join(itos[x] for x in idx_vec[::-1] if x > 3) 96 | else: # don't know how to deal with it---do your best 97 | return ''.join(itos[x] for x in idx_vec if x > 3) 98 | 99 | 100 | def get_random_input(): 101 | device = "cuda" if torch.cuda.is_available() else "cpu" 102 | input_seq = torch.zeros(300,26, dtype=torch.int64) # create a batch of size 300 103 | input_seq[:,0] = 3 # first token is 3 104 | input_seq[:,1] = torch.randint(1,3,(300,)) # second token is random fwd/bwd 105 | max_length = 26 # max length of any sequence is 26. sequences are ordered by length 106 | for i in range(300): # for every elemeent in the batch 107 | max_length = torch.randint(5,max_length+1,(1,)).item() # sample a new length (pad the rest) 108 | input_seq[i,2:max_length] = torch.randint(1,8000,(max_length - 2,)) # randomly sample tokens for the sequence 109 | lengths = (input_seq > 0).sum(dim=1) # get the lengths 110 | fps = torch.randint(1, (300,4096), device=device).type(torch.float) # get random fingerprints 111 | input_seq = input_seq.to(device) 112 | return input_seq, fps, lengths 113 | 114 | 115 | itos_fn=os.path.join(PRETRAINED_DIR, 'id.dec8000') 116 | itos = [s.strip() for i,s in enumerate(open(itos_fn, encoding='utf-8'))] 117 | decoder = partial(decoder, itos=itos) 118 | isLeafNode = lambda s_vec: (s_vec[-1] == 2 or s_vec[-1] == 1 or len(s_vec) == 26) if len(s_vec) > 3 else False 119 | 120 | 121 | def test_output_old_pytorch(): 122 | """This test compares a set of random inputs/outputs 123 | from the early version of DESMILES against those of the later versions. 124 | The early version of DESMILES used pytorch 0.4 and fastai 0.0.2, 125 | and the local version of the test creates a random input every time 126 | with explicit dependencies for running the original code. 127 | To reduce the burden on keeping multiple installations of the dependencies 128 | we've only kept a number of specific inputs and outputs and read them from torch. 129 | """ 130 | device = "cuda" if torch.cuda.is_available() else "cpu" 131 | datasets = [ 132 | [ os.path.join(REGRESSION_DIR, 'tmpo2o5u_i0.pt'), 133 | os.path.join(REGRESSION_DIR, 'tmp5g4_g26n.pt'), 134 | os.path.join(REGRESSION_DIR, 'tmpas2eooku.pt') ], 135 | [ os.path.join(REGRESSION_DIR, 'tmp4qe8p4rd.pt'), 136 | os.path.join(REGRESSION_DIR, 'tmp479_2pzi.pt'), 137 | os.path.join(REGRESSION_DIR, 'tmp90vhskjg.pt') ] 138 | ] 139 | model = get_pretrained_model() # load 3.7 pytorch 1.0+ desmiles from a dump of the old model 140 | for input_seq_fn, input_fp_fn, output_fn in datasets: 141 | input_seq = torch.load(input_seq_fn, map_location=device) 142 | input_fp = torch.load(input_fp_fn, map_location=device) 143 | lengths = (input_seq > 0).sum(dim=1) # get the lengths 144 | with torch.no_grad(): 145 | outputs = model(input_seq.transpose(0,1), input_fp, lengths) # get output of desmiles 146 | outputs_torch_04 = torch.load(output_fn) # load desmiles 0.4 output 147 | assert(torch.abs(outputs_torch_04[0] - outputs[0].to("cpu")).max() < 2e-2) 148 | for o_torch_04, o in zip(outputs_torch_04[1], outputs[1]): 149 | assert(torch.abs(o_torch_04 - o.to('cpu')).max() < 2e-2) 150 | for o_torch_04, o in zip(outputs_torch_04[2], outputs[2]): 151 | assert(torch.abs(o_torch_04 - o.to('cpu')).max() < 2e-2) 152 | return 153 | 154 | 155 | def test_split_desmiles_first_layer(): 156 | "Test splitting of DESMILES at 1st layer, as per paper" 157 | model = get_pretrained_model() 158 | fp_to_embedding = get_fp_to_embedding_model(model, first_layer=True) 159 | embedded_fp_to_smiles = get_embedded_fp_to_smiles_model(model, first_layer=True) 160 | input_seq, fps, lengths = get_random_input() 161 | with torch.no_grad(): 162 | output = model(input_seq.transpose(0,1), fps, lengths) 163 | embedding = fp_to_embedding(fps) 164 | output_2 = embedded_fp_to_smiles(input_seq.transpose(0,1), embedding, lengths) 165 | assert(torch.abs(output[0] - output_2[0]).max() < 1e-4) 166 | for o, o2 in zip(output[1], output_2[1]): 167 | assert(torch.abs(o - o2).max() < 1e-4) 168 | for o, o2 in zip(output[2], output_2[2]): 169 | assert(torch.abs(o - o2).max() < 1e-4) 170 | 171 | 172 | def test_split_desmiles_second_layer(): 173 | "Test splitting of DESMILES at 2nd layer, just in case" 174 | model = get_pretrained_model() 175 | fp_to_embedding = get_fp_to_embedding_model(model, first_layer=False) 176 | embedded_fp_to_smiles = get_embedded_fp_to_smiles_model(model, first_layer=False) 177 | input_seq, fps, lengths = get_random_input() 178 | with torch.no_grad(): 179 | output = model(input_seq.transpose(0,1), fps, lengths) 180 | embedding = fp_to_embedding(fps) 181 | output_2 = embedded_fp_to_smiles(input_seq.transpose(0,1), embedding, lengths) 182 | assert(torch.abs(output[0] - output_2[0]).max() < 1e-4) 183 | for o, o2 in zip(output[1], output_2[1]): 184 | assert(torch.abs(o - o2).max() < 1e-4) 185 | for o, o2 in zip(output[2], output_2[2]): 186 | assert(torch.abs(o - o2).max() < 1e-4) 187 | 188 | 189 | def test_recurrent_desmiles(): 190 | "Test recurrent version of DESMILES used in beam and astar search." 191 | model = get_pretrained_model() 192 | input_seq, fps, lengths = get_random_input() 193 | with torch.no_grad(): 194 | output = model(input_seq.transpose(0,1), fps, lengths) 195 | rdesmiles = RecurrentDESMILES(model) 196 | rdesmiles.embed_fingerprints(fps) 197 | output_2 = rdesmiles(input_seq.transpose(0,1)) 198 | # Output of sequence-input recurrent astar model is BS x SL x vocab_size (8000). 199 | # Output of fp-sequence-input model is -1 (SLxBS) x vocab_size (8000). 200 | assert(torch.abs(output_2.transpose(0,1).reshape(-1, 8000) - output[0]).max() < 1e-5) 201 | 202 | def test_recurrent_desmiles_from_embedding_first_layer(): 203 | "Test recurrent version of DESMILES used in beam and astar search." 204 | model = get_pretrained_model() 205 | fp_to_embedding = get_fp_to_embedding_model(model, first_layer=True) 206 | # use only the second linear layer to precalculate fingerprint embedding 207 | rdesmiles = RecurrentDESMILES(model, fp_embedding_layers=(1,)) 208 | input_seq, fps, lengths = get_random_input() 209 | with torch.no_grad(): 210 | output = model(input_seq.transpose(0,1), fps, lengths) 211 | embedding = fp_to_embedding(fps) 212 | rdesmiles.embed_fingerprints(embedding) 213 | output_2 = rdesmiles(input_seq.transpose(0,1)) 214 | # Output of sequence-input recurrent astar model is BS x SL x vocab_size (8000). 215 | # Output of fp-sequence-input model is -1 (SLxBS) x vocab_size (8000). 216 | assert(torch.abs(output_2.transpose(0,1).reshape(-1, 8000) - output[0]).max() < 1e-5) 217 | 218 | def test_recurrent_desmiles_from_embedding_second_layer(): 219 | "Test recurrent version of DESMILES used in beam and astar search." 220 | model = get_pretrained_model() 221 | fp_to_embedding = get_fp_to_embedding_model(model, first_layer=False) 222 | # the fingerprint is already fully embedded so fp_embedding_layers is empty 223 | rdesmiles = RecurrentDESMILES(model, fp_embedding_layers=()) 224 | input_seq, fps, lengths = get_random_input() 225 | with torch.no_grad(): 226 | output = model(input_seq.transpose(0,1), fps, lengths) 227 | embedding = fp_to_embedding(fps) 228 | rdesmiles.embed_fingerprints(embedding) 229 | output_2 = rdesmiles(input_seq.transpose(0,1)) 230 | # Output of sequence-input recurrent astar model is BS x SL x vocab_size (8000). 231 | # Output of fp-sequence-input model is -1 (SLxBS) x vocab_size (8000). 232 | assert(torch.abs(output_2.transpose(0,1).reshape(-1, 8000) - output[0]).max() < 1e-5) 233 | 234 | 235 | def test_against_sequential_astar_mem_safe(): 236 | from desmiles.utils import smiles_to_fingerprint 237 | import desmiles.decoding 238 | from desmiles.decoding.astar import AstarTreeParallel 239 | device = "cuda" if torch.cuda.is_available() else "cpu" 240 | old_astar_results = get_old_astar_results() 241 | smiles_to_test = old_astar_results.keys() 242 | for smiles in smiles_to_test: 243 | print(f"Testing A* with smiles: {smiles}") 244 | # get fingerprint: output of smiles_to_fingerprint is fp_size (4096), 245 | # which is reshaped to BS x fp_size, 246 | # which is the input to the sequence-input (recurrent) DESMILES 247 | fp = torch.tensor(smiles_to_fingerprint(smiles).astype(float).reshape(1,-1), dtype=torch.float, device=device) 248 | sequential_results = old_astar_results[smiles] 249 | parallel_results = [] 250 | model = get_pretrained_rnn() 251 | astar_par = AstarTreeParallel(fp, model, num_expand=100) 252 | parallel_results = [] 253 | for _ in range(len(sequential_results)): 254 | score, smiles_vector = next(astar_par) 255 | # The new code keeps the padding in the output, removed here: 256 | smiles_vector = [x for x in smiles_vector.tolist() if x > 0] 257 | smiles_string = decoder(smiles_vector) 258 | parallel_results.append([score, smiles_string, smiles_vector]) 259 | # Numerical accuracy is such that sometimes the order is swapped. 260 | # So this check tests one before and one after. 261 | # only check input molecule up to the penultimate SMILES. 262 | for i, (pr, sr) in enumerate(zip(parallel_results[:-1], sequential_results[:-1])): 263 | assert ((pr[0] == approx(sr[0], rel=1e-2)) or 264 | (pr[0] == approx(sequential_results[i+1][0], rel=1e-2)) or 265 | (pr[0] == approx(sequential_results[i-1][0], rel=1e-2))) 266 | assert (pr[1:] == sr[1:]) or (pr[1:] == sequential_results[i-1][1:]) or ((pr[1:] == sequential_results[i+1][1:])) 267 | return 268 | 269 | 270 | def test_against_sequential_astar_not_mem_safe(): 271 | from desmiles.utils import smiles_to_fingerprint 272 | import desmiles.decoding 273 | from desmiles.decoding.astar import AstarTreeParallelNotSafe 274 | device = "cuda" if torch.cuda.is_available() else "cpu" 275 | old_astar_results = get_old_astar_results() 276 | smiles_to_test = old_astar_results.keys() 277 | for smiles in smiles_to_test: 278 | print(f"Testing A* with smiles: {smiles}") 279 | # get fingerprint: output of smiles_to_fingerprint is fp_size (4096), 280 | # which is reshaped to BS x fp_size, 281 | # which is the input to the sequence-input (recurrent) DESMILES 282 | fp = torch.tensor(smiles_to_fingerprint(smiles).astype(float).reshape(1,-1), dtype=torch.float, device=device) 283 | sequential_results = old_astar_results[smiles] 284 | parallel_results = [] 285 | model = get_pretrained_rnn() 286 | astar_par = AstarTreeParallelNotSafe(fp, model, num_expand=100) 287 | parallel_results = [] 288 | for _ in range(len(sequential_results)): 289 | score, smiles_vector = next(astar_par) 290 | # The new code keeps the padding in the output, removed here: 291 | smiles_vector = [x for x in smiles_vector.tolist() if x > 0] 292 | smiles_string = decoder(smiles_vector) 293 | parallel_results.append([score, smiles_string, smiles_vector]) 294 | # Numerical accuracy is such that sometimes the order is swapped. 295 | # So this check tests one before and one after. 296 | # only check input molecule up to the penultimate SMILES. 297 | for i, (pr, sr) in enumerate(zip(parallel_results[:-1], sequential_results[:-1])): 298 | assert ((pr[0] == approx(sr[0], rel=1e-2)) or 299 | (pr[0] == approx(sequential_results[i+1][0], rel=1e-2)) or 300 | (pr[0] == approx(sequential_results[i-1][0], rel=1e-2))) 301 | assert (pr[1:] == sr[1:]) or (pr[1:] == sequential_results[i-1][1:]) or ((pr[1:] == sequential_results[i+1][1:])) 302 | return 303 | 304 | 305 | def test_rdkit_molecule_chirality(): 306 | s1 = "C[C@@H]1CCCNC1[C@@H](C)N" 307 | s2 = "C[C@@H](N)C1NCCC[C@H]1C" 308 | mol = Chem.MolFromSmiles(s1) 309 | assert Chem.CanonSmiles(s1) == s2 310 | assert Chem.MolToSmiles(mol) == s2 311 | 312 | 313 | def test_one_cycle_learner(): 314 | from desmiles.learner import OriginalFastaiOneCycleScheduler, Learner 315 | from fastai.basic_data import DataBunch 316 | n=1000 317 | sigma=0.1 318 | x = np.linspace(-1,1, n) + (np.random.randn(n) * sigma) 319 | y = x**2 320 | x_t = torch.tensor(x, dtype=torch.float).unsqueeze(1) 321 | y_t = torch.tensor(y, dtype=torch.float).unsqueeze(1) 322 | 323 | trn_ds = torch.utils.data.TensorDataset(x_t, y_t) 324 | val_ds = torch.utils.data.TensorDataset(x_t, y_t) 325 | trn_loader = torch.utils.data.DataLoader(trn_ds, batch_size=10, shuffle=True) 326 | val_loader = torch.utils.data.DataLoader(val_ds, batch_size=10, shuffle=False) 327 | db = DataBunch(trn_loader, val_loader) 328 | model = torch.nn.Sequential(torch.nn.Linear(1,100), torch.nn.ReLU(), torch.nn.Linear(100,1), torch.nn.ReLU()) 329 | learner = Learner(db, model, loss_func=torch.nn.functional.mse_loss) 330 | div_factor=2 331 | one_cycle_linear_cb = OriginalFastaiOneCycleScheduler(learner, 0.002, div_factor=div_factor) 332 | learner.fit(1, callbacks=[one_cycle_linear_cb]) 333 | assert max(learner.recorder.lrs) / learner.recorder.lrs[0] == div_factor 334 | -------------------------------------------------------------------------------- /lib-python/desmiles/models.py: -------------------------------------------------------------------------------- 1 | from fastai.torch_core import * 2 | from fastai.layers import * 3 | from fastai.text.models.awd_lstm import EmbeddingDropout, RNNDropout, WeightDropout, dropout_mask 4 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 5 | 6 | __all__ = ['DesmilesCore', 'LinearDecoder', 'Desmiles', 'RecurrentDESMILES', 'FPEmbedder', 'EmbeddingToSMILES', 7 | 'FingerprintEmbedderCore', 'get_desmiles_model', 'get_fp_to_embedding_model', 'get_embedded_fp_to_smiles_model'] 8 | 9 | 10 | class DesmilesCore(nn.Module): 11 | '''Core Piece of DESMILES 12 | Does everything except final softmax layer (LinearDecoder) 13 | ''' 14 | 15 | initrange=0.1 16 | 17 | def __init__(self, vocab_sz:int, fp_emb_sz:int, emb_sz:int, n_hid:int, n_layers:int, pad_token:int=0, bidir:bool=False, 18 | hidden_p:float=0.2, input_p:float=0.6, embed_p:float=0.1, weight_p:float=0.5, qrnn:bool=False, num_bits=4096): 19 | super().__init__() 20 | self.bs,self.qrnn,self.ndir = 1, qrnn,(2 if bidir else 1) 21 | self.emb_sz,self.n_hid,self.n_layers = emb_sz,n_hid,n_layers 22 | self.encoder = nn.Embedding(vocab_sz, emb_sz, padding_idx=pad_token) 23 | self.encoder_dp = EmbeddingDropout(self.encoder, embed_p) 24 | self.rnns = [nn.LSTM(emb_sz if l == 0 else n_hid, (n_hid if l != n_layers - 1 else emb_sz)//self.ndir, 25 | 1, bidirectional=bidir, batch_first=False) for l in range(n_layers)] 26 | self.rnns = [WeightDropout(rnn, weight_p) for rnn in self.rnns] 27 | self.rnns = torch.nn.ModuleList(self.rnns) 28 | self.encoder.weight.data.uniform_(-self.initrange, self.initrange) 29 | self.input_dp = RNNDropout(input_p) 30 | self.hidden_dps = nn.ModuleList([RNNDropout(hidden_p) for l in range(n_layers)]) 31 | self.linear_fp = [nn.Sequential(*bn_drop_lin(num_bits, fp_emb_sz, p=0.0, actn=torch.nn.Tanh())), nn.Sequential(*bn_drop_lin(fp_emb_sz, n_hid, p=0.1, actn=torch.nn.Tanh()))] 32 | self.linear_fp = torch.nn.ModuleList(self.linear_fp) 33 | 34 | 35 | def forward(self, input_seq, input_fp, lengths): 36 | 37 | sl,bs = input_seq.size() 38 | if bs!=self.bs: 39 | self.bs=bs 40 | self.reset() 41 | # Apply LinearBlocks on input_fp 42 | emb_fp = input_fp 43 | for linear_fp in self.linear_fp: 44 | emb_fp = linear_fp(emb_fp) 45 | raw_output = self.input_dp(self.encoder_dp(input_seq)) 46 | raw_outputs,outputs = [],[] 47 | for l, (rnn,hid_dp) in enumerate(zip(self.rnns, self.hidden_dps)): 48 | raw_output = pack_padded_sequence(raw_output, lengths, batch_first=False) 49 | if l == 0: 50 | raw_output, new_h = rnn(raw_output, (emb_fp.unsqueeze(0), emb_fp.unsqueeze(0))) 51 | else: 52 | raw_output, new_h = rnn(raw_output, self.hidden[l]) 53 | raw_output, lengths = pad_packed_sequence(raw_output, batch_first=False) 54 | raw_outputs.append(raw_output) 55 | if l != self.n_layers - 1: raw_output = hid_dp(raw_output) 56 | if l == self.n_layers - 1: outputs.append(raw_output) 57 | #self.hidden = to_detach(new_hidden, cpu=False) 58 | return raw_outputs, outputs 59 | 60 | def _one_hidden(self, l:int)->Tensor: 61 | "Return one hidden state." 62 | nh = (self.n_hid if l != self.n_layers - 1 else self.emb_sz)//self.ndir 63 | return self.weights.new(self.ndir, self.bs, nh).zero_() 64 | 65 | def reset(self): 66 | "Reset the hidden states." 67 | [r.reset() for r in self.rnns if hasattr(r, 'reset')] 68 | self.weights = next(self.parameters()).data 69 | self.hidden = [(self._one_hidden(l), self._one_hidden(l)) for l in range(self.n_layers)] 70 | 71 | class LinearDecoder(nn.Module): 72 | "To go on top of a DesmilesCore module and create a DESMILES Model." 73 | initrange=0.1 74 | def __init__(self, n_out:int, n_hid:int, output_p:float, tie_encoder:nn.Module=None, bias:bool=True): 75 | super().__init__() 76 | self.decoder = nn.Linear(n_hid, n_out, bias=bias) 77 | self.decoder.weight.data.uniform_(-self.initrange, self.initrange) 78 | self.output_dp = RNNDropout(output_p) 79 | if bias: self.decoder.bias.data.zero_() 80 | if tie_encoder: self.decoder.weight = tie_encoder.weight 81 | 82 | def forward(self, input:Tuple[Tensor,Tensor])->Tuple[Tensor,Tensor,Tensor]: 83 | raw_outputs, outputs = input 84 | output = self.output_dp(outputs[-1]) 85 | decoded = self.decoder(output.contiguous().view(output.size(0)*output.size(1), output.size(2))) 86 | result = decoded.view(-1, decoded.size(1)) 87 | return result, raw_outputs, outputs 88 | 89 | 90 | class Desmiles(nn.Module): 91 | "Combines DesmilesCore with LinearDecoder for a full model" 92 | def __init__(self, desmiles_rnn_core, linear_decoder): 93 | super().__init__() 94 | self.desmiles_rnn_core = desmiles_rnn_core 95 | self.linear_decoder = linear_decoder 96 | 97 | def reset(self): 98 | for c in [self.desmiles_rnn_core, self.linear_decoder]: 99 | if hasattr(c, 'reset'): 100 | c.reset() 101 | 102 | def forward(self, input_seq, input_fp, lengths): 103 | output = self.desmiles_rnn_core(input_seq, input_fp, lengths.detach().cpu()) 104 | output = self.linear_decoder(output) 105 | return output 106 | 107 | def __getitem__(self,key): 108 | if key == 0: 109 | return self.desmiles_rnn_core 110 | elif key == 1: 111 | return self.linear_decoder 112 | else: 113 | raise ValueError("Indexing only supports 0 or 1") 114 | 115 | 116 | def get_desmiles_model(vocab_sz:int, fp_emb_sz:int, emb_sz:int, n_hid:int, n_layers:int, pad_token:int=0, tie_weights:bool=True, 117 | qrnn:bool=False, bias:bool=True, bidir:bool=False, output_p:float=0.4, hidden_p:float=0.2, input_p:float=0.6, 118 | embed_p:float=0.1, weight_p:float=0.5, num_bits=4096)->nn.Module: 119 | "Create a full DESMILES model." 120 | rnn_enc = DesmilesCore(vocab_sz, fp_emb_sz, emb_sz, n_hid=n_hid, n_layers=n_layers, pad_token=pad_token, qrnn=qrnn, bidir=bidir, 121 | hidden_p=hidden_p, input_p=input_p, embed_p=embed_p, weight_p=weight_p, num_bits=4096) 122 | enc = rnn_enc.encoder if tie_weights else None 123 | model = Desmiles(rnn_enc, LinearDecoder(vocab_sz, emb_sz, output_p, tie_encoder=enc, bias=bias)) 124 | model.reset() 125 | return model 126 | 127 | 128 | class RecurrentDESMILES(nn.Module): 129 | ''' 130 | RecurrentDESMILES is a reimplimentation of DESMILES model which provides separate functions 131 | for the fingerprint embedding and the decoding. This is used heavily in decoding methods 132 | such as beam search and A* 133 | ''' 134 | initrange=0.1 135 | def __init__(self, desmiles, fp_embedding_layers=(0, 1)): 136 | super().__init__() 137 | self.desmiles = desmiles 138 | self.fp_embedding_layers = nn.ModuleList([desmiles[0].linear_fp[layer] for layer in fp_embedding_layers]) 139 | self.embedding = desmiles[0].encoder 140 | self.rnns = nn.ModuleList([rnn for rnn in desmiles[0].rnns]) 141 | self.final_layer = desmiles[1].decoder 142 | self.emb_sz = desmiles[0].emb_sz 143 | self.nhid = desmiles[0].n_hid 144 | self.nlayers = len(self.rnns) 145 | 146 | def embed_fingerprints(self, fps): 147 | assert next(self.desmiles.parameters()).device == fps.device 148 | self.bs = fps.shape[0] 149 | self.reset() # reset the hidden state every time we get a new fingerprint 150 | with torch.no_grad(): 151 | embedded = fps 152 | for layer in self.fp_embedding_layers: 153 | embedded = layer(embedded) 154 | self.hiddens[0] = (embedded.unsqueeze(0), embedded.unsqueeze(0)) 155 | 156 | def forward(self, seq): 157 | sl,bs = seq.size() 158 | if bs!=self.bs: 159 | self.bs=bs 160 | self.reset() 161 | lengths = (seq > 0).sum(dim=0) 162 | with torch.no_grad(): 163 | output = self.embedding(seq) 164 | output = pack_padded_sequence(output, lengths.detach().cpu()) 165 | for i, (rnn, hidden) in enumerate(zip(self.rnns, self.hiddens)): 166 | output, hidden = rnn(output, hidden) 167 | self.hiddens[i] = hidden 168 | output, lengths = pad_packed_sequence(output) 169 | return self.final_layer(output).transpose(0,1) 170 | 171 | 172 | def one_hidden(self, l): 173 | device = "cuda" if torch.cuda.is_available() else "cpu" 174 | nh = (self.nhid if l != self.nlayers - 1 else self.emb_sz) 175 | return torch.zeros(1, self.bs, nh).to(device) 176 | 177 | def reset(self): 178 | self.hiddens = [(self.one_hidden(l), self.one_hidden(l)) for l in range(self.nlayers)] 179 | 180 | def set_hiddens(self, hiddens): 181 | assert len(hiddens) == len(self.hiddens) 182 | self.hiddens = hiddens 183 | 184 | def select_hidden(self, idxs): 185 | self.hiddens = [(h[0][:,idxs,:],h[1][:,idxs,:]) for h in self.hiddens] 186 | self.bs = len(idxs) 187 | 188 | ################################################################################# 189 | ########################### Code to Split Model ################################# 190 | ################################################################################# 191 | 192 | 193 | class FPEmbedder(nn.Module): 194 | """ Maps fingerprints to their pretrained embedding space 195 | For simplicity of loading weights from DESMILES, I create a copy of DESMILESRNNCore 196 | and just replace the forward() method 197 | """ 198 | initrange=0.1 199 | def __init__(self, vocab_sz:int, fp_emb_sz:int, emb_sz:int, n_hid:int, n_layers:int, pad_token:int=0, bidir:bool=False, 200 | hidden_p:float=0.2, input_p:float=0.6, embed_p:float=0.1, weight_p:float=0.5, qrnn:bool=False, num_bits=4096, first_layer=False): 201 | super().__init__() 202 | self.first_layer = first_layer 203 | self.bs,self.qrnn,self.ndir = 1, qrnn,(2 if bidir else 1) 204 | self.emb_sz,self.n_hid,self.n_layers = emb_sz,n_hid,n_layers 205 | self.encoder = nn.Embedding(vocab_sz, emb_sz, padding_idx=pad_token) 206 | self.encoder_dp = EmbeddingDropout(self.encoder, embed_p) 207 | self.rnns = [nn.LSTM(emb_sz if l == 0 else n_hid, (n_hid if l != n_layers - 1 else emb_sz)//self.ndir, 208 | 1, bidirectional=bidir, batch_first=False) for l in range(n_layers)] 209 | self.rnns = [WeightDropout(rnn, weight_p) for rnn in self.rnns] 210 | self.rnns = torch.nn.ModuleList(self.rnns) 211 | self.encoder.weight.data.uniform_(-self.initrange, self.initrange) 212 | self.input_dp = RNNDropout(input_p) 213 | self.hidden_dps = nn.ModuleList([RNNDropout(hidden_p) for l in range(n_layers)]) 214 | self.linear_fp = [nn.Sequential(*bn_drop_lin(num_bits, fp_emb_sz, p=0.0, actn=torch.nn.Tanh())), nn.Sequential(*bn_drop_lin(fp_emb_sz, n_hid, p=0.1, actn=torch.nn.Tanh()))] 215 | self.linear_fp = torch.nn.ModuleList(self.linear_fp) 216 | 217 | 218 | def forward(self, input_fp): 219 | if self.first_layer: 220 | return self.linear_fp[0](input_fp) 221 | 222 | for layer in self.linear_fp: 223 | input_fp = layer(input_fp) 224 | return input_fp 225 | 226 | class EmbeddingToSMILES(nn.Module): 227 | """Maps embedding space of fingerprints to SMILES 228 | For simplicity of loading weights from DESMILES, I create a copy of DESMILSE and just replace 229 | the forward() method 230 | """ 231 | def __init__(self, desmiles_rnn_core, linear_decoder): 232 | super().__init__() 233 | self.desmiles_rnn_core = desmiles_rnn_core 234 | self.linear_decoder = linear_decoder 235 | 236 | def reset(self): 237 | for c in [self.desmiles_rnn_core, self.linear_decoder]: 238 | if hasattr(c, 'reset'): 239 | c.reset() 240 | 241 | def forward(self, input_seq, input_fp, lengths): 242 | output = self.desmiles_rnn_core(input_seq, input_fp, lengths.detach().cpu()) 243 | output = self.linear_decoder(output) 244 | return output 245 | 246 | def __getitem__(self,key): 247 | if key == 0: 248 | return self.desmiles_rnn_core 249 | elif key == 1: 250 | return self.linear_decoder 251 | else: 252 | raise ValueError("Indexing only supports 0 or 1") 253 | 254 | 255 | class FingerprintEmbedderCore(nn.Module): 256 | """Maps embedding space of fingerprints to SMILES 257 | For simplicity of loading weights from DESMILES, I create a copy of DESMILSE and just replace 258 | the forward() method 259 | """ 260 | 261 | initrange=0.1 262 | 263 | def __init__(self, vocab_sz:int, fp_emb_sz:int, emb_sz:int, n_hid:int, n_layers:int, pad_token:int=0, bidir:bool=False, 264 | hidden_p:float=0.2, input_p:float=0.6, embed_p:float=0.1, weight_p:float=0.5, qrnn:bool=False, num_bits=4096, first_layer=True): 265 | super().__init__() 266 | self.first_layer = first_layer 267 | self.bs,self.qrnn,self.ndir = 1, qrnn,(2 if bidir else 1) 268 | self.emb_sz,self.n_hid,self.n_layers = emb_sz,n_hid,n_layers 269 | self.encoder = nn.Embedding(vocab_sz, emb_sz, padding_idx=pad_token) 270 | self.encoder_dp = EmbeddingDropout(self.encoder, embed_p) 271 | self.rnns = [nn.LSTM(emb_sz if l == 0 else n_hid, (n_hid if l != n_layers - 1 else emb_sz)//self.ndir, 272 | 1, bidirectional=bidir, batch_first=False) for l in range(n_layers)] 273 | self.rnns = [WeightDropout(rnn, weight_p) for rnn in self.rnns] 274 | self.rnns = torch.nn.ModuleList(self.rnns) 275 | self.encoder.weight.data.uniform_(-self.initrange, self.initrange) 276 | self.input_dp = RNNDropout(input_p) 277 | self.hidden_dps = nn.ModuleList([RNNDropout(hidden_p) for l in range(n_layers)]) 278 | self.linear_fp = [nn.Sequential(*bn_drop_lin(num_bits, fp_emb_sz, p=0.0, actn=torch.nn.Tanh())), nn.Sequential(*bn_drop_lin(fp_emb_sz, n_hid, p=0.1, actn=torch.nn.Tanh()))] 279 | self.linear_fp = torch.nn.ModuleList(self.linear_fp) 280 | 281 | 282 | def forward(self, input_seq, input_fp, lengths): 283 | 284 | sl,bs = input_seq.size() 285 | if bs!=self.bs: 286 | self.bs=bs 287 | self.reset() 288 | # Apply LinearBlocks on input_fp 289 | emb_fp = input_fp 290 | if self.first_layer: 291 | for linear_fp in self.linear_fp[1:]: 292 | emb_fp = linear_fp(emb_fp) 293 | raw_output = self.input_dp(self.encoder_dp(input_seq)) 294 | raw_outputs,outputs = [],[] 295 | for l, (rnn,hid_dp) in enumerate(zip(self.rnns, self.hidden_dps)): 296 | raw_output = pack_padded_sequence(raw_output, lengths, batch_first=False) 297 | if l == 0: 298 | raw_output, new_h = rnn(raw_output, (emb_fp.unsqueeze(0), emb_fp.unsqueeze(0))) 299 | else: 300 | raw_output, new_h = rnn(raw_output, self.hidden[l]) 301 | raw_output, lengths = pad_packed_sequence(raw_output, batch_first=False) 302 | raw_outputs.append(raw_output) 303 | if l != self.n_layers - 1: raw_output = hid_dp(raw_output) 304 | if l == self.n_layers - 1: outputs.append(raw_output) 305 | return raw_outputs, outputs 306 | 307 | def _one_hidden(self, l:int)->Tensor: 308 | "Return one hidden state." 309 | nh = (self.n_hid if l != self.n_layers - 1 else self.emb_sz)//self.ndir 310 | return self.weights.new(self.ndir, self.bs, nh).zero_() 311 | 312 | def reset(self): 313 | "Reset the hidden states." 314 | [r.reset() for r in self.rnns if hasattr(r, 'reset')] 315 | self.weights = next(self.parameters()).data 316 | self.hidden = [(self._one_hidden(l), self._one_hidden(l)) for l in range(self.n_layers)] 317 | 318 | def get_fp_to_embedding_model(desmiles, first_layer=True, device=None): 319 | if device is None: 320 | device = "cuda" if torch.cuda.is_available() else "cpu" 321 | fp_emb_sz = desmiles[0].linear_fp[0][1].out_features 322 | emb_sz = desmiles[0].emb_sz 323 | num_bits = desmiles[0].linear_fp[0][1].in_features 324 | n_tok = desmiles[0].encoder.weight.size()[0] 325 | nhid = desmiles[0].rnns[0].module.hidden_size 326 | nlayers = len(desmiles[0].rnns) 327 | pad_token = desmiles[0].encoder.padding_idx 328 | fp_to_embedding = FPEmbedder(n_tok, fp_emb_sz, emb_sz, nhid, nlayers, pad_token=pad_token, first_layer=first_layer) 329 | fp_to_embedding_dict = fp_to_embedding.state_dict() 330 | pretrained_dict = desmiles[0].state_dict() 331 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in fp_to_embedding_dict} 332 | assert len(pretrained_dict.keys()) == len(fp_to_embedding_dict.keys()) 333 | fp_to_embedding_dict.update(pretrained_dict) 334 | fp_to_embedding.load_state_dict(fp_to_embedding_dict) 335 | fp_to_embedding.eval() 336 | return fp_to_embedding.to(device) 337 | 338 | def get_embedded_fp_to_smiles_model(desmiles, first_layer=True, device=None): 339 | if device is None: 340 | device = "cuda" if torch.cuda.is_available() else "cpu" 341 | fp_emb_sz = desmiles[0].linear_fp[0][1].out_features 342 | emb_sz = desmiles[0].emb_sz 343 | num_bits = desmiles[0].linear_fp[0][1].in_features 344 | n_tok = desmiles[0].encoder.weight.size()[0] 345 | nhid = desmiles[0].rnns[0].module.hidden_size 346 | nlayers = len(desmiles[0].rnns) 347 | pad_token = desmiles[0].encoder.padding_idx 348 | embedded_fp_to_encoded_core = FingerprintEmbedderCore(n_tok, fp_emb_sz, emb_sz, nhid, nlayers, pad_token=pad_token, first_layer=first_layer) 349 | trained_dict = desmiles[0].state_dict() 350 | embedded_fp_to_encoded_core_dict = embedded_fp_to_encoded_core.state_dict() 351 | pretrained_dict = desmiles[0].state_dict() 352 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in embedded_fp_to_encoded_core_dict} 353 | assert len(pretrained_dict.keys()) == len(embedded_fp_to_encoded_core_dict.keys()) 354 | embedded_fp_to_encoded_core_dict.update(pretrained_dict) 355 | embedded_fp_to_encoded_core.load_state_dict(embedded_fp_to_encoded_core_dict) 356 | linear_decoder = desmiles[1] 357 | embedded_fp_to_encoded = EmbeddingToSMILES(embedded_fp_to_encoded_core, linear_decoder) 358 | embedded_fp_to_encoded.eval() 359 | return embedded_fp_to_encoded.to(device) 360 | -------------------------------------------------------------------------------- /Notebooks/overview_of_DESMILES.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "autoscroll": false, 8 | "ein.hycell": false, 9 | "ein.tags": "worksheet-0", 10 | "slideshow": { 11 | "slide_type": "-" 12 | } 13 | }, 14 | "outputs": [], 15 | "source": [ 16 | "%reload_ext autoreload\n", 17 | "%autoreload 2\n", 18 | "use_gpu = True\n", 19 | "if not use_gpu:\n", 20 | " import os\n", 21 | " os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"-1\"\n", 22 | "import seaborn as sns" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": { 29 | "autoscroll": false, 30 | "ein.hycell": false, 31 | "ein.tags": "worksheet-0", 32 | "slideshow": { 33 | "slide_type": "-" 34 | } 35 | }, 36 | "outputs": [], 37 | "source": [ 38 | "import sys" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": { 45 | "autoscroll": false, 46 | "ein.hycell": false, 47 | "ein.tags": "worksheet-0", 48 | "slideshow": { 49 | "slide_type": "-" 50 | } 51 | }, 52 | "outputs": [], 53 | "source": [ 54 | "import desmiles\n", 55 | "from desmiles.data import Vocab, FpSmilesList, DesmilesLoader, DataBunch\n", 56 | "from desmiles.learner import desmiles_model_learner\n", 57 | "from desmiles.models import Desmiles, RecurrentDESMILES, get_fp_to_embedding_model, get_embedded_fp_to_smiles_model\n", 58 | "from desmiles.utils import load_old_pretrained_desmiles, load_pretrained_desmiles" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": { 65 | "autoscroll": false, 66 | "ein.hycell": false, 67 | "ein.tags": "worksheet-0", 68 | "slideshow": { 69 | "slide_type": "-" 70 | } 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "from pathlib import Path\n", 75 | "import numpy as np\n", 76 | "import torch\n", 77 | "import pandas as pd\n", 78 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "device" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "from desmiles.config import DATA_DIR" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": { 103 | "autoscroll": false, 104 | "ein.hycell": false, 105 | "ein.tags": "worksheet-0", 106 | "slideshow": { 107 | "slide_type": "-" 108 | } 109 | }, 110 | "outputs": [], 111 | "source": [ 112 | "model_fn = Path(DATA_DIR) / 'pretrained/model_2000_400_2000_5.h5'\n", 113 | "architecture = {'fp_emb_sz': 2000, 'emb_sz': 400, 'nh': 2000, 'nl': 5, 'clip':0.3, 'alpha':2., 'beta':1.}\n", 114 | "# load fastai learner\n", 115 | "learner = load_old_pretrained_desmiles(model_fn, return_learner=True, **architecture)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": { 122 | "autoscroll": false, 123 | "ein.hycell": false, 124 | "ein.tags": "worksheet-0", 125 | "slideshow": { 126 | "slide_type": "-" 127 | } 128 | }, 129 | "outputs": [], 130 | "source": [ 131 | "from desmiles.learner import OriginalFastaiOneCycleScheduler, Learner\n", 132 | "\n", 133 | "# generate training data\n", 134 | "n=1000\n", 135 | "sigma=0.1\n", 136 | "# learn function y = x**2 + noise\n", 137 | "x = np.linspace(-1,1, n)\n", 138 | "y = x**2 + (np.random.randn(n) * sigma)\n", 139 | "x_t = torch.tensor(x, dtype=torch.float).unsqueeze(1)\n", 140 | "y_t = torch.tensor(y, dtype=torch.float).unsqueeze(1)\n", 141 | "\n", 142 | "# create databunch\n", 143 | "trn_ds = torch.utils.data.TensorDataset(x_t, y_t)\n", 144 | "val_ds = torch.utils.data.TensorDataset(x_t, y_t)\n", 145 | "trn_loader = torch.utils.data.DataLoader(trn_ds, batch_size=10, shuffle=True)\n", 146 | "val_loader = torch.utils.data.DataLoader(val_ds, batch_size=10, shuffle=False)\n", 147 | "db = DataBunch(trn_loader, val_loader)\n", 148 | "\n", 149 | "# train model \n", 150 | "model = torch.nn.Sequential(torch.nn.Linear(1,100), torch.nn.ReLU(), torch.nn.Linear(100,1), torch.nn.ReLU())\n", 151 | "learner = Learner(db, model, loss_func=torch.nn.functional.mse_loss)\n", 152 | "div_factor=10\n", 153 | "# Use the old fastai one cycle training policy\n", 154 | "one_cycle_linear_cb = OriginalFastaiOneCycleScheduler(learner, 0.002, div_factor=div_factor)\n", 155 | "learner.fit(5, callbacks=[one_cycle_linear_cb])\n", 156 | "learner.recorder.plot_lr()" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": { 163 | "autoscroll": false, 164 | "ein.hycell": false, 165 | "ein.tags": "worksheet-0", 166 | "slideshow": { 167 | "slide_type": "-" 168 | } 169 | }, 170 | "outputs": [], 171 | "source": [ 172 | "# The current fastai library uses the following for their one cycle training policy\n", 173 | "learner = Learner(db, model, loss_func=torch.nn.functional.mse_loss)\n", 174 | "learner.fit_one_cycle(5, 0.002)\n", 175 | "learner.recorder.plot_lr()" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": { 182 | "autoscroll": false, 183 | "ein.hycell": false, 184 | "ein.tags": "worksheet-0", 185 | "slideshow": { 186 | "slide_type": "-" 187 | } 188 | }, 189 | "outputs": [], 190 | "source": [ 191 | "import scipy.sparse\n", 192 | "import os\n", 193 | "MYDATA=os.path.join(DATA_DIR, 'notebooks')\n", 194 | " \n", 195 | "trn_smiles = np.load(os.path.join(MYDATA, 'training.enc8000.split_0.npy'))\n", 196 | "trn_fps = scipy.sparse.load_npz(os.path.join(MYDATA, 'training_fp.split_0.npz'))\n", 197 | "\n", 198 | "val_smiles = np.load(os.path.join(MYDATA, 'validation.enc8000.npy'))\n", 199 | "val_fps = scipy.sparse.load_npz(os.path.join(MYDATA,'validation_fp.npz'))\n", 200 | "\n", 201 | "itos_fn=os.path.join(DATA_DIR, 'pretrained', 'id.dec8000')\n", 202 | "itos = [s.strip() for i,s in enumerate(open(itos_fn, encoding='utf-8'))]\n", 203 | "vocab = Vocab(itos)" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "metadata": { 210 | "autoscroll": false, 211 | "ein.hycell": false, 212 | "ein.tags": "worksheet-0", 213 | "slideshow": { 214 | "slide_type": "-" 215 | } 216 | }, 217 | "outputs": [], 218 | "source": [ 219 | "#Let's train DESMILSE on 1% of 1/4 the data\n", 220 | "\n", 221 | "num_trn_smiles = trn_smiles.shape[0] \n", 222 | "trn_inds = np.random.permutation(np.arange(num_trn_smiles))\n", 223 | "num_to_keep = int(num_trn_smiles*0.01)\n", 224 | "trn_inds = trn_inds[:num_to_keep]\n", 225 | "\n", 226 | "num_val_smiles = val_smiles.shape[0] \n", 227 | "val_inds = np.random.permutation(np.arange(num_val_smiles))\n", 228 | "num_to_keep = int(num_val_smiles*0.01)\n", 229 | "val_inds = val_inds[:num_to_keep]" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": { 236 | "autoscroll": false, 237 | "ein.hycell": false, 238 | "ein.tags": "worksheet-0", 239 | "slideshow": { 240 | "slide_type": "-" 241 | } 242 | }, 243 | "outputs": [], 244 | "source": [ 245 | "# create data bunch\n", 246 | "bs=200\n", 247 | "trn_ds = FpSmilesList(trn_smiles[trn_inds], trn_fps[trn_inds], vocab)\n", 248 | "val_ds = FpSmilesList(val_smiles[val_inds], val_fps[val_inds], vocab)\n", 249 | "trn_dl = DesmilesLoader(trn_ds, bs=bs, vocab=vocab)\n", 250 | "val_dl = DesmilesLoader(val_ds, bs=bs, vocab=vocab)\n", 251 | "db = DataBunch(trn_dl, val_dl)" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": { 258 | "autoscroll": false, 259 | "ein.hycell": false, 260 | "ein.tags": "worksheet-0", 261 | "scrolled": true, 262 | "slideshow": { 263 | "slide_type": "-" 264 | } 265 | }, 266 | "outputs": [], 267 | "source": [ 268 | "from desmiles.utils import accuracy4\n", 269 | "architecture = {'fp_emb_sz': 200, 'emb_sz': 200, 'nh': 200, 'nl': 1}\n", 270 | "regularization = {'clip':0.3, 'alpha':2., 'beta':1.}\n", 271 | "\n", 272 | "# Training parameters\n", 273 | "max_lr = 0.001\n", 274 | "div_factor = 10.\n", 275 | "\n", 276 | "# 1) Create learner object\n", 277 | "learner = desmiles_model_learner(db, **architecture, **regularization)\n", 278 | "\n", 279 | "learner.metrics = [accuracy4]\n", 280 | "# 2) Specify training schedule\n", 281 | "one_cycle_linear_cb = OriginalFastaiOneCycleScheduler(learner, max_lr, div_factor=div_factor)\n", 282 | "# 3) Train \n", 283 | "learner.fit(50, callbacks=[one_cycle_linear_cb])\n", 284 | "# 4) Save model\n", 285 | "#learner.save('model_1')" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "metadata": { 292 | "autoscroll": false, 293 | "ein.hycell": false, 294 | "ein.tags": "worksheet-0", 295 | "slideshow": { 296 | "slide_type": "-" 297 | } 298 | }, 299 | "outputs": [], 300 | "source": [ 301 | "learner.recorder.plot_lr()" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": null, 307 | "metadata": { 308 | "autoscroll": false, 309 | "ein.hycell": false, 310 | "ein.tags": "worksheet-0", 311 | "slideshow": { 312 | "slide_type": "-" 313 | } 314 | }, 315 | "outputs": [], 316 | "source": [ 317 | "from desmiles.utils import decoder, image_of_mols\n", 318 | "from functools import partial\n", 319 | "\n", 320 | "# helper function to map from one-hot-encoded vector to smiles string\n", 321 | "decoder = partial(decoder, itos=itos)\n", 322 | "def smiles_idx_to_string(smiles_idx, decoder=decoder):\n", 323 | " return decoder(smiles_idx[smiles_idx > 0].tolist())" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": null, 329 | "metadata": { 330 | "autoscroll": false, 331 | "ein.hycell": false, 332 | "ein.tags": "worksheet-0", 333 | "slideshow": { 334 | "slide_type": "-" 335 | } 336 | }, 337 | "outputs": [], 338 | "source": [ 339 | "from desmiles.decoding.astar import AstarTreeParallelHybrid as AstarTree\n", 340 | "\n", 341 | "#learner.load('model_1')\n", 342 | "\n", 343 | "# Lets see if we at least fit our training set a bit\n", 344 | "(smiles_idx, fps, lengths), y = next(iter(trn_dl))\n", 345 | "test_smiles_idx = smiles_idx[:,-1]\n", 346 | "test_fp = fps[-1]\n", 347 | "test_smiles = smiles_idx_to_string(test_smiles_idx)\n", 348 | "image_of_mols([test_smiles])" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": null, 354 | "metadata": { 355 | "autoscroll": false, 356 | "ein.hycell": false, 357 | "ein.tags": "worksheet-0", 358 | "scrolled": true, 359 | "slideshow": { 360 | "slide_type": "-" 361 | } 362 | }, 363 | "outputs": [], 364 | "source": [ 365 | "model = learner.model\n", 366 | "model.eval()\n", 367 | "rdesmiles = RecurrentDESMILES(model)\n", 368 | "\n", 369 | "astar = AstarTree(test_fp.unsqueeze(0).to('cuda'), rdesmiles, num_expand=100)\n", 370 | "neg_log_prob, smiles_idx = next(astar)\n", 371 | "smiles = smiles_idx_to_string(smiles_idx)\n", 372 | "image_of_mols([smiles, test_smiles])" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": null, 378 | "metadata": { 379 | "scrolled": true 380 | }, 381 | "outputs": [], 382 | "source": [ 383 | "from desmiles.decoding.astar import AstarTreeParallelHybrid as AstarTree\n", 384 | "\n", 385 | "#learner.load('model_1')\n", 386 | "\n", 387 | "# Lets see if we at least fit our training set a bit\n", 388 | "(smiles_idx, fps, lengths), y = next(iter(val_dl))\n", 389 | "test_smiles_idx = smiles_idx[:,-1]\n", 390 | "test_fp = fps[-1]\n", 391 | "test_smiles = smiles_idx_to_string(test_smiles_idx)\n", 392 | "image_of_mols([test_smiles])" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": null, 398 | "metadata": { 399 | "scrolled": true 400 | }, 401 | "outputs": [], 402 | "source": [ 403 | "model = learner.model\n", 404 | "model.eval()\n", 405 | "rdesmiles = RecurrentDESMILES(model)\n", 406 | "\n", 407 | "astar = AstarTree(test_fp.unsqueeze(0).to('cuda'), rdesmiles, num_expand=100)\n", 408 | "neg_log_prob, smiles_idx = next(astar)\n", 409 | "smiles = smiles_idx_to_string(smiles_idx)\n", 410 | "image_of_mols([smiles, test_smiles])" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": null, 416 | "metadata": { 417 | "autoscroll": false, 418 | "ein.hycell": false, 419 | "ein.tags": "worksheet-0", 420 | "slideshow": { 421 | "slide_type": "-" 422 | } 423 | }, 424 | "outputs": [], 425 | "source": [ 426 | "model_fn = Path(DATA_DIR) / 'pretrained/model_2000_400_2000_5.h5'\n", 427 | "architecture = {'fp_emb_sz': 2000, 'emb_sz': 400, 'nh': 2000, 'nl': 5, 'clip':0.3, 'alpha':2., 'beta':1.}\n", 428 | "learner = load_old_pretrained_desmiles(model_fn, return_learner=True, **architecture)\n", 429 | "model = learner.model\n", 430 | "model.eval()\n", 431 | "# make a RecurrentDESMILES model\n", 432 | "model = RecurrentDESMILES(model)" 433 | ] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": null, 438 | "metadata": { 439 | "autoscroll": false, 440 | "ein.hycell": false, 441 | "ein.tags": "worksheet-0", 442 | "slideshow": { 443 | "slide_type": "-" 444 | } 445 | }, 446 | "outputs": [], 447 | "source": [ 448 | "from desmiles.utils import smiles_to_fingerprint\n", 449 | "validation_smiles = [s.strip() for s in open(os.path.join(DATA_DIR, 'pretrained', 'validation_smiles_10k.smi'))]\n", 450 | "inds = np.random.permutation(np.arange(len(validation_smiles)))\n", 451 | "i = 0\n", 452 | "smiles_to_invert = validation_smiles[inds[i]]\n", 453 | "fp = smiles_to_fingerprint(smiles_to_invert, as_tensor=True)\n", 454 | "image_of_mols([smiles_to_invert])" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": null, 460 | "metadata": { 461 | "autoscroll": false, 462 | "ein.hycell": false, 463 | "ein.tags": "worksheet-0", 464 | "slideshow": { 465 | "slide_type": "-" 466 | } 467 | }, 468 | "outputs": [], 469 | "source": [ 470 | "astar = AstarTree(fp.unsqueeze(0), model, num_expand=100)\n", 471 | "nlp, smiles_idx = next(astar)\n", 472 | "smiles = smiles_idx_to_string(smiles_idx)\n", 473 | "image_of_mols([smiles, smiles_to_invert])" 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": null, 479 | "metadata": { 480 | "autoscroll": false, 481 | "ein.hycell": false, 482 | "ein.tags": "worksheet-0", 483 | "slideshow": { 484 | "slide_type": "-" 485 | } 486 | }, 487 | "outputs": [], 488 | "source": [ 489 | "# model size\n", 490 | "np.sum([np.prod(p.shape) for p in model.parameters()])" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": null, 496 | "metadata": { 497 | "autoscroll": false, 498 | "ein.hycell": false, 499 | "ein.tags": "worksheet-0", 500 | "slideshow": { 501 | "slide_type": "-" 502 | } 503 | }, 504 | "outputs": [], 505 | "source": [ 506 | "%%time\n", 507 | "# Lets use the fast variant of A* to get 100 top solution\n", 508 | "astar = AstarTree(fp.unsqueeze(0), model, num_expand=1000, max_branches=5000)\n", 509 | "from collections import defaultdict\n", 510 | "scores = defaultdict(float)\n", 511 | "all_leaf_nodes = []\n", 512 | "for _ in range(1000):\n", 513 | " nlp, smiles_idx = next(astar)\n", 514 | " smiles = smiles_idx_to_string(smiles_idx)\n", 515 | " print(smiles, np.exp(-nlp))\n", 516 | " scores[smiles] += np.exp(-nlp)\n", 517 | " all_leaf_nodes.append(smiles)" 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "execution_count": null, 523 | "metadata": { 524 | "autoscroll": false, 525 | "ein.hycell": false, 526 | "ein.tags": "worksheet-0", 527 | "slideshow": { 528 | "slide_type": "-" 529 | } 530 | }, 531 | "outputs": [], 532 | "source": [ 533 | "sorted(scores.items(), key=lambda x: -x[1])" 534 | ] 535 | }, 536 | { 537 | "cell_type": "markdown", 538 | "metadata": {}, 539 | "source": [ 540 | "Don't forget to regenerate the DRD2 dataset if you haven't already done it.\n", 541 | "To do so, please run DESMILES/tests/download_drd2_dataset.sh /DESMILES/data/notebooks" 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": null, 547 | "metadata": { 548 | "autoscroll": false, 549 | "ein.hycell": false, 550 | "ein.tags": "worksheet-0", 551 | "slideshow": { 552 | "slide_type": "-" 553 | } 554 | }, 555 | "outputs": [], 556 | "source": [ 557 | "from drd2 import *" 558 | ] 559 | }, 560 | { 561 | "cell_type": "code", 562 | "execution_count": null, 563 | "metadata": { 564 | "autoscroll": false, 565 | "ein.hycell": false, 566 | "ein.tags": "worksheet-0", 567 | "slideshow": { 568 | "slide_type": "-" 569 | } 570 | }, 571 | "outputs": [], 572 | "source": [ 573 | "bs=200\n", 574 | "original_smile, train_fp, train_enc = load_training_data(raise_prob=True)\n", 575 | "db = create_databunch(train_fp, train_enc, itos_fn, bs)" 576 | ] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "execution_count": null, 581 | "metadata": { 582 | "autoscroll": false, 583 | "ein.hycell": false, 584 | "ein.tags": "worksheet-0", 585 | "slideshow": { 586 | "slide_type": "-" 587 | } 588 | }, 589 | "outputs": [], 590 | "source": [ 591 | "model_fn = Path(os.path.join(DATA_DIR, 'pretrained', 'model_2000_400_2000_5.h5'))\n", 592 | "learner = load_old_pretrained_desmiles(model_fn, return_learner=True)\n", 593 | "learner.metrics = [accuracy4]\n", 594 | "learner.data = db\n", 595 | "\n", 596 | "num_epochs = 5\n", 597 | "max_lr = 0.001\n", 598 | "div_factor = 7\n", 599 | "\n", 600 | "one_cycle_linear_cb = OriginalFastaiOneCycleScheduler(learner, max_lr, div_factor=div_factor)\n", 601 | "learner.fit(num_epochs, callbacks=[one_cycle_linear_cb])" 602 | ] 603 | }, 604 | { 605 | "cell_type": "code", 606 | "execution_count": null, 607 | "metadata": { 608 | "autoscroll": false, 609 | "ein.hycell": false, 610 | "ein.tags": "worksheet-0", 611 | "slideshow": { 612 | "slide_type": "slide" 613 | } 614 | }, 615 | "outputs": [], 616 | "source": [ 617 | "(val_smiles_idx, val_fps, _), _ = next(iter(db.valid_dl))" 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "execution_count": null, 623 | "metadata": { 624 | "autoscroll": false, 625 | "ein.hycell": false, 626 | "ein.tags": "worksheet-0", 627 | "slideshow": { 628 | "slide_type": "-" 629 | } 630 | }, 631 | "outputs": [], 632 | "source": [ 633 | "model = learner.model\n", 634 | "model.eval()\n", 635 | "model = RecurrentDESMILES(model)\n", 636 | "astar = AstarTree(val_fps[0].unsqueeze(0), model, num_expand=100)\n", 637 | "all_leaf_nodes = []\n", 638 | "for _ in range(100):\n", 639 | " nlp, smiles_idx = next(astar)\n", 640 | " smiles = smiles_idx_to_string(smiles_idx)\n", 641 | " print(smiles)\n", 642 | " all_leaf_nodes.append(smiles)" 643 | ] 644 | }, 645 | { 646 | "cell_type": "code", 647 | "execution_count": null, 648 | "metadata": { 649 | "autoscroll": false, 650 | "ein.hycell": false, 651 | "ein.tags": "worksheet-0", 652 | "slideshow": { 653 | "slide_type": "-" 654 | } 655 | }, 656 | "outputs": [], 657 | "source": [] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "execution_count": null, 662 | "metadata": {}, 663 | "outputs": [], 664 | "source": [] 665 | }, 666 | { 667 | "cell_type": "code", 668 | "execution_count": null, 669 | "metadata": {}, 670 | "outputs": [], 671 | "source": [] 672 | } 673 | ], 674 | "metadata": { 675 | "kernelspec": { 676 | "display_name": "Python 3 (ipykernel)", 677 | "language": "python", 678 | "name": "python3" 679 | }, 680 | "language_info": { 681 | "codemirror_mode": { 682 | "name": "ipython", 683 | "version": 3 684 | }, 685 | "file_extension": ".py", 686 | "mimetype": "text/x-python", 687 | "name": "python", 688 | "nbconvert_exporter": "python", 689 | "pygments_lexer": "ipython3", 690 | "version": "3.7.12" 691 | }, 692 | "name": "DESMILES.ipynb" 693 | }, 694 | "nbformat": 4, 695 | "nbformat_minor": 2 696 | } 697 | -------------------------------------------------------------------------------- /lib-python/desmiles/decoding/astar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import queue 4 | from itertools import count 5 | import numpy as np 6 | from ..models import RecurrentDESMILES 7 | 8 | class AstarTree: 9 | def __init__(self, fp, desmiles, max_length=30, max_branches=5000): 10 | assert type(fp) is torch.Tensor 11 | assert type(desmiles) is RecurrentDESMILES 12 | assert next(desmiles.parameters()).device == fp.device 13 | if fp.dim() == 1: 14 | fp = fp.unsqueeze(0) 15 | self.fp = fp 16 | self.desmiles = desmiles 17 | 18 | self.max_length = max_length 19 | self.max_branches = max_branches 20 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 21 | self.unique_identifier = count() # counter breaks ties between equivalent probabilities 22 | self.non_leaf_queue = queue.PriorityQueue() 23 | self.leaf_queue = queue.PriorityQueue() 24 | 25 | self.desmiles.embed_fingerprints(fp) # initialize rnn_desmiles by embedding the fingerprint 26 | root = self.initialize_root_node() 27 | 28 | self.non_leaf_queue.put((torch.tensor(0.0, device=self.device), next(self.unique_identifier), root)) 29 | self.leaf_queue.put((torch.tensor(float("inf"), device=self.device), next(self.unique_identifier), root)) 30 | self.num_expand = 1 31 | self.num_branches = 0 32 | 33 | 34 | @classmethod 35 | def create_from_astar_tree(cls, other): 36 | astar_tree = cls(other.fp, other.desmiles, max_length=other.max_length, max_branches=other.max_branches) 37 | astar_tree.non_leaf_queue = other.non_leaf_queue 38 | astar_tree.leaf_queue = other.leaf_queue 39 | return astar_tree 40 | 41 | def initialize_root_node(self): 42 | root = torch.zeros(self.max_length, dtype=torch.long, device=self.device) 43 | root[0] = 3 44 | return root 45 | 46 | def return_leaf_node(self): 47 | score, _, seq = self.leaf_queue.get() # pop the top, decode and yield. 48 | self.last_leaf_node = seq 49 | return score, seq 50 | 51 | def __next__(self): 52 | # If no more room for expansion 53 | while True: 54 | if self.num_branches > self.max_branches: 55 | if self.leaf_queue.qsize() > 0: 56 | return self.return_leaf_node() 57 | else: 58 | return np.inf, self.last_leaf_node 59 | # Return a leaf node if it has the best score 60 | if self.leaf_queue.queue[0][0] < self.non_leaf_queue.queue[0][0]: 61 | return self.return_leaf_node() 62 | # Otherwise, purge queue and branch 63 | self.purge_queue() # purge non-leaf queue if necessary 64 | self.branch_and_bound() 65 | 66 | def branch_and_bound(self): 67 | # get nodes to expand 68 | seqs, scores = self.branch() 69 | #seqs, scores = self.post_branch_callback(seqs, scores) 70 | self.bound(seqs, scores) 71 | self.num_branches += 1 72 | 73 | 74 | def branch(self): 75 | # Create scores and sequences for the num_expand nodes we will expand 76 | # it doesn't actually do the branching. 77 | scores = [] 78 | seqs = [] 79 | num_nodes = self.non_leaf_queue.qsize() 80 | while(len(scores) < min(self.num_expand, num_nodes)): 81 | score, _, seq = self.non_leaf_queue.get() 82 | scores.append(score) 83 | seqs.append(seq) 84 | scores = torch.tensor(scores, device=self.device) 85 | seqs = torch.stack(seqs) 86 | return seqs, scores 87 | 88 | def bound(self, seqs, scores): 89 | # Perform the branch operation and then bound the results: 90 | with torch.no_grad(): 91 | # Clone hidden states which are the embedded fingerprints only. 92 | # This only needs as argument the len(seqs) [or seq.shape[0]] 93 | hiddens = self.clone_hidden_states(seqs) 94 | # Get the probabilities for all the children 95 | # The call to get_log_probabilities will overwrite the hidden states based on the sequences. 96 | log_probs = self.get_log_probabilities(seqs) 97 | # reset the hidden states to those of the embedded fingerprints. 98 | self.desmiles.hiddens = hiddens 99 | seqs, scores = self.get_children(seqs, log_probs, scores) 100 | self.add_children_to_queue(seqs, scores) 101 | 102 | 103 | def add_children_to_queue(self, seqs, scores): 104 | # sort scores and grab the first max_branches to add to the two separate queues. 105 | sort_idx = self.sort_scores(scores, self.max_branches) 106 | scores = scores[sort_idx] 107 | seqs = seqs[sort_idx] # this is a 2D tensor with dimensions: (8000 x num_expanded_children) x 30 (padded sequences) 108 | is_leaf_node = self.are_children_leaf_nodes(seqs) 109 | for i, (score, child) in enumerate(zip(scores[is_leaf_node].tolist(), seqs[is_leaf_node])): 110 | self.leaf_queue.put((score, next(self.unique_identifier), child)) 111 | for i, (score, child) in enumerate(zip(scores[~is_leaf_node].tolist(), seqs[~is_leaf_node])): 112 | self.non_leaf_queue.put((score, next(self.unique_identifier), child)) 113 | 114 | def are_children_leaf_nodes(self, children): 115 | # the -1 is a hack for using the index in the last_chars line. Could revert back to actual legths. 116 | lengths = (children > 0).sum(dim=1) - 1 117 | last_chars = torch.tensor([child[length] for child, length in zip(children, lengths)], device=self.device) 118 | # last_chars == 0 means a pad character was chosen. for now I call this a leaf node so that it is not expanded further 119 | # The special characters are hard coded here: 120 | last_char_is_stop = (last_chars == 1) | (last_chars == 2) | (last_chars == 0) 121 | # this leaf node check only runs up to 30 elements, hardcoded (differs from earlier leaf_node check). 122 | is_leaf_node = (((lengths + 1) > 3) & last_char_is_stop) | ((lengths + 1) == 30) 123 | return is_leaf_node 124 | 125 | @staticmethod 126 | def sort_scores(scores, max_branches): 127 | return torch.sort(scores)[1][:max_branches] 128 | 129 | def get_children(self, seqs, log_probs, parent_nlps): 130 | lengths = (seqs > 0).sum(dim=1) 131 | children = torch.arange(log_probs.size(1), device=self.device, dtype=torch.long)[None,:].expand(seqs.shape[0], log_probs.size(1)) 132 | new_seqs = [] 133 | for seq, child, length in zip(seqs, children, lengths): 134 | seq = seq.expand(child.size(0), seq.size(0)) 135 | seq=torch.cat((seq[:,:length], child[:,None], seq[:,length:-1]), dim=1) 136 | new_seqs.append(seq) 137 | new_seqs = torch.stack(new_seqs) 138 | new_seqs = new_seqs.reshape(-1, new_seqs.shape[-1]) 139 | scores = (parent_nlps[:,None] - log_probs).reshape(-1) 140 | return new_seqs, scores 141 | 142 | def clone_hidden_states(self, seqs): 143 | num_sequences = seqs.shape[0] 144 | if num_sequences != self.desmiles.bs: 145 | self.desmiles.select_hidden(torch.zeros(num_sequences, dtype=torch.long)) 146 | self.desmiles.bs = num_sequences 147 | hiddens = [(h[0].clone(),h[1].clone()) for h in self.desmiles.hiddens] 148 | return hiddens 149 | 150 | def get_log_probabilities(self, seqs): 151 | logits = self.desmiles(seqs.transpose(0,1)) 152 | lengths = (seqs > 0).sum(dim=1) - 1 153 | # get the energy corresponding to the next token (specified by lengths) 154 | logits = torch.stack([logits[i,l] for i,l in zip(np.arange(lengths.shape[0]), lengths.tolist())]) 155 | assert(logits.dim() == 2) 156 | return F.log_softmax(logits, dim=1) 157 | 158 | def purge_queue(self, downto=10000, maxsize=1000000): ## will need to update to also keep mol_queue mols. 159 | if self.non_leaf_queue.qsize() > maxsize: 160 | print("PURGING") 161 | q2 = queue.PriorityQueue() 162 | ## Get top elements into new queue 163 | for i in range(downto): 164 | q2.put(self.non_leaf_queue.get()) 165 | self.non_leaf_queue = q2 166 | 167 | 168 | class AstarTreeParallel(AstarTree): 169 | def __init__(self, fp, desmiles, max_length=30, max_branches=5000, num_expand=1): 170 | super().__init__(fp, desmiles, max_length=30, max_branches=5000) 171 | self.num_expand = num_expand 172 | 173 | def branch(self): 174 | seqs, scores = super().branch() 175 | return self.sort_by_length(seqs, scores) 176 | 177 | @staticmethod 178 | def sort_by_length(seqs, scores): 179 | lengths = (seqs > 0).sum(dim=1) 180 | length_idx = AstarTreeParallel.sort_lengths(lengths) 181 | seqs = seqs[length_idx] 182 | scores = scores[length_idx] 183 | return seqs, scores 184 | 185 | @staticmethod 186 | def sort_lengths(lengths): 187 | return torch.sort(-lengths)[1] 188 | 189 | 190 | 191 | class AstarTreeParallelNotSafe: 192 | def __init__(self, fp, desmiles, max_length=30, max_branches=5000, num_expand=1): 193 | assert type(fp) is torch.Tensor 194 | assert type(desmiles) is RecurrentDESMILES 195 | assert next(desmiles.parameters()).device == fp.device 196 | if fp.dim() == 1: 197 | fp = fp.unsqueeze(0) 198 | self.fp = fp 199 | self.desmiles = desmiles 200 | 201 | self.max_length = max_length 202 | self.max_branches = max_branches 203 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 204 | self.unique_identifier = count() # counter breaks ties between equivalent probabilities 205 | self.non_leaf_queue = queue.PriorityQueue() 206 | self.leaf_queue = queue.PriorityQueue() 207 | 208 | self.desmiles.embed_fingerprints(fp) # initialize rnn_desmiles by embedding the fingerprint 209 | root = self.initialize_root_node() 210 | 211 | self.non_leaf_queue.put((torch.tensor(0.0, device=self.device), next(self.unique_identifier), root)) 212 | self.leaf_queue.put((torch.tensor(float("inf"), device=self.device), next(self.unique_identifier), root)) 213 | self.num_expand = 1 214 | 215 | self.num_expand = num_expand 216 | self.node_to_hiddens = {} 217 | score, ident, root = self.non_leaf_queue.get() 218 | self.node_to_hiddens[root] = (self.clone_hiddens(), 0) 219 | self.non_leaf_queue.put((score, ident, root)) 220 | self.num_expand = num_expand 221 | self.num_branches = 0 222 | 223 | def initialize_root_node(self): 224 | # The sequences are kept in reverse order from right to left starting from the last element; 225 | # This is only happening for the AstarTreeParallelNotSafe version. 226 | # In this way we don't need to look at the length vectors since some sequences will have differing lengths. 227 | # In the memory safe way, we can keep the sequences right padded. 228 | root = torch.zeros(self.max_length, dtype=torch.long, device=self.device) 229 | root[-1] = 3 230 | return root 231 | 232 | def clone_hiddens(self): 233 | return [(h[0].clone(),h[1].clone()) for h in self.desmiles.hiddens] 234 | 235 | def __next__(self): 236 | # If no more room for expansion 237 | while True: 238 | # In this version of Astar we don't keep expanding if we've reached the max branches. 239 | # This differs from the logic of the early astar algorithm used in the paper. 240 | if self.num_branches > self.max_branches: 241 | if self.leaf_queue.qsize() > 0: 242 | return self.return_leaf_node() 243 | else: 244 | return np.inf, self.last_leaf_node 245 | # Return a leaf node if it has the best score 246 | if self.leaf_queue.queue[0][0] < self.non_leaf_queue.queue[0][0]: 247 | return self.return_leaf_node() 248 | # Otherwise, purge queue and branch 249 | self.purge_queue() # purge non-leaf queue if necessary 250 | self.branch_and_bound() 251 | 252 | def branch_and_bound(self): 253 | # get nodes to expand 254 | seqs, scores = self.branch() 255 | #seqs, scores = self.post_branch_callback(seqs, scores) 256 | self.bound(seqs, scores) 257 | self.num_branches += 1 258 | 259 | 260 | def branch(self): 261 | # Prepare for the branching operation 262 | # This branch is more complex than the plain Astar because it handles hidden states. 263 | # Collect the scores, seqs, hiddens, etc for up to num_expand nodes 264 | scores = [] 265 | seqs = [] 266 | hiddens = [] 267 | # When you branch a node every child has its own hidden state branching from the same parent hidden state 268 | # this index maps back to the hidden state of the parent. 269 | hidden_idxs = [] 270 | num_nodes = self.non_leaf_queue.qsize() 271 | # This is a dictionary from the node tensor to a hidden state and a hidden_idx. 272 | node_to_hiddens = self.node_to_hiddens 273 | while(len(scores) < min(self.num_expand, num_nodes)): 274 | score, _, seq = self.non_leaf_queue.get() 275 | scores.append(score) 276 | seqs.append(seq) 277 | hidden, idx = node_to_hiddens[seq] 278 | hiddens.append(hidden) 279 | hidden_idxs.append(idx) 280 | scores = torch.tensor(scores, device=self.device) 281 | seqs = torch.stack(seqs) 282 | # select all the parent hidden states 283 | hiddens = AstarTreeParallelNotSafe.select_all_hiddens(hiddens, hidden_idxs) 284 | # concatenate all the states so we can batch evaluate them. 285 | hiddens = AstarTreeParallelNotSafe.concat_hiddens(hiddens) 286 | # set the hidden states 287 | self.desmiles.hiddens = hiddens 288 | # set the batch size (seqs legnth is same as length of hiddens[0][0]; first layer; cell state of hidden) 289 | self.desmiles.bs = seqs.shape[0] 290 | return seqs, scores 291 | 292 | def bound(self, seqs, scores): 293 | with torch.no_grad(): 294 | # hidden states are ready, so this only passes the last elements to the desmiles model 295 | log_probs = self.get_log_probabilities(seqs[:,-1].unsqueeze(0)) 296 | # Get an index for what parent that child was from. 297 | # Important for getting the hidden states back. 298 | seq_idx = self.get_seq_idx(log_probs) 299 | # Make the sequences and the scores of the children 300 | seqs, scores = self.get_children(seqs, log_probs, scores) 301 | self.add_children_to_queue(seqs, scores, seq_idx) 302 | 303 | def get_log_probabilities(self, seqs): 304 | logits = self.desmiles(seqs) 305 | return F.log_softmax(logits[:,-1], dim=-1) 306 | 307 | 308 | def add_children_to_queue(self, seqs, scores, seq_idxs): 309 | node_to_hiddens = self.node_to_hiddens 310 | hiddens = self.clone_hiddens() 311 | sort_idx = AstarTreeParallelNotSafe.sort_scores(scores, self.max_branches) 312 | scores = scores[sort_idx] 313 | seqs = seqs[sort_idx] 314 | seq_idxs = seq_idxs[sort_idx] 315 | is_leaf_node = self.are_children_leaf_nodes(seqs) 316 | for i, (score, child, seq_idx) in enumerate(zip(scores[is_leaf_node].tolist(), seqs[is_leaf_node], seq_idxs[is_leaf_node].tolist())): 317 | node_to_hiddens[child] = (hiddens, seq_idx) 318 | self.leaf_queue.put((score, next(self.unique_identifier), child)) 319 | for i, (score, child, seq_idx) in enumerate(zip(scores[~is_leaf_node].tolist(), seqs[~is_leaf_node], seq_idxs[~is_leaf_node].tolist())): 320 | node_to_hiddens[child] = (hiddens, seq_idx) 321 | self.non_leaf_queue.put((score, next(self.unique_identifier), child)) 322 | 323 | @staticmethod 324 | def sort_scores(scores, max_branches): 325 | return torch.sort(scores)[1][:max_branches] 326 | 327 | 328 | def get_children(self, seqs, log_probs, parent_nlps): 329 | children = torch.arange(log_probs.size(1), device=self.device, dtype=torch.long)[None,:].expand(seqs.shape[0], log_probs.size(1))[:,:,None] 330 | children = torch.cat([seqs[:,None,:].expand(seqs.size(0), log_probs.size(1), seqs.size(1)), children],dim=2) 331 | children = children.reshape(-1, children.shape[-1])[:,1:] 332 | scores = (parent_nlps[:,None] - log_probs).reshape(-1) 333 | return children, scores 334 | 335 | def are_children_leaf_nodes(self, children): 336 | lengths = (children > 0).sum(dim=1) - 1 337 | last_chars = children[:,-1] 338 | last_char_is_stop = (last_chars == 1) | (last_chars == 2) 339 | is_leaf_node = (((lengths + 1) > 3) & last_char_is_stop) | ((lengths + 1) == 30) | (last_chars == 0) 340 | return is_leaf_node 341 | 342 | @staticmethod 343 | def select_all_hiddens(hiddens, idxs): 344 | return [AstarTreeParallelNotSafe.select_hiddens(hidden, idx) for hidden, idx in zip(hiddens, idxs)] 345 | 346 | @staticmethod 347 | def select_hiddens(hiddens, idx, move_to_device=False): 348 | return [(h[0][:,idx:idx+1,:],h[1][:,idx:idx+1,:]) for h in hiddens] 349 | 350 | def get_seq_idx(self, log_probs): 351 | return torch.arange(log_probs.size(0), dtype=torch.long, device=self.device).unsqueeze(1).expand(log_probs.shape[0], log_probs.shape[1]).reshape(-1) 352 | 353 | 354 | @staticmethod 355 | def concat_hiddens(hiddens): 356 | layers = [] 357 | for layer in range(len(hiddens[0])): 358 | hidden_states = torch.cat([h[layer][0] for h in hiddens], dim=1) 359 | cell_states = torch.cat([h[layer][1] for h in hiddens], dim=1) 360 | layers.append((hidden_states, cell_states)) 361 | return layers 362 | 363 | def purge_queue(self, downto=10000, maxsize=1000000): 364 | if self.non_leaf_queue.qsize() > maxsize: 365 | q2 = queue.PriorityQueue() 366 | ## Get top elements into new queue 367 | for i in range(downto): 368 | q2.put(self.non_leaf_queue.get()) 369 | self.non_leaf_queue = q2 370 | 371 | def return_leaf_node(self): 372 | score, _, seq = self.leaf_queue.get() # pop the top, decode and yield. 373 | self.last_leaf_node = seq 374 | return score, seq 375 | 376 | def left_pad_to_right_pad(seqs): 377 | lengths = (seqs > 0).sum(dim=1) 378 | ns, sl = seqs.shape 379 | new_seqs = torch.zeros_like(seqs) 380 | for i,(s,l) in enumerate(zip(seqs, lengths)): 381 | new_seqs[i,:l] = s[s>0] 382 | return new_seqs 383 | 384 | class AstarTreeParallelHybrid(AstarTreeParallelNotSafe): 385 | @staticmethod 386 | def make_queue_with_right_padding(queue): 387 | scores = [] 388 | ids = [] 389 | tensors = [] 390 | for score, i, t in queue.queue: 391 | scores.append(score) 392 | ids.append(i) 393 | tensors.append(t) 394 | tensors = left_pad_to_right_pad(torch.stack(tensors)) 395 | new_queue = [(s,i,t) for s,i,t in zip(scores, ids, tensors)] 396 | queue.queue = new_queue 397 | return queue 398 | 399 | def __next__(self): 400 | safe_branch_thresh = 500 401 | # If no more room for expansion 402 | while True: 403 | if self.num_branches == safe_branch_thresh: 404 | print("Switching to memory safe queue!") 405 | self.purge_queue(maxsize=10000) 406 | self.non_leaf_queue = AstarTreeParallelHybrid.make_queue_with_right_padding(self.non_leaf_queue) 407 | self.leaf_queue = AstarTreeParallelHybrid.make_queue_with_right_padding(self.leaf_queue) 408 | 409 | self.safe_astar = AstarTreeParallel.create_from_astar_tree(self) 410 | self.safe_astar.num_expand = self.num_expand 411 | self.safe_astar.desmiles.embed_fingerprints(self.fp) # initialize rnn_desmiles by embedding the fingerprint 412 | del self.node_to_hiddens 413 | torch.cuda.empty_cache() 414 | if self.num_branches < safe_branch_thresh: 415 | if self.num_branches > self.max_branches: 416 | if self.leaf_queue.qsize() > 0: 417 | return self.return_leaf_node() 418 | else: 419 | return np.inf, self.last_leaf_node 420 | # Return a leaf node if it has the best score 421 | if self.leaf_queue.queue[0][0] < self.non_leaf_queue.queue[0][0]: 422 | return self.return_leaf_node() 423 | # Otherwise, purge queue and branch 424 | self.purge_queue() # purge non-leaf queue if necessary 425 | self.branch_and_bound() 426 | else: 427 | self.num_branches += 1 428 | return next(self.safe_astar) 429 | 430 | ######################## 431 | ### OLD ASTAR CODE 432 | ######################## 433 | 434 | def get_astar_tree(fp, model, decoder, isLeafNode, max_length=30, max_branches=5000): 435 | assert type(fp) is torch.Tensor 436 | #import pdb; pdb.set_trace() 437 | partial_mol_queue = queue.PriorityQueue() 438 | # Initialize string with (3) 439 | current_string = [3] 440 | partial_mol_queue.put((0, tuple(current_string))) 441 | # Add and initialize a Complete molecule queue 442 | mol_queue = queue.PriorityQueue() 443 | mol_queue.put((np.inf, tuple(current_string))) 444 | num_branches = 1 445 | last_smiles_string = '' 446 | last_smiles_vector = None 447 | device = "cuda" if torch.cuda.is_available() else "cpu" 448 | while True: 449 | if num_branches > max_branches: # no more room for expansion 450 | if mol_queue.qsize() > 4: # if we have enough addional molecules (which we should have) 451 | score, smiles_tup = mol_queue.get() # pop the top, decode and yield. 452 | smiles_string = decoder(list(smiles_tup)) 453 | last_smiles_string = smiles_string 454 | last_smiles_vector = np.asarray(smiles_tup) 455 | yield score, smiles_string, last_smiles_vector, num_branches 456 | else: 457 | yield np.inf, last_smiles_string, last_smiles_vector, num_branches 458 | partial_mol_queue = purgedQueue(partial_mol_queue) 459 | # grab the queue with the higher priority 460 | que_to_pop = partial_mol_queue if partial_mol_queue.queue[0][0] < mol_queue.queue[0][0] else mol_queue 461 | score, smiles_tup = que_to_pop.get() 462 | if isLeafNode(smiles_tup): 463 | smiles_string = decoder(list(smiles_tup)) 464 | last_smiles_string = smiles_string 465 | last_smiles_vector = np.asarray(smiles_tup) 466 | yield score, smiles_string, last_smiles_vector, num_branches 467 | else: 468 | num_branches += 1 469 | children = getChildren(fp, list(smiles_tup), model, score, device=device) 470 | for i, (score, child) in enumerate(children): 471 | if (i==1) or (i==2): 472 | # change if (len(child) < 3) or (child[0] == child[-1]): 473 | if (len(child) < 4) or (child[1] == child[-1]): 474 | partial_mol_queue.put((score, tuple(child))) 475 | continue 476 | mol_queue.put((score, tuple(child))) # the begin and end subSMILES 477 | else: 478 | partial_mol_queue.put((score, tuple(child))) 479 | 480 | 481 | def getChildren(fp, smiles_list, model, parent_score, max_length=30, device="cpu"): 482 | import torch.nn.functional as F 483 | lengths = [len(smiles_list)] 484 | inp_smiles = np.zeros((max_length, 1), dtype=np.int32) 485 | inp_smiles[:len(smiles_list), 0] = np.asarray(smiles_list) 486 | fp = fp[None] 487 | with torch.no_grad(): 488 | energies, _, _ = model(torch.tensor(inp_smiles, dtype=torch.long, device=device), fp, torch.tensor(lengths, dtype=torch.long, device=device)) 489 | log_probs = F.log_softmax(energies[-1],dim=-1).data 490 | children = [] 491 | for i, log_prob in enumerate(log_probs): 492 | child = smiles_list.copy() 493 | child.append(i) 494 | children.append((parent_score - log_prob, child)) 495 | return children 496 | 497 | 498 | def purgedQueue(q, downto=10000, maxsize=1000000): ## will need to update to also keep mol_queue mols. 499 | if q.qsize() > maxsize: 500 | q2 = queue.PriorityQueue() 501 | ## Get top elements into new queue 502 | for i in range(downto): 503 | q2.put(q.get()) 504 | return q2 505 | return q 506 | -------------------------------------------------------------------------------- /Notebooks/intro_demo_of_DESMILES.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "*start this notebook server on a machine with a GPU; optionally, use a modern rdkit version* " 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from IPython.display import Image\n", 17 | "import os\n", 18 | "from desmiles.config import DATA_DIR\n", 19 | "\n", 20 | "fig_dir = os.path.join(DATA_DIR, 'notebooks', 'Figures')" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "The high level idea behind DESMILES is that if we can learn to generate small molecules from a reduced small molecule representation that has been very successful in modelling structure activity relationships, then we will be able to generate useful molecules for a variety of practical tasks in drug discovery. Furthermore, if this representation learning is encoding the chemical similarity, then we will be able to easily generate chemically similar molecules starting from any molecule. The following two images from the DESMILES publication show the outline of the model and the representation of a slice of chemical space. The notebook below demonstrates the basic functionality of the model." 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "img = Image(filename=f\"{fig_dir}/deep learn chem space (desmiles)__extended data fig 5__3__mcgillen__2019__.png\")\n", 37 | "display(img)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": { 44 | "scrolled": false 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "img = Image(filename=f\"{fig_dir}/deep learn chem space (desmiles)__extended data fig 1__1__maragakis__2019__.png\")\n", 49 | "display(img)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "# Demo of DESMILES" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "The next couple of cell define some high level code for generating and displaying new molecules." 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "import sys\n", 73 | "import os\n", 74 | "import argparse\n", 75 | "import tempfile\n", 76 | "import multiprocessing\n", 77 | "import functools\n", 78 | "import subprocess\n", 79 | "from pathlib import Path\n", 80 | "\n", 81 | "\n", 82 | "import numpy as np\n", 83 | "import pandas as pd\n", 84 | "import scipy\n", 85 | "\n", 86 | "\n", 87 | "from rdkit import Chem\n", 88 | "from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect\n", 89 | "\n", 90 | "import desmiles\n", 91 | "from desmiles.data import Vocab, FpSmilesList, DesmilesLoader, DataBunch\n", 92 | "from desmiles.learner import desmiles_model_learner\n", 93 | "from desmiles.models import Desmiles, RecurrentDESMILES\n", 94 | "from desmiles.models import get_fp_to_embedding_model, get_embedded_fp_to_smiles_model\n", 95 | "from desmiles.utils import load_old_pretrained_desmiles, load_pretrained_desmiles\n", 96 | "from desmiles.utils import accuracy4\n", 97 | "from desmiles.utils import smiles_idx_to_string\n", 98 | "from desmiles.learner import OriginalFastaiOneCycleScheduler, Learner\n", 99 | "from desmiles.decoding.astar import AstarTreeParallelHybrid as AstarTree\n", 100 | "from desmiles.config import DATA_DIR\n", 101 | "\n", 102 | "import torch" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "from functools import partial\n", 112 | "from IPython.display import SVG, display, HTML\n", 113 | "import rdkit.Chem\n", 114 | "from rdkit.Chem import Draw,rdMolDescriptors,AllChem, rdDepictor\n", 115 | "from rdkit.Chem.Draw import IPythonConsole, rdMolDraw2D\n", 116 | "import matplotlib.pyplot as plt\n", 117 | "\n", 118 | "\n", 119 | "table_from_to_template='
{} --> {}
'\n", 120 | "table_4_template='
{} {} {} {}
'\n", 121 | "\n", 122 | "\n", 123 | "def canon_smiles(x):\n", 124 | " s = ''\n", 125 | " try: \n", 126 | " s = Chem.CanonSmiles(x, useChiral=True)\n", 127 | " except:\n", 128 | " pass\n", 129 | " return s\n", 130 | "\n", 131 | "\n", 132 | "def get_itos_8k():\n", 133 | " return np.load(os.path.join(DATA_DIR, 'pretrained', \"itos.npy\"))\n", 134 | "\n", 135 | "\n", 136 | "def vec_to_smiles(idx_vec_inp, itos):\n", 137 | " \"\"\"Return a SMILES string from an index vector (deals with reversal)\"\"\"\n", 138 | " ##HACK TO WORK WITH NEWER VERSION 2020-06-08\n", 139 | " if idx_vec_inp[0] == 3:\n", 140 | " idx_vec = idx_vec_inp[1:]\n", 141 | " else:\n", 142 | " idx_vec = idx_vec_inp\n", 143 | " ##\n", 144 | " if idx_vec[0] == 1: # SMILES string is in fwd direction\n", 145 | " return ''.join(itos[x] for x in idx_vec if x > 3)\n", 146 | " if idx_vec[0] == 2: # SMILES string is in bwd direction\n", 147 | " #despot.Print(\"decoder: bwd direction\")\n", 148 | " return ''.join(itos[x] for x in idx_vec[::-1] if x > 3)\n", 149 | " else: # don't know how to deal with it---do your best\n", 150 | " print(\"decoder received an invalid start to the SMILES\", idx_vec)\n", 151 | " return ''.join(itos[x] for x in idx_vec if x > 3)\n", 152 | "\n", 153 | " \n", 154 | "def smiles_to_fingerprint(smiles_str, sparse=False, as_tensor=True):\n", 155 | " \"Return the desmiles fp\"\n", 156 | " rdmol = Chem.MolFromSmiles(smiles_str)\n", 157 | " fp = np.concatenate([\n", 158 | " np.asarray(GetMorganFingerprintAsBitVect(rdmol, 2, useChirality=True), dtype=np.uint8),\n", 159 | " np.asarray(GetMorganFingerprintAsBitVect(rdmol, 3, useChirality=True), dtype=np.uint8)])\n", 160 | " if sparse:\n", 161 | " return scipy.sparse.csr_matrix(fp)\n", 162 | " if as_tensor:\n", 163 | " import torch\n", 164 | " device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 165 | " return torch.tensor(fp.astype(np.float32)).to(device)\n", 166 | " return fp\n", 167 | "\n", 168 | "\n", 169 | "def barcode_fp(fp, width=8, height=0.5):\n", 170 | " fig = plt.figure()\n", 171 | " ax2 = fig.add_axes([0, 0, width, height], xticks=[], yticks=[])\n", 172 | " barprops = dict(aspect='auto', cmap=plt.cm.binary, interpolation='nearest')\n", 173 | " return ax2.imshow(fp.reshape((1, -1)), **barprops)\n", 174 | " \n", 175 | " \n", 176 | "def moltosvg(rdkit_mol, size_x=450, size_y=150):\n", 177 | " try:\n", 178 | " rdkit_mol.GetAtomWithIdx(0).GetExplicitValence()\n", 179 | " except RuntimeError:\n", 180 | " rdkit_mol.UpdatePropertyCache(False)\n", 181 | " try:\n", 182 | " mc_mol = rdMolDraw2D.PrepareMolForDrawing(rdkit_mol, kekulize=True)\n", 183 | " except ValueError: # <- can happen on a kekulization failure \n", 184 | " mc_mol = rdMolDraw2D.PrepareMolForDrawing(rdkit_mol, kekulize=False)\n", 185 | " drawer = rdMolDraw2D.MolDraw2DSVG(size_x, size_y)\n", 186 | " drawer.DrawMolecule(mc_mol)\n", 187 | " drawer.FinishDrawing()\n", 188 | " svg = drawer.GetDrawingText()\n", 189 | " # It seems that the svg renderer used doesn't quite hit the spec.\n", 190 | " # Here are some fixes to make it work in the notebook, although I think\n", 191 | " # the underlying issue needs to be resolved at the generation step\n", 192 | " return svg.replace('svg:','')\n", 193 | "\n", 194 | "\n", 195 | "def displayTable4SMILES(smiles, size_x=225, size_y=150, width=980):\n", 196 | " assert(len(smiles)==4)\n", 197 | " svgs = map(lambda x: moltosvg(Chem.MolFromSmiles(x), size_x=size_x, size_y=size_y), smiles)\n", 198 | " display(HTML(table_4_template.format(width, *svgs)))\n", 199 | " \n", 200 | " \n", 201 | "def procSMILES(sm):\n", 202 | " m = Chem.MolFromSmiles(sm)\n", 203 | " AllChem.Compute2DCoords(m)\n", 204 | " return m\n", 205 | "\n", 206 | "\n", 207 | "def imageOfMols(smiles_list, molsPerRow=4, subImgSize=(240,200), labels=None):\n", 208 | " mols = [procSMILES(sm) for sm in smiles_list]\n", 209 | " if labels is not None:\n", 210 | " labels = [str(x) for x in labels]\n", 211 | " img = Draw.MolsToGridImage(mols, molsPerRow=molsPerRow, subImgSize=subImgSize, useSVG=True, legends=labels)\n", 212 | " return img\n", 213 | "\n", 214 | "\n", 215 | "def imageOfMolsLabels(smiles_labels_list, molsPerRow=5, subImgSize=(200,200)):\n", 216 | " mols = [procSMILES(sm[0]) for sm in smiles_labels_list]\n", 217 | " labels = [str(sm[1]) for sm in smiles_labels_list]\n", 218 | " img = Draw.MolsToGridImage(mols, molsPerRow=molsPerRow, subImgSize=subImgSize, useSVG=True, legends=labels)\n", 219 | " return img\n", 220 | "\n", 221 | "\n", 222 | "from itertools import zip_longest\n", 223 | "def grouper(iterable, n, fillvalue=None):\n", 224 | " args = [iter(iterable)] * n\n", 225 | " return zip_longest(*args, fillvalue=fillvalue)\n" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "#from astar_purge import get_astar_tree, isLeafNode\n", 235 | "def get_most_probable_smiles(emb_fp, emfp_to_smiles, max_branches=50, num_expand=2000):\n", 236 | " with torch.no_grad():\n", 237 | " astar_tree = AstarTree(emb_fp, \n", 238 | " emfp_to_smiles, \n", 239 | " max_branches=max_branches, num_expand=num_expand)\n", 240 | " score, smile_idx = next(astar_tree)\n", 241 | " return smiles_idx_to_string(smile_idx)\n", 242 | "\n", 243 | "\n", 244 | "\n", 245 | "def get_upto_n_most_probable_valid_smiles(n, emb_fp, emfp_to_smiles, max_branches=50, num_expand=2000):\n", 246 | " with torch.no_grad():\n", 247 | " astar_tree = AstarTree(emb_fp, #.astype(np.float32), \n", 248 | " emfp_to_smiles, \n", 249 | " max_branches=max_branches, num_expand=num_expand)\n", 250 | " for _ in range(n):\n", 251 | " smile = smiles_idx_to_string(next(astar_tree)[1])\n", 252 | " if Chem.MolFromSmiles(smile):\n", 253 | " yield smile\n", 254 | " else:\n", 255 | " yield \"\"\n", 256 | "\n", 257 | " \n", 258 | "def get_first_n_most_probable_valid_smiles(n, emb_fp, emfp_to_smiles, \n", 259 | " max_branches=50, max_search=200,\n", 260 | " num_expand=2000,\n", 261 | " verbose=False):\n", 262 | " results = set()\n", 263 | " with torch.no_grad():\n", 264 | " astar_tree = AstarTree(emb_fp, \n", 265 | " emfp_to_smiles, \n", 266 | " max_branches=max_branches, num_expand=num_expand)\n", 267 | " for i in range(max_search):\n", 268 | " if len(results) == n:\n", 269 | " break\n", 270 | " score, smile = next(astar_tree)\n", 271 | " smile = smiles_idx_to_string(smile)\n", 272 | " if verbose:\n", 273 | " print(score, smile)\n", 274 | " if (smile is not None) and Chem.MolFromSmiles(smile):\n", 275 | " smile = Chem.CanonSmiles(smile)\n", 276 | " results.add(smile)\n", 277 | " yield smile\n", 278 | " else:\n", 279 | " yield \"\"\n", 280 | "\n", 281 | " \n", 282 | "def set_of_upto_n_most_probable_valid_smiles(n, emb_fp, embfp_to_smiles, max_branches=50, num_expand=2000):\n", 283 | " result = list(get_upto_n_most_probable_valid_smiles(n, emb_fp, embfp_to_smiles, \n", 284 | " max_branches, num_expand))\n", 285 | " return set([Chem.CanonSmiles(x) for x in result if x != \"\"])\n", 286 | "\n", 287 | " \n", 288 | "def get_most_probable_valid_smiles(emb_fp, emfp_to_smiles, max_branches=50, num_expand=2000):\n", 289 | " counter = 0\n", 290 | " with torch.no_grad():\n", 291 | " astar_tree = AstarTree(emb_fp, emfp_to_smiles,\n", 292 | " max_branches=max_branches,\n", 293 | " num_expand=num_expand)\n", 294 | " while counter < 5:\n", 295 | " smile = smiles_idx_to_string(next(astar_tree)[1])\n", 296 | " if Chem.MolFromSmiles(smile):\n", 297 | " return smile\n", 298 | " counter += 1\n", 299 | " return smile\n", 300 | " \n", 301 | " \n", 302 | "def dedup(seq):\n", 303 | " seen = set()\n", 304 | " seen_add = seen.add\n", 305 | " return [x for x in seq if not (x in seen or seen_add(x))]\n", 306 | "\n", 307 | "\n", 308 | "def clean(seq):\n", 309 | " return [x for x in seq if x]\n", 310 | "\n", 311 | "\n", 312 | "def dedup_clean(seq):\n", 313 | " return dedup(clean(seq))\n", 314 | "\n", 315 | "\n", 316 | "def to_numpy_int(f1):\n", 317 | " return f1.cpu().numpy().astype(np.int32) if isinstance(f1, torch.Tensor) else f1\n", 318 | "\n", 319 | "\n", 320 | "def tanimoto(f1, f2):\n", 321 | " s1 = to_numpy_int(f1)\n", 322 | " s2 = to_numpy_int(f2)\n", 323 | " return np.sum(s1 & s2) / np.sum(s1 | s2)\n", 324 | "\n", 325 | "\n", 326 | "def displayTable4labels(labels, width=980):\n", 327 | " display(HTML(table_4_template.format(width, *labels)))\n", 328 | "\n" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": null, 334 | "metadata": {}, 335 | "outputs": [], 336 | "source": [] 337 | }, 338 | { 339 | "cell_type": "markdown", 340 | "metadata": {}, 341 | "source": [ 342 | "##
Load some data
\n", 343 | "\n", 344 | "First load the encoded version of the training set and the validation set" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": null, 350 | "metadata": {}, 351 | "outputs": [], 352 | "source": [ 353 | "itos = get_itos_8k()\n", 354 | "vec_to_smiles_8k = partial(vec_to_smiles, itos=itos)\n", 355 | "\n", 356 | "def random_smiles_enc8k(enc_table, n=2):\n", 357 | " idx = np.random.randint(0, len(enc_table), n)\n", 358 | " return [(vec_to_smiles_8k(enc_table[i]), enc_table[i][enc_table[i]>0]) for i in idx]\n", 359 | "\n", 360 | "def random_smiles(enc_table, n=10):\n", 361 | " return [s for s,e in random_smiles_enc8k(enc_table, n)]" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": null, 367 | "metadata": { 368 | "autoscroll": false, 369 | "ein.hycell": false, 370 | "ein.tags": "worksheet-0", 371 | "slideshow": { 372 | "slide_type": "-" 373 | } 374 | }, 375 | "outputs": [], 376 | "source": [ 377 | "training_smiles_enc8k = np.load(os.path.join(DATA_DIR, 'pretrained', 'training.enc8000.npy'))\n", 378 | "val2_smiles_enc8k = np.load(os.path.join(DATA_DIR, 'pretrained', 'val2.enc8000.npy'))\n", 379 | "v2samples = list(pd.read_csv(os.path.join(DATA_DIR, 'notebooks', \"fast_val2_molecules.csv\"))[\"SMILES\"])" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": null, 385 | "metadata": { 386 | "scrolled": true 387 | }, 388 | "outputs": [], 389 | "source": [ 390 | "print('\\n'.join(random_smiles(training_smiles_enc8k, 5)))" 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": null, 396 | "metadata": { 397 | "scrolled": true 398 | }, 399 | "outputs": [], 400 | "source": [ 401 | "len(training_smiles_enc8k)" 402 | ] 403 | }, 404 | { 405 | "cell_type": "markdown", 406 | "metadata": {}, 407 | "source": [ 408 | "Hit Ctrl-Enter on the following cell several times to explore random samples from the training set. \n", 409 | "Change \"training_smiles_enc8k\" to \"val2_smiles_enc8k\" to explore the validation set." 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": null, 415 | "metadata": { 416 | "scrolled": false 417 | }, 418 | "outputs": [], 419 | "source": [ 420 | "import rdkit.Chem.Descriptors\n", 421 | "smiles = random_smiles(training_smiles_enc8k, 4) # pick training_smiles... or val2_smiles...\n", 422 | "displayTable4SMILES(smiles)\n", 423 | "for x in smiles: \n", 424 | " m = Chem.MolFromSmiles(x)\n", 425 | " print( m.GetNumAtoms(), m.GetNumAtoms(onlyExplicit=False), \n", 426 | " Chem.Descriptors.NumAromaticRings(m), \n", 427 | " np.round(Chem.Descriptors.TPSA(m), 2),\n", 428 | " np.round(Chem.Descriptors.MolWt(m), 2), \n", 429 | " np.round(Chem.Descriptors.MolLogP(m), 2) )\n", 430 | "for x in smiles: barcode_fp(smiles_to_fingerprint(x), height=0.3)" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": null, 436 | "metadata": {}, 437 | "outputs": [], 438 | "source": [] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": null, 443 | "metadata": {}, 444 | "outputs": [], 445 | "source": [] 446 | }, 447 | { 448 | "cell_type": "markdown", 449 | "metadata": {}, 450 | "source": [ 451 | "##
subSMILES and DESMILES
" 452 | ] 453 | }, 454 | { 455 | "cell_type": "markdown", 456 | "metadata": {}, 457 | "source": [ 458 | "Each molecule is made out of up to 26 subSMILES (byte-pair encoded symbols).\n", 459 | "These are represented as integer numbers in the encoded tables, and 'itos' converts them to strings.\n", 460 | "The following example decomposes a random molecule from the training set." 461 | ] 462 | }, 463 | { 464 | "cell_type": "code", 465 | "execution_count": null, 466 | "metadata": {}, 467 | "outputs": [], 468 | "source": [ 469 | "training_smiles_enc8k.shape" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": null, 475 | "metadata": { 476 | "scrolled": false 477 | }, 478 | "outputs": [], 479 | "source": [ 480 | "rsamples = random_smiles_enc8k(training_smiles_enc8k, 1)\n", 481 | "print(rsamples)\n", 482 | "display(imageOfMolsLabels(rsamples, subImgSize=(600,300), molsPerRow=1))\n", 483 | "barcode_fp(smiles_to_fingerprint(rsamples[0][0]));" 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": null, 489 | "metadata": {}, 490 | "outputs": [], 491 | "source": [ 492 | "rsamples[0][0], [(x, itos[x]) for s, e in rsamples for x in e]" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": null, 498 | "metadata": {}, 499 | "outputs": [], 500 | "source": [] 501 | }, 502 | { 503 | "cell_type": "markdown", 504 | "metadata": {}, 505 | "source": [ 506 | "### Load a pretrained model" 507 | ] 508 | }, 509 | { 510 | "cell_type": "markdown", 511 | "metadata": {}, 512 | "source": [ 513 | "Let's load a model that was trained on molecules from both the validation and the training set. This would be the model to use in future applications of DESMILES. The model parameters came as the result of the hyperoptimization discussed in the publication." 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "execution_count": null, 519 | "metadata": {}, 520 | "outputs": [], 521 | "source": [ 522 | "get_model = desmiles.utils.load_old_pretrained_desmiles\n", 523 | "fp_to_smiles_5layer = get_model(os.path.join(DATA_DIR, 'pretrained', 'train_val1_val2','model_2000_400_2000_5'))\n", 524 | "rmodel = RecurrentDESMILES(fp_to_smiles_5layer)" 525 | ] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "execution_count": null, 530 | "metadata": {}, 531 | "outputs": [], 532 | "source": [ 533 | "rmodel" 534 | ] 535 | }, 536 | { 537 | "cell_type": "markdown", 538 | "metadata": {}, 539 | "source": [ 540 | "##
First tests: create a fragment from its fingerprint
" 541 | ] 542 | }, 543 | { 544 | "cell_type": "markdown", 545 | "metadata": {}, 546 | "source": [ 547 | "The simplest application of the model is to create a molecule from its fingerprint.\n", 548 | "Here is the example of a fragment that is outside of the original library." 549 | ] 550 | }, 551 | { 552 | "cell_type": "code", 553 | "execution_count": null, 554 | "metadata": { 555 | "autoscroll": false, 556 | "ein.hycell": false, 557 | "ein.tags": "worksheet-0", 558 | "scrolled": true, 559 | "slideshow": { 560 | "slide_type": "-" 561 | } 562 | }, 563 | "outputs": [], 564 | "source": [ 565 | "fragment = 'Nc1ncc(C(F)(F)F)cc1F' \n", 566 | "display(imageOfMols([fragment], labels=[fragment]))\n", 567 | "fp = smiles_to_fingerprint(fragment, as_tensor=True)\n", 568 | "barcode_fp(fp);" 569 | ] 570 | }, 571 | { 572 | "cell_type": "code", 573 | "execution_count": null, 574 | "metadata": {}, 575 | "outputs": [], 576 | "source": [ 577 | "fp, fp.size(), set(fp.cpu().numpy()), sum(fp)" 578 | ] 579 | }, 580 | { 581 | "cell_type": "markdown", 582 | "metadata": {}, 583 | "source": [ 584 | "We can invert this fingerprint to generate a small molecule. " 585 | ] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": null, 590 | "metadata": { 591 | "autoscroll": false, 592 | "ein.hycell": false, 593 | "ein.tags": "worksheet-0", 594 | "scrolled": true, 595 | "slideshow": { 596 | "slide_type": "-" 597 | } 598 | }, 599 | "outputs": [], 600 | "source": [ 601 | "%%time\n", 602 | "smiles = get_most_probable_smiles(fp, rmodel)" 603 | ] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "execution_count": null, 608 | "metadata": { 609 | "scrolled": false 610 | }, 611 | "outputs": [], 612 | "source": [ 613 | "imageOfMols([smiles])" 614 | ] 615 | }, 616 | { 617 | "cell_type": "markdown", 618 | "metadata": {}, 619 | "source": [ 620 | "##
Generate a collection of fragments with Astar
" 621 | ] 622 | }, 623 | { 624 | "cell_type": "markdown", 625 | "metadata": {}, 626 | "source": [ 627 | "Often one wants to generate a whole bunch of variations of a single molecule.\n", 628 | "\n", 629 | "The example below shows variations of this simple fragment outside the training/validation set, \n", 630 | "together with a measure of the fingerprint similarity (higher is better; 1.0 is perfect match of fingerprints.)\n", 631 | "\n", 632 | "Sometimes the model will go through a number of invalid intermediate attempts, before finding the next example.\n", 633 | "The function get_n_most_probable_valid_smiles will only return the valid molecules.\n", 634 | "The parameter max_branches limits the search; for very complicated molecules, \n", 635 | "the search might be exhausted before the optimal molecules get returned, \n", 636 | "so higher values of max_branches might get \"better\" molecules \n", 637 | "but the search will take up more GPU memory.\n" 638 | ] 639 | }, 640 | { 641 | "cell_type": "code", 642 | "execution_count": null, 643 | "metadata": {}, 644 | "outputs": [], 645 | "source": [ 646 | "%%time\n", 647 | "%%capture --no-stdout --no-display\n", 648 | "smiles = dedup_clean([Chem.CanonSmiles(x) \n", 649 | " for x in get_first_n_most_probable_valid_smiles(8, fp, rmodel, \n", 650 | " max_branches=100)])" 651 | ] 652 | }, 653 | { 654 | "cell_type": "code", 655 | "execution_count": null, 656 | "metadata": {}, 657 | "outputs": [], 658 | "source": [ 659 | "labels = np.round([tanimoto(fp, smiles_to_fingerprint(x)) for x in smiles], 2)\n", 660 | "for x in smiles: barcode_fp(smiles_to_fingerprint(x), height=0.2)\n", 661 | "display(imageOfMols(smiles, molsPerRow=4, subImgSize=(240, 200), labels=labels));" 662 | ] 663 | }, 664 | { 665 | "cell_type": "markdown", 666 | "metadata": {}, 667 | "source": [ 668 | "### Check out some more complex and random molecules (subset of validation 2)" 669 | ] 670 | }, 671 | { 672 | "cell_type": "markdown", 673 | "metadata": {}, 674 | "source": [ 675 | "Below are some complicated molecules and their top 3 variants according to DEMSMILES. For this demonstration we picked as inputs a subset of the validation molecules that decoded rather quickly. \n", 676 | "\n", 677 | "Please click Ctrl-Enter on the next cell a couple of times until you see some interesting molecules." 678 | ] 679 | }, 680 | { 681 | "cell_type": "code", 682 | "execution_count": null, 683 | "metadata": { 684 | "scrolled": true 685 | }, 686 | "outputs": [], 687 | "source": [ 688 | "rsamples = list(np.random.choice(v2samples, 4))\n", 689 | "print(rsamples)\n", 690 | "displayTable4SMILES(rsamples)" 691 | ] 692 | }, 693 | { 694 | "cell_type": "code", 695 | "execution_count": null, 696 | "metadata": { 697 | "scrolled": false 698 | }, 699 | "outputs": [], 700 | "source": [ 701 | "%%time\n", 702 | "%%capture --no-stdout --no-display\n", 703 | "for s in rsamples:\n", 704 | " newfp = smiles_to_fingerprint(s)\n", 705 | " newsmiles = dedup_clean([Chem.CanonSmiles(x) for x in \n", 706 | " get_first_n_most_probable_valid_smiles(4,newfp,rmodel, max_branches=200)])\n", 707 | " newsmiles.extend([\"\", \"\", \"\", \"\"])\n", 708 | " newsmiles = newsmiles[:4]\n", 709 | " displayTable4SMILES(newsmiles)\n", 710 | " labels = [np.round(tanimoto(newfp, smiles_to_fingerprint(s)), 2) for s in newsmiles]\n", 711 | " displayTable4labels(labels)" 712 | ] 713 | }, 714 | { 715 | "cell_type": "code", 716 | "execution_count": null, 717 | "metadata": {}, 718 | "outputs": [], 719 | "source": [] 720 | }, 721 | { 722 | "cell_type": "markdown", 723 | "metadata": {}, 724 | "source": [ 725 | "###
Perturbations of a molecule
" 726 | ] 727 | }, 728 | { 729 | "cell_type": "markdown", 730 | "metadata": {}, 731 | "source": [ 732 | "Let's add a little noise to the fingerprints of our little fragment by turning some random bits on." 733 | ] 734 | }, 735 | { 736 | "cell_type": "code", 737 | "execution_count": null, 738 | "metadata": {}, 739 | "outputs": [], 740 | "source": [ 741 | "%%time\n", 742 | "smiles = []\n", 743 | "extra_fp_on = [30, 40, 50, 60]\n", 744 | "torch.random.manual_seed(314)\n", 745 | "for num_bits_on in extra_fp_on:\n", 746 | " random_indices = torch.randint(0, fp.size().numel(), torch.Size([num_bits_on]))\n", 747 | " fp_add = torch.zeros_like(fp)\n", 748 | " fp_add[random_indices] = 1\n", 749 | " fp_pert = fp + fp_add\n", 750 | " s = get_most_probable_valid_smiles(fp_pert, rmodel, max_branches=100)\n", 751 | " smiles.append(s)\n", 752 | "display(imageOfMols(smiles, labels=extra_fp_on))" 753 | ] 754 | }, 755 | { 756 | "cell_type": "code", 757 | "execution_count": null, 758 | "metadata": {}, 759 | "outputs": [], 760 | "source": [] 761 | }, 762 | { 763 | "cell_type": "markdown", 764 | "metadata": {}, 765 | "source": [ 766 | "##
Intro to algebra of molecules
\n", 767 | "We can \"add\" two molecules by mixing their fingerprints \n", 768 | "(or by mixing their embeddings, or other internal layers)\n", 769 | "Even though the model is highly nonlinear, the \"addition\" is often intuitive, \n", 770 | "for example when the model is able to combine the fingerprints \n", 771 | "and create a molecule that matches both inputs." 772 | ] 773 | }, 774 | { 775 | "cell_type": "code", 776 | "execution_count": null, 777 | "metadata": {}, 778 | "outputs": [], 779 | "source": [ 780 | "fragment2 = \"c1cccnc1N2CCCCC2\"\n", 781 | "\n", 782 | "imageOfMols([fragment, fragment2])" 783 | ] 784 | }, 785 | { 786 | "cell_type": "code", 787 | "execution_count": null, 788 | "metadata": { 789 | "scrolled": false 790 | }, 791 | "outputs": [], 792 | "source": [ 793 | "%%time\n", 794 | "%%capture --no-stdout --no-display\n", 795 | "fps = [smiles_to_fingerprint(x, as_tensor=False) for x in [fragment, fragment2]]\n", 796 | "ftarget = (fps[0] | fps[1])\n", 797 | "ftarget = torch.Tensor(ftarget).cuda()\n", 798 | "smiles = dedup_clean(get_first_n_most_probable_valid_smiles(6, ftarget, \n", 799 | " rmodel))\n", 800 | "labels = [tanimoto(smiles_to_fingerprint(x), ftarget) for x in smiles]\n", 801 | "labels = [\"A\", \"B\"] + [str(x) for x in np.round(labels, 2)]\n", 802 | "display(imageOfMols([fragment, fragment2, *smiles], labels=labels))" 803 | ] 804 | }, 805 | { 806 | "cell_type": "markdown", 807 | "metadata": {}, 808 | "source": [ 809 | "Let's get rid of the tri-fluoromethyl group" 810 | ] 811 | }, 812 | { 813 | "cell_type": "code", 814 | "execution_count": null, 815 | "metadata": { 816 | "scrolled": true 817 | }, 818 | "outputs": [], 819 | "source": [ 820 | "fragment3 = \"c1ccccc1C(F)(F)F\"\n", 821 | "imageOfMols([fragment, fragment2, fragment3], labels=[\"A\", \"+ B\", \"- C\"])" 822 | ] 823 | }, 824 | { 825 | "cell_type": "code", 826 | "execution_count": null, 827 | "metadata": { 828 | "scrolled": false 829 | }, 830 | "outputs": [], 831 | "source": [ 832 | "%%time\n", 833 | "%%capture --no-stdout --no-display\n", 834 | "fps = [smiles_to_fingerprint(x, as_tensor=False) for x in [fragment, fragment2, fragment3]]\n", 835 | "f3target = (fps[0] | fps[1]) - fps[2]\n", 836 | "f3target = np.clip(f3target, 0, 1)\n", 837 | "f3target = torch.Tensor(f3target).cuda()\n", 838 | "smiles = dedup_clean(get_first_n_most_probable_valid_smiles(5, f3target, rmodel))\n", 839 | "labels = [tanimoto(smiles_to_fingerprint(x), f3target) for x in smiles]\n", 840 | "labels = [\"A\", \"+ B\", \"- C\"] + [str(x) for x in np.round(labels, 2)]\n", 841 | "display(imageOfMols([fragment, fragment2, fragment3, *smiles], labels=labels))" 842 | ] 843 | }, 844 | { 845 | "cell_type": "markdown", 846 | "metadata": {}, 847 | "source": [ 848 | "##
Fine tuning applications
\n" 849 | ] 850 | }, 851 | { 852 | "cell_type": "markdown", 853 | "metadata": {}, 854 | "source": [ 855 | "A promising way to generate new molecules is to finetune the DESMILES model to improve the outputs using training inputs from matched pairs. This somewhat more complicated application of the model was described in the publication, and was used to generate the figures below." 856 | ] 857 | }, 858 | { 859 | "cell_type": "code", 860 | "execution_count": null, 861 | "metadata": { 862 | "scrolled": false 863 | }, 864 | "outputs": [], 865 | "source": [ 866 | "img = Image(filename=f\"{fig_dir}/deep learn chem space (desmiles)__extended data fig 5__2__nisonoff__2019__.png\")\n", 867 | "display(img)" 868 | ] 869 | }, 870 | { 871 | "cell_type": "code", 872 | "execution_count": null, 873 | "metadata": { 874 | "scrolled": false 875 | }, 876 | "outputs": [], 877 | "source": [ 878 | "img = Image(filename=f\"{fig_dir}/deep learn chem space (desmiles)__extended data fig 4__2__nisonoff__2019__.png\")\n", 879 | "display(img)" 880 | ] 881 | } 882 | ], 883 | "metadata": { 884 | "kernelspec": { 885 | "display_name": "Python 3 (ipykernel)", 886 | "language": "python", 887 | "name": "python3" 888 | }, 889 | "language_info": { 890 | "codemirror_mode": { 891 | "name": "ipython", 892 | "version": 3 893 | }, 894 | "file_extension": ".py", 895 | "mimetype": "text/x-python", 896 | "name": "python", 897 | "nbconvert_exporter": "python", 898 | "pygments_lexer": "ipython3", 899 | "version": "3.7.12" 900 | }, 901 | "name": "FPEmbedding_Paul.ipynb" 902 | }, 903 | "nbformat": 4, 904 | "nbformat_minor": 2 905 | } 906 | --------------------------------------------------------------------------------