├── .flake8 ├── .gitignore ├── .travis.yml ├── LICENSE ├── MANIFEST.in ├── README.md ├── dockers ├── Dockerfile └── requirements.txt ├── guacamol ├── __init__.py ├── assess_distribution_learning.py ├── assess_goal_directed_generation.py ├── benchmark_suites.py ├── common_scoring_functions.py ├── data │ ├── __init__.py │ ├── get_data.py │ └── holdout_set_gcm_v1.smiles ├── distribution_learning_benchmark.py ├── distribution_matching_generator.py ├── frechet_benchmark.py ├── goal_directed_benchmark.py ├── goal_directed_generator.py ├── goal_directed_score_contributions.py ├── py.typed ├── score_modifier.py ├── scoring_function.py ├── standard_benchmarks.py └── utils │ ├── __init__.py │ ├── chemistry.py │ ├── data.py │ ├── descriptors.py │ ├── fingerprints.py │ ├── helpers.py │ ├── math.py │ └── sampling_helpers.py ├── mypy.ini ├── setup.py └── tests ├── __init__.py ├── mock_generator.py ├── test_distribution_learning_benchmarks.py ├── test_goal_directed_benchmark.py ├── test_sampling_helpers.py ├── test_score_modifier.py ├── test_scoring_functions.py └── utils ├── test_chemistry.py ├── test_data.py └── test_descriptors.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E501, E731 3 | exclude = .git,__pycache__ 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *pyc 2 | .idea 3 | .cache 4 | .mypy_cache 5 | .pytest_cache 6 | __pycache__ 7 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | python: 4 | - "3.6" 5 | 6 | before_install: 7 | # download and install miniconda (for convenience) 8 | - wget http://repo.continuum.io/miniconda/Miniconda3-4.1.11-Linux-x86_64.sh -O miniconda.sh; 9 | - bash miniconda.sh -b -p $HOME/conda 10 | - export PATH="$HOME/conda/bin:$PATH" 11 | - conda config --set always_yes yes --set changeps1 no 12 | - conda update -q conda 13 | 14 | - conda create -n test_env python=$TRAVIS_PYTHON_VERSION pip cmake 15 | - source activate test_env 16 | 17 | install: 18 | - pip install -r dockers/requirements.txt # install testing requirements 19 | - pip install . # install guacamol benchmark 20 | 21 | script: 22 | # Style guide enforcement 23 | - flake8 guacamol && flake8 tests 24 | # Static typing enforcement 25 | - mypy guacamol && mypy tests 26 | # Test suite 27 | - python -m pytest tests -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 BenevolentAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Include the file holdout SMILES strings when the guacamol package is generated 2 | include guacamol/data/holdout_set_gcm_v1.smiles 3 | # Marker file to say that guacamol supports type checking 4 | include guacamol/py.typed 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GuacaMol 2 | 3 | [![Build Status](https://travis-ci.com/BenevolentAI/guacamol.svg?branch=master)](https://travis-ci.com/BenevolentAI/guacamol) 4 | 5 | **GuacaMol** is an open source Python package for benchmarking of models for 6 | *de novo* molecular design. 7 | 8 | For an in-depth explanation of the types of benchmarks and baseline scores, 9 | please consult our paper 10 | [Benchmarking Models for De Novo Molecular Design](https://arxiv.org/abs/1811.09621) 11 | 12 | ## Installation 13 | 14 | The easiest way to install `guacamol` is with `pip`: 15 | ```bash 16 | pip install guacamol 17 | ``` 18 | 19 | Dependencies: 20 | - `guacamol` requires the [RDKit library](http://rdkit.org/) (version `2018.09.1.0` or newer). 21 | - We also depend on the [FCD](https://github.com/bioinf-jku/FCD) library (version `1.1`) for the calculation of the Fréchet ChemNet Distance. 22 | 23 | #### Unit testing suite 24 | 25 | You can test your installation of the guacamol benchmarking library by running the unit tests from this directory: 26 | ```bash 27 | pytest . 28 | ``` 29 | 30 | 31 | ## Benchmarking models 32 | 33 | For the distribution-learning benchmarks, specialize `DistributionMatchingGenerator` 34 | (from `guacamol.distribution_matching_generator`) for your model. 35 | Instances of this class must be able to generate molecules similar to the training set. 36 | For the actual benchmarks, call `assess_distribution_learning` 37 | (from `guacamol.assess_distribution_learning`) with an instance of your class. 38 | You must also provide the location of the training set file (See section "Data" below). 39 | 40 | For the goal-directed benchmarks, specialize `GoalDirectedGenerator` 41 | (from `guacamol.goal_directed_generator`) for your model. 42 | Instances of this class must be able to generate a specified number of molecules 43 | that achieve high scores for a given scoring function. 44 | For the actual benchmarks, call `assess_goal_directed_generation` 45 | (from `guacamol.assess_goal_directed_generation`) with an instance of your class. 46 | 47 | Example implementations for baseline methods are available from https://github.com/BenevolentAI/guacamol_baselines. 48 | 49 | In [guacamol_baselines](https://github.com/BenevolentAI/guacamol_baselines), 50 | we provide a `Dockerfile` with an example environment for developing generative models and running guacamol. 51 | 52 | ## Data 53 | 54 | For fairness in the evaluation of the benchmarks and comparability of the results, 55 | you should use a training set containing molecules from the ChEMBL dataset. 56 | Follow the procedure described below to get standardized datasets. 57 | 58 | 59 | ### Download 60 | 61 | You can download pre-built datasets [here](https://figshare.com/projects/GuacaMol/56639): 62 | 63 | md5 `05ad85d871958a05c02ab51a4fde8530` [training](https://ndownloader.figshare.com/files/13612760 ) 64 | md5 `e53db4bff7dc4784123ae6df72e3b1f0` [validation](https://ndownloader.figshare.com/files/13612766) 65 | md5 `677b757ccec4809febd83850b43e1616` [test](https://ndownloader.figshare.com/files/13612757) 66 | md5 `7d45bc95c33c10cb96ef5e78c38ac0b6` [all](https://ndownloader.figshare.com/files/13612745) 67 | 68 | 69 | ### Generation 70 | 71 | To generate the training data yourself, run 72 | ``` 73 | python -m guacamol.data.get_data -o [output_directory] 74 | ``` 75 | which will download and process ChEMBL for you in your current folder. 76 | 77 | This script will use the molecules from 78 | [`holdout_set_gcm_v1.smiles`](https://github.com/BenevolentAI/guacamol/blob/master/guacamol/data/holdout_set_gcm_v1.smiles) 79 | as a holdout set, and will exclude molecules very similar to these. 80 | 81 | Different versions of your Python packages may lead to differences in the generated dataset, which will cause the script to fail. 82 | See the section below ("Docker") to reproducibly generate the standardized dataset with the hashes given above. 83 | 84 | ### Docker 85 | 86 | To be sure that you have the right dependencies you can build a Docker image, run from the top-level directory: 87 | ``` 88 | docker build -t guacamol-deps -f dockers/Dockerfile . 89 | ``` 90 | Then you can run: 91 | ``` 92 | docker run --rm -it -v `pwd`:/guacamol -w /guacamol guacamol-deps python -m guacamol.data.get_data -o guacamol/data 93 | ``` 94 | 95 | ## Change log 96 | - 1 May 2020: update version of FCD dependency 97 | - 15 Oct 2020: pin dependencies since FCD does not 98 | - 10 Nov 2021: relax pinned versions of keras, tensorflow & h5py dependencies 99 | - 20 Dec 2021: expose forbidden symbols argument for custom smiles dataset filtering 100 | 101 | ## Leaderboard 102 | 103 | See [https://www.benevolent.com/guacamol](https://www.benevolent.com/guacamol). 104 | -------------------------------------------------------------------------------- /dockers/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 2 | 3 | # metainformation 4 | LABEL org.opencontainers.image.version = "0.5.5" 5 | LABEL org.opencontainers.image.authors = "BenevolentAI" 6 | LABEL org.opencontainers.image.source = "https://github.com/BenevolentAI/guacamol" 7 | LABEL org.opencontainers.image.licenses = "MIT" 8 | LABEL org.opencontainers.image.base.name="docker.io/library/ubuntu:18.04" 9 | 10 | RUN apt-get update && apt-get install -y --no-install-recommends \ 11 | build-essential \ 12 | cmake ca-certificates \ 13 | libglib2.0-0 libxext6 libsm6 libxrender1 \ 14 | wget curl bash bzip2 && \ 15 | apt-get clean && \ 16 | rm -rf /var/lib/apt/lists/* 17 | 18 | # MiniConda 19 | RUN curl -LO --silent https://repo.continuum.io/miniconda/Miniconda3-4.5.11-Linux-x86_64.sh && \ 20 | bash Miniconda3-4.5.11-Linux-x86_64.sh -p /miniconda -b && \ 21 | rm Miniconda3-4.5.11-Linux-x86_64.sh 22 | 23 | ENV PATH=/miniconda/bin:${PATH} 24 | 25 | # Add the source code 26 | RUN mkdir -p /app 27 | ADD . /app 28 | 29 | # python deps for running tests 30 | RUN pip install --upgrade pip && pip install --no-cache-dir -r /app/dockers/requirements.txt 31 | 32 | # install guacamol 33 | RUN pip install --upgrade pip && pip install --no-cache-dir /app/ 34 | 35 | # Launch inside the folder 36 | WORKDIR /app/ 37 | -------------------------------------------------------------------------------- /dockers/requirements.txt: -------------------------------------------------------------------------------- 1 | flake8>=3.5.0 2 | mypy>=0.630 3 | pytest>=3.8.2 -------------------------------------------------------------------------------- /guacamol/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.5.5" 2 | -------------------------------------------------------------------------------- /guacamol/assess_distribution_learning.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import logging 4 | from collections import OrderedDict 5 | from typing import List, Dict, Any 6 | 7 | import guacamol 8 | from guacamol.distribution_learning_benchmark import DistributionLearningBenchmark, DistributionLearningBenchmarkResult 9 | from guacamol.distribution_matching_generator import DistributionMatchingGenerator 10 | from guacamol.benchmark_suites import distribution_learning_benchmark_suite 11 | from guacamol.utils.data import get_time_string 12 | 13 | logger = logging.getLogger(__name__) 14 | logger.addHandler(logging.NullHandler()) 15 | 16 | 17 | def assess_distribution_learning(model: DistributionMatchingGenerator, 18 | chembl_training_file: str, 19 | json_output_file='output_distribution_learning.json', 20 | benchmark_version='v1') -> None: 21 | """ 22 | Assesses a distribution-matching model for de novo molecule design. 23 | 24 | Args: 25 | model: Model to evaluate 26 | chembl_training_file: path to ChEMBL training set, necessary for some benchmarks 27 | json_output_file: Name of the file where to save the results in JSON format 28 | benchmark_version: which benchmark suite to execute 29 | """ 30 | _assess_distribution_learning(model=model, 31 | chembl_training_file=chembl_training_file, 32 | json_output_file=json_output_file, 33 | benchmark_version=benchmark_version, 34 | number_samples=10000) 35 | 36 | 37 | def _assess_distribution_learning(model: DistributionMatchingGenerator, 38 | chembl_training_file: str, 39 | json_output_file: str, 40 | benchmark_version: str, 41 | number_samples: int) -> None: 42 | """ 43 | Internal equivalent to assess_distribution_learning, but allows for a flexible number of samples. 44 | To call directly only for testing. 45 | """ 46 | logger.info(f'Benchmarking distribution learning, version {benchmark_version}') 47 | benchmarks = distribution_learning_benchmark_suite(chembl_file_path=chembl_training_file, 48 | version_name=benchmark_version, 49 | number_samples=number_samples) 50 | 51 | results = _evaluate_distribution_learning_benchmarks(model=model, benchmarks=benchmarks) 52 | 53 | benchmark_results: Dict[str, Any] = OrderedDict() 54 | benchmark_results['guacamol_version'] = guacamol.__version__ 55 | benchmark_results['benchmark_suite_version'] = benchmark_version 56 | benchmark_results['timestamp'] = get_time_string() 57 | benchmark_results['samples'] = model.generate(100) 58 | benchmark_results['results'] = [vars(result) for result in results] 59 | 60 | logger.info(f'Save results to file {json_output_file}') 61 | with open(json_output_file, 'wt') as f: 62 | f.write(json.dumps(benchmark_results, indent=4)) 63 | 64 | 65 | def _evaluate_distribution_learning_benchmarks(model: DistributionMatchingGenerator, 66 | benchmarks: List[DistributionLearningBenchmark] 67 | ) -> List[DistributionLearningBenchmarkResult]: 68 | """ 69 | Evaluate a model with the given benchmarks. 70 | Should not be called directly except for testing purposes. 71 | 72 | Args: 73 | model: model to assess 74 | benchmarks: list of benchmarks to evaluate 75 | json_output_file: Name of the file where to save the results in JSON format 76 | """ 77 | 78 | logger.info(f'Number of benchmarks: {len(benchmarks)}') 79 | 80 | results = [] 81 | for i, benchmark in enumerate(benchmarks, 1): 82 | logger.info(f'Running benchmark {i}/{len(benchmarks)}: {benchmark.name}') 83 | result = benchmark.assess_model(model) 84 | logger.info(f'Results for the benchmark "{result.benchmark_name}":') 85 | logger.info(f' Score: {result.score:.6f}') 86 | logger.info(f' Sampling time: {str(datetime.timedelta(seconds=int(result.sampling_time)))}') 87 | logger.info(f' Metadata: {result.metadata}') 88 | results.append(result) 89 | 90 | logger.info('Finished execution of the benchmarks') 91 | 92 | return results 93 | -------------------------------------------------------------------------------- /guacamol/assess_goal_directed_generation.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import logging 4 | from collections import OrderedDict 5 | from typing import List, Any, Dict 6 | 7 | import guacamol 8 | from guacamol.goal_directed_benchmark import GoalDirectedBenchmark, GoalDirectedBenchmarkResult 9 | from guacamol.goal_directed_generator import GoalDirectedGenerator 10 | from guacamol.benchmark_suites import goal_directed_benchmark_suite 11 | from guacamol.utils.data import get_time_string 12 | 13 | logger = logging.getLogger(__name__) 14 | logger.addHandler(logging.NullHandler()) 15 | 16 | 17 | def assess_goal_directed_generation(goal_directed_molecule_generator: GoalDirectedGenerator, 18 | json_output_file='output_goal_directed.json', 19 | benchmark_version='v1') -> None: 20 | """ 21 | Assesses a distribution-matching model for de novo molecule design. 22 | 23 | Args: 24 | goal_directed_molecule_generator: Model to evaluate 25 | json_output_file: Name of the file where to save the results in JSON format 26 | benchmark_version: which benchmark suite to execute 27 | """ 28 | logger.info(f'Benchmarking goal-directed molecule generation, version {benchmark_version}') 29 | benchmarks = goal_directed_benchmark_suite(version_name=benchmark_version) 30 | 31 | results = _evaluate_goal_directed_benchmarks( 32 | goal_directed_molecule_generator=goal_directed_molecule_generator, 33 | benchmarks=benchmarks) 34 | 35 | benchmark_results: Dict[str, Any] = OrderedDict() 36 | benchmark_results['guacamol_version'] = guacamol.__version__ 37 | benchmark_results['benchmark_suite_version'] = benchmark_version 38 | benchmark_results['timestamp'] = get_time_string() 39 | benchmark_results['results'] = [vars(result) for result in results] 40 | 41 | logger.info(f'Save results to file {json_output_file}') 42 | with open(json_output_file, 'wt') as f: 43 | f.write(json.dumps(benchmark_results, indent=4)) 44 | 45 | 46 | def _evaluate_goal_directed_benchmarks(goal_directed_molecule_generator: GoalDirectedGenerator, 47 | benchmarks: List[GoalDirectedBenchmark] 48 | ) -> List[GoalDirectedBenchmarkResult]: 49 | """ 50 | Evaluate a model with the given benchmarks. 51 | Should not be called directly except for testing purposes. 52 | 53 | Args: 54 | goal_directed_molecule_generator: model to assess 55 | benchmarks: list of benchmarks to evaluate 56 | json_output_file: Name of the file where to save the results in JSON format 57 | """ 58 | 59 | logger.info(f'Number of benchmarks: {len(benchmarks)}') 60 | 61 | results = [] 62 | for i, benchmark in enumerate(benchmarks, 1): 63 | logger.info(f'Running benchmark {i}/{len(benchmarks)}: {benchmark.name}') 64 | result = benchmark.assess_model(goal_directed_molecule_generator) 65 | logger.info(f'Results for the benchmark "{result.benchmark_name}":') 66 | logger.info(f' Score: {result.score:.6f}') 67 | logger.info(f' Execution time: {str(datetime.timedelta(seconds=int(result.execution_time)))}') 68 | logger.info(f' Metadata: {result.metadata}') 69 | results.append(result) 70 | 71 | logger.info('Finished execution of the benchmarks') 72 | 73 | return results 74 | -------------------------------------------------------------------------------- /guacamol/benchmark_suites.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from guacamol.distribution_learning_benchmark import DistributionLearningBenchmark, ValidityBenchmark, \ 4 | UniquenessBenchmark 5 | from guacamol.goal_directed_benchmark import GoalDirectedBenchmark 6 | from guacamol.scoring_function import ArithmeticMeanScoringFunction 7 | from guacamol.standard_benchmarks import hard_cobimetinib, similarity, logP_benchmark, cns_mpo, \ 8 | qed_benchmark, median_camphor_menthol, novelty_benchmark, isomers_c11h24, isomers_c7h8n2o2, isomers_c9h10n2o2pf2cl, \ 9 | frechet_benchmark, tpsa_benchmark, hard_osimertinib, hard_fexofenadine, weird_physchem, start_pop_ranolazine, \ 10 | kldiv_benchmark, perindopril_rings, amlodipine_rings, sitagliptin_replacement, zaleplon_with_other_formula, valsartan_smarts, \ 11 | median_tadalafil_sildenafil, decoration_hop, scaffold_hop, ranolazine_mpo, pioglitazone_mpo 12 | 13 | 14 | def goal_directed_benchmark_suite(version_name: str) -> List[GoalDirectedBenchmark]: 15 | if version_name == 'v1': 16 | return goal_directed_suite_v1() 17 | if version_name == 'v2': 18 | return goal_directed_suite_v2() 19 | if version_name == 'trivial': 20 | return goal_directed_suite_trivial() 21 | 22 | raise Exception(f'Goal-directed benchmark suite "{version_name}" does not exist.') 23 | 24 | 25 | def distribution_learning_benchmark_suite(chembl_file_path: str, 26 | version_name: str, 27 | number_samples: int) -> List[DistributionLearningBenchmark]: 28 | """ 29 | Returns a suite of benchmarks for a specified benchmark version 30 | 31 | Args: 32 | chembl_file_path: path to ChEMBL training set, necessary for some benchmarks 33 | version_name: benchmark version 34 | 35 | Returns: 36 | List of benchmaks 37 | """ 38 | 39 | # For distribution-learning, v1 and v2 are identical 40 | if version_name == 'v1' or version_name == 'v2': 41 | return distribution_learning_suite_v1(chembl_file_path=chembl_file_path, number_samples=number_samples) 42 | 43 | raise Exception(f'Distribution-learning benchmark suite "{version_name}" does not exist.') 44 | 45 | 46 | def goal_directed_suite_v1() -> List[GoalDirectedBenchmark]: 47 | max_logP = 6.35584 48 | return [ 49 | isomers_c11h24(mean_function='arithmetic'), 50 | isomers_c7h8n2o2(mean_function='arithmetic'), 51 | isomers_c9h10n2o2pf2cl(mean_function='arithmetic', n_samples=100), 52 | 53 | hard_cobimetinib(max_logP=max_logP), 54 | hard_osimertinib(ArithmeticMeanScoringFunction), 55 | hard_fexofenadine(ArithmeticMeanScoringFunction), 56 | weird_physchem(), 57 | 58 | # start pop benchmark 59 | # e.g. 60 | start_pop_ranolazine(), 61 | 62 | # similarity Benchmarks 63 | 64 | # explicit rediscovery 65 | similarity(smiles='CC1=CC=C(C=C1)C1=CC(=NN1C1=CC=C(C=C1)S(N)(=O)=O)C(F)(F)F', name='Celecoxib', 66 | fp_type='ECFP4', threshold=1.0, rediscovery=True), 67 | similarity(smiles='Cc1c(C)c2OC(C)(COc3ccc(CC4SC(=O)NC4=O)cc3)CCc2c(C)c1O', name='Troglitazone', 68 | fp_type='ECFP4', threshold=1.0, rediscovery=True), 69 | similarity(smiles='CN(C)S(=O)(=O)c1ccc2Sc3ccccc3C(=CCCN4CCN(C)CC4)c2c1', name='Thiothixene', 70 | fp_type='ECFP4', threshold=1.0, rediscovery=True), 71 | 72 | # generate similar stuff 73 | similarity(smiles='Clc4cccc(N3CCN(CCCCOc2ccc1c(NC(=O)CC1)c2)CC3)c4Cl', 74 | name='Aripiprazole', fp_type='FCFP4', threshold=0.75), 75 | similarity(smiles='CC(C)(C)NCC(O)c1ccc(O)c(CO)c1', name='Albuterol', 76 | fp_type='FCFP4', threshold=0.75), 77 | similarity(smiles='COc1ccc2[C@H]3CC[C@@]4(C)[C@@H](CC[C@@]4(O)C#C)[C@@H]3CCc2c1', name='Mestranol', 78 | fp_type='AP', threshold=0.75), 79 | 80 | logP_benchmark(target=-1.0), 81 | logP_benchmark(target=8.0), 82 | tpsa_benchmark(target=150.0), 83 | 84 | cns_mpo(max_logP=max_logP), 85 | qed_benchmark(), 86 | median_camphor_menthol(ArithmeticMeanScoringFunction) 87 | ] 88 | 89 | 90 | def goal_directed_suite_v2() -> List[GoalDirectedBenchmark]: 91 | return [ 92 | # explicit rediscovery 93 | similarity(smiles='CC1=CC=C(C=C1)C1=CC(=NN1C1=CC=C(C=C1)S(N)(=O)=O)C(F)(F)F', name='Celecoxib', fp_type='ECFP4', 94 | threshold=1.0, rediscovery=True), 95 | similarity(smiles='Cc1c(C)c2OC(C)(COc3ccc(CC4SC(=O)NC4=O)cc3)CCc2c(C)c1O', name='Troglitazone', fp_type='ECFP4', 96 | threshold=1.0, rediscovery=True), 97 | similarity(smiles='CN(C)S(=O)(=O)c1ccc2Sc3ccccc3C(=CCCN4CCN(C)CC4)c2c1', name='Thiothixene', fp_type='ECFP4', 98 | threshold=1.0, rediscovery=True), 99 | 100 | # generate similar stuff 101 | similarity(smiles='Clc4cccc(N3CCN(CCCCOc2ccc1c(NC(=O)CC1)c2)CC3)c4Cl', name='Aripiprazole', fp_type='ECFP4', 102 | threshold=0.75), 103 | similarity(smiles='CC(C)(C)NCC(O)c1ccc(O)c(CO)c1', name='Albuterol', fp_type='FCFP4', threshold=0.75), 104 | similarity(smiles='COc1ccc2[C@H]3CC[C@@]4(C)[C@@H](CC[C@@]4(O)C#C)[C@@H]3CCc2c1', name='Mestranol', 105 | fp_type='AP', threshold=0.75), 106 | 107 | # isomers 108 | isomers_c11h24(), 109 | isomers_c9h10n2o2pf2cl(), 110 | 111 | # median molecules 112 | median_camphor_menthol(), 113 | median_tadalafil_sildenafil(), 114 | 115 | # all other MPOs 116 | hard_osimertinib(), 117 | hard_fexofenadine(), 118 | ranolazine_mpo(), 119 | perindopril_rings(), 120 | amlodipine_rings(), 121 | sitagliptin_replacement(), 122 | zaleplon_with_other_formula(), 123 | valsartan_smarts(), 124 | decoration_hop(), 125 | scaffold_hop(), 126 | ] 127 | 128 | 129 | def goal_directed_suite_trivial() -> List[GoalDirectedBenchmark]: 130 | """ 131 | Trivial goal-directed benchmarks from the paper. 132 | """ 133 | return [ 134 | logP_benchmark(target=-1.0), 135 | logP_benchmark(target=8.0), 136 | tpsa_benchmark(target=150.0), 137 | cns_mpo(), 138 | qed_benchmark(), 139 | isomers_c7h8n2o2(), 140 | pioglitazone_mpo(), 141 | ] 142 | 143 | 144 | def distribution_learning_suite_v1(chembl_file_path: str, number_samples: int = 10000) -> \ 145 | List[DistributionLearningBenchmark]: 146 | """ 147 | Suite of distribution learning benchmarks, v1. 148 | 149 | Args: 150 | chembl_file_path: path to the file with the reference ChEMBL molecules 151 | 152 | Returns: 153 | List of benchmarks, version 1 154 | """ 155 | return [ 156 | ValidityBenchmark(number_samples=number_samples), 157 | UniquenessBenchmark(number_samples=number_samples), 158 | novelty_benchmark(training_set_file=chembl_file_path, number_samples=number_samples), 159 | kldiv_benchmark(training_set_file=chembl_file_path, number_samples=number_samples), 160 | frechet_benchmark(training_set_file=chembl_file_path, number_samples=number_samples) 161 | ] 162 | -------------------------------------------------------------------------------- /guacamol/common_scoring_functions.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List 2 | 3 | from rdkit import Chem 4 | from rdkit.DataStructs.cDataStructs import TanimotoSimilarity 5 | 6 | from guacamol.utils.descriptors import mol_weight, logP, num_H_donors, tpsa, num_atoms, AtomCounter 7 | from guacamol.utils.fingerprints import get_fingerprint 8 | from guacamol.score_modifier import ScoreModifier, MinGaussianModifier, MaxGaussianModifier, GaussianModifier 9 | from guacamol.scoring_function import ScoringFunctionBasedOnRdkitMol, MoleculewiseScoringFunction 10 | from guacamol.utils.chemistry import smiles_to_rdkit_mol, parse_molecular_formula 11 | from guacamol.utils.math import arithmetic_mean, geometric_mean 12 | 13 | 14 | class RdkitScoringFunction(ScoringFunctionBasedOnRdkitMol): 15 | """ 16 | Scoring function wrapping RDKit descriptors. 17 | """ 18 | 19 | def __init__(self, descriptor: Callable[[Chem.Mol], float], score_modifier: ScoreModifier = None) -> None: 20 | """ 21 | Args: 22 | descriptor: molecular descriptors, such as the ones in descriptors.py 23 | score_modifier: score modifier 24 | """ 25 | super().__init__(score_modifier=score_modifier) 26 | self.descriptor = descriptor 27 | 28 | def score_mol(self, mol: Chem.Mol) -> float: 29 | return self.descriptor(mol) 30 | 31 | 32 | class TanimotoScoringFunction(ScoringFunctionBasedOnRdkitMol): 33 | """ 34 | Scoring function that looks at the fingerprint similarity against a target molecule. 35 | """ 36 | 37 | def __init__(self, target, fp_type, score_modifier: ScoreModifier = None) -> None: 38 | """ 39 | Args: 40 | target: target molecule 41 | fp_type: fingerprint type 42 | score_modifier: score modifier 43 | """ 44 | super().__init__(score_modifier=score_modifier) 45 | 46 | self.target = target 47 | self.fp_type = fp_type 48 | target_mol = smiles_to_rdkit_mol(target) 49 | if target_mol is None: 50 | raise RuntimeError(f'The similarity target {target} is not a valid molecule.') 51 | 52 | self.ref_fp = get_fingerprint(target_mol, self.fp_type) 53 | 54 | def score_mol(self, mol: Chem.Mol) -> float: 55 | fp = get_fingerprint(mol, self.fp_type) 56 | return TanimotoSimilarity(fp, self.ref_fp) 57 | 58 | 59 | class CNS_MPO_ScoringFunction(ScoringFunctionBasedOnRdkitMol): 60 | """ 61 | CNS MPO scoring function 62 | """ 63 | 64 | def __init__(self, max_logP=5.0, maxMW=360, min_tpsa=40, max_tpsa=90, max_hbd=0) -> None: 65 | super().__init__() 66 | 67 | self.logP_gauss = MinGaussianModifier(max_logP, 1) 68 | self.molW_gauss = MinGaussianModifier(maxMW, 60) 69 | self.tpsa_maxgauss = MaxGaussianModifier(min_tpsa, 20) 70 | self.tpsa_mingauss = MinGaussianModifier(max_tpsa, 30) 71 | self.hbd_gauss = MinGaussianModifier(max_hbd, 2.0) 72 | 73 | def score_mol(self, mol: Chem.Mol) -> float: 74 | mw = mol_weight(mol) 75 | lp = logP(mol) 76 | hbd = num_H_donors(mol) 77 | mol_tpsa = tpsa(mol) 78 | 79 | o1 = self.tpsa_mingauss(mol_tpsa) 80 | o2 = self.tpsa_maxgauss(mol_tpsa) 81 | o3 = self.hbd_gauss(hbd) 82 | o4 = self.logP_gauss(lp) 83 | o5 = self.molW_gauss(mw) 84 | 85 | return 0.2 * (o1 + o2 + o3 + o4 + o5) 86 | 87 | 88 | class IsomerScoringFunction(MoleculewiseScoringFunction): 89 | """ 90 | Scoring function for closeness to a molecular formula. 91 | 92 | The score penalizes deviations from the required number of atoms for each element type, and for the total 93 | number of atoms. 94 | 95 | F.i., if the target formula is C2H4, the scoring function is the average of three contributions: 96 | - number of C atoms with a Gaussian modifier with mu=2, sigma=1 97 | - number of H atoms with a Gaussian modifier with mu=4, sigma=1 98 | - total number of atoms with a Gaussian modifier with mu=6, sigma=2 99 | """ 100 | 101 | def __init__(self, molecular_formula: str, mean_function='geometric') -> None: 102 | """ 103 | Args: 104 | molecular_formula: target molecular formula 105 | mean_function: which function to use for averaging: 'arithmetic' or 'geometric' 106 | """ 107 | super().__init__() 108 | 109 | self.mean_function = self.determine_mean_function(mean_function) 110 | self.scoring_functions = self.determine_scoring_functions(molecular_formula) 111 | 112 | @staticmethod 113 | def determine_mean_function(mean_function: str) -> Callable[[List[float]], float]: 114 | if mean_function == 'arithmetic': 115 | return arithmetic_mean 116 | if mean_function == 'geometric': 117 | return geometric_mean 118 | raise ValueError(f'Invalid mean function: "{mean_function}"') 119 | 120 | @staticmethod 121 | def determine_scoring_functions(molecular_formula: str) -> List[RdkitScoringFunction]: 122 | element_occurrences = parse_molecular_formula(molecular_formula) 123 | 124 | total_number_atoms = sum(element_tuple[1] for element_tuple in element_occurrences) 125 | 126 | # scoring functions for each element 127 | functions = [RdkitScoringFunction(descriptor=AtomCounter(element), 128 | score_modifier=GaussianModifier(mu=n_atoms, sigma=1.0)) 129 | for element, n_atoms in element_occurrences] 130 | 131 | # scoring functions for the total number of atoms 132 | functions.append(RdkitScoringFunction(descriptor=num_atoms, 133 | score_modifier=GaussianModifier(mu=total_number_atoms, sigma=2.0))) 134 | 135 | return functions 136 | 137 | def raw_score(self, smiles: str) -> float: 138 | # return the average of all scoring functions 139 | scores = [f.score(smiles) for f in self.scoring_functions] 140 | if self.corrupt_score in scores: 141 | return self.corrupt_score 142 | return self.mean_function(scores) 143 | 144 | 145 | class SMARTSScoringFunction(ScoringFunctionBasedOnRdkitMol): 146 | """ 147 | Tests for SMARTS which should be or should not be present in the compound. 148 | 149 | 150 | """ 151 | 152 | def __init__(self, target: str, inverse=False) -> None: 153 | """ 154 | 155 | :param target: The SMARTS string to match. 156 | :param inverse: Specifies whether the SMARTS is desired (False) or an antipattern, which we don't want to see 157 | in the molecules (inverse=False) 158 | """ 159 | super().__init__() 160 | self.inverse = inverse 161 | self.smarts = target 162 | self.target = Chem.MolFromSmarts(target) 163 | 164 | assert target is not None 165 | 166 | def score_mol(self, mol: Chem.Mol) -> float: 167 | 168 | matches = mol.GetSubstructMatches(self.target) 169 | 170 | if len(matches) > 0: 171 | if self.inverse: 172 | return 0.0 173 | else: 174 | return 1.0 175 | else: 176 | if self.inverse: 177 | return 1.0 178 | else: 179 | return 0.0 180 | -------------------------------------------------------------------------------- /guacamol/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenevolentAI/guacamol/60ebe1f6a396f16e08b834dce448e9343d259feb/guacamol/data/__init__.py -------------------------------------------------------------------------------- /guacamol/data/get_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gzip 3 | import hashlib 4 | import logging 5 | import numpy as np 6 | import os.path 7 | import pkgutil 8 | import platform 9 | from joblib import Parallel, delayed 10 | from typing import List, Iterable, Optional, Set 11 | 12 | from guacamol.utils.chemistry import canonicalize_list, filter_and_canonicalize, \ 13 | initialise_neutralisation_reactions, split_charged_mol, get_fingerprints_from_smileslist 14 | from guacamol.utils.data import download_if_not_present 15 | from guacamol.utils.helpers import setup_default_logger 16 | 17 | logger = logging.getLogger(__name__) 18 | logger.addHandler(logging.NullHandler()) 19 | 20 | TRAIN_HASH = '05ad85d871958a05c02ab51a4fde8530' 21 | VALID_HASH = 'e53db4bff7dc4784123ae6df72e3b1f0' 22 | TEST_HASH = '677b757ccec4809febd83850b43e1616' 23 | 24 | CHEMBL_URL = 'ftp://ftp.ebi.ac.uk/pub/databases/chembl/ChEMBLdb/releases/chembl_24_1/chembl_24_1_chemreps.txt.gz' 25 | CHEMBL_FILE_NAME = 'chembl_24_1_chemreps.txt.gz' 26 | 27 | # Threshold to remove molecules too similar to the holdout set 28 | TANIMOTO_CUTOFF = 0.323 29 | 30 | 31 | def get_argparser(): 32 | parser = argparse.ArgumentParser(description='Data Preparation for GuacaMol', 33 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 34 | parser.add_argument('-o', '--destination', default='.', help='Download and Output location') 35 | parser.add_argument('--n_jobs', default=8, type=int, help='Number of cores to use') 36 | return parser 37 | 38 | 39 | def extract_chembl(line) -> str: 40 | """ 41 | Extract smiles from chembl tsv 42 | 43 | Returns: 44 | SMILES string 45 | """ 46 | return line.split('\t')[1] 47 | 48 | 49 | def extract_smilesfile(line) -> str: 50 | """ 51 | Extract smiles from SMILES file 52 | 53 | Returns: 54 | SMILES string 55 | """ 56 | return line.split(' ')[0].strip() 57 | 58 | 59 | class AllowedSmilesCharDictionary(object): 60 | """ 61 | A fixed dictionary for druglike SMILES. 62 | """ 63 | 64 | def __init__(self, forbidden_symbols: Optional[Set[str]] = None) -> None: 65 | if forbidden_symbols is None: 66 | forbidden_symbols = {'Ag', 'Al', 'Am', 'Ar', 'At', 'Au', 'D', 'E', 'Fe', 'G', 'K', 'L', 'M', 'Ra', 'Re', 67 | 'Rf', 'Rg', 'Rh', 'Ru', 'T', 'U', 'V', 'W', 'Xe', 68 | 'Y', 'Zr', 'a', 'd', 'f', 'g', 'h', 'k', 'm', 'si', 't', 'te', 'u', 'v', 'y'} 69 | self.forbidden_symbols = forbidden_symbols 70 | 71 | def allowed(self, smiles: str) -> bool: 72 | """ 73 | Determine if SMILES string has illegal symbols 74 | 75 | Args: 76 | smiles: SMILES string 77 | 78 | Returns: 79 | True if all legal 80 | """ 81 | for symbol in self.forbidden_symbols: 82 | if symbol in smiles: 83 | print('Forbidden symbol {:<2} in {}'.format(symbol, smiles)) 84 | return False 85 | return True 86 | 87 | 88 | def get_raw_smiles(file_name, smiles_char_dict, open_fn, extract_fn) -> List[str]: 89 | """ 90 | Extracts the raw smiles from an input file. 91 | open_fn will open the file to iterate over it (e.g. use open_fn=open or open_fn=filegzip.open) 92 | extract_fn specifies how to process the lines, choose from 93 | Pre-filter molecules of 5 <= length <= 200, because processing larger molecules (e.g. peptides) takes very long. 94 | 95 | Returns: 96 | a list of SMILES strings 97 | """ 98 | data = [] 99 | # open the gzipped chembl filegzip.open 100 | with open_fn(file_name, 'rt') as f: 101 | 102 | line_count = 0 103 | for line in f: 104 | 105 | line_count += 1 106 | # extract the canonical smiles column 107 | if platform.system() == "Windows": 108 | line = line.decode("utf-8") 109 | 110 | # smiles = line.split('\t')[1] 111 | 112 | smiles = extract_fn(line) 113 | 114 | # only keep reasonably sized molecules 115 | if 5 <= len(smiles) <= 200: 116 | 117 | smiles = split_charged_mol(smiles) 118 | 119 | if smiles_char_dict.allowed(smiles): 120 | # check whether the molecular graph consists of 121 | # multiple connected components (eg. in salts) 122 | # if so, just keep the largest one 123 | 124 | data.append(smiles) 125 | 126 | print(f'Processed {len(data)} molecules from {line_count} lines in the input file.') 127 | 128 | return data 129 | 130 | 131 | def write_smiles(dataset: Iterable[str], filename: str): 132 | """ 133 | Dumps a list of SMILES into a file, one per line 134 | """ 135 | n_lines = 0 136 | with open(filename, 'w') as out: 137 | for smiles_str in dataset: 138 | out.write('%s\n' % smiles_str) 139 | n_lines += 1 140 | print(f'{filename} contains {n_lines} molecules') 141 | 142 | 143 | def compare_hash(output_file: str, correct_hash: str) -> bool: 144 | """ 145 | Computes the md5 hash of a SMILES file and check it against a given one 146 | Returns false if hashes are different 147 | """ 148 | output_hash = hashlib.md5(open(output_file, 'rb').read()).hexdigest() 149 | if output_hash != correct_hash: 150 | logger.error(f'{output_file} file has different hash, {output_hash}, than expected, {correct_hash}!') 151 | return False 152 | 153 | return True 154 | 155 | 156 | def main(): 157 | """ Get Chembl-24. 158 | 159 | Preprocessing steps: 160 | 161 | 1) filter SMILES shorter than 5 and longer than 200 chars and those with forbidden symbols 162 | 2) canonicalize, neutralize, only permit smiles shorter than 100 chars 163 | 3) shuffle, write files, check if they are consistently hashed. 164 | """ 165 | setup_default_logger() 166 | 167 | argparser = get_argparser() 168 | args = argparser.parse_args() 169 | 170 | # Set constants 171 | np.random.seed(1337) 172 | neutralization_rxns = initialise_neutralisation_reactions() 173 | smiles_dict = AllowedSmilesCharDictionary() 174 | 175 | print('Preprocessing ChEMBL molecules...') 176 | 177 | chembl_file = os.path.join(args.destination, CHEMBL_FILE_NAME) 178 | 179 | # read holdout set and decode it 180 | raw_data = pkgutil.get_data('guacamol.data', 'holdout_set_gcm_v1.smiles') 181 | assert raw_data is not None 182 | data = raw_data.decode('utf-8').splitlines() 183 | 184 | holdout_mols = [i.split(' ')[0] for i in data] 185 | holdout_set = set(canonicalize_list(holdout_mols, False)) 186 | holdout_fps = get_fingerprints_from_smileslist(holdout_set) 187 | 188 | # Download Chembl24 if needed. 189 | download_if_not_present(chembl_file, 190 | uri=CHEMBL_URL) 191 | raw_smiles = get_raw_smiles(chembl_file, smiles_char_dict=smiles_dict, open_fn=gzip.open, 192 | extract_fn=extract_chembl) 193 | 194 | file_prefix = 'chembl24_canon' 195 | 196 | print(f'and standardizing {len(raw_smiles)} molecules using {args.n_jobs} cores, ' 197 | f'and excluding molecules based on ECFP4 similarity of > {TANIMOTO_CUTOFF} to the holdout set.') 198 | 199 | # Process all the SMILES in parallel 200 | runner = Parallel(n_jobs=args.n_jobs, verbose=2) 201 | 202 | joblist = (delayed(filter_and_canonicalize)(smiles_str, 203 | holdout_set, 204 | holdout_fps, 205 | neutralization_rxns, 206 | TANIMOTO_CUTOFF, 207 | False) 208 | for smiles_str in raw_smiles) 209 | 210 | output = runner(joblist) 211 | 212 | # Put all nonzero molecules in a list, remove duplicates, sort and shuffle 213 | 214 | all_good_mols = sorted(list(set([item[0] for item in output if item]))) 215 | np.random.shuffle(all_good_mols) 216 | print(f'Ended up with {len(all_good_mols)} molecules. Preparing splits...') 217 | 218 | # Split into train-dev-test 219 | # Check whether the md5-hashes of the generated smiles files match 220 | # the precomputed hashes, this ensures everyone works with the same splits. 221 | 222 | VALID_SIZE = int(0.05 * len(all_good_mols)) 223 | TEST_SIZE = int(0.15 * len(all_good_mols)) 224 | 225 | dev_set = all_good_mols[0:VALID_SIZE] 226 | dev_path = os.path.join(args.destination, f'{file_prefix}_dev-valid.smiles') 227 | write_smiles(dev_set, dev_path) 228 | 229 | test_set = all_good_mols[VALID_SIZE:VALID_SIZE + TEST_SIZE] 230 | test_path = os.path.join(args.destination, f'{file_prefix}_test.smiles') 231 | write_smiles(test_set, test_path) 232 | 233 | train_set = all_good_mols[VALID_SIZE + TEST_SIZE:] 234 | train_path = os.path.join(args.destination, f'{file_prefix}_train.smiles') 235 | write_smiles(train_set, train_path) 236 | 237 | # check the hashes 238 | valid_hashes = [ 239 | compare_hash(train_path, TRAIN_HASH), 240 | compare_hash(dev_path, VALID_HASH), 241 | compare_hash(test_path, TEST_HASH), 242 | ] 243 | 244 | if not all(valid_hashes): 245 | raise SystemExit('Invalid hashes for the dataset files') 246 | 247 | print('Dataset generation successful. You are ready to go.') 248 | 249 | 250 | if __name__ == '__main__': 251 | main() 252 | -------------------------------------------------------------------------------- /guacamol/data/holdout_set_gcm_v1.smiles: -------------------------------------------------------------------------------- 1 | CC1=CC=C(C=C1)C1=CC(=NN1C1=CC=C(C=C1)S(N)(=O)=O)C(F)(F)F Celecoxib 2 | Clc4cccc(N3CCN(CCCCOc2ccc1c(NC(=O)CC1)c2)CC3)c4Cl Aripiprazole 3 | OC1(CN(C1)C(=O)C1=C(NC2=C(F)C=C(I)C=C2)C(F)=C(F)C=C1)C1CCCCN1 Cobimetinib 4 | COc1cc(N(C)CCN(C)C)c(NC(=O)C=C)cc1Nc2nccc(n2)c3cn(C)c4ccccc34 Osimertinib 5 | Cc1c(C)c2OC(C)(COc3ccc(CC4SC(=O)NC4=O)cc3)CCc2c(C)c1O Troglitazone 6 | COc1ccccc1OCC(O)CN2CCN(CC(=O)Nc3c(C)cccc3C)CC2 Ranolazine 7 | CN(C)S(=O)(=O)c1ccc2Sc3ccccc3C(=CCCN4CCN(C)CC4)c2c1 Thiothixene 8 | CC(C)(C)NCC(O)c1ccc(O)c(CO)c1 Albuterol 9 | CC(C)(C(=O)O)c1ccc(cc1)C(O)CCCN2CCC(CC2)C(O)(c3ccccc3)c4ccccc4 Fexofenadine 10 | COc1ccc2[C@H]3CC[C@@]4(C)[C@@H](CC[C@@]4(O)C#C)[C@@H]3CCc2c1 Mestranol -------------------------------------------------------------------------------- /guacamol/distribution_learning_benchmark.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from abc import abstractmethod 4 | from typing import Dict, Any, Iterable, List 5 | import numpy as np 6 | 7 | from guacamol.utils.chemistry import canonicalize_list, is_valid, calculate_pc_descriptors, continuous_kldiv, \ 8 | discrete_kldiv, calculate_internal_pairwise_similarities 9 | from guacamol.distribution_matching_generator import DistributionMatchingGenerator 10 | from guacamol.utils.data import get_random_subset 11 | from guacamol.utils.sampling_helpers import sample_valid_molecules, sample_unique_molecules 12 | 13 | logger = logging.getLogger(__name__) 14 | logger.addHandler(logging.NullHandler()) 15 | 16 | 17 | class DistributionLearningBenchmarkResult: 18 | """ 19 | Contains the results of a distribution learning benchmark. 20 | 21 | NB: timing does not make sense since training happens outside of DistributionLearningBenchmark. 22 | """ 23 | 24 | def __init__(self, benchmark_name: str, score: float, sampling_time: float, metadata: Dict[str, Any]) -> None: 25 | """ 26 | Args: 27 | benchmark_name: name of the distribution-learning benchmark 28 | score: benchmark score 29 | sampling_time: time for sampling the molecules in seconds 30 | metadata: benchmark-specific information 31 | """ 32 | self.benchmark_name = benchmark_name 33 | self.score = score 34 | self.sampling_time = sampling_time 35 | self.metadata = metadata 36 | 37 | 38 | class DistributionLearningBenchmark: 39 | """ 40 | Base class for assessing how well a model is able to generate molecules matching a molecule distribution. 41 | 42 | Derived class should implement the assess_molecules function 43 | """ 44 | 45 | def __init__(self, name: str, number_samples: int) -> None: 46 | self.name = name 47 | self.number_samples = number_samples 48 | 49 | @abstractmethod 50 | def assess_model(self, model: DistributionMatchingGenerator) -> DistributionLearningBenchmarkResult: 51 | """ 52 | Assess a distribution-matching generator model. 53 | 54 | Args: 55 | model: model to assess 56 | """ 57 | 58 | 59 | class ValidityBenchmark(DistributionLearningBenchmark): 60 | """ 61 | Assesses what percentage of molecules generated by a model are valid molecules (i.e. have a valid SMILES string) 62 | """ 63 | 64 | def __init__(self, number_samples) -> None: 65 | super().__init__(name='Validity', number_samples=number_samples) 66 | 67 | def assess_model(self, model: DistributionMatchingGenerator) -> DistributionLearningBenchmarkResult: 68 | start_time = time.time() 69 | molecules = model.generate(number_samples=self.number_samples) 70 | end_time = time.time() 71 | 72 | if len(molecules) != self.number_samples: 73 | raise Exception('The model did not generate the correct number of molecules') 74 | 75 | number_valid = sum(1 if is_valid(smiles) else 0 for smiles in molecules) 76 | validity_ratio = number_valid / self.number_samples 77 | metadata = { 78 | 'number_samples': self.number_samples, 79 | 'number_valid': number_valid, 80 | } 81 | 82 | return DistributionLearningBenchmarkResult(benchmark_name=self.name, 83 | score=validity_ratio, 84 | sampling_time=end_time - start_time, 85 | metadata=metadata) 86 | 87 | 88 | class UniquenessBenchmark(DistributionLearningBenchmark): 89 | """ 90 | Assesses what percentage of molecules generated by a model are unique. 91 | """ 92 | 93 | def __init__(self, number_samples) -> None: 94 | super().__init__(name='Uniqueness', number_samples=number_samples) 95 | 96 | def assess_model(self, model: DistributionMatchingGenerator) -> DistributionLearningBenchmarkResult: 97 | start_time = time.time() 98 | molecules = sample_valid_molecules(model=model, number_molecules=self.number_samples) 99 | end_time = time.time() 100 | 101 | if len(molecules) != self.number_samples: 102 | logger.warning('The model could not generate enough valid molecules. The score will be penalized.') 103 | 104 | # canonicalize_list removes duplicates (and invalid molecules, but there shouldn't be any) 105 | unique_molecules = canonicalize_list(molecules, include_stereocenters=False) 106 | 107 | unique_ratio = len(unique_molecules) / self.number_samples 108 | metadata = { 109 | 'number_samples': self.number_samples, 110 | 'number_unique': len(unique_molecules) 111 | } 112 | 113 | return DistributionLearningBenchmarkResult(benchmark_name=self.name, 114 | score=unique_ratio, 115 | sampling_time=end_time - start_time, 116 | metadata=metadata) 117 | 118 | 119 | class NoveltyBenchmark(DistributionLearningBenchmark): 120 | def __init__(self, number_samples: int, training_set: Iterable[str]) -> None: 121 | """ 122 | Args: 123 | number_samples: number of samples to generate from the model 124 | training_set: molecules from the training set 125 | """ 126 | super().__init__(name='Novelty', number_samples=number_samples) 127 | self.training_set_molecules = set(canonicalize_list(training_set, include_stereocenters=False)) 128 | 129 | def assess_model(self, model: DistributionMatchingGenerator) -> DistributionLearningBenchmarkResult: 130 | """ 131 | Assess a distribution-matching generator model. 132 | 133 | Args: 134 | model: model to assess 135 | """ 136 | start_time = time.time() 137 | molecules = sample_unique_molecules(model=model, number_molecules=self.number_samples, max_tries=2) 138 | end_time = time.time() 139 | 140 | if len(molecules) != self.number_samples: 141 | logger.warning('The model could not generate enough unique molecules. The score will be penalized.') 142 | 143 | # canonicalize_list in order to remove stereo information (also removes duplicates and invalid molecules, but there shouldn't be any) 144 | unique_molecules = set(canonicalize_list(molecules, include_stereocenters=False)) 145 | 146 | novel_molecules = unique_molecules.difference(self.training_set_molecules) 147 | 148 | novel_ratio = len(novel_molecules) / self.number_samples 149 | 150 | metadata = { 151 | 'number_samples': self.number_samples, 152 | 'number_novel': len(novel_molecules) 153 | } 154 | 155 | return DistributionLearningBenchmarkResult(benchmark_name=self.name, 156 | score=novel_ratio, 157 | sampling_time=end_time - start_time, 158 | metadata=metadata) 159 | 160 | 161 | class KLDivBenchmark(DistributionLearningBenchmark): 162 | """ 163 | Computes the KL divergence between a number of samples and the training set for physchem descriptors 164 | """ 165 | 166 | def __init__(self, number_samples: int, training_set: List[str]) -> None: 167 | """ 168 | Args: 169 | number_samples: number of samples to generate from the model 170 | training_set: molecules from the training set 171 | """ 172 | super().__init__(name='KL divergence', number_samples=number_samples) 173 | self.training_set_molecules = canonicalize_list(get_random_subset(training_set, self.number_samples, seed=42), 174 | include_stereocenters=False) 175 | self.pc_descriptor_subset = [ 176 | 'BertzCT', 177 | 'MolLogP', 178 | 'MolWt', 179 | 'TPSA', 180 | 'NumHAcceptors', 181 | 'NumHDonors', 182 | 'NumRotatableBonds', 183 | 'NumAliphaticRings', 184 | 'NumAromaticRings' 185 | ] 186 | 187 | def assess_model(self, model: DistributionMatchingGenerator) -> DistributionLearningBenchmarkResult: 188 | """ 189 | Assess a distribution-matching generator model. 190 | 191 | Args: 192 | model: model to assess 193 | """ 194 | start_time = time.time() 195 | molecules = sample_unique_molecules(model=model, number_molecules=self.number_samples, max_tries=2) 196 | end_time = time.time() 197 | 198 | if len(molecules) != self.number_samples: 199 | logger.warning('The model could not generate enough unique molecules. The score will be penalized.') 200 | 201 | # canonicalize_list in order to remove stereo information (also removes duplicates and invalid molecules, but there shouldn't be any) 202 | unique_molecules = set(canonicalize_list(molecules, include_stereocenters=False)) 203 | 204 | # first we calculate the descriptors, which are np.arrays of size n_samples x n_descriptors 205 | d_sampled = calculate_pc_descriptors(unique_molecules, self.pc_descriptor_subset) 206 | d_chembl = calculate_pc_descriptors(self.training_set_molecules, self.pc_descriptor_subset) 207 | 208 | kldivs = {} 209 | 210 | # now we calculate the kl divergence for the float valued descriptors ... 211 | for i in range(4): 212 | kldiv = continuous_kldiv(X_baseline=d_chembl[:, i], X_sampled=d_sampled[:, i]) 213 | kldivs[self.pc_descriptor_subset[i]] = kldiv 214 | 215 | # ... and for the int valued ones. 216 | for i in range(4, 9): 217 | kldiv = discrete_kldiv(X_baseline=d_chembl[:, i], X_sampled=d_sampled[:, i]) 218 | kldivs[self.pc_descriptor_subset[i]] = kldiv 219 | 220 | # pairwise similarity 221 | 222 | chembl_sim = calculate_internal_pairwise_similarities(self.training_set_molecules) 223 | chembl_sim = chembl_sim.max(axis=1) 224 | 225 | sampled_sim = calculate_internal_pairwise_similarities(unique_molecules) 226 | sampled_sim = sampled_sim.max(axis=1) 227 | 228 | kldiv_int_int = continuous_kldiv(X_baseline=chembl_sim, X_sampled=sampled_sim) 229 | kldivs['internal_similarity'] = kldiv_int_int 230 | 231 | # for some reason, this runs into problems when both sets are identical. 232 | # cross_set_sim = calculate_pairwise_similarities(self.training_set_molecules, unique_molecules) 233 | # cross_set_sim = cross_set_sim.max(axis=1) 234 | # 235 | # kldiv_ext = discrete_kldiv(chembl_sim, cross_set_sim) 236 | # kldivs['external_similarity'] = kldiv_ext 237 | # kldiv_sum += kldiv_ext 238 | 239 | metadata = { 240 | 'number_samples': self.number_samples, 241 | 'kl_divs': kldivs 242 | } 243 | 244 | # Each KL divergence value is transformed to be in [0, 1]. 245 | # Then their average delivers the final score. 246 | partial_scores = [np.exp(-score) for score in kldivs.values()] 247 | score = sum(partial_scores) / len(partial_scores) 248 | 249 | return DistributionLearningBenchmarkResult(benchmark_name=self.name, 250 | score=score, 251 | sampling_time=end_time - start_time, 252 | metadata=metadata) 253 | -------------------------------------------------------------------------------- /guacamol/distribution_matching_generator.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import List 3 | 4 | 5 | class DistributionMatchingGenerator(metaclass=ABCMeta): 6 | """ 7 | Interface for molecule generators. 8 | """ 9 | 10 | @abstractmethod 11 | def generate(self, number_samples: int) -> List[str]: 12 | """ 13 | Samples SMILES strings from a molecule generator. 14 | 15 | Args: 16 | number_samples: number of molecules to generate 17 | 18 | Returns: 19 | A list of SMILES strings. 20 | """ 21 | -------------------------------------------------------------------------------- /guacamol/frechet_benchmark.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pkgutil 4 | import tempfile 5 | import time 6 | from typing import List 7 | 8 | import fcd 9 | import numpy as np 10 | 11 | from guacamol.distribution_learning_benchmark import DistributionLearningBenchmark, DistributionLearningBenchmarkResult 12 | from guacamol.distribution_matching_generator import DistributionMatchingGenerator 13 | from guacamol.utils.data import get_random_subset 14 | from guacamol.utils.sampling_helpers import sample_valid_molecules 15 | 16 | logger = logging.getLogger(__name__) 17 | logger.addHandler(logging.NullHandler()) 18 | 19 | 20 | class FrechetBenchmark(DistributionLearningBenchmark): 21 | """ 22 | Calculates the Fréchet ChemNet Distance. 23 | 24 | See http://dx.doi.org/10.1021/acs.jcim.8b00234 for the publication. 25 | """ 26 | 27 | def __init__(self, training_set: List[str], 28 | chemnet_model_filename='ChemNet_v0.13_pretrained.h5', 29 | sample_size=10000) -> None: 30 | """ 31 | Args: 32 | training_set: molecules from the training set 33 | chemnet_model_filename: name of the file for trained ChemNet model. 34 | Must be present in the 'fcd' package, since it will be loaded directly from there. 35 | sample_size: how many molecules to generate the distribution statistics from (both reference data and model) 36 | """ 37 | self.chemnet_model_filename = chemnet_model_filename 38 | self.sample_size = sample_size 39 | super().__init__(name='Frechet ChemNet Distance', number_samples=self.sample_size) 40 | 41 | self.reference_molecules = get_random_subset(training_set, self.sample_size, seed=42) 42 | 43 | def assess_model(self, model: DistributionMatchingGenerator) -> DistributionLearningBenchmarkResult: 44 | chemnet = self._load_chemnet() 45 | 46 | start_time = time.time() 47 | generated_molecules = sample_valid_molecules(model=model, number_molecules=self.number_samples) 48 | end_time = time.time() 49 | 50 | if len(generated_molecules) != self.number_samples: 51 | logger.warning('The model could not generate enough valid molecules.') 52 | 53 | mu_ref, cov_ref = self._calculate_distribution_statistics(chemnet, self.reference_molecules) 54 | mu, cov = self._calculate_distribution_statistics(chemnet, generated_molecules) 55 | 56 | FCD = fcd.calculate_frechet_distance(mu1=mu_ref, mu2=mu, 57 | sigma1=cov_ref, sigma2=cov) 58 | score = np.exp(-0.2 * FCD) 59 | 60 | metadata = { 61 | 'number_reference_molecules': len(self.reference_molecules), 62 | 'number_generated_molecules': len(generated_molecules), 63 | 'FCD': FCD 64 | } 65 | 66 | return DistributionLearningBenchmarkResult(benchmark_name=self.name, 67 | score=score, 68 | sampling_time=end_time - start_time, 69 | metadata=metadata) 70 | 71 | def _load_chemnet(self): 72 | """ 73 | Load the ChemNet model from the file specified in the init function. 74 | 75 | This file lives inside a package but to use it, it must always be an actual file. 76 | The safest way to proceed is therefore: 77 | 1. read the file with pkgutil 78 | 2. save it to a temporary file 79 | 3. load the model from the temporary file 80 | """ 81 | model_bytes = pkgutil.get_data('fcd', self.chemnet_model_filename) 82 | assert model_bytes is not None 83 | 84 | tmpdir = tempfile.gettempdir() 85 | model_path = os.path.join(tmpdir, self.chemnet_model_filename) 86 | 87 | with open(model_path, 'wb') as f: 88 | f.write(model_bytes) 89 | 90 | logger.info(f'Saved ChemNet model to \'{model_path}\'') 91 | 92 | return fcd.load_ref_model(model_path) 93 | 94 | def _calculate_distribution_statistics(self, model, molecules: List[str]): 95 | sample_std = fcd.canonical_smiles(molecules) 96 | gen_mol_act = fcd.get_predictions(model, sample_std) 97 | 98 | mu = np.mean(gen_mol_act, axis=0) 99 | cov = np.cov(gen_mol_act.T) 100 | return mu, cov 101 | -------------------------------------------------------------------------------- /guacamol/goal_directed_benchmark.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from typing import Any, Dict, List, Tuple, Optional 4 | 5 | import numpy as np 6 | 7 | from guacamol.goal_directed_score_contributions import ScoreContributionSpecification, compute_global_score 8 | from guacamol.scoring_function import ScoringFunction, ScoringFunctionWrapper 9 | from guacamol.goal_directed_generator import GoalDirectedGenerator 10 | from guacamol.utils.chemistry import canonicalize_list, remove_duplicates, calculate_internal_pairwise_similarities 11 | 12 | logger = logging.getLogger(__name__) 13 | logger.addHandler(logging.NullHandler()) 14 | 15 | 16 | class GoalDirectedBenchmarkResult: 17 | """ 18 | Contains the results of a goal-directed benchmark. 19 | """ 20 | 21 | def __init__(self, benchmark_name: str, score: float, optimized_molecules: List[Tuple[str, float]], 22 | execution_time: float, number_scoring_function_calls: int, metadata: Dict[str, Any]) -> None: 23 | """ 24 | Args: 25 | benchmark_name: name of the goal-directed benchmark 26 | score: benchmark score 27 | optimized_molecules: generated molecules, given as a list of (SMILES string, molecule score) tuples 28 | execution_time: execution time for the benchmark in seconds 29 | number_scoring_function_calls: number of calls to the scoring function 30 | metadata: benchmark-specific information 31 | """ 32 | self.benchmark_name = benchmark_name 33 | self.score = score 34 | self.optimized_molecules = optimized_molecules 35 | self.execution_time = execution_time 36 | self.number_scoring_function_calls = number_scoring_function_calls 37 | self.metadata = metadata 38 | 39 | 40 | class GoalDirectedBenchmark: 41 | """ 42 | This class assesses how well a model is able to generate molecules satisfying a given objective. 43 | """ 44 | 45 | def __init__(self, name: str, objective: ScoringFunction, 46 | contribution_specification: ScoreContributionSpecification, 47 | starting_population: Optional[List[str]] = None) -> None: 48 | """ 49 | Args: 50 | name: Benchmark name 51 | objective: Objective for the goal-directed optimization 52 | contribution_specification: Specifies how to calculate the global benchmark score 53 | """ 54 | self.name = name 55 | self.objective = objective 56 | self.wrapped_objective = ScoringFunctionWrapper(scoring_function=objective) 57 | self.contribution_specification = contribution_specification 58 | self.starting_population = starting_population 59 | 60 | def assess_model(self, model: GoalDirectedGenerator) -> GoalDirectedBenchmarkResult: 61 | """ 62 | Assess the given model by asking it to generate molecules optimizing a scoring function. 63 | The number of molecules to generate is determined automatically from the score contribution specification. 64 | 65 | Args: 66 | model: model to assess 67 | """ 68 | number_molecules_to_generate = max(self.contribution_specification.top_counts) 69 | start_time = time.time() 70 | molecules = model.generate_optimized_molecules(scoring_function=self.wrapped_objective, 71 | number_molecules=number_molecules_to_generate, 72 | starting_population=self.starting_population 73 | ) 74 | end_time = time.time() 75 | 76 | canonicalized_molecules = canonicalize_list(molecules, include_stereocenters=False) 77 | unique_molecules = remove_duplicates(canonicalized_molecules) 78 | scores = self.objective.score_list(unique_molecules) 79 | 80 | if len(unique_molecules) != number_molecules_to_generate: 81 | number_missing = number_molecules_to_generate - len(unique_molecules) 82 | logger.warning(f'An incorrect number of distinct molecules was generated: ' 83 | f'{len(unique_molecules)} instead of {number_molecules_to_generate}. ' 84 | f'Padding scores with {number_missing} zeros...') 85 | scores.extend([0.0] * number_missing) 86 | 87 | global_score, top_x_dict = compute_global_score(self.contribution_specification, scores) 88 | 89 | scored_molecules = zip(unique_molecules, scores) 90 | sorted_scored_molecules = sorted(scored_molecules, key=lambda x: (x[1], x[0]), reverse=True) 91 | 92 | internal_similarities = calculate_internal_pairwise_similarities(unique_molecules) 93 | 94 | # accumulate internal_similarities in metadata 95 | int_simi_histogram = np.histogram(internal_similarities, bins=10, range=(0, 1), density=True) 96 | 97 | metadata: Dict[str, Any] = {} 98 | metadata.update(top_x_dict) 99 | metadata['internal_similarity_max'] = internal_similarities.max() 100 | metadata['internal_similarity_mean'] = internal_similarities.mean() 101 | metadata["internal_similarity_histogram_density"] = int_simi_histogram[0].tolist(), 102 | metadata["internal_similarity_histogram_bins"] = int_simi_histogram[1].tolist(), 103 | 104 | return GoalDirectedBenchmarkResult(benchmark_name=self.name, 105 | score=global_score, 106 | optimized_molecules=sorted_scored_molecules, 107 | execution_time=end_time - start_time, 108 | number_scoring_function_calls=self.wrapped_objective.evaluations, 109 | metadata=metadata) 110 | -------------------------------------------------------------------------------- /guacamol/goal_directed_generator.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import List, Optional 3 | 4 | from guacamol.scoring_function import ScoringFunction 5 | 6 | 7 | class GoalDirectedGenerator(metaclass=ABCMeta): 8 | """ 9 | Interface for goal-directed molecule generators. 10 | """ 11 | 12 | @abstractmethod 13 | def generate_optimized_molecules(self, scoring_function: ScoringFunction, number_molecules: int, 14 | starting_population: Optional[List[str]] = None) -> List[str]: 15 | """ 16 | Given an objective function, generate molecules that score as high as possible. 17 | 18 | Args: 19 | scoring_function: scoring function 20 | number_molecules: number of molecules to generate 21 | starting_population: molecules to start the optimization from (optional) 22 | 23 | Returns: 24 | A list of SMILES strings for the generated molecules. 25 | """ 26 | -------------------------------------------------------------------------------- /guacamol/goal_directed_score_contributions.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Dict 2 | 3 | 4 | class ScoreContributionSpecification: 5 | """ 6 | Specifies how to calculate the score of a goal-directed benchmark. 7 | 8 | The global score will be a weighted average of top-x scores. 9 | This class specifies which top-x to consider and what the corresponding weights are. 10 | """ 11 | 12 | def __init__(self, contributions: List[Tuple[int, float]]) -> None: 13 | """ 14 | Args: 15 | contributions: List of tuples (top_count, weight) for the score contributions 16 | """ 17 | self.contributions = contributions 18 | 19 | @property 20 | def top_counts(self) -> List[int]: 21 | return [x[0] for x in self.contributions] 22 | 23 | @property 24 | def weights(self) -> List[float]: 25 | return [x[1] for x in self.contributions] 26 | 27 | 28 | def uniform_specification(*top_counts: int) -> ScoreContributionSpecification: 29 | """ 30 | Creates an instance of ScoreContributionSpecification where all the top-x contributions have equal weight 31 | 32 | Args: 33 | top_counts: list of values, where each value x will correspond to the top-x contribution 34 | """ 35 | contributions = [(x, 1.0) for x in top_counts] 36 | return ScoreContributionSpecification(contributions=contributions) 37 | 38 | 39 | def compute_global_score(contribution_specification: ScoreContributionSpecification, 40 | scores: List[float]) -> Tuple[float, Dict[str, float]]: 41 | """ 42 | Computes the global score according to the contribution specification. 43 | 44 | Args: 45 | contribution_specification: Score contribution specification 46 | scores: List of all scores - list must be long enough for all top_counts in contribution_specification 47 | 48 | Returns: 49 | Tuple with the global score and a dict with the considered top-x scores 50 | """ 51 | sorted_scores = sorted(scores, reverse=True) 52 | 53 | global_score = 0.0 54 | top_x_dict = {} 55 | 56 | for top_count, weight in contribution_specification.contributions: 57 | score = sum(sorted_scores[:top_count]) / top_count 58 | top_x_dict[f'top_{top_count}'] = score 59 | global_score += score * weight 60 | 61 | global_score /= sum(contribution_specification.weights) 62 | 63 | return global_score, top_x_dict 64 | -------------------------------------------------------------------------------- /guacamol/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenevolentAI/guacamol/60ebe1f6a396f16e08b834dce448e9343d259feb/guacamol/py.typed -------------------------------------------------------------------------------- /guacamol/score_modifier.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from functools import partial 3 | from typing import List 4 | 5 | import numpy as np 6 | 7 | 8 | class ScoreModifier: 9 | """ 10 | Interface for score modifiers. 11 | """ 12 | 13 | @abstractmethod 14 | def __call__(self, x): 15 | """ 16 | Apply the modifier on x. 17 | 18 | Args: 19 | x: float or np.array to modify 20 | 21 | Returns: 22 | float or np.array (depending on the type of x) after application of the distance function. 23 | """ 24 | 25 | 26 | class ChainedModifier(ScoreModifier): 27 | """ 28 | Calls several modifiers one after the other, for instance: 29 | score = modifier3(modifier2(modifier1(raw_score))) 30 | """ 31 | 32 | def __init__(self, modifiers: List[ScoreModifier]) -> None: 33 | """ 34 | Args: 35 | modifiers: modifiers to call in sequence. 36 | The modifier applied last (and delivering the final score) is the last one in the list. 37 | """ 38 | self.modifiers = modifiers 39 | 40 | def __call__(self, x): 41 | score = x 42 | for modifier in self.modifiers: 43 | score = modifier(score) 44 | return score 45 | 46 | 47 | class LinearModifier(ScoreModifier): 48 | """ 49 | Score modifier that multiplies the score by a scalar (default: 1, i.e. do nothing). 50 | """ 51 | 52 | def __init__(self, slope=1.0): 53 | self.slope = slope 54 | 55 | def __call__(self, x): 56 | return self.slope * x 57 | 58 | 59 | class SquaredModifier(ScoreModifier): 60 | """ 61 | Score modifier that has a maximum at a given target value, and decreases 62 | quadratically with increasing distance from the target value. 63 | """ 64 | 65 | def __init__(self, target_value: float, coefficient=1.0) -> None: 66 | self.target_value = target_value 67 | self.coefficient = coefficient 68 | 69 | def __call__(self, x): 70 | return 1.0 - self.coefficient * np.square(self.target_value - x) 71 | 72 | 73 | class AbsoluteScoreModifier(ScoreModifier): 74 | """ 75 | Score modifier that has a maximum at a given target value, and decreases 76 | linearly with increasing distance from the target value. 77 | """ 78 | 79 | def __init__(self, target_value: float) -> None: 80 | self.target_value = target_value 81 | 82 | def __call__(self, x): 83 | return 1. - np.abs(self.target_value - x) 84 | 85 | 86 | class GaussianModifier(ScoreModifier): 87 | """ 88 | Score modifier that reproduces a Gaussian bell shape. 89 | """ 90 | 91 | def __init__(self, mu: float, sigma: float) -> None: 92 | self.mu = mu 93 | self.sigma = sigma 94 | 95 | def __call__(self, x): 96 | return np.exp(-0.5 * np.power((x - self.mu) / self.sigma, 2.)) 97 | 98 | 99 | class MinMaxGaussianModifier(ScoreModifier): 100 | """ 101 | Score modifier that reproduces a half Gaussian bell shape. 102 | For minimize==True, the function is 1.0 for x <= mu and decreases to zero for x > mu. 103 | For minimize==False, the function is 1.0 for x >= mu and decreases to zero for x < mu. 104 | """ 105 | 106 | def __init__(self, mu: float, sigma: float, minimize=False) -> None: 107 | self.mu = mu 108 | self.sigma = sigma 109 | self.minimize = minimize 110 | self._full_gaussian = GaussianModifier(mu=mu, sigma=sigma) 111 | 112 | def __call__(self, x): 113 | if self.minimize: 114 | mod_x = np.maximum(x, self.mu) 115 | else: 116 | mod_x = np.minimum(x, self.mu) 117 | return self._full_gaussian(mod_x) 118 | 119 | 120 | MinGaussianModifier = partial(MinMaxGaussianModifier, minimize=True) 121 | MaxGaussianModifier = partial(MinMaxGaussianModifier, minimize=False) 122 | 123 | 124 | class ClippedScoreModifier(ScoreModifier): 125 | r""" 126 | Clips a score between specified low and high scores, and does a linear interpolation in between. 127 | 128 | The function looks like this: 129 | 130 | upper_x < lower_x lower_x < upper_x 131 | __________ ____________ 132 | \ / 133 | \ / 134 | \__________ _________/ 135 | 136 | This class works as follows: 137 | First the input is mapped onto a linear interpolation between both specified points. 138 | Then the generated values are clipped between low and high scores. 139 | """ 140 | 141 | def __init__(self, upper_x: float, lower_x=0.0, high_score=1.0, low_score=0.0) -> None: 142 | """ 143 | Args: 144 | upper_x: x-value from which (or until which if smaller than lower_x) the score is maximal 145 | lower_x: x-value until which (or from which if larger than upper_x) the score is minimal 146 | high_score: maximal score to clip to 147 | low_score: minimal score to clip to 148 | """ 149 | assert low_score < high_score 150 | 151 | self.upper_x = upper_x 152 | self.lower_x = lower_x 153 | self.high_score = high_score 154 | self.low_score = low_score 155 | 156 | self.slope = (high_score - low_score) / (upper_x - lower_x) 157 | self.intercept = high_score - self.slope * upper_x 158 | 159 | def __call__(self, x): 160 | y = self.slope * x + self.intercept 161 | return np.clip(y, self.low_score, self.high_score) 162 | 163 | 164 | class SmoothClippedScoreModifier(ScoreModifier): 165 | """ 166 | Smooth variant of ClippedScoreModifier. 167 | 168 | Implemented as a logistic function that has the same steepness as ClippedScoreModifier in the 169 | center of the logistic function. 170 | """ 171 | 172 | def __init__(self, upper_x: float, lower_x=0.0, high_score=1.0, low_score=0.0) -> None: 173 | """ 174 | Args: 175 | upper_x: x-value from which (or until which if smaller than lower_x) the score approaches high_score 176 | lower_x: x-value until which (or from which if larger than upper_x) the score approaches low_score 177 | high_score: maximal score (reached at +/- infinity) 178 | low_score: minimal score (reached at -/+ infinity) 179 | """ 180 | assert low_score < high_score 181 | 182 | self.upper_x = upper_x 183 | self.lower_x = lower_x 184 | self.high_score = high_score 185 | self.low_score = low_score 186 | 187 | # Slope of a standard logistic function in the middle is 0.25 -> rescale k accordingly 188 | self.k = 4.0 / (upper_x - lower_x) 189 | self.middle_x = (upper_x + lower_x) / 2 190 | self.L = high_score - low_score 191 | 192 | def __call__(self, x): 193 | return self.low_score + self.L / (1 + np.exp(-self.k * (x - self.middle_x))) 194 | 195 | 196 | class ThresholdedLinearModifier(ScoreModifier): 197 | """ 198 | Returns a value of min(input, threshold)/threshold. 199 | """ 200 | 201 | def __init__(self, threshold: float) -> None: 202 | self.threshold = threshold 203 | 204 | def __call__(self, x): 205 | return np.minimum(x, self.threshold) / self.threshold 206 | -------------------------------------------------------------------------------- /guacamol/scoring_function.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import logging 3 | from typing import List, Optional 4 | 5 | import numpy as np 6 | from rdkit import Chem 7 | 8 | from guacamol.utils.chemistry import smiles_to_rdkit_mol 9 | from guacamol.score_modifier import ScoreModifier, LinearModifier 10 | from guacamol.utils.math import geometric_mean 11 | 12 | logger = logging.getLogger(__name__) 13 | logger.addHandler(logging.NullHandler()) 14 | 15 | 16 | class InvalidMolecule(Exception): 17 | pass 18 | 19 | 20 | class ScoringFunction: 21 | """ 22 | Base class for an objective function. 23 | 24 | In general, do not inherit directly from this class. Prefer `MoleculewiseScoringFunction` or `BatchScoringFunction`. 25 | """ 26 | 27 | def __init__(self, score_modifier: ScoreModifier = None) -> None: 28 | """ 29 | Args: 30 | score_modifier: Modifier to apply to the score. If None, will be LinearModifier() 31 | """ 32 | self.score_modifier = score_modifier 33 | self.corrupt_score = -1.0 34 | 35 | @property 36 | def score_modifier(self): 37 | return self._score_modifier 38 | 39 | @score_modifier.setter 40 | def score_modifier(self, modifier: Optional[ScoreModifier]): 41 | self._score_modifier = LinearModifier() if modifier is None else modifier 42 | 43 | def modify_score(self, raw_score: float) -> float: 44 | return self._score_modifier(raw_score) 45 | 46 | @abstractmethod 47 | def score(self, smiles: str) -> float: 48 | """ 49 | Score a single molecule as smiles 50 | """ 51 | raise NotImplementedError 52 | 53 | @abstractmethod 54 | def score_list(self, smiles_list: List[str]) -> List[float]: 55 | """ 56 | Score a list of smiles. 57 | 58 | Args: 59 | smiles_list: list of smiles [smiles1, smiles2,...] 60 | 61 | Returns: a list of scores 62 | 63 | the order of the input smiles is matched in the output. 64 | 65 | """ 66 | raise NotImplementedError 67 | 68 | 69 | class MoleculewiseScoringFunction(ScoringFunction): 70 | """ 71 | Objective function that is implemented by calculating the score molecule after molecule. 72 | Rather use `BatchScoringFunction` than this if your objective function can process a batch of molecules 73 | more efficiently than by trivially parallelizing the `score` function. 74 | 75 | Derived classes must only implement the `raw_score` function. 76 | """ 77 | 78 | def __init__(self, score_modifier: ScoreModifier = None) -> None: 79 | """ 80 | Args: 81 | score_modifier: Modifier to apply to the score. If None, will be LinearModifier() 82 | """ 83 | super().__init__(score_modifier=score_modifier) 84 | 85 | def score(self, smiles: str) -> float: 86 | try: 87 | return self.modify_score(self.raw_score(smiles)) 88 | except InvalidMolecule: 89 | return self.corrupt_score 90 | except Exception: 91 | logger.warning(f'Unknown exception thrown during scoring of {smiles}') 92 | return self.corrupt_score 93 | 94 | def score_list(self, smiles_list: List[str]) -> List[float]: 95 | return [self.score(smiles) for smiles in smiles_list] 96 | 97 | @abstractmethod 98 | def raw_score(self, smiles: str) -> float: 99 | """ 100 | Get the objective score before application of the modifier. 101 | 102 | For invalid molecules, `InvalidMolecule` should be raised. 103 | For unsuccessful score calculations, `ScoreCannotBeCalculated` should be raised. 104 | """ 105 | raise NotImplementedError 106 | 107 | 108 | class BatchScoringFunction(ScoringFunction): 109 | """ 110 | Objective function that is implemented by calculating the scores of molecules in batches. 111 | Rather use `MoleculewiseScoringFunction` than this if processing a batch is not faster than 112 | trivially parallelizing the `score` function for the distinct molecules. 113 | 114 | Derived classes must only implement the `raw_score_list` function. 115 | """ 116 | 117 | def __init__(self, score_modifier: ScoreModifier = None) -> None: 118 | """ 119 | Args: 120 | score_modifier: Modifier to apply to the score. If None, will be LinearModifier() 121 | """ 122 | super().__init__(score_modifier=score_modifier) 123 | 124 | def score(self, smiles: str) -> float: 125 | return self.score_list([smiles])[0] 126 | 127 | def score_list(self, smiles_list: List[str]) -> List[float]: 128 | raw_scores = self.raw_score_list(smiles_list) 129 | 130 | scores = [self.corrupt_score if raw_score is None 131 | else self.modify_score(raw_score) 132 | for raw_score in raw_scores] 133 | 134 | return scores 135 | 136 | @abstractmethod 137 | def raw_score_list(self, smiles_list: List[str]) -> List[float]: 138 | """ 139 | Calculate the objective score before application of the modifier for a batch of molecules. 140 | 141 | Args: 142 | smiles_list: list of SMILES strings to process 143 | 144 | Returns: 145 | A list of scores. For unsuccessful calculations or invalid molecules, `None` should be given as a value for 146 | the corresponding molecule. 147 | """ 148 | raise NotImplementedError 149 | 150 | 151 | class ScoringFunctionBasedOnRdkitMol(MoleculewiseScoringFunction): 152 | """ 153 | Base class for scoring functions that calculate scores based on rdkit.Chem.Mol instances. 154 | 155 | Derived classes must implement the `score_mol` function. 156 | """ 157 | 158 | def raw_score(self, smiles: str) -> float: 159 | mol = smiles_to_rdkit_mol(smiles) 160 | 161 | if mol is None: 162 | raise InvalidMolecule 163 | 164 | return self.score_mol(mol) 165 | 166 | @abstractmethod 167 | def score_mol(self, mol: Chem.Mol) -> float: 168 | """ 169 | Calculate the molecule score based on a RDKit molecule 170 | 171 | Args: 172 | mol: RDKit molecule 173 | """ 174 | raise NotImplementedError 175 | 176 | 177 | class ArithmeticMeanScoringFunction(BatchScoringFunction): 178 | """ 179 | Scoring function that combines multiple scoring functions linearly. 180 | """ 181 | 182 | def __init__(self, scoring_functions: List[ScoringFunction], weights=None) -> None: 183 | """ 184 | Args: 185 | scoring_functions: scoring functions to combine 186 | weights: weight for the corresponding scoring functions. If None, all will have the same weight. 187 | """ 188 | super().__init__() 189 | 190 | self.scoring_functions = scoring_functions 191 | number_scoring_functions = len(scoring_functions) 192 | 193 | self.weights = np.ones(number_scoring_functions) if weights is None else weights 194 | assert number_scoring_functions == len(self.weights) 195 | 196 | def raw_score_list(self, smiles_list: List[str]) -> List[float]: 197 | scores = [] 198 | 199 | for function, weight in zip(self.scoring_functions, self.weights): 200 | res = function.score_list(smiles_list) 201 | scores.append(weight * np.array(res)) 202 | 203 | scores = np.array(scores).sum(axis=0) / np.sum(self.weights) 204 | 205 | return list(scores) 206 | 207 | 208 | class GeometricMeanScoringFunction(MoleculewiseScoringFunction): 209 | """ 210 | Scoring function that combines multiple scoring functions multiplicatively. 211 | """ 212 | 213 | def __init__(self, scoring_functions: List[ScoringFunction]) -> None: 214 | """ 215 | Args: 216 | scoring_functions: scoring functions to combine 217 | """ 218 | super().__init__() 219 | 220 | self.scoring_functions = scoring_functions 221 | 222 | def raw_score(self, smiles: str) -> float: 223 | partial_scores = [f.score(smiles) for f in self.scoring_functions] 224 | if self.corrupt_score in partial_scores: 225 | return self.corrupt_score 226 | 227 | return geometric_mean(partial_scores) 228 | 229 | 230 | class ScoringFunctionWrapper(ScoringFunction): 231 | """ 232 | Wraps a scoring function to store the number of calls to it. 233 | """ 234 | 235 | def __init__(self, scoring_function: ScoringFunction) -> None: 236 | super().__init__() 237 | self.scoring_function = scoring_function 238 | self.evaluations = 0 239 | 240 | def score(self, smiles): 241 | self._increment_evaluation_count(1) 242 | return self.scoring_function.score(smiles) 243 | 244 | def score_list(self, smiles_list): 245 | self._increment_evaluation_count(len(smiles_list)) 246 | return self.scoring_function.score_list(smiles_list) 247 | 248 | def _increment_evaluation_count(self, n: int): 249 | # Ideally, this should be protected by a lock in order to allow for multithreading. 250 | # However, adding a threading.Lock member variable makes the class non-pickle-able, which prevents any multithreading. 251 | # Therefore, in the current implementation there cannot be a guarantee that self.evaluations will be calculated correctly. 252 | self.evaluations += n 253 | -------------------------------------------------------------------------------- /guacamol/standard_benchmarks.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem 2 | 3 | from guacamol.common_scoring_functions import TanimotoScoringFunction, RdkitScoringFunction, CNS_MPO_ScoringFunction, \ 4 | IsomerScoringFunction, SMARTSScoringFunction 5 | from guacamol.distribution_learning_benchmark import DistributionLearningBenchmark, NoveltyBenchmark, KLDivBenchmark 6 | from guacamol.frechet_benchmark import FrechetBenchmark 7 | from guacamol.goal_directed_benchmark import GoalDirectedBenchmark 8 | from guacamol.goal_directed_score_contributions import uniform_specification 9 | from guacamol.score_modifier import MinGaussianModifier, MaxGaussianModifier, ClippedScoreModifier, GaussianModifier 10 | from guacamol.scoring_function import ArithmeticMeanScoringFunction, GeometricMeanScoringFunction, ScoringFunction 11 | from guacamol.utils.descriptors import num_rotatable_bonds, num_aromatic_rings, logP, qed, tpsa, bertz, mol_weight, \ 12 | AtomCounter, num_rings 13 | 14 | 15 | def isomers_c11h24(mean_function='geometric') -> GoalDirectedBenchmark: 16 | """ 17 | Benchmark to try and get all C11H24 molecules there are. 18 | There should be 159 if one ignores stereochemistry. 19 | 20 | Args: 21 | mean_function: 'arithmetic' or 'geometric' 22 | """ 23 | 24 | specification = uniform_specification(159) 25 | 26 | return GoalDirectedBenchmark(name='C11H24', 27 | objective=IsomerScoringFunction('C11H24', mean_function=mean_function), 28 | contribution_specification=specification) 29 | 30 | 31 | def isomers_c7h8n2o2(mean_function='geometric') -> GoalDirectedBenchmark: 32 | """ 33 | Benchmark to try and get 100 isomers for C7H8N2O2. 34 | 35 | Args: 36 | mean_function: 'arithmetic' or 'geometric' 37 | """ 38 | 39 | specification = uniform_specification(100) 40 | 41 | return GoalDirectedBenchmark(name='C7H8N2O2', 42 | objective=IsomerScoringFunction('C7H8N2O2', mean_function=mean_function), 43 | contribution_specification=specification) 44 | 45 | 46 | def isomers_c9h10n2o2pf2cl(mean_function='geometric', n_samples=250) -> GoalDirectedBenchmark: 47 | """ 48 | Benchmark to try and get 100 isomers for C9H10N2O2PF2Cl. 49 | 50 | Args: 51 | mean_function: 'arithmetic' or 'geometric' 52 | """ 53 | 54 | specification = uniform_specification(n_samples) 55 | 56 | return GoalDirectedBenchmark(name='C9H10N2O2PF2Cl', 57 | objective=IsomerScoringFunction('C9H10N2O2PF2Cl', mean_function=mean_function), 58 | contribution_specification=specification) 59 | 60 | 61 | def hard_cobimetinib(max_logP=5.0) -> GoalDirectedBenchmark: 62 | smiles = 'OC1(CN(C1)C(=O)C1=C(NC2=C(F)C=C(I)C=C2)C(F)=C(F)C=C1)C1CCCCN1' 63 | 64 | modifier = ClippedScoreModifier(upper_x=0.7) 65 | os_tf = TanimotoScoringFunction(smiles, fp_type='FCFP4', score_modifier=modifier) 66 | os_ap = TanimotoScoringFunction(smiles, fp_type='ECFP6', 67 | score_modifier=MinGaussianModifier(mu=0.75, sigma=0.1)) 68 | 69 | rot_b = RdkitScoringFunction(descriptor=num_rotatable_bonds, 70 | score_modifier=MinGaussianModifier(mu=3, sigma=1)) 71 | 72 | rings = RdkitScoringFunction(descriptor=num_aromatic_rings, 73 | score_modifier=MaxGaussianModifier(mu=3, sigma=1)) 74 | 75 | t_cns = ArithmeticMeanScoringFunction([os_tf, os_ap, rot_b, rings, CNS_MPO_ScoringFunction(max_logP=max_logP)]) 76 | 77 | specification = uniform_specification(1, 10, 100) 78 | 79 | return GoalDirectedBenchmark(name='Cobimetinib MPO', 80 | objective=t_cns, 81 | contribution_specification=specification) 82 | 83 | 84 | def hard_osimertinib(mean_cls=GeometricMeanScoringFunction) -> GoalDirectedBenchmark: 85 | smiles = 'COc1cc(N(C)CCN(C)C)c(NC(=O)C=C)cc1Nc2nccc(n2)c3cn(C)c4ccccc34' 86 | 87 | modifier = ClippedScoreModifier(upper_x=0.8) 88 | similar_to_osimertinib = TanimotoScoringFunction(smiles, fp_type='FCFP4', score_modifier=modifier) 89 | 90 | but_not_too_similar = TanimotoScoringFunction(smiles, fp_type='ECFP6', 91 | score_modifier=MinGaussianModifier(mu=0.85, sigma=0.1)) 92 | 93 | tpsa_over_100 = RdkitScoringFunction(descriptor=tpsa, 94 | score_modifier=MaxGaussianModifier(mu=100, sigma=10)) 95 | 96 | logP_scoring = RdkitScoringFunction(descriptor=logP, 97 | score_modifier=MinGaussianModifier(mu=1, sigma=1)) 98 | 99 | make_osimertinib_great_again = mean_cls( 100 | [similar_to_osimertinib, but_not_too_similar, tpsa_over_100, logP_scoring]) 101 | 102 | specification = uniform_specification(1, 10, 100) 103 | 104 | return GoalDirectedBenchmark(name='Osimertinib MPO', 105 | objective=make_osimertinib_great_again, 106 | contribution_specification=specification) 107 | 108 | 109 | def hard_fexofenadine(mean_cls=GeometricMeanScoringFunction) -> GoalDirectedBenchmark: 110 | """ 111 | make fexofenadine less greasy 112 | :return: 113 | """ 114 | smiles = 'CC(C)(C(=O)O)c1ccc(cc1)C(O)CCCN2CCC(CC2)C(O)(c3ccccc3)c4ccccc4' 115 | 116 | modifier = ClippedScoreModifier(upper_x=0.8) 117 | similar_to_fexofenadine = TanimotoScoringFunction(smiles, fp_type='AP', score_modifier=modifier) 118 | 119 | tpsa_over_90 = RdkitScoringFunction(descriptor=tpsa, 120 | score_modifier=MaxGaussianModifier(mu=90, sigma=10)) 121 | 122 | logP_under_4 = RdkitScoringFunction(descriptor=logP, 123 | score_modifier=MinGaussianModifier(mu=4, sigma=1)) 124 | 125 | optimize_fexofenadine = mean_cls( 126 | [similar_to_fexofenadine, tpsa_over_90, logP_under_4]) 127 | 128 | specification = uniform_specification(1, 10, 100) 129 | 130 | return GoalDirectedBenchmark(name='Fexofenadine MPO', 131 | objective=optimize_fexofenadine, 132 | contribution_specification=specification) 133 | 134 | 135 | def start_pop_ranolazine() -> GoalDirectedBenchmark: 136 | ranolazine = 'COc1ccccc1OCC(O)CN2CCN(CC(=O)Nc3c(C)cccc3C)CC2' 137 | 138 | modifier = ClippedScoreModifier(upper_x=0.7) 139 | similar_to_ranolazine = TanimotoScoringFunction(ranolazine, fp_type='AP', score_modifier=modifier) 140 | 141 | logP_under_4 = RdkitScoringFunction(descriptor=logP, 142 | score_modifier=MaxGaussianModifier(mu=7, sigma=1)) 143 | 144 | aroma = RdkitScoringFunction(descriptor=num_aromatic_rings, 145 | score_modifier=MinGaussianModifier(mu=1, sigma=1)) 146 | 147 | fluorine = RdkitScoringFunction(descriptor=AtomCounter('F'), 148 | score_modifier=GaussianModifier(mu=1, sigma=1.0)) 149 | 150 | optimize_ranolazine = ArithmeticMeanScoringFunction([similar_to_ranolazine, logP_under_4, fluorine, aroma]) 151 | 152 | specification = uniform_specification(1, 10, 100) 153 | 154 | return GoalDirectedBenchmark(name='Ranolazine MPO', 155 | objective=optimize_ranolazine, 156 | contribution_specification=specification, 157 | starting_population=[ranolazine]) 158 | 159 | 160 | def weird_physchem() -> GoalDirectedBenchmark: 161 | min_bertz = RdkitScoringFunction(descriptor=bertz, 162 | score_modifier=MaxGaussianModifier(mu=1500, sigma=200)) 163 | 164 | mol_under_400 = RdkitScoringFunction(descriptor=mol_weight, 165 | score_modifier=MinGaussianModifier(mu=400, sigma=40)) 166 | 167 | aroma = RdkitScoringFunction(descriptor=num_aromatic_rings, 168 | score_modifier=MinGaussianModifier(mu=3, sigma=1)) 169 | 170 | fluorine = RdkitScoringFunction(descriptor=AtomCounter('F'), 171 | score_modifier=GaussianModifier(mu=6, sigma=1.0)) 172 | 173 | opt_weird = ArithmeticMeanScoringFunction( 174 | [min_bertz, mol_under_400, aroma, fluorine]) 175 | 176 | specification = uniform_specification(1, 10, 100) 177 | 178 | return GoalDirectedBenchmark(name='Physchem MPO', 179 | objective=opt_weird, 180 | contribution_specification=specification) 181 | 182 | 183 | def similarity_cns_mpo(smiles, molecule_name, max_logP=5.0) -> GoalDirectedBenchmark: 184 | benchmark_name = f'{molecule_name}' 185 | os_tf = TanimotoScoringFunction(smiles, fp_type='FCFP4') 186 | os_ap = TanimotoScoringFunction(smiles, fp_type='AP') 187 | anti_fp = TanimotoScoringFunction(smiles, fp_type='ECFP6', 188 | score_modifier=MinGaussianModifier(mu=0.70, sigma=0.1)) 189 | 190 | t_cns = ArithmeticMeanScoringFunction([os_tf, os_ap, anti_fp, CNS_MPO_ScoringFunction(max_logP=max_logP)]) 191 | 192 | specification = uniform_specification(1, 10, 100) 193 | 194 | return GoalDirectedBenchmark(name=benchmark_name, 195 | objective=t_cns, 196 | contribution_specification=specification) 197 | 198 | 199 | def similarity(smiles: str, name: str, fp_type: str = 'ECFP4', threshold: float = 0.7, 200 | rediscovery: bool = False) -> GoalDirectedBenchmark: 201 | category = 'rediscovery' if rediscovery else 'similarity' 202 | benchmark_name = f'{name} {category}' 203 | 204 | modifier = ClippedScoreModifier(upper_x=threshold) 205 | scoring_function = TanimotoScoringFunction(target=smiles, fp_type=fp_type, score_modifier=modifier) 206 | if rediscovery: 207 | specification = uniform_specification(1) 208 | else: 209 | specification = uniform_specification(1, 10, 100) 210 | 211 | return GoalDirectedBenchmark(name=benchmark_name, 212 | objective=scoring_function, 213 | contribution_specification=specification) 214 | 215 | 216 | def logP_benchmark(target: float) -> GoalDirectedBenchmark: 217 | benchmark_name = f'logP (target: {target})' 218 | objective = RdkitScoringFunction(descriptor=logP, 219 | score_modifier=GaussianModifier(mu=target, sigma=1)) 220 | 221 | specification = uniform_specification(1, 10, 100) 222 | 223 | return GoalDirectedBenchmark(name=benchmark_name, 224 | objective=objective, 225 | contribution_specification=specification) 226 | 227 | 228 | def tpsa_benchmark(target: float) -> GoalDirectedBenchmark: 229 | benchmark_name = f'TPSA (target: {target})' 230 | objective = RdkitScoringFunction(descriptor=tpsa, 231 | score_modifier=GaussianModifier(mu=target, sigma=20.0)) 232 | 233 | specification = uniform_specification(1, 10, 100) 234 | 235 | return GoalDirectedBenchmark(name=benchmark_name, 236 | objective=objective, 237 | contribution_specification=specification) 238 | 239 | 240 | def cns_mpo(max_logP=5.0) -> GoalDirectedBenchmark: 241 | specification = uniform_specification(1, 10, 100) 242 | return GoalDirectedBenchmark(name='CNS MPO', objective=CNS_MPO_ScoringFunction(max_logP=max_logP), 243 | contribution_specification=specification) 244 | 245 | 246 | def qed_benchmark() -> GoalDirectedBenchmark: 247 | specification = uniform_specification(1, 10, 100) 248 | return GoalDirectedBenchmark(name='QED', 249 | objective=RdkitScoringFunction(descriptor=qed), 250 | contribution_specification=specification) 251 | 252 | 253 | def median_camphor_menthol(mean_cls=GeometricMeanScoringFunction) -> GoalDirectedBenchmark: 254 | t_camphor = TanimotoScoringFunction('CC1(C)C2CCC1(C)C(=O)C2', fp_type='ECFP4') 255 | t_menthol = TanimotoScoringFunction('CC(C)C1CCC(C)CC1O', fp_type='ECFP4') 256 | median = mean_cls([t_menthol, t_camphor]) 257 | 258 | specification = uniform_specification(1, 10, 100) 259 | 260 | return GoalDirectedBenchmark(name='Median molecules 1', 261 | objective=median, 262 | contribution_specification=specification) 263 | 264 | 265 | def novelty_benchmark(training_set_file: str, number_samples: int) -> DistributionLearningBenchmark: 266 | smiles_list = [s.strip() for s in open(training_set_file).readlines()] 267 | return NoveltyBenchmark(number_samples=number_samples, training_set=smiles_list) 268 | 269 | 270 | def kldiv_benchmark(training_set_file: str, number_samples: int) -> DistributionLearningBenchmark: 271 | smiles_list = [s.strip() for s in open(training_set_file).readlines()] 272 | return KLDivBenchmark(number_samples=number_samples, training_set=smiles_list) 273 | 274 | 275 | def frechet_benchmark(training_set_file: str, number_samples: int) -> DistributionLearningBenchmark: 276 | smiles_list = [s.strip() for s in open(training_set_file).readlines()] 277 | return FrechetBenchmark(training_set=smiles_list, sample_size=number_samples) 278 | 279 | 280 | def perindopril_rings() -> GoalDirectedBenchmark: 281 | # perindopril with two aromatic rings 282 | perindopril = TanimotoScoringFunction('O=C(OCC)C(NC(C(=O)N1C(C(=O)O)CC2CCCCC12)C)CCC', 283 | fp_type='ECFP4') 284 | arom_rings = RdkitScoringFunction(descriptor=num_aromatic_rings, 285 | score_modifier=GaussianModifier(mu=2, sigma=0.5)) 286 | 287 | specification = uniform_specification(1, 10, 100) 288 | 289 | return GoalDirectedBenchmark(name='Perindopril MPO', 290 | objective=GeometricMeanScoringFunction([perindopril, arom_rings]), 291 | contribution_specification=specification) 292 | 293 | 294 | def amlodipine_rings() -> GoalDirectedBenchmark: 295 | # amlodipine with 3 rings 296 | amlodipine = TanimotoScoringFunction(r'Clc1ccccc1C2C(=C(/N/C(=C2/C(=O)OCC)COCCN)C)\C(=O)OC', fp_type='ECFP4') 297 | rings = RdkitScoringFunction(descriptor=num_rings, 298 | score_modifier=GaussianModifier(mu=3, sigma=0.5)) 299 | 300 | specification = uniform_specification(1, 10, 100) 301 | 302 | return GoalDirectedBenchmark(name='Amlodipine MPO', 303 | objective=GeometricMeanScoringFunction([amlodipine, rings]), 304 | contribution_specification=specification) 305 | 306 | 307 | def sitagliptin_replacement() -> GoalDirectedBenchmark: 308 | # Find a molecule dissimilar to sitagliptin, but with the same properties 309 | smiles = 'Fc1cc(c(F)cc1F)CC(N)CC(=O)N3Cc2nnc(n2CC3)C(F)(F)F' 310 | sitagliptin = Chem.MolFromSmiles(smiles) 311 | target_logp = logP(sitagliptin) 312 | target_tpsa = tpsa(sitagliptin) 313 | 314 | similarity = TanimotoScoringFunction(smiles, fp_type='ECFP4', 315 | score_modifier=GaussianModifier(mu=0, sigma=0.1)) 316 | lp = RdkitScoringFunction(descriptor=logP, 317 | score_modifier=GaussianModifier(mu=target_logp, sigma=0.2)) 318 | tp = RdkitScoringFunction(descriptor=tpsa, 319 | score_modifier=GaussianModifier(mu=target_tpsa, sigma=5)) 320 | isomers = IsomerScoringFunction('C16H15F6N5O') 321 | 322 | specification = uniform_specification(1, 10, 100) 323 | 324 | return GoalDirectedBenchmark(name='Sitagliptin MPO', 325 | objective=GeometricMeanScoringFunction([similarity, lp, tp, isomers]), 326 | contribution_specification=specification) 327 | 328 | 329 | def zaleplon_with_other_formula() -> GoalDirectedBenchmark: 330 | # zaleplon_with_other_formula with other formula 331 | zaleplon = TanimotoScoringFunction('O=C(C)N(CC)C1=CC=CC(C2=CC=NC3=C(C=NN23)C#N)=C1', 332 | fp_type='ECFP4') 333 | formula = IsomerScoringFunction('C19H17N3O2') 334 | 335 | specification = uniform_specification(1, 10, 100) 336 | 337 | return GoalDirectedBenchmark(name='Zaleplon MPO', 338 | objective=GeometricMeanScoringFunction([zaleplon, formula]), 339 | contribution_specification=specification) 340 | 341 | 342 | def smarts_with_other_target(smarts: str, other_molecule: str) -> ScoringFunction: 343 | smarts_scoring_function = SMARTSScoringFunction(target=smarts) 344 | other_mol = Chem.MolFromSmiles(other_molecule) 345 | target_logp = logP(other_mol) 346 | target_tpsa = tpsa(other_mol) 347 | target_bertz = bertz(other_mol) 348 | 349 | lp = RdkitScoringFunction(descriptor=logP, 350 | score_modifier=GaussianModifier(mu=target_logp, sigma=0.2)) 351 | tp = RdkitScoringFunction(descriptor=tpsa, 352 | score_modifier=GaussianModifier(mu=target_tpsa, sigma=5)) 353 | bz = RdkitScoringFunction(descriptor=bertz, 354 | score_modifier=GaussianModifier(mu=target_bertz, sigma=30)) 355 | 356 | return GeometricMeanScoringFunction([smarts_scoring_function, lp, tp, bz]) 357 | 358 | 359 | def valsartan_smarts() -> GoalDirectedBenchmark: 360 | # valsartan smarts with sitagliptin properties 361 | sitagliptin_smiles = 'NC(CC(=O)N1CCn2c(nnc2C(F)(F)F)C1)Cc1cc(F)c(F)cc1F' 362 | valsartan_smarts = 'CN(C=O)Cc1ccc(c2ccccc2)cc1' 363 | specification = uniform_specification(1, 10, 100) 364 | return GoalDirectedBenchmark(name='Valsartan SMARTS', 365 | objective=smarts_with_other_target(valsartan_smarts, sitagliptin_smiles), 366 | contribution_specification=specification) 367 | 368 | 369 | def median_tadalafil_sildenafil() -> GoalDirectedBenchmark: 370 | # median mol between tadalafil and sildenafil 371 | m1 = TanimotoScoringFunction('O=C1N(CC(N2C1CC3=C(C2C4=CC5=C(OCO5)C=C4)NC6=C3C=CC=C6)=O)C', fp_type='ECFP6') 372 | m2 = TanimotoScoringFunction('CCCC1=NN(C2=C1N=C(NC2=O)C3=C(C=CC(=C3)S(=O)(=O)N4CCN(CC4)C)OCC)C', fp_type='ECFP6') 373 | median = GeometricMeanScoringFunction([m1, m2]) 374 | 375 | specification = uniform_specification(1, 10, 100) 376 | 377 | return GoalDirectedBenchmark(name='Median molecules 2', 378 | objective=median, 379 | contribution_specification=specification) 380 | 381 | 382 | def pioglitazone_mpo() -> GoalDirectedBenchmark: 383 | # pioglitazone with same mw but less rotatable bonds 384 | smiles = 'O=C1NC(=O)SC1Cc3ccc(OCCc2ncc(cc2)CC)cc3' 385 | pioglitazone = Chem.MolFromSmiles(smiles) 386 | target_molw = mol_weight(pioglitazone) 387 | 388 | similarity = TanimotoScoringFunction(smiles, fp_type='ECFP4', 389 | score_modifier=GaussianModifier(mu=0, sigma=0.1)) 390 | mw = RdkitScoringFunction(descriptor=mol_weight, 391 | score_modifier=GaussianModifier(mu=target_molw, sigma=10)) 392 | rb = RdkitScoringFunction(descriptor=num_rotatable_bonds, 393 | score_modifier=GaussianModifier(mu=2, sigma=0.5)) 394 | 395 | specification = uniform_specification(1, 10, 100) 396 | 397 | return GoalDirectedBenchmark(name='Pioglitazone MPO', 398 | objective=GeometricMeanScoringFunction([similarity, mw, rb]), 399 | contribution_specification=specification) 400 | 401 | 402 | def decoration_hop() -> GoalDirectedBenchmark: 403 | smiles = 'CCCOc1cc2ncnc(Nc3ccc4ncsc4c3)c2cc1S(=O)(=O)C(C)(C)C' 404 | 405 | pharmacophor_sim = TanimotoScoringFunction(smiles, fp_type='PHCO', 406 | score_modifier=ClippedScoreModifier(upper_x=0.85)) 407 | # change deco 408 | deco1 = SMARTSScoringFunction('CS([#6])(=O)=O', inverse=True) 409 | deco2 = SMARTSScoringFunction('[#7]-c1ccc2ncsc2c1', inverse=True) 410 | 411 | # keep scaffold 412 | scaffold = SMARTSScoringFunction('[#7]-c1n[c;h1]nc2[c;h1]c(-[#8])[c;h0][c;h1]c12', inverse=False) 413 | 414 | deco_hop1_fn = ArithmeticMeanScoringFunction([pharmacophor_sim, deco1, deco2, scaffold]) 415 | 416 | specification = uniform_specification(1, 10, 100) 417 | 418 | return GoalDirectedBenchmark(name='Deco Hop', 419 | objective=deco_hop1_fn, 420 | contribution_specification=specification) 421 | 422 | 423 | def scaffold_hop() -> GoalDirectedBenchmark: 424 | """ 425 | Keep the decoration, and similarity to start point, but change the scaffold. 426 | """ 427 | 428 | smiles = 'CCCOc1cc2ncnc(Nc3ccc4ncsc4c3)c2cc1S(=O)(=O)C(C)(C)C' 429 | 430 | pharmacophor_sim = TanimotoScoringFunction(smiles, fp_type='PHCO', 431 | score_modifier=ClippedScoreModifier(upper_x=0.75)) 432 | 433 | deco = SMARTSScoringFunction('[#6]-[#6]-[#6]-[#8]-[#6]~[#6]~[#6]~[#6]~[#6]-[#7]-c1ccc2ncsc2c1', inverse=False) 434 | 435 | # anti scaffold 436 | scaffold = SMARTSScoringFunction('[#7]-c1n[c;h1]nc2[c;h1]c(-[#8])[c;h0][c;h1]c12', inverse=True) 437 | 438 | scaffold_hop_obj = ArithmeticMeanScoringFunction([pharmacophor_sim, deco, scaffold]) 439 | 440 | specification = uniform_specification(1, 10, 100) 441 | 442 | return GoalDirectedBenchmark(name='Scaffold Hop', 443 | objective=scaffold_hop_obj, 444 | contribution_specification=specification) 445 | 446 | 447 | def ranolazine_mpo() -> GoalDirectedBenchmark: 448 | """ 449 | Make start_pop_ranolazine more polar and add a fluorine 450 | """ 451 | ranolazine = 'COc1ccccc1OCC(O)CN2CCN(CC(=O)Nc3c(C)cccc3C)CC2' 452 | 453 | modifier = ClippedScoreModifier(upper_x=0.7) 454 | similar_to_ranolazine = TanimotoScoringFunction(ranolazine, fp_type='AP', score_modifier=modifier) 455 | 456 | logP_under_4 = RdkitScoringFunction(descriptor=logP, score_modifier=MaxGaussianModifier(mu=7, sigma=1)) 457 | 458 | tpsa_f = RdkitScoringFunction(descriptor=tpsa, score_modifier=MaxGaussianModifier(mu=95, sigma=20)) 459 | 460 | fluorine = RdkitScoringFunction(descriptor=AtomCounter('F'), score_modifier=GaussianModifier(mu=1, sigma=1.0)) 461 | 462 | optimize_ranolazine = GeometricMeanScoringFunction([similar_to_ranolazine, logP_under_4, fluorine, tpsa_f]) 463 | 464 | specification = uniform_specification(1, 10, 100) 465 | 466 | return GoalDirectedBenchmark(name='Ranolazine MPO', 467 | objective=optimize_ranolazine, 468 | contribution_specification=specification, 469 | starting_population=[ranolazine]) 470 | -------------------------------------------------------------------------------- /guacamol/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenevolentAI/guacamol/60ebe1f6a396f16e08b834dce448e9343d259feb/guacamol/utils/__init__.py -------------------------------------------------------------------------------- /guacamol/utils/chemistry.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | from typing import Optional, List, Iterable, Collection, Tuple 4 | 5 | import numpy as np 6 | from rdkit import Chem 7 | from rdkit import RDLogger, DataStructs 8 | from rdkit.Chem import AllChem 9 | from rdkit.ML.Descriptors import MoleculeDescriptors 10 | from scipy import histogram 11 | from scipy.stats import entropy, gaussian_kde 12 | 13 | from guacamol.utils.data import remove_duplicates 14 | 15 | # Mute RDKit logger 16 | RDLogger.logger().setLevel(RDLogger.CRITICAL) 17 | 18 | logger = logging.getLogger(__name__) 19 | logger.addHandler(logging.NullHandler()) 20 | 21 | 22 | def is_valid(smiles: str): 23 | """ 24 | Verifies whether a SMILES string corresponds to a valid molecule. 25 | 26 | Args: 27 | smiles: SMILES string 28 | 29 | Returns: 30 | True if the SMILES strings corresponds to a valid, non-empty molecule. 31 | """ 32 | 33 | mol = Chem.MolFromSmiles(smiles) 34 | 35 | return smiles != '' and mol is not None and mol.GetNumAtoms() > 0 36 | 37 | 38 | def canonicalize(smiles: str, include_stereocenters=True) -> Optional[str]: 39 | """ 40 | Canonicalize the SMILES strings with RDKit. 41 | 42 | The algorithm is detailed under https://pubs.acs.org/doi/full/10.1021/acs.jcim.5b00543 43 | 44 | Args: 45 | smiles: SMILES string to canonicalize 46 | include_stereocenters: whether to keep the stereochemical information in the canonical SMILES string 47 | 48 | Returns: 49 | Canonicalized SMILES string, None if the molecule is invalid. 50 | """ 51 | 52 | mol = Chem.MolFromSmiles(smiles) 53 | 54 | if mol is not None: 55 | return Chem.MolToSmiles(mol, isomericSmiles=include_stereocenters) 56 | else: 57 | return None 58 | 59 | 60 | def canonicalize_list(smiles_list: Iterable[str], include_stereocenters=True) -> List[str]: 61 | """ 62 | Canonicalize a list of smiles. Filters out repetitions and removes corrupted molecules. 63 | 64 | Args: 65 | smiles_list: molecules as SMILES strings 66 | include_stereocenters: whether to keep the stereochemical information in the canonical SMILES strings 67 | 68 | Returns: 69 | The canonicalized and filtered input smiles. 70 | """ 71 | 72 | canonicalized_smiles = [canonicalize(smiles, include_stereocenters) for smiles in smiles_list] 73 | 74 | # Remove None elements 75 | canonicalized_smiles = [s for s in canonicalized_smiles if s is not None] 76 | 77 | return remove_duplicates(canonicalized_smiles) 78 | 79 | 80 | def smiles_to_rdkit_mol(smiles: str) -> Optional[Chem.Mol]: 81 | """ 82 | Converts a SMILES string to a RDKit molecule. 83 | 84 | Args: 85 | smiles: SMILES string of the molecule 86 | 87 | Returns: 88 | RDKit Mol, None if the SMILES string is invalid 89 | """ 90 | mol = Chem.MolFromSmiles(smiles) 91 | 92 | # Sanitization check (detects invalid valence) 93 | if mol is not None: 94 | try: 95 | Chem.SanitizeMol(mol) 96 | except ValueError: 97 | return None 98 | 99 | return mol 100 | 101 | 102 | def split_charged_mol(smiles: str) -> str: 103 | if smiles.count('.') > 0: 104 | largest = '' 105 | largest_len = -1 106 | split = smiles.split('.') 107 | for i in split: 108 | if len(i) > largest_len: 109 | largest = i 110 | largest_len = len(i) 111 | return largest 112 | 113 | else: 114 | return smiles 115 | 116 | 117 | def initialise_neutralisation_reactions(): 118 | patts = ( 119 | # Imidazoles 120 | ('[n+;H]', 'n'), 121 | # Amines 122 | ('[N+;!H0]', 'N'), 123 | # Carboxylic acids and alcohols 124 | ('[$([O-]);!$([O-][#7])]', 'O'), 125 | # Thiols 126 | ('[S-;X1]', 'S'), 127 | # Sulfonamides 128 | ('[$([N-;X2]S(=O)=O)]', 'N'), 129 | # Enamines 130 | ('[$([N-;X2][C,N]=C)]', 'N'), 131 | # Tetrazoles 132 | ('[n-]', '[nH]'), 133 | # Sulfoxides 134 | ('[$([S-]=O)]', 'S'), 135 | # Amides 136 | ('[$([N-]C=O)]', 'N'), 137 | ) 138 | return [(Chem.MolFromSmarts(x), Chem.MolFromSmiles(y, False)) for x, y in patts] 139 | 140 | 141 | def neutralise_charges(mol, reactions=None): 142 | replaced = False 143 | 144 | for i, (reactant, product) in enumerate(reactions): 145 | while mol.HasSubstructMatch(reactant): 146 | replaced = True 147 | rms = AllChem.ReplaceSubstructs(mol, reactant, product) 148 | mol = rms[0] 149 | if replaced: 150 | Chem.SanitizeMol(mol) 151 | return mol, True 152 | else: 153 | return mol, False 154 | 155 | 156 | def filter_and_canonicalize(smiles: str, holdout_set, holdout_fps, neutralization_rxns, tanimoto_cutoff=0.5, 157 | include_stereocenters=False): 158 | """ 159 | Args: 160 | smiles: the molecule to process 161 | holdout_set: smiles of the holdout set 162 | holdout_fps: ECFP4 fingerprints of the holdout set 163 | neutralization_rxns: neutralization rdkit reactions 164 | tanimoto_cutoff: Remove molecules with a higher ECFP4 tanimoto similarity than this cutoff from the set 165 | include_stereocenters: whether to keep stereocenters during canonicalization 166 | 167 | Returns: 168 | list with canonical smiles as a list with one element, or a an empty list. This is to perform a flatmap: 169 | """ 170 | try: 171 | # Drop out if too long 172 | if len(smiles) > 200: 173 | return [] 174 | mol = Chem.MolFromSmiles(smiles) 175 | # Drop out if invalid 176 | if mol is None: 177 | return [] 178 | mol = Chem.RemoveHs(mol) 179 | 180 | # We only accept molecules consisting of H, B, C, N, O, F, Si, P, S, Cl, aliphatic Se, Br, I. 181 | metal_smarts = Chem.MolFromSmarts('[!#1!#5!#6!#7!#8!#9!#14!#15!#16!#17!#34!#35!#53]') 182 | 183 | has_metal = mol.HasSubstructMatch(metal_smarts) 184 | 185 | # Exclude molecules containing the forbidden elements. 186 | if has_metal: 187 | print(f'metal {smiles}') 188 | return [] 189 | 190 | canon_smi = Chem.MolToSmiles(mol, isomericSmiles=include_stereocenters) 191 | 192 | # Drop out if too long canonicalized: 193 | if len(canon_smi) > 100: 194 | return [] 195 | # Balance charges if unbalanced 196 | if canon_smi.count('+') - canon_smi.count('-') != 0: 197 | new_mol, changed = neutralise_charges(mol, reactions=neutralization_rxns) 198 | if changed: 199 | mol = new_mol 200 | canon_smi = Chem.MolToSmiles(mol, isomericSmiles=include_stereocenters) 201 | 202 | # Get most similar to holdout fingerprints, and exclude too similar molecules. 203 | max_tanimoto = highest_tanimoto_precalc_fps(mol, holdout_fps) 204 | if max_tanimoto < tanimoto_cutoff and canon_smi not in holdout_set: 205 | return [canon_smi] 206 | else: 207 | print("Exclude: {} {}".format(canon_smi, max_tanimoto)) 208 | except Exception as e: 209 | print(e) 210 | return [] 211 | 212 | 213 | def calculate_internal_pairwise_similarities(smiles_list: Collection[str]) -> np.ndarray: 214 | """ 215 | Computes the pairwise similarities of the provided list of smiles against itself. 216 | 217 | Returns: 218 | Symmetric matrix of pairwise similarities. Diagonal is set to zero. 219 | """ 220 | if len(smiles_list) > 10000: 221 | logger.warning(f'Calculating internal similarity on large set of ' 222 | f'SMILES strings ({len(smiles_list)})') 223 | 224 | mols = get_mols(smiles_list) 225 | fps = get_fingerprints(mols) 226 | nfps = len(fps) 227 | 228 | similarities = np.zeros((nfps, nfps)) 229 | 230 | for i in range(1, nfps): 231 | sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i]) 232 | similarities[i, :i] = sims 233 | similarities[:i, i] = sims 234 | 235 | return similarities 236 | 237 | 238 | def calculate_pairwise_similarities(smiles_list1: List[str], smiles_list2: List[str]) -> np.ndarray: 239 | """ 240 | Computes the pairwise ECFP4 tanimoto similarity of the two smiles containers. 241 | 242 | Returns: 243 | Pairwise similarity matrix as np.ndarray 244 | """ 245 | if len(smiles_list1) > 10000 or len(smiles_list2) > 10000: 246 | logger.warning(f'Calculating similarity between large sets of ' 247 | f'SMILES strings ({len(smiles_list1)} x {len(smiles_list2)})') 248 | 249 | mols1 = get_mols(smiles_list1) 250 | fps1 = get_fingerprints(mols1) 251 | 252 | mols2 = get_mols(smiles_list2) 253 | fps2 = get_fingerprints(mols2) 254 | 255 | similarities = [] 256 | 257 | for fp1 in fps1: 258 | sims = DataStructs.BulkTanimotoSimilarity(fp1, fps2) 259 | 260 | similarities.append(sims) 261 | 262 | return np.array(similarities) 263 | 264 | 265 | def get_fingerprints_from_smileslist(smiles_list): 266 | """ 267 | Converts the provided smiles into ECFP4 bitvectors of length 4096. 268 | 269 | Args: 270 | smiles_list: list of SMILES strings 271 | 272 | Returns: ECFP4 bitvectors of length 4096. 273 | 274 | """ 275 | return get_fingerprints(get_mols(smiles_list)) 276 | 277 | 278 | def get_fingerprints(mols: Iterable[Chem.Mol], radius=2, length=4096): 279 | """ 280 | Converts molecules to ECFP bitvectors. 281 | 282 | Args: 283 | mols: RDKit molecules 284 | radius: ECFP fingerprint radius 285 | length: number of bits 286 | 287 | Returns: a list of fingerprints 288 | """ 289 | return [AllChem.GetMorganFingerprintAsBitVect(m, radius, length) for m in mols] 290 | 291 | 292 | def get_mols(smiles_list: Iterable[str]) -> Iterable[Chem.Mol]: 293 | for i in smiles_list: 294 | try: 295 | mol = Chem.MolFromSmiles(i) 296 | if mol is not None: 297 | yield mol 298 | except Exception as e: 299 | logger.warning(e) 300 | 301 | 302 | def highest_tanimoto_precalc_fps(mol, fps): 303 | """ 304 | 305 | Args: 306 | mol: Rdkit molecule 307 | fps: precalculated ECFP4 bitvectors 308 | 309 | Returns: 310 | 311 | """ 312 | 313 | if fps is None or len(fps) == 0: 314 | return 0 315 | 316 | fp1 = AllChem.GetMorganFingerprintAsBitVect(mol, 2, 4096) 317 | sims = np.array(DataStructs.BulkTanimotoSimilarity(fp1, fps)) 318 | 319 | return sims.max() 320 | 321 | 322 | def continuous_kldiv(X_baseline: np.ndarray, X_sampled: np.ndarray) -> float: 323 | kde_P = gaussian_kde(X_baseline) 324 | kde_Q = gaussian_kde(X_sampled) 325 | x_eval = np.linspace(np.hstack([X_baseline, X_sampled]).min(), np.hstack([X_baseline, X_sampled]).max(), num=1000) 326 | P = kde_P(x_eval) + 1e-10 327 | Q = kde_Q(x_eval) + 1e-10 328 | 329 | return entropy(P, Q) 330 | 331 | 332 | def discrete_kldiv(X_baseline: np.ndarray, X_sampled: np.ndarray) -> float: 333 | P, bins = histogram(X_baseline, bins=10, density=True) 334 | P += 1e-10 335 | Q, _ = histogram(X_sampled, bins=bins, density=True) 336 | Q += 1e-10 337 | 338 | return entropy(P, Q) 339 | 340 | 341 | def calculate_pc_descriptors(smiles: Iterable[str], pc_descriptors: List[str]) -> np.ndarray: 342 | output = [] 343 | 344 | for i in smiles: 345 | d = _calculate_pc_descriptors(i, pc_descriptors) 346 | if d is not None: 347 | output.append(d) 348 | 349 | return np.array(output) 350 | 351 | 352 | def _calculate_pc_descriptors(smiles: str, pc_descriptors: List[str]) -> Optional[np.ndarray]: 353 | calc = MoleculeDescriptors.MolecularDescriptorCalculator(pc_descriptors) 354 | 355 | mol = Chem.MolFromSmiles(smiles) 356 | if mol is None: 357 | return None 358 | _fp = calc.CalcDescriptors(mol) 359 | _fp = np.array(_fp) 360 | mask = np.isfinite(_fp) 361 | if (mask == 0).sum() > 0: 362 | logger.warning(f'{smiles} contains an NAN physchem descriptor') 363 | _fp[~mask] = 0 364 | 365 | return _fp 366 | 367 | 368 | def parse_molecular_formula(formula: str) -> List[Tuple[str, int]]: 369 | """ 370 | Parse a molecular formulat to get the element types and counts. 371 | 372 | Args: 373 | formula: molecular formula, f.i. "C8H3F3Br" 374 | 375 | Returns: 376 | A list of tuples containing element types and number of occurrences. 377 | """ 378 | matches = re.findall(r'([A-Z][a-z]*)(\d*)', formula) 379 | 380 | # Convert matches to the required format 381 | results = [] 382 | for match in matches: 383 | # convert count to an integer, and set it to 1 if the count is not visible in the molecular formula 384 | count = 1 if not match[1] else int(match[1]) 385 | results.append((match[0], count)) 386 | 387 | return results 388 | -------------------------------------------------------------------------------- /guacamol/utils/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | from typing import List, Any, Optional, Set 5 | from urllib.request import urlretrieve 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | 11 | def remove_duplicates(list_with_duplicates): 12 | """ 13 | Removes the duplicates and keeps the ordering of the original list. 14 | For duplicates, the first occurrence is kept and the later occurrences are ignored. 15 | 16 | Args: 17 | list_with_duplicates: list that possibly contains duplicates 18 | 19 | Returns: 20 | A list with no duplicates. 21 | """ 22 | 23 | unique_set: Set[Any] = set() 24 | unique_list = [] 25 | for element in list_with_duplicates: 26 | if element not in unique_set: 27 | unique_set.add(element) 28 | unique_list.append(element) 29 | 30 | return unique_list 31 | 32 | 33 | def get_random_subset(dataset: List[Any], subset_size: int, seed: Optional[int] = None) -> List[Any]: 34 | """ 35 | Get a random subset of some dataset. 36 | 37 | For reproducibility, the random number generator seed can be specified. 38 | Nevertheless, the state of the random number generator is restored to avoid side effects. 39 | 40 | Args: 41 | dataset: full set to select a subset from 42 | subset_size: target size of the subset 43 | seed: random number generator seed. Defaults to not setting the seed. 44 | 45 | Returns: 46 | subset of the original dataset as a list 47 | """ 48 | if len(dataset) < subset_size: 49 | raise Exception(f'The dataset to extract a subset from is too small: ' 50 | f'{len(dataset)} < {subset_size}') 51 | 52 | # save random number generator state 53 | rng_state = np.random.get_state() 54 | 55 | if seed is not None: 56 | # extract a subset (for a given training set, the subset will always be identical). 57 | np.random.seed(seed) 58 | 59 | subset = np.random.choice(dataset, subset_size, replace=False) 60 | 61 | if seed is not None: 62 | # reset random number generator state, only if needed 63 | np.random.set_state(rng_state) 64 | 65 | return list(subset) 66 | 67 | 68 | def download_if_not_present(filename, uri): 69 | """ 70 | Download a file from a URI if it doesn't already exist. 71 | """ 72 | if os.path.isfile(filename): 73 | print("{} already downloaded, reusing.".format(filename)) 74 | else: 75 | with open(filename, "wb") as fd: 76 | print('Starting {} download from {}...'.format(filename, uri)) 77 | with ProgressBarUpTo(unit='B', unit_scale=True, unit_divisor=1024, miniters=1) as t: 78 | urlretrieve(uri, fd.name, reporthook=t.update_to) 79 | print('Finished {} download.'.format(filename)) 80 | 81 | 82 | class ProgressBar(tqdm): 83 | """ 84 | Create a version of TQDM that notices whether it is going to the output or a file. 85 | """ 86 | 87 | def __init__(self, *args, **kwargs) -> None: 88 | """Overwrite TQDM and detect if output is a file or not. 89 | """ 90 | # See if output is a terminal, set to updates every 30 seconds 91 | if not sys.stdout.isatty(): 92 | kwargs['mininterval'] = 30.0 93 | kwargs['maxinterval'] = 30.0 94 | super(ProgressBar, self).__init__(*args, **kwargs) 95 | 96 | 97 | class ProgressBarUpTo(ProgressBar): 98 | """ 99 | Fancy Progress Bar that accepts a position not a delta. 100 | """ 101 | 102 | def update_to(self, b=1, bsize=1, tsize=None): 103 | """ 104 | Update to a specified position. 105 | """ 106 | if tsize is not None: 107 | self.total = tsize 108 | self.update(b * bsize - self.n) # will also set self.n = b * bsize 109 | 110 | 111 | def get_time_string(): 112 | lt = time.localtime() 113 | return "%04d%02d%02d-%02d%02d" % (lt.tm_year, lt.tm_mon, lt.tm_mday, lt.tm_hour, lt.tm_min) 114 | -------------------------------------------------------------------------------- /guacamol/utils/descriptors.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem 2 | from rdkit.Chem import Descriptors, Mol, rdMolDescriptors 3 | 4 | 5 | def logP(mol: Mol) -> float: 6 | return Descriptors.MolLogP(mol) 7 | 8 | 9 | def qed(mol: Mol) -> float: 10 | return Descriptors.qed(mol) 11 | 12 | 13 | def tpsa(mol: Mol) -> float: 14 | return Descriptors.TPSA(mol) 15 | 16 | 17 | def bertz(mol: Mol) -> float: 18 | return Descriptors.BertzCT(mol) 19 | 20 | 21 | def mol_weight(mol: Mol) -> float: 22 | return Descriptors.MolWt(mol) 23 | 24 | 25 | def num_H_donors(mol: Mol) -> int: 26 | return Descriptors.NumHDonors(mol) 27 | 28 | 29 | def num_H_acceptors(mol: Mol) -> int: 30 | return Descriptors.NumHAcceptors(mol) 31 | 32 | 33 | def num_rotatable_bonds(mol: Mol) -> int: 34 | return Descriptors.NumRotatableBonds(mol) 35 | 36 | 37 | def num_rings(mol: Mol) -> int: 38 | return rdMolDescriptors.CalcNumRings(mol) 39 | 40 | 41 | def num_aromatic_rings(mol: Mol) -> int: 42 | return rdMolDescriptors.CalcNumAromaticRings(mol) 43 | 44 | 45 | def num_atoms(mol: Mol) -> int: 46 | """ 47 | Returns the total number of atoms, H included 48 | """ 49 | mol = Chem.AddHs(mol) 50 | return mol.GetNumAtoms() 51 | 52 | 53 | class AtomCounter: 54 | 55 | def __init__(self, element: str) -> None: 56 | """ 57 | Args: 58 | element: element to count within a molecule 59 | """ 60 | self.element = element 61 | 62 | def __call__(self, mol: Mol) -> int: 63 | """ 64 | Count the number of atoms of a given type. 65 | 66 | Args: 67 | mol: molecule 68 | 69 | Returns: 70 | The number of atoms of the given type. 71 | """ 72 | # if the molecule contains H atoms, they may be implicit, so add them 73 | if self.element == 'H': 74 | mol = Chem.AddHs(mol) 75 | 76 | return sum(1 for a in mol.GetAtoms() if a.GetSymbol() == self.element) 77 | -------------------------------------------------------------------------------- /guacamol/utils/fingerprints.py: -------------------------------------------------------------------------------- 1 | from rdkit.Chem import AllChem, Mol 2 | from rdkit.Chem.AtomPairs.Sheridan import GetBPFingerprint, GetBTFingerprint 3 | from rdkit.Chem.Pharm2D import Generate, Gobbi_Pharm2D 4 | 5 | 6 | class _FingerprintCalculator: 7 | """ 8 | Calculate the fingerprint while avoiding a series of if-else. 9 | See recipe 8.21 of the book "Python Cookbook". 10 | 11 | To support a new type of fingerprint, just add a function "get_fpname(self, mol)". 12 | """ 13 | 14 | def get_fingerprint(self, mol: Mol, fp_type: str): 15 | method_name = 'get_' + fp_type 16 | method = getattr(self, method_name) 17 | if method is None: 18 | raise Exception(f'{fp_type} is not a supported fingerprint type.') 19 | return method(mol) 20 | 21 | def get_AP(self, mol: Mol): 22 | return AllChem.GetAtomPairFingerprint(mol, maxLength=10) 23 | 24 | def get_PHCO(self, mol: Mol): 25 | return Generate.Gen2DFingerprint(mol, Gobbi_Pharm2D.factory) 26 | 27 | def get_BPF(self, mol: Mol): 28 | return GetBPFingerprint(mol) 29 | 30 | def get_BTF(self, mol: Mol): 31 | return GetBTFingerprint(mol) 32 | 33 | def get_PATH(self, mol: Mol): 34 | return AllChem.RDKFingerprint(mol) 35 | 36 | def get_ECFP4(self, mol: Mol): 37 | return AllChem.GetMorganFingerprint(mol, 2) 38 | 39 | def get_ECFP6(self, mol: Mol): 40 | return AllChem.GetMorganFingerprint(mol, 3) 41 | 42 | def get_FCFP4(self, mol: Mol): 43 | return AllChem.GetMorganFingerprint(mol, 2, useFeatures=True) 44 | 45 | def get_FCFP6(self, mol: Mol): 46 | return AllChem.GetMorganFingerprint(mol, 3, useFeatures=True) 47 | 48 | 49 | def get_fingerprint(mol: Mol, fp_type: str): 50 | return _FingerprintCalculator().get_fingerprint(mol=mol, fp_type=fp_type) 51 | -------------------------------------------------------------------------------- /guacamol/utils/helpers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_default_logger(): 5 | """ 6 | Call this function in your main function to initialize a basic logger. 7 | 8 | To have more control on the format or level, call `logging.basicConfig()` directly instead. 9 | 10 | If you don't initialize any logger, log entries from the guacamol package will not appear anywhere. 11 | """ 12 | logging.basicConfig(format='%(levelname)s : %(message)s', level=logging.INFO) 13 | -------------------------------------------------------------------------------- /guacamol/utils/math.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | 5 | 6 | def arithmetic_mean(values: List[float]) -> float: 7 | """ 8 | Computes the arithmetic mean of a list of values. 9 | """ 10 | return sum(values) / len(values) 11 | 12 | 13 | def geometric_mean(values: List[float]) -> float: 14 | """ 15 | Computes the geometric mean of a list of values. 16 | """ 17 | a = np.array(values) 18 | return a.prod() ** (1.0 / len(a)) 19 | -------------------------------------------------------------------------------- /guacamol/utils/sampling_helpers.py: -------------------------------------------------------------------------------- 1 | from typing import List, Set 2 | 3 | from guacamol.distribution_matching_generator import DistributionMatchingGenerator 4 | from guacamol.utils.chemistry import is_valid, canonicalize 5 | 6 | 7 | def sample_valid_molecules(model: DistributionMatchingGenerator, number_molecules: int, max_tries=10) -> List[str]: 8 | """ 9 | Sample from the given generator until the desired number of valid molecules 10 | has been sampled (i.e., ignore invalid molecules). 11 | 12 | Args: 13 | model: model to sample from 14 | number_molecules: number of valid molecules to generate 15 | max_tries: determines the maximum number N of samples to draw, N = number_molecules * max_tries 16 | 17 | Returns: 18 | A list of number_molecules valid molecules. If this was not possible with the given max_tries, the list may be shorter. 19 | """ 20 | 21 | max_samples = max_tries * number_molecules 22 | number_already_sampled = 0 23 | 24 | valid_molecules: List[str] = [] 25 | 26 | while len(valid_molecules) < number_molecules and number_already_sampled < max_samples: 27 | remaining_to_sample = number_molecules - len(valid_molecules) 28 | 29 | samples = model.generate(remaining_to_sample) 30 | number_already_sampled += remaining_to_sample 31 | 32 | valid_molecules += [m for m in samples if is_valid(m)] 33 | 34 | return valid_molecules 35 | 36 | 37 | def sample_unique_molecules(model: DistributionMatchingGenerator, number_molecules: int, max_tries=10) -> List[str]: 38 | """ 39 | Sample from the given generator until the desired number of unique (distinct) molecules 40 | has been sampled (i.e., ignore duplicate molecules). 41 | 42 | Args: 43 | model: model to sample from 44 | number_molecules: number of unique (distinct) molecules to generate 45 | max_tries: determines the maximum number N of samples to draw, N = number_molecules * max_tries 46 | 47 | Returns: 48 | A list of number_molecules unique molecules, in canonalized form. 49 | If this was not possible with the given max_tries, the list may be shorter. 50 | The generation order is kept. 51 | """ 52 | 53 | max_samples = max_tries * number_molecules 54 | number_already_sampled = 0 55 | 56 | unique_list: List[str] = [] 57 | unique_set: Set[str] = set() 58 | 59 | while len(unique_list) < number_molecules and number_already_sampled < max_samples: 60 | remaining_to_sample = number_molecules - len(unique_list) 61 | 62 | samples = model.generate(remaining_to_sample) 63 | number_already_sampled += remaining_to_sample 64 | 65 | for smiles in samples: 66 | canonical_smiles = canonicalize(smiles) 67 | if canonical_smiles is not None and canonical_smiles not in unique_set: 68 | unique_set.add(canonical_smiles) 69 | unique_list.append(canonical_smiles) 70 | 71 | # this should always be True 72 | assert len(unique_set) == len(unique_list) 73 | 74 | return unique_list 75 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | check_untyped_defs = True 3 | 4 | [mypy-fcd.*] 5 | ignore_missing_imports = True 6 | 7 | [mypy-joblib.*] 8 | ignore_missing_imports = True 9 | 10 | [mypy-numpy.*] 11 | ignore_missing_imports = True 12 | 13 | [mypy-pytest.*] 14 | ignore_missing_imports = True 15 | 16 | [mypy-rdkit.*] 17 | ignore_missing_imports = True 18 | 19 | [mypy-scipy.*] 20 | ignore_missing_imports = True 21 | 22 | [mypy-tqdm.*] 23 | ignore_missing_imports = True 24 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import re 3 | from os import path 4 | from setuptools import setup 5 | 6 | # Get the version from guacamol/__init__.py 7 | # Adapted from https://stackoverflow.com/a/39671214 8 | __version__ = re.search(r'__version__\s*=\s*[\'"]([^\'"]*)[\'"]', 9 | io.open('guacamol/__init__.py', encoding='utf_8_sig').read() 10 | ).group(1) 11 | 12 | this_directory = path.abspath(path.dirname(__file__)) 13 | with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f: 14 | long_description = f.read() 15 | 16 | setup(name='guacamol', 17 | version=__version__, 18 | author='BenevolentAI', 19 | author_email='guacamol@benevolent.ai', 20 | description='Guacamol: benchmarks for de novo molecular design', 21 | long_description=long_description, 22 | long_description_content_type='text/markdown', 23 | url='https://github.com/BenevolentAI/guacamol', 24 | packages=['guacamol', 'guacamol.data', 'guacamol.utils'], 25 | license='MIT', 26 | install_requires=[ 27 | 'joblib>=0.12.5', 28 | 'numpy>=1.15.2', 29 | 'scipy>=1.1.0', 30 | 'tqdm>=4.26.0', 31 | 'FCD>=1.1', 32 | 'rdkit-pypi>=2021.9.2.1', 33 | ], 34 | python_requires='>=3.6', 35 | extras_require={ 36 | 'rdkit': ['rdkit>=2018.09.1.0'], 37 | }, 38 | include_package_data=True, 39 | zip_safe=False, 40 | ) 41 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenevolentAI/guacamol/60ebe1f6a396f16e08b834dce448e9343d259feb/tests/__init__.py -------------------------------------------------------------------------------- /tests/mock_generator.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from guacamol.distribution_matching_generator import DistributionMatchingGenerator 4 | 5 | 6 | class MockGenerator(DistributionMatchingGenerator): 7 | """ 8 | Mock generator that returns pre-defined molecules, 9 | possibly split in several calls 10 | """ 11 | 12 | def __init__(self, molecules: List[str]) -> None: 13 | self.molecules = molecules 14 | self.cursor = 0 15 | 16 | def generate(self, number_samples: int) -> List[str]: 17 | end = self.cursor + number_samples 18 | 19 | sampled_molecules = self.molecules[self.cursor:end] 20 | self.cursor = end 21 | return sampled_molecules 22 | -------------------------------------------------------------------------------- /tests/test_distribution_learning_benchmarks.py: -------------------------------------------------------------------------------- 1 | from guacamol.distribution_learning_benchmark import ValidityBenchmark, UniquenessBenchmark, NoveltyBenchmark, \ 2 | KLDivBenchmark 3 | from guacamol.assess_distribution_learning import _assess_distribution_learning 4 | from .mock_generator import MockGenerator 5 | import numpy as np 6 | import tempfile 7 | from os.path import join 8 | 9 | 10 | def test_validity_does_not_penalize_duplicates(): 11 | generator = MockGenerator(['CCC', 'CCC']) 12 | benchmark = ValidityBenchmark(number_samples=2) 13 | 14 | assert benchmark.assess_model(generator).score == 1.0 15 | 16 | 17 | def test_validity_score_is_proportion_of_valid_molecules(): 18 | generator = MockGenerator(['CCC', 'CC(CC)C', 'invalidMolecule']) 19 | benchmark = ValidityBenchmark(number_samples=3) 20 | 21 | assert benchmark.assess_model(generator).score == 2.0 / 3.0 22 | 23 | 24 | def test_uniqueness_penalizes_duplicates(): 25 | generator = MockGenerator(['CCC', 'CCC', 'CCC']) 26 | benchmark = UniquenessBenchmark(number_samples=3) 27 | 28 | assert benchmark.assess_model(generator).score == 1.0 / 3.0 29 | 30 | 31 | def test_uniqueness_penalizes_duplicates_with_different_smiles_strings(): 32 | generator = MockGenerator(['C(O)C', 'CCO', 'OCC']) 33 | benchmark = UniquenessBenchmark(number_samples=3) 34 | 35 | assert benchmark.assess_model(generator).score == 1.0 / 3.0 36 | 37 | 38 | def test_uniqueness_does_not_penalize_invalid_molecules(): 39 | generator = MockGenerator(['C(O)C', 'invalid1', 'invalid2', 'CCC', 'NCCN']) 40 | benchmark = UniquenessBenchmark(number_samples=3) 41 | 42 | assert benchmark.assess_model(generator).score == 1.0 43 | 44 | 45 | def test_novelty_score_is_zero_if_no_molecule_is_new(): 46 | molecules = ['CCOCC', 'NNNNONNN', 'C=CC=C'] 47 | generator = MockGenerator(molecules) 48 | benchmark = NoveltyBenchmark(number_samples=3, training_set=molecules) 49 | 50 | assert benchmark.assess_model(generator).score == 0.0 51 | 52 | 53 | def test_novelty_score_is_one_if_all_molecules_are_new(): 54 | generator = MockGenerator(['CCOCC', 'NNNNONNN', 'C=CC=C']) 55 | benchmark = NoveltyBenchmark(number_samples=3, training_set=['CO', 'CC']) 56 | 57 | assert benchmark.assess_model(generator).score == 1.0 58 | 59 | 60 | def test_novelty_score_does_not_penalize_duplicates(): 61 | generator = MockGenerator(['CCOCC', 'O(CC)CC', 'C=CC=C', 'CC']) 62 | benchmark = NoveltyBenchmark(number_samples=3, training_set=['CO', 'CC']) 63 | 64 | # Gets 2 out of 3: one of the duplicated molecules is ignored, so the sampled molecules are 65 | # ['CCOCC', 'C=CC=C', 'CC'], and 'CC' is not novel 66 | assert benchmark.assess_model(generator).score == 2.0 / 3.0 67 | 68 | 69 | def test_novelty_score_penalizes_invalid_molecules(): 70 | generator = MockGenerator(['CCOCC', 'invalid1', 'invalid2', 'CCCC', 'CC']) 71 | benchmark = NoveltyBenchmark(number_samples=3, training_set=['CO', 'CC']) 72 | 73 | assert benchmark.assess_model(generator).score == 2.0 / 3.0 74 | 75 | 76 | def test_KLdiv_benchmark_same_dist(): 77 | generator = MockGenerator(['CCOCC', 'NNNNONNN', 'C=CC=C']) 78 | benchmark = KLDivBenchmark(number_samples=3, training_set=['CCOCC', 'NNNNONNN', 'C=CC=C']) 79 | result = benchmark.assess_model(generator) 80 | print(result.metadata) 81 | assert np.isclose(result.score, 1.0, ) 82 | 83 | 84 | def test_KLdiv_benchmark_different_dist(): 85 | generator = MockGenerator(['CCOCC', 'NNNNONNN', 'C=CC=C']) 86 | benchmark = KLDivBenchmark(number_samples=3, training_set=['FCCOCC', 'CC(CC)CCCCNONNN', 'C=CC=O']) 87 | result = benchmark.assess_model(generator) 88 | print(result.metadata) 89 | 90 | assert result.metadata['number_samples'] == 3 91 | assert result.metadata.get('kl_divs') is not None 92 | assert result.metadata['kl_divs'].get('BertzCT') > 0 93 | assert result.metadata['kl_divs'].get('MolLogP', None) > 0 94 | assert result.metadata['kl_divs'].get('MolWt', None) > 0 95 | assert result.metadata['kl_divs'].get('TPSA', None) > 0 96 | assert result.metadata['kl_divs'].get('NumHAcceptors', None) > 0 97 | assert result.metadata['kl_divs'].get('NumHDonors', None) > 0 98 | assert result.metadata['kl_divs'].get('NumRotatableBonds', None) > 0 99 | assert result.score < 1.0 100 | 101 | 102 | def test_distribution_learning_suite_v1(): 103 | generator = MockGenerator( 104 | ['CCl', 'CCOCCCl', 'ClCCF', 'CCCOCCOCCCO', 'CF', 'CCOCC', 'CCF', 'CCCOCC', 'NNNNONNN', 'C=CC=C'] * 10) 105 | 106 | mock_chembl = ['FCCOCC', 'C=CC=O', 'CCl', 'CCOCCCl', 'ClCCF', 'CCCOCCOCCCO', 'CF', 'CCOCC', 107 | 'CCF'] 108 | 109 | temp_dir = tempfile.mkdtemp() 110 | smiles_path = join(temp_dir, 'mock.smiles') 111 | with open(smiles_path, 'w') as f: 112 | for i in mock_chembl: 113 | f.write(f'{i}\n') 114 | f.close() 115 | 116 | json_path = join(temp_dir, 'output.json') 117 | 118 | _assess_distribution_learning(model=generator, 119 | chembl_training_file=smiles_path, 120 | json_output_file=json_path, 121 | benchmark_version='v1', 122 | number_samples=4) 123 | 124 | with open(json_path, 'r') as f: 125 | print(f.read()) 126 | -------------------------------------------------------------------------------- /tests/test_goal_directed_benchmark.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | import pytest 4 | from rdkit import Chem 5 | 6 | from guacamol.goal_directed_benchmark import GoalDirectedBenchmark 7 | from guacamol.goal_directed_generator import GoalDirectedGenerator 8 | from guacamol.goal_directed_score_contributions import uniform_specification 9 | from guacamol.scoring_function import ScoringFunctionBasedOnRdkitMol, ScoringFunction 10 | 11 | 12 | class MockScoringFunction(ScoringFunctionBasedOnRdkitMol): 13 | """ 14 | For testing purposes: scoring function that returns 0.1 * (number of atoms) 15 | """ 16 | 17 | def score_mol(self, mol: Chem.Mol) -> float: 18 | return 0.1 * mol.GetNumAtoms() 19 | 20 | 21 | class MockGenerator(GoalDirectedGenerator): 22 | """ 23 | Mock generator that returns pre-defined molecules 24 | """ 25 | def __init__(self, molecules: List[str]) -> None: 26 | self.molecules = molecules 27 | 28 | def generate_optimized_molecules(self, scoring_function: ScoringFunction, number_molecules: int, 29 | starting_population: Optional[List[str]] = None) -> List[str]: 30 | assert number_molecules == len(self.molecules) 31 | return self.molecules 32 | 33 | 34 | def test_removes_duplicates(): 35 | """ 36 | Assert that duplicated molecules (even with different SMILES strings) are considered only once. 37 | """ 38 | top3 = uniform_specification(3) 39 | benchmark = GoalDirectedBenchmark('benchmark', MockScoringFunction(), top3) 40 | generator = MockGenerator(['OCC', 'CCO', 'C(O)C']) 41 | 42 | individual_mock_score = 0.3 43 | 44 | assert benchmark.assess_model(generator).score == pytest.approx(individual_mock_score / 3) 45 | 46 | 47 | def test_removes_invalid_molecules(): 48 | top3 = uniform_specification(3) 49 | benchmark = GoalDirectedBenchmark('benchmark', MockScoringFunction(), top3) 50 | generator = MockGenerator(['OCC', 'invalid', 'invalid2']) 51 | 52 | individual_mock_score = 0.3 53 | 54 | assert benchmark.assess_model(generator).score == pytest.approx(individual_mock_score / 3) 55 | 56 | 57 | def test_correct_score_averaging(): 58 | top3 = uniform_specification(3) 59 | benchmark = GoalDirectedBenchmark('benchmark', MockScoringFunction(), top3) 60 | generator = MockGenerator(['OCC', 'CCCCOCCCC', 'C']) 61 | 62 | expected_score = (0.3 + 0.9 + 0.1) / 3 63 | 64 | assert benchmark.assess_model(generator).score == pytest.approx(expected_score) 65 | 66 | 67 | def test_correct_score_with_multiple_contributions(): 68 | """ 69 | Verify that 0.5 * (top1 + top3) delivers the correct result 70 | """ 71 | specification = uniform_specification(1, 3) 72 | benchmark = GoalDirectedBenchmark('benchmark', MockScoringFunction(), specification) 73 | generator = MockGenerator(['OCC', 'CCCCOCCCC', 'C']) 74 | 75 | top3 = (0.3 + 0.9 + 0.1) / 3 76 | top1 = 0.9 77 | expected_score = (top1 + top3) / 2 78 | 79 | assert benchmark.assess_model(generator).score == pytest.approx(expected_score) 80 | -------------------------------------------------------------------------------- /tests/test_sampling_helpers.py: -------------------------------------------------------------------------------- 1 | from guacamol.utils.sampling_helpers import sample_valid_molecules, sample_unique_molecules 2 | from .mock_generator import MockGenerator 3 | 4 | 5 | def test_sample_valid_molecules_for_valid_only(): 6 | generator = MockGenerator(['CCCC', 'CC']) 7 | 8 | mols = sample_valid_molecules(generator, 2) 9 | 10 | assert mols == ['CCCC', 'CC'] 11 | 12 | 13 | def test_sample_valid_molecules_with_invalid_molecules(): 14 | generator = MockGenerator(['invalid', 'invalid', 'invalid', 'CCCC', 'invalid', 'CC']) 15 | 16 | mols = sample_valid_molecules(generator, 2) 17 | 18 | assert mols == ['CCCC', 'CC'] 19 | 20 | 21 | def test_sample_valid_molecules_if_not_enough_valid_generated(): 22 | # does not raise an exception if 23 | molecules = ['invalid' for _ in range(20)] 24 | molecules[-1] = 'CC' 25 | molecules[-2] = 'CN' 26 | generator = MockGenerator(molecules) 27 | 28 | # samples a max of 9*2 molecules and just does not sample the good ones 29 | # in this case the list of generated molecules is empty 30 | assert not sample_valid_molecules(generator, 2, max_tries=9) 31 | 32 | # with a max of 10*2 molecules two valid molecules can be sampled 33 | generator = MockGenerator(molecules) 34 | mols = sample_valid_molecules(generator, 2) 35 | assert mols == ['CN', 'CC'] 36 | 37 | 38 | def test_sample_unique_molecules_for_valid_only(): 39 | generator = MockGenerator(['CCCC', 'CC']) 40 | 41 | mols = sample_unique_molecules(generator, 2) 42 | 43 | assert mols == ['CCCC', 'CC'] 44 | 45 | 46 | def test_sample_unique_molecules_with_invalid_molecules(): 47 | generator = MockGenerator(['invalid1', 'invalid2', 'inv3', 'CCCC', 'CC']) 48 | 49 | mols = sample_unique_molecules(generator, 2) 50 | 51 | assert mols == ['CCCC', 'CC'] 52 | 53 | 54 | def test_sample_unique_molecules_with_duplicate_molecules(): 55 | generator = MockGenerator(['CO', 'C(O)', 'CCCC', 'CC']) 56 | 57 | mols = sample_unique_molecules(generator, 2) 58 | 59 | assert mols == ['CO', 'CCCC'] 60 | 61 | 62 | def test_sample_unique_molecules_if_not_enough_unique_generated(): 63 | # does not raise an exception if 64 | molecules = ['CO' for _ in range(20)] 65 | molecules[-1] = 'CC' 66 | generator = MockGenerator(molecules) 67 | 68 | # samples a max of 9*2 molecules and just does not sample the other molecule 69 | # in this case the list of generated molecules contains just 'CO' 70 | mols = sample_unique_molecules(generator, 2, max_tries=9) 71 | assert mols == ['CO'] 72 | 73 | # with a max of 10*2 molecules two valid molecules can be sampled 74 | generator = MockGenerator(molecules) 75 | mols = sample_unique_molecules(generator, 2) 76 | assert mols == ['CO', 'CC'] 77 | -------------------------------------------------------------------------------- /tests/test_score_modifier.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from guacamol.score_modifier import LinearModifier, SquaredModifier, AbsoluteScoreModifier, GaussianModifier, \ 7 | MinGaussianModifier, MaxGaussianModifier, ThresholdedLinearModifier, ClippedScoreModifier, \ 8 | SmoothClippedScoreModifier, ChainedModifier 9 | 10 | scalar_value = 8.343 11 | value_array = np.array([[-3.3, 0, 5.5], 12 | [0.011, 2.0, -33]]) 13 | 14 | 15 | def test_linear_function_default(): 16 | f = LinearModifier() 17 | 18 | assert f(scalar_value) == scalar_value 19 | assert np.array_equal(f(value_array), value_array) 20 | 21 | 22 | def test_linear_function_with_slope(): 23 | slope = 3.3 24 | f = LinearModifier(slope=slope) 25 | 26 | assert f(scalar_value) == slope * scalar_value 27 | assert np.array_equal(f(value_array), slope * value_array) 28 | 29 | 30 | def test_squared_function(): 31 | target_value = 5.555 32 | coefficient = 0.123 33 | f = SquaredModifier(target_value=target_value, coefficient=coefficient) 34 | 35 | expected_scalar = 1.0 - coefficient * (target_value - scalar_value) ** 2 36 | expected_array = 1.0 - coefficient * np.square(target_value - value_array) 37 | 38 | assert f(scalar_value) == expected_scalar 39 | assert np.array_equal(f(value_array), expected_array) 40 | 41 | 42 | def test_absolute_function(): 43 | target_value = 5.555 44 | f = AbsoluteScoreModifier(target_value=target_value) 45 | 46 | expected_scalar = 1.0 - abs(target_value - scalar_value) 47 | expected_array = 1.0 - np.abs(target_value - value_array) 48 | 49 | assert f(scalar_value) == expected_scalar 50 | assert np.array_equal(f(value_array), expected_array) 51 | 52 | 53 | def gaussian(x, mu, sig): 54 | return np.exp(-np.power(x - mu, 2.) / (2 * np.power(sig, 2.))) 55 | 56 | 57 | def test_gaussian_function(): 58 | mu = -1.223 59 | sigma = 0.334 60 | 61 | f = GaussianModifier(mu=mu, sigma=sigma) 62 | 63 | assert f(mu) == 1.0 64 | assert f(scalar_value) == gaussian(scalar_value, mu, sigma) 65 | assert np.allclose(f(value_array), gaussian(value_array, mu, sigma)) 66 | 67 | 68 | def test_min_gaussian_function(): 69 | mu = -1.223 70 | sigma = 0.334 71 | 72 | f = MinGaussianModifier(mu=mu, sigma=sigma) 73 | 74 | assert f(mu) == 1.0 75 | 76 | low_value = -np.inf 77 | large_value = np.inf 78 | 79 | assert f(low_value) == 1.0 80 | assert f(large_value) == 0.0 81 | 82 | full_gaussian = partial(gaussian, mu=mu, sig=sigma) 83 | min_gaussian_lambda = lambda x: 1.0 if x < mu else full_gaussian(x) 84 | min_gaussian = np.vectorize(min_gaussian_lambda) 85 | 86 | assert f(scalar_value) == min_gaussian(scalar_value) 87 | assert np.allclose(f(value_array), min_gaussian(value_array)) 88 | 89 | 90 | def test_max_gaussian_function(): 91 | mu = -1.223 92 | sigma = 0.334 93 | 94 | f = MaxGaussianModifier(mu=mu, sigma=sigma) 95 | 96 | assert f(mu) == 1.0 97 | 98 | low_value = -np.inf 99 | large_value = np.inf 100 | 101 | assert f(low_value) == 0.0 102 | assert f(large_value) == 1.0 103 | 104 | full_gaussian = partial(gaussian, mu=mu, sig=sigma) 105 | max_gaussian_lambda = lambda x: 1.0 if x > mu else full_gaussian(x) 106 | max_gaussian = np.vectorize(max_gaussian_lambda) 107 | 108 | assert f(scalar_value) == max_gaussian(scalar_value) 109 | assert np.allclose(f(value_array), max_gaussian(value_array)) 110 | 111 | 112 | def test_tanimoto_threshold_function(): 113 | threshold = 5.555 114 | f = ThresholdedLinearModifier(threshold=threshold) 115 | 116 | large_value = np.inf 117 | 118 | assert f(large_value) == 1.0 119 | assert f(threshold) == 1.0 120 | 121 | expected_array = np.minimum(value_array, threshold) / threshold 122 | assert np.array_equal(f(value_array), expected_array) 123 | 124 | 125 | def test_clipped_function(): 126 | min_x = 4.4 127 | max_x = 8.8 128 | min_score = -3.3 129 | max_score = 9.2 130 | 131 | modifier = ClippedScoreModifier(upper_x=max_x, lower_x=min_x, high_score=max_score, low_score=min_score) 132 | 133 | # values smaller than min_x should be assigned min_score 134 | for x in [-2, 0, 4, 4.4]: 135 | assert modifier(x) == min_score 136 | 137 | # values larger than max_x should be assigned min_score 138 | for x in [8.8, 9.0, 1000]: 139 | assert modifier(x) == max_score 140 | 141 | # values in between are interpolated 142 | slope = (max_score - min_score) / (max_x - min_x) 143 | for x in [4.4, 4.8, 5.353, 8.034, 8.8]: 144 | dx = x - min_x 145 | dy = dx * slope 146 | assert modifier(x) == pytest.approx(min_score + dy) 147 | 148 | 149 | def test_clipped_function_inverted(): 150 | # The clipped function also works for decreasing scores 151 | max_x = 4.4 152 | min_x = 8.8 153 | min_score = -3.3 154 | max_score = 9.2 155 | 156 | modifier = ClippedScoreModifier(upper_x=max_x, lower_x=min_x, high_score=max_score, low_score=min_score) 157 | 158 | # values smaller than max_x should be assigned the maximal score 159 | for x in [-2, 0, 4, 4.4]: 160 | assert modifier(x) == max_score 161 | 162 | # values larger than min_x should be assigned min_score 163 | for x in [8.8, 9.0, 1000]: 164 | assert modifier(x) == min_score 165 | 166 | # values in between are interpolated 167 | slope = (max_score - min_score) / (max_x - min_x) 168 | for x in [4.4, 4.8, 5.353, 8.034, 8.8]: 169 | dx = x - min_x 170 | dy = dx * slope 171 | assert modifier(x) == pytest.approx(min_score + dy) 172 | 173 | 174 | def test_thresholded_is_special_case_of_clipped_for_positive_input(): 175 | threshold = 4.584 176 | thresholded_modifier = ThresholdedLinearModifier(threshold=threshold) 177 | clipped_modifier = ClippedScoreModifier(upper_x=threshold) 178 | 179 | values = np.array([0, 2.3, 8.545, 3.23, 0.12, 55.555]) 180 | 181 | assert np.allclose(thresholded_modifier(values), clipped_modifier(values)) 182 | 183 | 184 | def test_smooth_clipped(): 185 | min_x = 4.4 186 | max_x = 8.8 187 | min_score = -3.3 188 | max_score = 9.2 189 | 190 | modifier = SmoothClippedScoreModifier(upper_x=max_x, lower_x=min_x, high_score=max_score, low_score=min_score) 191 | 192 | # assert that the slope in the middle is correct 193 | 194 | middle_x = (min_x + max_x) / 2 195 | delta = 1e-5 196 | vp = modifier(middle_x + delta) 197 | vm = modifier(middle_x - delta) 198 | 199 | slope = (vp - vm) / (2 * delta) 200 | expected_slope = (max_score - min_score) / (max_x - min_x) 201 | 202 | assert slope == pytest.approx(expected_slope) 203 | 204 | # assert behavior at +- infinity 205 | 206 | assert modifier(1e5) == pytest.approx(max_score) 207 | assert modifier(-1e5) == pytest.approx(min_score) 208 | 209 | 210 | def test_smooth_clipped_inverted(): 211 | # The smooth clipped function also works for decreasing scores 212 | max_x = 4.4 213 | min_x = 8.8 214 | min_score = -3.3 215 | max_score = 9.2 216 | 217 | modifier = SmoothClippedScoreModifier(upper_x=max_x, lower_x=min_x, high_score=max_score, low_score=min_score) 218 | 219 | # assert that the slope in the middle is correct 220 | 221 | middle_x = (min_x + max_x) / 2 222 | delta = 1e-5 223 | vp = modifier(middle_x + delta) 224 | vm = modifier(middle_x - delta) 225 | 226 | slope = (vp - vm) / (2 * delta) 227 | expected_slope = (max_score - min_score) / (max_x - min_x) 228 | 229 | assert slope == pytest.approx(expected_slope) 230 | 231 | # assert behavior at +- infinity 232 | 233 | assert modifier(1e5) == pytest.approx(min_score) 234 | assert modifier(-1e5) == pytest.approx(max_score) 235 | 236 | 237 | def test_chained_modifier(): 238 | linear = LinearModifier(slope=2) 239 | squared = SquaredModifier(10.0) 240 | 241 | chained_1 = ChainedModifier([linear, squared]) 242 | chained_2 = ChainedModifier([squared, linear]) 243 | 244 | expected_1 = 1.0 - np.square(10.0 - (2 * scalar_value)) 245 | expected_2 = 2 * (1.0 - np.square(10.0 - scalar_value)) 246 | 247 | assert chained_1(scalar_value) == expected_1 248 | assert chained_2(scalar_value) == expected_2 249 | -------------------------------------------------------------------------------- /tests/test_scoring_functions.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | from typing import List 3 | 4 | import pytest 5 | 6 | from guacamol.common_scoring_functions import IsomerScoringFunction, SMARTSScoringFunction 7 | from guacamol.score_modifier import GaussianModifier 8 | from guacamol.scoring_function import BatchScoringFunction, ArithmeticMeanScoringFunction, GeometricMeanScoringFunction 9 | from guacamol.utils.math import geometric_mean 10 | 11 | 12 | class MockScoringFunction(BatchScoringFunction): 13 | """ 14 | Mock scoring function that returns values from an array given in the constructor. 15 | """ 16 | 17 | def __init__(self, values: List[float]) -> None: 18 | super().__init__() 19 | self.values = values 20 | self.index = 0 21 | 22 | def raw_score_list(self, smiles_list: List[str]) -> List[float]: 23 | start = self.index 24 | self.index += len(smiles_list) 25 | end = self.index 26 | return self.values[start:end] 27 | 28 | 29 | def test_isomer_scoring_function_uses_geometric_mean_by_default(): 30 | scoring_function = IsomerScoringFunction('C2H4') 31 | assert scoring_function.mean_function == geometric_mean 32 | 33 | 34 | def test_isomer_scoring_function_returns_one_for_correct_molecule(): 35 | c11h24_arithmetic = IsomerScoringFunction('C11H24', mean_function='arithmetic') 36 | c11h24_geometric = IsomerScoringFunction('C11H24', mean_function='geometric') 37 | 38 | # all those smiles fit the formula C11H24 39 | smiles1 = 'CCCCCCCCCCC' 40 | smiles2 = 'CC(CCC)CCCCCC' 41 | smiles3 = 'CCCCC(CC(C)CC)C' 42 | 43 | assert c11h24_arithmetic.score(smiles1) == 1.0 44 | assert c11h24_arithmetic.score(smiles2) == 1.0 45 | assert c11h24_arithmetic.score(smiles3) == 1.0 46 | assert c11h24_geometric.score(smiles1) == 1.0 47 | assert c11h24_geometric.score(smiles2) == 1.0 48 | assert c11h24_geometric.score(smiles3) == 1.0 49 | 50 | 51 | def test_isomer_scoring_function_penalizes_additional_atoms(): 52 | c11h24_arithmetic = IsomerScoringFunction('C11H24', mean_function='arithmetic') 53 | c11h24_geometric = IsomerScoringFunction('C11H24', mean_function='geometric') 54 | 55 | # all those smiles are C11H24O 56 | smiles1 = 'CCCCCCCCCCCO' 57 | smiles2 = 'CC(CCC)COCCCCC' 58 | smiles3 = 'CCCCOC(CC(C)CC)C' 59 | 60 | # the penalty corresponds to a deviation of 1.0 from the gaussian modifier for the total number of atoms 61 | n_atoms_score = GaussianModifier(mu=0, sigma=2)(1.0) 62 | c_score = 1.0 63 | h_score = 1.0 64 | expected_score_arithmetic = (n_atoms_score + c_score + h_score) / 3.0 65 | expected_score_geometric = (n_atoms_score * c_score * h_score) ** (1 / 3) 66 | 67 | assert c11h24_arithmetic.score(smiles1) == pytest.approx(expected_score_arithmetic) 68 | assert c11h24_arithmetic.score(smiles2) == pytest.approx(expected_score_arithmetic) 69 | assert c11h24_arithmetic.score(smiles3) == pytest.approx(expected_score_arithmetic) 70 | assert c11h24_geometric.score(smiles1) == pytest.approx(expected_score_geometric) 71 | assert c11h24_geometric.score(smiles2) == pytest.approx(expected_score_geometric) 72 | assert c11h24_geometric.score(smiles3) == pytest.approx(expected_score_geometric) 73 | 74 | 75 | def test_isomer_scoring_function_penalizes_incorrect_number_atoms(): 76 | c11h24_arithmetic = IsomerScoringFunction('C12H24', mean_function='arithmetic') 77 | c11h24_geometric = IsomerScoringFunction('C12H24', mean_function='geometric') 78 | 79 | # all those smiles fit the formula C11H24O 80 | smiles1 = 'CCCCCCCCOCCC' 81 | smiles2 = 'CC(CCOC)CCCCCC' 82 | smiles3 = 'COCCCC(CC(C)CC)C' 83 | 84 | # the penalty corresponds to a deviation of 1.0 from the gaussian modifier in the number of C atoms 85 | c_score = GaussianModifier(mu=0, sigma=1)(1.0) 86 | n_atoms_score = 1.0 87 | h_score = 1.0 88 | expected_score_arithmetic = (n_atoms_score + c_score + h_score) / 3.0 89 | expected_score_geometric = (n_atoms_score * c_score * h_score) ** (1 / 3) 90 | 91 | assert c11h24_arithmetic.score(smiles1) == pytest.approx(expected_score_arithmetic) 92 | assert c11h24_arithmetic.score(smiles2) == pytest.approx(expected_score_arithmetic) 93 | assert c11h24_arithmetic.score(smiles3) == pytest.approx(expected_score_arithmetic) 94 | assert c11h24_geometric.score(smiles1) == pytest.approx(expected_score_geometric) 95 | assert c11h24_geometric.score(smiles2) == pytest.approx(expected_score_geometric) 96 | assert c11h24_geometric.score(smiles3) == pytest.approx(expected_score_geometric) 97 | 98 | 99 | def test_isomer_scoring_function_invalid_molecule(): 100 | sf = IsomerScoringFunction('C60') 101 | 102 | assert sf.score('CCCinvalid') == sf.corrupt_score 103 | 104 | 105 | def test_smarts_function(): 106 | mol1 = 'COc1cc(N(C)CCN(C)C)c(NC(=O)C=C)cc1Nc2nccc(n2)c3cn(C)c4ccccc34' 107 | mol2 = 'Cc1c(C)c2OC(C)(COc3ccc(CC4SC(=O)NC4=O)cc3)CCc2c(C)c1O' 108 | smarts = '[#7;h1]c1ncccn1' 109 | 110 | scofu1 = SMARTSScoringFunction(target=smarts) 111 | scofu_inv = SMARTSScoringFunction(target=smarts, inverse=True) 112 | 113 | assert scofu1.score(mol1) == 1.0 114 | assert scofu1.score(mol2) == 0.0 115 | assert scofu_inv.score(mol1) == 0.0 116 | assert scofu_inv.score(mol2) == 1.0 117 | 118 | assert scofu1.score_list([mol1])[0] == 1.0 119 | assert scofu1.score_list([mol2])[0] == 0.0 120 | 121 | 122 | def test_arithmetic_mean_scoring_function(): 123 | # define a scoring function returning the mean from two mock functions 124 | # and assert that it returns the correct values. 125 | 126 | weight_1 = 0.4 127 | weight_2 = 0.6 128 | 129 | mock_values_1 = [0.232, 0.665, 0.0, 1.0, 0.993] 130 | mock_values_2 = [0.010, 0.335, 0.8, 0.3, 0.847] 131 | 132 | mock_1 = MockScoringFunction(mock_values_1) 133 | mock_2 = MockScoringFunction(mock_values_2) 134 | 135 | scoring_function = ArithmeticMeanScoringFunction(scoring_functions=[mock_1, mock_2], 136 | weights=[weight_1, weight_2]) 137 | 138 | smiles = ['CC'] * 5 139 | 140 | scores = scoring_function.score_list(smiles) 141 | expected = [weight_1 * v1 + weight_2 * v2 for v1, v2 in zip(mock_values_1, mock_values_2)] 142 | 143 | assert scores == expected 144 | 145 | 146 | def test_geometric_mean_scoring_function(): 147 | # define a scoring function returning the geometric mean from two mock functions 148 | # and assert that it returns the correct values. 149 | 150 | mock_values_1 = [0.232, 0.665, 0.0, 1.0, 0.993] 151 | mock_values_2 = [0.010, 0.335, 0.8, 0.3, 0.847] 152 | 153 | mock_1 = MockScoringFunction(mock_values_1) 154 | mock_2 = MockScoringFunction(mock_values_2) 155 | 156 | scoring_function = GeometricMeanScoringFunction(scoring_functions=[mock_1, mock_2]) 157 | 158 | smiles = ['CC'] * 5 159 | 160 | scores = scoring_function.score_list(smiles) 161 | expected = [sqrt(v1 * v2) for v1, v2 in zip(mock_values_1, mock_values_2)] 162 | 163 | assert scores == expected 164 | -------------------------------------------------------------------------------- /tests/utils/test_chemistry.py: -------------------------------------------------------------------------------- 1 | from guacamol.utils.chemistry import canonicalize, canonicalize_list, is_valid, \ 2 | calculate_internal_pairwise_similarities, calculate_pairwise_similarities, parse_molecular_formula 3 | 4 | 5 | def test_validity_empty_molecule(): 6 | smiles = '' 7 | assert not is_valid(smiles) 8 | 9 | 10 | def test_validity_incorrect_syntax(): 11 | smiles = 'CCCincorrectsyntaxCCC' 12 | assert not is_valid(smiles) 13 | 14 | 15 | def test_validity_incorrect_valence(): 16 | smiles = 'CCC(CC)(CC)(=O)CCC' 17 | assert not is_valid(smiles) 18 | 19 | 20 | def test_validity_correct_molecules(): 21 | smiles_1 = 'O' 22 | smiles_2 = 'C' 23 | smiles_3 = 'CC(ONONOC)CCCc1ccccc1' 24 | 25 | assert is_valid(smiles_1) 26 | assert is_valid(smiles_2) 27 | assert is_valid(smiles_3) 28 | 29 | 30 | def test_isomeric_canonicalisation(): 31 | endiandric_acid = r'OC(=O)[C@H]5C2\C=C/C3[C@@H]5CC4[C@H](C\C=C\C=C\c1ccccc1)[C@@H]2[C@@H]34' 32 | 33 | with_stereocenters = canonicalize(endiandric_acid, include_stereocenters=True) 34 | without_stereocenters = canonicalize(endiandric_acid, include_stereocenters=False) 35 | 36 | expected_with_stereocenters = 'O=C(O)[C@H]1C2C=CC3[C@@H]1CC1[C@H](C/C=C/C=C/c4ccccc4)[C@@H]2[C@@H]31' 37 | expected_without_stereocenters = 'O=C(O)C1C2C=CC3C1CC1C(CC=CC=Cc4ccccc4)C2C31' 38 | 39 | assert with_stereocenters == expected_with_stereocenters 40 | assert without_stereocenters == expected_without_stereocenters 41 | 42 | 43 | def test_list_canonicalization_removes_none(): 44 | m1 = 'CCC(OCOCO)CC(=O)NCC' 45 | m2 = 'this.is.not.a.molecule' 46 | m3 = 'c1ccccc1' 47 | m4 = 'CC(OCON=N)CC' 48 | 49 | molecules = [m1, m2, m3, m4] 50 | canonicalized_molecules = canonicalize_list(molecules) 51 | 52 | valid_molecules = [m1, m3, m4] 53 | expected = [canonicalize(smiles) for smiles in valid_molecules] 54 | 55 | assert canonicalized_molecules == expected 56 | 57 | 58 | def test_internal_sim(): 59 | molz = ['OCCCF', 'c1cc(F)ccc1', 'c1cnc(CO)cc1', 'FOOF'] 60 | sim = calculate_internal_pairwise_similarities(molz) 61 | 62 | assert sim.shape[0] == 4 63 | assert sim.shape[1] == 4 64 | # check elements 65 | for i in range(sim.shape[0]): 66 | for j in range(sim.shape[1]): 67 | assert sim[i, j] == sim[j, i] 68 | if i != j: 69 | assert sim[i, j] < 1.0 70 | else: 71 | assert sim[i, j] == 0 72 | 73 | 74 | def test_external_sim(): 75 | molz1 = ['OCCCF', 'c1cc(F)ccc1', 'c1cnc(CO)cc1', 'FOOF'] 76 | molz2 = ['c1cc(Cl)ccc1', '[Cr][Ac][K]', '[Ca](F)[Fe]'] 77 | sim = calculate_pairwise_similarities(molz1, molz2) 78 | 79 | assert sim.shape[0] == 4 80 | assert sim.shape[1] == 3 81 | # check elements 82 | for i in range(sim.shape[0]): 83 | for j in range(sim.shape[1]): 84 | assert sim[i, j] < 1.0 85 | 86 | 87 | def test_parse_molecular_formula(): 88 | formula = 'C6H9NOF2Cl2Br' 89 | parsed = parse_molecular_formula(formula) 90 | 91 | expected = [ 92 | ('C', 6), 93 | ('H', 9), 94 | ('N', 1), 95 | ('O', 1), 96 | ('F', 2), 97 | ('Cl', 2), 98 | ('Br', 1) 99 | ] 100 | 101 | assert parsed == expected 102 | -------------------------------------------------------------------------------- /tests/utils/test_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from guacamol.utils.data import get_random_subset 5 | 6 | 7 | def test_subset(): 8 | dataset = list(np.random.rand(100)) 9 | 10 | subset = get_random_subset(dataset, 10) 11 | 12 | for s in subset: 13 | assert s in dataset 14 | 15 | 16 | def test_subset_if_dataset_too_small(): 17 | dataset = list(np.random.rand(100)) 18 | 19 | with pytest.raises(Exception): 20 | get_random_subset(dataset, 1000) 21 | 22 | 23 | def test_subset_with_no_seed(): 24 | dataset = list(np.random.rand(100)) 25 | 26 | subset1 = get_random_subset(dataset, 10) 27 | subset2 = get_random_subset(dataset, 10) 28 | 29 | assert subset1 != subset2 30 | 31 | 32 | def test_subset_with_random_seed(): 33 | dataset = list(np.random.rand(100)) 34 | 35 | subset1 = get_random_subset(dataset, 10, seed=33) 36 | subset2 = get_random_subset(dataset, 10, seed=33) 37 | subset3 = get_random_subset(dataset, 10, seed=43) 38 | 39 | assert subset1 == subset2 40 | assert subset1 != subset3 41 | -------------------------------------------------------------------------------- /tests/utils/test_descriptors.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem 2 | 3 | from guacamol.utils.descriptors import num_atoms, AtomCounter 4 | 5 | 6 | def test_num_atoms(): 7 | smiles = 'CCOC(CCC)' 8 | mol = Chem.MolFromSmiles(smiles) 9 | assert num_atoms(mol) == 21 10 | 11 | 12 | def test_num_atoms_does_not_change_mol_instance(): 13 | smiles = 'CCOC(CCC)' 14 | mol = Chem.MolFromSmiles(smiles) 15 | 16 | assert mol.GetNumAtoms() == 7 17 | num_atoms(mol) 18 | assert mol.GetNumAtoms() == 7 19 | 20 | 21 | def test_count_c_atoms(): 22 | smiles = 'CCOC(CCC)' 23 | mol = Chem.MolFromSmiles(smiles) 24 | assert AtomCounter('C')(mol) == 6 25 | 26 | 27 | def test_count_h_atoms(): 28 | smiles = 'CCOC(CCC)' 29 | mol = Chem.MolFromSmiles(smiles) 30 | assert AtomCounter('H')(mol) == 14 31 | 32 | 33 | def test_count_h_atoms_does_not_change_mol_instance(): 34 | smiles = 'CCOC(CCC)' 35 | mol = Chem.MolFromSmiles(smiles) 36 | 37 | assert mol.GetNumAtoms() == 7 38 | AtomCounter('H')(mol) 39 | assert mol.GetNumAtoms() == 7 40 | --------------------------------------------------------------------------------