├── .github └── workflows │ ├── docs.yml │ ├── publish_to_pypi.yml │ └── tests.yml ├── .gitignore ├── CHANGELOG.md ├── CITATION.cff ├── LICENSE ├── README.md ├── aizynthfinder ├── __init__.py ├── aizynthfinder.py ├── analysis │ ├── __init__.py │ ├── routes.py │ ├── tree_analysis.py │ └── utils.py ├── chem │ ├── __init__.py │ ├── mol.py │ ├── reaction.py │ └── serialization.py ├── context │ ├── __init__.py │ ├── collection.py │ ├── config.py │ ├── policy │ │ ├── __init__.py │ │ ├── expansion_strategies.py │ │ ├── filter_strategies.py │ │ ├── policies.py │ │ └── utils.py │ ├── scoring │ │ ├── __init__.py │ │ ├── collection.py │ │ └── scorers.py │ └── stock │ │ ├── __init__.py │ │ ├── queries.py │ │ └── stock.py ├── data │ ├── default_training.yml │ ├── logging.yml │ └── templates │ │ ├── reaction_tree.dot │ │ └── reaction_tree.thtml ├── interfaces │ ├── __init__.py │ ├── aizynthapp.py │ ├── aizynthcli.py │ └── gui │ │ ├── __init__.py │ │ ├── clustering.py │ │ ├── pareto_fronts.py │ │ └── utils.py ├── reactiontree.py ├── search │ ├── __init__.py │ ├── andor_trees.py │ ├── breadth_first │ │ ├── __init__.py │ │ ├── nodes.py │ │ └── search_tree.py │ ├── dfpn │ │ ├── __init__.py │ │ ├── nodes.py │ │ └── search_tree.py │ ├── mcts │ │ ├── __init__.py │ │ ├── node.py │ │ ├── search.py │ │ ├── state.py │ │ └── utils.py │ └── retrostar │ │ ├── __init__.py │ │ ├── cost.py │ │ ├── nodes.py │ │ └── search_tree.py ├── tools │ ├── __init__.py │ ├── cat_output.py │ ├── download_public_data.py │ └── make_stock.py └── utils │ ├── __init__.py │ ├── bonds.py │ ├── exceptions.py │ ├── files.py │ ├── image.py │ ├── loading.py │ ├── logging.py │ ├── math.py │ ├── models.py │ ├── mongo.py │ ├── paths.py │ ├── sc_score.py │ └── type_utils.py ├── contrib └── notebook.ipynb ├── docs ├── analysis-rel.png ├── analysis-seq.png ├── cli.rst ├── conf.py ├── configuration.rst ├── gui.rst ├── gui_clustering.png ├── gui_input.png ├── gui_results.png ├── howto.rst ├── index.rst ├── line-desc.png ├── python_interface.rst ├── relationships.rst ├── scoring.rst ├── sequences.rst ├── stocks.rst ├── treesearch-rel.png └── treesearch-seq.png ├── env-dev.yml ├── plugins ├── README.md └── expansion_strategies.py ├── poetry.lock ├── pyproject.toml ├── tasks.py └── tests ├── __init__.py ├── breadth_first ├── __init__.py ├── test_nodes.py └── test_search.py ├── chem ├── __init__.py ├── test_mol.py ├── test_reaction.py └── test_serialization.py ├── conftest.py ├── context ├── __init__.py ├── conftest.py ├── data │ ├── custom_loader.py │ ├── custom_loader2.py │ ├── linear_route_w_metadata.json │ ├── simple_filter.bloom │ └── test_reactions_template.csv ├── test_collection.py ├── test_expansion_strategies.py ├── test_mcts_config.py ├── test_policy.py ├── test_score.py └── test_stock.py ├── data ├── and_or_tree.json ├── branched_route.json ├── combined_example_tree.json ├── combined_example_tree2.json ├── dummy2_raw_template_library.csv ├── dummy_noclass_raw_template_library.csv ├── dummy_raw_template_library.csv ├── dummy_sani_raw_template_library.csv ├── full_search_tree.json.gz ├── input_checkpoint.json.gz ├── linear_route.json ├── post_processing_test.py ├── pre_processing_test.py ├── routes_for_clustering.json ├── test_reactions_template.csv └── tree_for_clustering.json ├── dfpn ├── __init__.py ├── test_nodes.py └── test_search.py ├── mcts ├── __init__.py ├── conftest.py ├── test_multiobjective.py ├── test_node.py ├── test_reward.py ├── test_serialization.py └── test_tree.py ├── retrostar ├── __init__.py ├── conftest.py ├── data │ └── andor_tree_for_clustering.json ├── test_retrostar.py ├── test_retrostar_cost.py └── test_retrostar_nodes.py ├── test_analysis.py ├── test_cli.py ├── test_expander.py ├── test_finder.py ├── test_gui.py ├── test_reactiontree.py └── utils ├── __init__.py ├── test_bonds.py ├── test_dynamic_loading.py ├── test_external_tf_models.py ├── test_file_utils.py ├── test_image.py ├── test_local_onnx_model.py └── test_scscore.py /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: docs 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v3 12 | - name: build 13 | run: | 14 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh | 15 | bash -s -- --batch 16 | conda env create -f env-dev.yml 17 | conda run --name aizynth-dev poetry install -E all 18 | conda run --name aizynth-dev inv build-docs 19 | - name: deploy 20 | uses: peaceiris/actions-gh-pages@v3 21 | with: 22 | github_token: ${{ secrets.GITHUB_TOKEN }} 23 | publish_dir: ./docs/build/html 24 | publish_branch: gh-pages 25 | force_orphan: true 26 | -------------------------------------------------------------------------------- /.github/workflows/publish_to_pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python distributions to PyPI 2 | 3 | on: push 4 | 5 | jobs: 6 | build-n-publish: 7 | name: Build and publish Python distributions to PyPI 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@v3 12 | - name: Set up Python 13 | uses: actions/setup-python@v4 14 | with: 15 | python-version: "3.x" 16 | - name: Install pypa/build 17 | run: >- 18 | python3 -m 19 | pip install 20 | build 21 | --user 22 | - name: Build a binary wheel and a source tarball 23 | run: >- 24 | python3 -m 25 | build 26 | --sdist 27 | --wheel 28 | --outdir dist/ 29 | . 30 | - name: Publish distribution to PyPI 31 | if: startsWith(github.ref, 'refs/tags') 32 | uses: pypa/gh-action-pypi-publish@release/v1 33 | with: 34 | password: ${{ secrets.PYPI_API_TOKEN }} 35 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | - name: Run 15 | run: | 16 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh | 17 | bash -s -- --batch 18 | conda env create -f env-dev.yml 19 | conda run --name aizynth-dev poetry install -E all -E tf 20 | conda run --name aizynth-dev inv full-tests 21 | - name: Upload coverage to Codecov 22 | uses: codecov/codecov-action@v1 23 | with: 24 | token: ${{ secrets.CODECOV_TOKEN }} 25 | files: ./coverage.xml 26 | directory: ./coverage/ 27 | name: codecov-aizynth 28 | fail_ci_if_error: false 29 | verbose: true 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/build/ 73 | docs/aizynth* 74 | docs/modules.rst 75 | docs/cli_help.txt 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # PyCharmm idea 135 | .idea/ 136 | 137 | #VS Code 138 | .vscode/ 139 | 140 | # Pytest coverage output 141 | coverage/ 142 | 143 | # Files created by the tool 144 | smiles.txt 145 | config_local.yml 146 | devtools/seldon-output-*/ 147 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | # YAML 1.2 2 | --- 3 | abstract: "We present the open-source AiZynthFinder software that can be readily used in retrosynthetic planning. The algorithm is based on a Monte Carlo tree search that recursively breaks down a molecule to purchasable precursors. The tree search is guided by an artificial neural network policy that suggests possible precursors by utilizing a library of known reaction templates. The software is fast and can typically find a solution in less than 10 s and perform a complete search in less than 1 min. Moreover, the development of the code was guided by a range of software engineering principles such as automatic testing, system design and continuous integration leading to robust software with high maintainability. Finally, the software is well documented to make it suitable for beginners. The software is available at http://www.github.com/MolecularAI/aizynthfinder." 4 | authors: 5 | - 6 | family-names: Genheden 7 | given-names: Samuel 8 | - 9 | family-names: Thakkar 10 | given-names: Amol 11 | - 12 | family-names: "Chadimová" 13 | given-names: Veronika 14 | - 15 | family-names: Reymond 16 | given-names: "Jean-Louis" 17 | - 18 | family-names: Engkvist 19 | given-names: Ola 20 | - 21 | family-names: Bjerrum 22 | given-names: Esben 23 | orcid: "https://orcid.org/0000-0003-1614-7376" 24 | cff-version: "1.1.0" 25 | date-released: 2020-12-08 26 | doi: "https://doi.org/10.1186/s13321-020-00472-1" 27 | identifiers: 28 | - 29 | type: doi 30 | value: "10.1186/s13321-020-00472-1" 31 | keywords: 32 | - retrosynthesis 33 | - casp 34 | - retrosynthesis 35 | - cheminformatics 36 | - "neural-networks" 37 | - "monte-carlo-tree-search" 38 | - "chemical-reactions" 39 | - astrazeneca 40 | - "reaction-informatics" 41 | license: MIT 42 | message: "If you use this software, please cite it using these metadata." 43 | repository-code: "https://github.com/MolecularAI/aizynthfinder" 44 | title: AiZynthFinder 45 | version: "2.2.1" 46 | ... 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2020 Samuel Genheden and Esben Bjerrum 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /aizynthfinder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/aizynthfinder/__init__.py -------------------------------------------------------------------------------- /aizynthfinder/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | """ Sub-package containing analysis routines 2 | """ 3 | from aizynthfinder.analysis.tree_analysis import TreeAnalysis # isort: skip 4 | from aizynthfinder.analysis.routes import RouteCollection 5 | from aizynthfinder.analysis.utils import RouteSelectionArguments 6 | -------------------------------------------------------------------------------- /aizynthfinder/chem/__init__.py: -------------------------------------------------------------------------------- 1 | """ Sub-package containing chemistry routines 2 | """ 3 | from aizynthfinder.chem.mol import ( 4 | Molecule, 5 | MoleculeException, 6 | TreeMolecule, 7 | UniqueMolecule, 8 | none_molecule, 9 | ) 10 | from aizynthfinder.chem.reaction import ( 11 | FixedRetroReaction, 12 | RetroReaction, 13 | SmilesBasedRetroReaction, 14 | TemplatedRetroReaction, 15 | hash_reactions, 16 | ) 17 | from aizynthfinder.chem.serialization import ( 18 | MoleculeDeserializer, 19 | MoleculeSerializer, 20 | deserialize_action, 21 | serialize_action, 22 | ) 23 | -------------------------------------------------------------------------------- /aizynthfinder/chem/serialization.py: -------------------------------------------------------------------------------- 1 | """ Module containing helper classes and routines for serialization. 2 | """ 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | import aizynthfinder.chem 8 | from aizynthfinder.utils.loading import load_dynamic_class 9 | 10 | if TYPE_CHECKING: 11 | from aizynthfinder.chem import RetroReaction 12 | from aizynthfinder.utils.type_utils import Any, Dict, Optional, Sequence, StrDict 13 | 14 | 15 | class MoleculeSerializer: 16 | """ 17 | Utility class for serializing molecules 18 | 19 | The id of the molecule to be serialized can be obtained with: 20 | 21 | .. code-block:: 22 | 23 | serializer = MoleculeSerializer() 24 | mol = Molecule(smiles="CCCO") 25 | idx = serializer[mol] 26 | 27 | which will take care of the serialization of the molecule. 28 | """ 29 | 30 | def __init__(self) -> None: 31 | self._store: Dict[int, Any] = {} 32 | 33 | def __getitem__(self, mol: Optional[aizynthfinder.chem.Molecule]) -> Optional[int]: 34 | if mol is None: 35 | return None 36 | 37 | id_ = id(mol) 38 | if id_ not in self._store: 39 | self._add_mol(mol) 40 | return id_ 41 | 42 | @property 43 | def store(self) -> Dict[int, Any]: 44 | """Return all serialized molecules as a dictionary""" 45 | return self._store 46 | 47 | def _add_mol(self, mol: aizynthfinder.chem.Molecule) -> None: 48 | id_ = id(mol) 49 | dict_ = {"smiles": mol.smiles, "class": mol.__class__.__name__} 50 | if isinstance(mol, aizynthfinder.chem.TreeMolecule): 51 | dict_["parent"] = self[mol.parent] 52 | dict_["transform"] = mol.transform 53 | if not mol.parent: 54 | dict_["smiles"] = mol.original_smiles 55 | else: 56 | dict_["smiles"] = mol.mapped_smiles 57 | self._store[id_] = dict_ 58 | 59 | 60 | class MoleculeDeserializer: 61 | """ 62 | Utility class for deserializing molecules. 63 | The serialized molecules are created upon instantiation of the class. 64 | 65 | The deserialized molecules can be obtained with: 66 | 67 | .. code-block:: 68 | 69 | deserializer = MoleculeDeserializer() 70 | mol = deserializer[idx] 71 | 72 | """ 73 | 74 | def __init__(self, store: Dict[int, Any]) -> None: 75 | self._objects: Dict[int, Any] = {} 76 | self._create_molecules(store) 77 | 78 | def __getitem__(self, id_: Optional[int]) -> Optional[aizynthfinder.chem.Molecule]: 79 | if id_ is None: 80 | return None 81 | return self._objects[id_] 82 | 83 | def get_tree_molecules( 84 | self, ids: Sequence[int] 85 | ) -> Sequence[aizynthfinder.chem.TreeMolecule]: 86 | """ 87 | Return multiple deserialized tree molecules 88 | 89 | :param ids: the list of IDs to deserialize 90 | :return: the molecule objects 91 | """ 92 | objects = [] 93 | for id_ in ids: 94 | obj = self[id_] 95 | if obj is None or not isinstance(obj, aizynthfinder.chem.TreeMolecule): 96 | raise ValueError(f"Failed to deserialize molecule with id {id_}") 97 | objects.append(obj) 98 | return objects 99 | 100 | def _create_molecules(self, store: dict) -> None: 101 | for id_, spec in store.items(): 102 | if isinstance(id_, str): 103 | id_ = int(id_) 104 | 105 | cls = spec["class"] 106 | if "parent" in spec: 107 | spec["parent"] = self[spec["parent"]] 108 | 109 | kwargs = dict(spec) 110 | del kwargs["class"] 111 | self._objects[id_] = getattr(aizynthfinder.chem, cls)(**kwargs) 112 | 113 | 114 | def serialize_action( 115 | action: RetroReaction, molecule_store: MoleculeSerializer 116 | ) -> StrDict: 117 | """ 118 | Serialize a retrosynthesis action 119 | 120 | :param action: the (re)action to serialize 121 | :param molecule_store: the molecule serialization object 122 | :return: the action as a dictionary 123 | """ 124 | dict_ = action.to_dict() 125 | dict_["mol"] = molecule_store[dict_["mol"]] 126 | if not action.unqueried: 127 | dict_["reactants"] = [ 128 | [molecule_store[item] for item in lst_] for lst_ in action.reactants 129 | ] 130 | dict_["class"] = f"{action.__class__.__module__}.{action.__class__.__name__}" 131 | return dict_ 132 | 133 | 134 | def deserialize_action( 135 | dict_: StrDict, molecule_store: MoleculeDeserializer 136 | ) -> RetroReaction: 137 | """ 138 | Deserialize a retrosynthesis action 139 | 140 | :param dict_: the (re)action as a dictionary 141 | :param molecule_store: the molecule deserialization object 142 | :return: the created action object 143 | """ 144 | mol_spec = dict_.pop("mol") 145 | dict_["mol"] = molecule_store.get_tree_molecules([mol_spec])[0] 146 | try: 147 | class_spec = dict_.pop("class") 148 | except KeyError: 149 | class_spec = "aizynthfinder.chem.TemplatedRetroReaction" 150 | cls = load_dynamic_class(class_spec) 151 | if "reactants" in dict_: 152 | reactants = [ 153 | molecule_store.get_tree_molecules(lst_) for lst_ in dict_.pop("reactants") 154 | ] 155 | return cls.from_serialization(dict_, reactants) 156 | return cls(**dict_) 157 | -------------------------------------------------------------------------------- /aizynthfinder/context/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/aizynthfinder/context/__init__.py -------------------------------------------------------------------------------- /aizynthfinder/context/collection.py: -------------------------------------------------------------------------------- 1 | """ Module containing a class that is the base class for all collection classes (stock, policies, scorers) 2 | """ 3 | from __future__ import annotations 4 | 5 | import abc 6 | from typing import TYPE_CHECKING 7 | 8 | from aizynthfinder.utils.logging import logger 9 | 10 | if TYPE_CHECKING: 11 | from aizynthfinder.utils.type_utils import Any, List, Optional, StrDict, Union 12 | 13 | 14 | class ContextCollection(abc.ABC): 15 | """ 16 | Abstract base class for a collection of items 17 | that can be loaded and then (de-)selected. 18 | 19 | One can obtain individual items with: 20 | 21 | .. code-block:: 22 | 23 | an_item = collection["key"] 24 | 25 | And delete items with 26 | 27 | .. code-block:: 28 | 29 | del collection["key"] 30 | 31 | 32 | """ 33 | 34 | _single_selection = False 35 | _collection_name = "collection" 36 | 37 | def __init__(self) -> None: 38 | self._items: StrDict = {} 39 | self._selection: List[str] = [] 40 | self._logger = logger() 41 | 42 | def __delitem__(self, key: str) -> None: 43 | if key not in self._items: 44 | raise KeyError( 45 | f"{self._collection_name.capitalize()} with name {key} not loaded." 46 | ) 47 | del self._items[key] 48 | 49 | def __getitem__(self, key: str) -> Any: 50 | if key not in self._items: 51 | raise KeyError( 52 | f"{self._collection_name.capitalize()} with name {key} not loaded." 53 | ) 54 | return self._items[key] 55 | 56 | def __len__(self) -> int: 57 | return len(self._items) 58 | 59 | @property 60 | def items(self) -> List[str]: 61 | """The available item keys""" 62 | return list(self._items.keys()) 63 | 64 | @property 65 | def selection(self) -> Union[List[str], str, None]: 66 | """The keys of the selected item(s)""" 67 | if self._single_selection: 68 | return self._selection[0] if self._selection else None 69 | return self._selection 70 | 71 | @selection.setter 72 | def selection(self, value: str) -> None: 73 | self.select(value) 74 | 75 | def deselect(self, key: Optional[str] = None) -> None: 76 | """ 77 | Deselect one or all items 78 | 79 | If no key is passed, all items will be deselected. 80 | 81 | :param key: the key of the item to deselect, defaults to None 82 | :raises KeyError: if the key is not among the selected ones 83 | """ 84 | if not key: 85 | self._selection = [] 86 | return 87 | 88 | if key not in self._selection: 89 | raise KeyError(f"Cannot deselect {key} because it is not selected") 90 | self._selection.remove(key) 91 | 92 | @abc.abstractmethod 93 | def load(self, *_: Any) -> None: 94 | """Load an item. Needs to be implemented by a sub-class""" 95 | 96 | @abc.abstractmethod 97 | def load_from_config(self, **config: Any) -> None: 98 | """Load items from a configuration. Needs to be implemented by a sub-class""" 99 | 100 | def select(self, value: Union[str, List[str]], append: bool = False) -> None: 101 | """ 102 | Select one or more items. 103 | 104 | If this is a single selection collection, only a single value is accepted. 105 | If this is a multiple selection collection it will overwrite the selection completely, 106 | unless ``append`` is True and a single key is given. 107 | 108 | :param value: the key or keys of the item(s) to select 109 | :param append: if True will append single keys to existing selection 110 | :raises ValueError: if this a single collection and value is multiple keys 111 | :raises KeyError: if at least one of the keys are not corresponding to a loaded item 112 | """ 113 | if self._single_selection and not isinstance(value, str) and len(value) > 1: 114 | raise ValueError(f"Cannot select more than one {self._collection_name}") 115 | 116 | keys = [value] if isinstance(value, str) else value 117 | 118 | for key in keys: 119 | if key not in self._items: 120 | raise KeyError( 121 | f"Invalid key specified {key} when selecting {self._collection_name}" 122 | ) 123 | 124 | if self._single_selection: 125 | self._selection = [keys[0]] 126 | elif isinstance(value, str) and append: 127 | self._selection.append(value) 128 | else: 129 | self._selection = list(keys) 130 | 131 | self._logger.info(f"Selected as {self._collection_name}: {', '.join(keys)}") 132 | 133 | def select_all(self) -> None: 134 | """Select all loaded items""" 135 | if self.items: 136 | self.select(self.items) 137 | 138 | def select_first(self) -> None: 139 | """Select the first loaded item""" 140 | if self.items: 141 | self.select(self.items[0]) 142 | 143 | def select_last(self) -> None: 144 | """Select the last loaded item""" 145 | if self.items: 146 | self.select(self.items[-1]) 147 | -------------------------------------------------------------------------------- /aizynthfinder/context/policy/__init__.py: -------------------------------------------------------------------------------- 1 | """ Sub-package containing policy routines 2 | """ 3 | 4 | from aizynthfinder.context.policy.expansion_strategies import ( 5 | ExpansionStrategy, 6 | MultiExpansionStrategy, 7 | TemplateBasedDirectExpansionStrategy, 8 | TemplateBasedExpansionStrategy, 9 | ) 10 | from aizynthfinder.context.policy.filter_strategies import ( 11 | BondFilter, 12 | FilterStrategy, 13 | QuickKerasFilter, 14 | ReactantsCountFilter, 15 | ) 16 | from aizynthfinder.context.policy.policies import ExpansionPolicy, FilterPolicy 17 | from aizynthfinder.utils.exceptions import PolicyException 18 | -------------------------------------------------------------------------------- /aizynthfinder/context/policy/utils.py: -------------------------------------------------------------------------------- 1 | """ Module containing helper routines for policies 2 | """ 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | import numpy as np 8 | 9 | if TYPE_CHECKING: 10 | from aizynthfinder.chem import TreeMolecule 11 | from aizynthfinder.chem.reaction import RetroReaction 12 | from aizynthfinder.utils.type_utils import Any, Union 13 | 14 | 15 | def _make_fingerprint( 16 | obj: Union[TreeMolecule, RetroReaction], model: Any, chiral: bool = False 17 | ) -> np.ndarray: 18 | fingerprint = obj.fingerprint(radius=2, nbits=len(model), chiral=chiral) 19 | return fingerprint.reshape([1, len(model)]) 20 | -------------------------------------------------------------------------------- /aizynthfinder/context/scoring/__init__.py: -------------------------------------------------------------------------------- 1 | """ Sub-package containing scoring routines 2 | """ 3 | 4 | from aizynthfinder.context.scoring.collection import ScorerCollection 5 | from aizynthfinder.context.scoring.scorers import ( 6 | AverageTemplateOccurrenceScorer, 7 | BrokenBondsScorer, 8 | CombinedScorer, 9 | DeltaSyntheticComplexityScorer, 10 | FractionInStockScorer, 11 | MaxTransformScorerer, 12 | NumberOfPrecursorsInStockScorer, 13 | NumberOfPrecursorsScorer, 14 | NumberOfReactionsScorer, 15 | PriceSumScorer, 16 | ReactionClassMembershipScorer, 17 | RouteCostScorer, 18 | RouteSimilarityScorer, 19 | Scorer, 20 | StateScorer, 21 | StockAvailabilityScorer, 22 | SUPPORT_DISTANCES, 23 | ) 24 | from aizynthfinder.utils.exceptions import ScorerException 25 | -------------------------------------------------------------------------------- /aizynthfinder/context/stock/__init__.py: -------------------------------------------------------------------------------- 1 | """ Sub-package containing stock routines 2 | """ 3 | from aizynthfinder.context.stock.queries import ( 4 | InMemoryInchiKeyQuery, 5 | MongoDbInchiKeyQuery, 6 | StockQueryMixin, 7 | ) 8 | from aizynthfinder.context.stock.stock import Stock 9 | from aizynthfinder.utils.exceptions import StockException 10 | -------------------------------------------------------------------------------- /aizynthfinder/data/default_training.yml: -------------------------------------------------------------------------------- 1 | library_headers: ["index", "ID", "reaction_hash", "reactants", "products", "classification", "retro_template", "template_hash", "selectivity", "outcomes", "template_code"] 2 | column_map: 3 | reaction_hash: reaction_hash 4 | reactants: reactants 5 | products: products 6 | retro_template: retro_template 7 | template_hash: template_hash 8 | metadata_headers: ["template_hash", "classification"] 9 | in_csv_headers: False 10 | csv_sep: "," 11 | reaction_smiles_column: "" 12 | output_path: "." 13 | file_prefix: "" 14 | file_postfix: 15 | raw_library: _raw_template_library.csv 16 | library: _template_library.csv 17 | false_library: _template_library_false.csv 18 | training_labels: _training_labels.npz 19 | validation_labels: _validation_labels.npz 20 | testing_labels: _testing_labels.npz 21 | training_inputs: _training_inputs.npz 22 | validation_inputs: _validation_inputs.npz 23 | testing_inputs: _testing_inputs.npz 24 | training_inputs2: _training_inputs2.npz 25 | validation_inputs2: _validation_inputs2.npz 26 | testing_inputs2: _testing_inputs2.npz 27 | training_library: _training.csv 28 | validation_library: _validation.csv 29 | testing_library: _testing.csv 30 | unique_templates: _unique_templates.hdf5 31 | split_size: 32 | training: 0.9 33 | testing: 0.05 34 | validation: 0.05 35 | batch_size: 256 36 | epochs: 100 37 | fingerprint_radius: 2 38 | fingerprint_len: 2048 39 | template_occurrence: 3 40 | remove_unsanitizable_products: False 41 | negative_data: 42 | random_trials: 1000 43 | recommender_model: "" 44 | recommender_topn: 20 45 | model: 46 | drop_out: 0.4 47 | hidden_nodes: 512 48 | -------------------------------------------------------------------------------- /aizynthfinder/data/logging.yml: -------------------------------------------------------------------------------- 1 | version: 1 2 | formatters: 3 | file: 4 | format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' 5 | console: 6 | format: '%(message)s' 7 | handlers: 8 | console: 9 | class: logging.StreamHandler 10 | level: INFO 11 | formatter: console 12 | stream: ext://sys.stdout 13 | file: 14 | class: logging.FileHandler 15 | level: DEBUG 16 | formatter: file 17 | filename: aizynthfinder.log 18 | loggers: 19 | aizynthfinder: 20 | level: DEBUG 21 | handlers: 22 | - console 23 | - file 24 | root: 25 | level: ERROR 26 | handlers: [] 27 | -------------------------------------------------------------------------------- /aizynthfinder/data/templates/reaction_tree.dot: -------------------------------------------------------------------------------- 1 | strict digraph "" { 2 | graph [layout="dot", 3 | rankdir="RL", 4 | {% if use_splines %} 5 | splines="ortho" 6 | {% endif %} 7 | ]; 8 | node [label="\N"]; 9 | {% for molecule, image_filepath in molecules.items() %} 10 | {{ id(molecule) }} [ 11 | label="", 12 | color="white", 13 | shape="none", 14 | image="{{ image_filepath }}" 15 | ]; 16 | {% endfor %} 17 | {% for reaction, reaction_shape in reactions %} 18 | {{ id(reaction) }} [ 19 | label="", 20 | fillcolor="black", 21 | shape="{{ reaction_shape }}", 22 | style="filled", 23 | width="0.1", 24 | fixedsize="true" 25 | ]; 26 | {% endfor %} 27 | {% for nodes in edges %} 28 | {{ id(nodes[0]) }} -> {{ id(nodes[1]) }} [arrowhead="none"]; 29 | {% endfor %} 30 | } -------------------------------------------------------------------------------- /aizynthfinder/data/templates/reaction_tree.thtml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
9 | 10 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /aizynthfinder/interfaces/__init__.py: -------------------------------------------------------------------------------- 1 | """ Module for interfaces to the AiZynthFinder application 2 | """ 3 | try: 4 | from aizynthfinder.interfaces.aizynthapp import AiZynthApp # noqa 5 | except ModuleNotFoundError: 6 | pass 7 | -------------------------------------------------------------------------------- /aizynthfinder/interfaces/gui/__init__.py: -------------------------------------------------------------------------------- 1 | """ Package for GUI extensions 2 | """ 3 | -------------------------------------------------------------------------------- /aizynthfinder/interfaces/gui/clustering.py: -------------------------------------------------------------------------------- 1 | """ Module containing a GUI extension for clustering 2 | """ 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | import numpy as np 8 | from IPython.display import display 9 | from ipywidgets import BoundedIntText, Button, HBox, Label, Output, Tab 10 | 11 | try: 12 | import matplotlib.pylab as plt 13 | from route_distances.clustering import ClusteringHelper 14 | from scipy.cluster.hierarchy import dendrogram 15 | except ImportError: 16 | raise ImportError( 17 | "Clustering is not supported by this installation." 18 | " Please install aizynthfinder with extras dependencies." 19 | ) 20 | 21 | 22 | if TYPE_CHECKING: 23 | from aizynthfinder.analysis import RouteCollection 24 | from aizynthfinder.interfaces.aizynthapp import AiZynthApp 25 | from aizynthfinder.utils.type_utils import StrDict 26 | 27 | 28 | class ClusteringGui: 29 | """ 30 | GUI extension to cluster routes 31 | 32 | :param routes: the routes to cluster 33 | :param content: what to cluster on 34 | """ 35 | 36 | def __init__(self, routes: RouteCollection, content: str = "both"): 37 | self._routes = routes 38 | self._routes.distance_matrix(content=content, recreate=True) 39 | self._input: StrDict = dict() 40 | self._output: StrDict = dict() 41 | self._buttons: StrDict = dict() 42 | self._create_dendrogram() 43 | self._create_input() 44 | self._create_output() 45 | 46 | @classmethod 47 | def from_app(cls, app: AiZynthApp, content: str = "both"): 48 | """ 49 | Helper function to create a GUI from a GUI app interface 50 | 51 | :param app: the app to extract the routes from 52 | :param content: what to cluster on 53 | :return: the GUI object 54 | """ 55 | if app.finder.routes is None: 56 | raise ValueError("Cannot initialize GUI from no routes") 57 | return ClusteringGui(app.finder.routes, content) 58 | 59 | def _create_dendrogram(self) -> None: 60 | dend_out = Output( 61 | layout={"width": "99%", "height": "310px", "overflow_y": "auto"} 62 | ) 63 | with dend_out: 64 | print("This is the hierarchy of the routes") 65 | fig = plt.Figure() 66 | dendrogram( 67 | ClusteringHelper(self._routes.distance_matrix()).linkage_matrix(), 68 | color_threshold=0.0, 69 | labels=np.arange(1, len(self._routes) + 1), 70 | ax=fig.gca(), 71 | ) 72 | fig.gca().set_xlabel("Route") 73 | fig.gca().set_ylabel("Distance") 74 | display(fig) 75 | display(dend_out) 76 | 77 | def _create_input(self) -> None: 78 | self._input["number_clusters"] = BoundedIntText( 79 | continuous_update=True, 80 | min=1, 81 | max=len(self._routes) - 1, 82 | layout={"width": "80px"}, 83 | ) 84 | self._buttons["cluster"] = Button(description="Cluster") 85 | self._buttons["cluster"].on_click(self._on_cluster_button_clicked) 86 | box = HBox( 87 | [ 88 | Label("Number of clusters to make"), 89 | self._input["number_clusters"], 90 | self._buttons["cluster"], 91 | ] 92 | ) 93 | display(box) 94 | help_out = Output() 95 | with help_out: 96 | print( 97 | "Optimization is carried out if the number of given clusters are less than 2" 98 | ) 99 | display(help_out) 100 | 101 | def _create_output(self) -> None: 102 | self._output["clusters"] = Tab() 103 | display(self._output["clusters"]) 104 | 105 | def _on_cluster_button_clicked(self, _) -> None: 106 | self._buttons["cluster"].enabled = False 107 | self._routes.cluster(self._input["number_clusters"].value) 108 | self._buttons["cluster"].enabled = True 109 | 110 | outputs = [] 111 | for i, cluster in enumerate(self._routes.clusters or []): 112 | output = Output( 113 | layout={ 114 | "border": "1px solid silver", 115 | "width": "99%", 116 | "height": "500px", 117 | "overflow_y": "auto", 118 | } 119 | ) 120 | with output: 121 | for image in cluster.images: 122 | print(f"Route {self._routes.images.index(image)+1}") 123 | display(image) 124 | outputs.append(output) 125 | self._output["clusters"].set_title(i, f"Cluster {i+1}") 126 | self._output["clusters"].children = outputs 127 | -------------------------------------------------------------------------------- /aizynthfinder/interfaces/gui/utils.py: -------------------------------------------------------------------------------- 1 | """Module containing utility functions for GUI. 2 | """ 3 | from __future__ import annotations 4 | 5 | from collections import defaultdict 6 | from typing import TYPE_CHECKING 7 | 8 | import ipywidgets as widgets 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import pandas as pd 12 | import seaborn as sns 13 | from IPython.display import HTML, display 14 | from paretoset import paretorank 15 | 16 | if TYPE_CHECKING: 17 | from aizynthfinder.analysis.routes import RouteCollection 18 | from aizynthfinder.utils.type_utils import List, Optional 19 | 20 | 21 | def pareto_fronts_plot( 22 | routes: RouteCollection, 23 | ) -> None: 24 | """Plot the pareto front(s). 25 | 26 | :param routes: the route collection to plot as Pareto fronts 27 | """ 28 | scorer_names = list(routes.scores[0].keys()) 29 | scores = np.array( 30 | [[score_dict[name] for name in scorer_names] for score_dict in routes.scores] 31 | ) 32 | direction_arr = np.repeat("max", len(scorer_names)) 33 | pareto_ranks = paretorank(scores, sense=direction_arr, distinct=False) 34 | 35 | pareto_fronts = pd.DataFrame(scores, columns=scorer_names) 36 | pareto_fronts.loc[:, "pareto_rank"] = pareto_ranks 37 | pareto_fronts.loc[:, "route"] = np.arange(1, scores.shape[0] + 1) 38 | pareto_fronts_unique = pareto_fronts.drop_duplicates(scorer_names) 39 | # Apply the default theme 40 | sns.set_theme() 41 | fig = sns.relplot( 42 | data=pareto_fronts_unique, 43 | x=scorer_names[0], 44 | y=scorer_names[1], 45 | hue="pareto_rank", 46 | kind="line", 47 | markers=True, 48 | ) 49 | fig.set( 50 | xlabel=f"{scorer_names[0]}", 51 | ylabel=f"{scorer_names[1]}", 52 | title="Pareto Fronts", 53 | ) 54 | 55 | objectives2solutions = defaultdict(list) 56 | for _, row in pareto_fronts.iterrows(): 57 | x_val = row[scorer_names[0]] 58 | y_val = row[scorer_names[1]] 59 | objectives2solutions[(x_val, y_val)].append(int(row["route"])) 60 | 61 | # Add route label on pareto line(s) 62 | for _, row in pareto_fronts_unique.iterrows(): 63 | x_val = row[scorer_names[0]] 64 | y_val = row[scorer_names[1]] 65 | point_val = f"Option {_values_to_string(objectives2solutions[(x_val, y_val)])}" 66 | plt.text(x=x_val, y=y_val, s=point_val, size=10) 67 | plt.show() 68 | 69 | 70 | def route_display( 71 | index: Optional[int], routes: RouteCollection, output_widget: widgets.Output 72 | ) -> None: 73 | """ 74 | Display a route with auxillary information in a widget 75 | 76 | :param index: the index of the route to display 77 | :param routes: the route collection 78 | :param output_widget: the widget to display the route on 79 | """ 80 | if index is None or routes is None or index >= len(routes): 81 | return 82 | 83 | route = routes[index] 84 | state = route["node"].state 85 | status = "Solved" if state.is_solved else "Not Solved" 86 | 87 | output_widget.clear_output() 88 | with output_widget: 89 | display(HTML(f"

