├── .github
└── workflows
│ ├── docs.yml
│ ├── publish_to_pypi.yml
│ └── tests.yml
├── .gitignore
├── CHANGELOG.md
├── CITATION.cff
├── LICENSE
├── README.md
├── aizynthfinder
├── __init__.py
├── aizynthfinder.py
├── analysis
│ ├── __init__.py
│ ├── routes.py
│ ├── tree_analysis.py
│ └── utils.py
├── chem
│ ├── __init__.py
│ ├── mol.py
│ ├── reaction.py
│ └── serialization.py
├── context
│ ├── __init__.py
│ ├── collection.py
│ ├── config.py
│ ├── policy
│ │ ├── __init__.py
│ │ ├── expansion_strategies.py
│ │ ├── filter_strategies.py
│ │ ├── policies.py
│ │ └── utils.py
│ ├── scoring
│ │ ├── __init__.py
│ │ ├── collection.py
│ │ └── scorers.py
│ └── stock
│ │ ├── __init__.py
│ │ ├── queries.py
│ │ └── stock.py
├── data
│ ├── default_training.yml
│ ├── logging.yml
│ └── templates
│ │ ├── reaction_tree.dot
│ │ └── reaction_tree.thtml
├── interfaces
│ ├── __init__.py
│ ├── aizynthapp.py
│ ├── aizynthcli.py
│ └── gui
│ │ ├── __init__.py
│ │ ├── clustering.py
│ │ ├── pareto_fronts.py
│ │ └── utils.py
├── reactiontree.py
├── search
│ ├── __init__.py
│ ├── andor_trees.py
│ ├── breadth_first
│ │ ├── __init__.py
│ │ ├── nodes.py
│ │ └── search_tree.py
│ ├── dfpn
│ │ ├── __init__.py
│ │ ├── nodes.py
│ │ └── search_tree.py
│ ├── mcts
│ │ ├── __init__.py
│ │ ├── node.py
│ │ ├── search.py
│ │ ├── state.py
│ │ └── utils.py
│ └── retrostar
│ │ ├── __init__.py
│ │ ├── cost.py
│ │ ├── nodes.py
│ │ └── search_tree.py
├── tools
│ ├── __init__.py
│ ├── cat_output.py
│ ├── download_public_data.py
│ └── make_stock.py
└── utils
│ ├── __init__.py
│ ├── bonds.py
│ ├── exceptions.py
│ ├── files.py
│ ├── image.py
│ ├── loading.py
│ ├── logging.py
│ ├── math.py
│ ├── models.py
│ ├── mongo.py
│ ├── paths.py
│ ├── sc_score.py
│ └── type_utils.py
├── contrib
└── notebook.ipynb
├── docs
├── analysis-rel.png
├── analysis-seq.png
├── cli.rst
├── conf.py
├── configuration.rst
├── gui.rst
├── gui_clustering.png
├── gui_input.png
├── gui_results.png
├── howto.rst
├── index.rst
├── line-desc.png
├── python_interface.rst
├── relationships.rst
├── scoring.rst
├── sequences.rst
├── stocks.rst
├── treesearch-rel.png
└── treesearch-seq.png
├── env-dev.yml
├── plugins
├── README.md
└── expansion_strategies.py
├── poetry.lock
├── pyproject.toml
├── tasks.py
└── tests
├── __init__.py
├── breadth_first
├── __init__.py
├── test_nodes.py
└── test_search.py
├── chem
├── __init__.py
├── test_mol.py
├── test_reaction.py
└── test_serialization.py
├── conftest.py
├── context
├── __init__.py
├── conftest.py
├── data
│ ├── custom_loader.py
│ ├── custom_loader2.py
│ ├── linear_route_w_metadata.json
│ ├── simple_filter.bloom
│ └── test_reactions_template.csv
├── test_collection.py
├── test_expansion_strategies.py
├── test_mcts_config.py
├── test_policy.py
├── test_score.py
└── test_stock.py
├── data
├── and_or_tree.json
├── branched_route.json
├── combined_example_tree.json
├── combined_example_tree2.json
├── dummy2_raw_template_library.csv
├── dummy_noclass_raw_template_library.csv
├── dummy_raw_template_library.csv
├── dummy_sani_raw_template_library.csv
├── full_search_tree.json.gz
├── input_checkpoint.json.gz
├── linear_route.json
├── post_processing_test.py
├── pre_processing_test.py
├── routes_for_clustering.json
├── test_reactions_template.csv
└── tree_for_clustering.json
├── dfpn
├── __init__.py
├── test_nodes.py
└── test_search.py
├── mcts
├── __init__.py
├── conftest.py
├── test_multiobjective.py
├── test_node.py
├── test_reward.py
├── test_serialization.py
└── test_tree.py
├── retrostar
├── __init__.py
├── conftest.py
├── data
│ └── andor_tree_for_clustering.json
├── test_retrostar.py
├── test_retrostar_cost.py
└── test_retrostar_nodes.py
├── test_analysis.py
├── test_cli.py
├── test_expander.py
├── test_finder.py
├── test_gui.py
├── test_reactiontree.py
└── utils
├── __init__.py
├── test_bonds.py
├── test_dynamic_loading.py
├── test_external_tf_models.py
├── test_file_utils.py
├── test_image.py
├── test_local_onnx_model.py
└── test_scscore.py
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | name: docs
2 |
3 | on:
4 | push:
5 | branches: [ master ]
6 |
7 | jobs:
8 | build:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v3
12 | - name: build
13 | run: |
14 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh |
15 | bash -s -- --batch
16 | conda env create -f env-dev.yml
17 | conda run --name aizynth-dev poetry install -E all
18 | conda run --name aizynth-dev inv build-docs
19 | - name: deploy
20 | uses: peaceiris/actions-gh-pages@v3
21 | with:
22 | github_token: ${{ secrets.GITHUB_TOKEN }}
23 | publish_dir: ./docs/build/html
24 | publish_branch: gh-pages
25 | force_orphan: true
26 |
--------------------------------------------------------------------------------
/.github/workflows/publish_to_pypi.yml:
--------------------------------------------------------------------------------
1 | name: Publish Python distributions to PyPI
2 |
3 | on: push
4 |
5 | jobs:
6 | build-n-publish:
7 | name: Build and publish Python distributions to PyPI
8 | runs-on: ubuntu-latest
9 |
10 | steps:
11 | - uses: actions/checkout@v3
12 | - name: Set up Python
13 | uses: actions/setup-python@v4
14 | with:
15 | python-version: "3.x"
16 | - name: Install pypa/build
17 | run: >-
18 | python3 -m
19 | pip install
20 | build
21 | --user
22 | - name: Build a binary wheel and a source tarball
23 | run: >-
24 | python3 -m
25 | build
26 | --sdist
27 | --wheel
28 | --outdir dist/
29 | .
30 | - name: Publish distribution to PyPI
31 | if: startsWith(github.ref, 'refs/tags')
32 | uses: pypa/gh-action-pypi-publish@release/v1
33 | with:
34 | password: ${{ secrets.PYPI_API_TOKEN }}
35 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: tests
2 |
3 | on:
4 | push:
5 | branches: [master]
6 | pull_request:
7 | branches: [master]
8 |
9 | jobs:
10 | build:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - uses: actions/checkout@v3
14 | - name: Run
15 | run: |
16 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh |
17 | bash -s -- --batch
18 | conda env create -f env-dev.yml
19 | conda run --name aizynth-dev poetry install -E all -E tf
20 | conda run --name aizynth-dev inv full-tests
21 | - name: Upload coverage to Codecov
22 | uses: codecov/codecov-action@v1
23 | with:
24 | token: ${{ secrets.CODECOV_TOKEN }}
25 | files: ./coverage.xml
26 | directory: ./coverage/
27 | name: codecov-aizynth
28 | fail_ci_if_error: false
29 | verbose: true
30 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/build/
73 | docs/aizynth*
74 | docs/modules.rst
75 | docs/cli_help.txt
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # PyCharmm idea
135 | .idea/
136 |
137 | #VS Code
138 | .vscode/
139 |
140 | # Pytest coverage output
141 | coverage/
142 |
143 | # Files created by the tool
144 | smiles.txt
145 | config_local.yml
146 | devtools/seldon-output-*/
147 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | # YAML 1.2
2 | ---
3 | abstract: "We present the open-source AiZynthFinder software that can be readily used in retrosynthetic planning. The algorithm is based on a Monte Carlo tree search that recursively breaks down a molecule to purchasable precursors. The tree search is guided by an artificial neural network policy that suggests possible precursors by utilizing a library of known reaction templates. The software is fast and can typically find a solution in less than 10 s and perform a complete search in less than 1 min. Moreover, the development of the code was guided by a range of software engineering principles such as automatic testing, system design and continuous integration leading to robust software with high maintainability. Finally, the software is well documented to make it suitable for beginners. The software is available at http://www.github.com/MolecularAI/aizynthfinder."
4 | authors:
5 | -
6 | family-names: Genheden
7 | given-names: Samuel
8 | -
9 | family-names: Thakkar
10 | given-names: Amol
11 | -
12 | family-names: "Chadimová"
13 | given-names: Veronika
14 | -
15 | family-names: Reymond
16 | given-names: "Jean-Louis"
17 | -
18 | family-names: Engkvist
19 | given-names: Ola
20 | -
21 | family-names: Bjerrum
22 | given-names: Esben
23 | orcid: "https://orcid.org/0000-0003-1614-7376"
24 | cff-version: "1.1.0"
25 | date-released: 2020-12-08
26 | doi: "https://doi.org/10.1186/s13321-020-00472-1"
27 | identifiers:
28 | -
29 | type: doi
30 | value: "10.1186/s13321-020-00472-1"
31 | keywords:
32 | - retrosynthesis
33 | - casp
34 | - retrosynthesis
35 | - cheminformatics
36 | - "neural-networks"
37 | - "monte-carlo-tree-search"
38 | - "chemical-reactions"
39 | - astrazeneca
40 | - "reaction-informatics"
41 | license: MIT
42 | message: "If you use this software, please cite it using these metadata."
43 | repository-code: "https://github.com/MolecularAI/aizynthfinder"
44 | title: AiZynthFinder
45 | version: "2.2.1"
46 | ...
47 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2020 Samuel Genheden and Esben Bjerrum
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4 |
5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6 |
7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/aizynthfinder/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/aizynthfinder/__init__.py
--------------------------------------------------------------------------------
/aizynthfinder/analysis/__init__.py:
--------------------------------------------------------------------------------
1 | """ Sub-package containing analysis routines
2 | """
3 | from aizynthfinder.analysis.tree_analysis import TreeAnalysis # isort: skip
4 | from aizynthfinder.analysis.routes import RouteCollection
5 | from aizynthfinder.analysis.utils import RouteSelectionArguments
6 |
--------------------------------------------------------------------------------
/aizynthfinder/chem/__init__.py:
--------------------------------------------------------------------------------
1 | """ Sub-package containing chemistry routines
2 | """
3 | from aizynthfinder.chem.mol import (
4 | Molecule,
5 | MoleculeException,
6 | TreeMolecule,
7 | UniqueMolecule,
8 | none_molecule,
9 | )
10 | from aizynthfinder.chem.reaction import (
11 | FixedRetroReaction,
12 | RetroReaction,
13 | SmilesBasedRetroReaction,
14 | TemplatedRetroReaction,
15 | hash_reactions,
16 | )
17 | from aizynthfinder.chem.serialization import (
18 | MoleculeDeserializer,
19 | MoleculeSerializer,
20 | deserialize_action,
21 | serialize_action,
22 | )
23 |
--------------------------------------------------------------------------------
/aizynthfinder/chem/serialization.py:
--------------------------------------------------------------------------------
1 | """ Module containing helper classes and routines for serialization.
2 | """
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | import aizynthfinder.chem
8 | from aizynthfinder.utils.loading import load_dynamic_class
9 |
10 | if TYPE_CHECKING:
11 | from aizynthfinder.chem import RetroReaction
12 | from aizynthfinder.utils.type_utils import Any, Dict, Optional, Sequence, StrDict
13 |
14 |
15 | class MoleculeSerializer:
16 | """
17 | Utility class for serializing molecules
18 |
19 | The id of the molecule to be serialized can be obtained with:
20 |
21 | .. code-block::
22 |
23 | serializer = MoleculeSerializer()
24 | mol = Molecule(smiles="CCCO")
25 | idx = serializer[mol]
26 |
27 | which will take care of the serialization of the molecule.
28 | """
29 |
30 | def __init__(self) -> None:
31 | self._store: Dict[int, Any] = {}
32 |
33 | def __getitem__(self, mol: Optional[aizynthfinder.chem.Molecule]) -> Optional[int]:
34 | if mol is None:
35 | return None
36 |
37 | id_ = id(mol)
38 | if id_ not in self._store:
39 | self._add_mol(mol)
40 | return id_
41 |
42 | @property
43 | def store(self) -> Dict[int, Any]:
44 | """Return all serialized molecules as a dictionary"""
45 | return self._store
46 |
47 | def _add_mol(self, mol: aizynthfinder.chem.Molecule) -> None:
48 | id_ = id(mol)
49 | dict_ = {"smiles": mol.smiles, "class": mol.__class__.__name__}
50 | if isinstance(mol, aizynthfinder.chem.TreeMolecule):
51 | dict_["parent"] = self[mol.parent]
52 | dict_["transform"] = mol.transform
53 | if not mol.parent:
54 | dict_["smiles"] = mol.original_smiles
55 | else:
56 | dict_["smiles"] = mol.mapped_smiles
57 | self._store[id_] = dict_
58 |
59 |
60 | class MoleculeDeserializer:
61 | """
62 | Utility class for deserializing molecules.
63 | The serialized molecules are created upon instantiation of the class.
64 |
65 | The deserialized molecules can be obtained with:
66 |
67 | .. code-block::
68 |
69 | deserializer = MoleculeDeserializer()
70 | mol = deserializer[idx]
71 |
72 | """
73 |
74 | def __init__(self, store: Dict[int, Any]) -> None:
75 | self._objects: Dict[int, Any] = {}
76 | self._create_molecules(store)
77 |
78 | def __getitem__(self, id_: Optional[int]) -> Optional[aizynthfinder.chem.Molecule]:
79 | if id_ is None:
80 | return None
81 | return self._objects[id_]
82 |
83 | def get_tree_molecules(
84 | self, ids: Sequence[int]
85 | ) -> Sequence[aizynthfinder.chem.TreeMolecule]:
86 | """
87 | Return multiple deserialized tree molecules
88 |
89 | :param ids: the list of IDs to deserialize
90 | :return: the molecule objects
91 | """
92 | objects = []
93 | for id_ in ids:
94 | obj = self[id_]
95 | if obj is None or not isinstance(obj, aizynthfinder.chem.TreeMolecule):
96 | raise ValueError(f"Failed to deserialize molecule with id {id_}")
97 | objects.append(obj)
98 | return objects
99 |
100 | def _create_molecules(self, store: dict) -> None:
101 | for id_, spec in store.items():
102 | if isinstance(id_, str):
103 | id_ = int(id_)
104 |
105 | cls = spec["class"]
106 | if "parent" in spec:
107 | spec["parent"] = self[spec["parent"]]
108 |
109 | kwargs = dict(spec)
110 | del kwargs["class"]
111 | self._objects[id_] = getattr(aizynthfinder.chem, cls)(**kwargs)
112 |
113 |
114 | def serialize_action(
115 | action: RetroReaction, molecule_store: MoleculeSerializer
116 | ) -> StrDict:
117 | """
118 | Serialize a retrosynthesis action
119 |
120 | :param action: the (re)action to serialize
121 | :param molecule_store: the molecule serialization object
122 | :return: the action as a dictionary
123 | """
124 | dict_ = action.to_dict()
125 | dict_["mol"] = molecule_store[dict_["mol"]]
126 | if not action.unqueried:
127 | dict_["reactants"] = [
128 | [molecule_store[item] for item in lst_] for lst_ in action.reactants
129 | ]
130 | dict_["class"] = f"{action.__class__.__module__}.{action.__class__.__name__}"
131 | return dict_
132 |
133 |
134 | def deserialize_action(
135 | dict_: StrDict, molecule_store: MoleculeDeserializer
136 | ) -> RetroReaction:
137 | """
138 | Deserialize a retrosynthesis action
139 |
140 | :param dict_: the (re)action as a dictionary
141 | :param molecule_store: the molecule deserialization object
142 | :return: the created action object
143 | """
144 | mol_spec = dict_.pop("mol")
145 | dict_["mol"] = molecule_store.get_tree_molecules([mol_spec])[0]
146 | try:
147 | class_spec = dict_.pop("class")
148 | except KeyError:
149 | class_spec = "aizynthfinder.chem.TemplatedRetroReaction"
150 | cls = load_dynamic_class(class_spec)
151 | if "reactants" in dict_:
152 | reactants = [
153 | molecule_store.get_tree_molecules(lst_) for lst_ in dict_.pop("reactants")
154 | ]
155 | return cls.from_serialization(dict_, reactants)
156 | return cls(**dict_)
157 |
--------------------------------------------------------------------------------
/aizynthfinder/context/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/aizynthfinder/context/__init__.py
--------------------------------------------------------------------------------
/aizynthfinder/context/collection.py:
--------------------------------------------------------------------------------
1 | """ Module containing a class that is the base class for all collection classes (stock, policies, scorers)
2 | """
3 | from __future__ import annotations
4 |
5 | import abc
6 | from typing import TYPE_CHECKING
7 |
8 | from aizynthfinder.utils.logging import logger
9 |
10 | if TYPE_CHECKING:
11 | from aizynthfinder.utils.type_utils import Any, List, Optional, StrDict, Union
12 |
13 |
14 | class ContextCollection(abc.ABC):
15 | """
16 | Abstract base class for a collection of items
17 | that can be loaded and then (de-)selected.
18 |
19 | One can obtain individual items with:
20 |
21 | .. code-block::
22 |
23 | an_item = collection["key"]
24 |
25 | And delete items with
26 |
27 | .. code-block::
28 |
29 | del collection["key"]
30 |
31 |
32 | """
33 |
34 | _single_selection = False
35 | _collection_name = "collection"
36 |
37 | def __init__(self) -> None:
38 | self._items: StrDict = {}
39 | self._selection: List[str] = []
40 | self._logger = logger()
41 |
42 | def __delitem__(self, key: str) -> None:
43 | if key not in self._items:
44 | raise KeyError(
45 | f"{self._collection_name.capitalize()} with name {key} not loaded."
46 | )
47 | del self._items[key]
48 |
49 | def __getitem__(self, key: str) -> Any:
50 | if key not in self._items:
51 | raise KeyError(
52 | f"{self._collection_name.capitalize()} with name {key} not loaded."
53 | )
54 | return self._items[key]
55 |
56 | def __len__(self) -> int:
57 | return len(self._items)
58 |
59 | @property
60 | def items(self) -> List[str]:
61 | """The available item keys"""
62 | return list(self._items.keys())
63 |
64 | @property
65 | def selection(self) -> Union[List[str], str, None]:
66 | """The keys of the selected item(s)"""
67 | if self._single_selection:
68 | return self._selection[0] if self._selection else None
69 | return self._selection
70 |
71 | @selection.setter
72 | def selection(self, value: str) -> None:
73 | self.select(value)
74 |
75 | def deselect(self, key: Optional[str] = None) -> None:
76 | """
77 | Deselect one or all items
78 |
79 | If no key is passed, all items will be deselected.
80 |
81 | :param key: the key of the item to deselect, defaults to None
82 | :raises KeyError: if the key is not among the selected ones
83 | """
84 | if not key:
85 | self._selection = []
86 | return
87 |
88 | if key not in self._selection:
89 | raise KeyError(f"Cannot deselect {key} because it is not selected")
90 | self._selection.remove(key)
91 |
92 | @abc.abstractmethod
93 | def load(self, *_: Any) -> None:
94 | """Load an item. Needs to be implemented by a sub-class"""
95 |
96 | @abc.abstractmethod
97 | def load_from_config(self, **config: Any) -> None:
98 | """Load items from a configuration. Needs to be implemented by a sub-class"""
99 |
100 | def select(self, value: Union[str, List[str]], append: bool = False) -> None:
101 | """
102 | Select one or more items.
103 |
104 | If this is a single selection collection, only a single value is accepted.
105 | If this is a multiple selection collection it will overwrite the selection completely,
106 | unless ``append`` is True and a single key is given.
107 |
108 | :param value: the key or keys of the item(s) to select
109 | :param append: if True will append single keys to existing selection
110 | :raises ValueError: if this a single collection and value is multiple keys
111 | :raises KeyError: if at least one of the keys are not corresponding to a loaded item
112 | """
113 | if self._single_selection and not isinstance(value, str) and len(value) > 1:
114 | raise ValueError(f"Cannot select more than one {self._collection_name}")
115 |
116 | keys = [value] if isinstance(value, str) else value
117 |
118 | for key in keys:
119 | if key not in self._items:
120 | raise KeyError(
121 | f"Invalid key specified {key} when selecting {self._collection_name}"
122 | )
123 |
124 | if self._single_selection:
125 | self._selection = [keys[0]]
126 | elif isinstance(value, str) and append:
127 | self._selection.append(value)
128 | else:
129 | self._selection = list(keys)
130 |
131 | self._logger.info(f"Selected as {self._collection_name}: {', '.join(keys)}")
132 |
133 | def select_all(self) -> None:
134 | """Select all loaded items"""
135 | if self.items:
136 | self.select(self.items)
137 |
138 | def select_first(self) -> None:
139 | """Select the first loaded item"""
140 | if self.items:
141 | self.select(self.items[0])
142 |
143 | def select_last(self) -> None:
144 | """Select the last loaded item"""
145 | if self.items:
146 | self.select(self.items[-1])
147 |
--------------------------------------------------------------------------------
/aizynthfinder/context/policy/__init__.py:
--------------------------------------------------------------------------------
1 | """ Sub-package containing policy routines
2 | """
3 |
4 | from aizynthfinder.context.policy.expansion_strategies import (
5 | ExpansionStrategy,
6 | MultiExpansionStrategy,
7 | TemplateBasedDirectExpansionStrategy,
8 | TemplateBasedExpansionStrategy,
9 | )
10 | from aizynthfinder.context.policy.filter_strategies import (
11 | BondFilter,
12 | FilterStrategy,
13 | QuickKerasFilter,
14 | ReactantsCountFilter,
15 | )
16 | from aizynthfinder.context.policy.policies import ExpansionPolicy, FilterPolicy
17 | from aizynthfinder.utils.exceptions import PolicyException
18 |
--------------------------------------------------------------------------------
/aizynthfinder/context/policy/utils.py:
--------------------------------------------------------------------------------
1 | """ Module containing helper routines for policies
2 | """
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | import numpy as np
8 |
9 | if TYPE_CHECKING:
10 | from aizynthfinder.chem import TreeMolecule
11 | from aizynthfinder.chem.reaction import RetroReaction
12 | from aizynthfinder.utils.type_utils import Any, Union
13 |
14 |
15 | def _make_fingerprint(
16 | obj: Union[TreeMolecule, RetroReaction], model: Any, chiral: bool = False
17 | ) -> np.ndarray:
18 | fingerprint = obj.fingerprint(radius=2, nbits=len(model), chiral=chiral)
19 | return fingerprint.reshape([1, len(model)])
20 |
--------------------------------------------------------------------------------
/aizynthfinder/context/scoring/__init__.py:
--------------------------------------------------------------------------------
1 | """ Sub-package containing scoring routines
2 | """
3 |
4 | from aizynthfinder.context.scoring.collection import ScorerCollection
5 | from aizynthfinder.context.scoring.scorers import (
6 | AverageTemplateOccurrenceScorer,
7 | BrokenBondsScorer,
8 | CombinedScorer,
9 | DeltaSyntheticComplexityScorer,
10 | FractionInStockScorer,
11 | MaxTransformScorerer,
12 | NumberOfPrecursorsInStockScorer,
13 | NumberOfPrecursorsScorer,
14 | NumberOfReactionsScorer,
15 | PriceSumScorer,
16 | ReactionClassMembershipScorer,
17 | RouteCostScorer,
18 | RouteSimilarityScorer,
19 | Scorer,
20 | StateScorer,
21 | StockAvailabilityScorer,
22 | SUPPORT_DISTANCES,
23 | )
24 | from aizynthfinder.utils.exceptions import ScorerException
25 |
--------------------------------------------------------------------------------
/aizynthfinder/context/stock/__init__.py:
--------------------------------------------------------------------------------
1 | """ Sub-package containing stock routines
2 | """
3 | from aizynthfinder.context.stock.queries import (
4 | InMemoryInchiKeyQuery,
5 | MongoDbInchiKeyQuery,
6 | StockQueryMixin,
7 | )
8 | from aizynthfinder.context.stock.stock import Stock
9 | from aizynthfinder.utils.exceptions import StockException
10 |
--------------------------------------------------------------------------------
/aizynthfinder/data/default_training.yml:
--------------------------------------------------------------------------------
1 | library_headers: ["index", "ID", "reaction_hash", "reactants", "products", "classification", "retro_template", "template_hash", "selectivity", "outcomes", "template_code"]
2 | column_map:
3 | reaction_hash: reaction_hash
4 | reactants: reactants
5 | products: products
6 | retro_template: retro_template
7 | template_hash: template_hash
8 | metadata_headers: ["template_hash", "classification"]
9 | in_csv_headers: False
10 | csv_sep: ","
11 | reaction_smiles_column: ""
12 | output_path: "."
13 | file_prefix: ""
14 | file_postfix:
15 | raw_library: _raw_template_library.csv
16 | library: _template_library.csv
17 | false_library: _template_library_false.csv
18 | training_labels: _training_labels.npz
19 | validation_labels: _validation_labels.npz
20 | testing_labels: _testing_labels.npz
21 | training_inputs: _training_inputs.npz
22 | validation_inputs: _validation_inputs.npz
23 | testing_inputs: _testing_inputs.npz
24 | training_inputs2: _training_inputs2.npz
25 | validation_inputs2: _validation_inputs2.npz
26 | testing_inputs2: _testing_inputs2.npz
27 | training_library: _training.csv
28 | validation_library: _validation.csv
29 | testing_library: _testing.csv
30 | unique_templates: _unique_templates.hdf5
31 | split_size:
32 | training: 0.9
33 | testing: 0.05
34 | validation: 0.05
35 | batch_size: 256
36 | epochs: 100
37 | fingerprint_radius: 2
38 | fingerprint_len: 2048
39 | template_occurrence: 3
40 | remove_unsanitizable_products: False
41 | negative_data:
42 | random_trials: 1000
43 | recommender_model: ""
44 | recommender_topn: 20
45 | model:
46 | drop_out: 0.4
47 | hidden_nodes: 512
48 |
--------------------------------------------------------------------------------
/aizynthfinder/data/logging.yml:
--------------------------------------------------------------------------------
1 | version: 1
2 | formatters:
3 | file:
4 | format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
5 | console:
6 | format: '%(message)s'
7 | handlers:
8 | console:
9 | class: logging.StreamHandler
10 | level: INFO
11 | formatter: console
12 | stream: ext://sys.stdout
13 | file:
14 | class: logging.FileHandler
15 | level: DEBUG
16 | formatter: file
17 | filename: aizynthfinder.log
18 | loggers:
19 | aizynthfinder:
20 | level: DEBUG
21 | handlers:
22 | - console
23 | - file
24 | root:
25 | level: ERROR
26 | handlers: []
27 |
--------------------------------------------------------------------------------
/aizynthfinder/data/templates/reaction_tree.dot:
--------------------------------------------------------------------------------
1 | strict digraph "" {
2 | graph [layout="dot",
3 | rankdir="RL",
4 | {% if use_splines %}
5 | splines="ortho"
6 | {% endif %}
7 | ];
8 | node [label="\N"];
9 | {% for molecule, image_filepath in molecules.items() %}
10 | {{ id(molecule) }} [
11 | label="",
12 | color="white",
13 | shape="none",
14 | image="{{ image_filepath }}"
15 | ];
16 | {% endfor %}
17 | {% for reaction, reaction_shape in reactions %}
18 | {{ id(reaction) }} [
19 | label="",
20 | fillcolor="black",
21 | shape="{{ reaction_shape }}",
22 | style="filled",
23 | width="0.1",
24 | fixedsize="true"
25 | ];
26 | {% endfor %}
27 | {% for nodes in edges %}
28 | {{ id(nodes[0]) }} -> {{ id(nodes[1]) }} [arrowhead="none"];
29 | {% endfor %}
30 | }
--------------------------------------------------------------------------------
/aizynthfinder/data/templates/reaction_tree.thtml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
61 |
62 |
63 |
--------------------------------------------------------------------------------
/aizynthfinder/interfaces/__init__.py:
--------------------------------------------------------------------------------
1 | """ Module for interfaces to the AiZynthFinder application
2 | """
3 | try:
4 | from aizynthfinder.interfaces.aizynthapp import AiZynthApp # noqa
5 | except ModuleNotFoundError:
6 | pass
7 |
--------------------------------------------------------------------------------
/aizynthfinder/interfaces/gui/__init__.py:
--------------------------------------------------------------------------------
1 | """ Package for GUI extensions
2 | """
3 |
--------------------------------------------------------------------------------
/aizynthfinder/interfaces/gui/clustering.py:
--------------------------------------------------------------------------------
1 | """ Module containing a GUI extension for clustering
2 | """
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | import numpy as np
8 | from IPython.display import display
9 | from ipywidgets import BoundedIntText, Button, HBox, Label, Output, Tab
10 |
11 | try:
12 | import matplotlib.pylab as plt
13 | from route_distances.clustering import ClusteringHelper
14 | from scipy.cluster.hierarchy import dendrogram
15 | except ImportError:
16 | raise ImportError(
17 | "Clustering is not supported by this installation."
18 | " Please install aizynthfinder with extras dependencies."
19 | )
20 |
21 |
22 | if TYPE_CHECKING:
23 | from aizynthfinder.analysis import RouteCollection
24 | from aizynthfinder.interfaces.aizynthapp import AiZynthApp
25 | from aizynthfinder.utils.type_utils import StrDict
26 |
27 |
28 | class ClusteringGui:
29 | """
30 | GUI extension to cluster routes
31 |
32 | :param routes: the routes to cluster
33 | :param content: what to cluster on
34 | """
35 |
36 | def __init__(self, routes: RouteCollection, content: str = "both"):
37 | self._routes = routes
38 | self._routes.distance_matrix(content=content, recreate=True)
39 | self._input: StrDict = dict()
40 | self._output: StrDict = dict()
41 | self._buttons: StrDict = dict()
42 | self._create_dendrogram()
43 | self._create_input()
44 | self._create_output()
45 |
46 | @classmethod
47 | def from_app(cls, app: AiZynthApp, content: str = "both"):
48 | """
49 | Helper function to create a GUI from a GUI app interface
50 |
51 | :param app: the app to extract the routes from
52 | :param content: what to cluster on
53 | :return: the GUI object
54 | """
55 | if app.finder.routes is None:
56 | raise ValueError("Cannot initialize GUI from no routes")
57 | return ClusteringGui(app.finder.routes, content)
58 |
59 | def _create_dendrogram(self) -> None:
60 | dend_out = Output(
61 | layout={"width": "99%", "height": "310px", "overflow_y": "auto"}
62 | )
63 | with dend_out:
64 | print("This is the hierarchy of the routes")
65 | fig = plt.Figure()
66 | dendrogram(
67 | ClusteringHelper(self._routes.distance_matrix()).linkage_matrix(),
68 | color_threshold=0.0,
69 | labels=np.arange(1, len(self._routes) + 1),
70 | ax=fig.gca(),
71 | )
72 | fig.gca().set_xlabel("Route")
73 | fig.gca().set_ylabel("Distance")
74 | display(fig)
75 | display(dend_out)
76 |
77 | def _create_input(self) -> None:
78 | self._input["number_clusters"] = BoundedIntText(
79 | continuous_update=True,
80 | min=1,
81 | max=len(self._routes) - 1,
82 | layout={"width": "80px"},
83 | )
84 | self._buttons["cluster"] = Button(description="Cluster")
85 | self._buttons["cluster"].on_click(self._on_cluster_button_clicked)
86 | box = HBox(
87 | [
88 | Label("Number of clusters to make"),
89 | self._input["number_clusters"],
90 | self._buttons["cluster"],
91 | ]
92 | )
93 | display(box)
94 | help_out = Output()
95 | with help_out:
96 | print(
97 | "Optimization is carried out if the number of given clusters are less than 2"
98 | )
99 | display(help_out)
100 |
101 | def _create_output(self) -> None:
102 | self._output["clusters"] = Tab()
103 | display(self._output["clusters"])
104 |
105 | def _on_cluster_button_clicked(self, _) -> None:
106 | self._buttons["cluster"].enabled = False
107 | self._routes.cluster(self._input["number_clusters"].value)
108 | self._buttons["cluster"].enabled = True
109 |
110 | outputs = []
111 | for i, cluster in enumerate(self._routes.clusters or []):
112 | output = Output(
113 | layout={
114 | "border": "1px solid silver",
115 | "width": "99%",
116 | "height": "500px",
117 | "overflow_y": "auto",
118 | }
119 | )
120 | with output:
121 | for image in cluster.images:
122 | print(f"Route {self._routes.images.index(image)+1}")
123 | display(image)
124 | outputs.append(output)
125 | self._output["clusters"].set_title(i, f"Cluster {i+1}")
126 | self._output["clusters"].children = outputs
127 |
--------------------------------------------------------------------------------
/aizynthfinder/interfaces/gui/utils.py:
--------------------------------------------------------------------------------
1 | """Module containing utility functions for GUI.
2 | """
3 | from __future__ import annotations
4 |
5 | from collections import defaultdict
6 | from typing import TYPE_CHECKING
7 |
8 | import ipywidgets as widgets
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | import pandas as pd
12 | import seaborn as sns
13 | from IPython.display import HTML, display
14 | from paretoset import paretorank
15 |
16 | if TYPE_CHECKING:
17 | from aizynthfinder.analysis.routes import RouteCollection
18 | from aizynthfinder.utils.type_utils import List, Optional
19 |
20 |
21 | def pareto_fronts_plot(
22 | routes: RouteCollection,
23 | ) -> None:
24 | """Plot the pareto front(s).
25 |
26 | :param routes: the route collection to plot as Pareto fronts
27 | """
28 | scorer_names = list(routes.scores[0].keys())
29 | scores = np.array(
30 | [[score_dict[name] for name in scorer_names] for score_dict in routes.scores]
31 | )
32 | direction_arr = np.repeat("max", len(scorer_names))
33 | pareto_ranks = paretorank(scores, sense=direction_arr, distinct=False)
34 |
35 | pareto_fronts = pd.DataFrame(scores, columns=scorer_names)
36 | pareto_fronts.loc[:, "pareto_rank"] = pareto_ranks
37 | pareto_fronts.loc[:, "route"] = np.arange(1, scores.shape[0] + 1)
38 | pareto_fronts_unique = pareto_fronts.drop_duplicates(scorer_names)
39 | # Apply the default theme
40 | sns.set_theme()
41 | fig = sns.relplot(
42 | data=pareto_fronts_unique,
43 | x=scorer_names[0],
44 | y=scorer_names[1],
45 | hue="pareto_rank",
46 | kind="line",
47 | markers=True,
48 | )
49 | fig.set(
50 | xlabel=f"{scorer_names[0]}",
51 | ylabel=f"{scorer_names[1]}",
52 | title="Pareto Fronts",
53 | )
54 |
55 | objectives2solutions = defaultdict(list)
56 | for _, row in pareto_fronts.iterrows():
57 | x_val = row[scorer_names[0]]
58 | y_val = row[scorer_names[1]]
59 | objectives2solutions[(x_val, y_val)].append(int(row["route"]))
60 |
61 | # Add route label on pareto line(s)
62 | for _, row in pareto_fronts_unique.iterrows():
63 | x_val = row[scorer_names[0]]
64 | y_val = row[scorer_names[1]]
65 | point_val = f"Option {_values_to_string(objectives2solutions[(x_val, y_val)])}"
66 | plt.text(x=x_val, y=y_val, s=point_val, size=10)
67 | plt.show()
68 |
69 |
70 | def route_display(
71 | index: Optional[int], routes: RouteCollection, output_widget: widgets.Output
72 | ) -> None:
73 | """
74 | Display a route with auxillary information in a widget
75 |
76 | :param index: the index of the route to display
77 | :param routes: the route collection
78 | :param output_widget: the widget to display the route on
79 | """
80 | if index is None or routes is None or index >= len(routes):
81 | return
82 |
83 | route = routes[index]
84 | state = route["node"].state
85 | status = "Solved" if state.is_solved else "Not Solved"
86 |
87 | output_widget.clear_output()
88 | with output_widget:
89 | display(HTML(f"{status}"))
90 | table_content = "".join(
91 | f"{name} | {score:.4f} |
"
92 | if isinstance(score, (float, int))
93 | else f"{name} | ({', '.join(f'{val:.4f}' for val in score)}) |
"
94 | for name, score in route["all_score"].items()
95 | )
96 | display(HTML(f"
"))
97 | display(HTML("Compounds to Procure"))
98 | display(state.to_image())
99 | display(HTML("Steps"))
100 | display(routes[index]["image"])
101 |
102 |
103 | def _values_to_string(vals: List[int]) -> str:
104 | """
105 | Given a list of integers, produce a nice-looking string
106 |
107 | Example:
108 | >> _values_to_string([1,2,3,10,11,19])
109 | "1-3, 10-11, 19"
110 | """
111 | groups = []
112 | start = end = vals[0]
113 | for prev_val, val in zip(vals[:-1], vals[1:]):
114 | if prev_val == val - 1:
115 | end += 1
116 | else:
117 | groups.append((start, end))
118 | start = end = val
119 | groups.append((start, end))
120 |
121 | group_strs = []
122 | for start, end in groups:
123 | if start == end:
124 | group_strs.append(str(start))
125 | else:
126 | group_strs.append(f"{start}-{end}")
127 | if len(group_strs) % 5 == 0:
128 | group_strs[-1] = "\n" + group_strs[-1]
129 | return ", ".join(group_strs)
130 |
--------------------------------------------------------------------------------
/aizynthfinder/search/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/aizynthfinder/search/__init__.py
--------------------------------------------------------------------------------
/aizynthfinder/search/breadth_first/__init__.py:
--------------------------------------------------------------------------------
1 | """ Sub-package containing breadth first routines
2 | """
3 | from aizynthfinder.search.breadth_first.search_tree import SearchTree
4 |
--------------------------------------------------------------------------------
/aizynthfinder/search/dfpn/__init__.py:
--------------------------------------------------------------------------------
1 | """ Sub-package containing DFPN routines
2 | """
3 | from aizynthfinder.search.dfpn.search_tree import SearchTree
4 |
--------------------------------------------------------------------------------
/aizynthfinder/search/dfpn/search_tree.py:
--------------------------------------------------------------------------------
1 | """ Module containing a class that holds the tree search
2 | """
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | from aizynthfinder.reactiontree import ReactionTree
8 | from aizynthfinder.search.andor_trees import AndOrSearchTreeBase, SplitAndOrTree
9 | from aizynthfinder.search.dfpn.nodes import MoleculeNode, ReactionNode
10 | from aizynthfinder.utils.logging import logger
11 |
12 | if TYPE_CHECKING:
13 | from aizynthfinder.chem import RetroReaction
14 | from aizynthfinder.context.config import Configuration
15 | from aizynthfinder.search.andor_trees import TreeNodeMixin
16 | from aizynthfinder.utils.type_utils import List, Optional, Sequence, Union
17 |
18 |
19 | class SearchTree(AndOrSearchTreeBase):
20 | """
21 | Encapsulation of the Depth-First Proof-Number (DFPN) search algorithm.
22 |
23 | This algorithm does not support:
24 | 1. Filter policy
25 | 2. Serialization and deserialization
26 |
27 | :ivar config: settings of the tree search algorithm
28 | :ivar root: the root node
29 |
30 | :param config: settings of the tree search algorithm
31 | :param root_smiles: the root will be set to a node representing this molecule, defaults to None
32 | """
33 |
34 | def __init__(
35 | self, config: Configuration, root_smiles: Optional[str] = None
36 | ) -> None:
37 | super().__init__(config, root_smiles)
38 | self._mol_nodes: List[MoleculeNode] = []
39 | self._logger = logger()
40 | self._root_smiles = root_smiles
41 | if root_smiles:
42 | self.root: Optional[MoleculeNode] = MoleculeNode.create_root(
43 | root_smiles, config, self
44 | )
45 | self._mol_nodes.append(self.root)
46 | else:
47 | self.root = None
48 |
49 | self._routes: List[ReactionTree] = []
50 | self._frontier: Optional[Union[MoleculeNode, ReactionNode]] = None
51 | self._initiated = False
52 |
53 | self.profiling = {
54 | "expansion_calls": 0,
55 | "reactants_generations": 0,
56 | }
57 |
58 | @property
59 | def mol_nodes(self) -> Sequence[MoleculeNode]: # type: ignore
60 | """Return the molecule nodes of the tree"""
61 | return self._mol_nodes
62 |
63 | def one_iteration(self) -> bool:
64 | """
65 | Perform one iteration of expansion.
66 |
67 | If possible expand the frontier node twice, i.e. expanding an OR
68 | node and then and AND node. If frontier not expandable step up in the
69 | tree and find a new frontier to expand.
70 |
71 | If a solution is found, mask that tree for exploration and start over.
72 |
73 | :raises StopIteration: if the search should be pre-maturely terminated
74 | :return: if a solution was found
75 | :rtype: bool
76 | """
77 | if not self._initiated:
78 | if self.root is None:
79 | raise ValueError("Root is undefined. Cannot make an iteration")
80 |
81 | self._routes = []
82 | self._frontier = self.root
83 | assert self.root is not None
84 |
85 | while True:
86 | # Expand frontier, should be OR node
87 | assert isinstance(self._frontier, MoleculeNode)
88 | expanded_or = self._search_step()
89 | expanded_and = False
90 | if self._frontier:
91 | # Expand frontier again, this time an AND node
92 | assert isinstance(self._frontier, ReactionNode)
93 | expanded_and = self._search_step()
94 | if (
95 | expanded_or
96 | or expanded_and
97 | or self._frontier is None
98 | or self._frontier is self.root
99 | ):
100 | break
101 |
102 | found_solution = any(child.proven for child in self.root.children)
103 | if self._frontier is self.root:
104 | self.root.reset()
105 |
106 | if self._frontier is None:
107 | raise StopIteration()
108 |
109 | return found_solution
110 |
111 | def routes(self) -> List[ReactionTree]:
112 | """
113 | Extracts and returns routes from the AND/OR tree
114 |
115 | :return: the routes
116 | """
117 | if self.root is None:
118 | return []
119 | if not self._routes:
120 | self._routes = SplitAndOrTree(self.root, self.config.stock).routes
121 | return self._routes
122 |
123 | def _search_step(self) -> bool:
124 | assert self._frontier is not None
125 | expanded = False
126 | if self._frontier.expandable:
127 | self._frontier.expand()
128 | expanded = True
129 | if isinstance(self._frontier, ReactionNode):
130 | self._mol_nodes.extend(self._frontier.children)
131 |
132 | self._frontier.update()
133 | if not self._frontier.explorable():
134 | self._frontier = self._frontier.parent
135 | return False
136 |
137 | child = self._frontier.promising_child()
138 | if not child:
139 | self._frontier = self._frontier.parent
140 | return False
141 |
142 | self._frontier = child
143 | return expanded
144 |
--------------------------------------------------------------------------------
/aizynthfinder/search/mcts/__init__.py:
--------------------------------------------------------------------------------
1 | """ Sub-package containing MCTS routines
2 | """
3 | from aizynthfinder.search.mcts.node import MctsNode
4 | from aizynthfinder.search.mcts.search import MctsSearchTree
5 | from aizynthfinder.search.mcts.state import MctsState
6 |
--------------------------------------------------------------------------------
/aizynthfinder/search/mcts/utils.py:
--------------------------------------------------------------------------------
1 | """ Module containing utility routines for MCTS. This is not part of public interface """
2 | from __future__ import annotations
3 |
4 | from typing import TYPE_CHECKING
5 |
6 | from aizynthfinder.reactiontree import ReactionTreeLoader
7 |
8 | if TYPE_CHECKING:
9 | from aizynthfinder.chem import RetroReaction
10 | from aizynthfinder.search.mcts import MctsNode
11 | from aizynthfinder.utils.type_utils import List, Optional, Tuple
12 |
13 |
14 | class ReactionTreeFromSuperNode(ReactionTreeLoader):
15 | """
16 | Creates a reaction tree object from MCTS-like nodes and reaction objects
17 | """
18 |
19 | def _load(self, base_node: MctsNode) -> None: # type: ignore
20 | actions, nodes = route_to_node(base_node)
21 | self.tree.created_at_iteration = base_node.created_at_iteration
22 | root_mol = nodes[0].state.mols[0]
23 | self._unique_mols[id(root_mol)] = root_mol.make_unique()
24 | self._add_node(
25 | self._unique_mols[id(root_mol)],
26 | in_stock=nodes[0].state.is_solved,
27 | )
28 |
29 | for child, action in zip(nodes[1:], actions):
30 | self._add_bipartite(child, action)
31 |
32 | def _add_bipartite(self, child: MctsNode, action: RetroReaction) -> None:
33 | reaction_obj = self._unique_reaction(action)
34 | self._add_node(reaction_obj, depth=2 * action.mol.transform + 1)
35 | self.tree.graph.add_edge(self._unique_mol(action.mol), reaction_obj)
36 | reactant_nodes = []
37 | for mol in child.state.mols:
38 | if mol.parent is action.mol:
39 | self._add_node(
40 | self._unique_mol(mol),
41 | depth=2 * mol.transform,
42 | transform=mol.transform,
43 | in_stock=mol in child.state.stock,
44 | )
45 | self.tree.graph.add_edge(reaction_obj, self._unique_mol(mol))
46 | reactant_nodes.append(self._unique_mol(mol))
47 | reaction_obj.reactants = (tuple(reactant_nodes),)
48 |
49 |
50 | def route_to_node(
51 | from_node: MctsNode,
52 | ) -> Tuple[List[RetroReaction], List[MctsNode]]:
53 | """
54 | Return the route to a give node to the root.
55 |
56 | Will return both the actions taken to go between the nodes,
57 | and the nodes in the route themselves.
58 |
59 | :param from_node: the end of the route
60 | :return: the route
61 | """
62 | actions = []
63 | nodes = []
64 | current: Optional[MctsNode] = from_node
65 |
66 | while current is not None:
67 | parent = current.parent
68 | if parent is not None:
69 | action = parent[current]["action"]
70 | actions.append(action)
71 | nodes.append(current)
72 | current = parent
73 | return actions[::-1], nodes[::-1]
74 |
--------------------------------------------------------------------------------
/aizynthfinder/search/retrostar/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/aizynthfinder/search/retrostar/__init__.py
--------------------------------------------------------------------------------
/aizynthfinder/search/retrostar/cost.py:
--------------------------------------------------------------------------------
1 | """ Module containing Retro* cost models """
2 | from __future__ import annotations
3 |
4 | import pickle
5 | from typing import TYPE_CHECKING
6 |
7 | import numpy as np
8 |
9 | from aizynthfinder.search.retrostar.cost import __name__ as retrostar_cost_module
10 | from aizynthfinder.utils.loading import load_dynamic_class
11 |
12 | if TYPE_CHECKING:
13 | from aizynthfinder.chem import Molecule
14 | from aizynthfinder.context.config import Configuration
15 | from aizynthfinder.utils.type_utils import Any, List, Tuple
16 |
17 |
18 | class MoleculeCost:
19 | """
20 | A class to compute the molecule cost.
21 |
22 | The cost to be computed is taken from the input config. If no `molecule_cost` is
23 | set, assigns ZeroMoleculeCost as the `cost` by default. The `molecule_cost` can be
24 | set as a dictionary in config under `search` in the following format:
25 | 'algorithm': 'retrostar'
26 | 'algorithm_config': {
27 | 'molecule_cost': {
28 | 'cost': name of the search cost class or custom_package.custom_model.CustomClass,
29 | other settings or params
30 | }
31 | }
32 |
33 | The cost can be computed by calling the instantiated class with a molecule.
34 |
35 | .. code-block::
36 |
37 | calculator = MyCost(config)
38 | cost = calculator.calculate(molecule)
39 |
40 | :param config: the configuration of the tree search
41 | """
42 |
43 | def __init__(self, config: Configuration) -> None:
44 | self._config = config
45 | if "molecule_cost" not in self._config.search.algorithm_config:
46 | self._config.search.algorithm_config["molecule_cost"] = {
47 | "cost": "ZeroMoleculeCost"
48 | }
49 | kwargs = self._config.search.algorithm_config["molecule_cost"].copy()
50 |
51 | cls = load_dynamic_class(kwargs["cost"], retrostar_cost_module)
52 | del kwargs["cost"]
53 |
54 | self.molecule_cost = cls(**kwargs) if kwargs else cls()
55 |
56 | def __call__(self, mol: Molecule) -> float:
57 | return self.molecule_cost.calculate(mol)
58 |
59 |
60 | class RetroStarCost:
61 | """
62 | Encapsulation of the original Retro* molecular cost model
63 |
64 | Numpy implementation of original pytorch model
65 |
66 | The predictions of the score is made on a Molecule object
67 |
68 | .. code-block::
69 |
70 | mol = Molecule(smiles="CCC")
71 | scorer = RetroStarCost()
72 | score = scorer.calculate(mol)
73 |
74 | The model provided when creating the scorer object should be a pickled
75 | tuple.
76 | The first item of the tuple should be a list of the model weights for each layer.
77 | The second item of the tuple should be a list of the model biases for each layer.
78 |
79 | :param model_path: the filename of the model weights and biases
80 | :param fingerprint_length: the number of bits in the fingerprint
81 | :param fingerprint_radius: the radius of the fingerprint
82 | :param dropout_rate: the dropout_rate
83 | """
84 |
85 | _required_kwargs = ["model_path"]
86 |
87 | def __init__(self, **kwargs: Any) -> None:
88 | model_path = kwargs["model_path"]
89 | self.fingerprint_length: int = int(kwargs.get("fingerprint_length", 2048))
90 | self.fingerprint_radius: int = int(kwargs.get("fingerprint_radius", 2))
91 | self.dropout_rate: float = float(kwargs.get("dropout_rate", 0.1))
92 |
93 | self._dropout_prob = 1.0 - self.dropout_rate
94 | self._weights, self._biases = self._load_model(model_path)
95 |
96 | def __repr__(self) -> str:
97 | return "retrostar"
98 |
99 | def calculate(self, mol: Molecule) -> float:
100 | # pylint: disable=invalid-name
101 | mol.sanitize()
102 | vec = mol.fingerprint(
103 | radius=self.fingerprint_radius, nbits=self.fingerprint_length
104 | )
105 | for W, b in zip(self._weights[:-1], self._biases[:-1]):
106 | vec = np.matmul(vec, W) + b
107 | vec *= vec > 0 # ReLU
108 | # Drop-out
109 | vec *= np.random.binomial(1, self._dropout_prob, size=vec.shape) / (
110 | self._dropout_prob
111 | )
112 | vec = np.matmul(vec, self._weights[-1]) + self._biases[-1]
113 | return float(np.log(1 + np.exp(vec)))
114 |
115 | @staticmethod
116 | def _load_model(model_path: str) -> Tuple[List[np.ndarray], List[np.ndarray]]:
117 | with open(model_path, "rb") as fileobj:
118 | weights, biases = pickle.load(fileobj)
119 |
120 | return (
121 | [np.asarray(item) for item in weights],
122 | [np.asarray(item) for item in biases],
123 | )
124 |
125 |
126 | class ZeroMoleculeCost:
127 | """Encapsulation of a Zero cost model"""
128 |
129 | def __repr__(self) -> str:
130 | return "zero"
131 |
132 | def calculate(self, _mol: Molecule) -> float: # pytest: disable=unused-argument
133 | return 0.0
134 |
--------------------------------------------------------------------------------
/aizynthfinder/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/aizynthfinder/tools/__init__.py
--------------------------------------------------------------------------------
/aizynthfinder/tools/cat_output.py:
--------------------------------------------------------------------------------
1 | """ Module containing a CLI for concatenating output files (hdf5/json.gz files)
2 | """
3 | import argparse
4 |
5 | from aizynthfinder.utils.files import cat_datafiles
6 |
7 |
8 | def main() -> None:
9 | """Entry-point for the cat_aizynth_output CLI"""
10 | parser = argparse.ArgumentParser("cat_aizynthcli_output")
11 | parser.add_argument(
12 | "--files",
13 | required=True,
14 | nargs="+",
15 | help="the output filenames",
16 | )
17 | parser.add_argument(
18 | "--output",
19 | required=True,
20 | default="output.json.gz",
21 | help="the name of the concatenate output file ",
22 | )
23 | parser.add_argument(
24 | "--trees",
25 | help="if given, save all trees to this file",
26 | )
27 | args = parser.parse_args()
28 |
29 | cat_datafiles(args.files, args.output, args.trees)
30 |
31 |
32 | if __name__ == "__main__":
33 | main()
34 |
--------------------------------------------------------------------------------
/aizynthfinder/tools/download_public_data.py:
--------------------------------------------------------------------------------
1 | """ Module with script to download public data
2 | """
3 | import argparse
4 | import os
5 | import sys
6 |
7 | import requests
8 | import tqdm
9 |
10 | FILES_TO_DOWNLOAD = {
11 | "policy_model_onnx": {
12 | "filename": "uspto_model.onnx",
13 | "url": "https://zenodo.org/record/7797465/files/uspto_model.onnx",
14 | },
15 | "template_file": {
16 | "filename": "uspto_templates.csv.gz",
17 | "url": "https://zenodo.org/record/7341155/files/uspto_unique_templates.csv.gz",
18 | },
19 | "ringbreaker_model_onnx": {
20 | "filename": "uspto_ringbreaker_model.onnx",
21 | "url": "https://zenodo.org/record/7797465/files/uspto_ringbreaker_model.onnx",
22 | },
23 | "ringbreaker_templates": {
24 | "filename": "uspto_ringbreaker_templates.csv.gz",
25 | "url": "https://zenodo.org/record/7341155/files/uspto_ringbreaker_unique_templates.csv.gz",
26 | },
27 | "stock": {
28 | "filename": "zinc_stock.hdf5",
29 | "url": "https://ndownloader.figshare.com/files/23086469",
30 | },
31 | "filter_policy_onnx": {
32 | "filename": "uspto_filter_model.onnx",
33 | "url": "https://zenodo.org/record/7797465/files/uspto_filter_model.onnx",
34 | },
35 | }
36 |
37 | YAML_TEMPLATE = """expansion:
38 | uspto:
39 | - {}
40 | - {}
41 | ringbreaker:
42 | - {}
43 | - {}
44 | filter:
45 | uspto: {}
46 | stock:
47 | zinc: {}
48 | """
49 |
50 |
51 | def _download_file(url: str, filename: str) -> None:
52 | with requests.get(url, stream=True) as response:
53 | response.raise_for_status()
54 | total_size = int(response.headers.get("content-length", 0))
55 | pbar = tqdm.tqdm(
56 | total=total_size, desc=os.path.basename(filename), unit="B", unit_scale=True
57 | )
58 | with open(filename, "wb") as fileobj:
59 | for chunk in response.iter_content(chunk_size=1024):
60 | fileobj.write(chunk)
61 | pbar.update(len(chunk))
62 | pbar.close()
63 |
64 |
65 | def main() -> None:
66 | """Entry-point for CLI"""
67 | parser = argparse.ArgumentParser("download_public_data")
68 | parser.add_argument(
69 | "path",
70 | default=".",
71 | help="the path to download the files",
72 | )
73 | path = parser.parse_args().path
74 |
75 | try:
76 | for filespec in FILES_TO_DOWNLOAD.values():
77 | _download_file(filespec["url"], os.path.join(path, filespec["filename"]))
78 | except requests.HTTPError as err:
79 | print(f"Download failed with message {str(err)}")
80 | sys.exit(1)
81 |
82 | with open(os.path.join(path, "config.yml"), "w") as fileobj:
83 | path = os.path.abspath(path)
84 | fileobj.write(
85 | YAML_TEMPLATE.format(
86 | os.path.join(path, FILES_TO_DOWNLOAD["policy_model_onnx"]["filename"]),
87 | os.path.join(path, FILES_TO_DOWNLOAD["template_file"]["filename"]),
88 | os.path.join(
89 | path, FILES_TO_DOWNLOAD["ringbreaker_model_onnx"]["filename"]
90 | ),
91 | os.path.join(
92 | path, FILES_TO_DOWNLOAD["ringbreaker_templates"]["filename"]
93 | ),
94 | os.path.join(path, FILES_TO_DOWNLOAD["filter_policy_onnx"]["filename"]),
95 | os.path.join(path, FILES_TO_DOWNLOAD["stock"]["filename"]),
96 | )
97 | )
98 | print("Configuration file written to config.yml")
99 |
100 |
101 | if __name__ == "__main__":
102 | main()
103 |
--------------------------------------------------------------------------------
/aizynthfinder/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/aizynthfinder/utils/__init__.py
--------------------------------------------------------------------------------
/aizynthfinder/utils/bonds.py:
--------------------------------------------------------------------------------
1 | """Module containing a class to identify broken focussed bonds
2 | """
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | if TYPE_CHECKING:
8 | from typing import List, Sequence, Tuple
9 | from aizynthfinder.chem.mol import TreeMolecule
10 | from aizynthfinder.chem.reaction import RetroReaction
11 |
12 |
13 | class BrokenBonds:
14 | """
15 | A class to keep track of focussed bonds breaking in a target molecule.
16 |
17 | :param focussed_bonds: A list of focussed bond pairs. The bond pairs are represented
18 | as tuples of size 2. These bond pairs exist in the target molecule's atom bonds.
19 | """
20 |
21 | def __init__(self, focussed_bonds: Sequence[Sequence[int]]) -> None:
22 | self.focussed_bonds = sort_bonds(focussed_bonds)
23 | self.filtered_focussed_bonds: List[Tuple[int, int]] = []
24 |
25 | def __call__(self, reaction: RetroReaction) -> List[Tuple[int, int]]:
26 | """
27 | Provides a list of focussed bonds that break in any of the molecule's reactants.
28 |
29 | :param reaction: A retro reaction.
30 | :return: A list of all the focussed bonds that broke within the reactants
31 | that constitute the target molecule.
32 | """
33 | self.filtered_focussed_bonds = self._get_filtered_focussed_bonds(reaction.mol)
34 | if not self.filtered_focussed_bonds:
35 | return []
36 |
37 | molecule_bonds = []
38 | for reactant in reaction.reactants[reaction.index]:
39 | molecule_bonds += reactant.mapped_atom_bonds
40 |
41 | broken_focussed_bonds = self._get_broken_frozen_bonds(
42 | sort_bonds(molecule_bonds)
43 | )
44 | return broken_focussed_bonds
45 |
46 | def _get_broken_frozen_bonds(
47 | self,
48 | molecule_bonds: List[Tuple[int, int]],
49 | ) -> List[Tuple[int, int]]:
50 | broken_focussed_bonds = list(
51 | set(self.filtered_focussed_bonds) - set(molecule_bonds)
52 | )
53 | return broken_focussed_bonds
54 |
55 | def _get_filtered_focussed_bonds(
56 | self, molecule: TreeMolecule
57 | ) -> List[Tuple[int, int]]:
58 | molecule_bonds = molecule.mapped_atom_bonds
59 | atom_maps = [atom_map for bonds in molecule_bonds for atom_map in bonds]
60 |
61 | filtered_focussed_bonds = []
62 | for idx1, idx2 in self.focussed_bonds:
63 | if idx1 in atom_maps and idx2 in atom_maps:
64 | filtered_focussed_bonds.append((idx1, idx2))
65 | return filtered_focussed_bonds
66 |
67 |
68 | def sort_bonds(bonds: Sequence[Sequence[int]]) -> List[Tuple[int, int]]:
69 | return [tuple(sorted(bond)) for bond in bonds] # type: ignore
70 |
--------------------------------------------------------------------------------
/aizynthfinder/utils/exceptions.py:
--------------------------------------------------------------------------------
1 | """ Module containing custom exception classes
2 | """
3 |
4 |
5 | class CostException(Exception):
6 | """Exception raised by molecule cost classes"""
7 |
8 |
9 | class ExternalModelAPIError(Exception):
10 | """Custom error type to signal failure in External model"""
11 |
12 |
13 | class MoleculeException(Exception):
14 | """An exception that is raised by molecule class"""
15 |
16 |
17 | class NodeUnexpectedBehaviourException(Exception):
18 | """Exception that is raised if the tree search is behaving unexpectedly."""
19 |
20 |
21 | class PolicyException(Exception):
22 | """An exception raised by policy classes"""
23 |
24 |
25 | class RejectionException(Exception):
26 | """An exception raised if a retro action should be rejected"""
27 |
28 |
29 | class ScorerException(Exception):
30 | """Exception raised by scoring classes"""
31 |
32 |
33 | class StockException(Exception):
34 | """An exception raised by stock classes"""
35 |
36 |
37 | class TreeAnalysisException(Exception):
38 | """Exception raised when analysing trees"""
39 |
--------------------------------------------------------------------------------
/aizynthfinder/utils/loading.py:
--------------------------------------------------------------------------------
1 | """ Module containing routine to dynamically load a class from a specification """
2 | from __future__ import annotations
3 |
4 | import importlib
5 | from typing import TYPE_CHECKING
6 |
7 | if TYPE_CHECKING:
8 | from typing import Any, Optional
9 |
10 |
11 | def load_dynamic_class(
12 | name_spec: str,
13 | default_module: Optional[str] = None,
14 | exception_cls: Any = ValueError,
15 | ) -> Any:
16 | """
17 | Load an object from a dynamic specification.
18 |
19 | The specification can be either:
20 | ClassName, in-case the module name is taken from the `default_module` argument
21 | or
22 | package_name.module_name.ClassName, in-case the module is taken as `package_name.module_name`
23 |
24 | :param name_spec: the class specification
25 | :param default_module: the default module
26 | :param exception_cls: the exception class to raise on exception
27 | :return: the loaded class
28 | """
29 | if "." not in name_spec:
30 | name = name_spec
31 | if not default_module:
32 | raise exception_cls(
33 | "Must provide default_module argument if not given in name_spec"
34 | )
35 | module_name = default_module
36 | else:
37 | module_name, name = name_spec.rsplit(".", maxsplit=1)
38 |
39 | try:
40 | loaded_module = importlib.import_module(module_name)
41 | except ImportError:
42 | raise exception_cls(f"Unable to load module: {module_name}")
43 |
44 | if not hasattr(loaded_module, name):
45 | raise exception_cls(
46 | f"Module ({module_name}) does not have a class called {name}"
47 | )
48 |
49 | return getattr(loaded_module, name)
50 |
--------------------------------------------------------------------------------
/aizynthfinder/utils/logging.py:
--------------------------------------------------------------------------------
1 | """ Module containing routines to setup proper logging
2 | """
3 | # pylint: disable=ungrouped-imports, wrong-import-order, wrong-import-position, unused-import
4 | import logging.config
5 | import os
6 |
7 | import yaml
8 |
9 | # See Github issue 30 why sklearn is imported here
10 | try:
11 | import sklearn # noqa
12 | except ImportError:
13 | pass
14 | from rdkit import RDLogger
15 |
16 | from aizynthfinder.utils.paths import data_path
17 | from aizynthfinder.utils.type_utils import Optional
18 |
19 | # Suppress RDKit errors due to incomplete template (e.g. aromatic non-ring atoms)
20 | rd_logger = RDLogger.logger()
21 | rd_logger.setLevel(RDLogger.CRITICAL)
22 |
23 |
24 | def logger() -> logging.Logger:
25 | """
26 | Returns the logger that should be used by all classes
27 |
28 | :return: the logger object
29 | """
30 | return logging.getLogger("aizynthfinder")
31 |
32 |
33 | def setup_logger(
34 | console_level: int, file_level: Optional[int] = None
35 | ) -> logging.Logger:
36 | """
37 | Setup the logger that should be used by all classes
38 |
39 | The logger configuration is read from the `logging.yml` file.
40 |
41 | :param console_level: the level of logging to the console
42 | :param file_level: the level of logging to file, if not set logging to file is disabled, default to None
43 | :return: the logger object
44 | """
45 | filename = os.path.join(data_path(), "logging.yml")
46 | with open(filename, "r") as fileobj:
47 | config = yaml.load(fileobj.read(), Loader=yaml.SafeLoader)
48 |
49 | config["handlers"]["console"]["level"] = console_level
50 | if file_level:
51 | config["handlers"]["file"]["level"] = file_level
52 | else:
53 | del config["handlers"]["file"]
54 | config["loggers"]["aizynthfinder"]["handlers"].remove("file")
55 |
56 | logging.config.dictConfig(config)
57 | return logger()
58 |
--------------------------------------------------------------------------------
/aizynthfinder/utils/math.py:
--------------------------------------------------------------------------------
1 | """ Module containing diverse math functions, including neural network-related functions. """
2 |
3 | import numpy as np
4 | from aizynthfinder.utils.type_utils import Callable
5 |
6 |
7 | # pylint: disable=invalid-name
8 | def dense_layer_forward_pass(
9 | x: np.ndarray, weights: np.ndarray, bias: np.ndarray, activation: Callable
10 | ) -> np.ndarray:
11 | """
12 | Forward pass through a dense neural network layer.
13 | :param x: layer input
14 | :param weights: layer weights
15 | :param bias: layer bias
16 | :param activation: layer activation function
17 | :return: the layer output
18 | """
19 | x = np.matmul(x, weights) + bias
20 | return activation(x)
21 |
22 |
23 | # pylint: disable=invalid-name
24 | def rectified_linear_unit(x: np.ndarray) -> np.ndarray:
25 | """ReLU activation function"""
26 | return x * (x > 0)
27 |
28 |
29 | # pylint: disable=invalid-name
30 | def sigmoid(x: np.ndarray) -> np.ndarray:
31 | """Sigmoid activation function"""
32 | return 1 / (1 + np.exp(-x))
33 |
34 |
35 | # pylint: disable=invalid-name
36 | def softmax(x: np.ndarray) -> np.ndarray:
37 | """Compute softmax values for each sets of scores in x."""
38 | return np.exp(x) / np.sum(np.exp(x), axis=0)
39 |
--------------------------------------------------------------------------------
/aizynthfinder/utils/mongo.py:
--------------------------------------------------------------------------------
1 | """ Module containing routines to obtain a MongoClient instance
2 | """
3 | from typing import Optional
4 | from urllib.parse import urlencode
5 |
6 | try:
7 | from pymongo import MongoClient
8 | except ImportError:
9 | MongoClient = None
10 | HAS_PYMONGO = False
11 | else:
12 | HAS_PYMONGO = True
13 |
14 | from aizynthfinder.utils.logging import logger
15 |
16 | _CLIENT = None
17 |
18 |
19 | def get_mongo_client(
20 | host: str = "localhost",
21 | port: int = 27017,
22 | user: Optional[str] = None,
23 | password: Optional[str] = None,
24 | tls_certs_path: str = "",
25 | ) -> Optional[MongoClient]:
26 | """
27 | A helper function to create and reuse MongoClient
28 |
29 | The client is only setup once. Therefore if this function is called a second
30 | time with different parameters, it would still return the first client.
31 |
32 | :param host: the host
33 | :param port: the host port
34 | :param user: username, defaults to None
35 | :param password: password, defaults to None
36 | :param tls_certs_path: the path to TLS certificates if to be used, defaults to ""
37 | :raises ValueError: if host and port is not given first time
38 | :return: the MongoDB client
39 | """
40 | if not HAS_PYMONGO:
41 | return None
42 |
43 | global _CLIENT
44 | if _CLIENT is None:
45 | params = {}
46 | if tls_certs_path:
47 | params.update({"ssl": "true", "ssl_ca_certs": tls_certs_path})
48 | cred_str = f"{user}:{password}@" if password else ""
49 | uri = f"mongodb://{cred_str}{host}:{port}/?{urlencode(params)}"
50 | logger().debug(f"Connecting to MongoDB on {host}:{port}")
51 | _CLIENT = MongoClient(uri) # pylint: disable=C0103
52 | return _CLIENT
53 |
--------------------------------------------------------------------------------
/aizynthfinder/utils/paths.py:
--------------------------------------------------------------------------------
1 | """ Module containing routines for returning package paths
2 | """
3 | import os
4 |
5 |
6 | def package_path() -> str:
7 | """Return the path to the package"""
8 | return os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
9 |
10 |
11 | def data_path() -> str:
12 | """Return the path to the ``data`` directory of the package"""
13 | return os.path.join(package_path(), "data")
14 |
--------------------------------------------------------------------------------
/aizynthfinder/utils/sc_score.py:
--------------------------------------------------------------------------------
1 | """ Module containing the implementation of the SC-score model for synthetic complexity scoring. """
2 |
3 | from __future__ import annotations
4 |
5 | import pickle
6 | from typing import TYPE_CHECKING
7 |
8 | import numpy as np
9 | from rdkit.Chem import AllChem
10 |
11 | from aizynthfinder.utils.math import (
12 | dense_layer_forward_pass,
13 | rectified_linear_unit,
14 | sigmoid,
15 | )
16 |
17 | if TYPE_CHECKING:
18 | from aizynthfinder.utils.type_utils import RdMol, Sequence, Tuple
19 |
20 |
21 | class SCScore:
22 | """
23 | Encapsulation of the SCScore model
24 |
25 | Re-write of the SCScorer from the scscorer package
26 |
27 | The predictions of the score is made with a sanitized instance of an RDKit molecule
28 |
29 | .. code-block::
30 |
31 | mol = Molecule(smiles="CCC", sanitize=True)
32 | scscorer = SCScorer("path_to_model")
33 | score = scscorer(mol.rd_mol)
34 |
35 | The model provided when creating the scorer object should be pickled tuple.
36 | The first item of the tuple should be a list of the model weights for each layer.
37 | The second item of the tuple should be a list of the model biases for each layer.
38 |
39 | :param model_path: the filename of the model weights and biases
40 | :param fingerprint_length: the number of bits in the fingerprint
41 | :param fingerprint_radius: the radius of the fingerprint
42 | """
43 |
44 | def __init__(
45 | self,
46 | model_path: str,
47 | fingerprint_length: int = 1024,
48 | fingerprint_radius: int = 2,
49 | ) -> None:
50 | self._fingerprint_length = fingerprint_length
51 | self._fingerprint_radius = fingerprint_radius
52 | self._weights, self._biases = self._load_model(model_path)
53 | self.score_scale = 5.0
54 |
55 | def __call__(self, rd_mol: RdMol) -> float:
56 | fingerprint = self._make_fingerprint(rd_mol)
57 | normalized_score = self.forward(fingerprint)
58 | sc_score = (1 + (self.score_scale - 1) * normalized_score)[0]
59 | return sc_score
60 |
61 | # pylint: disable=invalid-name
62 | def forward(self, x: np.ndarray) -> np.ndarray:
63 | """Forward pass with dense neural network"""
64 | for weights, bias in zip(self._weights[:-1], self._biases[:-1]):
65 | x = dense_layer_forward_pass(x, weights, bias, rectified_linear_unit)
66 | return dense_layer_forward_pass(x, self._weights[-1], self._biases[-1], sigmoid)
67 |
68 | def _load_model(
69 | self, model_path: str
70 | ) -> Tuple[Sequence[np.ndarray], Sequence[np.ndarray]]:
71 | """Returns neural network model parameters."""
72 | with open(model_path, "rb") as fileobj:
73 | weights, biases = pickle.load(fileobj)
74 |
75 | weights = [np.asarray(item) for item in weights]
76 | biases = [np.asarray(item) for item in biases]
77 | return weights, biases
78 |
79 | def _make_fingerprint(self, rd_mol: RdMol) -> np.ndarray:
80 | """Returns the molecule's Morgan fingerprint"""
81 | fp_vec = AllChem.GetMorganFingerprintAsBitVect(
82 | rd_mol,
83 | self._fingerprint_radius,
84 | nBits=self._fingerprint_length,
85 | useChirality=True,
86 | )
87 | return np.array(
88 | fp_vec,
89 | dtype=bool,
90 | )
91 |
--------------------------------------------------------------------------------
/aizynthfinder/utils/type_utils.py:
--------------------------------------------------------------------------------
1 | """ Module containing all types and type imports
2 | """
3 | # pylint: disable=unused-import
4 | from typing import Callable # noqa
5 | from typing import Iterable # noqa
6 | from typing import List # noqa
7 | from typing import Sequence # noqa
8 | from typing import Set # noqa
9 | from typing import TypeVar # noqa
10 | from typing import Any, Dict, Optional, Tuple, Union
11 |
12 | from PIL.Image import Image
13 | from rdkit import Chem
14 | from rdkit.Chem import rdChemReactions
15 | from rdkit.DataStructs.cDataStructs import ExplicitBitVect
16 |
17 | StrDict = Dict[str, Any]
18 | RdMol = Chem.rdchem.Mol
19 | RdReaction = Chem.rdChemReactions.ChemicalReaction
20 | BitVector = ExplicitBitVect
21 | PilImage = Image
22 | PilColor = Union[str, Tuple[int, int, int]]
23 | FrameColors = Optional[Dict[bool, PilColor]]
24 |
--------------------------------------------------------------------------------
/contrib/notebook.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "ObD2YL7nM2_X"
7 | },
8 | "source": [
9 | "# AiZynthFinder\n",
10 | "\n",
11 | "Click the ▶ play button at the left of the **Installation** text below to install the application. The initial installation process may take a few minutes. Then run **Start application** cell. \n",
12 | "\n",
13 | "1. Enter the target compound [SMILES][1] code.\n",
14 | "3. Click the **Run Search** button to start the algorithm.\n",
15 | "4. Once it stops serching, click the **Show Reactions** button.\n",
16 | "\n",
17 | "[1]: https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": null,
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "#@title Installation -- Run this cell to install AiZynthFinder\n",
27 | "\n",
28 | "# This might not install the dependencies that you need\n",
29 | "#!(AIZ_LATEST_TAG=$(git -c 'versionsort.suffix=-' \\\n",
30 | "# ls-remote --exit-code --refs --sort='version:refname' --tags https://github.com/MolecularAI/aizynthfinder '*.*.*' \\\n",
31 | "# | tail --lines=1 \\\n",
32 | "# | cut -f 3 -d /) \\\n",
33 | "# && pip install --quiet https://github.com/MolecularAI/aizynthfinder/archive/${AIZ_LATEST_TAG}.tar.gz)\n",
34 | "!pip install --quiet aizynthfinder[all]\n",
35 | "!pip install --ignore-installed Pillow==9.0.0\n",
36 | "!mkdir --parents data && download_public_data data\n"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {
43 | "id": "beDP-CSNM429"
44 | },
45 | "outputs": [],
46 | "source": [
47 | "#@title Start application. {display-mode: \"form\"}\n",
48 | "\n",
49 | "from rdkit.Chem.Draw import IPythonConsole\n",
50 | "from aizynthfinder.interfaces import AiZynthApp\n",
51 | "application = AiZynthApp(\"./data/config.yml\")"
52 | ]
53 | },
54 | {
55 | "cell_type": "markdown",
56 | "metadata": {
57 | "id": "bwxusoogwlI9"
58 | },
59 | "source": [
60 | "# Bibliography\n",
61 | "\n",
62 | "_Genheden S, Thakkar A, Chadimova V, et al (2020) AiZynthFinder: a fast, robust and flexible open-source software for retrosynthetic planning. J. Cheminf. https://doi.org/10.1186/s13321-020-00472-1 ([GitHub](https://github.com/MolecularAI/aizynthfinder) & [Documentation](https://molecularai.github.io/aizynthfinder/html/index.html))_"
63 | ]
64 | }
65 | ],
66 | "metadata": {
67 | "accelerator": "GPU",
68 | "colab": {
69 | "collapsed_sections": [],
70 | "name": "AiZynthFinder.ipynb",
71 | "private_outputs": true,
72 | "provenance": [],
73 | "toc_visible": true
74 | },
75 | "kernelspec": {
76 | "display_name": "Python 3 (ipykernel)",
77 | "language": "python",
78 | "name": "python3"
79 | },
80 | "language_info": {
81 | "codemirror_mode": {
82 | "name": "ipython",
83 | "version": 3
84 | },
85 | "file_extension": ".py",
86 | "mimetype": "text/x-python",
87 | "name": "python",
88 | "nbconvert_exporter": "python",
89 | "pygments_lexer": "ipython3",
90 | "version": "3.10.9"
91 | }
92 | },
93 | "nbformat": 4,
94 | "nbformat_minor": 1
95 | }
96 |
--------------------------------------------------------------------------------
/docs/analysis-rel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/docs/analysis-rel.png
--------------------------------------------------------------------------------
/docs/analysis-seq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/docs/analysis-seq.png
--------------------------------------------------------------------------------
/docs/cli.rst:
--------------------------------------------------------------------------------
1 | Command-line interface
2 | ======================
3 |
4 | This tools provide the possibility to perform tree search on a batch of molecules.
5 |
6 | In its simplest form, you type
7 |
8 | .. code-block:: bash
9 |
10 | aizynthcli --config config_local.yml --smiles smiles.txt
11 |
12 | where `config_local.yml` contains configurations such as paths to policy models and stocks (see :doc:`here `)
13 | and `smiles.txt` is a simple text file with SMILES (one on each row).
14 |
15 |
16 | To find out what other arguments are available use the ``-h`` flag.
17 |
18 | .. code-block:: bash
19 |
20 | aizynthcli -h
21 |
22 | That gives something like this:
23 |
24 | .. include:: cli_help.txt
25 |
26 |
27 | By default:
28 |
29 | * `All` stocks are selected if no stock is specified
30 | * `First` expansion policy is selected if not expansion policy is specified
31 | * `All` filter policies are selected if it is not specified on the command-line
32 |
33 | Analysing output
34 | ----------------
35 |
36 |
37 | The results from the ``aizynthcli`` tool when supplying multiple SMILES is an JSON or HDF5 file that can be read as a pandas dataframe.
38 | It will be called `output.json.gz` by default.
39 |
40 | A `checkpoint.json.gz` will also be generated if a checkpoint file path is provided as input when calling the ``aizynthcli`` tool. The
41 | checkpoint data will contain the processed smiles with their corresponding results in each line of the file.
42 |
43 | .. code-block::
44 |
45 | import pandas as pd
46 | data = pd.read_json("output.json.gz", orient="table")
47 |
48 | it will contain statistics about the tree search and the top-ranked routes (as JSONs) for each target compound, see below.
49 |
50 | When a single SMILES is provided to the tool, the statistics will be written to the terminal, and the top-ranked routes to
51 | a JSON file (`trees.json` by default).
52 |
53 |
54 | This is an example of how to create images of the top-ranked routes for the first target compound
55 |
56 |
57 | .. code-block::
58 |
59 | import pandas as pd
60 | from aizynthfinder.reactiontree import ReactionTree
61 |
62 | data = pd.read_json("output.json.gz", orient="table")
63 | all_trees = data.trees.values # This contains a list of all the trees for all the compounds
64 | trees_for_first_target = all_trees[0]
65 |
66 | for itree, tree in enumerate(trees_for_first_target):
67 | imagefile = f"route{itree:03d}.png"
68 | ReactionTree.from_dict(tree).to_image().save(imagefile)
69 |
70 | The images will be called `route000.png`, `route001.png` etc.
71 |
72 |
73 | Specification of output
74 | -----------------------
75 |
76 | The JSON or HDF5 file created when running the tool with a list of SMILES will have the following columns
77 |
78 | ============================= ===========
79 | Column Description
80 | ============================= ===========
81 | target The target SMILES
82 | search_time The total search time in seconds
83 | first_solution_time The time elapsed until the first solution was found
84 | first_solution_iteration The number of iterations completed until the first solution was found
85 | number_of_nodes The number of nodes in the search tree
86 | max_transforms The maximum number of transformations for all routes in the search tree
87 | max_children The maximum number of children for a search node
88 | number_of_routes The number of routes in the search tree
89 | number_of_solved_routes The number of solved routes in search tree
90 | top_score The score of the top-scored route (default to MCTS reward)
91 | is_solved If the top-scored route is solved
92 | number_of_steps The number of reactions in the top-scored route
93 | number_of_precursors The number of starting materials
94 | number_of_precursors_in_stock The number of starting materials in stock
95 | precursors_in_stock Comma-separated list of SMILES of starting material in stock
96 | precursors_not_in_stock Comma-separated list of SMILES of starting material not in stock
97 | precursors_availability Semi-colon separated list of stock availability of the staring material
98 | policy_used_counts Dictionary of the total number of times an expansion policy have been used
99 | profiling Profiling information from the search tree, including expansion models call and reactant generation
100 | stock_info Dictionary of the stock availability for each of the starting material in all extracted routes
101 | top_scores Comma-separated list of the score of the extracted routes (default to MCTS reward)
102 | trees A list of the extracted routes as dictionaries
103 | ============================= ===========
104 |
105 | If you running the tool with a single SMILES, all of this data will be printed to the screen, except
106 | the ``stock_info`` and ``trees``.
107 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.insert(0, os.path.abspath("."))
5 |
6 | project = "aizynthfinder"
7 | copyright = "2020-2025, Molecular AI group"
8 | author = "Molecular AI group"
9 | release = "4.3.2"
10 |
11 | # This make sure that the cli_help.txt file is properly formated
12 | with open("cli_help.txt", "r") as fileobj:
13 | lines = fileobj.read().splitlines()
14 | with open("cli_help.txt", "w") as fileobj:
15 | fileobj.write(".. code-block::\n\n")
16 | fileobj.write(" " + "\n ".join(lines))
17 |
18 | extensions = [
19 | "sphinx.ext.autodoc",
20 | ]
21 | autodoc_member_order = "bysource"
22 | autodoc_typehints = "description"
23 |
24 | html_theme = "alabaster"
25 | html_theme_options = {
26 | "description": "A fast, robust and flexible software for retrosynthetic planning",
27 | "fixed_sidebar": True,
28 | }
29 |
--------------------------------------------------------------------------------
/docs/gui.rst:
--------------------------------------------------------------------------------
1 | Graphical user interface
2 | ========================
3 |
4 | This tool to provide the possibility to perform the tree search on a single compound using a GUI
5 | through a Jupyter notebook. If you are unfamiliar with notebooks, you find some introduction `here `_.
6 |
7 | To bring up the notebook, use
8 |
9 | .. code-block:: bash
10 |
11 | jupyter notebook
12 |
13 | and browse to an existing notebook or create a new one.
14 |
15 | Add these lines to the first cell in the notebook.
16 |
17 | .. code-block::
18 |
19 | from aizynthfinder.interfaces import AiZynthApp
20 | app = AiZynthApp("/path/to/configfile.yaml")
21 |
22 | where the ``AiZynthApp`` class needs to be instantiated with the path to a configuration file (see :doc:`here `).
23 |
24 | To use the interface, follow these steps:
25 |
26 | 1. Executed the code in the cell (press ``Ctrl+Enter``) and a simple GUI will appear
27 | 2. Enter the target SMILES and select stocks and policy model.
28 | 3. Press the ``Run Search`` button to perform the tree search.
29 |
30 | .. image:: gui_input.png
31 |
32 |
33 | 4. Press the ``Show Reactions`` to see the top-ranked routes
34 |
35 |
36 | .. image:: gui_results.png
37 |
38 | You can also choose to select and sort the top-ranked routes based on another scoring function.
39 |
40 |
41 | Creating the notebook
42 | ---------------------
43 |
44 | It is possible to create a notebook automatically with the ``aizynthapp`` tool
45 |
46 | .. code-block:: bash
47 |
48 | aizynthapp --config config_local.yml
49 |
50 | which will also automatically opens up the created notebook.
51 |
52 | Analysing the results
53 | ---------------------
54 |
55 | When the tree search has been finished. One can continue exploring the tree and extract output.
56 | This is done by using the ``finder`` property of the app object. The finder holds a reference to an ``AiZynthFinder`` object.
57 |
58 | .. code-block::
59 |
60 | finder = app.finder
61 | stats = finder.extract_statistics()
62 |
63 |
64 | Clustering
65 | -----------
66 |
67 | There is a GUI extension to perform clustering of the routes. Enter the following a new cell
68 |
69 | .. code-block::
70 |
71 | %matplotlib inline
72 | from aizynthfinder.interfaces.gui.clustering import ClusteringGui
73 | ClusteringGui.from_app(app)
74 |
75 |
76 | A GUI like this will be shown, where you see the hierarchy of the routes and then can select how many
77 | clusters you want to create.
78 |
79 | .. image:: gui_clustering.png
--------------------------------------------------------------------------------
/docs/gui_clustering.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/docs/gui_clustering.png
--------------------------------------------------------------------------------
/docs/gui_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/docs/gui_input.png
--------------------------------------------------------------------------------
/docs/gui_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/docs/gui_results.png
--------------------------------------------------------------------------------
/docs/howto.rst:
--------------------------------------------------------------------------------
1 | How-to
2 | =======
3 |
4 | This page outlines a few guidelines on some more advanced use-cases of AiZynthFinder or
5 | frequently raised issues.
6 |
7 |
8 | Using Retro*
9 | ------------
10 |
11 | AiZynthFinder implements other search algorithms than MCTS. This is an example of how Retro* can be used.
12 |
13 | The search algorithm is specified in the configuration file.
14 |
15 | .. code-block:: yaml
16 |
17 | search:
18 | algorithm: aizynthfinder.search.retrostar.search_tree.SearchTree
19 |
20 |
21 | This will use Retro* without a constant-valued oracle function. To specify the oracle function, you can
22 | do
23 |
24 | .. code-block:: yaml
25 |
26 | search:
27 | algorithm: aizynthfinder.search.retrostar.search_tree.SearchTree
28 | algorithm_config:
29 | molecule_cost:
30 | cost: aizynthfinder.search.retrostar.cost.RetroStarCost
31 | model_path: retrostar_value_model.pickle
32 | fingerprint_length: 2048
33 | fingerprint_radius: 2
34 | dropout_rate: 0.1
35 |
36 | The pickle file can be downloaded from `here `_
37 |
38 |
39 | Using multiple expansion policies
40 | ---------------------------------
41 |
42 | AiZynthFinder can use multiple expansion policies. This gives an example how a general USPTO and a RingBreaker model
43 | can be used together
44 |
45 | .. code-block:: yaml
46 |
47 | expansion:
48 | uspto:
49 | - uspto_keras_model.hdf5
50 | - uspto_unique_templates.csv.gz
51 | ringbreaker:
52 | - uspto_ringbreaker_keras_model.hdf5
53 | - uspto_ringbreaker_unique_templates.csv.gz
54 | multi_expansion_strategy:
55 | type: aizynthfinder.context.policy.MultiExpansionStrategy
56 | expansion_strategies: [uspto, ringbreaker]
57 | additive_expansion: True
58 |
59 | and then to use this with ``aizynthcli`` do something like this
60 |
61 | .. code-block::
62 |
63 | aizynthcli --smiles smiles.txt --config config.yml --policy multi_expansion_strategy
64 |
65 |
66 | Output more routes
67 | ------------------
68 |
69 | The number of routes in the output of ``aizynthcli`` can be controlled from the configuration file.
70 |
71 | This is how you can extract at least 25 routes but not more than 50 per target
72 |
73 | .. code-block:: yaml
74 |
75 | post_processing:
76 | min_routes: 25
77 | max_routes: 50
78 |
79 | Alternatively, you can extract all solved routes. If a target is unsolved, it will return the number
80 | of routes specified by ``min_routes`` and ``max_routes``.
81 |
82 | .. code-block:: yaml
83 |
84 | post_processing:
85 | min_routes: 5
86 | max_routes: 10
87 | all_routes: True
88 |
89 |
90 | Running multi-objective (MO) MCTS with disconnection-aware Chemformer
91 | ------------------
92 | Disconnection-aware retrosynthesis can be done with 1) MO-MCTS (state score + broken bonds score), 2) Chemformer or 3) both.
93 |
94 | First, you need to specify the bond constraints under search, see below.
95 | To run the MO-MCTS with the "broken bonds" score, add the "broken bonds" score to the list of search_rewards:
96 |
97 | .. code-block:: yaml
98 |
99 | search:
100 | break_bonds: [[1, 2], [3, 4]]
101 | freeze_bonds: []
102 | algorithm_config:
103 | search_rewards: ["state score", "broken bonds"]
104 |
105 | To use the disconnection-aware Chemformer, you first need to add the `plugins` folder to the `PYTHONPATH`, e.g.
106 |
107 | export PYTHONPATH=~/aizynthfinder/plugins/
108 |
109 | The script for starting a disconnection-aware Chemformer service is available at https://github.com/MolecularAI/Chemformer.
110 | The multi-expansion policy with template-based model and Chemformer is specified with:
111 |
112 | .. code-block:: yaml
113 | expansion:
114 | standard:
115 | type: template-based
116 | model: path/to/model
117 | template: path/to/templates
118 | chemformer_disconnect:
119 | type: expansion_strategies.DisconnectionAwareExpansionStrategy
120 | url: "http://localhost:8023/chemformer-disconnect-api/predict-disconnection"
121 | n_beams: 5
122 | multi_expansion:
123 | type: aizynthfinder.context.policy.MultiExpansionStrategy
124 | expansion_strategies: [chemformer_disconnect, standard]
125 | additive_expansion: True
126 | cutoff_number: 50
127 |
128 |
129 | To use MO-tree ranking and building, set:
130 |
131 | .. code-block:: yaml
132 | post_processing:
133 | route_scorers: ["state score", "broken bonds"]
134 |
135 | Note: If post_processing.route_scorers is not specified, it will default to search.algorithm_config.search_rewards.
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | aizynthfinder documentation
2 | ===========================
3 |
4 | aizynthfinder is a tool for retrosynthetic planning. The default algorithm is based on a Monte Carlo tree search that recursively breaks down a molecule to purchasable precursors. The tree search is guided by a policy that suggests possible precursors by utilizing a neural network trained on a library of known reaction templates.
5 |
6 | Introduction
7 | ------------
8 |
9 | You run retrosynthesis experiments you need a trained model and a stock collection. You can download a public available model based on USPTO and a stock collection from ZINC database.
10 |
11 | .. code-block::
12 |
13 | download_public_data .
14 |
15 | This will download the data to your current directory. The ``config.yml`` file can be used directly with the interfaces.
16 |
17 |
18 | There are two main interfaces provided by the package:
19 |
20 | * a script that performs tree search in batch mode and
21 | * an interface that is providing a GUI within a Jupyter notebook.
22 |
23 |
24 | The GUI interface should be run in a Jupyter notebook. This is a simple example of the code in a Jupyter notebook cell.
25 |
26 | .. code-block::
27 |
28 | from aizynthfinder.interfaces import AiZynthApp
29 | app = AiZynthApp("/path/to/configfile.yaml")
30 |
31 | where the ``AiZynthApp`` class needs to be instantiated with the path to a configuration file (see :doc:`here `).
32 |
33 | To use the interface, follow these steps:
34 |
35 | 1. Executed the code in the cell (press ``Ctrl+Enter``) and a simple GUI will appear
36 | 2. Enter the target SMILES and select stocks and policy model.
37 | 3. Press the ``Run Search`` button to perform the tree search.
38 | 4. Press the ``Show Reactions`` to see the top-ranked routes
39 |
40 |
41 |
42 | The batch-mode script is called ``aizynthcli`` and can be executed like:
43 |
44 | .. code-block:: bash
45 |
46 | aizynthcli --config config.yml --smiles smiles.txt
47 |
48 |
49 | where `config.yml` contains configurations such as paths to policy models and stocks (see :doc:`here `), and `smiles.txt` is a simple text
50 | file with SMILES (one on each row).
51 |
52 | If you just want to perform the tree search on a single molecule. You can directly specify it on the command-line
53 | within quotes:
54 |
55 | .. code-block:: bash
56 |
57 | aizynthcli --config config.yml --smiles "COc1cccc(OC(=O)/C=C/c2cc(OC)c(OC)c(OC)c2)c1"
58 |
59 |
60 | The output is some statistics about the tree search, the scores of the top-ranked routes, and the reaction tree
61 | of the top-ranked routes. When smiles are provided in a text file the results are stored in a JSON file,
62 | whereas if the SMILEs is provided on the command-line it is printed directly to the prompt
63 | (except the reaction trees, which are written to a JSON file).
64 |
65 |
66 | .. toctree::
67 | :hidden:
68 |
69 | gui
70 | cli
71 | python_interface
72 | configuration
73 | stocks
74 | scoring
75 | howto
76 | aizynthfinder
77 | sequences
78 | relationships
--------------------------------------------------------------------------------
/docs/line-desc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/docs/line-desc.png
--------------------------------------------------------------------------------
/docs/python_interface.rst:
--------------------------------------------------------------------------------
1 | Python interface
2 | ================
3 |
4 | This page gives a quick example of how the tree search can be completed
5 | by writing your own python interface. This is not recommended for most users.
6 |
7 |
8 | 1. Import the necessary class
9 |
10 | .. code-block:: python
11 |
12 | from aizynthfinder.aizynthfinder import AiZynthFinder
13 |
14 |
15 | 2. Instantiate that class by providing a configuration file.
16 |
17 |
18 | .. code-block:: python
19 |
20 | filename = "config.yml"
21 | finder = AiZynthFinder(configfile=filename)
22 |
23 |
24 | 3. Select stock and policy
25 |
26 |
27 | .. code-block:: python
28 |
29 | finder.stock.select("zinc")
30 | finder.expansion_policy.select("uspto")
31 | finder.filter_policy.select("uspto")
32 |
33 | `zinc` and `uspto` where the keys given to the stock and the policy in the configuration file.
34 | The first policy set is the expansion policy and the second is the filter policy. The filter policy is optional.
35 |
36 | 4. Set the target SMILES and perform the tree search
37 |
38 |
39 | .. code-block:: python
40 |
41 | finder.target_smiles = "Cc1cccc(c1N(CC(=O)Nc2ccc(cc2)c3ncon3)C(=O)C4CCS(=O)(=O)CC4)C"
42 | finder.tree_search()
43 |
44 |
45 | 5. Analyse the search tree and build routes
46 |
47 |
48 | .. code-block:: python
49 |
50 | finder.build_routes()
51 | stats = finder.extract_statistics()
52 |
53 |
54 | The ``build_routes`` method needs to be called before any analysis can be done.
55 |
56 | Expansion interface
57 | -------------------
58 |
59 | There is an interface for the expansion policy as well. It can be used to break down a molecule into reactants.
60 |
61 | .. code-block:: python
62 |
63 | filename = "config.yml"
64 | expander = AiZynthExpander(configfile=filename)
65 | expander.expansion_policy.select("uspto")
66 | expander.filter_policy.select("uspto")
67 | reactions = expander.do_expansion("Cc1cccc(c1N(CC(=O)Nc2ccc(cc2)c3ncon3)C(=O)C4CCS(=O)(=O)CC4)C")
68 |
69 | for this, you only need to select the policies. The filter policy is optional and using it will only add the
70 | feasibility of the reactions not filter it out.
71 |
72 | The result is a nested list of `FixedRetroReaction` objects. This you can manipulate to for instance get
73 | out all the reactants SMILES strings
74 |
75 | .. code-block:: python
76 |
77 | reactants_smiles = []
78 | for reaction_tuple in reactions:
79 | reactants_smiles.append([mol.smiles for mol in reaction_tuple[0].reactants[0])
80 |
81 | or you can put all the metadata of all the reactions in a pandas dataframe
82 |
83 | .. code-block:: python
84 |
85 | import pandas as pd
86 | metadata = []
87 | for reaction_tuple in reactions:
88 | for reaction in reaction_tuple:
89 | metadata.append(reaction.metadata)
90 | df = pd.DataFrame(metadata)
91 |
92 |
93 | Further reading
94 | ---------------
95 |
96 | The docstrings of all modules, classes and methods can be consulted :doc:`here `
97 |
98 |
99 | and you can always find them in an interactive Python shell using for instance:
100 |
101 | .. code-block:: python
102 |
103 | from aizynthfinder.chem import Molecule
104 | help(Molecule)
105 | help(Molecule.fingerprint)
106 |
107 |
108 | If you are interested in the the relationships between the classes have a look :doc:`here `
109 | and if you want to dig deeper in to the main algorithmic sequences have a look :doc:`here `
--------------------------------------------------------------------------------
/docs/relationships.rst:
--------------------------------------------------------------------------------
1 | Relationships
2 | =============
3 |
4 | This page shows some relationship diagrams, i.e. how the different objects are connect in a typical retrosynthesis
5 | analysis using Monte Carlo tree search.
6 |
7 | These are the tree different types of relationships used:
8 |
9 | .. image:: line-desc.png
10 |
11 | Tree search
12 | -----------
13 |
14 | This diagram explains how the different object are connect that are responsible for the Monte-Carlo tree search.
15 |
16 | .. image:: treesearch-rel.png
17 |
18 |
19 | Analysis / post-processing
20 | --------------------------
21 |
22 | This diagram explains how the different objects involved in the analysis of the search are connected.
23 |
24 | .. image:: analysis-rel.png
25 |
--------------------------------------------------------------------------------
/docs/scoring.rst:
--------------------------------------------------------------------------------
1 | Scoring
2 | =======
3 |
4 | aizynthfinder is capable of scoring reaction routes, both in the form of ``MctsNode`` objects when a search tree is available,
5 | and in the form of ``ReactionTree`` objects if post-processing is required.
6 |
7 | Currently, there are a few scoring functions available
8 |
9 | * State score - a function of the number of precursors in stock and the length of the route
10 | * Number of reactions - the number of steps in the route
11 | * Number of pre-cursors - the number of pre-cursors in the route
12 | * Number of pre-cursors in stock - the number of the pre-cursors that are purchaseable
13 | * Average template occurrence - the average occurrence of the templates used in the route
14 | * Sum of prices - the plain sum of the price of all pre-cursors
15 | * Route cost score - the cost of the synthesizing the route (Badowski et al. Chem Sci. 2019, 10, 4640)
16 |
17 |
18 | The *State score* is the score that is guiding the tree search in the :doc:`update phase `, and
19 | this is not configurable.
20 |
21 | In the Jupyter notebook :doc:`GUI ` one can choose to score the routes with any of the loaded the scorers.
22 |
23 | The first four scoring functions are loaded automatically when an ``aizynthfinder`` object is created.
24 |
25 |
26 | Add new scoring functions
27 | -------------------------
28 |
29 |
30 | Additional scoring functions can be implemented by inheriting from the class ``Scorer`` in the ``aizynthfinder.context.scoring.scorers`` module.
31 | The scoring class needs to implement the ``_score_node``, ``_score_reaction_tree`` and the ``__repr__`` methods.
32 |
33 | This is an example of that.
34 |
35 | .. code-block:: python
36 |
37 | from aizynthfinder.context.scoring.scorers import Scorer
38 |
39 | class DeltaNumberOfTransformsScorer(Scorer):
40 |
41 | def __repr__(self):
42 | return "delta number of transforms"
43 |
44 | def _score_node(self, node):
45 | return self._config.max_transforms - node.state.max_transforms
46 |
47 | def _score_reaction_tree(self, tree):
48 | return self._config.max_transforms - len(list(tree.reactions()))
49 |
50 |
51 | This can then be added to the ``scorers`` attribute of an ``aizynthfinderfinder`` object. The ``scorers`` attribute is a collection
52 | of ``Scorer`` objects.
53 |
54 | For instance to use this in the Jupyter notebook GUI, one can do
55 |
56 | .. code-block:: python
57 |
58 | from aizynthfinder.interfaces import AiZynthApp
59 | app = AiZynthApp("config_local.yml", setup=False)
60 | scorer = DeltaNumberOfTransformsScorer(app.finder.config)
61 | app.finder.scorers.load(scorer)
62 | app.setup()
63 |
64 |
--------------------------------------------------------------------------------
/docs/treesearch-rel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/docs/treesearch-rel.png
--------------------------------------------------------------------------------
/docs/treesearch-seq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/docs/treesearch-seq.png
--------------------------------------------------------------------------------
/env-dev.yml:
--------------------------------------------------------------------------------
1 | name: aizynth-dev
2 | channels:
3 | - https://conda.anaconda.org/conda-forge
4 | - defaults
5 | dependencies:
6 | - python>=3.9,<3.11
7 | - poetry>=1.1.4,<1.4.1
--------------------------------------------------------------------------------
/plugins/README.md:
--------------------------------------------------------------------------------
1 | # Plugins
2 |
3 | This folder contains some features for `aizynthfinder` that
4 | does not yet fit into the main codebase. It could be experimental
5 | features, or features that require the user to install some
6 | additional third-party dependencies.
7 |
8 | For the expansion models, you generally need to add the `plugins` folder to the `PYTHONPATH`, e.g.
9 |
10 | export PYTHONPATH=~/aizynthfinder/plugins/
11 |
12 | where the `aizynthfinder` repository is in the home folder
13 |
14 | ## Chemformer expansion model
15 |
16 | An expansion model using a REST API for the Chemformer model
17 | is supplied in the `expansion_strategies` module.
18 |
19 | To use it, you first need to install the `chemformer` package
20 | and launch the REST API service that comes with it.
21 |
22 | To use the expansion model in `aizynthfinder` you can use a config-file
23 | containing these lines
24 |
25 | expansion:
26 | chemformer:
27 | type: expansion_strategies.ChemformerBasedExpansionStrategy
28 | url: http://localhost:8000/chemformer-api/predict
29 | search:
30 | algorithm_config:
31 | immediate_instantiation: [chemformer]
32 | time_limit: 300
33 |
34 | The `time_limit` is a recommandation for allowing the more expensive expansion model
35 | to finish a sufficient number of retrosynthesis iterations.
36 |
37 | You would have to change `localhost:8000` to the name and port of the machine hosting the REST service.
38 |
39 | You can then use the config-file with either `aizynthcli` or the Jupyter notebook interface.
40 |
41 | ## ModelZoo expansion model
42 |
43 | An expansion model using the ModelZoo feature is supplied in the `expansion_strategies`
44 | module. This is an adoption of the code from this repo: `https://github.com/AlanHassen/modelsmatter` that were used in the publications [Models Matter: The Impact of Single-Step Models on Synthesis Prediction](https://arxiv.org/abs/2308.05522) and [Mind the Retrosynthesis Gap: Bridging the divide between Single-step and Multi-step Retrosynthesis Prediction](https://openreview.net/forum?id=LjdtY0hM7tf).
45 |
46 | To use it, you first need to install the `modelsmatter_modelzoo` package from
47 | https://github.com/PTorrenPeraire/modelsmatter_modelzoo and set up the `ssbenchmark`
48 | environment.
49 |
50 | Ensure that the `external_models` sub-package contains the models required.
51 | If it does not, you will need to manually clone the required model repositories
52 | within `external_models`.
53 |
54 | To use the expansion model in `aizynthfinder`, you can specify it in the config-file
55 | under `expansion`. Here is an example setting to use the expansion model with `chemformer`
56 | as the external model:
57 |
58 | expansion:
59 | chemformer:
60 | type: expansion_strategies.ModelZooExpansionStrategy:
61 | module_path: /path_to_folder_containing_cloned_repository/modelsmatter_modelzoo/external_models/modelsmatter_chemformer_hpc/
62 | use_gpu: False
63 | params:
64 | module_path: /path_to_model_file/chemformer_backward.ckpt
65 | vocab_path: /path_to_vocab_file/bart_vocab_downstream.txt
66 | search:
67 | algorithm_config:
68 | immediate_instantiation: [chemformer]
69 | time_limit: 300
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "aizynthfinder"
3 | version = "4.3.2"
4 | description = "Retrosynthetic route finding using neural network guided Monte-Carlo tree search"
5 | authors = ["Molecular AI group "]
6 | license = "MIT"
7 | include = ["aizynthfinder/data/*.yml", "aizynthfinder/data/templates/*"]
8 | readme = "README.md"
9 | homepage = "https://github.com/MolecularAI/aizynthfinder/"
10 | repository = "https://github.com/MolecularAI/aizynthfinder/"
11 | documentation = "https://molecularai.github.io/aizynthfinder/"
12 |
13 | [tool.pytest.ini_options]
14 | mccabe-complexity = 9
15 |
16 | [tool.pylint.'MESSAGES CONTROL']
17 | max-line-length = 120
18 | max-args = 8
19 | max-attributes = 20
20 | min-public-methods = 0
21 | disable = "C0116, E0401, E1101, I1101, R0801, R0902, R0903, R0914, R1732, R1735, W0221, W0237, W0406, W0602, W0603, W0707, W1201, W1203, W1514, W3101"
22 |
23 | [tool.coverage.report]
24 | exclude_also = [
25 | "if TYPE_CHECKING:"
26 | ]
27 |
28 | [tool.poetry.dependencies]
29 | python = ">=3.9,<3.12"
30 | ipywidgets = "^7.5.1"
31 | jinja2 = "^3.0.0"
32 | jupyter = "^1.0.0"
33 | jupytext = "^1.3.3"
34 | notebook = "^6.5.3"
35 | networkx = "^2.4"
36 | deprecated = "^1.2.10"
37 | pandas = "^1.0.0"
38 | pillow = "^9.0.0"
39 | requests = "^2.23.0"
40 | rdchiral = "^1.0.0"
41 | rdkit = "^2023.9.1"
42 | tables = "^3.6.1"
43 | tqdm = "^4.42.1"
44 | onnxruntime = "<1.17.0"
45 | tensorflow = {version = "^2.8.0", optional=true}
46 | grpcio = {version = "^1.24.0", optional=true}
47 | tensorflow-serving-api = {version = "^2.1.0", optional=true}
48 | pymongo = {version = "^3.10.1", optional=true}
49 | route-distances = {version = "^1.2.4", optional=true}
50 | scipy = {version = "^1.0", optional=true}
51 | matplotlib = "^3.0.0"
52 | timeout-decorator = {version = "^0.5.0", optional=true}
53 | molbloom = {version = "^2.1.0", optional=true}
54 | paretoset = "^1.2.3"
55 | seaborn = "^0.13.2"
56 | numpy = "<2.0.0"
57 |
58 | [tool.poetry.dev-dependencies]
59 | black = "^22.0.0"
60 | invoke = "^2.2.0"
61 | pytest = "^6.2.2"
62 | pytest-black = "^0.3.12"
63 | pytest-cov = "^2.11.0"
64 | pytest-datadir = "^1.3.1"
65 | pytest-mock = "^3.5.0"
66 | pytest-mccabe = "^2.0.0"
67 | Sphinx = "^7.3.7"
68 | mypy = "^1.0.0"
69 | pylint = "^2.16.0"
70 |
71 | [tool.poetry.extras]
72 | all = ["pymongo", "route-distances", "scipy", "timeout-decorator", "molbloom"]
73 | tf = ["tensorflow", "grpcio", "tensorflow-serving-api"]
74 |
75 | [tool.poetry.scripts]
76 | aizynthapp = "aizynthfinder.interfaces.aizynthapp:main"
77 | aizynthcli = "aizynthfinder.interfaces.aizynthcli:main"
78 | cat_aizynth_output = "aizynthfinder.tools.cat_output:main"
79 | download_public_data = "aizynthfinder.tools.download_public_data:main"
80 | smiles2stock = "aizynthfinder.tools.make_stock:main"
81 |
82 | [tool.coverage.run]
83 | relative_files = true
84 |
85 | [build-system]
86 | requires = ["poetry_core>=1.0.0"]
87 | build-backend = "poetry.core.masonry.api"
88 |
--------------------------------------------------------------------------------
/tasks.py:
--------------------------------------------------------------------------------
1 | from invoke import task
2 |
3 |
4 | @task
5 | def build_docs(context):
6 | context.run("aizynthcli -h > ./docs/cli_help.txt")
7 | context.run("sphinx-apidoc -o ./docs ./aizynthfinder")
8 | context.run("sphinx-build -M html ./docs ./docs/build")
9 |
10 |
11 | @task
12 | def full_tests(context):
13 | cmd = (
14 | "pytest --black --mccabe "
15 | "--cov aizynthfinder --cov-branch --cov-report html:coverage --cov-report xml "
16 | "tests/"
17 | )
18 | context.run(cmd)
19 |
20 |
21 | @task
22 | def run_mypy(context):
23 | context.run("mypy --ignore-missing-imports --show-error-codes aizynthfinder")
24 |
25 |
26 | @task
27 | def run_linting(context):
28 | print("Running mypy...")
29 | context.run("mypy --install-types", pty=True)
30 | context.run(
31 | "mypy --ignore-missing-imports --show-error-codes --implicit-optional aizynthfinder"
32 | )
33 | print("Running pylint...")
34 | context.run("pylint aizynthfinder")
35 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/tests/__init__.py
--------------------------------------------------------------------------------
/tests/breadth_first/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/tests/breadth_first/__init__.py
--------------------------------------------------------------------------------
/tests/breadth_first/test_nodes.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from aizynthfinder.search.breadth_first.nodes import MoleculeNode
4 | from aizynthfinder.chem.serialization import MoleculeSerializer, MoleculeDeserializer
5 |
6 |
7 | @pytest.fixture
8 | def setup_root(default_config):
9 | def wrapper(smiles):
10 | return MoleculeNode.create_root(smiles, config=default_config)
11 |
12 | return wrapper
13 |
14 |
15 | def test_create_root_node(setup_root):
16 | node = setup_root("CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1")
17 |
18 | assert node.ancestors() == {node.mol}
19 | assert node.expandable
20 | assert not node.children
21 |
22 |
23 | def test_create_stub(setup_root, get_action):
24 | root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1"
25 | root = setup_root(root_smiles)
26 | reaction = get_action()
27 |
28 | root.add_stub(reaction=reaction)
29 |
30 | assert len(root.children) == 1
31 | assert len(root.children[0].children) == 2
32 | rxn_node = root.children[0]
33 | assert rxn_node.reaction is reaction
34 | exp_list = [node.mol for node in rxn_node.children]
35 | assert exp_list == list(reaction.reactants[0])
36 |
37 |
38 | def test_initialize_stub_one_solved_leaf(
39 | setup_root, get_action, default_config, setup_stock
40 | ):
41 | root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1"
42 | root = setup_root(root_smiles)
43 | reaction = get_action()
44 | setup_stock(default_config, reaction.reactants[0][0])
45 |
46 | root.add_stub(reaction=reaction)
47 |
48 | assert not root.children[0].children[0].expandable
49 | assert root.children[0].children[1].expandable
50 |
51 |
52 | def test_serialization_deserialization(
53 | setup_root, get_action, default_config, setup_stock
54 | ):
55 | root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1"
56 | root = setup_root(root_smiles)
57 | reaction = get_action()
58 | setup_stock(default_config, *reaction.reactants[0])
59 | root.add_stub(reaction=reaction)
60 |
61 | molecule_serializer = MoleculeSerializer()
62 | dict_ = root.serialize(molecule_serializer)
63 |
64 | molecule_deserializer = MoleculeDeserializer(molecule_serializer.store)
65 | node = MoleculeNode.from_dict(dict_, default_config, molecule_deserializer)
66 |
67 | assert node.mol == root.mol
68 | assert len(node.children) == len(root.children)
69 |
70 | rxn_node = node.children[0]
71 | assert rxn_node.reaction.smarts == reaction.smarts
72 | assert rxn_node.reaction.metadata == reaction.metadata
73 |
74 | for grandchild1, grandchild2 in zip(rxn_node.children, root.children[0].children):
75 | assert grandchild1.mol == grandchild2.mol
76 |
--------------------------------------------------------------------------------
/tests/breadth_first/test_search.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import pytest
4 |
5 | from aizynthfinder.search.breadth_first.search_tree import SearchTree
6 |
7 |
8 | def test_one_iteration(default_config, setup_policies, setup_stock):
9 | root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1"
10 | child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"]
11 | child2_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F"]
12 | grandchild_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"]
13 | lookup = {
14 | root_smi: [
15 | {"smiles": ".".join(child1_smi), "prior": 0.7},
16 | {"smiles": ".".join(child2_smi), "prior": 0.3},
17 | ],
18 | child1_smi[1]: {"smiles": ".".join(grandchild_smi), "prior": 0.7},
19 | child2_smi[1]: {"smiles": ".".join(grandchild_smi), "prior": 0.7},
20 | }
21 | stock = [child1_smi[0], child1_smi[2]] + grandchild_smi
22 | setup_policies(lookup, config=default_config)
23 | setup_stock(default_config, *stock)
24 | tree = SearchTree(default_config, root_smi)
25 |
26 | assert len(tree.mol_nodes) == 1
27 |
28 | assert not tree.one_iteration()
29 |
30 | assert len(tree.mol_nodes) == 6
31 | smiles = [node.mol.smiles for node in tree.mol_nodes]
32 | assert smiles == [root_smi] + child1_smi + child2_smi
33 |
34 | assert tree.one_iteration()
35 |
36 | assert len(tree.mol_nodes) == 10
37 | smiles = [node.mol.smiles for node in tree.mol_nodes]
38 | assert (
39 | smiles == [root_smi] + child1_smi + child2_smi + grandchild_smi + grandchild_smi
40 | )
41 |
42 | with pytest.raises(StopIteration):
43 | tree.one_iteration()
44 |
45 |
46 | def test_search_incomplete(default_config, setup_policies, setup_stock):
47 | root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1"
48 | child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"]
49 | child2_smi = ["ClC(=O)c1ccc(F)cc1", "CN1CCC(CC1)C(=O)c1cccc(N)c1F"]
50 | grandchild_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"]
51 | lookup = {
52 | root_smi: [
53 | {"smiles": ".".join(child1_smi), "prior": 0.7},
54 | {"smiles": ".".join(child2_smi), "prior": 0.3},
55 | ],
56 | child1_smi[1]: {"smiles": ".".join(grandchild_smi), "prior": 0.7},
57 | }
58 | stock = [child1_smi[0], child1_smi[2]] + [grandchild_smi[0]]
59 | setup_policies(lookup, config=default_config)
60 | setup_stock(default_config, *stock)
61 | tree = SearchTree(default_config, root_smi)
62 |
63 | assert len(tree.mol_nodes) == 1
64 |
65 | tree.one_iteration()
66 | assert len(tree.mol_nodes) == 6
67 |
68 | assert not tree.one_iteration()
69 |
70 | assert len(tree.mol_nodes) == 8
71 |
72 | with pytest.raises(StopIteration):
73 | tree.one_iteration()
74 |
75 |
76 | def test_routes(default_config, setup_policies, setup_stock):
77 | random.seed(666)
78 | root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1"
79 | child1_smi = ["O", "CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F"]
80 | child2_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F"]
81 | grandchild_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"]
82 | lookup = {
83 | root_smi: [
84 | {"smiles": ".".join(child1_smi), "prior": 0.7},
85 | {"smiles": ".".join(child2_smi), "prior": 0.3},
86 | ],
87 | child1_smi[1]: {"smiles": ".".join(grandchild_smi), "prior": 0.7},
88 | }
89 | stock = [child1_smi[0], child1_smi[2]] + grandchild_smi
90 | setup_policies(lookup, config=default_config)
91 | setup_stock(default_config, *stock)
92 | tree = SearchTree(default_config, root_smi)
93 |
94 | while True:
95 | try:
96 | tree.one_iteration()
97 | except StopIteration:
98 | break
99 |
100 | routes = tree.routes()
101 |
102 | assert len(routes) == 2
103 | smiles = [mol.smiles for mol in routes[1].molecules()]
104 | assert smiles == [root_smi] + child1_smi + grandchild_smi
105 | smiles = [mol.smiles for mol in routes[0].molecules()]
106 | assert smiles == [root_smi] + child2_smi + grandchild_smi
107 |
--------------------------------------------------------------------------------
/tests/chem/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/tests/chem/__init__.py
--------------------------------------------------------------------------------
/tests/chem/test_mol.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from rdkit import Chem
3 |
4 | from aizynthfinder.chem import MoleculeException, Molecule
5 |
6 |
7 | def test_no_input():
8 | with pytest.raises(MoleculeException):
9 | Molecule()
10 |
11 |
12 | def test_create_with_mol():
13 | rd_mol = Chem.MolFromSmiles("O")
14 |
15 | mol = Molecule(rd_mol=rd_mol)
16 |
17 | assert mol.smiles == "O"
18 |
19 |
20 | def test_create_with_smiles():
21 | mol = Molecule(smiles="O")
22 |
23 | assert Chem.MolToSmiles(mol.rd_mol) == "O"
24 |
25 |
26 | def test_inchi():
27 | mol = Molecule(smiles="O")
28 |
29 | assert mol.inchi == "InChI=1S/H2O/h1H2"
30 |
31 |
32 | def test_inchi_key():
33 | mol = Molecule(smiles="O")
34 |
35 | assert mol.inchi_key == "XLYOFNOQVPJJNP-UHFFFAOYSA-N"
36 |
37 |
38 | def test_fingerprint():
39 | mol = Molecule(smiles="O")
40 |
41 | assert sum(mol.fingerprint(2)) == 1
42 |
43 | assert sum(mol.fingerprint(2, 10)) == 1
44 |
45 |
46 | def test_sanitize():
47 | mol = Molecule(smiles="O", sanitize=True)
48 |
49 | assert Chem.MolToSmiles(mol.rd_mol) == "O"
50 |
51 | mol = Molecule(smiles="c1ccccc1(C)(C)")
52 |
53 | with pytest.raises(MoleculeException):
54 | mol.sanitize()
55 |
56 | mol.sanitize(raise_exception=False)
57 | assert mol.smiles == "CC1(C)CCCCC1"
58 |
59 |
60 | def test_equality():
61 | mol1 = Molecule(smiles="CCCCO")
62 | mol2 = Molecule(smiles="OCCCC")
63 |
64 | assert mol1 == mol2
65 |
66 |
67 | def test_basic_equality():
68 | mol1 = Molecule(smiles="CC[C@@H](C)O") # R-2-butanol
69 | mol2 = Molecule(smiles="CC[C@H](C)O") # S-2-butanol
70 |
71 | assert mol1 != mol2
72 | assert mol1.basic_compare(mol2)
73 |
74 |
75 | def test_has_atom_mapping():
76 | mol1 = Molecule(smiles="CCCCO")
77 | mol2 = Molecule(smiles="C[C:5]CCO")
78 |
79 | assert not mol1.has_atom_mapping()
80 | assert mol2.has_atom_mapping()
81 |
82 |
83 | def test_remove_atom_mapping():
84 | mol = Molecule(smiles="C[C:5]CCO")
85 |
86 | assert mol.has_atom_mapping()
87 |
88 | mol.remove_atom_mapping()
89 |
90 | assert not mol.has_atom_mapping()
91 |
92 |
93 | def test_chiral_fingerprint():
94 | mol1 = Molecule(smiles="C[C@@H](C(=O)O)N")
95 | mol2 = Molecule(smiles="C[C@@H](C(=O)O)N")
96 |
97 | fp1 = mol1.fingerprint(radius=2, chiral=False)
98 | fp2 = mol2.fingerprint(radius=2, chiral=True)
99 |
100 | assert fp1.tolist() != fp2.tolist()
101 |
--------------------------------------------------------------------------------
/tests/chem/test_serialization.py:
--------------------------------------------------------------------------------
1 | from aizynthfinder.chem.serialization import MoleculeSerializer, MoleculeDeserializer
2 | from aizynthfinder.chem import Molecule, TreeMolecule
3 |
4 |
5 | def test_empty_store():
6 | serializer = MoleculeSerializer()
7 |
8 | assert serializer.store == {}
9 |
10 |
11 | def test_add_single_mol():
12 | serializer = MoleculeSerializer()
13 | mol = Molecule(smiles="CCC")
14 |
15 | id_ = serializer[mol]
16 |
17 | assert id_ == id(mol)
18 | assert serializer.store == {id_: {"smiles": "CCC", "class": "Molecule"}}
19 |
20 |
21 | def test_add_tree_mol():
22 | serializer = MoleculeSerializer()
23 | mol1 = TreeMolecule(parent=None, smiles="CCC", transform=1)
24 | mol2 = TreeMolecule(smiles="CCO", parent=mol1)
25 |
26 | id_ = serializer[mol2]
27 |
28 | assert id_ == id(mol2)
29 | assert list(serializer.store.keys()) == [id(mol1), id_]
30 | assert serializer.store == {
31 | id_: {
32 | "smiles": "CCO",
33 | "class": "TreeMolecule",
34 | "parent": id(mol1),
35 | "transform": 2,
36 | },
37 | id(mol1): {
38 | "smiles": "CCC",
39 | "class": "TreeMolecule",
40 | "parent": None,
41 | "transform": 1,
42 | },
43 | }
44 |
45 |
46 | def test_deserialize_single_mol():
47 | store = {123: {"smiles": "CCC", "class": "Molecule"}}
48 | deserializer = MoleculeDeserializer(store)
49 |
50 | assert deserializer[123].smiles == "CCC"
51 |
52 |
53 | def test_deserialize_tree_mols():
54 | store = {
55 | 123: {
56 | "smiles": "CCC",
57 | "class": "TreeMolecule",
58 | "parent": None,
59 | "transform": 1,
60 | },
61 | 234: {"smiles": "CCO", "class": "TreeMolecule", "parent": 123, "transform": 2},
62 | }
63 |
64 | deserializer = MoleculeDeserializer(store)
65 |
66 | assert deserializer[123].smiles == "CCC"
67 | assert deserializer[234].smiles == "CCO"
68 | assert deserializer[123].parent is None
69 | assert deserializer[234].parent is deserializer[123]
70 | assert deserializer[123].transform == 1
71 | assert deserializer[234].transform == 2
72 |
73 |
74 | def test_chaining():
75 | serializer = MoleculeSerializer()
76 | mol1 = TreeMolecule(parent=None, smiles="CCC", transform=1)
77 | mol2 = TreeMolecule(smiles="CCO", parent=mol1)
78 |
79 | id_ = serializer[mol2]
80 |
81 | deserializer = MoleculeDeserializer(serializer.store)
82 |
83 | assert deserializer[id_].smiles == mol2.smiles
84 | assert deserializer[id(mol1)].smiles == mol1.smiles
85 | assert id(deserializer[id_]) != id_
86 | assert id(deserializer[id(mol1)]) != id(mol1)
87 |
--------------------------------------------------------------------------------
/tests/context/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/tests/context/__init__.py
--------------------------------------------------------------------------------
/tests/context/conftest.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import pytest
6 | import pytest_mock
7 |
8 | from aizynthfinder.context.policy import TemplateBasedExpansionStrategy
9 | from aizynthfinder.context.scoring import BrokenBondsScorer
10 | from aizynthfinder.context.stock import (
11 | MongoDbInchiKeyQuery,
12 | StockException,
13 | StockQueryMixin,
14 | )
15 |
16 |
17 | @pytest.fixture
18 | def create_templates_file(tmpdir):
19 | def wrapper(templates):
20 | data = {"retro_template": templates}
21 | filename = str(tmpdir / "dummy_templates.hdf5")
22 | pd.DataFrame(data).to_hdf(filename, "table")
23 | return filename
24 |
25 | return wrapper
26 |
27 |
28 | @pytest.fixture
29 | def make_stock_query():
30 | class StockQuery(StockQueryMixin):
31 | def __init__(self, mols, price, amount):
32 | self.mols = mols
33 | self._price = price
34 | self._amount = amount
35 |
36 | def __contains__(self, mol):
37 | return mol in self.mols
38 |
39 | def __len__(self):
40 | return len(self.mols)
41 |
42 | def amount(self, mol):
43 | if mol in self._amount:
44 | return self._amount[mol]
45 | raise StockException()
46 |
47 | def price(self, mol):
48 | if mol in self._price:
49 | return self._price[mol]
50 | raise StockException()
51 |
52 | def wrapper(mols, amount=None, price=None):
53 | return StockQuery(mols, price or {}, amount or {})
54 |
55 | return wrapper
56 |
57 |
58 | @pytest.fixture
59 | def mocked_mongo_db_query(mocker):
60 | mocked_client = mocker.patch("aizynthfinder.context.stock.queries.get_mongo_client")
61 |
62 | def wrapper(**kwargs):
63 | return mocked_client, MongoDbInchiKeyQuery(**kwargs)
64 |
65 | return wrapper
66 |
67 |
68 | @pytest.fixture
69 | def setup_template_expansion_policy(
70 | default_config, create_dummy_templates, create_templates_file, mock_onnx_model
71 | ):
72 | def wrapper(
73 | key="policy1", templates=None, expansion_cls=TemplateBasedExpansionStrategy
74 | ):
75 | if templates is None:
76 | templates_filename = create_dummy_templates(3)
77 | else:
78 | templates_filename = create_templates_file(templates)
79 |
80 | strategy = expansion_cls(
81 | key, default_config, model="dummy.onnx", template=templates_filename
82 | )
83 |
84 | return strategy, mock_onnx_model
85 |
86 | return wrapper
87 |
88 |
89 | @pytest.fixture
90 | def setup_mcts_broken_bonds(setup_stock, setup_expanded_mcts, shared_datadir):
91 | def wrapper(broken=True, config=None):
92 | root_smi = "CN1CCC(C(=O)c2cccc([NH:1][C:2](=O)c3ccc(F)cc3)c2F)CC1"
93 |
94 | reaction_template = pd.read_csv(
95 | shared_datadir / "test_reactions_template.csv", sep="\t"
96 | )
97 | template1_smarts = reaction_template["RetroTemplate"][0]
98 | template2_smarts = reaction_template["RetroTemplate"][1]
99 |
100 | child1_smi = ["N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "CN1CCC(Cl)CC1", "O"]
101 | child2_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"]
102 |
103 | if not broken:
104 | template2_smarts = reaction_template["RetroTemplate"][2]
105 | child2_smi = ["N#Cc1cccc(Cl)c1F", "NC(=O)c1ccc(F)cc1"]
106 |
107 | lookup = {
108 | root_smi: {"smarts": template1_smarts, "prior": 1.0},
109 | child1_smi[0]: {
110 | "smarts": template2_smarts,
111 | "prior": 1.0,
112 | },
113 | }
114 |
115 | stock = [child1_smi[1], child1_smi[2]] + child2_smi
116 |
117 | if config:
118 | config.scorers.create_default_scorers()
119 | config.scorers.load(BrokenBondsScorer(config))
120 | setup_stock(config, *stock)
121 | return setup_expanded_mcts(lookup)
122 |
123 | return wrapper
124 |
--------------------------------------------------------------------------------
/tests/context/data/custom_loader.py:
--------------------------------------------------------------------------------
1 | def extract_smiles(filename):
2 | with open(filename, "r") as fileobj:
3 | for i, line in enumerate(fileobj.readlines()):
4 | if i == 0:
5 | continue
6 | yield line.strip().split(",")[0]
7 |
--------------------------------------------------------------------------------
/tests/context/data/custom_loader2.py:
--------------------------------------------------------------------------------
1 | def extract_smiles():
2 | return ["c1ccccc1", "Cc1ccccc1", "c1ccccc1", "CCO"]
3 |
--------------------------------------------------------------------------------
/tests/context/data/linear_route_w_metadata.json:
--------------------------------------------------------------------------------
1 | {
2 | "type": "mol",
3 | "route_metadata": {
4 | "created_at_iteration": 1,
5 | "is_solved": true
6 | },
7 | "hide": false,
8 | "smiles": "OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1",
9 | "is_chemical": true,
10 | "in_stock": false,
11 | "children": [
12 | {
13 | "type": "reaction",
14 | "hide": false,
15 | "smiles": "OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1>>OOc1ccc(-c2ccccc2)cc1.NC1CCCC(C2C=CC=C2)C1",
16 | "is_reaction": true,
17 | "metadata": {
18 | "classification": "abc"
19 | },
20 | "children": [
21 | {
22 | "type": "mol",
23 | "hide": false,
24 | "smiles": "NC1CCCC(C2C=CC=C2)C1",
25 | "is_chemical": true,
26 | "in_stock": true
27 | },
28 | {
29 | "type": "mol",
30 | "hide": false,
31 | "smiles": "OOc1ccc(-c2ccccc2)cc1",
32 | "is_chemical": true,
33 | "in_stock": false,
34 | "children": [
35 | {
36 | "type": "reaction",
37 | "hide": false,
38 | "smiles": "OOc1ccc(-c2ccccc2)cc1>>c1ccccc1.OOc1ccccc1",
39 | "is_reaction": true,
40 | "metadata": {
41 | "classification": "xyz"
42 | },
43 | "children": [
44 | {
45 | "type": "mol",
46 | "hide": false,
47 | "smiles": "c1ccccc1",
48 | "is_chemical": true,
49 | "in_stock": true
50 | },
51 | {
52 | "type": "mol",
53 | "hide": false,
54 | "smiles": "OOc1ccccc1",
55 | "is_chemical": true,
56 | "in_stock": true
57 | }
58 | ]
59 | }
60 | ]
61 | }
62 | ]
63 | }
64 | ]
65 | }
--------------------------------------------------------------------------------
/tests/context/data/test_reactions_template.csv:
--------------------------------------------------------------------------------
1 | ReactionSmilesClean mapped_rxn confidence RetroTemplate TemplateHash TemplateError
2 | CN1CCC(Cl)CC1.N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F.O>>CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1 Cl[CH:5]1[CH2:4][CH2:3][N:2]([CH3:1])[CH2:26][CH2:25]1.N#[C:6][c:8]1[cH:9][cH:10][cH:11][c:12]([NH:13][C:14](=[O:15])[c:16]2[cH:17][cH:18][c:19]([F:20])[cH:21][cH:22]2)[c:23]1[F:24].[OH2:7]>>[CH3:1][N:2]1[CH2:3][CH2:4][CH:5]([C:6](=[O:7])[c:8]2[cH:9][cH:10][cH:11][c:12]([NH:13][C:14](=[O:15])[c:16]3[cH:17][cH:18][c:19]([F:20])[cH:21][cH:22]3)[c:23]2[F:24])[CH2:25][CH2:26]1 0.9424734380772164 [C:2]-[CH;D3;+0:1](-[C:3])-[C;H0;D3;+0:4](=[O;H0;D1;+0:6])-[c:5]>>Cl-[CH;D3;+0:1](-[C:2])-[C:3].N#[C;H0;D2;+0:4]-[c:5].[OH2;D0;+0:6] 23f00a3c507eef75b22252e8eea99b2ce8eab9a573ff2fbd8227d93f49960a27
3 | N#Cc1cccc(N)c1F.O=C(Cl)c1ccc(F)cc1>>N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F Cl[C:9](=[O:10])[c:11]1[cH:12][cH:13][c:14]([F:15])[cH:16][cH:17]1.[N:1]#[C:2][c:3]1[cH:4][cH:5][cH:6][c:7]([NH2:8])[c:18]1[F:19]>>[N:1]#[C:2][c:3]1[cH:4][cH:5][cH:6][c:7]([NH:8][C:9](=[O:10])[c:11]2[cH:12][cH:13][c:14]([F:15])[cH:16][cH:17]2)[c:18]1[F:19] 0.9855040989023992 [O;D1;H0:2]=[C;H0;D3;+0:1](-[c:3])-[NH;D2;+0:4]-[c:5]>>Cl-[C;H0;D3;+0:1](=[O;D1;H0:2])-[c:3].[NH2;D1;+0:4]-[c:5] dadfa1075f086a1e76ed69f4d1c5cc44999708352c462aac5817785131169c41
4 | NC(=O)c1ccc(F)cc1.Fc1c(Cl)cccc1C#N>>N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F Cl[c:7]1[cH:6][cH:5][cH:4][c:3]([C:2]#[N:1])[c:18]1[F:19].[NH2:8][C:9](=[O:10])[c:11]1[cH:12][cH:13][c:14]([F:15])[cH:16][cH:17]1>>[N:1]#[C:2][c:3]1[cH:4][cH:5][cH:6][c:7]([NH:8][C:9](=[O:10])[c:11]2[cH:12][cH:13][c:14]([F:15])[cH:16][cH:17]2)[c:18]1[F:19] 0.9707971018098916 [O;D1;H0:6]=[C:5](-[NH;D2;+0:4]-[c;H0;D3;+0:1](:[c:2]):[c:3])-[c:7]1:[c:8]:[c:9]:[c:10]:[c:11]:[c:12]:1>>Cl-[c;H0;D3;+0:1](:[c:2]):[c:3].[NH2;D1;+0:4]-[C:5](=[O;D1;H0:6])-[c:7]1:[c:8]:[c:9]:[c:10]:[c:11]:[c:12]:1 01ba54afaa6c16613833b643d8f2503d9222eec7a3834f1cdd002faeb50ef239
5 |
--------------------------------------------------------------------------------
/tests/context/test_collection.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from aizynthfinder.context.collection import ContextCollection
4 |
5 |
6 | class StringCollection(ContextCollection):
7 | def load(self, key, value):
8 | self._items[key] = value
9 |
10 | def load_from_config(self, config):
11 | for key, value in config.items():
12 | self.load(key, value)
13 |
14 |
15 | class SingleStringCollection(StringCollection):
16 | _single_selection = True
17 |
18 |
19 | def test_empty_collection():
20 | collection = StringCollection()
21 |
22 | assert len(collection) == 0
23 | assert collection.selection == []
24 |
25 |
26 | def test_add_single_item():
27 | collection = StringCollection()
28 |
29 | collection.load("key1", "value1")
30 |
31 | assert len(collection) == 1
32 | assert collection.items == ["key1"]
33 | assert collection.selection == []
34 |
35 |
36 | def test_get_item():
37 | collection = StringCollection()
38 | collection.load("key1", "value1")
39 |
40 | assert collection["key1"] == "value1"
41 |
42 | with pytest.raises(KeyError):
43 | collection["key2"]
44 |
45 |
46 | def test_del_item():
47 | collection = StringCollection()
48 | collection.load("key1", "value1")
49 |
50 | del collection["key1"]
51 |
52 | assert len(collection) == 0
53 |
54 | with pytest.raises(KeyError):
55 | del collection["key2"]
56 |
57 |
58 | def test_select_single_item():
59 | collection = StringCollection()
60 | collection.load("key1", "value1")
61 |
62 | collection.selection = "key1"
63 |
64 | assert collection.selection == ["key1"]
65 |
66 | with pytest.raises(KeyError):
67 | collection.selection = "key2"
68 |
69 | collection.load("key2", "value2")
70 |
71 | collection.selection = "key2"
72 |
73 | assert collection.selection == ["key2"]
74 |
75 |
76 | def test_select_append():
77 | collection = StringCollection()
78 | collection.load("key1", "value1")
79 | collection.load("key2", "value2")
80 |
81 | collection.selection = "key1"
82 |
83 | assert collection.selection == ["key1"]
84 |
85 | collection.select("key2", append=True)
86 |
87 | assert collection.selection == ["key1", "key2"]
88 |
89 |
90 | def test_select_all():
91 | collection = StringCollection()
92 |
93 | collection.select_all()
94 |
95 | assert collection.selection == []
96 |
97 | collection.load("key1", "value1")
98 | collection.load("key2", "value2")
99 |
100 | collection.select_all()
101 |
102 | assert collection.selection == ["key1", "key2"]
103 |
104 |
105 | def test_select_first():
106 | collection = StringCollection()
107 |
108 | collection.select_first()
109 |
110 | assert collection.selection == []
111 |
112 | collection.load("key1", "value1")
113 | collection.load("key2", "value2")
114 |
115 | collection.select_first()
116 |
117 | assert collection.selection == ["key1"]
118 |
119 |
120 | def test_select_overwrite():
121 | collection = StringCollection()
122 | collection.load("key1", "value1")
123 | collection.load("key2", "value2")
124 |
125 | collection.selection = "key1"
126 |
127 | assert collection.selection == ["key1"]
128 |
129 | collection.selection = ["key2"]
130 |
131 | assert collection.selection == ["key2"]
132 |
133 |
134 | def test_deselect_all():
135 | collection = StringCollection()
136 | collection.load("key1", "value1")
137 | collection.load("key2", "value2")
138 | collection.selection = ("key1", "key2")
139 |
140 | assert len(collection) == 2
141 | assert collection.selection == ["key1", "key2"]
142 |
143 | collection.deselect()
144 |
145 | assert len(collection.selection) == 0
146 |
147 |
148 | def test_deselect_one():
149 | collection = StringCollection()
150 | collection.load("key1", "value1")
151 | collection.load("key2", "value2")
152 | collection.selection = ("key1", "key2")
153 |
154 | collection.deselect("key1")
155 |
156 | assert collection.selection == ["key2"]
157 |
158 | with pytest.raises(KeyError):
159 | collection.deselect("key1")
160 |
161 |
162 | def test_empty_single_collection():
163 | collection = SingleStringCollection()
164 |
165 | assert len(collection) == 0
166 | assert collection.selection is None
167 |
168 |
169 | def test_select_single_collection():
170 | collection = SingleStringCollection()
171 | collection.load("key1", "value1")
172 | collection.load("key2", "value2")
173 |
174 | collection.selection = "key1"
175 |
176 | assert collection.selection == "key1"
177 |
178 | collection.selection = "key2"
179 |
180 | assert collection.selection == "key2"
181 |
182 |
183 | def test_select_multiple_single_collection():
184 | collection = SingleStringCollection()
185 | collection.load("key1", "value1")
186 | collection.load("key2", "value2")
187 |
188 | collection.selection = ["key1"]
189 |
190 | assert collection.selection == "key1"
191 |
192 | with pytest.raises(ValueError):
193 | collection.selection = ["key1", "key2"]
194 |
--------------------------------------------------------------------------------
/tests/data/branched_route.json:
--------------------------------------------------------------------------------
1 | {
2 | "type": "mol",
3 | "hide": false,
4 | "smiles": "OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1",
5 | "is_chemical": true,
6 | "in_stock": false,
7 | "children": [
8 | {
9 | "type": "reaction",
10 | "hide": false,
11 | "smiles": "OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1>>OOc1ccc(-c2ccccc2)cc1.NC1CCCC(C2C=CC=C2)C1",
12 | "is_reaction": true,
13 | "metadata": {},
14 | "children": [
15 | {
16 | "type": "mol",
17 | "hide": false,
18 | "smiles": "OOc1ccc(-c2ccccc2)cc1",
19 | "is_chemical": true,
20 | "in_stock": false,
21 | "children": [
22 | {
23 | "type": "reaction",
24 | "hide": false,
25 | "smiles": "OOc1ccc(-c2ccccc2)cc1>>c1ccccc1.OOc1ccccc1",
26 | "is_reaction": true,
27 | "metadata": {},
28 | "children": [
29 | {
30 | "type": "mol",
31 | "hide": false,
32 | "smiles": "c1ccccc1",
33 | "is_chemical": true,
34 | "in_stock": true
35 | },
36 | {
37 | "type": "mol",
38 | "hide": false,
39 | "smiles": "OOc1ccccc1",
40 | "is_chemical": true,
41 | "in_stock": false,
42 | "children": [
43 | {
44 | "type": "reaction",
45 | "hide": false,
46 | "smiles": "OOc1ccccc1>>O.Oc1ccccc1",
47 | "is_reaction": true,
48 | "metadata": {},
49 | "children": [
50 | {
51 | "type": "mol",
52 | "hide": false,
53 | "smiles": "O",
54 | "is_chemical": true,
55 | "in_stock": false
56 | },
57 | {
58 | "type": "mol",
59 | "hide": false,
60 | "smiles": "Oc1ccccc1",
61 | "is_chemical": true,
62 | "in_stock": true
63 | }
64 | ]
65 | }
66 | ]
67 | }
68 | ]
69 | }
70 | ]
71 | },
72 | {
73 | "type": "mol",
74 | "hide": false,
75 | "smiles": "NC1CCCC(C2C=CC=C2)C1",
76 | "is_chemical": true,
77 | "in_stock": false,
78 | "children": [
79 | {
80 | "type": "reaction",
81 | "hide": false,
82 | "smiles": "NC1CCCC(C2C=CC=C2)C1>>NC1CCCCC1.C1=CCC=C1",
83 | "is_reaction": true,
84 | "metadata": {},
85 | "children": [
86 | {
87 | "type": "mol",
88 | "hide": false,
89 | "smiles": "NC1CCCCC1",
90 | "is_chemical": true,
91 | "in_stock": true
92 | },
93 | {
94 | "type": "mol",
95 | "hide": false,
96 | "smiles": "C1=CCC=C1",
97 | "is_chemical": true,
98 | "in_stock": true
99 | }
100 | ]
101 | }
102 | ]
103 | }
104 | ]
105 | }
106 | ]
107 | }
--------------------------------------------------------------------------------
/tests/data/combined_example_tree.json:
--------------------------------------------------------------------------------
1 | {
2 | "type": "mol",
3 | "hide": false,
4 | "smiles": "Cc1ccc2nc3ccccc3c(Nc3ccc(NC(=S)Nc4ccccc4)cc3)c2c1",
5 | "is_chemical": true,
6 | "in_stock": false,
7 | "children": [
8 | {
9 | "type": "reaction",
10 | "hide": false,
11 | "smiles": "[c:1]1([N:7][cH3:8])[cH:2][cH:3]:[N:4]:[cH:5][cH:6]1>>Cl[c:1]1[cH:2][cH:3]:[N:4]:[cH:5][cH:6]1.[N:7][cH3:8]",
12 | "is_reaction": true,
13 | "metadata": {},
14 | "children": [
15 | {
16 | "type": "mol",
17 | "hide": false,
18 | "smiles": "Cc1ccc2nc3ccccc3c(Cl)c2c1",
19 | "is_chemical": true,
20 | "in_stock": true
21 | },
22 | {
23 | "type": "mol",
24 | "hide": false,
25 | "smiles": "Nc1ccc(NC(=S)Nc2ccccc2)cc1",
26 | "is_chemical": true,
27 | "in_stock": true
28 | }
29 | ]
30 | },
31 | {
32 | "type": "reaction",
33 | "hide": false,
34 | "smiles": "[N:1]([cH3:2])[C:4](=[S:3])[N:5][cH3:6]>>[N:1][cH3:2].[S:3]=[C:4]=[N:5][cH3:6]",
35 | "is_reaction": true,
36 | "metadata": {},
37 | "children": [
38 | {
39 | "type": "mol",
40 | "hide": false,
41 | "smiles": "S=C=Nc1ccccc1",
42 | "is_chemical": true,
43 | "in_stock": true
44 | },
45 | {
46 | "type": "mol",
47 | "hide": false,
48 | "smiles": "Cc1ccc2nc3ccccc3c(Nc3ccc(N)cc3)c2c1",
49 | "is_chemical": true,
50 | "in_stock": false,
51 | "children": [
52 | {
53 | "type": "reaction",
54 | "hide": false,
55 | "smiles": "[c:1]1([N:7][cH3:8])[cH:2][cH:3]:[N:4]:[cH:5][cH:6]1>>Cl[c:1]1[cH:2][cH:3]:[N:4]:[cH:5][cH:6]1.[N:7][cH3:8]",
56 | "is_reaction": true,
57 | "metadata": {},
58 | "children": [
59 | {
60 | "type": "mol",
61 | "hide": false,
62 | "smiles": "Nc1ccc(N)cc1",
63 | "is_chemical": true,
64 | "in_stock": true
65 | },
66 | {
67 | "type": "mol",
68 | "hide": false,
69 | "smiles": "Cc1ccc2nc3ccccc3c(Cl)c2c1",
70 | "is_chemical": true,
71 | "in_stock": true
72 | }
73 | ]
74 | },
75 | {
76 | "type": "reaction",
77 | "hide": false,
78 | "smiles": "[c:1]([cH2:2])([cH2:3])[N:4][cH3:5]>>Br[c:1]([cH2:2])[cH2:3].[N:4][cH3:5]",
79 | "is_reaction": true,
80 | "metadata": {},
81 | "children": [
82 | {
83 | "type": "mol",
84 | "hide": false,
85 | "smiles": "Nc1ccc(Br)cc1",
86 | "is_chemical": true,
87 | "in_stock": true
88 | },
89 | {
90 | "type": "mol",
91 | "hide": false,
92 | "smiles": "Cc1ccc2nc3ccccc3c(N)c2c1",
93 | "is_chemical": true,
94 | "in_stock": true
95 | }
96 | ]
97 | }
98 | ]
99 | }
100 | ]
101 | }
102 | ]
103 | }
--------------------------------------------------------------------------------
/tests/data/combined_example_tree2.json:
--------------------------------------------------------------------------------
1 | {
2 | "type": "mol",
3 | "hide": false,
4 | "smiles": "Cc1ccc2nc3ccccc3c(Nc3ccc(NC(=S)Nc4ccccc4)cc3)c2c1",
5 | "is_chemical": true,
6 | "in_stock": false,
7 | "children": [
8 | {
9 | "type": "reaction",
10 | "hide": false,
11 | "smiles": "[c:1]1([N:7][cH3:8])[cH:2][cH:3]:[N:4]:[cH:5][cH:6]1>>Cl[c:1]1[cH:2][cH:3]:[N:4]:[cH:5][cH:6]1.[N:7][cH3:8]",
12 | "is_reaction": true,
13 | "metadata": {},
14 | "children": [
15 | {
16 | "type": "mol",
17 | "hide": false,
18 | "smiles": "Cc1ccc2nc3ccccc3c(Cl)c2c1",
19 | "is_chemical": true,
20 | "in_stock": true
21 | },
22 | {
23 | "type": "mol",
24 | "hide": false,
25 | "smiles": "Nc1ccc(NC(=S)Nc2ccccc2)cc1",
26 | "is_chemical": true,
27 | "in_stock": true
28 | }
29 | ]
30 | },
31 | {
32 | "type": "reaction",
33 | "hide": false,
34 | "smiles": "[N:1]([cH3:2])[C:4](=[S:3])[N:5][cH3:6]>>[N:1][cH3:2].[S:3]=[C:4]=[N:5][cH3:6]",
35 | "is_reaction": true,
36 | "metadata": {},
37 | "children": [
38 | {
39 | "type": "mol",
40 | "hide": false,
41 | "smiles": "Nc1ccccc1",
42 | "is_chemical": true,
43 | "in_stock": false
44 | },
45 | {
46 | "type": "mol",
47 | "hide": false,
48 | "smiles": "Cc1ccc2nc3ccccc3c(Nc3ccc(N=C=S)cc3)c2c1",
49 | "is_chemical": true,
50 | "in_stock": false,
51 | "children": [
52 | {
53 | "type": "reaction",
54 | "hide": false,
55 | "smiles": "[c:1]1([N:7][cH3:8])[cH:2][cH:3]:[N:4]:[cH:5][cH:6]1>>Cl[c:1]1[cH:2][cH:3]:[N:4]:[cH:5][cH:6]1.[N:7][cH3:8]",
56 | "is_reaction": true,
57 | "metadata": {},
58 | "children": [
59 | {
60 | "type": "mol",
61 | "hide": false,
62 | "smiles": "Nc1ccc(N=C=S)cc1",
63 | "is_chemical": true,
64 | "in_stock": false
65 | },
66 | {
67 | "type": "mol",
68 | "hide": false,
69 | "smiles": "Cc1ccc2nc3ccccc3c(Cl)c2c1",
70 | "is_chemical": true,
71 | "in_stock": true
72 | }
73 | ]
74 | },
75 | {
76 | "type": "reaction",
77 | "hide": false,
78 | "smiles": "[c:1]([cH2:2])([cH2:3])[N:4][cH3:5]>>Br[c:1]([cH2:2])[cH2:3].[N:4][cH3:5]",
79 | "is_reaction": true,
80 | "metadata": {},
81 | "children": [
82 | {
83 | "type": "mol",
84 | "hide": false,
85 | "smiles": "Cc1ccc2nc3ccccc3c(N)c2c1",
86 | "is_chemical": true,
87 | "in_stock": true
88 | },
89 | {
90 | "type": "mol",
91 | "hide": false,
92 | "smiles": "S=C=Nc1ccc(Br)cc1",
93 | "is_chemical": true,
94 | "in_stock": false
95 | }
96 | ]
97 | }
98 | ]
99 | }
100 | ]
101 | }
102 | ]
103 | }
--------------------------------------------------------------------------------
/tests/data/dummy2_raw_template_library.csv:
--------------------------------------------------------------------------------
1 | ID,PseudoHash,RSMI,classification,retro_template,template_hash
2 | 0,AAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX
3 | 0,AAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX
4 | 0,BAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX
5 | 0,ABA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX
6 | 0,AAB,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX
7 | 0,BBA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX
8 | 0,ABB,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY
9 | 0,BAB,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY
10 | 0,CAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY
11 | 0,ACA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY
12 | 0,AAC,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY
13 | 0,CCA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),DXX
--------------------------------------------------------------------------------
/tests/data/dummy_noclass_raw_template_library.csv:
--------------------------------------------------------------------------------
1 | 0,0,AAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX
2 | 0,0,AAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX
3 | 0,0,BAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX
4 | 0,0,ABA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX
5 | 0,0,AAB,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX
6 | 0,0,BBA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX
7 | 0,0,ABB,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY
8 | 0,0,BAB,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY
9 | 0,0,CAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY
10 | 0,0,ACA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY
11 | 0,0,AAC,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY
12 | 0,0,CCA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),DXX
--------------------------------------------------------------------------------
/tests/data/dummy_raw_template_library.csv:
--------------------------------------------------------------------------------
1 | 0,0,AAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2
2 | 0,0,AAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2
3 | 0,0,BAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2
4 | 0,0,ABA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2
5 | 0,0,AAB,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2
6 | 0,0,BBA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2
7 | 0,0,ABB,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY,0,2
8 | 0,0,BAB,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY,0,2
9 | 0,0,CAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY,0,2
10 | 0,0,ACA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY,0,2
11 | 0,0,AAC,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY,0,2
12 | 0,0,CCA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),DXX,0,2
--------------------------------------------------------------------------------
/tests/data/dummy_sani_raw_template_library.csv:
--------------------------------------------------------------------------------
1 | 0,0,AAA,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2
2 | 0,0,AAA,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2
3 | 0,0,BAA,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2
4 | 0,0,ABA,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2
5 | 0,0,AAB,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2
6 | 0,0,BBA,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX,0,2
7 | 0,0,ABB,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY,0,2
8 | 0,0,BAB,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY,0,2
9 | 0,0,CAA,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY,0,2
10 | 0,0,ACA,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY,0,2
11 | 0,0,AAC,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY,0,2
12 | 0,0,CCA,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),DXX,0,2
13 | 0,0,SSA,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,c1ccccc1(C)(C),unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),SXX,0,2
14 | 0,0,SSB,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,c1ccccc1(C)(C),unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),SXX,0,2
15 | 0,0,SSC,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,c1ccccc1(C)(C),unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),SXX,0,2
16 | 0,0,SSB,CNO.CCCCOc1ccc(CC(=O)Cl)cc1,c1ccccc1(C)(C),unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),SXX,0,2
--------------------------------------------------------------------------------
/tests/data/full_search_tree.json.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/tests/data/full_search_tree.json.gz
--------------------------------------------------------------------------------
/tests/data/input_checkpoint.json.gz:
--------------------------------------------------------------------------------
1 | {"processed_smiles": "c1ccccc1", "results": {"a": 1, "b": 2, "is_solved": true, "stock_info": 1, "trees": 3}}
2 | {"processed_smiles": "Cc1ccccc1", "results": {"a": 1, "b": 2, "is_solved": true, "stock_info": 1, "trees": 3}}
3 |
--------------------------------------------------------------------------------
/tests/data/linear_route.json:
--------------------------------------------------------------------------------
1 | {
2 | "type": "mol",
3 | "route_metadata": {
4 | "created_at_iteration": 1,
5 | "is_solved": true
6 | },
7 | "hide": false,
8 | "smiles": "OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1",
9 | "is_chemical": true,
10 | "in_stock": false,
11 | "children": [
12 | {
13 | "type": "reaction",
14 | "hide": false,
15 | "smiles": "OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1>>OOc1ccc(-c2ccccc2)cc1.NC1CCCC(C2C=CC=C2)C1",
16 | "is_reaction": true,
17 | "metadata": {},
18 | "children": [
19 | {
20 | "type": "mol",
21 | "hide": false,
22 | "smiles": "NC1CCCC(C2C=CC=C2)C1",
23 | "is_chemical": true,
24 | "in_stock": true
25 | },
26 | {
27 | "type": "mol",
28 | "hide": false,
29 | "smiles": "OOc1ccc(-c2ccccc2)cc1",
30 | "is_chemical": true,
31 | "in_stock": false,
32 | "children": [
33 | {
34 | "type": "reaction",
35 | "hide": false,
36 | "smiles": "OOc1ccc(-c2ccccc2)cc1>>c1ccccc1.OOc1ccccc1",
37 | "is_reaction": true,
38 | "metadata": {},
39 | "children": [
40 | {
41 | "type": "mol",
42 | "hide": false,
43 | "smiles": "c1ccccc1",
44 | "is_chemical": true,
45 | "in_stock": true
46 | },
47 | {
48 | "type": "mol",
49 | "hide": false,
50 | "smiles": "OOc1ccccc1",
51 | "is_chemical": true,
52 | "in_stock": true
53 | }
54 | ]
55 | }
56 | ]
57 | }
58 | ]
59 | }
60 | ]
61 | }
--------------------------------------------------------------------------------
/tests/data/post_processing_test.py:
--------------------------------------------------------------------------------
1 | def post_processing(finder):
2 | return {"quantity": 5, "another_quantity": 10}
3 |
--------------------------------------------------------------------------------
/tests/data/pre_processing_test.py:
--------------------------------------------------------------------------------
1 | def pre_processing(finder, target_idx):
2 | raise ValueError(target_idx)
3 |
--------------------------------------------------------------------------------
/tests/data/test_reactions_template.csv:
--------------------------------------------------------------------------------
1 | ReactionSmilesClean mapped_rxn confidence RetroTemplate TemplateHash TemplateError
2 | CN1CCC(Cl)CC1.N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F.O>>CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1 Cl[CH:5]1[CH2:4][CH2:3][N:2]([CH3:1])[CH2:26][CH2:25]1.N#[C:6][c:8]1[cH:9][cH:10][cH:11][c:12]([NH:13][C:14](=[O:15])[c:16]2[cH:17][cH:18][c:19]([F:20])[cH:21][cH:22]2)[c:23]1[F:24].[OH2:7]>>[CH3:1][N:2]1[CH2:3][CH2:4][CH:5]([C:6](=[O:7])[c:8]2[cH:9][cH:10][cH:11][c:12]([NH:13][C:14](=[O:15])[c:16]3[cH:17][cH:18][c:19]([F:20])[cH:21][cH:22]3)[c:23]2[F:24])[CH2:25][CH2:26]1 0.9424734380772164 [C:2]-[CH;D3;+0:1](-[C:3])-[C;H0;D3;+0:4](=[O;H0;D1;+0:6])-[c:5]>>Cl-[CH;D3;+0:1](-[C:2])-[C:3].N#[C;H0;D2;+0:4]-[c:5].[OH2;D0;+0:6] 23f00a3c507eef75b22252e8eea99b2ce8eab9a573ff2fbd8227d93f49960a27
3 | N#Cc1cccc(N)c1F.O=C(Cl)c1ccc(F)cc1>>N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F Cl[C:9](=[O:10])[c:11]1[cH:12][cH:13][c:14]([F:15])[cH:16][cH:17]1.[N:1]#[C:2][c:3]1[cH:4][cH:5][cH:6][c:7]([NH2:8])[c:18]1[F:19]>>[N:1]#[C:2][c:3]1[cH:4][cH:5][cH:6][c:7]([NH:8][C:9](=[O:10])[c:11]2[cH:12][cH:13][c:14]([F:15])[cH:16][cH:17]2)[c:18]1[F:19] 0.9855040989023992 [O;D1;H0:2]=[C;H0;D3;+0:1](-[c:3])-[NH;D2;+0:4]-[c:5]>>Cl-[C;H0;D3;+0:1](=[O;D1;H0:2])-[c:3].[NH2;D1;+0:4]-[c:5] dadfa1075f086a1e76ed69f4d1c5cc44999708352c462aac5817785131169c41
4 | NC(=O)c1ccc(F)cc1.Fc1c(Cl)cccc1C#N>>N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F Cl[c:7]1[cH:6][cH:5][cH:4][c:3]([C:2]#[N:1])[c:18]1[F:19].[NH2:8][C:9](=[O:10])[c:11]1[cH:12][cH:13][c:14]([F:15])[cH:16][cH:17]1>>[N:1]#[C:2][c:3]1[cH:4][cH:5][cH:6][c:7]([NH:8][C:9](=[O:10])[c:11]2[cH:12][cH:13][c:14]([F:15])[cH:16][cH:17]2)[c:18]1[F:19] 0.9707971018098916 [O;D1;H0:6]=[C:5](-[NH;D2;+0:4]-[c;H0;D3;+0:1](:[c:2]):[c:3])-[c:7]1:[c:8]:[c:9]:[c:10]:[c:11]:[c:12]:1>>Cl-[c;H0;D3;+0:1](:[c:2]):[c:3].[NH2;D1;+0:4]-[C:5](=[O;D1;H0:6])-[c:7]1:[c:8]:[c:9]:[c:10]:[c:11]:[c:12]:1 01ba54afaa6c16613833b643d8f2503d9222eec7a3834f1cdd002faeb50ef239
--------------------------------------------------------------------------------
/tests/dfpn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/tests/dfpn/__init__.py
--------------------------------------------------------------------------------
/tests/dfpn/test_nodes.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from aizynthfinder.search.dfpn.nodes import MoleculeNode, BIG_INT
4 | from aizynthfinder.search.dfpn import SearchTree
5 |
6 |
7 | @pytest.fixture
8 | def setup_root(default_config):
9 | def wrapper(smiles):
10 | owner = SearchTree(default_config)
11 | return MoleculeNode.create_root(smiles, config=default_config, owner=owner)
12 |
13 | return wrapper
14 |
15 |
16 | def test_create_root_node(setup_root):
17 | node = setup_root("OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1")
18 |
19 | assert node.expandable
20 | assert not node.children
21 | assert node.dn == 1
22 | assert node.pn == 1
23 |
24 |
25 | def test_expand_mol_node(
26 | default_config, setup_root, setup_policies, get_linear_expansion
27 | ):
28 | node = setup_root("OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1")
29 | setup_policies(get_linear_expansion)
30 |
31 | node.expand()
32 |
33 | assert not node.expandable
34 | assert len(node.children) == 1
35 |
36 |
37 | def test_promising_child(
38 | default_config, setup_root, setup_policies, get_linear_expansion
39 | ):
40 | node = setup_root("OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1")
41 | setup_policies(get_linear_expansion)
42 | node.expand()
43 |
44 | child = node.promising_child()
45 |
46 | assert child is node.children[0]
47 | assert child.pn_threshold == BIG_INT - 1
48 | assert child.dn_threshold == BIG_INT
49 |
50 |
51 | def test_expand_reaction_node(
52 | default_config, setup_root, setup_policies, get_linear_expansion
53 | ):
54 | node = setup_root("OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1")
55 | setup_policies(get_linear_expansion)
56 | node.expand()
57 | child = node.promising_child()
58 |
59 | child.expand()
60 |
61 | assert len(child.children) == 2
62 |
63 |
64 | def test_promising_child_reaction_node(
65 | default_config,
66 | setup_root,
67 | setup_policies,
68 | get_linear_expansion,
69 | ):
70 | node = setup_root("OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1")
71 | setup_policies(get_linear_expansion)
72 | node.expand()
73 |
74 | child = node.promising_child()
75 | child.expand()
76 |
77 | grandchild = child.promising_child()
78 |
79 | assert grandchild.mol.smiles == "OOc1ccc(-c2ccccc2)cc1"
80 | assert grandchild.pn_threshold == BIG_INT - 1
81 | assert grandchild.dn_threshold == 2
82 |
83 |
84 | def test_update(
85 | default_config,
86 | setup_root,
87 | setup_policies,
88 | get_linear_expansion,
89 | setup_stock,
90 | ):
91 | node = setup_root("OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1")
92 | setup_stock(default_config, "OOc1ccc(-c2ccccc2)cc1", "NC1CCCC(C2C=CC=C2)C1")
93 | setup_policies(get_linear_expansion)
94 | node.expand()
95 |
96 | child = node.promising_child()
97 | child.expand()
98 | child.update()
99 |
100 | node.update()
101 |
102 | assert node.proven
103 | assert not node.disproven
104 |
--------------------------------------------------------------------------------
/tests/dfpn/test_search.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import pytest
4 |
5 | from aizynthfinder.search.dfpn.search_tree import SearchTree
6 |
7 |
8 | def test_search(default_config, setup_policies, setup_stock):
9 | root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1"
10 | child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"]
11 | child2_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F"]
12 | grandchild_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"]
13 | lookup = {
14 | root_smi: [
15 | {"smiles": ".".join(child1_smi), "prior": 0.7},
16 | {"smiles": ".".join(child2_smi), "prior": 0.3},
17 | ],
18 | child1_smi[1]: {"smiles": ".".join(grandchild_smi), "prior": 0.7},
19 | child2_smi[1]: {"smiles": ".".join(grandchild_smi), "prior": 0.7},
20 | }
21 | stock = [child1_smi[0], child1_smi[2]] + grandchild_smi
22 | setup_policies(lookup, config=default_config)
23 | setup_stock(default_config, *stock)
24 | tree = SearchTree(default_config, root_smi)
25 |
26 | assert not tree.one_iteration()
27 |
28 | routes = tree.routes()
29 | assert all([not route.is_solved for route in routes])
30 | assert len(tree.mol_nodes) == 4
31 |
32 | assert not tree.one_iteration()
33 | assert tree.one_iteration()
34 | assert tree.one_iteration()
35 | assert tree.one_iteration()
36 | assert tree.one_iteration()
37 |
38 | routes = tree.routes()
39 | assert len(routes) == 2
40 | assert all(route.is_solved for route in routes)
41 |
42 | with pytest.raises(StopIteration):
43 | tree.one_iteration()
44 |
45 |
46 | def test_search_incomplete(default_config, setup_policies, setup_stock):
47 | root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1"
48 | child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"]
49 | child2_smi = ["ClC(=O)c1ccc(F)cc1", "CN1CCC(CC1)C(=O)c1cccc(N)c1F"]
50 | grandchild_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"]
51 | lookup = {
52 | root_smi: [
53 | {"smiles": ".".join(child1_smi), "prior": 0.7},
54 | {"smiles": ".".join(child2_smi), "prior": 0.3},
55 | ],
56 | child1_smi[1]: {"smiles": ".".join(grandchild_smi), "prior": 0.7},
57 | }
58 | stock = [child1_smi[0], child1_smi[2]] + [grandchild_smi[0]]
59 | setup_policies(lookup, config=default_config)
60 | setup_stock(default_config, *stock)
61 | tree = SearchTree(default_config, root_smi)
62 |
63 | while True:
64 | try:
65 | tree.one_iteration()
66 | except StopIteration:
67 | break
68 |
69 | routes = tree.routes()
70 | assert len(routes) == 2
71 | assert all(not route.is_solved for route in routes)
72 |
--------------------------------------------------------------------------------
/tests/mcts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/tests/mcts/__init__.py
--------------------------------------------------------------------------------
/tests/mcts/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from aizynthfinder.search.mcts import MctsNode, MctsSearchTree
4 |
5 |
6 | @pytest.fixture
7 | def generate_root(default_config):
8 | def wrapper(smiles, config=None):
9 | return MctsNode.create_root(smiles, tree=None, config=config or default_config)
10 |
11 | return wrapper
12 |
13 |
14 | @pytest.fixture
15 | def fresh_tree(default_config):
16 | return MctsSearchTree(config=default_config, root_smiles=None)
17 |
18 |
19 | @pytest.fixture
20 | def set_default_prior(default_config):
21 | default_config.search.algorithm_config["use_prior"] = False
22 |
23 | def wrapper(prior):
24 | default_config.search.algorithm_config["default_prior"] = prior
25 |
26 | yield wrapper
27 | default_config.search.algorithm_config["use_prior"] = True
28 |
29 |
30 | @pytest.fixture
31 | def setup_mcts_search(get_one_step_expansion, setup_policies, generate_root):
32 | expansion_strategy, filter_strategy = setup_policies(get_one_step_expansion)
33 | root_smiles = list(expansion_strategy.lookup.keys())[0]
34 | return (
35 | generate_root(root_smiles),
36 | expansion_strategy,
37 | filter_strategy,
38 | )
39 |
40 |
41 | @pytest.fixture
42 | def setup_complete_mcts_tree(default_config, setup_policies, setup_stock):
43 | root_smiles = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1"
44 | tree = MctsSearchTree(config=default_config, root_smiles=root_smiles)
45 | lookup = {
46 | root_smiles: {
47 | "smiles": "CN1CCC(Cl)CC1.N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F.O",
48 | "prior": 1.0,
49 | },
50 | "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F": {
51 | "smiles": "N#Cc1cccc(N)c1F.O=C(Cl)c1ccc(F)cc1",
52 | "prior": 1.0,
53 | },
54 | }
55 | setup_policies(lookup)
56 |
57 | setup_stock(
58 | default_config, "CN1CCC(Cl)CC1", "O", "N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"
59 | )
60 |
61 | node1 = tree.root
62 | node1.expand()
63 |
64 | node2 = node1.promising_child()
65 | node2.expand()
66 |
67 | node3 = node2.promising_child()
68 | node3.expand()
69 |
70 | return tree, [node1, node2, node3]
71 |
--------------------------------------------------------------------------------
/tests/mcts/test_multiobjective.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from aizynthfinder.search.mcts.node import ParetoMctsNode
4 | from aizynthfinder.search.mcts import MctsSearchTree
5 |
6 |
7 | @pytest.fixture
8 | def generate_root(default_config):
9 | def wrapper(smiles, config=None):
10 | return ParetoMctsNode.create_root(
11 | smiles, tree=None, config=config or default_config
12 | )
13 |
14 | return wrapper
15 |
16 |
17 | @pytest.fixture
18 | def setup_mcts_search(
19 | default_config, get_one_step_expansion, setup_policies, generate_root
20 | ):
21 | default_config.search.algorithm_config["search_rewards"] = [
22 | "number of reactions",
23 | "number of pre-cursors in stock",
24 | ]
25 | expansion_strategy, filter_strategy = setup_policies(get_one_step_expansion)
26 | root_smiles = list(expansion_strategy.lookup.keys())[0]
27 | return (
28 | generate_root(root_smiles),
29 | expansion_strategy,
30 | filter_strategy,
31 | )
32 |
33 |
34 | def test_expand_root_node(setup_mcts_search):
35 | root, _, _ = setup_mcts_search
36 |
37 | root.expand()
38 |
39 | view = root.children_view()
40 | assert len(view["actions"]) == 3
41 | assert view["priors"] == [[0.7, 0.7], [0.5, 0.5], [0.3, 0.3]]
42 | assert view["values"] == [[0.7, 0.7], [0.5, 0.5], [0.3, 0.3]]
43 | assert view["visitations"] == [1, 1, 1]
44 | assert view["objects"] == [None, None, None]
45 |
46 |
47 | def test_expand_root_with_default_priors(setup_mcts_search, set_default_prior):
48 | root, _, _ = setup_mcts_search
49 | set_default_prior(0.01)
50 |
51 | root.expand()
52 |
53 | view = root.children_view()
54 | assert len(view["actions"]) == 3
55 | assert view["priors"] == [[0.01, 0.01], [0.01, 0.01], [0.01, 0.01]]
56 | assert view["values"] == [[0.01, 0.01], [0.01, 0.01], [0.01, 0.01]]
57 | assert view["visitations"] == [1, 1, 1]
58 | assert view["objects"] == [None, None, None]
59 |
60 |
61 | def test_backpropagate(setup_mcts_search):
62 | root, _, _ = setup_mcts_search
63 | root.expand()
64 | child = root.promising_child()
65 | view_prior = root.children_view()
66 |
67 | root.backpropagate(child, [1.5, 2.0])
68 |
69 | view_post = root.children_view()
70 | assert view_post["visitations"][0] == view_prior["visitations"][0] + 1
71 | assert view_prior["visitations"][1:] == view_post["visitations"][1:]
72 | assert view_prior["values"] == view_post["values"]
73 | assert view_post["rewards_cum"][0][0] == view_prior["rewards_cum"][0][0] + 1.5
74 | assert view_post["rewards_cum"][0][1] == view_prior["rewards_cum"][0][1] + 2.0
75 | assert view_prior["rewards_cum"][1:] == view_post["rewards_cum"][1:]
76 |
77 |
78 | def test_setup_weighted_sum_tree(default_config):
79 | default_config.search.algorithm_config["search_rewards"] = [
80 | "number of reactions",
81 | "number of pre-cursors in stock",
82 | ]
83 | default_config.search.algorithm_config["search_rewards_weights"] = [0.5, 0.1]
84 | root_smiles = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1"
85 | tree = MctsSearchTree(config=default_config, root_smiles=root_smiles)
86 |
87 | assert tree.mode == "weighted-sum"
88 | assert len(tree.reward_scorer.selection) == 2
89 | assert tree.compute_reward(tree.root) == 0.0
90 |
91 |
92 | def test_setup_mo_tree(default_config):
93 | default_config.search.algorithm_config["search_rewards"] = [
94 | "number of reactions",
95 | "number of pre-cursors in stock",
96 | ]
97 | root_smiles = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1"
98 | tree = MctsSearchTree(config=default_config, root_smiles=root_smiles)
99 |
100 | assert tree.mode == "multi-objective"
101 | assert len(tree.reward_scorer.selection) == 2
102 | assert tree.compute_reward(tree.root) == [0.0, 0.0]
103 |
--------------------------------------------------------------------------------
/tests/mcts/test_reward.py:
--------------------------------------------------------------------------------
1 | from aizynthfinder.analysis import TreeAnalysis
2 | from aizynthfinder.context.scoring import (
3 | NumberOfReactionsScorer,
4 | StateScorer,
5 | )
6 | from aizynthfinder.search.mcts import MctsSearchTree
7 |
8 |
9 | def test_reward_node(default_config, generate_root):
10 | config = default_config
11 | search_reward_scorer = repr(StateScorer(config))
12 | post_process_reward_scorer = repr(NumberOfReactionsScorer())
13 |
14 | config.search.algorithm_config["search_rewards"] = [search_reward_scorer]
15 | config.post_processing.route_scorers = [post_process_reward_scorer]
16 |
17 | node = generate_root("CCCCOc1ccc(CC(=O)N(C)O)cc1", config)
18 |
19 | search_scorer = config.scorers[config.search.algorithm_config["search_rewards"][0]]
20 | route_scorer = config.scorers[config.post_processing.route_scorers[0]]
21 |
22 | assert round(search_scorer(node), 4) == 0.0491
23 | assert route_scorer(node) == 0
24 |
25 |
26 | def test_default_postprocessing_reward(setup_aizynthfinder):
27 | """Test using default postprocessing.route_score"""
28 | root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1"
29 | child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"]
30 | lookup = {root_smi: {"smiles": ".".join(child1_smi), "prior": 1.0}}
31 | finder = setup_aizynthfinder(lookup, child1_smi)
32 |
33 | config = finder.config
34 | config.search.return_first = True
35 |
36 | search_reward_scorer = repr(NumberOfReactionsScorer())
37 | state_scorer = repr(StateScorer(config))
38 |
39 | config.search.algorithm_config["search_rewards"] = [search_reward_scorer]
40 | finder.config = config
41 |
42 | assert len(finder.config.post_processing.route_scorers) == 0
43 |
44 | finder.tree_search()
45 | tree_analysis_search = TreeAnalysis(
46 | finder.tree, scorer=config.scorers[search_reward_scorer]
47 | )
48 | tree_analysis_pp = TreeAnalysis(finder.tree, scorer=config.scorers[state_scorer])
49 |
50 | finder.build_routes()
51 | assert finder.tree.reward_scorer_name == search_reward_scorer
52 |
53 | top_score_tree_analysis = tree_analysis_search.tree_statistics()["top_score"]
54 | top_score_finder = finder.tree.compute_reward(tree_analysis_search.best())
55 |
56 | assert top_score_finder == top_score_tree_analysis
57 |
58 | top_score_tree_analysis = tree_analysis_pp.tree_statistics()["top_score"]
59 | top_score_finder = finder.analysis.tree_statistics()["top_score"]
60 |
61 | # Finder used the search_reward_scorer and not state_scorer
62 | assert top_score_finder != top_score_tree_analysis
63 |
64 |
65 | def test_custom_reward(setup_aizynthfinder):
66 | """Test using different custom reward functions for MCTS and route building."""
67 |
68 | root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1"
69 | child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"]
70 | lookup = {root_smi: {"smiles": ".".join(child1_smi), "prior": 1.0}}
71 | finder = setup_aizynthfinder(lookup, child1_smi)
72 |
73 | # Test first with return_first and multiple route scores
74 | config = finder.config
75 | config.search.return_first = True
76 |
77 | search_reward_scorer = repr(StateScorer(config))
78 | post_process_reward_scorer = repr(NumberOfReactionsScorer())
79 |
80 | config.search.algorithm_config["search_rewards"] = [search_reward_scorer]
81 | config.post_processing.route_scorers = [post_process_reward_scorer]
82 | finder.config = config
83 |
84 | assert finder.config.post_processing.route_scorers == [post_process_reward_scorer]
85 |
86 | finder.tree_search()
87 | tree_analysis_search = TreeAnalysis(
88 | finder.tree, scorer=config.scorers[search_reward_scorer]
89 | )
90 | tree_analysis_pp = TreeAnalysis(
91 | finder.tree, scorer=config.scorers[post_process_reward_scorer]
92 | )
93 |
94 | finder.build_routes()
95 |
96 | assert finder.config.post_processing.route_scorers == [post_process_reward_scorer]
97 | assert finder.tree.reward_scorer_name == search_reward_scorer
98 |
99 | top_score_tree_analysis = tree_analysis_search.tree_statistics()["top_score"]
100 | top_score_finder = finder.tree.compute_reward(tree_analysis_search.best())
101 |
102 | assert top_score_finder == top_score_tree_analysis
103 |
104 | top_score_tree_analysis = tree_analysis_pp.tree_statistics()["top_score"]
105 | top_score_finder = finder.analysis.tree_statistics()["top_score"]
106 |
107 | assert top_score_finder == top_score_tree_analysis
108 |
109 |
110 | def test_reward_node_backward_compatibility(default_config):
111 | reward_scorer = repr(NumberOfReactionsScorer())
112 | default_config.search.algorithm_config["search_reward"] = reward_scorer
113 |
114 | tree = MctsSearchTree(config=default_config, root_smiles=None)
115 |
116 | assert tree.reward_scorer_name == reward_scorer
117 |
--------------------------------------------------------------------------------
/tests/mcts/test_tree.py:
--------------------------------------------------------------------------------
1 | def test_select_leaf_root(setup_complete_mcts_tree):
2 | tree, nodes = setup_complete_mcts_tree
3 | nodes[0].is_expanded = False
4 |
5 | leaf = tree.select_leaf()
6 |
7 | assert leaf is nodes[0]
8 |
9 |
10 | def test_select_leaf(setup_complete_mcts_tree):
11 | tree, nodes = setup_complete_mcts_tree
12 |
13 | leaf = tree.select_leaf()
14 |
15 | assert leaf is nodes[2]
16 |
17 |
18 | def test_backpropagation(setup_complete_mcts_tree, mocker):
19 | tree, nodes = setup_complete_mcts_tree
20 | for node in nodes:
21 | node.backpropagate = mocker.MagicMock()
22 | score = tree.reward_scorer[tree.reward_scorer_name](nodes[2])
23 |
24 | tree.backpropagate(nodes[2])
25 |
26 | nodes[0].backpropagate.assert_called_once_with(nodes[1], score)
27 | nodes[1].backpropagate.assert_called_once_with(nodes[2], score)
28 | nodes[2].backpropagate.assert_not_called()
29 |
30 |
31 | def test_route_to_node(setup_complete_mcts_tree):
32 | tree, nodes = setup_complete_mcts_tree
33 |
34 | actions, route_nodes = nodes[2].path_to()
35 |
36 | assert len(actions) == 2
37 | assert len(nodes) == 3
38 | assert nodes[0] == route_nodes[0]
39 | assert nodes[1] == route_nodes[1]
40 | assert nodes[2] == route_nodes[2]
41 |
42 |
43 | def test_create_graph(setup_complete_mcts_tree):
44 | tree, nodes = setup_complete_mcts_tree
45 |
46 | graph = tree.graph()
47 |
48 | assert len(graph) == 3
49 | assert list(graph.successors(nodes[0])) == [nodes[1]]
50 | assert list(graph.successors(nodes[1])) == [nodes[2]]
51 |
--------------------------------------------------------------------------------
/tests/retrostar/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/tests/retrostar/__init__.py
--------------------------------------------------------------------------------
/tests/retrostar/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 |
4 |
5 | from aizynthfinder.search.retrostar.search_tree import SearchTree
6 | from aizynthfinder.search.retrostar.cost import MoleculeCost
7 | from aizynthfinder.search.retrostar.nodes import MoleculeNode
8 | from aizynthfinder.aizynthfinder import AiZynthFinder
9 |
10 |
11 | @pytest.fixture
12 | def setup_aizynthfinder(setup_policies, setup_stock):
13 | def wrapper(expansions, stock):
14 | finder = AiZynthFinder()
15 | root_smi = list(expansions.keys())[0]
16 | setup_policies(expansions, config=finder.config)
17 | setup_stock(finder.config, *stock)
18 | finder.target_smiles = root_smi
19 | finder.config.search.algorithm = (
20 | "aizynthfinder.search.retrostar.search_tree.SearchTree"
21 | )
22 | return finder
23 |
24 | return wrapper
25 |
26 |
27 | @pytest.fixture
28 | def setup_search_tree(default_config, setup_policies, setup_stock):
29 | root_smiles = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1"
30 | tree = SearchTree(config=default_config, root_smiles=root_smiles)
31 | lookup = {
32 | root_smiles: {
33 | "smiles": "CN1CCC(Cl)CC1.N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F.O",
34 | "prior": 1.0,
35 | }
36 | }
37 | setup_policies(lookup)
38 |
39 | setup_stock(
40 | default_config, "CN1CCC(Cl)CC1", "O", "N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"
41 | )
42 | return tree
43 |
44 |
45 | @pytest.fixture
46 | def setup_star_root(default_config):
47 | def wrapper(smiles):
48 | return MoleculeNode.create_root(
49 | smiles, config=default_config, molecule_cost=MoleculeCost(default_config)
50 | )
51 |
52 | return wrapper
53 |
--------------------------------------------------------------------------------
/tests/retrostar/test_retrostar.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from aizynthfinder.search.retrostar.search_tree import SearchTree
4 | from aizynthfinder.chem.serialization import MoleculeSerializer
5 |
6 |
7 | def test_one_iteration(setup_search_tree):
8 | tree = setup_search_tree
9 |
10 | tree.one_iteration()
11 |
12 | assert len(tree.root.children) == 1
13 | assert len(tree.root.children[0].children) == 3
14 |
15 |
16 | def test_one_iteration_filter_unfeasible(setup_search_tree):
17 | tree = setup_search_tree
18 | smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1>>CN1CCC(Cl)CC1.N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F.O"
19 | tree.config.filter_policy["dummy"].lookup[smi] = 0.0
20 |
21 | tree.one_iteration()
22 | assert len(tree.root.children) == 0
23 |
24 |
25 | def test_one_iteration_filter_feasible(setup_search_tree):
26 | tree = setup_search_tree
27 | smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1>>CN1CCC(Cl)CC1.N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F.O"
28 | tree.config.filter_policy["dummy"].lookup[smi] = 0.5
29 |
30 | tree.one_iteration()
31 | assert len(tree.root.children) == 1
32 |
33 |
34 | def test_one_expansion_with_finder(setup_aizynthfinder):
35 | """
36 | Test the building of this tree:
37 | root
38 | |
39 | child 1
40 | """
41 | root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1"
42 | child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"]
43 | lookup = {root_smi: {"smiles": ".".join(child1_smi), "prior": 1.0}}
44 | finder = setup_aizynthfinder(lookup, child1_smi)
45 |
46 | # Test first with return_first
47 | finder.config.search.return_first = True
48 | finder.tree_search()
49 |
50 | nodes = finder.tree.mol_nodes
51 | assert len(nodes) == 4
52 | assert nodes[0].mol.smiles == root_smi
53 | assert nodes[1].mol.smiles == child1_smi[0]
54 | assert finder.search_stats["iterations"] == 1
55 | assert finder.search_stats["returned_first"]
56 |
57 | # then test with iteration limit
58 | finder.config.search.return_first = False
59 | finder.config.search.iteration_limit = 45
60 | finder.prepare_tree()
61 | finder.tree_search()
62 |
63 | assert len(finder.tree.mol_nodes) == 4
64 | # It will not continue because it cannot expand any more nodes
65 | assert finder.search_stats["iterations"] == 2
66 | assert not finder.search_stats["returned_first"]
67 |
68 |
69 | def test_serialization_deserialization(
70 | mocker, setup_search_tree, tmpdir, default_config
71 | ):
72 | tree = setup_search_tree
73 | tree.one_iteration()
74 |
75 | mocked_json_dump = mocker.patch(
76 | "aizynthfinder.search.retrostar.search_tree.json.dump"
77 | )
78 | serializer = MoleculeSerializer()
79 | filename = str(tmpdir / "dummy.json")
80 |
81 | # Test serialization
82 |
83 | tree.serialize(filename)
84 |
85 | expected_dict = {
86 | "tree": tree.root.serialize(serializer),
87 | "molecules": serializer.store,
88 | }
89 |
90 | mocked_json_dump.assert_called_once_with(
91 | expected_dict, mocker.ANY, indent=mocker.ANY
92 | )
93 |
94 | # Test deserialization
95 |
96 | mocker.patch(
97 | "aizynthfinder.search.retrostar.search_tree.json.load",
98 | return_value=expected_dict,
99 | )
100 | mocker.patch(
101 | "aizynthfinder.search.retrostar.nodes.deserialize_action", return_value=None
102 | )
103 |
104 | new_tree = SearchTree.from_json(filename, default_config)
105 |
106 | assert new_tree.root.mol == tree.root.mol
107 | assert len(new_tree.root.children) == len(tree.root.children)
108 |
109 |
110 | def test_split_andor_tree(shared_datadir, default_config):
111 | tree = SearchTree.from_json(
112 | str(shared_datadir / "andor_tree_for_clustering.json"), default_config
113 | )
114 |
115 | routes = tree.routes()
116 |
117 | assert len(routes) == 97
118 |
119 |
120 | def test_update(shared_datadir, default_config, setup_stock):
121 | # Todo: re-write
122 | setup_stock(
123 | default_config,
124 | "Nc1ccc(NC(=S)Nc2ccccc2)cc1",
125 | "Cc1ccc2nc3ccccc3c(Cl)c2c1",
126 | "Nc1ccccc1",
127 | "Nc1ccc(N=C=S)cc1",
128 | "Cc1ccc2nc3ccccc3c(Br)c2c1",
129 | "Nc1ccc(Br)cc1",
130 | )
131 | tree = SearchTree.from_json(
132 | str(shared_datadir / "andor_tree_for_clustering.json"), default_config
133 | )
134 |
135 | saved_root_value = tree.root.value
136 | tree.mol_nodes[-1].parent.update(35, from_mol=tree.mol_nodes[-1].mol)
137 |
138 | assert [np.round(child.value, 2) for child in tree.root.children][:2] == [
139 | 3.17,
140 | 3.31,
141 | ]
142 | assert tree.root.value == saved_root_value
143 |
144 | tree.serialize("temp.json")
145 |
--------------------------------------------------------------------------------
/tests/retrostar/test_retrostar_cost.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from aizynthfinder.chem import Molecule
4 | from aizynthfinder.search.retrostar.cost import (
5 | MoleculeCost,
6 | RetroStarCost,
7 | ZeroMoleculeCost,
8 | )
9 |
10 |
11 | def test_retrostar_cost(setup_mocked_model):
12 | mol = Molecule(smiles="CCCC")
13 |
14 | cost = RetroStarCost(model_path="dummy", fingerprint_length=10, dropout_rate=0.0)
15 | assert pytest.approx(cost.calculate(mol), abs=0.001) == 30
16 |
17 |
18 | def test_zero_molecule_cost():
19 | mol = Molecule(smiles="CCCC")
20 |
21 | cost = ZeroMoleculeCost().calculate(mol)
22 | assert cost == 0.0
23 |
24 |
25 | def test_molecule_cost_zero(default_config):
26 | default_config.search.algorithm_config["molecule_cost"] = {
27 | "cost": "ZeroMoleculeCost"
28 | }
29 | mol = Molecule(smiles="CCCC")
30 |
31 | molecule_cost = MoleculeCost(default_config)(mol)
32 | assert molecule_cost == 0.0
33 |
34 |
35 | def test_molecule_cost_retrostar(default_config, setup_mocked_model):
36 | default_config.search.algorithm_config["molecule_cost"] = {
37 | "cost": "aizynthfinder.search.retrostar.cost.RetroStarCost",
38 | "model_path": "dummy",
39 | "fingerprint_length": 10,
40 | "dropout_rate": 0.0,
41 | }
42 | mol = Molecule(smiles="CCCC")
43 |
44 | molecule_cost = MoleculeCost(default_config)(mol)
45 | assert pytest.approx(molecule_cost, abs=0.001) == 30
46 |
--------------------------------------------------------------------------------
/tests/retrostar/test_retrostar_nodes.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import networkx as nx
3 |
4 | from aizynthfinder.search.retrostar.nodes import MoleculeNode
5 | from aizynthfinder.chem.serialization import MoleculeSerializer, MoleculeDeserializer
6 | from aizynthfinder.search.andor_trees import ReactionTreeFromAndOrTrace
7 |
8 |
9 | def test_create_root_node(setup_star_root):
10 | node = setup_star_root("CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1")
11 |
12 | assert node.target_value == 0
13 | assert node.ancestors() == {node.mol}
14 | assert node.expandable
15 | assert not node.children
16 |
17 |
18 | def test_close_single_node(setup_star_root):
19 | node = setup_star_root("CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1")
20 |
21 | assert node.expandable
22 |
23 | delta = node.close()
24 |
25 | assert not np.isfinite(delta)
26 | assert not node.solved
27 | assert not node.expandable
28 |
29 |
30 | def test_create_stub(setup_star_root, get_action):
31 | root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1"
32 | root = setup_star_root(root_smiles)
33 | reaction = get_action()
34 |
35 | root.add_stub(cost=5.0, reaction=reaction)
36 |
37 | assert len(root.children) == 1
38 | assert len(root.children[0].children) == 2
39 | rxn_node = root.children[0]
40 | assert rxn_node.reaction is reaction
41 | exp_list = [node.mol for node in rxn_node.children]
42 | assert exp_list == list(reaction.reactants[0])
43 | assert rxn_node.value == rxn_node.target_value == 5
44 | assert not rxn_node.solved
45 |
46 | # This is done after a node has been expanded
47 | delta = root.close()
48 |
49 | assert delta == 5.0
50 | assert root.value == 5.0
51 | assert rxn_node.children[0].ancestors() == {root.mol, rxn_node.children[0].mol}
52 |
53 |
54 | def test_initialize_stub_one_solved_leaf(
55 | setup_star_root, get_action, default_config, setup_stock
56 | ):
57 | root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1"
58 | root = setup_star_root(root_smiles)
59 | reaction = get_action()
60 | setup_stock(default_config, reaction.reactants[0][0])
61 |
62 | root.add_stub(cost=5.0, reaction=reaction)
63 | root.close()
64 |
65 | assert not root.children[0].solved
66 | assert not root.solved
67 | assert root.children[0].children[0].solved
68 |
69 |
70 | def test_initialize_stub_two_solved_leafs(
71 | setup_star_root, get_action, default_config, setup_stock
72 | ):
73 | root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1"
74 | root = setup_star_root(root_smiles)
75 | reaction = get_action()
76 | setup_stock(default_config, *reaction.reactants[0])
77 |
78 | root.add_stub(cost=5.0, reaction=reaction)
79 | root.close()
80 |
81 | assert root.children[0].solved
82 | assert root.solved
83 |
84 |
85 | def test_serialization_deserialization(
86 | setup_star_root, get_action, default_config, setup_stock
87 | ):
88 | root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1"
89 | root = setup_star_root(root_smiles)
90 | reaction = get_action()
91 | setup_stock(default_config, *reaction.reactants[0])
92 | root.add_stub(cost=5.0, reaction=reaction)
93 | root.close()
94 |
95 | molecule_serializer = MoleculeSerializer()
96 | dict_ = root.serialize(molecule_serializer)
97 |
98 | molecule_deserializer = MoleculeDeserializer(molecule_serializer.store)
99 | node = MoleculeNode.from_dict(
100 | dict_, default_config, molecule_deserializer, root.molecule_cost
101 | )
102 |
103 | assert node.mol == root.mol
104 | assert node.value == root.value
105 | assert node.cost == root.cost
106 | assert len(node.children) == len(root.children)
107 |
108 | rxn_node = node.children[0]
109 | assert rxn_node.reaction.smarts == reaction.smarts
110 | assert rxn_node.reaction.metadata == reaction.metadata
111 | assert rxn_node.cost == root.children[0].cost
112 | assert rxn_node.value == root.children[0].value
113 |
114 | for grandchild1, grandchild2 in zip(rxn_node.children, root.children[0].children):
115 | assert grandchild1.mol == grandchild2.mol
116 |
117 |
118 | def test_conversion_to_reaction_tree(
119 | setup_star_root, get_action, default_config, setup_stock
120 | ):
121 | root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1"
122 | root = setup_star_root(root_smiles)
123 | reaction = get_action()
124 | setup_stock(default_config, *reaction.reactants[0])
125 | root.add_stub(cost=5.0, reaction=reaction)
126 | root.close()
127 | graph = nx.DiGraph()
128 | graph.add_edge(root, root.children[0])
129 | graph.add_edge(root.children[0], root.children[0].children[0])
130 | graph.add_edge(root.children[0], root.children[0].children[1])
131 |
132 | rt = ReactionTreeFromAndOrTrace(graph, default_config.stock).tree
133 |
134 | molecules = list(rt.molecules())
135 | rt_reactions = list(rt.reactions())
136 | assert len(molecules) == 3
137 | assert len(list(rt.leafs())) == 2
138 | assert len(rt_reactions) == 1
139 | assert molecules[0].inchi_key == root.mol.inchi_key
140 | assert molecules[1].inchi_key == root.children[0].children[0].mol.inchi_key
141 | assert molecules[2].inchi_key == root.children[0].children[1].mol.inchi_key
142 | assert rt_reactions[0].reaction_smiles() == reaction.reaction_smiles()
143 |
--------------------------------------------------------------------------------
/tests/test_expander.py:
--------------------------------------------------------------------------------
1 | from aizynthfinder.aizynthfinder import AiZynthExpander
2 |
3 |
4 | def test_expander_defaults(get_one_step_expansion, setup_policies):
5 | expander = AiZynthExpander()
6 | setup_policies(get_one_step_expansion, config=expander.config)
7 | smi = "CCCCOc1ccc(CC(=O)N(C)O)cc1"
8 |
9 | reactions = expander.do_expansion(smi)
10 |
11 | assert len(reactions) == 2
12 | assert len(reactions[0]) == 1
13 | assert len(reactions[1]) == 1
14 |
15 | assert reactions[0][0].mol.smiles == smi
16 | assert reactions[1][0].mol.smiles == smi
17 | assert len(reactions[0][0].reactants[0]) == 2
18 | assert len(reactions[1][0].reactants[0]) == 2
19 | smi1 = [mol.smiles for mol in reactions[0][0].reactants[0]]
20 | smi2 = [mol.smiles for mol in reactions[1][0].reactants[0]]
21 | assert smi1 != smi2
22 |
23 |
24 | def test_expander_top1(get_one_step_expansion, setup_policies):
25 | expander = AiZynthExpander()
26 | setup_policies(get_one_step_expansion, config=expander.config)
27 | smi = "CCCCOc1ccc(CC(=O)N(C)O)cc1"
28 |
29 | reactions = expander.do_expansion(smi, return_n=1)
30 |
31 | assert len(reactions) == 1
32 | smiles_list = [mol.smiles for mol in reactions[0][0].reactants[0]]
33 | assert smiles_list == ["CCCCOc1ccc(CC(=O)Cl)cc1", "CNO"]
34 |
35 |
36 | def test_expander_filter(get_one_step_expansion, setup_policies):
37 | def filter_func(reaction):
38 | return "CNO" not in [mol.smiles for mol in reaction.reactants[0]]
39 |
40 | expander = AiZynthExpander()
41 | setup_policies(get_one_step_expansion, config=expander.config)
42 | smi = "CCCCOc1ccc(CC(=O)N(C)O)cc1"
43 |
44 | reactions = expander.do_expansion(smi, filter_func=filter_func)
45 |
46 | assert len(reactions) == 1
47 | smiles_list = [mol.smiles for mol in reactions[0][0].reactants[0]]
48 | assert smiles_list == ["CCCCBr", "CN(O)C(=O)Cc1ccc(O)cc1"]
49 |
50 |
51 | def test_expander_filter_policy(get_one_step_expansion, setup_policies):
52 | expander = AiZynthExpander()
53 | _, filter_strategy = setup_policies(get_one_step_expansion, config=expander.config)
54 | filter_strategy.lookup[
55 | "CCCCOc1ccc(CC(=O)N(C)O)cc1>>CCCCOc1ccc(CC(=O)Cl)cc1.CNO"
56 | ] = 0.5
57 | smi = "CCCCOc1ccc(CC(=O)N(C)O)cc1"
58 |
59 | reactions = expander.do_expansion(smi)
60 |
61 | assert len(reactions) == 2
62 | assert reactions[0][0].metadata["feasibility"] == 0.5
63 | assert reactions[1][0].metadata["feasibility"] == 0.0
64 |
--------------------------------------------------------------------------------
/tests/test_gui.py:
--------------------------------------------------------------------------------
1 | from aizynthfinder.analysis.routes import RouteCollection
2 | from aizynthfinder.interfaces.gui.utils import pareto_fronts_plot, route_display
3 | from aizynthfinder.interfaces.gui.pareto_fronts import ParetoFrontsGUI
4 |
5 |
6 | def test_plot_pareto_fronts(setup_mo_scorer, setup_analysis, default_config):
7 | analysis, _ = setup_analysis(scorer=setup_mo_scorer(default_config))
8 | routes = RouteCollection.from_analysis(analysis)
9 |
10 | pareto_fronts_plot(routes)
11 |
12 |
13 | def test_display_route(setup_analysis, mocker):
14 | display_patch = mocker.patch("aizynthfinder.interfaces.gui.utils.display")
15 | mocked_widget = mocker.MagicMock()
16 | analysis, _ = setup_analysis()
17 | routes = RouteCollection.from_analysis(analysis)
18 | routes.make_images()
19 |
20 | route_display(0, routes, mocked_widget)
21 | display_patch.assert_called()
22 |
23 |
24 | def test_pareto_gui(setup_analysis, mocker, default_config):
25 | display_patch = mocker.patch("aizynthfinder.interfaces.gui.pareto_fronts.display")
26 | analysis, _ = setup_analysis()
27 | routes = RouteCollection.from_analysis(analysis)
28 |
29 | ParetoFrontsGUI(routes.reaction_trees, default_config.scorers)
30 |
31 | display_patch.assert_called()
32 |
--------------------------------------------------------------------------------
/tests/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MolecularAI/aizynthfinder/8877e4ed61550eaa46186d22c08e56f5e0629e02/tests/utils/__init__.py
--------------------------------------------------------------------------------
/tests/utils/test_bonds.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests to check the functionality of the bonds script.
3 | """
4 | from aizynthfinder.chem import TreeMolecule, SmilesBasedRetroReaction
5 | from aizynthfinder.utils.bonds import BrokenBonds
6 |
7 |
8 | def test_focussed_bonds_broken():
9 | mol = TreeMolecule(smiles="[CH3:1][NH:2][C:3](C)=[O:4]", parent=None)
10 | reaction = SmilesBasedRetroReaction(
11 | mol,
12 | mapped_prod_smiles="[CH3:1][NH:2][C:3](C)=[O:4]",
13 | reactants_str="C[C:3](=[O:4])O.[CH3:1][NH:2]",
14 | )
15 | focussed_bonds = [(1, 2), (3, 4), (3, 2)]
16 | broken_bonds = BrokenBonds(focussed_bonds)
17 | broken_focussed_bonds = broken_bonds(reaction)
18 |
19 | assert broken_focussed_bonds == [(2, 3)]
20 |
21 |
22 | def test_focussed_bonds_not_broken():
23 | mol = TreeMolecule(smiles="[CH3:1][NH:2][C:3](C)=[O:4]", parent=None)
24 | reaction = SmilesBasedRetroReaction(
25 | mol,
26 | mapped_prod_smiles="[CH3:1][NH:2][C:3](C)=[O:4]",
27 | reactants_str="C[C:3](=[O:4])O.[CH3:1][NH:2]",
28 | )
29 | focussed_bonds = [(1, 2), (3, 4)]
30 | broken_bonds = BrokenBonds(focussed_bonds)
31 | broken_focussed_bonds = broken_bonds(reaction)
32 |
33 | assert broken_focussed_bonds == []
34 | assert mol.has_all_focussed_bonds(focussed_bonds) is True
35 |
36 |
37 | def test_focussed_bonds_not_in_target_mol():
38 | mol = TreeMolecule(smiles="[CH3:1][NH:2][C:3](C)=[O:4]", parent=None)
39 | focussed_bonds = [(1, 4)]
40 |
41 | assert mol.has_all_focussed_bonds(focussed_bonds) is False
42 |
--------------------------------------------------------------------------------
/tests/utils/test_dynamic_loading.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from aizynthfinder.utils.loading import load_dynamic_class
4 |
5 |
6 | class MyException(Exception):
7 | pass
8 |
9 |
10 | def test_simple():
11 | cls = load_dynamic_class("aizynthfinder.reactiontree.ReactionTree")
12 |
13 | assert cls.__name__ == "ReactionTree"
14 |
15 |
16 | def test_default_module():
17 | cls = load_dynamic_class(
18 | "ReactionTree", default_module="aizynthfinder.reactiontree"
19 | )
20 |
21 | assert cls.__name__ == "ReactionTree"
22 |
23 |
24 | def test_no_default_module():
25 | with pytest.raises(ValueError, match="default_module"):
26 | load_dynamic_class("ReactionTree")
27 |
28 | with pytest.raises(MyException, match="default_module"):
29 | load_dynamic_class("ReactionTree", exception_cls=MyException)
30 |
31 |
32 | def test_incorrect_module():
33 | bad_module = "aizynthfinder.rt."
34 | with pytest.raises(ValueError, match=bad_module):
35 | load_dynamic_class(f"{bad_module}.ReactionTree")
36 |
37 | with pytest.raises(MyException, match=bad_module):
38 | load_dynamic_class(f"{bad_module}.ReactionTree", exception_cls=MyException)
39 |
40 |
41 | def test_incorrect_class():
42 | bad_class = "ReactionTreee"
43 | with pytest.raises(ValueError, match=bad_class):
44 | load_dynamic_class(f"aizynthfinder.reactiontree.{bad_class}")
45 |
46 | with pytest.raises(MyException, match=bad_class):
47 | load_dynamic_class(
48 | f"aizynthfinder.reactiontree.{bad_class}", exception_cls=MyException
49 | )
50 |
--------------------------------------------------------------------------------
/tests/utils/test_external_tf_models.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | import aizynthfinder.utils.models as models
5 | from aizynthfinder.utils.models import (
6 | SUPPORT_EXTERNAL_APIS,
7 | ExternalModelViaGRPC,
8 | ExternalModelViaREST,
9 | )
10 |
11 |
12 | @pytest.mark.xfail(
13 | condition=not SUPPORT_EXTERNAL_APIS, reason="API packages not installed"
14 | )
15 | @pytest.fixture()
16 | def setup_rest_mock(mocker):
17 | models.TF_SERVING_HOST = "localhost"
18 | models.TF_SERVING_REST_PORT = "255"
19 | mocked_request = mocker.patch("aizynthfinder.utils.models.requests.request")
20 | mocked_request.return_value.status_code = 200
21 | mocked_request.return_value.headers = {"Content-Type": "application/json"}
22 |
23 | def wrapper(response):
24 | if isinstance(response, list):
25 | mocked_request.return_value.json.side_effect = response
26 | else:
27 | mocked_request.return_value.json.return_value = response
28 | return mocked_request
29 |
30 | yield wrapper
31 |
32 | models.TF_SERVING_HOST = None
33 | models.TF_SERVING_REST_PORT = None
34 |
35 |
36 | @pytest.mark.xfail(
37 | condition=not SUPPORT_EXTERNAL_APIS, reason="API packages not installed"
38 | )
39 | @pytest.fixture()
40 | def setup_grpc_mock(mocker, signature_grpc):
41 | models.TF_SERVING_HOST = "localhost"
42 | models.TF_SERVING_GRPC_PORT = "255"
43 | mocker.patch("aizynthfinder.utils.models.grpc.insecure_channel")
44 | mocked_pred_service = mocker.patch(
45 | "aizynthfinder.utils.models.prediction_service_pb2_grpc.PredictionServiceStub"
46 | )
47 | mocker.patch(
48 | "aizynthfinder.utils.models.get_model_metadata_pb2.GetModelMetadataRequest"
49 | )
50 | mocker.patch("aizynthfinder.utils.models.predict_pb2.PredictRequest")
51 | mocked_message = mocker.patch("aizynthfinder.utils.models.MessageToDict")
52 | mocked_message.return_value = signature_grpc
53 |
54 | def wrapper(response=None):
55 | if not response:
56 | return
57 | mocked_pred_service.return_value.Predict.return_value.outputs = response
58 |
59 | yield wrapper
60 |
61 | models.TF_SERVING_HOST = None
62 | models.TF_SERVING_GRPC_PORT = None
63 |
64 |
65 | @pytest.fixture()
66 | def signature_rest():
67 | return {
68 | "metadata": {
69 | "signature_def": {
70 | "signature_def": {
71 | "serving_default": {
72 | "inputs": {
73 | "first_layer": {
74 | "tensor_shape": {"dim": [{"size": 1}, {"size": 2048}]}
75 | }
76 | }
77 | }
78 | }
79 | }
80 | }
81 | }
82 |
83 |
84 | @pytest.fixture()
85 | def signature_grpc():
86 | return {
87 | "metadata": {
88 | "signature_def": {
89 | "signatureDef": {
90 | "serving_default": {
91 | "inputs": {
92 | "first_layer": {
93 | "tensorShape": {"dim": [{"size": 1}, {"size": 2048}]}
94 | }
95 | },
96 | "outputs": {"output": None},
97 | }
98 | }
99 | }
100 | }
101 | }
102 |
103 |
104 | @pytest.mark.xfail(
105 | condition=not SUPPORT_EXTERNAL_APIS, reason="API packages not installed"
106 | )
107 | def test_setup_tf_rest_model(signature_rest, setup_rest_mock):
108 | setup_rest_mock(signature_rest)
109 |
110 | model = ExternalModelViaREST("dummy")
111 |
112 | assert len(model) == 2048
113 |
114 |
115 | @pytest.mark.xfail(
116 | condition=not SUPPORT_EXTERNAL_APIS, reason="API packages not installed"
117 | )
118 | def test_predict_tf_rest_model(signature_rest, setup_rest_mock):
119 | responses = [signature_rest, {"outputs": [0.0, 1.0]}]
120 | setup_rest_mock(responses)
121 | model = ExternalModelViaREST("dummy")
122 |
123 | out = model.predict(np.zeros([1, len(model)]))
124 |
125 | assert list(out) == [0.0, 1.0]
126 |
127 |
128 | @pytest.mark.xfail(
129 | condition=not SUPPORT_EXTERNAL_APIS, reason="API packages not installed"
130 | )
131 | def test_setup_tf_grpc_model(setup_grpc_mock):
132 | setup_grpc_mock()
133 |
134 | model = ExternalModelViaGRPC("dummy")
135 |
136 | assert len(model) == 2048
137 |
138 |
139 | @pytest.mark.xfail(
140 | condition=not SUPPORT_EXTERNAL_APIS,
141 | reason="Tensorflow and API packages not installed",
142 | )
143 | def test_predict_tf_grpc_model(setup_grpc_mock):
144 | import tensorflow as tf
145 |
146 | setup_grpc_mock({"output": tf.make_tensor_proto(tf.constant([0.0, 1.0]))})
147 | model = ExternalModelViaGRPC("dummy")
148 |
149 | out = model.predict(np.zeros([1, len(model)]))
150 |
151 | assert list(out) == [0.0, 1.0]
152 |
--------------------------------------------------------------------------------
/tests/utils/test_file_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gzip
3 | import json
4 |
5 | import pytest
6 | import pandas as pd
7 |
8 | from aizynthfinder.utils.files import (
9 | cat_datafiles,
10 | split_file,
11 | start_processes,
12 | read_datafile,
13 | save_datafile,
14 | )
15 |
16 |
17 | @pytest.fixture
18 | def create_dummy_file(tmpdir, mocker):
19 | patched_tempfile = mocker.patch("aizynthfinder.utils.files.tempfile.mktemp")
20 | split_files = [
21 | tmpdir / "split1",
22 | tmpdir / "split2",
23 | tmpdir / "split3",
24 | ]
25 | patched_tempfile.side_effect = split_files
26 | filename = tmpdir / "input"
27 |
28 | def wrapper(content):
29 | with open(filename, "w") as fileobj:
30 | fileobj.write(content)
31 | return filename, split_files
32 |
33 | return wrapper
34 |
35 |
36 | def test_split_file_even(create_dummy_file):
37 | filename, split_files = create_dummy_file("\n".join(list("abcdef")))
38 |
39 | split_file(filename, 3)
40 |
41 | read_lines = []
42 | for filename in split_files:
43 | assert os.path.exists(filename)
44 | with open(filename, "r") as fileobj:
45 | read_lines.append(fileobj.read())
46 | assert read_lines[0] == "a\nb"
47 | assert read_lines[1] == "c\nd"
48 | assert read_lines[2] == "e\nf"
49 |
50 |
51 | def test_split_file_odd(create_dummy_file):
52 | filename, split_files = create_dummy_file("\n".join(list("abcdefg")))
53 |
54 | split_file(filename, 3)
55 |
56 | read_lines = []
57 | for filename in split_files:
58 | assert os.path.exists(filename)
59 | with open(filename, "r") as fileobj:
60 | read_lines.append(fileobj.read())
61 | assert read_lines[0] == "a\nb\nc"
62 | assert read_lines[1] == "d\ne"
63 | assert read_lines[2] == "f\ng"
64 |
65 |
66 | def test_start_processes(tmpdir):
67 |
68 | script_filename = str(tmpdir / "dummy.py")
69 | with open(script_filename, "w") as fileobj:
70 | fileobj.write("import sys\nimport time\nprint(sys.argv[1])\ntime.sleep(2)\n")
71 |
72 | def create_cmd(index, filename):
73 | return ["python", script_filename, f"{filename}-{index}"]
74 |
75 | start_processes(["dummy", "dummy"], str(tmpdir / "log"), create_cmd, 2)
76 |
77 | for index in [1, 2]:
78 | logfile = str(tmpdir / f"log{index}.log")
79 | assert os.path.exists(logfile)
80 | with open(logfile, "r") as fileobj:
81 | lines = fileobj.read()
82 | assert lines == f"dummy-{index}\n"
83 |
84 |
85 | def test_cat_hdf(create_dummy_stock1, create_dummy_stock2, tmpdir):
86 | filename = str(tmpdir / "output.hdf")
87 | inputs = [create_dummy_stock1("hdf5"), create_dummy_stock2]
88 |
89 | cat_datafiles(inputs, filename)
90 |
91 | data = pd.read_hdf(filename, "table")
92 | assert len(data) == 4
93 | assert list(data.inchi_key.values) == [
94 | "UHOVQNZJYSORNB-UHFFFAOYSA-N",
95 | "YXFVVABEGXRONW-UHFFFAOYSA-N",
96 | "UHOVQNZJYSORNB-UHFFFAOYSA-N",
97 | "ISWSIDIOOBJBQZ-UHFFFAOYSA-N",
98 | ]
99 |
100 |
101 | def test_cat_hdf_no_trees(tmpdir, create_dummy_stock1, create_dummy_stock2):
102 | hdf_filename = str(tmpdir / "output.hdf")
103 | tree_filename = str(tmpdir / "trees.json")
104 | inputs = [create_dummy_stock1("hdf5"), create_dummy_stock2]
105 |
106 | cat_datafiles(inputs, hdf_filename, tree_filename)
107 |
108 | assert not os.path.exists(tree_filename)
109 |
110 |
111 | def test_cat_hdf_trees(tmpdir):
112 | hdf_filename = str(tmpdir / "output.hdf")
113 | tree_filename = str(tmpdir / "trees.json")
114 | filename1 = str(tmpdir / "file1.hdf5")
115 | filename2 = str(tmpdir / "file2.hdf5")
116 | trees1 = [[1], [2]]
117 | trees2 = [[3], [4]]
118 | pd.DataFrame({"mol": ["A", "B"], "trees": trees1}).to_hdf(filename1, "table")
119 | pd.DataFrame({"mol": ["A", "B"], "trees": trees2}).to_hdf(filename2, "table")
120 |
121 | cat_datafiles([filename1, filename2], hdf_filename, tree_filename)
122 |
123 | assert os.path.exists(tree_filename + ".gz")
124 | with gzip.open(tree_filename + ".gz", "rt", encoding="UTF-8") as fileobj:
125 | trees_cat = json.load(fileobj)
126 | assert trees_cat == trees1 + trees2
127 | assert "trees" not in pd.read_hdf(hdf_filename, "table")
128 |
129 |
130 | @pytest.mark.parametrize(
131 | ("filename"),
132 | [
133 | ("temp.json"),
134 | ("temp.hdf5"),
135 | ],
136 | )
137 | def test_save_load_datafile_roundtrip(filename, tmpdir):
138 | data1 = pd.DataFrame({"a": [0, 1, 2], "b": [2, 3, 4]})
139 |
140 | save_datafile(data1, tmpdir / filename)
141 |
142 | data2 = read_datafile(tmpdir / filename)
143 |
144 | assert data1.columns.to_list() == data2.columns.to_list()
145 | assert data1.a.to_list() == data2.a.to_list()
146 | assert data1.b.to_list() == data2.b.to_list()
147 |
--------------------------------------------------------------------------------
/tests/utils/test_image.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import json
4 | from tarfile import TarFile
5 | from pathlib import Path
6 |
7 | import pytest
8 | from PIL import Image, ImageDraw
9 |
10 | from aizynthfinder.utils import image
11 | from aizynthfinder.chem import TreeMolecule, TemplatedRetroReaction
12 |
13 |
14 | @pytest.fixture
15 | def new_image():
16 | img = Image.new(mode="RGB", size=(300, 300), color="white")
17 | draw = ImageDraw.Draw(img)
18 | draw.point([5, 150, 100, 250], fill="black")
19 | return img
20 |
21 |
22 | @pytest.fixture
23 | def setup_graph():
24 | mol1 = TreeMolecule(smiles="CCCO", parent=None)
25 | reaction = TemplatedRetroReaction(mol=mol1, smarts="")
26 |
27 | return [mol1], [reaction], [(mol1, reaction)], ["green"]
28 |
29 |
30 | def test_crop(new_image):
31 | cropped = image.crop_image(new_image)
32 |
33 | assert cropped.width == 136
34 | assert cropped.height == 141
35 | assert cropped.getpixel((21, 21)) == (0, 0, 0)
36 | assert cropped.getpixel((116, 121)) == (0, 0, 0)
37 |
38 |
39 | def test_rounded_rectangle(new_image):
40 | color = (255, 0, 0)
41 | modified = image.draw_rounded_rectangle(new_image, color=color)
42 |
43 | assert modified.getpixel((0, 150)) == color
44 | assert modified.getpixel((150, 0)) == color
45 | assert modified.getpixel((299, 150)) == color
46 | assert modified.getpixel((150, 299)) == color
47 |
48 |
49 | def test_save_molecule_images():
50 | nfiles = len(os.listdir(image.IMAGE_FOLDER))
51 |
52 | mols = [
53 | TreeMolecule(smiles="CCCO", parent=None),
54 | TreeMolecule(smiles="CCCO", parent=None),
55 | TreeMolecule(smiles="CCCCO", parent=None),
56 | ]
57 |
58 | image.save_molecule_images(mols, ["green", "green", "green"])
59 |
60 | assert len(os.listdir(image.IMAGE_FOLDER)) == nfiles + 2
61 |
62 | image.save_molecule_images(mols, ["green", "orange", "green"])
63 |
64 | assert len(os.listdir(image.IMAGE_FOLDER)) == nfiles + 2
65 |
66 |
67 | def test_visjs_page(mocker, tmpdir, setup_graph):
68 | mkdtemp_patch = mocker.patch("aizynthfinder.utils.image.tempfile.mkdtemp")
69 | mkdtemp_patch.return_value = str(tmpdir / "tmp")
70 | os.mkdir(tmpdir / "tmp")
71 | molecules, reactions, edges, frame_colors = setup_graph
72 | filename = str(tmpdir / "arch.tar")
73 |
74 | image.make_visjs_page(filename, molecules, reactions, edges, frame_colors)
75 |
76 | assert os.path.exists(filename)
77 | with TarFile(filename) as tarobj:
78 | assert "./route.html" in tarobj.getnames()
79 | assert len([name for name in tarobj.getnames() if name.endswith(".png")]) == 1
80 |
81 |
82 | def test_image_factory(request):
83 | route_path = Path(request.fspath).parent.parent / "data" / "branched_route.json"
84 | with open(route_path, "r") as fileobj:
85 | dict_ = json.load(fileobj)
86 | dict_["children"][0]["children"][1]["hide"] = True
87 |
88 | factory0 = image.RouteImageFactory(dict_)
89 |
90 | factory_tighter = image.RouteImageFactory(dict_, margin=50)
91 | assert factory0.image.width == factory_tighter.image.width + 150
92 | assert factory0.image.height == factory_tighter.image.height + 175
93 |
94 | factory_hidden = image.RouteImageFactory(dict_, show_all=False)
95 | assert factory0.image.width == factory_hidden.image.width
96 | assert factory0.image.height > factory_hidden.image.height
97 |
--------------------------------------------------------------------------------
/tests/utils/test_local_onnx_model.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List
2 |
3 | import numpy as np
4 | import pytest
5 | import pytest_mock
6 | from aizynthfinder.utils import models
7 |
8 |
9 | def test_local_onnx_model_predict(mock_onnx_model: pytest_mock.MockerFixture) -> None:
10 | onnx_model = models.LocalOnnxModel("test_model.onnx")
11 | output = onnx_model.predict(np.array([1]))
12 | expected_output = np.array([[0.2, 0.7, 0.1]])
13 |
14 | assert np.array_equal(output, expected_output)
15 |
16 |
17 | def test_local_onnx_model_length(mock_onnx_model: pytest_mock.MockerFixture) -> None:
18 | onnx_model = models.LocalOnnxModel("test_model.onnx")
19 | output = len(onnx_model)
20 | expected_output = 3
21 |
22 | assert output == expected_output
23 |
24 |
25 | def test_local_onnx_model_output_size(
26 | mock_onnx_model: pytest_mock.MockerFixture,
27 | ) -> None:
28 | onnx_model = models.LocalOnnxModel("test_model.onnx")
29 | output = onnx_model.output_size
30 | expected_output = 3
31 |
32 | assert output == expected_output
33 |
--------------------------------------------------------------------------------
/tests/utils/test_scscore.py:
--------------------------------------------------------------------------------
1 | import pickle
2 |
3 | import pytest
4 | from rdkit import Chem
5 |
6 | from aizynthfinder.utils.sc_score import SCScore
7 |
8 | # Dummy, tiny SCScore model
9 | _weights0 = [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1]]
10 | _weights1 = [[1, 1], [1, 1]]
11 | _weights = [_weights0, _weights1]
12 | _biases = [[0, 0], [0]]
13 |
14 |
15 | def test_scscore(tmpdir):
16 | filename = str(tmpdir / "dummy.pickle")
17 | with open(filename, "wb") as fileobj:
18 | pickle.dump((_weights, _biases), fileobj)
19 | scorer = SCScore(filename, 5)
20 | mol = Chem.MolFromSmiles("C")
21 |
22 | assert pytest.approx(scorer(mol), abs=1e-3) == 4.523
23 |
--------------------------------------------------------------------------------