{status}")) 90 | table_content = "".join( 91 | f"{name}{score:.4f}" 92 | if isinstance(score, (float, int)) 93 | else f"{name}({', '.join(f'{val:.4f}' for val in score)})" 94 | for name, score in route["all_score"].items() 95 | ) 96 | display(HTML(f"{table_content}
")) 97 | display(HTML("

Compounds to Procure")) 98 | display(state.to_image()) 99 | display(HTML("

Steps")) 100 | display(routes[index]["image"]) 101 | 102 | 103 | def _values_to_string(vals: List[int]) -> str: 104 | """ 105 | Given a list of integers, produce a nice-looking string 106 | 107 | Example: 108 | >> _values_to_string([1,2,3,10,11,19]) 109 | "1-3, 10-11, 19" 110 | """ 111 | groups = [] 112 | start = end = vals[0] 113 | for prev_val, val in zip(vals[:-1], vals[1:]): 114 | if prev_val == val - 1: 115 | end += 1 116 | else: 117 | groups.append((start, end)) 118 | start = end = val 119 | groups.append((start, end)) 120 | 121 | group_strs = [] 122 | for start, end in groups: 123 | if start == end: 124 | group_strs.append(str(start)) 125 | else: 126 | group_strs.append(f"{start}-{end}") 127 | if len(group_strs) % 5 == 0: 128 | group_strs[-1] = "\n" + group_strs[-1] 129 | return ", ".join(group_strs) 130 | -------------------------------------------------------------------------------- /aizynthfinder/search/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/aizynthfinder/search/__init__.py -------------------------------------------------------------------------------- /aizynthfinder/search/breadth_first/__init__.py: -------------------------------------------------------------------------------- 1 | """ Sub-package containing breadth first routines 2 | """ 3 | from aizynthfinder.search.breadth_first.search_tree import SearchTree 4 | -------------------------------------------------------------------------------- /aizynthfinder/search/dfpn/__init__.py: -------------------------------------------------------------------------------- 1 | """ Sub-package containing DFPN routines 2 | """ 3 | from aizynthfinder.search.dfpn.search_tree import SearchTree 4 | -------------------------------------------------------------------------------- /aizynthfinder/search/dfpn/search_tree.py: -------------------------------------------------------------------------------- 1 | """ Module containing a class that holds the tree search 2 | """ 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | from aizynthfinder.reactiontree import ReactionTree 8 | from aizynthfinder.search.andor_trees import AndOrSearchTreeBase, SplitAndOrTree 9 | from aizynthfinder.search.dfpn.nodes import MoleculeNode, ReactionNode 10 | from aizynthfinder.utils.logging import logger 11 | 12 | if TYPE_CHECKING: 13 | from aizynthfinder.chem import RetroReaction 14 | from aizynthfinder.context.config import Configuration 15 | from aizynthfinder.search.andor_trees import TreeNodeMixin 16 | from aizynthfinder.utils.type_utils import List, Optional, Sequence, Union 17 | 18 | 19 | class SearchTree(AndOrSearchTreeBase): 20 | """ 21 | Encapsulation of the Depth-First Proof-Number (DFPN) search algorithm. 22 | 23 | This algorithm does not support: 24 | 1. Filter policy 25 | 2. Serialization and deserialization 26 | 27 | :ivar config: settings of the tree search algorithm 28 | :ivar root: the root node 29 | 30 | :param config: settings of the tree search algorithm 31 | :param root_smiles: the root will be set to a node representing this molecule, defaults to None 32 | """ 33 | 34 | def __init__( 35 | self, config: Configuration, root_smiles: Optional[str] = None 36 | ) -> None: 37 | super().__init__(config, root_smiles) 38 | self._mol_nodes: List[MoleculeNode] = [] 39 | self._logger = logger() 40 | self._root_smiles = root_smiles 41 | if root_smiles: 42 | self.root: Optional[MoleculeNode] = MoleculeNode.create_root( 43 | root_smiles, config, self 44 | ) 45 | self._mol_nodes.append(self.root) 46 | else: 47 | self.root = None 48 | 49 | self._routes: List[ReactionTree] = [] 50 | self._frontier: Optional[Union[MoleculeNode, ReactionNode]] = None 51 | self._initiated = False 52 | 53 | self.profiling = { 54 | "expansion_calls": 0, 55 | "reactants_generations": 0, 56 | } 57 | 58 | @property 59 | def mol_nodes(self) -> Sequence[MoleculeNode]: # type: ignore 60 | """Return the molecule nodes of the tree""" 61 | return self._mol_nodes 62 | 63 | def one_iteration(self) -> bool: 64 | """ 65 | Perform one iteration of expansion. 66 | 67 | If possible expand the frontier node twice, i.e. expanding an OR 68 | node and then and AND node. If frontier not expandable step up in the 69 | tree and find a new frontier to expand. 70 | 71 | If a solution is found, mask that tree for exploration and start over. 72 | 73 | :raises StopIteration: if the search should be pre-maturely terminated 74 | :return: if a solution was found 75 | :rtype: bool 76 | """ 77 | if not self._initiated: 78 | if self.root is None: 79 | raise ValueError("Root is undefined. Cannot make an iteration") 80 | 81 | self._routes = [] 82 | self._frontier = self.root 83 | assert self.root is not None 84 | 85 | while True: 86 | # Expand frontier, should be OR node 87 | assert isinstance(self._frontier, MoleculeNode) 88 | expanded_or = self._search_step() 89 | expanded_and = False 90 | if self._frontier: 91 | # Expand frontier again, this time an AND node 92 | assert isinstance(self._frontier, ReactionNode) 93 | expanded_and = self._search_step() 94 | if ( 95 | expanded_or 96 | or expanded_and 97 | or self._frontier is None 98 | or self._frontier is self.root 99 | ): 100 | break 101 | 102 | found_solution = any(child.proven for child in self.root.children) 103 | if self._frontier is self.root: 104 | self.root.reset() 105 | 106 | if self._frontier is None: 107 | raise StopIteration() 108 | 109 | return found_solution 110 | 111 | def routes(self) -> List[ReactionTree]: 112 | """ 113 | Extracts and returns routes from the AND/OR tree 114 | 115 | :return: the routes 116 | """ 117 | if self.root is None: 118 | return [] 119 | if not self._routes: 120 | self._routes = SplitAndOrTree(self.root, self.config.stock).routes 121 | return self._routes 122 | 123 | def _search_step(self) -> bool: 124 | assert self._frontier is not None 125 | expanded = False 126 | if self._frontier.expandable: 127 | self._frontier.expand() 128 | expanded = True 129 | if isinstance(self._frontier, ReactionNode): 130 | self._mol_nodes.extend(self._frontier.children) 131 | 132 | self._frontier.update() 133 | if not self._frontier.explorable(): 134 | self._frontier = self._frontier.parent 135 | return False 136 | 137 | child = self._frontier.promising_child() 138 | if not child: 139 | self._frontier = self._frontier.parent 140 | return False 141 | 142 | self._frontier = child 143 | return expanded 144 | -------------------------------------------------------------------------------- /aizynthfinder/search/mcts/__init__.py: -------------------------------------------------------------------------------- 1 | """ Sub-package containing MCTS routines 2 | """ 3 | from aizynthfinder.search.mcts.node import MctsNode 4 | from aizynthfinder.search.mcts.search import MctsSearchTree 5 | from aizynthfinder.search.mcts.state import MctsState 6 | -------------------------------------------------------------------------------- /aizynthfinder/search/mcts/utils.py: -------------------------------------------------------------------------------- 1 | """ Module containing utility routines for MCTS. This is not part of public interface """ 2 | from __future__ import annotations 3 | 4 | from typing import TYPE_CHECKING 5 | 6 | from aizynthfinder.reactiontree import ReactionTreeLoader 7 | 8 | if TYPE_CHECKING: 9 | from aizynthfinder.chem import RetroReaction 10 | from aizynthfinder.search.mcts import MctsNode 11 | from aizynthfinder.utils.type_utils import List, Optional, Tuple 12 | 13 | 14 | class ReactionTreeFromSuperNode(ReactionTreeLoader): 15 | """ 16 | Creates a reaction tree object from MCTS-like nodes and reaction objects 17 | """ 18 | 19 | def _load(self, base_node: MctsNode) -> None: # type: ignore 20 | actions, nodes = route_to_node(base_node) 21 | self.tree.created_at_iteration = base_node.created_at_iteration 22 | root_mol = nodes[0].state.mols[0] 23 | self._unique_mols[id(root_mol)] = root_mol.make_unique() 24 | self._add_node( 25 | self._unique_mols[id(root_mol)], 26 | in_stock=nodes[0].state.is_solved, 27 | ) 28 | 29 | for child, action in zip(nodes[1:], actions): 30 | self._add_bipartite(child, action) 31 | 32 | def _add_bipartite(self, child: MctsNode, action: RetroReaction) -> None: 33 | reaction_obj = self._unique_reaction(action) 34 | self._add_node(reaction_obj, depth=2 * action.mol.transform + 1) 35 | self.tree.graph.add_edge(self._unique_mol(action.mol), reaction_obj) 36 | reactant_nodes = [] 37 | for mol in child.state.mols: 38 | if mol.parent is action.mol: 39 | self._add_node( 40 | self._unique_mol(mol), 41 | depth=2 * mol.transform, 42 | transform=mol.transform, 43 | in_stock=mol in child.state.stock, 44 | ) 45 | self.tree.graph.add_edge(reaction_obj, self._unique_mol(mol)) 46 | reactant_nodes.append(self._unique_mol(mol)) 47 | reaction_obj.reactants = (tuple(reactant_nodes),) 48 | 49 | 50 | def route_to_node( 51 | from_node: MctsNode, 52 | ) -> Tuple[List[RetroReaction], List[MctsNode]]: 53 | """ 54 | Return the route to a give node to the root. 55 | 56 | Will return both the actions taken to go between the nodes, 57 | and the nodes in the route themselves. 58 | 59 | :param from_node: the end of the route 60 | :return: the route 61 | """ 62 | actions = [] 63 | nodes = [] 64 | current: Optional[MctsNode] = from_node 65 | 66 | while current is not None: 67 | parent = current.parent 68 | if parent is not None: 69 | action = parent[current]["action"] 70 | actions.append(action) 71 | nodes.append(current) 72 | current = parent 73 | return actions[::-1], nodes[::-1] 74 | -------------------------------------------------------------------------------- /aizynthfinder/search/retrostar/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/aizynthfinder/search/retrostar/__init__.py -------------------------------------------------------------------------------- /aizynthfinder/search/retrostar/cost.py: -------------------------------------------------------------------------------- 1 | """ Module containing Retro* cost models """ 2 | from __future__ import annotations 3 | 4 | import pickle 5 | from typing import TYPE_CHECKING 6 | 7 | import numpy as np 8 | 9 | from aizynthfinder.search.retrostar.cost import __name__ as retrostar_cost_module 10 | from aizynthfinder.utils.loading import load_dynamic_class 11 | 12 | if TYPE_CHECKING: 13 | from aizynthfinder.chem import Molecule 14 | from aizynthfinder.context.config import Configuration 15 | from aizynthfinder.utils.type_utils import Any, List, Tuple 16 | 17 | 18 | class MoleculeCost: 19 | """ 20 | A class to compute the molecule cost. 21 | 22 | The cost to be computed is taken from the input config. If no `molecule_cost` is 23 | set, assigns ZeroMoleculeCost as the `cost` by default. The `molecule_cost` can be 24 | set as a dictionary in config under `search` in the following format: 25 | 'algorithm': 'retrostar' 26 | 'algorithm_config': { 27 | 'molecule_cost': { 28 | 'cost': name of the search cost class or custom_package.custom_model.CustomClass, 29 | other settings or params 30 | } 31 | } 32 | 33 | The cost can be computed by calling the instantiated class with a molecule. 34 | 35 | .. code-block:: 36 | 37 | calculator = MyCost(config) 38 | cost = calculator.calculate(molecule) 39 | 40 | :param config: the configuration of the tree search 41 | """ 42 | 43 | def __init__(self, config: Configuration) -> None: 44 | self._config = config 45 | if "molecule_cost" not in self._config.search.algorithm_config: 46 | self._config.search.algorithm_config["molecule_cost"] = { 47 | "cost": "ZeroMoleculeCost" 48 | } 49 | kwargs = self._config.search.algorithm_config["molecule_cost"].copy() 50 | 51 | cls = load_dynamic_class(kwargs["cost"], retrostar_cost_module) 52 | del kwargs["cost"] 53 | 54 | self.molecule_cost = cls(**kwargs) if kwargs else cls() 55 | 56 | def __call__(self, mol: Molecule) -> float: 57 | return self.molecule_cost.calculate(mol) 58 | 59 | 60 | class RetroStarCost: 61 | """ 62 | Encapsulation of the original Retro* molecular cost model 63 | 64 | Numpy implementation of original pytorch model 65 | 66 | The predictions of the score is made on a Molecule object 67 | 68 | .. code-block:: 69 | 70 | mol = Molecule(smiles="CCC") 71 | scorer = RetroStarCost() 72 | score = scorer.calculate(mol) 73 | 74 | The model provided when creating the scorer object should be a pickled 75 | tuple. 76 | The first item of the tuple should be a list of the model weights for each layer. 77 | The second item of the tuple should be a list of the model biases for each layer. 78 | 79 | :param model_path: the filename of the model weights and biases 80 | :param fingerprint_length: the number of bits in the fingerprint 81 | :param fingerprint_radius: the radius of the fingerprint 82 | :param dropout_rate: the dropout_rate 83 | """ 84 | 85 | _required_kwargs = ["model_path"] 86 | 87 | def __init__(self, **kwargs: Any) -> None: 88 | model_path = kwargs["model_path"] 89 | self.fingerprint_length: int = int(kwargs.get("fingerprint_length", 2048)) 90 | self.fingerprint_radius: int = int(kwargs.get("fingerprint_radius", 2)) 91 | self.dropout_rate: float = float(kwargs.get("dropout_rate", 0.1)) 92 | 93 | self._dropout_prob = 1.0 - self.dropout_rate 94 | self._weights, self._biases = self._load_model(model_path) 95 | 96 | def __repr__(self) -> str: 97 | return "retrostar" 98 | 99 | def calculate(self, mol: Molecule) -> float: 100 | # pylint: disable=invalid-name 101 | mol.sanitize() 102 | vec = mol.fingerprint( 103 | radius=self.fingerprint_radius, nbits=self.fingerprint_length 104 | ) 105 | for W, b in zip(self._weights[:-1], self._biases[:-1]): 106 | vec = np.matmul(vec, W) + b 107 | vec *= vec > 0 # ReLU 108 | # Drop-out 109 | vec *= np.random.binomial(1, self._dropout_prob, size=vec.shape) / ( 110 | self._dropout_prob 111 | ) 112 | vec = np.matmul(vec, self._weights[-1]) + self._biases[-1] 113 | return float(np.log(1 + np.exp(vec))) 114 | 115 | @staticmethod 116 | def _load_model(model_path: str) -> Tuple[List[np.ndarray], List[np.ndarray]]: 117 | with open(model_path, "rb") as fileobj: 118 | weights, biases = pickle.load(fileobj) 119 | 120 | return ( 121 | [np.asarray(item) for item in weights], 122 | [np.asarray(item) for item in biases], 123 | ) 124 | 125 | 126 | class ZeroMoleculeCost: 127 | """Encapsulation of a Zero cost model""" 128 | 129 | def __repr__(self) -> str: 130 | return "zero" 131 | 132 | def calculate(self, _mol: Molecule) -> float: # pytest: disable=unused-argument 133 | return 0.0 134 | -------------------------------------------------------------------------------- /aizynthfinder/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/aizynthfinder/tools/__init__.py -------------------------------------------------------------------------------- /aizynthfinder/tools/cat_output.py: -------------------------------------------------------------------------------- 1 | """ Module containing a CLI for concatenating output files (hdf5/json.gz files) 2 | """ 3 | import argparse 4 | 5 | from aizynthfinder.utils.files import cat_datafiles 6 | 7 | 8 | def main() -> None: 9 | """Entry-point for the cat_aizynth_output CLI""" 10 | parser = argparse.ArgumentParser("cat_aizynthcli_output") 11 | parser.add_argument( 12 | "--files", 13 | required=True, 14 | nargs="+", 15 | help="the output filenames", 16 | ) 17 | parser.add_argument( 18 | "--output", 19 | required=True, 20 | default="output.json.gz", 21 | help="the name of the concatenate output file ", 22 | ) 23 | parser.add_argument( 24 | "--trees", 25 | help="if given, save all trees to this file", 26 | ) 27 | args = parser.parse_args() 28 | 29 | cat_datafiles(args.files, args.output, args.trees) 30 | 31 | 32 | if __name__ == "__main__": 33 | main() 34 | -------------------------------------------------------------------------------- /aizynthfinder/tools/download_public_data.py: -------------------------------------------------------------------------------- 1 | """ Module with script to download public data 2 | """ 3 | import argparse 4 | import os 5 | import sys 6 | 7 | import requests 8 | import tqdm 9 | 10 | FILES_TO_DOWNLOAD = { 11 | "policy_model_onnx": { 12 | "filename": "uspto_model.onnx", 13 | "url": "https://zenodo.org/record/7797465/files/uspto_model.onnx", 14 | }, 15 | "template_file": { 16 | "filename": "uspto_templates.csv.gz", 17 | "url": "https://zenodo.org/record/7341155/files/uspto_unique_templates.csv.gz", 18 | }, 19 | "ringbreaker_model_onnx": { 20 | "filename": "uspto_ringbreaker_model.onnx", 21 | "url": "https://zenodo.org/record/7797465/files/uspto_ringbreaker_model.onnx", 22 | }, 23 | "ringbreaker_templates": { 24 | "filename": "uspto_ringbreaker_templates.csv.gz", 25 | "url": "https://zenodo.org/record/7341155/files/uspto_ringbreaker_unique_templates.csv.gz", 26 | }, 27 | "stock": { 28 | "filename": "zinc_stock.hdf5", 29 | "url": "https://ndownloader.figshare.com/files/23086469", 30 | }, 31 | "filter_policy_onnx": { 32 | "filename": "uspto_filter_model.onnx", 33 | "url": "https://zenodo.org/record/7797465/files/uspto_filter_model.onnx", 34 | }, 35 | } 36 | 37 | YAML_TEMPLATE = """expansion: 38 | uspto: 39 | - {} 40 | - {} 41 | ringbreaker: 42 | - {} 43 | - {} 44 | filter: 45 | uspto: {} 46 | stock: 47 | zinc: {} 48 | """ 49 | 50 | 51 | def _download_file(url: str, filename: str) -> None: 52 | with requests.get(url, stream=True) as response: 53 | response.raise_for_status() 54 | total_size = int(response.headers.get("content-length", 0)) 55 | pbar = tqdm.tqdm( 56 | total=total_size, desc=os.path.basename(filename), unit="B", unit_scale=True 57 | ) 58 | with open(filename, "wb") as fileobj: 59 | for chunk in response.iter_content(chunk_size=1024): 60 | fileobj.write(chunk) 61 | pbar.update(len(chunk)) 62 | pbar.close() 63 | 64 | 65 | def main() -> None: 66 | """Entry-point for CLI""" 67 | parser = argparse.ArgumentParser("download_public_data") 68 | parser.add_argument( 69 | "path", 70 | default=".", 71 | help="the path to download the files", 72 | ) 73 | path = parser.parse_args().path 74 | 75 | try: 76 | for filespec in FILES_TO_DOWNLOAD.values(): 77 | _download_file(filespec["url"], os.path.join(path, filespec["filename"])) 78 | except requests.HTTPError as err: 79 | print(f"Download failed with message {str(err)}") 80 | sys.exit(1) 81 | 82 | with open(os.path.join(path, "config.yml"), "w") as fileobj: 83 | path = os.path.abspath(path) 84 | fileobj.write( 85 | YAML_TEMPLATE.format( 86 | os.path.join(path, FILES_TO_DOWNLOAD["policy_model_onnx"]["filename"]), 87 | os.path.join(path, FILES_TO_DOWNLOAD["template_file"]["filename"]), 88 | os.path.join( 89 | path, FILES_TO_DOWNLOAD["ringbreaker_model_onnx"]["filename"] 90 | ), 91 | os.path.join( 92 | path, FILES_TO_DOWNLOAD["ringbreaker_templates"]["filename"] 93 | ), 94 | os.path.join(path, FILES_TO_DOWNLOAD["filter_policy_onnx"]["filename"]), 95 | os.path.join(path, FILES_TO_DOWNLOAD["stock"]["filename"]), 96 | ) 97 | ) 98 | print("Configuration file written to config.yml") 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /aizynthfinder/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/aizynthfinder/utils/__init__.py -------------------------------------------------------------------------------- /aizynthfinder/utils/bonds.py: -------------------------------------------------------------------------------- 1 | """Module containing a class to identify broken focussed bonds 2 | """ 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | if TYPE_CHECKING: 8 | from typing import List, Sequence, Tuple 9 | from aizynthfinder.chem.mol import TreeMolecule 10 | from aizynthfinder.chem.reaction import RetroReaction 11 | 12 | 13 | class BrokenBonds: 14 | """ 15 | A class to keep track of focussed bonds breaking in a target molecule. 16 | 17 | :param focussed_bonds: A list of focussed bond pairs. The bond pairs are represented 18 | as tuples of size 2. These bond pairs exist in the target molecule's atom bonds. 19 | """ 20 | 21 | def __init__(self, focussed_bonds: Sequence[Sequence[int]]) -> None: 22 | self.focussed_bonds = sort_bonds(focussed_bonds) 23 | self.filtered_focussed_bonds: List[Tuple[int, int]] = [] 24 | 25 | def __call__(self, reaction: RetroReaction) -> List[Tuple[int, int]]: 26 | """ 27 | Provides a list of focussed bonds that break in any of the molecule's reactants. 28 | 29 | :param reaction: A retro reaction. 30 | :return: A list of all the focussed bonds that broke within the reactants 31 | that constitute the target molecule. 32 | """ 33 | self.filtered_focussed_bonds = self._get_filtered_focussed_bonds(reaction.mol) 34 | if not self.filtered_focussed_bonds: 35 | return [] 36 | 37 | molecule_bonds = [] 38 | for reactant in reaction.reactants[reaction.index]: 39 | molecule_bonds += reactant.mapped_atom_bonds 40 | 41 | broken_focussed_bonds = self._get_broken_frozen_bonds( 42 | sort_bonds(molecule_bonds) 43 | ) 44 | return broken_focussed_bonds 45 | 46 | def _get_broken_frozen_bonds( 47 | self, 48 | molecule_bonds: List[Tuple[int, int]], 49 | ) -> List[Tuple[int, int]]: 50 | broken_focussed_bonds = list( 51 | set(self.filtered_focussed_bonds) - set(molecule_bonds) 52 | ) 53 | return broken_focussed_bonds 54 | 55 | def _get_filtered_focussed_bonds( 56 | self, molecule: TreeMolecule 57 | ) -> List[Tuple[int, int]]: 58 | molecule_bonds = molecule.mapped_atom_bonds 59 | atom_maps = [atom_map for bonds in molecule_bonds for atom_map in bonds] 60 | 61 | filtered_focussed_bonds = [] 62 | for idx1, idx2 in self.focussed_bonds: 63 | if idx1 in atom_maps and idx2 in atom_maps: 64 | filtered_focussed_bonds.append((idx1, idx2)) 65 | return filtered_focussed_bonds 66 | 67 | 68 | def sort_bonds(bonds: Sequence[Sequence[int]]) -> List[Tuple[int, int]]: 69 | return [tuple(sorted(bond)) for bond in bonds] # type: ignore 70 | -------------------------------------------------------------------------------- /aizynthfinder/utils/exceptions.py: -------------------------------------------------------------------------------- 1 | """ Module containing custom exception classes 2 | """ 3 | 4 | 5 | class CostException(Exception): 6 | """Exception raised by molecule cost classes""" 7 | 8 | 9 | class ExternalModelAPIError(Exception): 10 | """Custom error type to signal failure in External model""" 11 | 12 | 13 | class MoleculeException(Exception): 14 | """An exception that is raised by molecule class""" 15 | 16 | 17 | class NodeUnexpectedBehaviourException(Exception): 18 | """Exception that is raised if the tree search is behaving unexpectedly.""" 19 | 20 | 21 | class PolicyException(Exception): 22 | """An exception raised by policy classes""" 23 | 24 | 25 | class RejectionException(Exception): 26 | """An exception raised if a retro action should be rejected""" 27 | 28 | 29 | class ScorerException(Exception): 30 | """Exception raised by scoring classes""" 31 | 32 | 33 | class StockException(Exception): 34 | """An exception raised by stock classes""" 35 | 36 | 37 | class TreeAnalysisException(Exception): 38 | """Exception raised when analysing trees""" 39 | -------------------------------------------------------------------------------- /aizynthfinder/utils/loading.py: -------------------------------------------------------------------------------- 1 | """ Module containing routine to dynamically load a class from a specification """ 2 | from __future__ import annotations 3 | 4 | import importlib 5 | from typing import TYPE_CHECKING 6 | 7 | if TYPE_CHECKING: 8 | from typing import Any, Optional 9 | 10 | 11 | def load_dynamic_class( 12 | name_spec: str, 13 | default_module: Optional[str] = None, 14 | exception_cls: Any = ValueError, 15 | ) -> Any: 16 | """ 17 | Load an object from a dynamic specification. 18 | 19 | The specification can be either: 20 | ClassName, in-case the module name is taken from the `default_module` argument 21 | or 22 | package_name.module_name.ClassName, in-case the module is taken as `package_name.module_name` 23 | 24 | :param name_spec: the class specification 25 | :param default_module: the default module 26 | :param exception_cls: the exception class to raise on exception 27 | :return: the loaded class 28 | """ 29 | if "." not in name_spec: 30 | name = name_spec 31 | if not default_module: 32 | raise exception_cls( 33 | "Must provide default_module argument if not given in name_spec" 34 | ) 35 | module_name = default_module 36 | else: 37 | module_name, name = name_spec.rsplit(".", maxsplit=1) 38 | 39 | try: 40 | loaded_module = importlib.import_module(module_name) 41 | except ImportError: 42 | raise exception_cls(f"Unable to load module: {module_name}") 43 | 44 | if not hasattr(loaded_module, name): 45 | raise exception_cls( 46 | f"Module ({module_name}) does not have a class called {name}" 47 | ) 48 | 49 | return getattr(loaded_module, name) 50 | -------------------------------------------------------------------------------- /aizynthfinder/utils/logging.py: -------------------------------------------------------------------------------- 1 | """ Module containing routines to setup proper logging 2 | """ 3 | # pylint: disable=ungrouped-imports, wrong-import-order, wrong-import-position, unused-import 4 | import logging.config 5 | import os 6 | 7 | import yaml 8 | 9 | # See Github issue 30 why sklearn is imported here 10 | try: 11 | import sklearn # noqa 12 | except ImportError: 13 | pass 14 | from rdkit import RDLogger 15 | 16 | from aizynthfinder.utils.paths import data_path 17 | from aizynthfinder.utils.type_utils import Optional 18 | 19 | # Suppress RDKit errors due to incomplete template (e.g. aromatic non-ring atoms) 20 | rd_logger = RDLogger.logger() 21 | rd_logger.setLevel(RDLogger.CRITICAL) 22 | 23 | 24 | def logger() -> logging.Logger: 25 | """ 26 | Returns the logger that should be used by all classes 27 | 28 | :return: the logger object 29 | """ 30 | return logging.getLogger("aizynthfinder") 31 | 32 | 33 | def setup_logger( 34 | console_level: int, file_level: Optional[int] = None 35 | ) -> logging.Logger: 36 | """ 37 | Setup the logger that should be used by all classes 38 | 39 | The logger configuration is read from the `logging.yml` file. 40 | 41 | :param console_level: the level of logging to the console 42 | :param file_level: the level of logging to file, if not set logging to file is disabled, default to None 43 | :return: the logger object 44 | """ 45 | filename = os.path.join(data_path(), "logging.yml") 46 | with open(filename, "r") as fileobj: 47 | config = yaml.load(fileobj.read(), Loader=yaml.SafeLoader) 48 | 49 | config["handlers"]["console"]["level"] = console_level 50 | if file_level: 51 | config["handlers"]["file"]["level"] = file_level 52 | else: 53 | del config["handlers"]["file"] 54 | config["loggers"]["aizynthfinder"]["handlers"].remove("file") 55 | 56 | logging.config.dictConfig(config) 57 | return logger() 58 | -------------------------------------------------------------------------------- /aizynthfinder/utils/math.py: -------------------------------------------------------------------------------- 1 | """ Module containing diverse math functions, including neural network-related functions. """ 2 | 3 | import numpy as np 4 | from aizynthfinder.utils.type_utils import Callable 5 | 6 | 7 | # pylint: disable=invalid-name 8 | def dense_layer_forward_pass( 9 | x: np.ndarray, weights: np.ndarray, bias: np.ndarray, activation: Callable 10 | ) -> np.ndarray: 11 | """ 12 | Forward pass through a dense neural network layer. 13 | :param x: layer input 14 | :param weights: layer weights 15 | :param bias: layer bias 16 | :param activation: layer activation function 17 | :return: the layer output 18 | """ 19 | x = np.matmul(x, weights) + bias 20 | return activation(x) 21 | 22 | 23 | # pylint: disable=invalid-name 24 | def rectified_linear_unit(x: np.ndarray) -> np.ndarray: 25 | """ReLU activation function""" 26 | return x * (x > 0) 27 | 28 | 29 | # pylint: disable=invalid-name 30 | def sigmoid(x: np.ndarray) -> np.ndarray: 31 | """Sigmoid activation function""" 32 | return 1 / (1 + np.exp(-x)) 33 | 34 | 35 | # pylint: disable=invalid-name 36 | def softmax(x: np.ndarray) -> np.ndarray: 37 | """Compute softmax values for each sets of scores in x.""" 38 | return np.exp(x) / np.sum(np.exp(x), axis=0) 39 | -------------------------------------------------------------------------------- /aizynthfinder/utils/mongo.py: -------------------------------------------------------------------------------- 1 | """ Module containing routines to obtain a MongoClient instance 2 | """ 3 | from typing import Optional 4 | from urllib.parse import urlencode 5 | 6 | try: 7 | from pymongo import MongoClient 8 | except ImportError: 9 | MongoClient = None 10 | HAS_PYMONGO = False 11 | else: 12 | HAS_PYMONGO = True 13 | 14 | from aizynthfinder.utils.logging import logger 15 | 16 | _CLIENT = None 17 | 18 | 19 | def get_mongo_client( 20 | host: str = "localhost", 21 | port: int = 27017, 22 | user: Optional[str] = None, 23 | password: Optional[str] = None, 24 | tls_certs_path: str = "", 25 | ) -> Optional[MongoClient]: 26 | """ 27 | A helper function to create and reuse MongoClient 28 | 29 | The client is only setup once. Therefore if this function is called a second 30 | time with different parameters, it would still return the first client. 31 | 32 | :param host: the host 33 | :param port: the host port 34 | :param user: username, defaults to None 35 | :param password: password, defaults to None 36 | :param tls_certs_path: the path to TLS certificates if to be used, defaults to "" 37 | :raises ValueError: if host and port is not given first time 38 | :return: the MongoDB client 39 | """ 40 | if not HAS_PYMONGO: 41 | return None 42 | 43 | global _CLIENT 44 | if _CLIENT is None: 45 | params = {} 46 | if tls_certs_path: 47 | params.update({"ssl": "true", "ssl_ca_certs": tls_certs_path}) 48 | cred_str = f"{user}:{password}@" if password else "" 49 | uri = f"mongodb://{cred_str}{host}:{port}/?{urlencode(params)}" 50 | logger().debug(f"Connecting to MongoDB on {host}:{port}") 51 | _CLIENT = MongoClient(uri) # pylint: disable=C0103 52 | return _CLIENT 53 | -------------------------------------------------------------------------------- /aizynthfinder/utils/paths.py: -------------------------------------------------------------------------------- 1 | """ Module containing routines for returning package paths 2 | """ 3 | import os 4 | 5 | 6 | def package_path() -> str: 7 | """Return the path to the package""" 8 | return os.path.abspath(os.path.dirname(os.path.dirname(__file__))) 9 | 10 | 11 | def data_path() -> str: 12 | """Return the path to the ``data`` directory of the package""" 13 | return os.path.join(package_path(), "data") 14 | -------------------------------------------------------------------------------- /aizynthfinder/utils/sc_score.py: -------------------------------------------------------------------------------- 1 | """ Module containing the implementation of the SC-score model for synthetic complexity scoring. """ 2 | 3 | from __future__ import annotations 4 | 5 | import pickle 6 | from typing import TYPE_CHECKING 7 | 8 | import numpy as np 9 | from rdkit.Chem import AllChem 10 | 11 | from aizynthfinder.utils.math import ( 12 | dense_layer_forward_pass, 13 | rectified_linear_unit, 14 | sigmoid, 15 | ) 16 | 17 | if TYPE_CHECKING: 18 | from aizynthfinder.utils.type_utils import RdMol, Sequence, Tuple 19 | 20 | 21 | class SCScore: 22 | """ 23 | Encapsulation of the SCScore model 24 | 25 | Re-write of the SCScorer from the scscorer package 26 | 27 | The predictions of the score is made with a sanitized instance of an RDKit molecule 28 | 29 | .. code-block:: 30 | 31 | mol = Molecule(smiles="CCC", sanitize=True) 32 | scscorer = SCScorer("path_to_model") 33 | score = scscorer(mol.rd_mol) 34 | 35 | The model provided when creating the scorer object should be pickled tuple. 36 | The first item of the tuple should be a list of the model weights for each layer. 37 | The second item of the tuple should be a list of the model biases for each layer. 38 | 39 | :param model_path: the filename of the model weights and biases 40 | :param fingerprint_length: the number of bits in the fingerprint 41 | :param fingerprint_radius: the radius of the fingerprint 42 | """ 43 | 44 | def __init__( 45 | self, 46 | model_path: str, 47 | fingerprint_length: int = 1024, 48 | fingerprint_radius: int = 2, 49 | ) -> None: 50 | self._fingerprint_length = fingerprint_length 51 | self._fingerprint_radius = fingerprint_radius 52 | self._weights, self._biases = self._load_model(model_path) 53 | self.score_scale = 5.0 54 | 55 | def __call__(self, rd_mol: RdMol) -> float: 56 | fingerprint = self._make_fingerprint(rd_mol) 57 | normalized_score = self.forward(fingerprint) 58 | sc_score = (1 + (self.score_scale - 1) * normalized_score)[0] 59 | return sc_score 60 | 61 | # pylint: disable=invalid-name 62 | def forward(self, x: np.ndarray) -> np.ndarray: 63 | """Forward pass with dense neural network""" 64 | for weights, bias in zip(self._weights[:-1], self._biases[:-1]): 65 | x = dense_layer_forward_pass(x, weights, bias, rectified_linear_unit) 66 | return dense_layer_forward_pass(x, self._weights[-1], self._biases[-1], sigmoid) 67 | 68 | def _load_model( 69 | self, model_path: str 70 | ) -> Tuple[Sequence[np.ndarray], Sequence[np.ndarray]]: 71 | """Returns neural network model parameters.""" 72 | with open(model_path, "rb") as fileobj: 73 | weights, biases = pickle.load(fileobj) 74 | 75 | weights = [np.asarray(item) for item in weights] 76 | biases = [np.asarray(item) for item in biases] 77 | return weights, biases 78 | 79 | def _make_fingerprint(self, rd_mol: RdMol) -> np.ndarray: 80 | """Returns the molecule's Morgan fingerprint""" 81 | fp_vec = AllChem.GetMorganFingerprintAsBitVect( 82 | rd_mol, 83 | self._fingerprint_radius, 84 | nBits=self._fingerprint_length, 85 | useChirality=True, 86 | ) 87 | return np.array( 88 | fp_vec, 89 | dtype=bool, 90 | ) 91 | -------------------------------------------------------------------------------- /aizynthfinder/utils/type_utils.py: -------------------------------------------------------------------------------- 1 | """ Module containing all types and type imports 2 | """ 3 | # pylint: disable=unused-import 4 | from typing import Callable # noqa 5 | from typing import Iterable # noqa 6 | from typing import List # noqa 7 | from typing import Sequence # noqa 8 | from typing import Set # noqa 9 | from typing import TypeVar # noqa 10 | from typing import Any, Dict, Optional, Tuple, Union 11 | 12 | from PIL.Image import Image 13 | from rdkit import Chem 14 | from rdkit.Chem import rdChemReactions 15 | from rdkit.DataStructs.cDataStructs import ExplicitBitVect 16 | 17 | StrDict = Dict[str, Any] 18 | RdMol = Chem.rdchem.Mol 19 | RdReaction = Chem.rdChemReactions.ChemicalReaction 20 | BitVector = ExplicitBitVect 21 | PilImage = Image 22 | PilColor = Union[str, Tuple[int, int, int]] 23 | FrameColors = Optional[Dict[bool, PilColor]] 24 | -------------------------------------------------------------------------------- /contrib/notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "ObD2YL7nM2_X" 7 | }, 8 | "source": [ 9 | "# AiZynthFinder\n", 10 | "\n", 11 | "Click the ▶ play button at the left of the **Installation** text below to install the application. The initial installation process may take a few minutes. Then run **Start application** cell. \n", 12 | "\n", 13 | "1. Enter the target compound [SMILES][1] code.\n", 14 | "3. Click the **Run Search** button to start the algorithm.\n", 15 | "4. Once it stops serching, click the **Show Reactions** button.\n", 16 | "\n", 17 | "[1]: https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "#@title Installation -- Run this cell to install AiZynthFinder\n", 27 | "\n", 28 | "# This might not install the dependencies that you need\n", 29 | "#!(AIZ_LATEST_TAG=$(git -c 'versionsort.suffix=-' \\\n", 30 | "# ls-remote --exit-code --refs --sort='version:refname' --tags https://github.com/MolecularAI/aizynthfinder '*.*.*' \\\n", 31 | "# | tail --lines=1 \\\n", 32 | "# | cut -f 3 -d /) \\\n", 33 | "# && pip install --quiet https://github.com/MolecularAI/aizynthfinder/archive/${AIZ_LATEST_TAG}.tar.gz)\n", 34 | "!pip install --quiet aizynthfinder[all]\n", 35 | "!pip install --ignore-installed Pillow==9.0.0\n", 36 | "!mkdir --parents data && download_public_data data\n" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": { 43 | "id": "beDP-CSNM429" 44 | }, 45 | "outputs": [], 46 | "source": [ 47 | "#@title Start application. {display-mode: \"form\"}\n", 48 | "\n", 49 | "from rdkit.Chem.Draw import IPythonConsole\n", 50 | "from aizynthfinder.interfaces import AiZynthApp\n", 51 | "application = AiZynthApp(\"./data/config.yml\")" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": { 57 | "id": "bwxusoogwlI9" 58 | }, 59 | "source": [ 60 | "# Bibliography\n", 61 | "\n", 62 | "_Genheden S, Thakkar A, Chadimova V, et al (2020) AiZynthFinder: a fast, robust and flexible open-source software for retrosynthetic planning. J. Cheminf. https://doi.org/10.1186/s13321-020-00472-1 ([GitHub](https://github.com/MolecularAI/aizynthfinder) & [Documentation](https://molecularai.github.io/aizynthfinder/html/index.html))_" 63 | ] 64 | } 65 | ], 66 | "metadata": { 67 | "accelerator": "GPU", 68 | "colab": { 69 | "collapsed_sections": [], 70 | "name": "AiZynthFinder.ipynb", 71 | "private_outputs": true, 72 | "provenance": [], 73 | "toc_visible": true 74 | }, 75 | "kernelspec": { 76 | "display_name": "Python 3 (ipykernel)", 77 | "language": "python", 78 | "name": "python3" 79 | }, 80 | "language_info": { 81 | "codemirror_mode": { 82 | "name": "ipython", 83 | "version": 3 84 | }, 85 | "file_extension": ".py", 86 | "mimetype": "text/x-python", 87 | "name": "python", 88 | "nbconvert_exporter": "python", 89 | "pygments_lexer": "ipython3", 90 | "version": "3.10.9" 91 | } 92 | }, 93 | "nbformat": 4, 94 | "nbformat_minor": 1 95 | } 96 | -------------------------------------------------------------------------------- /docs/analysis-rel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/docs/analysis-rel.png -------------------------------------------------------------------------------- /docs/analysis-seq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/docs/analysis-seq.png -------------------------------------------------------------------------------- /docs/cli.rst: -------------------------------------------------------------------------------- 1 | Command-line interface 2 | ====================== 3 | 4 | This tools provide the possibility to perform tree search on a batch of molecules. 5 | 6 | In its simplest form, you type 7 | 8 | .. code-block:: bash 9 | 10 | aizynthcli --config config_local.yml --smiles smiles.txt 11 | 12 | where `config_local.yml` contains configurations such as paths to policy models and stocks (see :doc:`here `) 13 | and `smiles.txt` is a simple text file with SMILES (one on each row). 14 | 15 | 16 | To find out what other arguments are available use the ``-h`` flag. 17 | 18 | .. code-block:: bash 19 | 20 | aizynthcli -h 21 | 22 | That gives something like this: 23 | 24 | .. include:: cli_help.txt 25 | 26 | 27 | By default: 28 | 29 | * `All` stocks are selected if no stock is specified 30 | * `First` expansion policy is selected if not expansion policy is specified 31 | * `All` filter policies are selected if it is not specified on the command-line 32 | 33 | Analysing output 34 | ---------------- 35 | 36 | 37 | The results from the ``aizynthcli`` tool when supplying multiple SMILES is an JSON or HDF5 file that can be read as a pandas dataframe. 38 | It will be called `output.json.gz` by default. 39 | 40 | A `checkpoint.json.gz` will also be generated if a checkpoint file path is provided as input when calling the ``aizynthcli`` tool. The 41 | checkpoint data will contain the processed smiles with their corresponding results in each line of the file. 42 | 43 | .. code-block:: 44 | 45 | import pandas as pd 46 | data = pd.read_json("output.json.gz", orient="table") 47 | 48 | it will contain statistics about the tree search and the top-ranked routes (as JSONs) for each target compound, see below. 49 | 50 | When a single SMILES is provided to the tool, the statistics will be written to the terminal, and the top-ranked routes to 51 | a JSON file (`trees.json` by default). 52 | 53 | 54 | This is an example of how to create images of the top-ranked routes for the first target compound 55 | 56 | 57 | .. code-block:: 58 | 59 | import pandas as pd 60 | from aizynthfinder.reactiontree import ReactionTree 61 | 62 | data = pd.read_json("output.json.gz", orient="table") 63 | all_trees = data.trees.values # This contains a list of all the trees for all the compounds 64 | trees_for_first_target = all_trees[0] 65 | 66 | for itree, tree in enumerate(trees_for_first_target): 67 | imagefile = f"route{itree:03d}.png" 68 | ReactionTree.from_dict(tree).to_image().save(imagefile) 69 | 70 | The images will be called `route000.png`, `route001.png` etc. 71 | 72 | 73 | Specification of output 74 | ----------------------- 75 | 76 | The JSON or HDF5 file created when running the tool with a list of SMILES will have the following columns 77 | 78 | ============================= =========== 79 | Column Description 80 | ============================= =========== 81 | target The target SMILES 82 | search_time The total search time in seconds 83 | first_solution_time The time elapsed until the first solution was found 84 | first_solution_iteration The number of iterations completed until the first solution was found 85 | number_of_nodes The number of nodes in the search tree 86 | max_transforms The maximum number of transformations for all routes in the search tree 87 | max_children The maximum number of children for a search node 88 | number_of_routes The number of routes in the search tree 89 | number_of_solved_routes The number of solved routes in search tree 90 | top_score The score of the top-scored route (default to MCTS reward) 91 | is_solved If the top-scored route is solved 92 | number_of_steps The number of reactions in the top-scored route 93 | number_of_precursors The number of starting materials 94 | number_of_precursors_in_stock The number of starting materials in stock 95 | precursors_in_stock Comma-separated list of SMILES of starting material in stock 96 | precursors_not_in_stock Comma-separated list of SMILES of starting material not in stock 97 | precursors_availability Semi-colon separated list of stock availability of the staring material 98 | policy_used_counts Dictionary of the total number of times an expansion policy have been used 99 | profiling Profiling information from the search tree, including expansion models call and reactant generation 100 | stock_info Dictionary of the stock availability for each of the starting material in all extracted routes 101 | top_scores Comma-separated list of the score of the extracted routes (default to MCTS reward) 102 | trees A list of the extracted routes as dictionaries 103 | ============================= =========== 104 | 105 | If you running the tool with a single SMILES, all of this data will be printed to the screen, except 106 | the ``stock_info`` and ``trees``. 107 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, os.path.abspath(".")) 5 | 6 | project = "aizynthfinder" 7 | copyright = "2020-2025, Molecular AI group" 8 | author = "Molecular AI group" 9 | release = "4.3.2" 10 | 11 | # This make sure that the cli_help.txt file is properly formated 12 | with open("cli_help.txt", "r") as fileobj: 13 | lines = fileobj.read().splitlines() 14 | with open("cli_help.txt", "w") as fileobj: 15 | fileobj.write(".. code-block::\n\n") 16 | fileobj.write(" " + "\n ".join(lines)) 17 | 18 | extensions = [ 19 | "sphinx.ext.autodoc", 20 | ] 21 | autodoc_member_order = "bysource" 22 | autodoc_typehints = "description" 23 | 24 | html_theme = "alabaster" 25 | html_theme_options = { 26 | "description": "A fast, robust and flexible software for retrosynthetic planning", 27 | "fixed_sidebar": True, 28 | } 29 | -------------------------------------------------------------------------------- /docs/gui.rst: -------------------------------------------------------------------------------- 1 | Graphical user interface 2 | ======================== 3 | 4 | This tool to provide the possibility to perform the tree search on a single compound using a GUI 5 | through a Jupyter notebook. If you are unfamiliar with notebooks, you find some introduction `here `_. 6 | 7 | To bring up the notebook, use 8 | 9 | .. code-block:: bash 10 | 11 | jupyter notebook 12 | 13 | and browse to an existing notebook or create a new one. 14 | 15 | Add these lines to the first cell in the notebook. 16 | 17 | .. code-block:: 18 | 19 | from aizynthfinder.interfaces import AiZynthApp 20 | app = AiZynthApp("/path/to/configfile.yaml") 21 | 22 | where the ``AiZynthApp`` class needs to be instantiated with the path to a configuration file (see :doc:`here `). 23 | 24 | To use the interface, follow these steps: 25 | 26 | 1. Executed the code in the cell (press ``Ctrl+Enter``) and a simple GUI will appear 27 | 2. Enter the target SMILES and select stocks and policy model. 28 | 3. Press the ``Run Search`` button to perform the tree search. 29 | 30 | .. image:: gui_input.png 31 | 32 | 33 | 4. Press the ``Show Reactions`` to see the top-ranked routes 34 | 35 | 36 | .. image:: gui_results.png 37 | 38 | You can also choose to select and sort the top-ranked routes based on another scoring function. 39 | 40 | 41 | Creating the notebook 42 | --------------------- 43 | 44 | It is possible to create a notebook automatically with the ``aizynthapp`` tool 45 | 46 | .. code-block:: bash 47 | 48 | aizynthapp --config config_local.yml 49 | 50 | which will also automatically opens up the created notebook. 51 | 52 | Analysing the results 53 | --------------------- 54 | 55 | When the tree search has been finished. One can continue exploring the tree and extract output. 56 | This is done by using the ``finder`` property of the app object. The finder holds a reference to an ``AiZynthFinder`` object. 57 | 58 | .. code-block:: 59 | 60 | finder = app.finder 61 | stats = finder.extract_statistics() 62 | 63 | 64 | Clustering 65 | ----------- 66 | 67 | There is a GUI extension to perform clustering of the routes. Enter the following a new cell 68 | 69 | .. code-block:: 70 | 71 | %matplotlib inline 72 | from aizynthfinder.interfaces.gui.clustering import ClusteringGui 73 | ClusteringGui.from_app(app) 74 | 75 | 76 | A GUI like this will be shown, where you see the hierarchy of the routes and then can select how many 77 | clusters you want to create. 78 | 79 | .. image:: gui_clustering.png -------------------------------------------------------------------------------- /docs/gui_clustering.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/docs/gui_clustering.png -------------------------------------------------------------------------------- /docs/gui_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/docs/gui_input.png -------------------------------------------------------------------------------- /docs/gui_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/docs/gui_results.png -------------------------------------------------------------------------------- /docs/howto.rst: -------------------------------------------------------------------------------- 1 | How-to 2 | ======= 3 | 4 | This page outlines a few guidelines on some more advanced use-cases of AiZynthFinder or 5 | frequently raised issues. 6 | 7 | 8 | Using Retro* 9 | ------------ 10 | 11 | AiZynthFinder implements other search algorithms than MCTS. This is an example of how Retro* can be used. 12 | 13 | The search algorithm is specified in the configuration file. 14 | 15 | .. code-block:: yaml 16 | 17 | search: 18 | algorithm: aizynthfinder.search.retrostar.search_tree.SearchTree 19 | 20 | 21 | This will use Retro* without a constant-valued oracle function. To specify the oracle function, you can 22 | do 23 | 24 | .. code-block:: yaml 25 | 26 | search: 27 | algorithm: aizynthfinder.search.retrostar.search_tree.SearchTree 28 | algorithm_config: 29 | molecule_cost: 30 | cost: aizynthfinder.search.retrostar.cost.RetroStarCost 31 | model_path: retrostar_value_model.pickle 32 | fingerprint_length: 2048 33 | fingerprint_radius: 2 34 | dropout_rate: 0.1 35 | 36 | The pickle file can be downloaded from `here `_ 37 | 38 | 39 | Using multiple expansion policies 40 | --------------------------------- 41 | 42 | AiZynthFinder can use multiple expansion policies. This gives an example how a general USPTO and a RingBreaker model 43 | can be used together 44 | 45 | .. code-block:: yaml 46 | 47 | expansion: 48 | uspto: 49 | - uspto_keras_model.hdf5 50 | - uspto_unique_templates.csv.gz 51 | ringbreaker: 52 | - uspto_ringbreaker_keras_model.hdf5 53 | - uspto_ringbreaker_unique_templates.csv.gz 54 | multi_expansion_strategy: 55 | type: aizynthfinder.context.policy.MultiExpansionStrategy 56 | expansion_strategies: [uspto, ringbreaker] 57 | additive_expansion: True 58 | 59 | and then to use this with ``aizynthcli`` do something like this 60 | 61 | .. code-block:: 62 | 63 | aizynthcli --smiles smiles.txt --config config.yml --policy multi_expansion_strategy 64 | 65 | 66 | Output more routes 67 | ------------------ 68 | 69 | The number of routes in the output of ``aizynthcli`` can be controlled from the configuration file. 70 | 71 | This is how you can extract at least 25 routes but not more than 50 per target 72 | 73 | .. code-block:: yaml 74 | 75 | post_processing: 76 | min_routes: 25 77 | max_routes: 50 78 | 79 | Alternatively, you can extract all solved routes. If a target is unsolved, it will return the number 80 | of routes specified by ``min_routes`` and ``max_routes``. 81 | 82 | .. code-block:: yaml 83 | 84 | post_processing: 85 | min_routes: 5 86 | max_routes: 10 87 | all_routes: True 88 | 89 | 90 | Running multi-objective (MO) MCTS with disconnection-aware Chemformer 91 | ------------------ 92 | Disconnection-aware retrosynthesis can be done with 1) MO-MCTS (state score + broken bonds score), 2) Chemformer or 3) both. 93 | 94 | First, you need to specify the bond constraints under search, see below. 95 | To run the MO-MCTS with the "broken bonds" score, add the "broken bonds" score to the list of search_rewards: 96 | 97 | .. code-block:: yaml 98 | 99 | search: 100 | break_bonds: [[1, 2], [3, 4]] 101 | freeze_bonds: [] 102 | algorithm_config: 103 | search_rewards: ["state score", "broken bonds"] 104 | 105 | To use the disconnection-aware Chemformer, you first need to add the `plugins` folder to the `PYTHONPATH`, e.g. 106 | 107 | export PYTHONPATH=~/aizynthfinder/plugins/ 108 | 109 | The script for starting a disconnection-aware Chemformer service is available at https://github.com/MolecularAI/Chemformer. 110 | The multi-expansion policy with template-based model and Chemformer is specified with: 111 | 112 | .. code-block:: yaml 113 | expansion: 114 | standard: 115 | type: template-based 116 | model: path/to/model 117 | template: path/to/templates 118 | chemformer_disconnect: 119 | type: expansion_strategies.DisconnectionAwareExpansionStrategy 120 | url: "http://localhost:8023/chemformer-disconnect-api/predict-disconnection" 121 | n_beams: 5 122 | multi_expansion: 123 | type: aizynthfinder.context.policy.MultiExpansionStrategy 124 | expansion_strategies: [chemformer_disconnect, standard] 125 | additive_expansion: True 126 | cutoff_number: 50 127 | 128 | 129 | To use MO-tree ranking and building, set: 130 | 131 | .. code-block:: yaml 132 | post_processing: 133 | route_scorers: ["state score", "broken bonds"] 134 | 135 | Note: If post_processing.route_scorers is not specified, it will default to search.algorithm_config.search_rewards. -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | aizynthfinder documentation 2 | =========================== 3 | 4 | aizynthfinder is a tool for retrosynthetic planning. The default algorithm is based on a Monte Carlo tree search that recursively breaks down a molecule to purchasable precursors. The tree search is guided by a policy that suggests possible precursors by utilizing a neural network trained on a library of known reaction templates. 5 | 6 | Introduction 7 | ------------ 8 | 9 | You run retrosynthesis experiments you need a trained model and a stock collection. You can download a public available model based on USPTO and a stock collection from ZINC database. 10 | 11 | .. code-block:: 12 | 13 | download_public_data . 14 | 15 | This will download the data to your current directory. The ``config.yml`` file can be used directly with the interfaces. 16 | 17 | 18 | There are two main interfaces provided by the package: 19 | 20 | * a script that performs tree search in batch mode and 21 | * an interface that is providing a GUI within a Jupyter notebook. 22 | 23 | 24 | The GUI interface should be run in a Jupyter notebook. This is a simple example of the code in a Jupyter notebook cell. 25 | 26 | .. code-block:: 27 | 28 | from aizynthfinder.interfaces import AiZynthApp 29 | app = AiZynthApp("/path/to/configfile.yaml") 30 | 31 | where the ``AiZynthApp`` class needs to be instantiated with the path to a configuration file (see :doc:`here `). 32 | 33 | To use the interface, follow these steps: 34 | 35 | 1. Executed the code in the cell (press ``Ctrl+Enter``) and a simple GUI will appear 36 | 2. Enter the target SMILES and select stocks and policy model. 37 | 3. Press the ``Run Search`` button to perform the tree search. 38 | 4. Press the ``Show Reactions`` to see the top-ranked routes 39 | 40 | 41 | 42 | The batch-mode script is called ``aizynthcli`` and can be executed like: 43 | 44 | .. code-block:: bash 45 | 46 | aizynthcli --config config.yml --smiles smiles.txt 47 | 48 | 49 | where `config.yml` contains configurations such as paths to policy models and stocks (see :doc:`here `), and `smiles.txt` is a simple text 50 | file with SMILES (one on each row). 51 | 52 | If you just want to perform the tree search on a single molecule. You can directly specify it on the command-line 53 | within quotes: 54 | 55 | .. code-block:: bash 56 | 57 | aizynthcli --config config.yml --smiles "COc1cccc(OC(=O)/C=C/c2cc(OC)c(OC)c(OC)c2)c1" 58 | 59 | 60 | The output is some statistics about the tree search, the scores of the top-ranked routes, and the reaction tree 61 | of the top-ranked routes. When smiles are provided in a text file the results are stored in a JSON file, 62 | whereas if the SMILEs is provided on the command-line it is printed directly to the prompt 63 | (except the reaction trees, which are written to a JSON file). 64 | 65 | 66 | .. toctree:: 67 | :hidden: 68 | 69 | gui 70 | cli 71 | python_interface 72 | configuration 73 | stocks 74 | scoring 75 | howto 76 | aizynthfinder 77 | sequences 78 | relationships -------------------------------------------------------------------------------- /docs/line-desc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/docs/line-desc.png -------------------------------------------------------------------------------- /docs/python_interface.rst: -------------------------------------------------------------------------------- 1 | Python interface 2 | ================ 3 | 4 | This page gives a quick example of how the tree search can be completed 5 | by writing your own python interface. This is not recommended for most users. 6 | 7 | 8 | 1. Import the necessary class 9 | 10 | .. code-block:: python 11 | 12 | from aizynthfinder.aizynthfinder import AiZynthFinder 13 | 14 | 15 | 2. Instantiate that class by providing a configuration file. 16 | 17 | 18 | .. code-block:: python 19 | 20 | filename = "config.yml" 21 | finder = AiZynthFinder(configfile=filename) 22 | 23 | 24 | 3. Select stock and policy 25 | 26 | 27 | .. code-block:: python 28 | 29 | finder.stock.select("zinc") 30 | finder.expansion_policy.select("uspto") 31 | finder.filter_policy.select("uspto") 32 | 33 | `zinc` and `uspto` where the keys given to the stock and the policy in the configuration file. 34 | The first policy set is the expansion policy and the second is the filter policy. The filter policy is optional. 35 | 36 | 4. Set the target SMILES and perform the tree search 37 | 38 | 39 | .. code-block:: python 40 | 41 | finder.target_smiles = "Cc1cccc(c1N(CC(=O)Nc2ccc(cc2)c3ncon3)C(=O)C4CCS(=O)(=O)CC4)C" 42 | finder.tree_search() 43 | 44 | 45 | 5. Analyse the search tree and build routes 46 | 47 | 48 | .. code-block:: python 49 | 50 | finder.build_routes() 51 | stats = finder.extract_statistics() 52 | 53 | 54 | The ``build_routes`` method needs to be called before any analysis can be done. 55 | 56 | Expansion interface 57 | ------------------- 58 | 59 | There is an interface for the expansion policy as well. It can be used to break down a molecule into reactants. 60 | 61 | .. code-block:: python 62 | 63 | filename = "config.yml" 64 | expander = AiZynthExpander(configfile=filename) 65 | expander.expansion_policy.select("uspto") 66 | expander.filter_policy.select("uspto") 67 | reactions = expander.do_expansion("Cc1cccc(c1N(CC(=O)Nc2ccc(cc2)c3ncon3)C(=O)C4CCS(=O)(=O)CC4)C") 68 | 69 | for this, you only need to select the policies. The filter policy is optional and using it will only add the 70 | feasibility of the reactions not filter it out. 71 | 72 | The result is a nested list of `FixedRetroReaction` objects. This you can manipulate to for instance get 73 | out all the reactants SMILES strings 74 | 75 | .. code-block:: python 76 | 77 | reactants_smiles = [] 78 | for reaction_tuple in reactions: 79 | reactants_smiles.append([mol.smiles for mol in reaction_tuple[0].reactants[0]) 80 | 81 | or you can put all the metadata of all the reactions in a pandas dataframe 82 | 83 | .. code-block:: python 84 | 85 | import pandas as pd 86 | metadata = [] 87 | for reaction_tuple in reactions: 88 | for reaction in reaction_tuple: 89 | metadata.append(reaction.metadata) 90 | df = pd.DataFrame(metadata) 91 | 92 | 93 | Further reading 94 | --------------- 95 | 96 | The docstrings of all modules, classes and methods can be consulted :doc:`here ` 97 | 98 | 99 | and you can always find them in an interactive Python shell using for instance: 100 | 101 | .. code-block:: python 102 | 103 | from aizynthfinder.chem import Molecule 104 | help(Molecule) 105 | help(Molecule.fingerprint) 106 | 107 | 108 | If you are interested in the the relationships between the classes have a look :doc:`here ` 109 | and if you want to dig deeper in to the main algorithmic sequences have a look :doc:`here ` -------------------------------------------------------------------------------- /docs/relationships.rst: -------------------------------------------------------------------------------- 1 | Relationships 2 | ============= 3 | 4 | This page shows some relationship diagrams, i.e. how the different objects are connect in a typical retrosynthesis 5 | analysis using Monte Carlo tree search. 6 | 7 | These are the tree different types of relationships used: 8 | 9 | .. image:: line-desc.png 10 | 11 | Tree search 12 | ----------- 13 | 14 | This diagram explains how the different object are connect that are responsible for the Monte-Carlo tree search. 15 | 16 | .. image:: treesearch-rel.png 17 | 18 | 19 | Analysis / post-processing 20 | -------------------------- 21 | 22 | This diagram explains how the different objects involved in the analysis of the search are connected. 23 | 24 | .. image:: analysis-rel.png 25 | -------------------------------------------------------------------------------- /docs/scoring.rst: -------------------------------------------------------------------------------- 1 | Scoring 2 | ======= 3 | 4 | aizynthfinder is capable of scoring reaction routes, both in the form of ``MctsNode`` objects when a search tree is available, 5 | and in the form of ``ReactionTree`` objects if post-processing is required. 6 | 7 | Currently, there are a few scoring functions available 8 | 9 | * State score - a function of the number of precursors in stock and the length of the route 10 | * Number of reactions - the number of steps in the route 11 | * Number of pre-cursors - the number of pre-cursors in the route 12 | * Number of pre-cursors in stock - the number of the pre-cursors that are purchaseable 13 | * Average template occurrence - the average occurrence of the templates used in the route 14 | * Sum of prices - the plain sum of the price of all pre-cursors 15 | * Route cost score - the cost of the synthesizing the route (Badowski et al. Chem Sci. 2019, 10, 4640) 16 | 17 | 18 | The *State score* is the score that is guiding the tree search in the :doc:`update phase `, and 19 | this is not configurable. 20 | 21 | In the Jupyter notebook :doc:`GUI ` one can choose to score the routes with any of the loaded the scorers. 22 | 23 | The first four scoring functions are loaded automatically when an ``aizynthfinder`` object is created. 24 | 25 | 26 | Add new scoring functions 27 | ------------------------- 28 | 29 | 30 | Additional scoring functions can be implemented by inheriting from the class ``Scorer`` in the ``aizynthfinder.context.scoring.scorers`` module. 31 | The scoring class needs to implement the ``_score_node``, ``_score_reaction_tree`` and the ``__repr__`` methods. 32 | 33 | This is an example of that. 34 | 35 | .. code-block:: python 36 | 37 | from aizynthfinder.context.scoring.scorers import Scorer 38 | 39 | class DeltaNumberOfTransformsScorer(Scorer): 40 | 41 | def __repr__(self): 42 | return "delta number of transforms" 43 | 44 | def _score_node(self, node): 45 | return self._config.max_transforms - node.state.max_transforms 46 | 47 | def _score_reaction_tree(self, tree): 48 | return self._config.max_transforms - len(list(tree.reactions())) 49 | 50 | 51 | This can then be added to the ``scorers`` attribute of an ``aizynthfinderfinder`` object. The ``scorers`` attribute is a collection 52 | of ``Scorer`` objects. 53 | 54 | For instance to use this in the Jupyter notebook GUI, one can do 55 | 56 | .. code-block:: python 57 | 58 | from aizynthfinder.interfaces import AiZynthApp 59 | app = AiZynthApp("config_local.yml", setup=False) 60 | scorer = DeltaNumberOfTransformsScorer(app.finder.config) 61 | app.finder.scorers.load(scorer) 62 | app.setup() 63 | 64 | -------------------------------------------------------------------------------- /docs/treesearch-rel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/docs/treesearch-rel.png -------------------------------------------------------------------------------- /docs/treesearch-seq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/docs/treesearch-seq.png -------------------------------------------------------------------------------- /env-dev.yml: -------------------------------------------------------------------------------- 1 | name: aizynth-dev 2 | channels: 3 | - https://conda.anaconda.org/conda-forge 4 | - defaults 5 | dependencies: 6 | - python>=3.9,<3.11 7 | - poetry>=1.1.4,<1.4.1 -------------------------------------------------------------------------------- /plugins/README.md: -------------------------------------------------------------------------------- 1 | # Plugins 2 | 3 | This folder contains some features for `aizynthfinder` that 4 | does not yet fit into the main codebase. It could be experimental 5 | features, or features that require the user to install some 6 | additional third-party dependencies. 7 | 8 | For the expansion models, you generally need to add the `plugins` folder to the `PYTHONPATH`, e.g. 9 | 10 | export PYTHONPATH=~/aizynthfinder/plugins/ 11 | 12 | where the `aizynthfinder` repository is in the home folder 13 | 14 | ## Chemformer expansion model 15 | 16 | An expansion model using a REST API for the Chemformer model 17 | is supplied in the `expansion_strategies` module. 18 | 19 | To use it, you first need to install the `chemformer` package 20 | and launch the REST API service that comes with it. 21 | 22 | To use the expansion model in `aizynthfinder` you can use a config-file 23 | containing these lines 24 | 25 | expansion: 26 | chemformer: 27 | type: expansion_strategies.ChemformerBasedExpansionStrategy 28 | url: http://localhost:8000/chemformer-api/predict 29 | search: 30 | algorithm_config: 31 | immediate_instantiation: [chemformer] 32 | time_limit: 300 33 | 34 | The `time_limit` is a recommandation for allowing the more expensive expansion model 35 | to finish a sufficient number of retrosynthesis iterations. 36 | 37 | You would have to change `localhost:8000` to the name and port of the machine hosting the REST service. 38 | 39 | You can then use the config-file with either `aizynthcli` or the Jupyter notebook interface. 40 | 41 | ## ModelZoo expansion model 42 | 43 | An expansion model using the ModelZoo feature is supplied in the `expansion_strategies` 44 | module. This is an adoption of the code from this repo: `https://github.com/AlanHassen/modelsmatter` that were used in the publications [Models Matter: The Impact of Single-Step Models on Synthesis Prediction](https://arxiv.org/abs/2308.05522) and [Mind the Retrosynthesis Gap: Bridging the divide between Single-step and Multi-step Retrosynthesis Prediction](https://openreview.net/forum?id=LjdtY0hM7tf). 45 | 46 | To use it, you first need to install the `modelsmatter_modelzoo` package from 47 | https://github.com/PTorrenPeraire/modelsmatter_modelzoo and set up the `ssbenchmark` 48 | environment. 49 | 50 | Ensure that the `external_models` sub-package contains the models required. 51 | If it does not, you will need to manually clone the required model repositories 52 | within `external_models`. 53 | 54 | To use the expansion model in `aizynthfinder`, you can specify it in the config-file 55 | under `expansion`. Here is an example setting to use the expansion model with `chemformer` 56 | as the external model: 57 | 58 | expansion: 59 | chemformer: 60 | type: expansion_strategies.ModelZooExpansionStrategy: 61 | module_path: /path_to_folder_containing_cloned_repository/modelsmatter_modelzoo/external_models/modelsmatter_chemformer_hpc/ 62 | use_gpu: False 63 | params: 64 | module_path: /path_to_model_file/chemformer_backward.ckpt 65 | vocab_path: /path_to_vocab_file/bart_vocab_downstream.txt 66 | search: 67 | algorithm_config: 68 | immediate_instantiation: [chemformer] 69 | time_limit: 300 -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "aizynthfinder" 3 | version = "4.3.2" 4 | description = "Retrosynthetic route finding using neural network guided Monte-Carlo tree search" 5 | authors = ["Molecular AI group "] 6 | license = "MIT" 7 | include = ["aizynthfinder/data/*.yml", "aizynthfinder/data/templates/*"] 8 | readme = "README.md" 9 | homepage = "https://github.com/MolecularAI/aizynthfinder/" 10 | repository = "https://github.com/MolecularAI/aizynthfinder/" 11 | documentation = "https://molecularai.github.io/aizynthfinder/" 12 | 13 | [tool.pytest.ini_options] 14 | mccabe-complexity = 9 15 | 16 | [tool.pylint.'MESSAGES CONTROL'] 17 | max-line-length = 120 18 | max-args = 8 19 | max-attributes = 20 20 | min-public-methods = 0 21 | disable = "C0116, E0401, E1101, I1101, R0801, R0902, R0903, R0914, R1732, R1735, W0221, W0237, W0406, W0602, W0603, W0707, W1201, W1203, W1514, W3101" 22 | 23 | [tool.coverage.report] 24 | exclude_also = [ 25 | "if TYPE_CHECKING:" 26 | ] 27 | 28 | [tool.poetry.dependencies] 29 | python = ">=3.9,<3.12" 30 | ipywidgets = "^7.5.1" 31 | jinja2 = "^3.0.0" 32 | jupyter = "^1.0.0" 33 | jupytext = "^1.3.3" 34 | notebook = "^6.5.3" 35 | networkx = "^2.4" 36 | deprecated = "^1.2.10" 37 | pandas = "^1.0.0" 38 | pillow = "^9.0.0" 39 | requests = "^2.23.0" 40 | rdchiral = "^1.0.0" 41 | rdkit = "^2023.9.1" 42 | tables = "^3.6.1" 43 | tqdm = "^4.42.1" 44 | onnxruntime = "<1.17.0" 45 | tensorflow = {version = "^2.8.0", optional=true} 46 | grpcio = {version = "^1.24.0", optional=true} 47 | tensorflow-serving-api = {version = "^2.1.0", optional=true} 48 | pymongo = {version = "^3.10.1", optional=true} 49 | route-distances = {version = "^1.2.4", optional=true} 50 | scipy = {version = "^1.0", optional=true} 51 | matplotlib = "^3.0.0" 52 | timeout-decorator = {version = "^0.5.0", optional=true} 53 | molbloom = {version = "^2.1.0", optional=true} 54 | paretoset = "^1.2.3" 55 | seaborn = "^0.13.2" 56 | numpy = "<2.0.0" 57 | 58 | [tool.poetry.dev-dependencies] 59 | black = "^22.0.0" 60 | invoke = "^2.2.0" 61 | pytest = "^6.2.2" 62 | pytest-black = "^0.3.12" 63 | pytest-cov = "^2.11.0" 64 | pytest-datadir = "^1.3.1" 65 | pytest-mock = "^3.5.0" 66 | pytest-mccabe = "^2.0.0" 67 | Sphinx = "^7.3.7" 68 | mypy = "^1.0.0" 69 | pylint = "^2.16.0" 70 | 71 | [tool.poetry.extras] 72 | all = ["pymongo", "route-distances", "scipy", "timeout-decorator", "molbloom"] 73 | tf = ["tensorflow", "grpcio", "tensorflow-serving-api"] 74 | 75 | [tool.poetry.scripts] 76 | aizynthapp = "aizynthfinder.interfaces.aizynthapp:main" 77 | aizynthcli = "aizynthfinder.interfaces.aizynthcli:main" 78 | cat_aizynth_output = "aizynthfinder.tools.cat_output:main" 79 | download_public_data = "aizynthfinder.tools.download_public_data:main" 80 | smiles2stock = "aizynthfinder.tools.make_stock:main" 81 | 82 | [tool.coverage.run] 83 | relative_files = true 84 | 85 | [build-system] 86 | requires = ["poetry_core>=1.0.0"] 87 | build-backend = "poetry.core.masonry.api" 88 | -------------------------------------------------------------------------------- /tasks.py: -------------------------------------------------------------------------------- 1 | from invoke import task 2 | 3 | 4 | @task 5 | def build_docs(context): 6 | context.run("aizynthcli -h > ./docs/cli_help.txt") 7 | context.run("sphinx-apidoc -o ./docs ./aizynthfinder") 8 | context.run("sphinx-build -M html ./docs ./docs/build") 9 | 10 | 11 | @task 12 | def full_tests(context): 13 | cmd = ( 14 | "pytest --black --mccabe " 15 | "--cov aizynthfinder --cov-branch --cov-report html:coverage --cov-report xml " 16 | "tests/" 17 | ) 18 | context.run(cmd) 19 | 20 | 21 | @task 22 | def run_mypy(context): 23 | context.run("mypy --ignore-missing-imports --show-error-codes aizynthfinder") 24 | 25 | 26 | @task 27 | def run_linting(context): 28 | print("Running mypy...") 29 | context.run("mypy --install-types", pty=True) 30 | context.run( 31 | "mypy --ignore-missing-imports --show-error-codes --implicit-optional aizynthfinder" 32 | ) 33 | print("Running pylint...") 34 | context.run("pylint aizynthfinder") 35 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/tests/__init__.py -------------------------------------------------------------------------------- /tests/breadth_first/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/tests/breadth_first/__init__.py -------------------------------------------------------------------------------- /tests/breadth_first/test_nodes.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from aizynthfinder.search.breadth_first.nodes import MoleculeNode 4 | from aizynthfinder.chem.serialization import MoleculeSerializer, MoleculeDeserializer 5 | 6 | 7 | @pytest.fixture 8 | def setup_root(default_config): 9 | def wrapper(smiles): 10 | return MoleculeNode.create_root(smiles, config=default_config) 11 | 12 | return wrapper 13 | 14 | 15 | def test_create_root_node(setup_root): 16 | node = setup_root("CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1") 17 | 18 | assert node.ancestors() == {node.mol} 19 | assert node.expandable 20 | assert not node.children 21 | 22 | 23 | def test_create_stub(setup_root, get_action): 24 | root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1" 25 | root = setup_root(root_smiles) 26 | reaction = get_action() 27 | 28 | root.add_stub(reaction=reaction) 29 | 30 | assert len(root.children) == 1 31 | assert len(root.children[0].children) == 2 32 | rxn_node = root.children[0] 33 | assert rxn_node.reaction is reaction 34 | exp_list = [node.mol for node in rxn_node.children] 35 | assert exp_list == list(reaction.reactants[0]) 36 | 37 | 38 | def test_initialize_stub_one_solved_leaf( 39 | setup_root, get_action, default_config, setup_stock 40 | ): 41 | root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1" 42 | root = setup_root(root_smiles) 43 | reaction = get_action() 44 | setup_stock(default_config, reaction.reactants[0][0]) 45 | 46 | root.add_stub(reaction=reaction) 47 | 48 | assert not root.children[0].children[0].expandable 49 | assert root.children[0].children[1].expandable 50 | 51 | 52 | def test_serialization_deserialization( 53 | setup_root, get_action, default_config, setup_stock 54 | ): 55 | root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1" 56 | root = setup_root(root_smiles) 57 | reaction = get_action() 58 | setup_stock(default_config, *reaction.reactants[0]) 59 | root.add_stub(reaction=reaction) 60 | 61 | molecule_serializer = MoleculeSerializer() 62 | dict_ = root.serialize(molecule_serializer) 63 | 64 | molecule_deserializer = MoleculeDeserializer(molecule_serializer.store) 65 | node = MoleculeNode.from_dict(dict_, default_config, molecule_deserializer) 66 | 67 | assert node.mol == root.mol 68 | assert len(node.children) == len(root.children) 69 | 70 | rxn_node = node.children[0] 71 | assert rxn_node.reaction.smarts == reaction.smarts 72 | assert rxn_node.reaction.metadata == reaction.metadata 73 | 74 | for grandchild1, grandchild2 in zip(rxn_node.children, root.children[0].children): 75 | assert grandchild1.mol == grandchild2.mol 76 | -------------------------------------------------------------------------------- /tests/breadth_first/test_search.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import pytest 4 | 5 | from aizynthfinder.search.breadth_first.search_tree import SearchTree 6 | 7 | 8 | def test_one_iteration(default_config, setup_policies, setup_stock): 9 | root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" 10 | child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"] 11 | child2_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F"] 12 | grandchild_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"] 13 | lookup = { 14 | root_smi: [ 15 | {"smiles": ".".join(child1_smi), "prior": 0.7}, 16 | {"smiles": ".".join(child2_smi), "prior": 0.3}, 17 | ], 18 | child1_smi[1]: {"smiles": ".".join(grandchild_smi), "prior": 0.7}, 19 | child2_smi[1]: {"smiles": ".".join(grandchild_smi), "prior": 0.7}, 20 | } 21 | stock = [child1_smi[0], child1_smi[2]] + grandchild_smi 22 | setup_policies(lookup, config=default_config) 23 | setup_stock(default_config, *stock) 24 | tree = SearchTree(default_config, root_smi) 25 | 26 | assert len(tree.mol_nodes) == 1 27 | 28 | assert not tree.one_iteration() 29 | 30 | assert len(tree.mol_nodes) == 6 31 | smiles = [node.mol.smiles for node in tree.mol_nodes] 32 | assert smiles == [root_smi] + child1_smi + child2_smi 33 | 34 | assert tree.one_iteration() 35 | 36 | assert len(tree.mol_nodes) == 10 37 | smiles = [node.mol.smiles for node in tree.mol_nodes] 38 | assert ( 39 | smiles == [root_smi] + child1_smi + child2_smi + grandchild_smi + grandchild_smi 40 | ) 41 | 42 | with pytest.raises(StopIteration): 43 | tree.one_iteration() 44 | 45 | 46 | def test_search_incomplete(default_config, setup_policies, setup_stock): 47 | root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" 48 | child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"] 49 | child2_smi = ["ClC(=O)c1ccc(F)cc1", "CN1CCC(CC1)C(=O)c1cccc(N)c1F"] 50 | grandchild_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"] 51 | lookup = { 52 | root_smi: [ 53 | {"smiles": ".".join(child1_smi), "prior": 0.7}, 54 | {"smiles": ".".join(child2_smi), "prior": 0.3}, 55 | ], 56 | child1_smi[1]: {"smiles": ".".join(grandchild_smi), "prior": 0.7}, 57 | } 58 | stock = [child1_smi[0], child1_smi[2]] + [grandchild_smi[0]] 59 | setup_policies(lookup, config=default_config) 60 | setup_stock(default_config, *stock) 61 | tree = SearchTree(default_config, root_smi) 62 | 63 | assert len(tree.mol_nodes) == 1 64 | 65 | tree.one_iteration() 66 | assert len(tree.mol_nodes) == 6 67 | 68 | assert not tree.one_iteration() 69 | 70 | assert len(tree.mol_nodes) == 8 71 | 72 | with pytest.raises(StopIteration): 73 | tree.one_iteration() 74 | 75 | 76 | def test_routes(default_config, setup_policies, setup_stock): 77 | random.seed(666) 78 | root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" 79 | child1_smi = ["O", "CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F"] 80 | child2_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F"] 81 | grandchild_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"] 82 | lookup = { 83 | root_smi: [ 84 | {"smiles": ".".join(child1_smi), "prior": 0.7}, 85 | {"smiles": ".".join(child2_smi), "prior": 0.3}, 86 | ], 87 | child1_smi[1]: {"smiles": ".".join(grandchild_smi), "prior": 0.7}, 88 | } 89 | stock = [child1_smi[0], child1_smi[2]] + grandchild_smi 90 | setup_policies(lookup, config=default_config) 91 | setup_stock(default_config, *stock) 92 | tree = SearchTree(default_config, root_smi) 93 | 94 | while True: 95 | try: 96 | tree.one_iteration() 97 | except StopIteration: 98 | break 99 | 100 | routes = tree.routes() 101 | 102 | assert len(routes) == 2 103 | smiles = [mol.smiles for mol in routes[1].molecules()] 104 | assert smiles == [root_smi] + child1_smi + grandchild_smi 105 | smiles = [mol.smiles for mol in routes[0].molecules()] 106 | assert smiles == [root_smi] + child2_smi + grandchild_smi 107 | -------------------------------------------------------------------------------- /tests/chem/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/tests/chem/__init__.py -------------------------------------------------------------------------------- /tests/chem/test_mol.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from rdkit import Chem 3 | 4 | from aizynthfinder.chem import MoleculeException, Molecule 5 | 6 | 7 | def test_no_input(): 8 | with pytest.raises(MoleculeException): 9 | Molecule() 10 | 11 | 12 | def test_create_with_mol(): 13 | rd_mol = Chem.MolFromSmiles("O") 14 | 15 | mol = Molecule(rd_mol=rd_mol) 16 | 17 | assert mol.smiles == "O" 18 | 19 | 20 | def test_create_with_smiles(): 21 | mol = Molecule(smiles="O") 22 | 23 | assert Chem.MolToSmiles(mol.rd_mol) == "O" 24 | 25 | 26 | def test_inchi(): 27 | mol = Molecule(smiles="O") 28 | 29 | assert mol.inchi == "InChI=1S/H2O/h1H2" 30 | 31 | 32 | def test_inchi_key(): 33 | mol = Molecule(smiles="O") 34 | 35 | assert mol.inchi_key == "XLYOFNOQVPJJNP-UHFFFAOYSA-N" 36 | 37 | 38 | def test_fingerprint(): 39 | mol = Molecule(smiles="O") 40 | 41 | assert sum(mol.fingerprint(2)) == 1 42 | 43 | assert sum(mol.fingerprint(2, 10)) == 1 44 | 45 | 46 | def test_sanitize(): 47 | mol = Molecule(smiles="O", sanitize=True) 48 | 49 | assert Chem.MolToSmiles(mol.rd_mol) == "O" 50 | 51 | mol = Molecule(smiles="c1ccccc1(C)(C)") 52 | 53 | with pytest.raises(MoleculeException): 54 | mol.sanitize() 55 | 56 | mol.sanitize(raise_exception=False) 57 | assert mol.smiles == "CC1(C)CCCCC1" 58 | 59 | 60 | def test_equality(): 61 | mol1 = Molecule(smiles="CCCCO") 62 | mol2 = Molecule(smiles="OCCCC") 63 | 64 | assert mol1 == mol2 65 | 66 | 67 | def test_basic_equality(): 68 | mol1 = Molecule(smiles="CC[C@@H](C)O") # R-2-butanol 69 | mol2 = Molecule(smiles="CC[C@H](C)O") # S-2-butanol 70 | 71 | assert mol1 != mol2 72 | assert mol1.basic_compare(mol2) 73 | 74 | 75 | def test_has_atom_mapping(): 76 | mol1 = Molecule(smiles="CCCCO") 77 | mol2 = Molecule(smiles="C[C:5]CCO") 78 | 79 | assert not mol1.has_atom_mapping() 80 | assert mol2.has_atom_mapping() 81 | 82 | 83 | def test_remove_atom_mapping(): 84 | mol = Molecule(smiles="C[C:5]CCO") 85 | 86 | assert mol.has_atom_mapping() 87 | 88 | mol.remove_atom_mapping() 89 | 90 | assert not mol.has_atom_mapping() 91 | 92 | 93 | def test_chiral_fingerprint(): 94 | mol1 = Molecule(smiles="C[C@@H](C(=O)O)N") 95 | mol2 = Molecule(smiles="C[C@@H](C(=O)O)N") 96 | 97 | fp1 = mol1.fingerprint(radius=2, chiral=False) 98 | fp2 = mol2.fingerprint(radius=2, chiral=True) 99 | 100 | assert fp1.tolist() != fp2.tolist() 101 | -------------------------------------------------------------------------------- /tests/chem/test_serialization.py: -------------------------------------------------------------------------------- 1 | from aizynthfinder.chem.serialization import MoleculeSerializer, MoleculeDeserializer 2 | from aizynthfinder.chem import Molecule, TreeMolecule 3 | 4 | 5 | def test_empty_store(): 6 | serializer = MoleculeSerializer() 7 | 8 | assert serializer.store == {} 9 | 10 | 11 | def test_add_single_mol(): 12 | serializer = MoleculeSerializer() 13 | mol = Molecule(smiles="CCC") 14 | 15 | id_ = serializer[mol] 16 | 17 | assert id_ == id(mol) 18 | assert serializer.store == {id_: {"smiles": "CCC", "class": "Molecule"}} 19 | 20 | 21 | def test_add_tree_mol(): 22 | serializer = MoleculeSerializer() 23 | mol1 = TreeMolecule(parent=None, smiles="CCC", transform=1) 24 | mol2 = TreeMolecule(smiles="CCO", parent=mol1) 25 | 26 | id_ = serializer[mol2] 27 | 28 | assert id_ == id(mol2) 29 | assert list(serializer.store.keys()) == [id(mol1), id_] 30 | assert serializer.store == { 31 | id_: { 32 | "smiles": "CCO", 33 | "class": "TreeMolecule", 34 | "parent": id(mol1), 35 | "transform": 2, 36 | }, 37 | id(mol1): { 38 | "smiles": "CCC", 39 | "class": "TreeMolecule", 40 | "parent": None, 41 | "transform": 1, 42 | }, 43 | } 44 | 45 | 46 | def test_deserialize_single_mol(): 47 | store = {123: {"smiles": "CCC", "class": "Molecule"}} 48 | deserializer = MoleculeDeserializer(store) 49 | 50 | assert deserializer[123].smiles == "CCC" 51 | 52 | 53 | def test_deserialize_tree_mols(): 54 | store = { 55 | 123: { 56 | "smiles": "CCC", 57 | "class": "TreeMolecule", 58 | "parent": None, 59 | "transform": 1, 60 | }, 61 | 234: {"smiles": "CCO", "class": "TreeMolecule", "parent": 123, "transform": 2}, 62 | } 63 | 64 | deserializer = MoleculeDeserializer(store) 65 | 66 | assert deserializer[123].smiles == "CCC" 67 | assert deserializer[234].smiles == "CCO" 68 | assert deserializer[123].parent is None 69 | assert deserializer[234].parent is deserializer[123] 70 | assert deserializer[123].transform == 1 71 | assert deserializer[234].transform == 2 72 | 73 | 74 | def test_chaining(): 75 | serializer = MoleculeSerializer() 76 | mol1 = TreeMolecule(parent=None, smiles="CCC", transform=1) 77 | mol2 = TreeMolecule(smiles="CCO", parent=mol1) 78 | 79 | id_ = serializer[mol2] 80 | 81 | deserializer = MoleculeDeserializer(serializer.store) 82 | 83 | assert deserializer[id_].smiles == mol2.smiles 84 | assert deserializer[id(mol1)].smiles == mol1.smiles 85 | assert id(deserializer[id_]) != id_ 86 | assert id(deserializer[id(mol1)]) != id(mol1) 87 | -------------------------------------------------------------------------------- /tests/context/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/tests/context/__init__.py -------------------------------------------------------------------------------- /tests/context/conftest.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | import pytest_mock 7 | 8 | from aizynthfinder.context.policy import TemplateBasedExpansionStrategy 9 | from aizynthfinder.context.scoring import BrokenBondsScorer 10 | from aizynthfinder.context.stock import ( 11 | MongoDbInchiKeyQuery, 12 | StockException, 13 | StockQueryMixin, 14 | ) 15 | 16 | 17 | @pytest.fixture 18 | def create_templates_file(tmpdir): 19 | def wrapper(templates): 20 | data = {"retro_template": templates} 21 | filename = str(tmpdir / "dummy_templates.hdf5") 22 | pd.DataFrame(data).to_hdf(filename, "table") 23 | return filename 24 | 25 | return wrapper 26 | 27 | 28 | @pytest.fixture 29 | def make_stock_query(): 30 | class StockQuery(StockQueryMixin): 31 | def __init__(self, mols, price, amount): 32 | self.mols = mols 33 | self._price = price 34 | self._amount = amount 35 | 36 | def __contains__(self, mol): 37 | return mol in self.mols 38 | 39 | def __len__(self): 40 | return len(self.mols) 41 | 42 | def amount(self, mol): 43 | if mol in self._amount: 44 | return self._amount[mol] 45 | raise StockException() 46 | 47 | def price(self, mol): 48 | if mol in self._price: 49 | return self._price[mol] 50 | raise StockException() 51 | 52 | def wrapper(mols, amount=None, price=None): 53 | return StockQuery(mols, price or {}, amount or {}) 54 | 55 | return wrapper 56 | 57 | 58 | @pytest.fixture 59 | def mocked_mongo_db_query(mocker): 60 | mocked_client = mocker.patch("aizynthfinder.context.stock.queries.get_mongo_client") 61 | 62 | def wrapper(**kwargs): 63 | return mocked_client, MongoDbInchiKeyQuery(**kwargs) 64 | 65 | return wrapper 66 | 67 | 68 | @pytest.fixture 69 | def setup_template_expansion_policy( 70 | default_config, create_dummy_templates, create_templates_file, mock_onnx_model 71 | ): 72 | def wrapper( 73 | key="policy1", templates=None, expansion_cls=TemplateBasedExpansionStrategy 74 | ): 75 | if templates is None: 76 | templates_filename = create_dummy_templates(3) 77 | else: 78 | templates_filename = create_templates_file(templates) 79 | 80 | strategy = expansion_cls( 81 | key, default_config, model="dummy.onnx", template=templates_filename 82 | ) 83 | 84 | return strategy, mock_onnx_model 85 | 86 | return wrapper 87 | 88 | 89 | @pytest.fixture 90 | def setup_mcts_broken_bonds(setup_stock, setup_expanded_mcts, shared_datadir): 91 | def wrapper(broken=True, config=None): 92 | root_smi = "CN1CCC(C(=O)c2cccc([NH:1][C:2](=O)c3ccc(F)cc3)c2F)CC1" 93 | 94 | reaction_template = pd.read_csv( 95 | shared_datadir / "test_reactions_template.csv", sep="\t" 96 | ) 97 | template1_smarts = reaction_template["RetroTemplate"][0] 98 | template2_smarts = reaction_template["RetroTemplate"][1] 99 | 100 | child1_smi = ["N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "CN1CCC(Cl)CC1", "O"] 101 | child2_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"] 102 | 103 | if not broken: 104 | template2_smarts = reaction_template["RetroTemplate"][2] 105 | child2_smi = ["N#Cc1cccc(Cl)c1F", "NC(=O)c1ccc(F)cc1"] 106 | 107 | lookup = { 108 | root_smi: {"smarts": template1_smarts, "prior": 1.0}, 109 | child1_smi[0]: { 110 | "smarts": template2_smarts, 111 | "prior": 1.0, 112 | }, 113 | } 114 | 115 | stock = [child1_smi[1], child1_smi[2]] + child2_smi 116 | 117 | if config: 118 | config.scorers.create_default_scorers() 119 | config.scorers.load(BrokenBondsScorer(config)) 120 | setup_stock(config, *stock) 121 | return setup_expanded_mcts(lookup) 122 | 123 | return wrapper 124 | -------------------------------------------------------------------------------- /tests/context/data/custom_loader.py: -------------------------------------------------------------------------------- 1 | def extract_smiles(filename): 2 | with open(filename, "r") as fileobj: 3 | for i, line in enumerate(fileobj.readlines()): 4 | if i == 0: 5 | continue 6 | yield line.strip().split(",")[0] 7 | -------------------------------------------------------------------------------- /tests/context/data/custom_loader2.py: -------------------------------------------------------------------------------- 1 | def extract_smiles(): 2 | return ["c1ccccc1", "Cc1ccccc1", "c1ccccc1", "CCO"] 3 | -------------------------------------------------------------------------------- /tests/context/data/linear_route_w_metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "mol", 3 | "route_metadata": { 4 | "created_at_iteration": 1, 5 | "is_solved": true 6 | }, 7 | "hide": false, 8 | "smiles": "OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1", 9 | "is_chemical": true, 10 | "in_stock": false, 11 | "children": [ 12 | { 13 | "type": "reaction", 14 | "hide": false, 15 | "smiles": "OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1>>OOc1ccc(-c2ccccc2)cc1.NC1CCCC(C2C=CC=C2)C1", 16 | "is_reaction": true, 17 | "metadata": { 18 | "classification": "abc" 19 | }, 20 | "children": [ 21 | { 22 | "type": "mol", 23 | "hide": false, 24 | "smiles": "NC1CCCC(C2C=CC=C2)C1", 25 | "is_chemical": true, 26 | "in_stock": true 27 | }, 28 | { 29 | "type": "mol", 30 | "hide": false, 31 | "smiles": "OOc1ccc(-c2ccccc2)cc1", 32 | "is_chemical": true, 33 | "in_stock": false, 34 | "children": [ 35 | { 36 | "type": "reaction", 37 | "hide": false, 38 | "smiles": "OOc1ccc(-c2ccccc2)cc1>>c1ccccc1.OOc1ccccc1", 39 | "is_reaction": true, 40 | "metadata": { 41 | "classification": "xyz" 42 | }, 43 | "children": [ 44 | { 45 | "type": "mol", 46 | "hide": false, 47 | "smiles": "c1ccccc1", 48 | "is_chemical": true, 49 | "in_stock": true 50 | }, 51 | { 52 | "type": "mol", 53 | "hide": false, 54 | "smiles": "OOc1ccccc1", 55 | "is_chemical": true, 56 | "in_stock": true 57 | } 58 | ] 59 | } 60 | ] 61 | } 62 | ] 63 | } 64 | ] 65 | } -------------------------------------------------------------------------------- /tests/context/data/test_reactions_template.csv: -------------------------------------------------------------------------------- 1 | ReactionSmilesClean mapped_rxn confidence RetroTemplate TemplateHash TemplateError 2 | CN1CCC(Cl)CC1.N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F.O>>CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1 Cl[CH:5]1[CH2:4][CH2:3][N:2]([CH3:1])[CH2:26][CH2:25]1.N#[C:6][c:8]1[cH:9][cH:10][cH:11][c:12]([NH:13][C:14](=[O:15])[c:16]2[cH:17][cH:18][c:19]([F:20])[cH:21][cH:22]2)[c:23]1[F:24].[OH2:7]>>[CH3:1][N:2]1[CH2:3][CH2:4][CH:5]([C:6](=[O:7])[c:8]2[cH:9][cH:10][cH:11][c:12]([NH:13][C:14](=[O:15])[c:16]3[cH:17][cH:18][c:19]([F:20])[cH:21][cH:22]3)[c:23]2[F:24])[CH2:25][CH2:26]1 0.9424734380772164 [C:2]-[CH;D3;+0:1](-[C:3])-[C;H0;D3;+0:4](=[O;H0;D1;+0:6])-[c:5]>>Cl-[CH;D3;+0:1](-[C:2])-[C:3].N#[C;H0;D2;+0:4]-[c:5].[OH2;D0;+0:6] 23f00a3c507eef75b22252e8eea99b2ce8eab9a573ff2fbd8227d93f49960a27 3 | N#Cc1cccc(N)c1F.O=C(Cl)c1ccc(F)cc1>>N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F Cl[C:9](=[O:10])[c:11]1[cH:12][cH:13][c:14]([F:15])[cH:16][cH:17]1.[N:1]#[C:2][c:3]1[cH:4][cH:5][cH:6][c:7]([NH2:8])[c:18]1[F:19]>>[N:1]#[C:2][c:3]1[cH:4][cH:5][cH:6][c:7]([NH:8][C:9](=[O:10])[c:11]2[cH:12][cH:13][c:14]([F:15])[cH:16][cH:17]2)[c:18]1[F:19] 0.9855040989023992 [O;D1;H0:2]=[C;H0;D3;+0:1](-[c:3])-[NH;D2;+0:4]-[c:5]>>Cl-[C;H0;D3;+0:1](=[O;D1;H0:2])-[c:3].[NH2;D1;+0:4]-[c:5] dadfa1075f086a1e76ed69f4d1c5cc44999708352c462aac5817785131169c41 4 | NC(=O)c1ccc(F)cc1.Fc1c(Cl)cccc1C#N>>N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F Cl[c:7]1[cH:6][cH:5][cH:4][c:3]([C:2]#[N:1])[c:18]1[F:19].[NH2:8][C:9](=[O:10])[c:11]1[cH:12][cH:13][c:14]([F:15])[cH:16][cH:17]1>>[N:1]#[C:2][c:3]1[cH:4][cH:5][cH:6][c:7]([NH:8][C:9](=[O:10])[c:11]2[cH:12][cH:13][c:14]([F:15])[cH:16][cH:17]2)[c:18]1[F:19] 0.9707971018098916 [O;D1;H0:6]=[C:5](-[NH;D2;+0:4]-[c;H0;D3;+0:1](:[c:2]):[c:3])-[c:7]1:[c:8]:[c:9]:[c:10]:[c:11]:[c:12]:1>>Cl-[c;H0;D3;+0:1](:[c:2]):[c:3].[NH2;D1;+0:4]-[C:5](=[O;D1;H0:6])-[c:7]1:[c:8]:[c:9]:[c:10]:[c:11]:[c:12]:1 01ba54afaa6c16613833b643d8f2503d9222eec7a3834f1cdd002faeb50ef239 5 | -------------------------------------------------------------------------------- /tests/context/test_collection.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from aizynthfinder.context.collection import ContextCollection 4 | 5 | 6 | class StringCollection(ContextCollection): 7 | def load(self, key, value): 8 | self._items[key] = value 9 | 10 | def load_from_config(self, config): 11 | for key, value in config.items(): 12 | self.load(key, value) 13 | 14 | 15 | class SingleStringCollection(StringCollection): 16 | _single_selection = True 17 | 18 | 19 | def test_empty_collection(): 20 | collection = StringCollection() 21 | 22 | assert len(collection) == 0 23 | assert collection.selection == [] 24 | 25 | 26 | def test_add_single_item(): 27 | collection = StringCollection() 28 | 29 | collection.load("key1", "value1") 30 | 31 | assert len(collection) == 1 32 | assert collection.items == ["key1"] 33 | assert collection.selection == [] 34 | 35 | 36 | def test_get_item(): 37 | collection = StringCollection() 38 | collection.load("key1", "value1") 39 | 40 | assert collection["key1"] == "value1" 41 | 42 | with pytest.raises(KeyError): 43 | collection["key2"] 44 | 45 | 46 | def test_del_item(): 47 | collection = StringCollection() 48 | collection.load("key1", "value1") 49 | 50 | del collection["key1"] 51 | 52 | assert len(collection) == 0 53 | 54 | with pytest.raises(KeyError): 55 | del collection["key2"] 56 | 57 | 58 | def test_select_single_item(): 59 | collection = StringCollection() 60 | collection.load("key1", "value1") 61 | 62 | collection.selection = "key1" 63 | 64 | assert collection.selection == ["key1"] 65 | 66 | with pytest.raises(KeyError): 67 | collection.selection = "key2" 68 | 69 | collection.load("key2", "value2") 70 | 71 | collection.selection = "key2" 72 | 73 | assert collection.selection == ["key2"] 74 | 75 | 76 | def test_select_append(): 77 | collection = StringCollection() 78 | collection.load("key1", "value1") 79 | collection.load("key2", "value2") 80 | 81 | collection.selection = "key1" 82 | 83 | assert collection.selection == ["key1"] 84 | 85 | collection.select("key2", append=True) 86 | 87 | assert collection.selection == ["key1", "key2"] 88 | 89 | 90 | def test_select_all(): 91 | collection = StringCollection() 92 | 93 | collection.select_all() 94 | 95 | assert collection.selection == [] 96 | 97 | collection.load("key1", "value1") 98 | collection.load("key2", "value2") 99 | 100 | collection.select_all() 101 | 102 | assert collection.selection == ["key1", "key2"] 103 | 104 | 105 | def test_select_first(): 106 | collection = StringCollection() 107 | 108 | collection.select_first() 109 | 110 | assert collection.selection == [] 111 | 112 | collection.load("key1", "value1") 113 | collection.load("key2", "value2") 114 | 115 | collection.select_first() 116 | 117 | assert collection.selection == ["key1"] 118 | 119 | 120 | def test_select_overwrite(): 121 | collection = StringCollection() 122 | collection.load("key1", "value1") 123 | collection.load("key2", "value2") 124 | 125 | collection.selection = "key1" 126 | 127 | assert collection.selection == ["key1"] 128 | 129 | collection.selection = ["key2"] 130 | 131 | assert collection.selection == ["key2"] 132 | 133 | 134 | def test_deselect_all(): 135 | collection = StringCollection() 136 | collection.load("key1", "value1") 137 | collection.load("key2", "value2") 138 | collection.selection = ("key1", "key2") 139 | 140 | assert len(collection) == 2 141 | assert collection.selection == ["key1", "key2"] 142 | 143 | collection.deselect() 144 | 145 | assert len(collection.selection) == 0 146 | 147 | 148 | def test_deselect_one(): 149 | collection = StringCollection() 150 | collection.load("key1", "value1") 151 | collection.load("key2", "value2") 152 | collection.selection = ("key1", "key2") 153 | 154 | collection.deselect("key1") 155 | 156 | assert collection.selection == ["key2"] 157 | 158 | with pytest.raises(KeyError): 159 | collection.deselect("key1") 160 | 161 | 162 | def test_empty_single_collection(): 163 | collection = SingleStringCollection() 164 | 165 | assert len(collection) == 0 166 | assert collection.selection is None 167 | 168 | 169 | def test_select_single_collection(): 170 | collection = SingleStringCollection() 171 | collection.load("key1", "value1") 172 | collection.load("key2", "value2") 173 | 174 | collection.selection = "key1" 175 | 176 | assert collection.selection == "key1" 177 | 178 | collection.selection = "key2" 179 | 180 | assert collection.selection == "key2" 181 | 182 | 183 | def test_select_multiple_single_collection(): 184 | collection = SingleStringCollection() 185 | collection.load("key1", "value1") 186 | collection.load("key2", "value2") 187 | 188 | collection.selection = ["key1"] 189 | 190 | assert collection.selection == "key1" 191 | 192 | with pytest.raises(ValueError): 193 | collection.selection = ["key1", "key2"] 194 | -------------------------------------------------------------------------------- /tests/data/branched_route.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "mol", 3 | "hide": false, 4 | "smiles": "OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1", 5 | "is_chemical": true, 6 | "in_stock": false, 7 | "children": [ 8 | { 9 | "type": "reaction", 10 | "hide": false, 11 | "smiles": "OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1>>OOc1ccc(-c2ccccc2)cc1.NC1CCCC(C2C=CC=C2)C1", 12 | "is_reaction": true, 13 | "metadata": {}, 14 | "children": [ 15 | { 16 | "type": "mol", 17 | "hide": false, 18 | "smiles": "OOc1ccc(-c2ccccc2)cc1", 19 | "is_chemical": true, 20 | "in_stock": false, 21 | "children": [ 22 | { 23 | "type": "reaction", 24 | "hide": false, 25 | "smiles": "OOc1ccc(-c2ccccc2)cc1>>c1ccccc1.OOc1ccccc1", 26 | "is_reaction": true, 27 | "metadata": {}, 28 | "children": [ 29 | { 30 | "type": "mol", 31 | "hide": false, 32 | "smiles": "c1ccccc1", 33 | "is_chemical": true, 34 | "in_stock": true 35 | }, 36 | { 37 | "type": "mol", 38 | "hide": false, 39 | "smiles": "OOc1ccccc1", 40 | "is_chemical": true, 41 | "in_stock": false, 42 | "children": [ 43 | { 44 | "type": "reaction", 45 | "hide": false, 46 | "smiles": "OOc1ccccc1>>O.Oc1ccccc1", 47 | "is_reaction": true, 48 | "metadata": {}, 49 | "children": [ 50 | { 51 | "type": "mol", 52 | "hide": false, 53 | "smiles": "O", 54 | "is_chemical": true, 55 | "in_stock": false 56 | }, 57 | { 58 | "type": "mol", 59 | "hide": false, 60 | "smiles": "Oc1ccccc1", 61 | "is_chemical": true, 62 | "in_stock": true 63 | } 64 | ] 65 | } 66 | ] 67 | } 68 | ] 69 | } 70 | ] 71 | }, 72 | { 73 | "type": "mol", 74 | "hide": false, 75 | "smiles": "NC1CCCC(C2C=CC=C2)C1", 76 | "is_chemical": true, 77 | "in_stock": false, 78 | "children": [ 79 | { 80 | "type": "reaction", 81 | "hide": false, 82 | "smiles": "NC1CCCC(C2C=CC=C2)C1>>NC1CCCCC1.C1=CCC=C1", 83 | "is_reaction": true, 84 | "metadata": {}, 85 | "children": [ 86 | { 87 | "type": "mol", 88 | "hide": false, 89 | "smiles": "NC1CCCCC1", 90 | "is_chemical": true, 91 | "in_stock": true 92 | }, 93 | { 94 | "type": "mol", 95 | "hide": false, 96 | "smiles": "C1=CCC=C1", 97 | "is_chemical": true, 98 | "in_stock": true 99 | } 100 | ] 101 | } 102 | ] 103 | } 104 | ] 105 | } 106 | ] 107 | } -------------------------------------------------------------------------------- /tests/data/combined_example_tree.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "mol", 3 | "hide": false, 4 | "smiles": "Cc1ccc2nc3ccccc3c(Nc3ccc(NC(=S)Nc4ccccc4)cc3)c2c1", 5 | "is_chemical": true, 6 | "in_stock": false, 7 | "children": [ 8 | { 9 | "type": "reaction", 10 | "hide": false, 11 | "smiles": "[c:1]1([N:7][cH3:8])[cH:2][cH:3]:[N:4]:[cH:5][cH:6]1>>Cl[c:1]1[cH:2][cH:3]:[N:4]:[cH:5][cH:6]1.[N:7][cH3:8]", 12 | "is_reaction": true, 13 | "metadata": {}, 14 | "children": [ 15 | { 16 | "type": "mol", 17 | "hide": false, 18 | "smiles": "Cc1ccc2nc3ccccc3c(Cl)c2c1", 19 | "is_chemical": true, 20 | "in_stock": true 21 | }, 22 | { 23 | "type": "mol", 24 | "hide": false, 25 | "smiles": "Nc1ccc(NC(=S)Nc2ccccc2)cc1", 26 | "is_chemical": true, 27 | "in_stock": true 28 | } 29 | ] 30 | }, 31 | { 32 | "type": "reaction", 33 | "hide": false, 34 | "smiles": "[N:1]([cH3:2])[C:4](=[S:3])[N:5][cH3:6]>>[N:1][cH3:2].[S:3]=[C:4]=[N:5][cH3:6]", 35 | "is_reaction": true, 36 | "metadata": {}, 37 | "children": [ 38 | { 39 | "type": "mol", 40 | "hide": false, 41 | "smiles": "S=C=Nc1ccccc1", 42 | "is_chemical": true, 43 | "in_stock": true 44 | }, 45 | { 46 | "type": "mol", 47 | "hide": false, 48 | "smiles": "Cc1ccc2nc3ccccc3c(Nc3ccc(N)cc3)c2c1", 49 | "is_chemical": true, 50 | "in_stock": false, 51 | "children": [ 52 | { 53 | "type": "reaction", 54 | "hide": false, 55 | "smiles": "[c:1]1([N:7][cH3:8])[cH:2][cH:3]:[N:4]:[cH:5][cH:6]1>>Cl[c:1]1[cH:2][cH:3]:[N:4]:[cH:5][cH:6]1.[N:7][cH3:8]", 56 | "is_reaction": true, 57 | "metadata": {}, 58 | "children": [ 59 | { 60 | "type": "mol", 61 | "hide": false, 62 | "smiles": "Nc1ccc(N)cc1", 63 | "is_chemical": true, 64 | "in_stock": true 65 | }, 66 | { 67 | "type": "mol", 68 | "hide": false, 69 | "smiles": "Cc1ccc2nc3ccccc3c(Cl)c2c1", 70 | "is_chemical": true, 71 | "in_stock": true 72 | } 73 | ] 74 | }, 75 | { 76 | "type": "reaction", 77 | "hide": false, 78 | "smiles": "[c:1]([cH2:2])([cH2:3])[N:4][cH3:5]>>Br[c:1]([cH2:2])[cH2:3].[N:4][cH3:5]", 79 | "is_reaction": true, 80 | "metadata": {}, 81 | "children": [ 82 | { 83 | "type": "mol", 84 | "hide": false, 85 | "smiles": "Nc1ccc(Br)cc1", 86 | "is_chemical": true, 87 | "in_stock": true 88 | }, 89 | { 90 | "type": "mol", 91 | "hide": false, 92 | "smiles": "Cc1ccc2nc3ccccc3c(N)c2c1", 93 | "is_chemical": true, 94 | "in_stock": true 95 | } 96 | ] 97 | } 98 | ] 99 | } 100 | ] 101 | } 102 | ] 103 | } -------------------------------------------------------------------------------- /tests/data/combined_example_tree2.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "mol", 3 | "hide": false, 4 | "smiles": "Cc1ccc2nc3ccccc3c(Nc3ccc(NC(=S)Nc4ccccc4)cc3)c2c1", 5 | "is_chemical": true, 6 | "in_stock": false, 7 | "children": [ 8 | { 9 | "type": "reaction", 10 | "hide": false, 11 | "smiles": "[c:1]1([N:7][cH3:8])[cH:2][cH:3]:[N:4]:[cH:5][cH:6]1>>Cl[c:1]1[cH:2][cH:3]:[N:4]:[cH:5][cH:6]1.[N:7][cH3:8]", 12 | "is_reaction": true, 13 | "metadata": {}, 14 | "children": [ 15 | { 16 | "type": "mol", 17 | "hide": false, 18 | "smiles": "Cc1ccc2nc3ccccc3c(Cl)c2c1", 19 | "is_chemical": true, 20 | "in_stock": true 21 | }, 22 | { 23 | "type": "mol", 24 | "hide": false, 25 | "smiles": "Nc1ccc(NC(=S)Nc2ccccc2)cc1", 26 | "is_chemical": true, 27 | "in_stock": true 28 | } 29 | ] 30 | }, 31 | { 32 | "type": "reaction", 33 | "hide": false, 34 | "smiles": "[N:1]([cH3:2])[C:4](=[S:3])[N:5][cH3:6]>>[N:1][cH3:2].[S:3]=[C:4]=[N:5][cH3:6]", 35 | "is_reaction": true, 36 | "metadata": {}, 37 | "children": [ 38 | { 39 | "type": "mol", 40 | "hide": false, 41 | "smiles": "Nc1ccccc1", 42 | "is_chemical": true, 43 | "in_stock": false 44 | }, 45 | { 46 | "type": "mol", 47 | "hide": false, 48 | "smiles": "Cc1ccc2nc3ccccc3c(Nc3ccc(N=C=S)cc3)c2c1", 49 | "is_chemical": true, 50 | "in_stock": false, 51 | "children": [ 52 | { 53 | "type": "reaction", 54 | "hide": false, 55 | "smiles": "[c:1]1([N:7][cH3:8])[cH:2][cH:3]:[N:4]:[cH:5][cH:6]1>>Cl[c:1]1[cH:2][cH:3]:[N:4]:[cH:5][cH:6]1.[N:7][cH3:8]", 56 | "is_reaction": true, 57 | "metadata": {}, 58 | "children": [ 59 | { 60 | "type": "mol", 61 | "hide": false, 62 | "smiles": "Nc1ccc(N=C=S)cc1", 63 | "is_chemical": true, 64 | "in_stock": false 65 | }, 66 | { 67 | "type": "mol", 68 | "hide": false, 69 | "smiles": "Cc1ccc2nc3ccccc3c(Cl)c2c1", 70 | "is_chemical": true, 71 | "in_stock": true 72 | } 73 | ] 74 | }, 75 | { 76 | "type": "reaction", 77 | "hide": false, 78 | "smiles": "[c:1]([cH2:2])([cH2:3])[N:4][cH3:5]>>Br[c:1]([cH2:2])[cH2:3].[N:4][cH3:5]", 79 | "is_reaction": true, 80 | "metadata": {}, 81 | "children": [ 82 | { 83 | "type": "mol", 84 | "hide": false, 85 | "smiles": "Cc1ccc2nc3ccccc3c(N)c2c1", 86 | "is_chemical": true, 87 | "in_stock": true 88 | }, 89 | { 90 | "type": "mol", 91 | "hide": false, 92 | "smiles": "S=C=Nc1ccc(Br)cc1", 93 | "is_chemical": true, 94 | "in_stock": false 95 | } 96 | ] 97 | } 98 | ] 99 | } 100 | ] 101 | } 102 | ] 103 | } -------------------------------------------------------------------------------- /tests/data/dummy2_raw_template_library.csv: -------------------------------------------------------------------------------- 1 | ID,PseudoHash,RSMI,classification,retro_template,template_hash 2 | 0,AAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX 3 | 0,AAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX 4 | 0,BAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX 5 | 0,ABA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX 6 | 0,AAB,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX 7 | 0,BBA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX 8 | 0,ABB,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY 9 | 0,BAB,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY 10 | 0,CAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY 11 | 0,ACA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY 12 | 0,AAC,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY 13 | 0,CCA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),DXX -------------------------------------------------------------------------------- /tests/data/dummy_noclass_raw_template_library.csv: -------------------------------------------------------------------------------- 1 | 0,0,AAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX 2 | 0,0,AAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX 3 | 0,0,BAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX 4 | 0,0,ABA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX 5 | 0,0,AAB,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX 6 | 0,0,BBA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX 7 | 0,0,ABB,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY 8 | 0,0,BAB,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY 9 | 0,0,CAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY 10 | 0,0,ACA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY 11 | 0,0,AAC,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY 12 | 0,0,CCA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),DXX -------------------------------------------------------------------------------- /tests/data/dummy_raw_template_library.csv: -------------------------------------------------------------------------------- 1 | 0,0,AAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2 2 | 0,0,AAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2 3 | 0,0,BAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2 4 | 0,0,ABA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2 5 | 0,0,AAB,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2 6 | 0,0,BBA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2 7 | 0,0,ABB,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY,0,2 8 | 0,0,BAB,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY,0,2 9 | 0,0,CAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY,0,2 10 | 0,0,ACA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY,0,2 11 | 0,0,AAC,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY,0,2 12 | 0,0,CCA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),DXX,0,2 -------------------------------------------------------------------------------- /tests/data/dummy_sani_raw_template_library.csv: -------------------------------------------------------------------------------- 1 | 0,0,AAA,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2 2 | 0,0,AAA,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2 3 | 0,0,BAA,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2 4 | 0,0,ABA,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2 5 | 0,0,AAB,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2 6 | 0,0,BBA,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2 7 | 0,0,ABB,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY,0,2 8 | 0,0,BAB,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY,0,2 9 | 0,0,CAA,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY,0,2 10 | 0,0,ACA,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY,0,2 11 | 0,0,AAC,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY,0,2 12 | 0,0,CCA,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),DXX,0,2 13 | 0,0,SSA,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,c1ccccc1(C)(C),unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),SXX,0,2 14 | 0,0,SSB,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,c1ccccc1(C)(C),unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),SXX,0,2 15 | 0,0,SSC,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,c1ccccc1(C)(C),unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),SXX,0,2 16 | 0,0,SSB,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,c1ccccc1(C)(C),unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),SXX,0,2 -------------------------------------------------------------------------------- /tests/data/full_search_tree.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/tests/data/full_search_tree.json.gz -------------------------------------------------------------------------------- /tests/data/input_checkpoint.json.gz: -------------------------------------------------------------------------------- 1 | {"processed_smiles": "c1ccccc1", "results": {"a": 1, "b": 2, "is_solved": true, "stock_info": 1, "trees": 3}} 2 | {"processed_smiles": "Cc1ccccc1", "results": {"a": 1, "b": 2, "is_solved": true, "stock_info": 1, "trees": 3}} 3 | -------------------------------------------------------------------------------- /tests/data/linear_route.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "mol", 3 | "route_metadata": { 4 | "created_at_iteration": 1, 5 | "is_solved": true 6 | }, 7 | "hide": false, 8 | "smiles": "OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1", 9 | "is_chemical": true, 10 | "in_stock": false, 11 | "children": [ 12 | { 13 | "type": "reaction", 14 | "hide": false, 15 | "smiles": "OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1>>OOc1ccc(-c2ccccc2)cc1.NC1CCCC(C2C=CC=C2)C1", 16 | "is_reaction": true, 17 | "metadata": {}, 18 | "children": [ 19 | { 20 | "type": "mol", 21 | "hide": false, 22 | "smiles": "NC1CCCC(C2C=CC=C2)C1", 23 | "is_chemical": true, 24 | "in_stock": true 25 | }, 26 | { 27 | "type": "mol", 28 | "hide": false, 29 | "smiles": "OOc1ccc(-c2ccccc2)cc1", 30 | "is_chemical": true, 31 | "in_stock": false, 32 | "children": [ 33 | { 34 | "type": "reaction", 35 | "hide": false, 36 | "smiles": "OOc1ccc(-c2ccccc2)cc1>>c1ccccc1.OOc1ccccc1", 37 | "is_reaction": true, 38 | "metadata": {}, 39 | "children": [ 40 | { 41 | "type": "mol", 42 | "hide": false, 43 | "smiles": "c1ccccc1", 44 | "is_chemical": true, 45 | "in_stock": true 46 | }, 47 | { 48 | "type": "mol", 49 | "hide": false, 50 | "smiles": "OOc1ccccc1", 51 | "is_chemical": true, 52 | "in_stock": true 53 | } 54 | ] 55 | } 56 | ] 57 | } 58 | ] 59 | } 60 | ] 61 | } -------------------------------------------------------------------------------- /tests/data/post_processing_test.py: -------------------------------------------------------------------------------- 1 | def post_processing(finder): 2 | return {"quantity": 5, "another_quantity": 10} 3 | -------------------------------------------------------------------------------- /tests/data/pre_processing_test.py: -------------------------------------------------------------------------------- 1 | def pre_processing(finder, target_idx): 2 | raise ValueError(target_idx) 3 | -------------------------------------------------------------------------------- /tests/data/test_reactions_template.csv: -------------------------------------------------------------------------------- 1 | ReactionSmilesClean mapped_rxn confidence RetroTemplate TemplateHash TemplateError 2 | CN1CCC(Cl)CC1.N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F.O>>CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1 Cl[CH:5]1[CH2:4][CH2:3][N:2]([CH3:1])[CH2:26][CH2:25]1.N#[C:6][c:8]1[cH:9][cH:10][cH:11][c:12]([NH:13][C:14](=[O:15])[c:16]2[cH:17][cH:18][c:19]([F:20])[cH:21][cH:22]2)[c:23]1[F:24].[OH2:7]>>[CH3:1][N:2]1[CH2:3][CH2:4][CH:5]([C:6](=[O:7])[c:8]2[cH:9][cH:10][cH:11][c:12]([NH:13][C:14](=[O:15])[c:16]3[cH:17][cH:18][c:19]([F:20])[cH:21][cH:22]3)[c:23]2[F:24])[CH2:25][CH2:26]1 0.9424734380772164 [C:2]-[CH;D3;+0:1](-[C:3])-[C;H0;D3;+0:4](=[O;H0;D1;+0:6])-[c:5]>>Cl-[CH;D3;+0:1](-[C:2])-[C:3].N#[C;H0;D2;+0:4]-[c:5].[OH2;D0;+0:6] 23f00a3c507eef75b22252e8eea99b2ce8eab9a573ff2fbd8227d93f49960a27 3 | N#Cc1cccc(N)c1F.O=C(Cl)c1ccc(F)cc1>>N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F Cl[C:9](=[O:10])[c:11]1[cH:12][cH:13][c:14]([F:15])[cH:16][cH:17]1.[N:1]#[C:2][c:3]1[cH:4][cH:5][cH:6][c:7]([NH2:8])[c:18]1[F:19]>>[N:1]#[C:2][c:3]1[cH:4][cH:5][cH:6][c:7]([NH:8][C:9](=[O:10])[c:11]2[cH:12][cH:13][c:14]([F:15])[cH:16][cH:17]2)[c:18]1[F:19] 0.9855040989023992 [O;D1;H0:2]=[C;H0;D3;+0:1](-[c:3])-[NH;D2;+0:4]-[c:5]>>Cl-[C;H0;D3;+0:1](=[O;D1;H0:2])-[c:3].[NH2;D1;+0:4]-[c:5] dadfa1075f086a1e76ed69f4d1c5cc44999708352c462aac5817785131169c41 4 | NC(=O)c1ccc(F)cc1.Fc1c(Cl)cccc1C#N>>N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F Cl[c:7]1[cH:6][cH:5][cH:4][c:3]([C:2]#[N:1])[c:18]1[F:19].[NH2:8][C:9](=[O:10])[c:11]1[cH:12][cH:13][c:14]([F:15])[cH:16][cH:17]1>>[N:1]#[C:2][c:3]1[cH:4][cH:5][cH:6][c:7]([NH:8][C:9](=[O:10])[c:11]2[cH:12][cH:13][c:14]([F:15])[cH:16][cH:17]2)[c:18]1[F:19] 0.9707971018098916 [O;D1;H0:6]=[C:5](-[NH;D2;+0:4]-[c;H0;D3;+0:1](:[c:2]):[c:3])-[c:7]1:[c:8]:[c:9]:[c:10]:[c:11]:[c:12]:1>>Cl-[c;H0;D3;+0:1](:[c:2]):[c:3].[NH2;D1;+0:4]-[C:5](=[O;D1;H0:6])-[c:7]1:[c:8]:[c:9]:[c:10]:[c:11]:[c:12]:1 01ba54afaa6c16613833b643d8f2503d9222eec7a3834f1cdd002faeb50ef239 -------------------------------------------------------------------------------- /tests/dfpn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/tests/dfpn/__init__.py -------------------------------------------------------------------------------- /tests/dfpn/test_nodes.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from aizynthfinder.search.dfpn.nodes import MoleculeNode, BIG_INT 4 | from aizynthfinder.search.dfpn import SearchTree 5 | 6 | 7 | @pytest.fixture 8 | def setup_root(default_config): 9 | def wrapper(smiles): 10 | owner = SearchTree(default_config) 11 | return MoleculeNode.create_root(smiles, config=default_config, owner=owner) 12 | 13 | return wrapper 14 | 15 | 16 | def test_create_root_node(setup_root): 17 | node = setup_root("OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1") 18 | 19 | assert node.expandable 20 | assert not node.children 21 | assert node.dn == 1 22 | assert node.pn == 1 23 | 24 | 25 | def test_expand_mol_node( 26 | default_config, setup_root, setup_policies, get_linear_expansion 27 | ): 28 | node = setup_root("OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1") 29 | setup_policies(get_linear_expansion) 30 | 31 | node.expand() 32 | 33 | assert not node.expandable 34 | assert len(node.children) == 1 35 | 36 | 37 | def test_promising_child( 38 | default_config, setup_root, setup_policies, get_linear_expansion 39 | ): 40 | node = setup_root("OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1") 41 | setup_policies(get_linear_expansion) 42 | node.expand() 43 | 44 | child = node.promising_child() 45 | 46 | assert child is node.children[0] 47 | assert child.pn_threshold == BIG_INT - 1 48 | assert child.dn_threshold == BIG_INT 49 | 50 | 51 | def test_expand_reaction_node( 52 | default_config, setup_root, setup_policies, get_linear_expansion 53 | ): 54 | node = setup_root("OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1") 55 | setup_policies(get_linear_expansion) 56 | node.expand() 57 | child = node.promising_child() 58 | 59 | child.expand() 60 | 61 | assert len(child.children) == 2 62 | 63 | 64 | def test_promising_child_reaction_node( 65 | default_config, 66 | setup_root, 67 | setup_policies, 68 | get_linear_expansion, 69 | ): 70 | node = setup_root("OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1") 71 | setup_policies(get_linear_expansion) 72 | node.expand() 73 | 74 | child = node.promising_child() 75 | child.expand() 76 | 77 | grandchild = child.promising_child() 78 | 79 | assert grandchild.mol.smiles == "OOc1ccc(-c2ccccc2)cc1" 80 | assert grandchild.pn_threshold == BIG_INT - 1 81 | assert grandchild.dn_threshold == 2 82 | 83 | 84 | def test_update( 85 | default_config, 86 | setup_root, 87 | setup_policies, 88 | get_linear_expansion, 89 | setup_stock, 90 | ): 91 | node = setup_root("OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1") 92 | setup_stock(default_config, "OOc1ccc(-c2ccccc2)cc1", "NC1CCCC(C2C=CC=C2)C1") 93 | setup_policies(get_linear_expansion) 94 | node.expand() 95 | 96 | child = node.promising_child() 97 | child.expand() 98 | child.update() 99 | 100 | node.update() 101 | 102 | assert node.proven 103 | assert not node.disproven 104 | -------------------------------------------------------------------------------- /tests/dfpn/test_search.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import pytest 4 | 5 | from aizynthfinder.search.dfpn.search_tree import SearchTree 6 | 7 | 8 | def test_search(default_config, setup_policies, setup_stock): 9 | root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" 10 | child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"] 11 | child2_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F"] 12 | grandchild_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"] 13 | lookup = { 14 | root_smi: [ 15 | {"smiles": ".".join(child1_smi), "prior": 0.7}, 16 | {"smiles": ".".join(child2_smi), "prior": 0.3}, 17 | ], 18 | child1_smi[1]: {"smiles": ".".join(grandchild_smi), "prior": 0.7}, 19 | child2_smi[1]: {"smiles": ".".join(grandchild_smi), "prior": 0.7}, 20 | } 21 | stock = [child1_smi[0], child1_smi[2]] + grandchild_smi 22 | setup_policies(lookup, config=default_config) 23 | setup_stock(default_config, *stock) 24 | tree = SearchTree(default_config, root_smi) 25 | 26 | assert not tree.one_iteration() 27 | 28 | routes = tree.routes() 29 | assert all([not route.is_solved for route in routes]) 30 | assert len(tree.mol_nodes) == 4 31 | 32 | assert not tree.one_iteration() 33 | assert tree.one_iteration() 34 | assert tree.one_iteration() 35 | assert tree.one_iteration() 36 | assert tree.one_iteration() 37 | 38 | routes = tree.routes() 39 | assert len(routes) == 2 40 | assert all(route.is_solved for route in routes) 41 | 42 | with pytest.raises(StopIteration): 43 | tree.one_iteration() 44 | 45 | 46 | def test_search_incomplete(default_config, setup_policies, setup_stock): 47 | root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" 48 | child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"] 49 | child2_smi = ["ClC(=O)c1ccc(F)cc1", "CN1CCC(CC1)C(=O)c1cccc(N)c1F"] 50 | grandchild_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"] 51 | lookup = { 52 | root_smi: [ 53 | {"smiles": ".".join(child1_smi), "prior": 0.7}, 54 | {"smiles": ".".join(child2_smi), "prior": 0.3}, 55 | ], 56 | child1_smi[1]: {"smiles": ".".join(grandchild_smi), "prior": 0.7}, 57 | } 58 | stock = [child1_smi[0], child1_smi[2]] + [grandchild_smi[0]] 59 | setup_policies(lookup, config=default_config) 60 | setup_stock(default_config, *stock) 61 | tree = SearchTree(default_config, root_smi) 62 | 63 | while True: 64 | try: 65 | tree.one_iteration() 66 | except StopIteration: 67 | break 68 | 69 | routes = tree.routes() 70 | assert len(routes) == 2 71 | assert all(not route.is_solved for route in routes) 72 | -------------------------------------------------------------------------------- /tests/mcts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/tests/mcts/__init__.py -------------------------------------------------------------------------------- /tests/mcts/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from aizynthfinder.search.mcts import MctsNode, MctsSearchTree 4 | 5 | 6 | @pytest.fixture 7 | def generate_root(default_config): 8 | def wrapper(smiles, config=None): 9 | return MctsNode.create_root(smiles, tree=None, config=config or default_config) 10 | 11 | return wrapper 12 | 13 | 14 | @pytest.fixture 15 | def fresh_tree(default_config): 16 | return MctsSearchTree(config=default_config, root_smiles=None) 17 | 18 | 19 | @pytest.fixture 20 | def set_default_prior(default_config): 21 | default_config.search.algorithm_config["use_prior"] = False 22 | 23 | def wrapper(prior): 24 | default_config.search.algorithm_config["default_prior"] = prior 25 | 26 | yield wrapper 27 | default_config.search.algorithm_config["use_prior"] = True 28 | 29 | 30 | @pytest.fixture 31 | def setup_mcts_search(get_one_step_expansion, setup_policies, generate_root): 32 | expansion_strategy, filter_strategy = setup_policies(get_one_step_expansion) 33 | root_smiles = list(expansion_strategy.lookup.keys())[0] 34 | return ( 35 | generate_root(root_smiles), 36 | expansion_strategy, 37 | filter_strategy, 38 | ) 39 | 40 | 41 | @pytest.fixture 42 | def setup_complete_mcts_tree(default_config, setup_policies, setup_stock): 43 | root_smiles = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" 44 | tree = MctsSearchTree(config=default_config, root_smiles=root_smiles) 45 | lookup = { 46 | root_smiles: { 47 | "smiles": "CN1CCC(Cl)CC1.N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F.O", 48 | "prior": 1.0, 49 | }, 50 | "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F": { 51 | "smiles": "N#Cc1cccc(N)c1F.O=C(Cl)c1ccc(F)cc1", 52 | "prior": 1.0, 53 | }, 54 | } 55 | setup_policies(lookup) 56 | 57 | setup_stock( 58 | default_config, "CN1CCC(Cl)CC1", "O", "N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1" 59 | ) 60 | 61 | node1 = tree.root 62 | node1.expand() 63 | 64 | node2 = node1.promising_child() 65 | node2.expand() 66 | 67 | node3 = node2.promising_child() 68 | node3.expand() 69 | 70 | return tree, [node1, node2, node3] 71 | -------------------------------------------------------------------------------- /tests/mcts/test_multiobjective.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from aizynthfinder.search.mcts.node import ParetoMctsNode 4 | from aizynthfinder.search.mcts import MctsSearchTree 5 | 6 | 7 | @pytest.fixture 8 | def generate_root(default_config): 9 | def wrapper(smiles, config=None): 10 | return ParetoMctsNode.create_root( 11 | smiles, tree=None, config=config or default_config 12 | ) 13 | 14 | return wrapper 15 | 16 | 17 | @pytest.fixture 18 | def setup_mcts_search( 19 | default_config, get_one_step_expansion, setup_policies, generate_root 20 | ): 21 | default_config.search.algorithm_config["search_rewards"] = [ 22 | "number of reactions", 23 | "number of pre-cursors in stock", 24 | ] 25 | expansion_strategy, filter_strategy = setup_policies(get_one_step_expansion) 26 | root_smiles = list(expansion_strategy.lookup.keys())[0] 27 | return ( 28 | generate_root(root_smiles), 29 | expansion_strategy, 30 | filter_strategy, 31 | ) 32 | 33 | 34 | def test_expand_root_node(setup_mcts_search): 35 | root, _, _ = setup_mcts_search 36 | 37 | root.expand() 38 | 39 | view = root.children_view() 40 | assert len(view["actions"]) == 3 41 | assert view["priors"] == [[0.7, 0.7], [0.5, 0.5], [0.3, 0.3]] 42 | assert view["values"] == [[0.7, 0.7], [0.5, 0.5], [0.3, 0.3]] 43 | assert view["visitations"] == [1, 1, 1] 44 | assert view["objects"] == [None, None, None] 45 | 46 | 47 | def test_expand_root_with_default_priors(setup_mcts_search, set_default_prior): 48 | root, _, _ = setup_mcts_search 49 | set_default_prior(0.01) 50 | 51 | root.expand() 52 | 53 | view = root.children_view() 54 | assert len(view["actions"]) == 3 55 | assert view["priors"] == [[0.01, 0.01], [0.01, 0.01], [0.01, 0.01]] 56 | assert view["values"] == [[0.01, 0.01], [0.01, 0.01], [0.01, 0.01]] 57 | assert view["visitations"] == [1, 1, 1] 58 | assert view["objects"] == [None, None, None] 59 | 60 | 61 | def test_backpropagate(setup_mcts_search): 62 | root, _, _ = setup_mcts_search 63 | root.expand() 64 | child = root.promising_child() 65 | view_prior = root.children_view() 66 | 67 | root.backpropagate(child, [1.5, 2.0]) 68 | 69 | view_post = root.children_view() 70 | assert view_post["visitations"][0] == view_prior["visitations"][0] + 1 71 | assert view_prior["visitations"][1:] == view_post["visitations"][1:] 72 | assert view_prior["values"] == view_post["values"] 73 | assert view_post["rewards_cum"][0][0] == view_prior["rewards_cum"][0][0] + 1.5 74 | assert view_post["rewards_cum"][0][1] == view_prior["rewards_cum"][0][1] + 2.0 75 | assert view_prior["rewards_cum"][1:] == view_post["rewards_cum"][1:] 76 | 77 | 78 | def test_setup_weighted_sum_tree(default_config): 79 | default_config.search.algorithm_config["search_rewards"] = [ 80 | "number of reactions", 81 | "number of pre-cursors in stock", 82 | ] 83 | default_config.search.algorithm_config["search_rewards_weights"] = [0.5, 0.1] 84 | root_smiles = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" 85 | tree = MctsSearchTree(config=default_config, root_smiles=root_smiles) 86 | 87 | assert tree.mode == "weighted-sum" 88 | assert len(tree.reward_scorer.selection) == 2 89 | assert tree.compute_reward(tree.root) == 0.0 90 | 91 | 92 | def test_setup_mo_tree(default_config): 93 | default_config.search.algorithm_config["search_rewards"] = [ 94 | "number of reactions", 95 | "number of pre-cursors in stock", 96 | ] 97 | root_smiles = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" 98 | tree = MctsSearchTree(config=default_config, root_smiles=root_smiles) 99 | 100 | assert tree.mode == "multi-objective" 101 | assert len(tree.reward_scorer.selection) == 2 102 | assert tree.compute_reward(tree.root) == [0.0, 0.0] 103 | -------------------------------------------------------------------------------- /tests/mcts/test_reward.py: -------------------------------------------------------------------------------- 1 | from aizynthfinder.analysis import TreeAnalysis 2 | from aizynthfinder.context.scoring import ( 3 | NumberOfReactionsScorer, 4 | StateScorer, 5 | ) 6 | from aizynthfinder.search.mcts import MctsSearchTree 7 | 8 | 9 | def test_reward_node(default_config, generate_root): 10 | config = default_config 11 | search_reward_scorer = repr(StateScorer(config)) 12 | post_process_reward_scorer = repr(NumberOfReactionsScorer()) 13 | 14 | config.search.algorithm_config["search_rewards"] = [search_reward_scorer] 15 | config.post_processing.route_scorers = [post_process_reward_scorer] 16 | 17 | node = generate_root("CCCCOc1ccc(CC(=O)N(C)O)cc1", config) 18 | 19 | search_scorer = config.scorers[config.search.algorithm_config["search_rewards"][0]] 20 | route_scorer = config.scorers[config.post_processing.route_scorers[0]] 21 | 22 | assert round(search_scorer(node), 4) == 0.0491 23 | assert route_scorer(node) == 0 24 | 25 | 26 | def test_default_postprocessing_reward(setup_aizynthfinder): 27 | """Test using default postprocessing.route_score""" 28 | root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" 29 | child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"] 30 | lookup = {root_smi: {"smiles": ".".join(child1_smi), "prior": 1.0}} 31 | finder = setup_aizynthfinder(lookup, child1_smi) 32 | 33 | config = finder.config 34 | config.search.return_first = True 35 | 36 | search_reward_scorer = repr(NumberOfReactionsScorer()) 37 | state_scorer = repr(StateScorer(config)) 38 | 39 | config.search.algorithm_config["search_rewards"] = [search_reward_scorer] 40 | finder.config = config 41 | 42 | assert len(finder.config.post_processing.route_scorers) == 0 43 | 44 | finder.tree_search() 45 | tree_analysis_search = TreeAnalysis( 46 | finder.tree, scorer=config.scorers[search_reward_scorer] 47 | ) 48 | tree_analysis_pp = TreeAnalysis(finder.tree, scorer=config.scorers[state_scorer]) 49 | 50 | finder.build_routes() 51 | assert finder.tree.reward_scorer_name == search_reward_scorer 52 | 53 | top_score_tree_analysis = tree_analysis_search.tree_statistics()["top_score"] 54 | top_score_finder = finder.tree.compute_reward(tree_analysis_search.best()) 55 | 56 | assert top_score_finder == top_score_tree_analysis 57 | 58 | top_score_tree_analysis = tree_analysis_pp.tree_statistics()["top_score"] 59 | top_score_finder = finder.analysis.tree_statistics()["top_score"] 60 | 61 | # Finder used the search_reward_scorer and not state_scorer 62 | assert top_score_finder != top_score_tree_analysis 63 | 64 | 65 | def test_custom_reward(setup_aizynthfinder): 66 | """Test using different custom reward functions for MCTS and route building.""" 67 | 68 | root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" 69 | child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"] 70 | lookup = {root_smi: {"smiles": ".".join(child1_smi), "prior": 1.0}} 71 | finder = setup_aizynthfinder(lookup, child1_smi) 72 | 73 | # Test first with return_first and multiple route scores 74 | config = finder.config 75 | config.search.return_first = True 76 | 77 | search_reward_scorer = repr(StateScorer(config)) 78 | post_process_reward_scorer = repr(NumberOfReactionsScorer()) 79 | 80 | config.search.algorithm_config["search_rewards"] = [search_reward_scorer] 81 | config.post_processing.route_scorers = [post_process_reward_scorer] 82 | finder.config = config 83 | 84 | assert finder.config.post_processing.route_scorers == [post_process_reward_scorer] 85 | 86 | finder.tree_search() 87 | tree_analysis_search = TreeAnalysis( 88 | finder.tree, scorer=config.scorers[search_reward_scorer] 89 | ) 90 | tree_analysis_pp = TreeAnalysis( 91 | finder.tree, scorer=config.scorers[post_process_reward_scorer] 92 | ) 93 | 94 | finder.build_routes() 95 | 96 | assert finder.config.post_processing.route_scorers == [post_process_reward_scorer] 97 | assert finder.tree.reward_scorer_name == search_reward_scorer 98 | 99 | top_score_tree_analysis = tree_analysis_search.tree_statistics()["top_score"] 100 | top_score_finder = finder.tree.compute_reward(tree_analysis_search.best()) 101 | 102 | assert top_score_finder == top_score_tree_analysis 103 | 104 | top_score_tree_analysis = tree_analysis_pp.tree_statistics()["top_score"] 105 | top_score_finder = finder.analysis.tree_statistics()["top_score"] 106 | 107 | assert top_score_finder == top_score_tree_analysis 108 | 109 | 110 | def test_reward_node_backward_compatibility(default_config): 111 | reward_scorer = repr(NumberOfReactionsScorer()) 112 | default_config.search.algorithm_config["search_reward"] = reward_scorer 113 | 114 | tree = MctsSearchTree(config=default_config, root_smiles=None) 115 | 116 | assert tree.reward_scorer_name == reward_scorer 117 | -------------------------------------------------------------------------------- /tests/mcts/test_tree.py: -------------------------------------------------------------------------------- 1 | def test_select_leaf_root(setup_complete_mcts_tree): 2 | tree, nodes = setup_complete_mcts_tree 3 | nodes[0].is_expanded = False 4 | 5 | leaf = tree.select_leaf() 6 | 7 | assert leaf is nodes[0] 8 | 9 | 10 | def test_select_leaf(setup_complete_mcts_tree): 11 | tree, nodes = setup_complete_mcts_tree 12 | 13 | leaf = tree.select_leaf() 14 | 15 | assert leaf is nodes[2] 16 | 17 | 18 | def test_backpropagation(setup_complete_mcts_tree, mocker): 19 | tree, nodes = setup_complete_mcts_tree 20 | for node in nodes: 21 | node.backpropagate = mocker.MagicMock() 22 | score = tree.reward_scorer[tree.reward_scorer_name](nodes[2]) 23 | 24 | tree.backpropagate(nodes[2]) 25 | 26 | nodes[0].backpropagate.assert_called_once_with(nodes[1], score) 27 | nodes[1].backpropagate.assert_called_once_with(nodes[2], score) 28 | nodes[2].backpropagate.assert_not_called() 29 | 30 | 31 | def test_route_to_node(setup_complete_mcts_tree): 32 | tree, nodes = setup_complete_mcts_tree 33 | 34 | actions, route_nodes = nodes[2].path_to() 35 | 36 | assert len(actions) == 2 37 | assert len(nodes) == 3 38 | assert nodes[0] == route_nodes[0] 39 | assert nodes[1] == route_nodes[1] 40 | assert nodes[2] == route_nodes[2] 41 | 42 | 43 | def test_create_graph(setup_complete_mcts_tree): 44 | tree, nodes = setup_complete_mcts_tree 45 | 46 | graph = tree.graph() 47 | 48 | assert len(graph) == 3 49 | assert list(graph.successors(nodes[0])) == [nodes[1]] 50 | assert list(graph.successors(nodes[1])) == [nodes[2]] 51 | -------------------------------------------------------------------------------- /tests/retrostar/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/tests/retrostar/__init__.py -------------------------------------------------------------------------------- /tests/retrostar/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | 5 | from aizynthfinder.search.retrostar.search_tree import SearchTree 6 | from aizynthfinder.search.retrostar.cost import MoleculeCost 7 | from aizynthfinder.search.retrostar.nodes import MoleculeNode 8 | from aizynthfinder.aizynthfinder import AiZynthFinder 9 | 10 | 11 | @pytest.fixture 12 | def setup_aizynthfinder(setup_policies, setup_stock): 13 | def wrapper(expansions, stock): 14 | finder = AiZynthFinder() 15 | root_smi = list(expansions.keys())[0] 16 | setup_policies(expansions, config=finder.config) 17 | setup_stock(finder.config, *stock) 18 | finder.target_smiles = root_smi 19 | finder.config.search.algorithm = ( 20 | "aizynthfinder.search.retrostar.search_tree.SearchTree" 21 | ) 22 | return finder 23 | 24 | return wrapper 25 | 26 | 27 | @pytest.fixture 28 | def setup_search_tree(default_config, setup_policies, setup_stock): 29 | root_smiles = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" 30 | tree = SearchTree(config=default_config, root_smiles=root_smiles) 31 | lookup = { 32 | root_smiles: { 33 | "smiles": "CN1CCC(Cl)CC1.N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F.O", 34 | "prior": 1.0, 35 | } 36 | } 37 | setup_policies(lookup) 38 | 39 | setup_stock( 40 | default_config, "CN1CCC(Cl)CC1", "O", "N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1" 41 | ) 42 | return tree 43 | 44 | 45 | @pytest.fixture 46 | def setup_star_root(default_config): 47 | def wrapper(smiles): 48 | return MoleculeNode.create_root( 49 | smiles, config=default_config, molecule_cost=MoleculeCost(default_config) 50 | ) 51 | 52 | return wrapper 53 | -------------------------------------------------------------------------------- /tests/retrostar/test_retrostar.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from aizynthfinder.search.retrostar.search_tree import SearchTree 4 | from aizynthfinder.chem.serialization import MoleculeSerializer 5 | 6 | 7 | def test_one_iteration(setup_search_tree): 8 | tree = setup_search_tree 9 | 10 | tree.one_iteration() 11 | 12 | assert len(tree.root.children) == 1 13 | assert len(tree.root.children[0].children) == 3 14 | 15 | 16 | def test_one_iteration_filter_unfeasible(setup_search_tree): 17 | tree = setup_search_tree 18 | smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1>>CN1CCC(Cl)CC1.N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F.O" 19 | tree.config.filter_policy["dummy"].lookup[smi] = 0.0 20 | 21 | tree.one_iteration() 22 | assert len(tree.root.children) == 0 23 | 24 | 25 | def test_one_iteration_filter_feasible(setup_search_tree): 26 | tree = setup_search_tree 27 | smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1>>CN1CCC(Cl)CC1.N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F.O" 28 | tree.config.filter_policy["dummy"].lookup[smi] = 0.5 29 | 30 | tree.one_iteration() 31 | assert len(tree.root.children) == 1 32 | 33 | 34 | def test_one_expansion_with_finder(setup_aizynthfinder): 35 | """ 36 | Test the building of this tree: 37 | root 38 | | 39 | child 1 40 | """ 41 | root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" 42 | child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"] 43 | lookup = {root_smi: {"smiles": ".".join(child1_smi), "prior": 1.0}} 44 | finder = setup_aizynthfinder(lookup, child1_smi) 45 | 46 | # Test first with return_first 47 | finder.config.search.return_first = True 48 | finder.tree_search() 49 | 50 | nodes = finder.tree.mol_nodes 51 | assert len(nodes) == 4 52 | assert nodes[0].mol.smiles == root_smi 53 | assert nodes[1].mol.smiles == child1_smi[0] 54 | assert finder.search_stats["iterations"] == 1 55 | assert finder.search_stats["returned_first"] 56 | 57 | # then test with iteration limit 58 | finder.config.search.return_first = False 59 | finder.config.search.iteration_limit = 45 60 | finder.prepare_tree() 61 | finder.tree_search() 62 | 63 | assert len(finder.tree.mol_nodes) == 4 64 | # It will not continue because it cannot expand any more nodes 65 | assert finder.search_stats["iterations"] == 2 66 | assert not finder.search_stats["returned_first"] 67 | 68 | 69 | def test_serialization_deserialization( 70 | mocker, setup_search_tree, tmpdir, default_config 71 | ): 72 | tree = setup_search_tree 73 | tree.one_iteration() 74 | 75 | mocked_json_dump = mocker.patch( 76 | "aizynthfinder.search.retrostar.search_tree.json.dump" 77 | ) 78 | serializer = MoleculeSerializer() 79 | filename = str(tmpdir / "dummy.json") 80 | 81 | # Test serialization 82 | 83 | tree.serialize(filename) 84 | 85 | expected_dict = { 86 | "tree": tree.root.serialize(serializer), 87 | "molecules": serializer.store, 88 | } 89 | 90 | mocked_json_dump.assert_called_once_with( 91 | expected_dict, mocker.ANY, indent=mocker.ANY 92 | ) 93 | 94 | # Test deserialization 95 | 96 | mocker.patch( 97 | "aizynthfinder.search.retrostar.search_tree.json.load", 98 | return_value=expected_dict, 99 | ) 100 | mocker.patch( 101 | "aizynthfinder.search.retrostar.nodes.deserialize_action", return_value=None 102 | ) 103 | 104 | new_tree = SearchTree.from_json(filename, default_config) 105 | 106 | assert new_tree.root.mol == tree.root.mol 107 | assert len(new_tree.root.children) == len(tree.root.children) 108 | 109 | 110 | def test_split_andor_tree(shared_datadir, default_config): 111 | tree = SearchTree.from_json( 112 | str(shared_datadir / "andor_tree_for_clustering.json"), default_config 113 | ) 114 | 115 | routes = tree.routes() 116 | 117 | assert len(routes) == 97 118 | 119 | 120 | def test_update(shared_datadir, default_config, setup_stock): 121 | # Todo: re-write 122 | setup_stock( 123 | default_config, 124 | "Nc1ccc(NC(=S)Nc2ccccc2)cc1", 125 | "Cc1ccc2nc3ccccc3c(Cl)c2c1", 126 | "Nc1ccccc1", 127 | "Nc1ccc(N=C=S)cc1", 128 | "Cc1ccc2nc3ccccc3c(Br)c2c1", 129 | "Nc1ccc(Br)cc1", 130 | ) 131 | tree = SearchTree.from_json( 132 | str(shared_datadir / "andor_tree_for_clustering.json"), default_config 133 | ) 134 | 135 | saved_root_value = tree.root.value 136 | tree.mol_nodes[-1].parent.update(35, from_mol=tree.mol_nodes[-1].mol) 137 | 138 | assert [np.round(child.value, 2) for child in tree.root.children][:2] == [ 139 | 3.17, 140 | 3.31, 141 | ] 142 | assert tree.root.value == saved_root_value 143 | 144 | tree.serialize("temp.json") 145 | -------------------------------------------------------------------------------- /tests/retrostar/test_retrostar_cost.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from aizynthfinder.chem import Molecule 4 | from aizynthfinder.search.retrostar.cost import ( 5 | MoleculeCost, 6 | RetroStarCost, 7 | ZeroMoleculeCost, 8 | ) 9 | 10 | 11 | def test_retrostar_cost(setup_mocked_model): 12 | mol = Molecule(smiles="CCCC") 13 | 14 | cost = RetroStarCost(model_path="dummy", fingerprint_length=10, dropout_rate=0.0) 15 | assert pytest.approx(cost.calculate(mol), abs=0.001) == 30 16 | 17 | 18 | def test_zero_molecule_cost(): 19 | mol = Molecule(smiles="CCCC") 20 | 21 | cost = ZeroMoleculeCost().calculate(mol) 22 | assert cost == 0.0 23 | 24 | 25 | def test_molecule_cost_zero(default_config): 26 | default_config.search.algorithm_config["molecule_cost"] = { 27 | "cost": "ZeroMoleculeCost" 28 | } 29 | mol = Molecule(smiles="CCCC") 30 | 31 | molecule_cost = MoleculeCost(default_config)(mol) 32 | assert molecule_cost == 0.0 33 | 34 | 35 | def test_molecule_cost_retrostar(default_config, setup_mocked_model): 36 | default_config.search.algorithm_config["molecule_cost"] = { 37 | "cost": "aizynthfinder.search.retrostar.cost.RetroStarCost", 38 | "model_path": "dummy", 39 | "fingerprint_length": 10, 40 | "dropout_rate": 0.0, 41 | } 42 | mol = Molecule(smiles="CCCC") 43 | 44 | molecule_cost = MoleculeCost(default_config)(mol) 45 | assert pytest.approx(molecule_cost, abs=0.001) == 30 46 | -------------------------------------------------------------------------------- /tests/retrostar/test_retrostar_nodes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import networkx as nx 3 | 4 | from aizynthfinder.search.retrostar.nodes import MoleculeNode 5 | from aizynthfinder.chem.serialization import MoleculeSerializer, MoleculeDeserializer 6 | from aizynthfinder.search.andor_trees import ReactionTreeFromAndOrTrace 7 | 8 | 9 | def test_create_root_node(setup_star_root): 10 | node = setup_star_root("CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1") 11 | 12 | assert node.target_value == 0 13 | assert node.ancestors() == {node.mol} 14 | assert node.expandable 15 | assert not node.children 16 | 17 | 18 | def test_close_single_node(setup_star_root): 19 | node = setup_star_root("CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1") 20 | 21 | assert node.expandable 22 | 23 | delta = node.close() 24 | 25 | assert not np.isfinite(delta) 26 | assert not node.solved 27 | assert not node.expandable 28 | 29 | 30 | def test_create_stub(setup_star_root, get_action): 31 | root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1" 32 | root = setup_star_root(root_smiles) 33 | reaction = get_action() 34 | 35 | root.add_stub(cost=5.0, reaction=reaction) 36 | 37 | assert len(root.children) == 1 38 | assert len(root.children[0].children) == 2 39 | rxn_node = root.children[0] 40 | assert rxn_node.reaction is reaction 41 | exp_list = [node.mol for node in rxn_node.children] 42 | assert exp_list == list(reaction.reactants[0]) 43 | assert rxn_node.value == rxn_node.target_value == 5 44 | assert not rxn_node.solved 45 | 46 | # This is done after a node has been expanded 47 | delta = root.close() 48 | 49 | assert delta == 5.0 50 | assert root.value == 5.0 51 | assert rxn_node.children[0].ancestors() == {root.mol, rxn_node.children[0].mol} 52 | 53 | 54 | def test_initialize_stub_one_solved_leaf( 55 | setup_star_root, get_action, default_config, setup_stock 56 | ): 57 | root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1" 58 | root = setup_star_root(root_smiles) 59 | reaction = get_action() 60 | setup_stock(default_config, reaction.reactants[0][0]) 61 | 62 | root.add_stub(cost=5.0, reaction=reaction) 63 | root.close() 64 | 65 | assert not root.children[0].solved 66 | assert not root.solved 67 | assert root.children[0].children[0].solved 68 | 69 | 70 | def test_initialize_stub_two_solved_leafs( 71 | setup_star_root, get_action, default_config, setup_stock 72 | ): 73 | root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1" 74 | root = setup_star_root(root_smiles) 75 | reaction = get_action() 76 | setup_stock(default_config, *reaction.reactants[0]) 77 | 78 | root.add_stub(cost=5.0, reaction=reaction) 79 | root.close() 80 | 81 | assert root.children[0].solved 82 | assert root.solved 83 | 84 | 85 | def test_serialization_deserialization( 86 | setup_star_root, get_action, default_config, setup_stock 87 | ): 88 | root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1" 89 | root = setup_star_root(root_smiles) 90 | reaction = get_action() 91 | setup_stock(default_config, *reaction.reactants[0]) 92 | root.add_stub(cost=5.0, reaction=reaction) 93 | root.close() 94 | 95 | molecule_serializer = MoleculeSerializer() 96 | dict_ = root.serialize(molecule_serializer) 97 | 98 | molecule_deserializer = MoleculeDeserializer(molecule_serializer.store) 99 | node = MoleculeNode.from_dict( 100 | dict_, default_config, molecule_deserializer, root.molecule_cost 101 | ) 102 | 103 | assert node.mol == root.mol 104 | assert node.value == root.value 105 | assert node.cost == root.cost 106 | assert len(node.children) == len(root.children) 107 | 108 | rxn_node = node.children[0] 109 | assert rxn_node.reaction.smarts == reaction.smarts 110 | assert rxn_node.reaction.metadata == reaction.metadata 111 | assert rxn_node.cost == root.children[0].cost 112 | assert rxn_node.value == root.children[0].value 113 | 114 | for grandchild1, grandchild2 in zip(rxn_node.children, root.children[0].children): 115 | assert grandchild1.mol == grandchild2.mol 116 | 117 | 118 | def test_conversion_to_reaction_tree( 119 | setup_star_root, get_action, default_config, setup_stock 120 | ): 121 | root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1" 122 | root = setup_star_root(root_smiles) 123 | reaction = get_action() 124 | setup_stock(default_config, *reaction.reactants[0]) 125 | root.add_stub(cost=5.0, reaction=reaction) 126 | root.close() 127 | graph = nx.DiGraph() 128 | graph.add_edge(root, root.children[0]) 129 | graph.add_edge(root.children[0], root.children[0].children[0]) 130 | graph.add_edge(root.children[0], root.children[0].children[1]) 131 | 132 | rt = ReactionTreeFromAndOrTrace(graph, default_config.stock).tree 133 | 134 | molecules = list(rt.molecules()) 135 | rt_reactions = list(rt.reactions()) 136 | assert len(molecules) == 3 137 | assert len(list(rt.leafs())) == 2 138 | assert len(rt_reactions) == 1 139 | assert molecules[0].inchi_key == root.mol.inchi_key 140 | assert molecules[1].inchi_key == root.children[0].children[0].mol.inchi_key 141 | assert molecules[2].inchi_key == root.children[0].children[1].mol.inchi_key 142 | assert rt_reactions[0].reaction_smiles() == reaction.reaction_smiles() 143 | -------------------------------------------------------------------------------- /tests/test_expander.py: -------------------------------------------------------------------------------- 1 | from aizynthfinder.aizynthfinder import AiZynthExpander 2 | 3 | 4 | def test_expander_defaults(get_one_step_expansion, setup_policies): 5 | expander = AiZynthExpander() 6 | setup_policies(get_one_step_expansion, config=expander.config) 7 | smi = "CCCCOc1ccc(CC(=O)N(C)O)cc1" 8 | 9 | reactions = expander.do_expansion(smi) 10 | 11 | assert len(reactions) == 2 12 | assert len(reactions[0]) == 1 13 | assert len(reactions[1]) == 1 14 | 15 | assert reactions[0][0].mol.smiles == smi 16 | assert reactions[1][0].mol.smiles == smi 17 | assert len(reactions[0][0].reactants[0]) == 2 18 | assert len(reactions[1][0].reactants[0]) == 2 19 | smi1 = [mol.smiles for mol in reactions[0][0].reactants[0]] 20 | smi2 = [mol.smiles for mol in reactions[1][0].reactants[0]] 21 | assert smi1 != smi2 22 | 23 | 24 | def test_expander_top1(get_one_step_expansion, setup_policies): 25 | expander = AiZynthExpander() 26 | setup_policies(get_one_step_expansion, config=expander.config) 27 | smi = "CCCCOc1ccc(CC(=O)N(C)O)cc1" 28 | 29 | reactions = expander.do_expansion(smi, return_n=1) 30 | 31 | assert len(reactions) == 1 32 | smiles_list = [mol.smiles for mol in reactions[0][0].reactants[0]] 33 | assert smiles_list == ["CCCCOc1ccc(CC(=O)Cl)cc1", "CNO"] 34 | 35 | 36 | def test_expander_filter(get_one_step_expansion, setup_policies): 37 | def filter_func(reaction): 38 | return "CNO" not in [mol.smiles for mol in reaction.reactants[0]] 39 | 40 | expander = AiZynthExpander() 41 | setup_policies(get_one_step_expansion, config=expander.config) 42 | smi = "CCCCOc1ccc(CC(=O)N(C)O)cc1" 43 | 44 | reactions = expander.do_expansion(smi, filter_func=filter_func) 45 | 46 | assert len(reactions) == 1 47 | smiles_list = [mol.smiles for mol in reactions[0][0].reactants[0]] 48 | assert smiles_list == ["CCCCBr", "CN(O)C(=O)Cc1ccc(O)cc1"] 49 | 50 | 51 | def test_expander_filter_policy(get_one_step_expansion, setup_policies): 52 | expander = AiZynthExpander() 53 | _, filter_strategy = setup_policies(get_one_step_expansion, config=expander.config) 54 | filter_strategy.lookup[ 55 | "CCCCOc1ccc(CC(=O)N(C)O)cc1>>CCCCOc1ccc(CC(=O)Cl)cc1.CNO" 56 | ] = 0.5 57 | smi = "CCCCOc1ccc(CC(=O)N(C)O)cc1" 58 | 59 | reactions = expander.do_expansion(smi) 60 | 61 | assert len(reactions) == 2 62 | assert reactions[0][0].metadata["feasibility"] == 0.5 63 | assert reactions[1][0].metadata["feasibility"] == 0.0 64 | -------------------------------------------------------------------------------- /tests/test_gui.py: -------------------------------------------------------------------------------- 1 | from aizynthfinder.analysis.routes import RouteCollection 2 | from aizynthfinder.interfaces.gui.utils import pareto_fronts_plot, route_display 3 | from aizynthfinder.interfaces.gui.pareto_fronts import ParetoFrontsGUI 4 | 5 | 6 | def test_plot_pareto_fronts(setup_mo_scorer, setup_analysis, default_config): 7 | analysis, _ = setup_analysis(scorer=setup_mo_scorer(default_config)) 8 | routes = RouteCollection.from_analysis(analysis) 9 | 10 | pareto_fronts_plot(routes) 11 | 12 | 13 | def test_display_route(setup_analysis, mocker): 14 | display_patch = mocker.patch("aizynthfinder.interfaces.gui.utils.display") 15 | mocked_widget = mocker.MagicMock() 16 | analysis, _ = setup_analysis() 17 | routes = RouteCollection.from_analysis(analysis) 18 | routes.make_images() 19 | 20 | route_display(0, routes, mocked_widget) 21 | display_patch.assert_called() 22 | 23 | 24 | def test_pareto_gui(setup_analysis, mocker, default_config): 25 | display_patch = mocker.patch("aizynthfinder.interfaces.gui.pareto_fronts.display") 26 | analysis, _ = setup_analysis() 27 | routes = RouteCollection.from_analysis(analysis) 28 | 29 | ParetoFrontsGUI(routes.reaction_trees, default_config.scorers) 30 | 31 | display_patch.assert_called() 32 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/tests/utils/__init__.py -------------------------------------------------------------------------------- /tests/utils/test_bonds.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests to check the functionality of the bonds script. 3 | """ 4 | from aizynthfinder.chem import TreeMolecule, SmilesBasedRetroReaction 5 | from aizynthfinder.utils.bonds import BrokenBonds 6 | 7 | 8 | def test_focussed_bonds_broken(): 9 | mol = TreeMolecule(smiles="[CH3:1][NH:2][C:3](C)=[O:4]", parent=None) 10 | reaction = SmilesBasedRetroReaction( 11 | mol, 12 | mapped_prod_smiles="[CH3:1][NH:2][C:3](C)=[O:4]", 13 | reactants_str="C[C:3](=[O:4])O.[CH3:1][NH:2]", 14 | ) 15 | focussed_bonds = [(1, 2), (3, 4), (3, 2)] 16 | broken_bonds = BrokenBonds(focussed_bonds) 17 | broken_focussed_bonds = broken_bonds(reaction) 18 | 19 | assert broken_focussed_bonds == [(2, 3)] 20 | 21 | 22 | def test_focussed_bonds_not_broken(): 23 | mol = TreeMolecule(smiles="[CH3:1][NH:2][C:3](C)=[O:4]", parent=None) 24 | reaction = SmilesBasedRetroReaction( 25 | mol, 26 | mapped_prod_smiles="[CH3:1][NH:2][C:3](C)=[O:4]", 27 | reactants_str="C[C:3](=[O:4])O.[CH3:1][NH:2]", 28 | ) 29 | focussed_bonds = [(1, 2), (3, 4)] 30 | broken_bonds = BrokenBonds(focussed_bonds) 31 | broken_focussed_bonds = broken_bonds(reaction) 32 | 33 | assert broken_focussed_bonds == [] 34 | assert mol.has_all_focussed_bonds(focussed_bonds) is True 35 | 36 | 37 | def test_focussed_bonds_not_in_target_mol(): 38 | mol = TreeMolecule(smiles="[CH3:1][NH:2][C:3](C)=[O:4]", parent=None) 39 | focussed_bonds = [(1, 4)] 40 | 41 | assert mol.has_all_focussed_bonds(focussed_bonds) is False 42 | -------------------------------------------------------------------------------- /tests/utils/test_dynamic_loading.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from aizynthfinder.utils.loading import load_dynamic_class 4 | 5 | 6 | class MyException(Exception): 7 | pass 8 | 9 | 10 | def test_simple(): 11 | cls = load_dynamic_class("aizynthfinder.reactiontree.ReactionTree") 12 | 13 | assert cls.__name__ == "ReactionTree" 14 | 15 | 16 | def test_default_module(): 17 | cls = load_dynamic_class( 18 | "ReactionTree", default_module="aizynthfinder.reactiontree" 19 | ) 20 | 21 | assert cls.__name__ == "ReactionTree" 22 | 23 | 24 | def test_no_default_module(): 25 | with pytest.raises(ValueError, match="default_module"): 26 | load_dynamic_class("ReactionTree") 27 | 28 | with pytest.raises(MyException, match="default_module"): 29 | load_dynamic_class("ReactionTree", exception_cls=MyException) 30 | 31 | 32 | def test_incorrect_module(): 33 | bad_module = "aizynthfinder.rt." 34 | with pytest.raises(ValueError, match=bad_module): 35 | load_dynamic_class(f"{bad_module}.ReactionTree") 36 | 37 | with pytest.raises(MyException, match=bad_module): 38 | load_dynamic_class(f"{bad_module}.ReactionTree", exception_cls=MyException) 39 | 40 | 41 | def test_incorrect_class(): 42 | bad_class = "ReactionTreee" 43 | with pytest.raises(ValueError, match=bad_class): 44 | load_dynamic_class(f"aizynthfinder.reactiontree.{bad_class}") 45 | 46 | with pytest.raises(MyException, match=bad_class): 47 | load_dynamic_class( 48 | f"aizynthfinder.reactiontree.{bad_class}", exception_cls=MyException 49 | ) 50 | -------------------------------------------------------------------------------- /tests/utils/test_external_tf_models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import aizynthfinder.utils.models as models 5 | from aizynthfinder.utils.models import ( 6 | SUPPORT_EXTERNAL_APIS, 7 | ExternalModelViaGRPC, 8 | ExternalModelViaREST, 9 | ) 10 | 11 | 12 | @pytest.mark.xfail( 13 | condition=not SUPPORT_EXTERNAL_APIS, reason="API packages not installed" 14 | ) 15 | @pytest.fixture() 16 | def setup_rest_mock(mocker): 17 | models.TF_SERVING_HOST = "localhost" 18 | models.TF_SERVING_REST_PORT = "255" 19 | mocked_request = mocker.patch("aizynthfinder.utils.models.requests.request") 20 | mocked_request.return_value.status_code = 200 21 | mocked_request.return_value.headers = {"Content-Type": "application/json"} 22 | 23 | def wrapper(response): 24 | if isinstance(response, list): 25 | mocked_request.return_value.json.side_effect = response 26 | else: 27 | mocked_request.return_value.json.return_value = response 28 | return mocked_request 29 | 30 | yield wrapper 31 | 32 | models.TF_SERVING_HOST = None 33 | models.TF_SERVING_REST_PORT = None 34 | 35 | 36 | @pytest.mark.xfail( 37 | condition=not SUPPORT_EXTERNAL_APIS, reason="API packages not installed" 38 | ) 39 | @pytest.fixture() 40 | def setup_grpc_mock(mocker, signature_grpc): 41 | models.TF_SERVING_HOST = "localhost" 42 | models.TF_SERVING_GRPC_PORT = "255" 43 | mocker.patch("aizynthfinder.utils.models.grpc.insecure_channel") 44 | mocked_pred_service = mocker.patch( 45 | "aizynthfinder.utils.models.prediction_service_pb2_grpc.PredictionServiceStub" 46 | ) 47 | mocker.patch( 48 | "aizynthfinder.utils.models.get_model_metadata_pb2.GetModelMetadataRequest" 49 | ) 50 | mocker.patch("aizynthfinder.utils.models.predict_pb2.PredictRequest") 51 | mocked_message = mocker.patch("aizynthfinder.utils.models.MessageToDict") 52 | mocked_message.return_value = signature_grpc 53 | 54 | def wrapper(response=None): 55 | if not response: 56 | return 57 | mocked_pred_service.return_value.Predict.return_value.outputs = response 58 | 59 | yield wrapper 60 | 61 | models.TF_SERVING_HOST = None 62 | models.TF_SERVING_GRPC_PORT = None 63 | 64 | 65 | @pytest.fixture() 66 | def signature_rest(): 67 | return { 68 | "metadata": { 69 | "signature_def": { 70 | "signature_def": { 71 | "serving_default": { 72 | "inputs": { 73 | "first_layer": { 74 | "tensor_shape": {"dim": [{"size": 1}, {"size": 2048}]} 75 | } 76 | } 77 | } 78 | } 79 | } 80 | } 81 | } 82 | 83 | 84 | @pytest.fixture() 85 | def signature_grpc(): 86 | return { 87 | "metadata": { 88 | "signature_def": { 89 | "signatureDef": { 90 | "serving_default": { 91 | "inputs": { 92 | "first_layer": { 93 | "tensorShape": {"dim": [{"size": 1}, {"size": 2048}]} 94 | } 95 | }, 96 | "outputs": {"output": None}, 97 | } 98 | } 99 | } 100 | } 101 | } 102 | 103 | 104 | @pytest.mark.xfail( 105 | condition=not SUPPORT_EXTERNAL_APIS, reason="API packages not installed" 106 | ) 107 | def test_setup_tf_rest_model(signature_rest, setup_rest_mock): 108 | setup_rest_mock(signature_rest) 109 | 110 | model = ExternalModelViaREST("dummy") 111 | 112 | assert len(model) == 2048 113 | 114 | 115 | @pytest.mark.xfail( 116 | condition=not SUPPORT_EXTERNAL_APIS, reason="API packages not installed" 117 | ) 118 | def test_predict_tf_rest_model(signature_rest, setup_rest_mock): 119 | responses = [signature_rest, {"outputs": [0.0, 1.0]}] 120 | setup_rest_mock(responses) 121 | model = ExternalModelViaREST("dummy") 122 | 123 | out = model.predict(np.zeros([1, len(model)])) 124 | 125 | assert list(out) == [0.0, 1.0] 126 | 127 | 128 | @pytest.mark.xfail( 129 | condition=not SUPPORT_EXTERNAL_APIS, reason="API packages not installed" 130 | ) 131 | def test_setup_tf_grpc_model(setup_grpc_mock): 132 | setup_grpc_mock() 133 | 134 | model = ExternalModelViaGRPC("dummy") 135 | 136 | assert len(model) == 2048 137 | 138 | 139 | @pytest.mark.xfail( 140 | condition=not SUPPORT_EXTERNAL_APIS, 141 | reason="Tensorflow and API packages not installed", 142 | ) 143 | def test_predict_tf_grpc_model(setup_grpc_mock): 144 | import tensorflow as tf 145 | 146 | setup_grpc_mock({"output": tf.make_tensor_proto(tf.constant([0.0, 1.0]))}) 147 | model = ExternalModelViaGRPC("dummy") 148 | 149 | out = model.predict(np.zeros([1, len(model)])) 150 | 151 | assert list(out) == [0.0, 1.0] 152 | -------------------------------------------------------------------------------- /tests/utils/test_file_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gzip 3 | import json 4 | 5 | import pytest 6 | import pandas as pd 7 | 8 | from aizynthfinder.utils.files import ( 9 | cat_datafiles, 10 | split_file, 11 | start_processes, 12 | read_datafile, 13 | save_datafile, 14 | ) 15 | 16 | 17 | @pytest.fixture 18 | def create_dummy_file(tmpdir, mocker): 19 | patched_tempfile = mocker.patch("aizynthfinder.utils.files.tempfile.mktemp") 20 | split_files = [ 21 | tmpdir / "split1", 22 | tmpdir / "split2", 23 | tmpdir / "split3", 24 | ] 25 | patched_tempfile.side_effect = split_files 26 | filename = tmpdir / "input" 27 | 28 | def wrapper(content): 29 | with open(filename, "w") as fileobj: 30 | fileobj.write(content) 31 | return filename, split_files 32 | 33 | return wrapper 34 | 35 | 36 | def test_split_file_even(create_dummy_file): 37 | filename, split_files = create_dummy_file("\n".join(list("abcdef"))) 38 | 39 | split_file(filename, 3) 40 | 41 | read_lines = [] 42 | for filename in split_files: 43 | assert os.path.exists(filename) 44 | with open(filename, "r") as fileobj: 45 | read_lines.append(fileobj.read()) 46 | assert read_lines[0] == "a\nb" 47 | assert read_lines[1] == "c\nd" 48 | assert read_lines[2] == "e\nf" 49 | 50 | 51 | def test_split_file_odd(create_dummy_file): 52 | filename, split_files = create_dummy_file("\n".join(list("abcdefg"))) 53 | 54 | split_file(filename, 3) 55 | 56 | read_lines = [] 57 | for filename in split_files: 58 | assert os.path.exists(filename) 59 | with open(filename, "r") as fileobj: 60 | read_lines.append(fileobj.read()) 61 | assert read_lines[0] == "a\nb\nc" 62 | assert read_lines[1] == "d\ne" 63 | assert read_lines[2] == "f\ng" 64 | 65 | 66 | def test_start_processes(tmpdir): 67 | 68 | script_filename = str(tmpdir / "dummy.py") 69 | with open(script_filename, "w") as fileobj: 70 | fileobj.write("import sys\nimport time\nprint(sys.argv[1])\ntime.sleep(2)\n") 71 | 72 | def create_cmd(index, filename): 73 | return ["python", script_filename, f"{filename}-{index}"] 74 | 75 | start_processes(["dummy", "dummy"], str(tmpdir / "log"), create_cmd, 2) 76 | 77 | for index in [1, 2]: 78 | logfile = str(tmpdir / f"log{index}.log") 79 | assert os.path.exists(logfile) 80 | with open(logfile, "r") as fileobj: 81 | lines = fileobj.read() 82 | assert lines == f"dummy-{index}\n" 83 | 84 | 85 | def test_cat_hdf(create_dummy_stock1, create_dummy_stock2, tmpdir): 86 | filename = str(tmpdir / "output.hdf") 87 | inputs = [create_dummy_stock1("hdf5"), create_dummy_stock2] 88 | 89 | cat_datafiles(inputs, filename) 90 | 91 | data = pd.read_hdf(filename, "table") 92 | assert len(data) == 4 93 | assert list(data.inchi_key.values) == [ 94 | "UHOVQNZJYSORNB-UHFFFAOYSA-N", 95 | "YXFVVABEGXRONW-UHFFFAOYSA-N", 96 | "UHOVQNZJYSORNB-UHFFFAOYSA-N", 97 | "ISWSIDIOOBJBQZ-UHFFFAOYSA-N", 98 | ] 99 | 100 | 101 | def test_cat_hdf_no_trees(tmpdir, create_dummy_stock1, create_dummy_stock2): 102 | hdf_filename = str(tmpdir / "output.hdf") 103 | tree_filename = str(tmpdir / "trees.json") 104 | inputs = [create_dummy_stock1("hdf5"), create_dummy_stock2] 105 | 106 | cat_datafiles(inputs, hdf_filename, tree_filename) 107 | 108 | assert not os.path.exists(tree_filename) 109 | 110 | 111 | def test_cat_hdf_trees(tmpdir): 112 | hdf_filename = str(tmpdir / "output.hdf") 113 | tree_filename = str(tmpdir / "trees.json") 114 | filename1 = str(tmpdir / "file1.hdf5") 115 | filename2 = str(tmpdir / "file2.hdf5") 116 | trees1 = [[1], [2]] 117 | trees2 = [[3], [4]] 118 | pd.DataFrame({"mol": ["A", "B"], "trees": trees1}).to_hdf(filename1, "table") 119 | pd.DataFrame({"mol": ["A", "B"], "trees": trees2}).to_hdf(filename2, "table") 120 | 121 | cat_datafiles([filename1, filename2], hdf_filename, tree_filename) 122 | 123 | assert os.path.exists(tree_filename + ".gz") 124 | with gzip.open(tree_filename + ".gz", "rt", encoding="UTF-8") as fileobj: 125 | trees_cat = json.load(fileobj) 126 | assert trees_cat == trees1 + trees2 127 | assert "trees" not in pd.read_hdf(hdf_filename, "table") 128 | 129 | 130 | @pytest.mark.parametrize( 131 | ("filename"), 132 | [ 133 | ("temp.json"), 134 | ("temp.hdf5"), 135 | ], 136 | ) 137 | def test_save_load_datafile_roundtrip(filename, tmpdir): 138 | data1 = pd.DataFrame({"a": [0, 1, 2], "b": [2, 3, 4]}) 139 | 140 | save_datafile(data1, tmpdir / filename) 141 | 142 | data2 = read_datafile(tmpdir / filename) 143 | 144 | assert data1.columns.to_list() == data2.columns.to_list() 145 | assert data1.a.to_list() == data2.a.to_list() 146 | assert data1.b.to_list() == data2.b.to_list() 147 | -------------------------------------------------------------------------------- /tests/utils/test_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import json 4 | from tarfile import TarFile 5 | from pathlib import Path 6 | 7 | import pytest 8 | from PIL import Image, ImageDraw 9 | 10 | from aizynthfinder.utils import image 11 | from aizynthfinder.chem import TreeMolecule, TemplatedRetroReaction 12 | 13 | 14 | @pytest.fixture 15 | def new_image(): 16 | img = Image.new(mode="RGB", size=(300, 300), color="white") 17 | draw = ImageDraw.Draw(img) 18 | draw.point([5, 150, 100, 250], fill="black") 19 | return img 20 | 21 | 22 | @pytest.fixture 23 | def setup_graph(): 24 | mol1 = TreeMolecule(smiles="CCCO", parent=None) 25 | reaction = TemplatedRetroReaction(mol=mol1, smarts="") 26 | 27 | return [mol1], [reaction], [(mol1, reaction)], ["green"] 28 | 29 | 30 | def test_crop(new_image): 31 | cropped = image.crop_image(new_image) 32 | 33 | assert cropped.width == 136 34 | assert cropped.height == 141 35 | assert cropped.getpixel((21, 21)) == (0, 0, 0) 36 | assert cropped.getpixel((116, 121)) == (0, 0, 0) 37 | 38 | 39 | def test_rounded_rectangle(new_image): 40 | color = (255, 0, 0) 41 | modified = image.draw_rounded_rectangle(new_image, color=color) 42 | 43 | assert modified.getpixel((0, 150)) == color 44 | assert modified.getpixel((150, 0)) == color 45 | assert modified.getpixel((299, 150)) == color 46 | assert modified.getpixel((150, 299)) == color 47 | 48 | 49 | def test_save_molecule_images(): 50 | nfiles = len(os.listdir(image.IMAGE_FOLDER)) 51 | 52 | mols = [ 53 | TreeMolecule(smiles="CCCO", parent=None), 54 | TreeMolecule(smiles="CCCO", parent=None), 55 | TreeMolecule(smiles="CCCCO", parent=None), 56 | ] 57 | 58 | image.save_molecule_images(mols, ["green", "green", "green"]) 59 | 60 | assert len(os.listdir(image.IMAGE_FOLDER)) == nfiles + 2 61 | 62 | image.save_molecule_images(mols, ["green", "orange", "green"]) 63 | 64 | assert len(os.listdir(image.IMAGE_FOLDER)) == nfiles + 2 65 | 66 | 67 | def test_visjs_page(mocker, tmpdir, setup_graph): 68 | mkdtemp_patch = mocker.patch("aizynthfinder.utils.image.tempfile.mkdtemp") 69 | mkdtemp_patch.return_value = str(tmpdir / "tmp") 70 | os.mkdir(tmpdir / "tmp") 71 | molecules, reactions, edges, frame_colors = setup_graph 72 | filename = str(tmpdir / "arch.tar") 73 | 74 | image.make_visjs_page(filename, molecules, reactions, edges, frame_colors) 75 | 76 | assert os.path.exists(filename) 77 | with TarFile(filename) as tarobj: 78 | assert "./route.html" in tarobj.getnames() 79 | assert len([name for name in tarobj.getnames() if name.endswith(".png")]) == 1 80 | 81 | 82 | def test_image_factory(request): 83 | route_path = Path(request.fspath).parent.parent / "data" / "branched_route.json" 84 | with open(route_path, "r") as fileobj: 85 | dict_ = json.load(fileobj) 86 | dict_["children"][0]["children"][1]["hide"] = True 87 | 88 | factory0 = image.RouteImageFactory(dict_) 89 | 90 | factory_tighter = image.RouteImageFactory(dict_, margin=50) 91 | assert factory0.image.width == factory_tighter.image.width + 150 92 | assert factory0.image.height == factory_tighter.image.height + 175 93 | 94 | factory_hidden = image.RouteImageFactory(dict_, show_all=False) 95 | assert factory0.image.width == factory_hidden.image.width 96 | assert factory0.image.height > factory_hidden.image.height 97 | -------------------------------------------------------------------------------- /tests/utils/test_local_onnx_model.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import numpy as np 4 | import pytest 5 | import pytest_mock 6 | from aizynthfinder.utils import models 7 | 8 | 9 | def test_local_onnx_model_predict(mock_onnx_model: pytest_mock.MockerFixture) -> None: 10 | onnx_model = models.LocalOnnxModel("test_model.onnx") 11 | output = onnx_model.predict(np.array([1])) 12 | expected_output = np.array([[0.2, 0.7, 0.1]]) 13 | 14 | assert np.array_equal(output, expected_output) 15 | 16 | 17 | def test_local_onnx_model_length(mock_onnx_model: pytest_mock.MockerFixture) -> None: 18 | onnx_model = models.LocalOnnxModel("test_model.onnx") 19 | output = len(onnx_model) 20 | expected_output = 3 21 | 22 | assert output == expected_output 23 | 24 | 25 | def test_local_onnx_model_output_size( 26 | mock_onnx_model: pytest_mock.MockerFixture, 27 | ) -> None: 28 | onnx_model = models.LocalOnnxModel("test_model.onnx") 29 | output = onnx_model.output_size 30 | expected_output = 3 31 | 32 | assert output == expected_output 33 | -------------------------------------------------------------------------------- /tests/utils/test_scscore.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import pytest 4 | from rdkit import Chem 5 | 6 | from aizynthfinder.utils.sc_score import SCScore 7 | 8 | # Dummy, tiny SCScore model 9 | _weights0 = [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1]] 10 | _weights1 = [[1, 1], [1, 1]] 11 | _weights = [_weights0, _weights1] 12 | _biases = [[0, 0], [0]] 13 | 14 | 15 | def test_scscore(tmpdir): 16 | filename = str(tmpdir / "dummy.pickle") 17 | with open(filename, "wb") as fileobj: 18 | pickle.dump((_weights, _biases), fileobj) 19 | scorer = SCScore(filename, 5) 20 | mol = Chem.MolFromSmiles("C") 21 | 22 | assert pytest.approx(scorer(mol), abs=1e-3) == 4.523 23 | --------------------------------------------------------------------------------