├── syntheseus ├── py.typed ├── cli │ ├── __init__.py │ ├── main.py │ └── search_config.yml ├── tests │ ├── __init__.py │ ├── cli │ │ ├── __init__.py │ │ ├── test_search.py │ │ └── test_cli.py │ ├── interface │ │ ├── __init__.py │ │ ├── test_bag.py │ │ ├── test_toy_models.py │ │ ├── test_reaction.py │ │ └── test_molecule.py │ ├── search │ │ ├── __init__.py │ │ ├── graph │ │ │ ├── __init__.py │ │ │ ├── test_route.py │ │ │ ├── test_base.py │ │ │ ├── test_standardization.py │ │ │ ├── test_message_passing.py │ │ │ └── test_molset.py │ │ ├── algorithms │ │ │ ├── __init__.py │ │ │ └── test_random.py │ │ ├── analysis │ │ │ ├── __init__.py │ │ │ ├── test_solution_time.py │ │ │ ├── conftest.py │ │ │ └── test_diversity.py │ │ ├── node_evaluation │ │ │ ├── __init__.py │ │ │ └── test_common.py │ │ ├── test_visualization.py │ │ └── test_mol_inventory.py │ ├── reaction_prediction │ │ ├── __init__.py │ │ ├── chem │ │ │ ├── __init__.py │ │ │ └── test_utils.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ └── test_dataset.py │ │ ├── inference │ │ │ ├── __init__.py │ │ │ └── test_models.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── test_misc.py │ │ │ └── test_parallel.py │ └── conftest.py ├── interface │ ├── __init__.py │ ├── bag.py │ ├── reaction.py │ └── molecule.py ├── search │ ├── analysis │ │ ├── __init__.py │ │ └── solution_time.py │ ├── graph │ │ ├── __init__.py │ │ ├── message_passing │ │ │ ├── __init__.py │ │ │ └── update_functions.py │ │ ├── route.py │ │ ├── node.py │ │ └── base_graph.py │ ├── utils │ │ ├── __init__.py │ │ └── misc.py │ ├── algorithms │ │ ├── __init__.py │ │ ├── mcts │ │ │ ├── __init__.py │ │ │ └── molset.py │ │ ├── best_first │ │ │ └── __init__.py │ │ ├── mixins.py │ │ ├── breadth_first.py │ │ └── random.py │ ├── __init__.py │ ├── node_evaluation │ │ ├── __init__.py │ │ └── common.py │ └── mol_inventory.py ├── reaction_prediction │ ├── __init__.py │ ├── chem │ │ ├── __init__.py │ │ └── utils.py │ ├── data │ │ ├── __init__.py │ │ ├── reaction_sample.py │ │ └── dataset.py │ ├── utils │ │ ├── __init__.py │ │ ├── testing.py │ │ ├── metrics.py │ │ ├── model_loading.py │ │ ├── parallel.py │ │ ├── config.py │ │ ├── downloading.py │ │ ├── inference.py │ │ └── misc.py │ ├── models │ │ ├── __init__.py │ │ └── retro_knn.py │ ├── environment_gln │ │ ├── environment.yml │ │ └── Dockerfile │ └── inference │ │ ├── default_checkpoint_links.yml │ │ ├── __init__.py │ │ ├── config.py │ │ ├── base.py │ │ ├── gln.py │ │ ├── toy_models.py │ │ ├── graph2edits.py │ │ ├── retro_knn.py │ │ └── local_retro.py └── __init__.py ├── docs ├── images │ ├── logo.png │ └── logo.svg ├── tutorials │ └── .gitignore ├── index.md ├── cli │ ├── search.md │ └── eval_single_step.md ├── single_step.md └── installation.md ├── environment.yml ├── CODE_OF_CONDUCT.md ├── .gitignore ├── SUPPORT.md ├── .coveragerc ├── .github ├── azure_pipelines │ └── code-security-analysis.yml └── workflows │ ├── release.yml │ ├── docs.yml │ └── ci.yml ├── environment_full.yml ├── LICENSE ├── DEVELOPMENT.md ├── mkdocs.yml ├── .pre-commit-config.yaml ├── SECURITY.md ├── pyproject.toml └── README.md /syntheseus/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/cli/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/interface/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/tests/cli/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/search/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/search/graph/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/search/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/tests/interface/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/tests/search/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/search/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/tests/search/graph/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/chem/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/search/algorithms/mcts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/tests/reaction_prediction/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/tests/search/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/tests/search/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/search/algorithms/best_first/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/tests/reaction_prediction/chem/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/tests/reaction_prediction/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/tests/search/node_evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/tests/reaction_prediction/inference/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /syntheseus/tests/reaction_prediction/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/syntheseus/HEAD/docs/images/logo.png -------------------------------------------------------------------------------- /docs/tutorials/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore files created as outputs from the tutorials 2 | *.pdf 3 | *.pkl 4 | -------------------------------------------------------------------------------- /syntheseus/search/utils/misc.py: -------------------------------------------------------------------------------- 1 | def lookup_by_name(module, name): 2 | return module.__dict__[name] 3 | -------------------------------------------------------------------------------- /syntheseus/search/__init__.py: -------------------------------------------------------------------------------- 1 | INT_INF = int(1e17) # integer large enough to practically be infinity, but less than 2^63 - 1 2 | -------------------------------------------------------------------------------- /syntheseus/search/node_evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from syntheseus.search.node_evaluation.base import BaseNodeEvaluator 2 | 3 | __all__ = ["BaseNodeEvaluator"] 4 | -------------------------------------------------------------------------------- /syntheseus/tests/search/analysis/test_solution_time.py: -------------------------------------------------------------------------------- 1 | """This function is tested minimally in the algorithm tests, so there are no specific tests at this time.""" 2 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: syntheseus 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - numpy 7 | - pip # ideally >=25.2 unless using python 3.8 8 | - python>=3.8 # ideally 3.9+ 9 | - rdkit 10 | -------------------------------------------------------------------------------- /syntheseus/search/graph/message_passing/__init__.py: -------------------------------------------------------------------------------- 1 | from syntheseus.search.graph.message_passing.run import run_message_passing 2 | from syntheseus.search.graph.message_passing.update_functions import ( 3 | depth_update, 4 | has_solution_update, 5 | ) 6 | 7 | __all__ = ["run_message_passing", "has_solution_update", "depth_update"] 8 | -------------------------------------------------------------------------------- /syntheseus/tests/search/graph/test_route.py: -------------------------------------------------------------------------------- 1 | """Route objects are tested implicity in other tests, so there are only minimal tests for now.""" 2 | 3 | from syntheseus.interface.molecule import Molecule 4 | 5 | 6 | def test_route_starting_molecules(minimal_synthesis_graph): 7 | assert minimal_synthesis_graph.get_starting_molecules() == {Molecule("CC")} 8 | -------------------------------------------------------------------------------- /syntheseus/search/algorithms/mcts/molset.py: -------------------------------------------------------------------------------- 1 | from syntheseus.search.algorithms.base import MolSetSearchAlgorithm 2 | from syntheseus.search.algorithms.mcts.base import BaseMCTS 3 | from syntheseus.search.graph.molset import MolSetGraph, MolSetNode 4 | 5 | 6 | class MolSetMCTS(BaseMCTS[MolSetGraph, MolSetNode, MolSetNode], MolSetSearchAlgorithm[int]): 7 | pass 8 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/environment_gln/environment.yml: -------------------------------------------------------------------------------- 1 | name: gln-env 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | - pyg 8 | dependencies: 9 | - python=3.9 10 | - pytorch=1.13.0=py3.9_cuda11.6_cudnn8.3.2_0 11 | - scipy 12 | - tqdm 13 | - boost 14 | - boost-cpp 15 | - cairo 16 | - cmake 17 | - eigen 18 | - gxx_linux-64 19 | - mkl<2024.1.0 20 | - pillow 21 | - pkg-config 22 | - py-boost 23 | - pyg==2.1.0 24 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /syntheseus/__init__.py: -------------------------------------------------------------------------------- 1 | from syntheseus.interface.bag import Bag 2 | from syntheseus.interface.models import BackwardReactionModel, ForwardReactionModel, ReactionModel 3 | from syntheseus.interface.molecule import Molecule 4 | from syntheseus.interface.reaction import Reaction, SingleProductReaction 5 | 6 | __all__ = [ 7 | "Molecule", 8 | "Reaction", 9 | "SingleProductReaction", 10 | "Bag", 11 | "ReactionModel", 12 | "BackwardReactionModel", 13 | "ForwardReactionModel", 14 | ] 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # Notebook checkpoints: 30 | .ipynb_checkpoints 31 | 32 | # Unit test / coverage reports 33 | .coverage 34 | .coverage.* 35 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support 2 | 3 | ## How to file issues and get help 4 | 5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 6 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 7 | feature request as a new Issue. 8 | 9 | For help and questions about using this project, please file a GitHub Issue or email us. 10 | We will try to respond to all issues. 11 | 12 | ## Microsoft Support Policy 13 | 14 | Support for this project is limited to the resources listed above. 15 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | source = syntheseus 4 | omit = syntheseus/tests/* 5 | 6 | [report] 7 | # Regexes for lines to exclude from consideration 8 | exclude_lines = 9 | # Have to re-enable the standard pragma 10 | pragma: no cover 11 | 12 | # Don't complain about missing debug-only code: 13 | def __repr__ 14 | if self\.debug 15 | 16 | # Don't complain if tests don't hit defensive assertion code: 17 | raise AssertionError 18 | raise NotImplementedError 19 | 20 | # Don't complain about abstract methods, they aren't run: 21 | @(abc\.)?abstractmethod 22 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/inference/default_checkpoint_links.yml: -------------------------------------------------------------------------------- 1 | backward: 2 | Chemformer: https://figshare.com/ndownloader/files/42009888 3 | GLN: https://figshare.com/ndownloader/files/45882867 4 | Graph2Edits: https://figshare.com/ndownloader/files/44194301 5 | LocalRetro: https://figshare.com/ndownloader/files/42287319 6 | MEGAN: https://figshare.com/ndownloader/files/42012732 7 | MHNreact: https://figshare.com/ndownloader/files/42012777 8 | RetroKNN: https://figshare.com/ndownloader/files/45662430 9 | RootAligned: https://figshare.com/ndownloader/files/42012792 10 | forward: 11 | Chemformer: https://figshare.com/ndownloader/files/42012708 12 | -------------------------------------------------------------------------------- /syntheseus/tests/search/graph/test_base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class BaseNodeTest(abc.ABC): 5 | """Base class which defines common tests for graph nodes.""" 6 | 7 | @abc.abstractmethod 8 | def get_node(self): 9 | pass 10 | 11 | def test_node_comparison(self): 12 | """Test that nodes are equal if and only if they are the same object.""" 13 | node1 = self.get_node() 14 | node2 = self.get_node() 15 | assert node1 is not node2 16 | assert node1 != node2 17 | 18 | @abc.abstractmethod 19 | def test_nodes_not_frozen(self): 20 | """Test that the fields of the node can be modified.""" 21 | -------------------------------------------------------------------------------- /.github/azure_pipelines/code-security-analysis.yml: -------------------------------------------------------------------------------- 1 | schedules: 2 | - cron: '0 4 * * 0' 3 | displayName: Weekly build 4 | branches: 5 | include: 6 | - main 7 | 8 | pr: none 9 | trigger: none 10 | 11 | pool: 12 | vmImage: 'windows-latest' 13 | 14 | steps: 15 | - task: CredScan@2 16 | inputs: 17 | toolMajorVersion: 'V2' 18 | - task: ComponentGovernanceComponentDetection@0 19 | inputs: 20 | scanType: 'Register' 21 | verbosity: 'Verbose' 22 | alertWarningLevel: 'High' 23 | - task: PublishSecurityAnalysisLogs@2 24 | inputs: 25 | ArtifactName: 'CodeAnalysisLogs' 26 | ArtifactType: 'Container' 27 | AllTools: true 28 | ToolLogsNotFoundAction: 'Standard' 29 | -------------------------------------------------------------------------------- /syntheseus/tests/interface/test_bag.py: -------------------------------------------------------------------------------- 1 | from syntheseus.interface.bag import Bag 2 | 3 | 4 | def test_basic_operations() -> None: 5 | bag_1 = Bag(["a", "b", "a"]) 6 | bag_2 = Bag(["b", "a", "a"]) 7 | bag_3 = Bag(["a", "b"]) 8 | 9 | # Test `__contains__`. 10 | assert "a" in bag_1 11 | assert "b" in bag_1 12 | assert "c" not in bag_1 13 | 14 | # Test `__eq__`. 15 | assert bag_1 == bag_2 16 | assert bag_1 != bag_3 17 | 18 | # Test `__iter__`. 19 | assert Bag(bag_1) == bag_1 20 | 21 | # Test `__len__`. 22 | assert len(bag_1) == 3 23 | assert len(bag_3) == 2 24 | 25 | # Test `__hash__`. 26 | assert len(set([bag_1, bag_2, bag_3])) == 2 27 | -------------------------------------------------------------------------------- /syntheseus/tests/reaction_prediction/utils/test_misc.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | 3 | import pytest 4 | 5 | from syntheseus.reaction_prediction.utils.misc import parallelize 6 | 7 | 8 | def square(x: int) -> int: 9 | return x * x 10 | 11 | 12 | @pytest.mark.parametrize("use_iterator", [False, True]) 13 | @pytest.mark.parametrize("num_processes", [0, 2]) 14 | def test_parallelize(use_iterator: bool, num_processes: int) -> None: 15 | inputs: Iterable[int] = [1, 2, 3, 4] 16 | expected_outputs = [square(x) for x in inputs] 17 | 18 | if use_iterator: 19 | inputs = iter(inputs) 20 | 21 | assert list(parallelize(square, inputs, num_processes=num_processes)) == expected_outputs 22 | -------------------------------------------------------------------------------- /environment_full.yml: -------------------------------------------------------------------------------- 1 | name: syntheseus-full 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - pyg # RetroKNN 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - faiss-gpu # RetroKNN 10 | - numpy 11 | - pip>=25.2 12 | - python==3.9.7 13 | - pytorch=2.2.2=py3.9_cuda12.1_cudnn8.9.2_0 14 | - pytorch-lightning==2.2.2 # Chemformer 15 | - pytorch-scatter==2.1.2 # RetroKNN 16 | - rdchiral_cpp # MHNreact 17 | - rdkit=2021.09.4 18 | - pip: 19 | - --find-links https://data.dgl.ai/wheels/torch-2.2/cu121/repo.html 20 | - dgl==2.4.0 # LocalRetro, RetroKNN 21 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/utils/testing.py: -------------------------------------------------------------------------------- 1 | def are_single_step_models_installed(): 2 | """Check whether the single-step models can be imported successfully.""" 3 | try: 4 | # Try to import the single-step model repositories to check if they are installed. It could 5 | # be the case that these are installed but their dependencies are not, in which case trying 6 | # to *use* the models would fail; nevertheless, the below is good enough for our usecases. 7 | 8 | import chemformer # noqa: F401 9 | import graph2edits # noqa: F401 10 | import local_retro # noqa: F401 11 | import megan # noqa: F401 12 | import mhnreact # noqa: F401 13 | import root_aligned # noqa: F401 14 | 15 | return True 16 | except ModuleNotFoundError: 17 | return False 18 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release Stable Version 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | version: 7 | required: true 8 | type: string 9 | 10 | permissions: 11 | contents: write 12 | 13 | jobs: 14 | push-tag: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Configure Git 19 | run: | 20 | git config user.name github-actions[bot] 21 | git config user.email 41898282+github-actions[bot]@users.noreply.github.com 22 | - run: | 23 | git tag -a v${{ inputs.version }} -m "Release v${{ inputs.version }}" 24 | git push origin v${{ inputs.version }} 25 | build-docs: 26 | needs: push-tag 27 | uses: ./.github/workflows/docs.yml 28 | with: 29 | versions: ${{ inputs.version }} stable 30 | -------------------------------------------------------------------------------- /syntheseus/tests/search/graph/test_standardization.py: -------------------------------------------------------------------------------- 1 | """ 2 | At the moment, only a smoke-test is done for graph standardization, 3 | but the exact behaviour is not well-tested. 4 | These tests should be added in the future. 5 | """ 6 | 7 | import pytest 8 | 9 | from syntheseus.search.graph.and_or import AndOrGraph 10 | from syntheseus.search.graph.molset import MolSetGraph 11 | from syntheseus.search.graph.standardization import get_unique_node_andor_graph 12 | 13 | 14 | def test_smoke_andor(andor_graph_non_minimal: AndOrGraph): 15 | with pytest.warns(UserWarning): 16 | output = get_unique_node_andor_graph(andor_graph_non_minimal) 17 | 18 | assert len(output) == len(andor_graph_non_minimal) # no nodes deleted here 19 | 20 | 21 | def test_smoke_molset(molset_tree_non_minimal: MolSetGraph): 22 | with pytest.warns(UserWarning): 23 | get_unique_node_andor_graph(molset_tree_non_minimal) 24 | -------------------------------------------------------------------------------- /syntheseus/cli/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Callable, Dict 3 | 4 | from syntheseus.cli import eval_single_step, search 5 | 6 | 7 | def main() -> None: 8 | supported_commands: Dict[str, Callable] = { 9 | "search": search.main, 10 | "eval-single-step": eval_single_step.main, 11 | } 12 | supported_command_names = ", ".join(supported_commands.keys()) 13 | 14 | if len(sys.argv) == 1: 15 | raise ValueError(f"Please choose a command from: {supported_command_names}") 16 | 17 | command = sys.argv[1] 18 | if command not in supported_commands: 19 | raise ValueError(f"Command {command} not supported; choose from: {supported_command_names}") 20 | 21 | # Drop the subcommand name and let the chosen command parse the rest of the arguments. 22 | del sys.argv[1] 23 | supported_commands[command]() 24 | 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /syntheseus/search/graph/message_passing/update_functions.py: -------------------------------------------------------------------------------- 1 | from syntheseus.search.graph.base_graph import RetrosynthesisSearchGraph 2 | from syntheseus.search.graph.node import BaseGraphNode 3 | 4 | 5 | def depth_update(node: BaseGraphNode, graph: RetrosynthesisSearchGraph) -> bool: 6 | parent_depths = [n.depth for n in graph.predecessors(node)] 7 | if len(parent_depths) == 0: 8 | new_depth = 0 9 | else: 10 | new_depth = min(parent_depths) + 1 11 | 12 | depth_changed = node.depth != new_depth 13 | node.depth = new_depth 14 | return depth_changed 15 | 16 | 17 | def has_solution_update(node: BaseGraphNode, graph: RetrosynthesisSearchGraph) -> bool: 18 | new_has_solution = node._has_intrinsic_solution() or node._has_solution_from_children( 19 | list(graph.successors(node)) 20 | ) 21 | 22 | old_has_solution = node.has_solution 23 | node.has_solution = new_has_solution 24 | return new_has_solution != old_has_solution 25 | -------------------------------------------------------------------------------- /syntheseus/search/algorithms/mixins.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Generic, TypeVar 4 | 5 | from syntheseus.search.algorithms.base import SearchAlgorithm 6 | from syntheseus.search.graph.node import BaseGraphNode 7 | from syntheseus.search.node_evaluation import BaseNodeEvaluator 8 | 9 | NodeType = TypeVar("NodeType", bound=BaseGraphNode) 10 | 11 | 12 | class ValueFunctionMixin(SearchAlgorithm, Generic[NodeType]): 13 | def __init__(self, *args, value_function: BaseNodeEvaluator[NodeType], **kwargs): 14 | super().__init__(*args, **kwargs) 15 | self.value_function = value_function 16 | 17 | def set_node_values(self, nodes, graph): 18 | output_nodes = super().set_node_values(nodes, graph) 19 | for node in output_nodes: 20 | node.data.setdefault("num_calls_value_function", self.value_function.num_calls) 21 | return output_nodes 22 | 23 | def reset(self) -> None: 24 | super().reset() 25 | self.value_function.reset() 26 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/inference/__init__.py: -------------------------------------------------------------------------------- 1 | from syntheseus.reaction_prediction.inference.chemformer import ChemformerModel 2 | from syntheseus.reaction_prediction.inference.gln import GLNModel 3 | from syntheseus.reaction_prediction.inference.graph2edits import Graph2EditsModel 4 | from syntheseus.reaction_prediction.inference.local_retro import LocalRetroModel 5 | from syntheseus.reaction_prediction.inference.megan import MEGANModel 6 | from syntheseus.reaction_prediction.inference.mhnreact import MHNreactModel 7 | from syntheseus.reaction_prediction.inference.retro_knn import RetroKNNModel 8 | from syntheseus.reaction_prediction.inference.root_aligned import RootAlignedModel 9 | from syntheseus.reaction_prediction.inference.toy_models import ( 10 | LinearMoleculesToyModel, 11 | ListOfReactionsToyModel, 12 | ) 13 | 14 | __all__ = [ 15 | "ChemformerModel", 16 | "GLNModel", 17 | "Graph2EditsModel", 18 | "LinearMoleculesToyModel", 19 | "ListOfReactionsToyModel", 20 | "LocalRetroModel", 21 | "MEGANModel", 22 | "MHNreactModel", 23 | "RetroKNNModel", 24 | "RootAlignedModel", 25 | ] 26 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Docs 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | workflow_call: 7 | inputs: 8 | versions: 9 | required: true 10 | type: string 11 | workflow_dispatch: 12 | 13 | permissions: 14 | contents: write 15 | 16 | jobs: 17 | deploy: 18 | runs-on: ubuntu-latest 19 | steps: 20 | - uses: actions/checkout@v4 21 | - name: Configure Git 22 | run: | 23 | git config user.name github-actions[bot] 24 | git config user.email 41898282+github-actions[bot]@users.noreply.github.com 25 | git fetch origin gh-pages --depth=1 26 | - uses: actions/setup-python@v4 27 | with: 28 | python-version: 3.x 29 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV 30 | - uses: actions/cache@v3 31 | with: 32 | key: mkdocs-material-${{ env.cache_id }} 33 | path: .cache 34 | restore-keys: | 35 | mkdocs-material- 36 | - run: pip install mkdocs-material mkdocs-jupyter mike 37 | - run: mike deploy --push --update-aliases ${{ inputs.versions != '' && inputs.versions || 'dev' }} 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 |
2 | ![Image title](https://github.com/microsoft/syntheseus/assets/61470923/f01a9939-61fa-4461-a124-c13eddcdd75a){width="450"} 3 |

Navigating the labyrinth of synthesis planning

4 |
5 | 6 | --- 7 | 8 | [![CI](https://github.com/microsoft/syntheseus/actions/workflows/ci.yml/badge.svg?branch=main)](https://github.com/microsoft/syntheseus/actions/workflows/ci.yml) 9 | [![Python Version](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/) 10 | [![pypi](https://img.shields.io/pypi/v/syntheseus.svg)](https://pypi.org/project/syntheseus/) 11 | [![code style](https://img.shields.io/badge/code%20style-black-202020.svg)](https://github.com/ambv/black) 12 | [![License](https://img.shields.io/badge/license-MIT-green.svg)](https://github.com/microsoft/syntheseus/blob/main/LICENSE) 13 | 14 | Syntheseus is a package for end-to-end retrosynthetic planning. 15 | 16 | - ⚒️ Combines search algorithms and reaction models in a standardized way 17 | - 🧭 Includes implementations of common search algorithms 18 | - 🧪 Includes wrappers for state-of-the-art reaction models 19 | - ⚙️ Exposes a simple API to plug in custom models and algorithms 20 | - 📈 Can be used to benchmark components of a retrosynthesis pipeline 21 | -------------------------------------------------------------------------------- /syntheseus/interface/bag.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Collection 2 | from typing import Generic, Iterable, Iterator, Protocol, TypeVar 3 | 4 | 5 | class Comparable(Protocol): 6 | def __lt__(self, __other) -> bool: 7 | ... 8 | 9 | 10 | ElementT = TypeVar("ElementT", bound=Comparable) 11 | 12 | 13 | class Bag(Collection, Generic[ElementT]): 14 | """Class representing a frozen multi-set (i.e. set where elements are allowed to repeat). 15 | 16 | The bag elements are internally stored as a sorted tuple for simplicity, thus lookup time is 17 | linear in terms of the size of the bag. 18 | """ 19 | 20 | def __init__(self, values: Iterable[ElementT]) -> None: 21 | self._values = tuple(sorted(values)) 22 | 23 | def __iter__(self) -> Iterator: 24 | return iter(self._values) 25 | 26 | def __contains__(self, element) -> bool: 27 | return element in self._values 28 | 29 | def __eq__(self, other) -> bool: 30 | if isinstance(other, Bag): 31 | return self._values == other._values 32 | else: 33 | return False 34 | 35 | def __len__(self) -> int: 36 | return len(self._values) 37 | 38 | def __repr__(self) -> str: 39 | return repr(self._values) 40 | 41 | def __hash__(self) -> int: 42 | return hash(self._values) 43 | -------------------------------------------------------------------------------- /DEVELOPMENT.md: -------------------------------------------------------------------------------- 1 | To release a new stable version of `syntheseus` one needs to complete the following steps: 2 | 3 | 1. Create a PR editing the `CHANGELOG.md` analogous to [#88](https://github.com/microsoft/syntheseus/pull/88) (note how the changelog has to be modified in three places). 4 | 2. Run the ["Release Stable Version" workflow](https://github.com/microsoft/syntheseus/actions/workflows/release.yml), providing the version number as argument (use the `x.y.z` format _without_ the leading "v"; the workflow will prepend it wherever necessary). Make sure the branch is set to `main`. If this step does not work as intended the tag can always be deleted manually, while the changes to the docs can be reverted by removing a commit from the `gh-pages` branch (requires a force push). 5 | 3. Create a GitHub Release from the [newly created tag](https://github.com/microsoft/syntheseus/tags). Set the name to `syntheseus x.y.z`. The description should be the list of changes copied from the changelog (example [here](https://github.com/microsoft/syntheseus/releases/tag/v0.4.0)). Consider including a short description before the list of changes to describe the main gist of the release. 6 | 4. Release a new version to PyPI (e.g. by following [these instructions](https://realpython.com/pypi-publish-python-package/)). Consider publishing to [Test PyPI](https://test.pypi.org/) first to verify that the README renders correctly. 7 | -------------------------------------------------------------------------------- /syntheseus/tests/conftest.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | 5 | from syntheseus.interface.bag import Bag 6 | from syntheseus.interface.molecule import Molecule 7 | from syntheseus.interface.reaction import SingleProductReaction 8 | 9 | 10 | @pytest.fixture 11 | def cocs_mol() -> Molecule: 12 | """Returns the molecule with smiles 'COCS'.""" 13 | return Molecule("COCS", make_rdkit_mol=False) 14 | 15 | 16 | @pytest.fixture 17 | def ossc_mol() -> Molecule: 18 | """Returns the molecule with smiles 'OSSC'.""" 19 | return Molecule("OSSC", make_rdkit_mol=False) 20 | 21 | 22 | @pytest.fixture 23 | def soos_mol() -> Molecule: 24 | """Returns the molecule with smiles 'SOOS'.""" 25 | return Molecule("SOOS", make_rdkit_mol=False) 26 | 27 | 28 | @pytest.fixture 29 | def rxn_cocs_from_co_cs(cocs_mol: Molecule) -> SingleProductReaction: 30 | """Returns a reaction with COCS as the product.""" 31 | return SingleProductReaction(product=cocs_mol, reactants=Bag([Molecule("CO"), Molecule("CS")])) 32 | 33 | 34 | @pytest.fixture 35 | def rxn_cs_from_cc() -> SingleProductReaction: 36 | return SingleProductReaction(product=Molecule("CS"), reactants=Bag([Molecule("CC")])) 37 | 38 | 39 | @pytest.fixture 40 | def rxn_cocs_from_cocc(cocs_mol: Molecule) -> SingleProductReaction: 41 | return SingleProductReaction(product=cocs_mol, reactants=Bag([Molecule("COCC")])) 42 | -------------------------------------------------------------------------------- /syntheseus/tests/reaction_prediction/utils/test_parallel.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from syntheseus import Molecule 4 | from syntheseus.tests.cli.test_eval_single_step import DummyModel 5 | 6 | try: 7 | import torch 8 | 9 | from syntheseus.reaction_prediction.utils.parallel import ParallelReactionModel 10 | 11 | torch_available = True 12 | cuda_available = torch.cuda.is_available() 13 | except ModuleNotFoundError: 14 | torch_available = False 15 | cuda_available = False 16 | 17 | 18 | @pytest.mark.skipif( 19 | not torch_available, reason="Simple testing of parallel inference requires torch" 20 | ) 21 | def test_parallel_reaction_model_cpu() -> None: 22 | # We cannot really run this on CPU, so just check if the model creation works as normal. 23 | parallel_model: ParallelReactionModel = ParallelReactionModel( 24 | model_fn=DummyModel, devices=["cpu"] * 4 25 | ) 26 | assert parallel_model([]) == [] 27 | 28 | 29 | @pytest.mark.skipif( 30 | not cuda_available, reason="Full testing of parallel inference requires GPU to be available" 31 | ) 32 | def test_parallel_reaction_model_gpu() -> None: 33 | model = DummyModel() 34 | parallel_model: ParallelReactionModel = ParallelReactionModel( 35 | model_fn=DummyModel, devices=["cuda:0"] * 4 36 | ) 37 | 38 | inputs = [Molecule("C" * length) for length in range(1, 6)] 39 | assert parallel_model(inputs) == model(inputs) 40 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/inference/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from enum import Enum 3 | from typing import Any, Dict 4 | 5 | from omegaconf import MISSING 6 | 7 | from syntheseus.reaction_prediction.inference import ( 8 | ChemformerModel, 9 | GLNModel, 10 | Graph2EditsModel, 11 | LocalRetroModel, 12 | MEGANModel, 13 | MHNreactModel, 14 | RetroKNNModel, 15 | RootAlignedModel, 16 | ) 17 | 18 | 19 | class ForwardModelClass(Enum): 20 | Chemformer = ChemformerModel 21 | 22 | 23 | class BackwardModelClass(Enum): 24 | Chemformer = ChemformerModel 25 | GLN = GLNModel 26 | Graph2Edits = Graph2EditsModel 27 | LocalRetro = LocalRetroModel 28 | MEGAN = MEGANModel 29 | MHNreact = MHNreactModel 30 | RetroKNN = RetroKNNModel 31 | RootAligned = RootAlignedModel 32 | 33 | 34 | @dataclass 35 | class ModelConfig: 36 | """Config for loading any reaction models, forward or backward.""" 37 | 38 | model_dir: str = MISSING 39 | model_kwargs: Dict[str, Any] = field(default_factory=dict) 40 | 41 | 42 | @dataclass 43 | class ForwardModelConfig(ModelConfig): 44 | """Config for loading one of the supported forward models.""" 45 | 46 | model_class: ForwardModelClass = MISSING 47 | 48 | 49 | @dataclass 50 | class BackwardModelConfig(ModelConfig): 51 | """Config for loading one of the supported backward models.""" 52 | 53 | model_class: BackwardModelClass = MISSING 54 | -------------------------------------------------------------------------------- /syntheseus/tests/interface/test_toy_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from syntheseus.interface.molecule import Molecule 4 | from syntheseus.interface.reaction import SingleProductReaction 5 | from syntheseus.reaction_prediction.inference import ( 6 | LinearMoleculesToyModel, 7 | ListOfReactionsToyModel, 8 | ) 9 | 10 | 11 | def test_linear_molecules_invalid_molecule() -> None: 12 | """ 13 | The LinearMolecules model is only defined on the space of linear molecules. 14 | If called on a non-linear molecule (e.g. with branching) it is currently set up to throw an error. 15 | This test ensures that this happens. 16 | 17 | NOTE: in the future the behaviour could be changed to just return an empty list, 18 | but for a toy example I thought it would be best to alert the user with a warning. 19 | """ 20 | rxn_model = LinearMoleculesToyModel() 21 | with pytest.raises(AssertionError): 22 | rxn_model([Molecule("CC(C)C")]) 23 | 24 | 25 | def test_list_of_reactions_model( 26 | rxn_cocs_from_co_cs: SingleProductReaction, 27 | rxn_cocs_from_cocc: SingleProductReaction, 28 | rxn_cs_from_cc: SingleProductReaction, 29 | ) -> None: 30 | """Simple test of the ListOfReactionsModel class.""" 31 | model = ListOfReactionsToyModel([rxn_cocs_from_co_cs, rxn_cocs_from_cocc, rxn_cs_from_cc]) 32 | output = model([Molecule("COCS"), Molecule("CS"), Molecule("CO")]) 33 | assert output == [[rxn_cocs_from_co_cs, rxn_cocs_from_cocc], [rxn_cs_from_cc], []] 34 | -------------------------------------------------------------------------------- /syntheseus/cli/search_config.yml: -------------------------------------------------------------------------------- 1 | mcts: 2 | Chemformer: 3 | bound_constant: 1 4 | policy_kwargs: 5 | temperature: 8.0 6 | value_function_kwargs: 7 | constant: 0.75 8 | GLN: 9 | bound_constant: 100 10 | policy_kwargs: 11 | temperature: 4.0 12 | value_function_kwargs: 13 | constant: 0.5 14 | LocalRetro: 15 | bound_constant: 1 16 | policy_kwargs: 17 | temperature: 4.0 18 | value_function_kwargs: 19 | constant: 0.5 20 | MEGAN: 21 | bound_constant: 1 22 | policy_kwargs: 23 | clip_probability_max: 0.9999 24 | clip_probability_min: 1.0e-05 25 | temperature: 2.0 26 | value_function_kwargs: 27 | constant: 0.75 28 | MHNreact: 29 | bound_constant: 1 30 | policy_kwargs: 31 | temperature: 8.0 32 | value_function_kwargs: 33 | constant: 0.5 34 | RetroKNN: 35 | bound_constant: 1 36 | policy_kwargs: 37 | temperature: 8.0 38 | value_function_kwargs: 39 | constant: 0.75 40 | RootAligned: 41 | bound_constant: 10 42 | policy_kwargs: 43 | clip_probability_max: 0.999 44 | clip_probability_min: 1.0e-05 45 | temperature: 8.0 46 | value_function_kwargs: 47 | constant: 0.5 48 | retro_star: 49 | MEGAN: 50 | and_node_cost_fn_kwargs: 51 | clip_probability_max: 0.99 52 | clip_probability_min: 1.0e-05 53 | RootAligned: 54 | and_node_cost_fn_kwargs: 55 | clip_probability_max: 0.9999 56 | clip_probability_min: 1.0e-08 57 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Syntheseus 2 | site_url: https://microsoft.github.io/syntheseus/ 3 | 4 | repo_name: microsoft/syntheseus 5 | repo_url: https://github.com/microsoft/syntheseus 6 | edit_uri: edit/main/docs/ 7 | 8 | theme: 9 | name: material 10 | palette: 11 | - media: "(prefers-color-scheme: light)" 12 | scheme: default 13 | primary: black 14 | accent: red 15 | toggle: 16 | icon: material/toggle-switch 17 | name: "Switch to dark mode" 18 | - media: "(prefers-color-scheme: dark)" 19 | scheme: slate 20 | primary: black 21 | accent: red 22 | toggle: 23 | icon: material/toggle-switch-off-outline 24 | name: "Switch to light mode" 25 | features: 26 | - content.code.copy 27 | - navigation.tabs 28 | 29 | nav: 30 | - Get Started: 31 | - Overview: index.md 32 | - Installation: installation.md 33 | - Single-Step Models: single_step.md 34 | - CLI: 35 | - Single-Step Evaluation: cli/eval_single_step.md 36 | - Running Search: cli/search.md 37 | - Tutorials: 38 | - Quick Start: tutorials/quick_start.ipynb 39 | - Integrating a Custom Model: tutorials/custom_model.ipynb 40 | - Multi-step Search on PaRoutes: tutorials/paroutes_benchmark.ipynb 41 | 42 | plugins: 43 | - mkdocs-jupyter 44 | 45 | markdown_extensions: 46 | - admonition 47 | - attr_list 48 | - md_in_html 49 | - pymdownx.details 50 | - pymdownx.superfences 51 | - pymdownx.tabbed: 52 | alternate_style: true 53 | 54 | extra: 55 | version: 56 | default: stable 57 | provider: mike 58 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/environment_gln/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.8-cudnn8-ubuntu22.04 2 | MAINTAINER krmaziar@microsoft.com 3 | 4 | # Set bash, as conda doesn't like dash 5 | SHELL [ "/bin/bash", "--login", "-c" ] 6 | 7 | # Make bash aware of conda 8 | RUN echo ". /opt/miniconda/etc/profile.d/conda.sh" >> ~/.profile 9 | 10 | # Turn off caching in pip 11 | ENV PIP_NO_CACHE_DIR=1 12 | 13 | # Install the dependencies into conda's default environment 14 | COPY ./environment.yml /tmp/ 15 | RUN conda config --remove channels defaults && \ 16 | conda config --add channels conda-forge && \ 17 | conda update --all && \ 18 | conda install mamba -n base -c conda-forge && \ 19 | conda clean --all --yes 20 | RUN mamba env update -p /opt/miniconda -f /tmp/environment.yml && conda clean --all --yes 21 | 22 | # Install RDKit from source 23 | RUN git clone https://github.com/rdkit/rdkit.git 24 | WORKDIR /rdkit 25 | RUN git checkout 8bd74d91118f3fdb370081ef0a18d71715e7c6cf 26 | RUN mkdir build && cd build && cmake -DPy_ENABLE_SHARED=1 \ 27 | -DRDK_INSTALL_INTREE=ON \ 28 | -DRDK_INSTALL_STATIC_LIBS=OFF \ 29 | -DRDK_BUILD_CPP_TESTS=ON \ 30 | -DPYTHON_NUMPY_INCLUDE_PATH="$(python -c 'import numpy ; print(numpy.get_include())')" \ 31 | -DBOOST_ROOT="$CONDA_PREFIX" \ 32 | .. && make && make install 33 | WORKDIR / 34 | 35 | # Install GLN (this relies on `CUDA_HOME` being set correctly) 36 | RUN git clone https://github.com/kmaziarz/GLN.git 37 | WORKDIR /GLN 38 | RUN pip install -e . 39 | 40 | ENV PYTHONPATH=$PYTHONPATH:/rdkit:/GLN 41 | ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/rdkit/lib 42 | -------------------------------------------------------------------------------- /docs/cli/search.md: -------------------------------------------------------------------------------- 1 | # Running Search 2 | 3 | ## Usage 4 | 5 | ``` 6 | syntheseus search \ 7 | search_targets_file=[SMILES_FILE_WITH_SEARCH_TARGETS] \ 8 | inventory_smiles_file=[SMILES_FILE_WITH_PURCHASABLE_MOLECULES] \ 9 | model_class=[MODEL_CLASS] \ 10 | model_dir=[MODEL_DIR] \ 11 | time_limit_s=[NUMBER_OF_SECONDS_PER_TARGET] 12 | ``` 13 | 14 | Both the search targets and the purchasable molecules inventory are expected to be plain SMILES files, with one molecule per line. 15 | 16 | The `search` command accepts further arguments to configure the search algorithm; see `SearchConfig` in `cli/search.py` for the complete list. 17 | 18 | !!! info 19 | When using one of the natively supported single-step models you can omit `model_dir`, which will cause `syntheseus` to use a default checkpoint trained on USPTO-50K (see [here](../single_step.md) for details). 20 | 21 | ## Configuring the search algorithm 22 | 23 | You can set the search algorithm explicitly using the `search_algorithm` argument to `retro_star` (default), `mcts` or `pdvn`. 24 | For all of those algorithms you can vary hyperparameters such as the policy/value functions or MCTS bound type/constant. 25 | 26 | In practice however there may be no need to override any hyperparameters, especially if combining Retro\* or MCTS with one of the natively supported models, as for those `syntheseus` will automatically choose sensible hyperparameter defaults (listed in `cli/search_config.yml`). 27 | [In our experience](https://arxiv.org/abs/2310.19796) both Retro* and MCTS show similar performance when tuned properly, but you may want to try both for your particular usecase and see which one works best empirically. 28 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/inference/base.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, Union 3 | 4 | from syntheseus.interface.models import ( 5 | BackwardReactionModel, 6 | ForwardReactionModel, 7 | InputType, 8 | ReactionModel, 9 | ReactionType, 10 | ) 11 | from syntheseus.reaction_prediction.utils.downloading import get_default_model_dir_from_cache 12 | 13 | 14 | class ExternalReactionModel(ReactionModel[InputType, ReactionType]): 15 | """Base class for the external reaction models, abstracting out common functinality.""" 16 | 17 | def __init__( 18 | self, model_dir: Optional[Union[str, Path]] = None, device: Optional[str] = None, **kwargs 19 | ) -> None: 20 | super().__init__(**kwargs) 21 | import torch 22 | 23 | self.model_dir = Path(model_dir or self.get_default_model_dir()) 24 | self.device = device or ("cuda:0" if torch.cuda.is_available() else "cpu") 25 | 26 | def get_default_model_dir(self) -> Path: 27 | model_dir = get_default_model_dir_from_cache(self.name, is_forward=self.is_forward()) 28 | 29 | if model_dir is None: 30 | raise ValueError( 31 | f"Could not obtain a default checkpoint for model {self.name}, " 32 | "please provide an explicit value for `model_dir`" 33 | ) 34 | 35 | return model_dir 36 | 37 | @property 38 | def name(self) -> str: 39 | return self.__class__.__name__.removesuffix("Model") 40 | 41 | 42 | class ExternalBackwardReactionModel(ExternalReactionModel, BackwardReactionModel): 43 | pass 44 | 45 | 46 | class ExternalForwardReactionModel(ExternalReactionModel, ForwardReactionModel): 47 | pass 48 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # This should not be necessary, except that `conda<4.11` has a bug dealing with `python>=3.10` 2 | # (see https://github.com/conda/conda/issues/10969), and the below makes that go away. 3 | default_language_version: 4 | python: python3 5 | 6 | repos: 7 | # Generally useful pre-commit hooks 8 | - repo: https://github.com/pre-commit/pre-commit-hooks 9 | rev: v4.3.0 # Use the ref you want to point at 10 | hooks: 11 | - id: check-added-large-files 12 | - id: check-case-conflict 13 | - id: check-docstring-first 14 | - id: check-executables-have-shebangs 15 | - id: check-json 16 | - id: check-merge-conflict 17 | - id: check-symlinks 18 | - id: check-toml 19 | - id: check-yaml 20 | - id: debug-statements 21 | - id: destroyed-symlinks 22 | - id: detect-private-key 23 | - id: end-of-file-fixer 24 | - id: name-tests-test 25 | args: ["--pytest-test-first"] 26 | - id: trailing-whitespace 27 | args: [--markdown-linebreak-ext=md] 28 | 29 | # latest version of black when this pre-commit config is being set up 30 | - repo: https://github.com/psf/black 31 | rev: 23.3.0 32 | hooks: 33 | - id: black 34 | name: "black" 35 | args: ["--config=pyproject.toml"] 36 | 37 | # latest version of mypy at time pre-commit config is being set up 38 | # NOTE: only checks code in "syntheseus" directory. 39 | - repo: https://github.com/pre-commit/mirrors-mypy 40 | rev: v1.2.0 41 | hooks: 42 | - id: mypy 43 | name: "mypy" 44 | files: "syntheseus/" 45 | args: ["--install-types", "--non-interactive"] 46 | 47 | # Latest ruff (does linting + more) 48 | - repo: https://github.com/charliermarsh/ruff-pre-commit 49 | rev: 'v0.2.1' 50 | hooks: 51 | - id: ruff 52 | args: [--fix] 53 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, TypeVar 3 | 4 | import numpy as np 5 | 6 | OutputType = TypeVar("OutputType") 7 | 8 | 9 | class TopKMetricsAccumulator: 10 | """Class to accumulate prediction top-k accuracy and MRR under a given notion of correctness.""" 11 | 12 | def __init__(self, max_num_results: int): 13 | self._max_num_results = max_num_results 14 | 15 | # Initialize things we will need to compute the metrics. 16 | self._top_k_correct_cnt = np.zeros(max_num_results) 17 | self._sum_reciprocal_rank = 0.0 18 | self._num_samples = 0 19 | 20 | def add(self, is_output_correct: List[bool]) -> None: 21 | assert len(is_output_correct) <= self._max_num_results 22 | 23 | self._num_samples += 1 24 | for idx, output in enumerate(is_output_correct): 25 | if output: 26 | self._top_k_correct_cnt[idx:] += 1.0 27 | self._sum_reciprocal_rank += 1.0 / (idx + 1) 28 | break 29 | 30 | @property 31 | def num_samples(self) -> int: 32 | return self._num_samples 33 | 34 | @property 35 | def top_k(self) -> List[float]: 36 | return list(self._top_k_correct_cnt / self._num_samples) 37 | 38 | @property 39 | def mrr(self) -> float: 40 | return self._sum_reciprocal_rank / self._num_samples 41 | 42 | 43 | @dataclass(frozen=True) 44 | class ModelTimingResults: 45 | time_model_call: float 46 | time_post_processing: float 47 | 48 | 49 | def compute_total_time(timing_results: List[ModelTimingResults]) -> ModelTimingResults: 50 | return ModelTimingResults( 51 | **{ 52 | key: sum(getattr(result, key) for result in timing_results) 53 | for key in ["time_model_call", "time_post_processing"] 54 | } 55 | ) 56 | -------------------------------------------------------------------------------- /syntheseus/tests/search/analysis/conftest.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | 5 | from syntheseus.search.algorithms.breadth_first import AndOr_BreadthFirstSearch 6 | from syntheseus.search.analysis.route_extraction import iter_routes_cost_order 7 | from syntheseus.search.graph.and_or import AndNode, AndOrGraph 8 | from syntheseus.search.graph.molset import MolSetNode 9 | from syntheseus.search.graph.route import SynthesisGraph 10 | from syntheseus.tests.search.conftest import RetrosynthesisTask 11 | 12 | 13 | def set_uniform_costs(graph) -> None: 14 | """Set a unit cost of 1 for all nodes with reactions.""" 15 | for node in graph.nodes(): 16 | if isinstance(node, (MolSetNode, AndNode)): 17 | node.data["route_cost"] = 1.0 18 | else: 19 | node.data["route_cost"] = 0.0 20 | 21 | 22 | @pytest.fixture 23 | def andor_graph_with_many_routes(retrosynthesis_task6: RetrosynthesisTask) -> AndOrGraph: 24 | task = retrosynthesis_task6 25 | alg = AndOr_BreadthFirstSearch( 26 | reaction_model=task.reaction_model, mol_inventory=task.inventory, unique_nodes=True 27 | ) 28 | output_graph, _ = alg.run_from_mol(task.target_mol) 29 | assert len(output_graph) == 278 # make sure number of nodes is always the same 30 | set_uniform_costs(output_graph) 31 | return output_graph 32 | 33 | 34 | @pytest.fixture 35 | def sample_synthesis_routes(andor_graph_with_many_routes: AndOrGraph) -> list[SynthesisGraph]: 36 | """Return 11 synthesis routes extracted from the graph of length <= 3.""" 37 | output = list( 38 | iter_routes_cost_order(andor_graph_with_many_routes, max_routes=10_000, stop_cost=4.0) 39 | ) 40 | assert len(output) == 11 41 | return [ 42 | andor_graph_with_many_routes.to_synthesis_graph(route) # type: ignore # node type unclear 43 | for route in output 44 | ] 45 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/utils/model_loading.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Union 3 | 4 | from omegaconf import OmegaConf 5 | 6 | from syntheseus.interface.models import ReactionModel 7 | from syntheseus.reaction_prediction.inference.config import BackwardModelConfig, ForwardModelConfig 8 | 9 | logger = logging.getLogger(__file__) 10 | 11 | 12 | def get_model( 13 | config: Union[BackwardModelConfig, ForwardModelConfig], 14 | batch_size: int, 15 | num_gpus: int, 16 | **model_kwargs, 17 | ) -> ReactionModel: 18 | # Check that model kwargs don't overlap 19 | overlapping_kwargs = set(config.model_kwargs.keys()) & set(model_kwargs.keys()) 20 | if overlapping_kwargs: 21 | raise ValueError(f"Model kwargs overlap: {overlapping_kwargs}") 22 | 23 | def model_fn(device): 24 | return config.model_class.value( 25 | model_dir=OmegaConf.select(config, "model_dir"), 26 | device=device, 27 | **config.model_kwargs, 28 | **model_kwargs, 29 | ) 30 | 31 | if num_gpus == 0: 32 | return model_fn("cpu") 33 | elif num_gpus == 1: 34 | return model_fn("cuda:0") 35 | else: 36 | if batch_size < num_gpus: 37 | raise ValueError(f"Cannot split batch of size {batch_size} across {num_gpus} GPUs") 38 | 39 | batch_size_per_gpu = batch_size // num_gpus 40 | 41 | if batch_size_per_gpu < 16: 42 | logger.warning(f"Batch size per GPU is very small: ~{batch_size_per_gpu}") 43 | 44 | try: 45 | from syntheseus.reaction_prediction.utils.parallel import ParallelReactionModel 46 | except ModuleNotFoundError: 47 | raise ValueError("Multi-GPU evaluation is only supported for torch-based models") 48 | 49 | return ParallelReactionModel( 50 | model_fn=model_fn, devices=[f"cuda:{idx}" for idx in range(num_gpus)] 51 | ) 52 | -------------------------------------------------------------------------------- /syntheseus/tests/search/algorithms/test_random.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | 5 | from syntheseus.search.algorithms.random import ( 6 | AndOr_RandomSearch, 7 | MolSet_RandomSearch, 8 | ) 9 | from syntheseus.tests.search.algorithms.test_base import BaseAlgorithmTest 10 | from syntheseus.tests.search.conftest import RetrosynthesisTask 11 | 12 | 13 | class BaseRandomSearchTest(BaseAlgorithmTest): 14 | """ 15 | Base test for random search. 16 | 17 | We skip `test_found_routesX` because random search is very inefficient. 18 | """ 19 | 20 | @pytest.mark.skip(reason="Random search is very inefficient") 21 | def test_found_routes1(self, retrosynthesis_task1: RetrosynthesisTask) -> None: 22 | pass 23 | 24 | @pytest.mark.skip(reason="Random search is very inefficient") 25 | def test_found_routes2(self, retrosynthesis_task2: RetrosynthesisTask) -> None: 26 | pass 27 | 28 | @pytest.mark.flaky(reruns=3) 29 | @pytest.mark.parametrize("limit", [0, 1, 2, 1000]) 30 | def test_limit_iterations( 31 | self, 32 | retrosynthesis_task1: RetrosynthesisTask, 33 | retrosynthesis_task2: RetrosynthesisTask, 34 | retrosynthesis_task3: RetrosynthesisTask, 35 | limit: int, 36 | ) -> None: 37 | # Here we are just overriding the limits which are tested. 38 | # Random search is inefficient, so sometimes after 100 iterations not all tasks are solved. 39 | super().test_limit_iterations( 40 | retrosynthesis_task1, retrosynthesis_task2, retrosynthesis_task3, limit 41 | ) 42 | 43 | 44 | class TestAndOrRandomSearch(BaseRandomSearchTest): 45 | def setup_algorithm(self, **kwargs) -> AndOr_RandomSearch: 46 | return AndOr_RandomSearch(**kwargs) 47 | 48 | 49 | class TestMolSetRandomSearch(BaseRandomSearchTest): 50 | def setup_algorithm(self, **kwargs) -> MolSet_RandomSearch: 51 | return MolSet_RandomSearch(**kwargs) 52 | -------------------------------------------------------------------------------- /syntheseus/tests/reaction_prediction/inference/test_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from syntheseus.interface.bag import Bag 4 | from syntheseus.interface.molecule import Molecule 5 | from syntheseus.reaction_prediction.inference.base import ExternalBackwardReactionModel 6 | from syntheseus.reaction_prediction.inference.config import BackwardModelClass 7 | from syntheseus.reaction_prediction.utils.testing import are_single_step_models_installed 8 | 9 | pytestmark = pytest.mark.skipif( 10 | not are_single_step_models_installed(), 11 | reason="Model tests require all single-step models to be installed", 12 | ) 13 | 14 | 15 | MODEL_CLASSES_TO_TEST = set(BackwardModelClass) - {BackwardModelClass.GLN} 16 | 17 | 18 | @pytest.fixture(scope="module", params=list(MODEL_CLASSES_TO_TEST) * 2) 19 | def model(request) -> ExternalBackwardReactionModel: 20 | model_cls = request.param.value 21 | return model_cls() 22 | 23 | 24 | def test_call(model: ExternalBackwardReactionModel) -> None: 25 | [result] = model([Molecule("Cc1ccc(-c2ccc(C)cc2)cc1")], num_results=20) 26 | model_predictions = [prediction.reactants for prediction in result] 27 | 28 | # Prepare some coupling reactions that are reasonable predictions for the product above. 29 | expected_predictions = [ 30 | Bag([Molecule(f"Cc1ccc({leaving_group_1})cc1"), Molecule(f"Cc1ccc({leaving_group_2})cc1")]) 31 | for leaving_group_1 in ["Br", "I"] 32 | for leaving_group_2 in ["B(O)O", "I", "[Mg+]"] 33 | ] 34 | 35 | # The model should recover at least two (out of six) in its top-20. 36 | assert len(set(expected_predictions) & set(model_predictions)) >= 2 37 | 38 | 39 | def test_misc(model: ExternalBackwardReactionModel) -> None: 40 | import torch 41 | 42 | assert isinstance(model.name, str) 43 | assert isinstance(model.get_model_info(), dict) 44 | assert model.is_backward() is not model.is_forward() 45 | 46 | for p in model.get_parameters(): 47 | assert isinstance(p, torch.Tensor) 48 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/utils/parallel.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Callable, List, Sequence 3 | 4 | import torch 5 | from more_itertools import chunked 6 | 7 | from syntheseus.interface.models import InputType, ReactionModel, ReactionType 8 | 9 | 10 | class ParallelReactionModel(ReactionModel[InputType, ReactionType]): 11 | """Wraps an arbitrary `ReactionModel` to enable multi-GPU inference. 12 | 13 | Unlike most off-the-shelf multi-GPU approaches (e.g. strategies in `pytorch_lightning`, 14 | `nn.DataParallel`, `nn.DistributedDataParallel`), this class only handles inference (not 15 | training), and because of that it can be much looser in terms of the constraints the 16 | parallelized model has to satisfy. It also works with lists of inputs (chunking them up 17 | appropriately), whereas other approaches usually only work with tensors. 18 | """ 19 | 20 | def __init__(self, *args, model_fn: Callable, devices: List, **kwargs) -> None: 21 | super().__init__(*args, **kwargs) 22 | 23 | self._devices = devices 24 | self._model_replicas = [model_fn(device=device) for device in devices] 25 | 26 | def _get_reactions( 27 | self, inputs: List[InputType], num_results: int 28 | ) -> List[Sequence[ReactionType]]: 29 | # Chunk up the inputs into (roughly) equal-sized chunks. 30 | chunk_size = math.ceil(len(inputs) / len(self._devices)) 31 | input_chunks = list((input,) for input in chunked(inputs, chunk_size)) 32 | 33 | # If `len(inputs)` is not divisible by `len(self._devices)` the last chunk may end up empty. 34 | num_chunks = len(input_chunks) 35 | 36 | outputs = torch.nn.parallel.parallel_apply( 37 | self._model_replicas[:num_chunks], 38 | input_chunks, 39 | tuple({"num_results": num_results} for _ in range(num_chunks)), 40 | self._devices[:num_chunks], 41 | ) 42 | 43 | # Contatenate all outputs from the replicas. 44 | return sum(outputs, []) 45 | 46 | def is_forward(self) -> bool: 47 | return self._model_replicas[0].is_forward() 48 | -------------------------------------------------------------------------------- /syntheseus/tests/reaction_prediction/chem/test_utils.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem 2 | 3 | from syntheseus import Bag, Molecule, Reaction, SingleProductReaction 4 | from syntheseus.reaction_prediction.chem.utils import ( 5 | remove_atom_mapping, 6 | remove_atom_mapping_from_mol, 7 | remove_stereo_information, 8 | remove_stereo_information_from_reaction, 9 | ) 10 | 11 | 12 | def test_remove_mapping() -> None: 13 | smiles_mapped = "[OH:1][CH2:2][c:3]1[cH:4][n:5][cH:6][cH:7][c:8]1[Br:9]" 14 | smiles_unmapped = "OCc1cnccc1Br" 15 | 16 | assert remove_atom_mapping(smiles_mapped) == smiles_unmapped 17 | assert remove_atom_mapping(smiles_unmapped) == smiles_unmapped 18 | 19 | mol = Chem.MolFromSmiles(smiles_mapped) 20 | remove_atom_mapping_from_mol(mol) 21 | 22 | assert Chem.MolToSmiles(mol) == smiles_unmapped 23 | 24 | 25 | def test_remove_stereo_information() -> None: 26 | mol = Molecule("CC(N)C#N") 27 | mols_chiral = [Molecule("C[C@H](N)C#N"), Molecule("C[C@@H](N)C#N")] 28 | 29 | assert len(set([mol] + mols_chiral)) == 3 30 | assert len(set([mol] + [remove_stereo_information(m) for m in mols_chiral])) == 1 31 | 32 | 33 | def test_remove_stereo_information_from_reaction() -> None: 34 | reactants = Bag([Molecule("CCC"), Molecule("CC(N)C#N")]) 35 | reactants_chiral = Bag([Molecule("CCC"), Molecule("C[C@H](N)C#N")]) 36 | 37 | product = Molecule("CC(N)C#N") 38 | product_chiral = Molecule("C[C@H](N)C#N") 39 | 40 | rxn = Reaction(reactants=reactants, products=Bag([product])) 41 | rxn_chiral = Reaction(reactants=reactants_chiral, products=Bag([product_chiral])) 42 | rxn_stereo_removed = remove_stereo_information_from_reaction(rxn_chiral) 43 | 44 | assert type(rxn_stereo_removed) is Reaction 45 | assert rxn_stereo_removed == rxn 46 | 47 | sp_rxn = SingleProductReaction(reactants=reactants, product=product) 48 | sp_rxn_chiral = SingleProductReaction(reactants=reactants_chiral, product=product_chiral) 49 | sp_rxn_stero_removed = remove_stereo_information_from_reaction(sp_rxn_chiral) 50 | 51 | assert type(sp_rxn_stero_removed) is SingleProductReaction 52 | assert sp_rxn_stero_removed == sp_rxn 53 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/chem/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Union 2 | 3 | from rdkit import Chem 4 | 5 | from syntheseus import Bag, Molecule, SingleProductReaction 6 | from syntheseus.interface.models import ReactionType 7 | from syntheseus.interface.molecule import SMILES_SEPARATOR 8 | 9 | ATOM_MAPPING_PROP_NAME = "molAtomMapNumber" 10 | 11 | 12 | def remove_atom_mapping_from_mol(mol: Chem.Mol) -> None: 13 | """Removed the atom mapping from an rdkit molecule modifying it in place.""" 14 | for atom in mol.GetAtoms(): 15 | atom.ClearProp(ATOM_MAPPING_PROP_NAME) 16 | 17 | 18 | def remove_atom_mapping(smiles: str) -> str: 19 | """Removes the atom mapping from a SMILES string. 20 | 21 | Args: 22 | smiles: Molecule SMILES to be modified. 23 | 24 | Returns: 25 | str: Input SMILES with atom map numbers stripped away. 26 | """ 27 | mol = Chem.MolFromSmiles(smiles) 28 | remove_atom_mapping_from_mol(mol) 29 | 30 | return Chem.MolToSmiles(mol) 31 | 32 | 33 | def remove_stereo_information(mol: Molecule) -> Molecule: 34 | return Molecule(Chem.MolToSmiles(mol.rdkit_mol, isomericSmiles=False)) 35 | 36 | 37 | def remove_stereo_information_from_reaction(reaction: ReactionType) -> ReactionType: 38 | mol_kwargs: Dict[str, Union[Molecule, Bag[Molecule]]] = { 39 | "reactants": Bag([remove_stereo_information(mol) for mol in reaction.reactants]) 40 | } 41 | 42 | if isinstance(reaction, SingleProductReaction): 43 | mol_kwargs["product"] = remove_stereo_information(reaction.product) 44 | else: 45 | mol_kwargs["products"] = Bag([remove_stereo_information(mol) for mol in reaction.products]) 46 | 47 | return reaction.__class__( 48 | **mol_kwargs, identifier=reaction.identifier, metadata=reaction.metadata # type: ignore[arg-type] 49 | ) 50 | 51 | 52 | def molecule_bag_from_smiles_strict(smiles: str) -> Bag[Molecule]: 53 | return Bag([Molecule(component) for component in smiles.split(SMILES_SEPARATOR)]) 54 | 55 | 56 | def molecule_bag_from_smiles(smiles: str) -> Optional[Bag[Molecule]]: 57 | try: 58 | return molecule_bag_from_smiles_strict(smiles) 59 | except ValueError: 60 | # If any of the components ends up invalid we return `None` instead. 61 | return None 62 | -------------------------------------------------------------------------------- /syntheseus/search/algorithms/breadth_first.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import collections 4 | import logging 5 | from typing import Generic 6 | 7 | from syntheseus.search.algorithms.base import ( 8 | AndOrSearchAlgorithm, 9 | GraphType, 10 | MolSetSearchAlgorithm, 11 | SearchAlgorithm, 12 | ) 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class GeneralBreadthFirstSearch(SearchAlgorithm[GraphType, int], Generic[GraphType]): 18 | """Base class for breadth first search algorithms (pseudo-code is the same for all data structures).""" 19 | 20 | @property 21 | def requires_tree(self) -> bool: 22 | return False # can work on any graph 23 | 24 | def _run_from_graph_after_setup(self, graph: GraphType) -> int: 25 | log_level = logging.DEBUG - 1 26 | logger_active = logger.isEnabledFor(log_level) 27 | 28 | queue = collections.deque([node for node in graph._graph.nodes() if not node.is_expanded]) 29 | step = 0 # initialize this variable in case loop is not entered 30 | for step in range(self.limit_iterations): 31 | if self.should_stop_search(graph) or len(queue) == 0: 32 | break 33 | 34 | # Pop node and potentially expand it 35 | node = queue.popleft() 36 | if node.is_expanded: 37 | outcome = "already expanded, do nothing" 38 | else: 39 | new_nodes = self.expand_node(node, graph) 40 | self.set_node_values(set(new_nodes) | {node}, graph) 41 | queue.extend([n for n in new_nodes if self.can_expand_node(n, graph)]) 42 | outcome = f"expanded, created {len(new_nodes)} new nodes" 43 | 44 | if logger_active: 45 | logger.log( 46 | log_level, f"Step {step}: node {node} {outcome}. Queue size: {len(queue)}" 47 | ) 48 | 49 | return step 50 | 51 | 52 | class MolSet_BreadthFirstSearch(GeneralBreadthFirstSearch, MolSetSearchAlgorithm): 53 | @property 54 | def requires_tree(self) -> bool: 55 | # Even though it "could" work, molset graphs with unique nodes are not 56 | # well-supported so we don't allow it at this time. 57 | return True 58 | 59 | 60 | class AndOr_BreadthFirstSearch(GeneralBreadthFirstSearch, AndOrSearchAlgorithm): 61 | pass 62 | -------------------------------------------------------------------------------- /syntheseus/tests/search/test_visualization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests that vizualization runs. Tests are skipped if graphviz is not installed. 3 | """ 4 | 5 | import tempfile 6 | 7 | import pytest 8 | 9 | from syntheseus.search.graph.and_or import AndOrGraph 10 | from syntheseus.search.graph.molset import MolSetGraph 11 | 12 | visualization = pytest.importorskip("syntheseus.search.visualization") 13 | 14 | 15 | @pytest.mark.parametrize("draw_mols", [False, True]) 16 | @pytest.mark.parametrize("partial_graph", [False, True]) 17 | def test_andor_visualization( 18 | andor_graph_non_minimal: AndOrGraph, draw_mols: bool, partial_graph: bool 19 | ) -> None: 20 | if partial_graph: 21 | nodes = [n for n in andor_graph_non_minimal.nodes() if n.depth <= 2] 22 | else: 23 | nodes = None 24 | 25 | # Visualize, both with and without drawing mols 26 | with tempfile.TemporaryDirectory() as tmp_dir: 27 | visualization.visualize_andor( 28 | graph=andor_graph_non_minimal, 29 | filename=f"{tmp_dir}/tmp.pdf", 30 | draw_mols=draw_mols, 31 | nodes=nodes, 32 | ) 33 | 34 | 35 | @pytest.mark.parametrize("draw_mols", [False, True]) 36 | @pytest.mark.parametrize("partial_graph", [False, True]) 37 | def test_molset_visualization( 38 | molset_tree_non_minimal: MolSetGraph, draw_mols: bool, partial_graph: bool 39 | ) -> None: 40 | if partial_graph: 41 | nodes = [n for n in molset_tree_non_minimal.nodes() if n.depth <= 2] 42 | else: 43 | nodes = None 44 | with tempfile.TemporaryDirectory() as tmp_dir: 45 | visualization.visualize_molset( 46 | graph=molset_tree_non_minimal, 47 | filename=f"{tmp_dir}/tmp.pdf", 48 | draw_mols=draw_mols, 49 | nodes=nodes, 50 | ) 51 | 52 | 53 | def test_filename_ends_with_pdf( 54 | molset_tree_non_minimal: MolSetGraph, 55 | andor_graph_non_minimal: AndOrGraph, 56 | ) -> None: 57 | """Test that an error is raised if the file name doesn't end in .pdf""" 58 | 59 | with pytest.raises(ValueError): 60 | visualization.visualize_andor( 61 | graph=andor_graph_non_minimal, 62 | filename="tmp.xyz", 63 | ) 64 | with pytest.raises(ValueError): 65 | visualization.visualize_molset( 66 | graph=molset_tree_non_minimal, 67 | filename="tmp.xyz", 68 | ) 69 | -------------------------------------------------------------------------------- /syntheseus/search/algorithms/random.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from typing import Generic 5 | 6 | from syntheseus.search.algorithms.base import ( 7 | AndOrSearchAlgorithm, 8 | GraphType, 9 | MolSetSearchAlgorithm, 10 | SearchAlgorithm, 11 | ) 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class BaseRandomSearch(SearchAlgorithm[GraphType, int], Generic[GraphType]): 17 | """Base class for both AND/OR and MolSet random search algorithms.""" 18 | 19 | @property 20 | def requires_tree(self) -> bool: 21 | return False # can work on any graph 22 | 23 | def _run_from_graph_after_setup(self, graph: GraphType) -> int: 24 | log_level = logging.DEBUG - 1 25 | logger_active = logger.isEnabledFor(log_level) 26 | 27 | # Initialize a set of nodes which can be expanded 28 | expandable_nodes = {node for node in graph.nodes() if self.can_expand_node(node, graph)} 29 | 30 | step = 0 # initialize this variable in case loop is not entered 31 | for step in range(self.limit_iterations): 32 | if self.should_stop_search(graph) or len(expandable_nodes) == 0: 33 | break 34 | 35 | # Choose a random node to expand 36 | node = self.random_state.choice(list(expandable_nodes)) 37 | expandable_nodes.remove(node) 38 | 39 | # Expand the node 40 | new_nodes = self.expand_node(node, graph) 41 | self.set_node_values(set(new_nodes) | {node}, graph) 42 | for n in new_nodes: 43 | if self.can_expand_node(n, graph): 44 | expandable_nodes.add(n) 45 | 46 | if logger_active: 47 | logger.log( 48 | log_level, 49 | f"Step {step}: node {node} expanded, created {len(new_nodes)} new nodes. " 50 | f"Num expandable nodes: {len(expandable_nodes)}.", 51 | ) 52 | 53 | return step 54 | 55 | 56 | class MolSet_RandomSearch(BaseRandomSearch, MolSetSearchAlgorithm): 57 | @property 58 | def requires_tree(self) -> bool: 59 | # Even though it "could" work, molset graphs with unique nodes are not 60 | # well-supported so we don't allow it at this time. 61 | return True 62 | 63 | 64 | class AndOr_RandomSearch(BaseRandomSearch, AndOrSearchAlgorithm): 65 | pass 66 | -------------------------------------------------------------------------------- /syntheseus/search/node_evaluation/common.py: -------------------------------------------------------------------------------- 1 | """Common node evaluation functions.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Sequence, Union 6 | 7 | from syntheseus.interface.reaction import SingleProductReaction 8 | from syntheseus.search.graph.and_or import AndNode 9 | from syntheseus.search.graph.molset import MolSetNode 10 | from syntheseus.search.node_evaluation.base import NoCacheNodeEvaluator, ReactionModelBasedEvaluator 11 | 12 | 13 | class ConstantNodeEvaluator(NoCacheNodeEvaluator): 14 | def __init__(self, constant: float, **kwargs): 15 | super().__init__(**kwargs) 16 | self.constant = constant 17 | 18 | def _evaluate_nodes(self, nodes, graph=None): 19 | return [self.constant] * len(nodes) 20 | 21 | 22 | class HasSolutionValueFunction(NoCacheNodeEvaluator): 23 | def _evaluate_nodes(self, nodes, graph=None): 24 | return [float(n.has_solution) for n in nodes] 25 | 26 | 27 | class ReactionModelLogProbCost(ReactionModelBasedEvaluator[AndNode]): 28 | """Evaluator that uses the reactions' negative logprob to form a cost (useful for Retro*).""" 29 | 30 | def __init__(self, **kwargs) -> None: 31 | super().__init__(return_log=True, **kwargs) 32 | 33 | def _get_reaction(self, node: AndNode, graph) -> SingleProductReaction: 34 | return node.reaction 35 | 36 | def _evaluate_nodes(self, nodes, graph=None) -> Sequence[float]: 37 | return [-v for v in super()._evaluate_nodes(nodes, graph)] 38 | 39 | 40 | class ReactionModelProbPolicy(ReactionModelBasedEvaluator[Union[MolSetNode, AndNode]]): 41 | """Evaluator that uses the reactions' probability to form a policy (useful for MCTS).""" 42 | 43 | def __init__(self, **kwargs) -> None: 44 | kwargs["normalize"] = kwargs.get("normalize", True) # set `normalize = True` by default 45 | super().__init__(return_log=False, **kwargs) 46 | 47 | def _get_reaction(self, node: Union[MolSetNode, AndNode], graph) -> SingleProductReaction: 48 | if isinstance(node, MolSetNode): 49 | parents = list(graph.predecessors(node)) 50 | assert len(parents) == 1, "Graph must be a tree" 51 | return graph._graph.edges[parents[0], node]["reaction"] 52 | elif isinstance(node, AndNode): 53 | return node.reaction 54 | else: 55 | raise ValueError(f"ReactionModelProbPolicy does not support nodes of type {type(node)}") 56 | -------------------------------------------------------------------------------- /docs/images/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 16 | 36 | 38 | 43 | Syntheseus 61 | 62 | 63 | -------------------------------------------------------------------------------- /docs/single_step.md: -------------------------------------------------------------------------------- 1 | Syntheseus currently supports 8 established single-step models. 2 | 3 | For convenience, for each model we include a default checkpoint trained on USPTO-50K. 4 | If no checkpoint directory is provided during model loading, `syntheseus` will automatically download a default checkpoint and cache it on disk for future use. 5 | The default path for the cache is `$HOME/.cache/torch/syntheseus`, but it can be overriden by setting the `SYNTHESEUS_CACHE_DIR` environment variable. 6 | See table below for the links to the default checkpoints. 7 | 8 | | Model checkpoint link | Source | 9 | |----------------------------------------------------------------|--------| 10 | | [Chemformer](https://figshare.com/ndownloader/files/42009888) | finetuned by us starting from checkpoint released by authors | 11 | | [GLN](https://figshare.com/ndownloader/files/45882867) | released by authors | 12 | | [Graph2Edits](https://figshare.com/ndownloader/files/44194301) | released by authors | 13 | | [LocalRetro](https://figshare.com/ndownloader/files/42287319) | trained by us | 14 | | [MEGAN](https://figshare.com/ndownloader/files/42012732) | trained by us | 15 | | [MHNreact](https://figshare.com/ndownloader/files/42012777) | trained by us | 16 | | [RetroKNN](https://figshare.com/ndownloader/files/45662430) | trained by us | 17 | | [RootAligned](https://figshare.com/ndownloader/files/42012792) | released by authors | 18 | 19 | ??? note "More advanced datasets" 20 | 21 | The USPTO-50K dataset is well-established but relatively small. Advanced users may prefer to retrain their models of interest on a larger dataset, such as USPTO-FULL or Pistachio. To do that, please follow the instructions in the original model repositories. 22 | 23 | In `reaction_prediction/cli/eval.py` a forward model can be used for computing back-translation (round-trip) accuracy. 24 | See [here](https://figshare.com/ndownloader/files/42012708) for a Chemformer checkpoint finetuned for forward prediction on USPTO-50K. As for the backward direction, pretrained weights released by original authors were used as a starting point. 25 | 26 | ??? info "Licenses" 27 | All checkpoints were produced in a way that involved external model repositories, hence may be affected by the exact license each model was released with. 28 | For more details about a particular model see the top of the corresponding model wrapper file in `reaction_prediction/inference/`. 29 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/utils/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | from typing import Any, Callable, Dict, List, Optional, TypeVar, Union, cast 4 | 5 | from omegaconf import DictConfig, ListConfig, OmegaConf 6 | 7 | R = TypeVar("R") 8 | 9 | 10 | def get_config( 11 | argv: Optional[List[str]], 12 | config_cls: Callable[..., R], 13 | defaults: Optional[Dict[str, Any]] = None, 14 | ) -> R: 15 | """ 16 | Utility function to get `OmegaConf` config options. 17 | 18 | Args: 19 | argv: Either a list of command line arguments to parse, or `None`. If `None`, this argument 20 | is set from `sys.argv`. 21 | config_cls: Dataclass object specifying config structure (i.e. which fields to expect in the 22 | config). It should be the class itself, not an instance of the class. 23 | 24 | Returns: 25 | Config object, which will pass as an instance of `config_cls` among other things. Note: the 26 | type for this could be specified more carefully, but `OmegaConf`'s typing system is a bit 27 | complex. Search `OmegaConf`'s docs for "structured" for more info. 28 | """ 29 | 30 | if argv is None: 31 | argv = sys.argv[1:] 32 | # Parse command line arguments 33 | parser = argparse.ArgumentParser(allow_abbrev=False) # prevent prefix matching issues 34 | parser.add_argument( 35 | "--config", 36 | type=str, 37 | action="append", 38 | default=list(), 39 | help="Path to a yaml config file. " 40 | "Argument can be repeated multiple times, with later configs overwriting previous ones.", 41 | ) 42 | args, config_changes = parser.parse_known_args(argv) 43 | 44 | # Read configs from defaults, file and command line 45 | conf_yamls: List[Union[DictConfig, ListConfig]] = [] 46 | if defaults: 47 | conf_yamls = [OmegaConf.create(defaults)] 48 | 49 | conf_yamls += [OmegaConf.load(c) for c in args.config] 50 | conf_cli = OmegaConf.from_cli(config_changes) 51 | 52 | # Make merged config options 53 | # CLI options take priority over YAML file options 54 | schema = OmegaConf.structured(config_cls) 55 | config = OmegaConf.merge(schema, *conf_yamls, conf_cli) 56 | OmegaConf.set_readonly(config, True) # should not be written to 57 | return cast(R, config) 58 | 59 | 60 | def get_error_message_for_missing_value(name: str, possible_values: List[str]) -> str: 61 | return f"{name} should be set to one of [{', '.join(possible_values)}]" 62 | -------------------------------------------------------------------------------- /syntheseus/tests/search/test_mol_inventory.py: -------------------------------------------------------------------------------- 1 | """Tests for MolInventory objects, focusing on the provided SmilesListInventory.""" 2 | 3 | import pytest 4 | 5 | from syntheseus.interface.molecule import Molecule 6 | from syntheseus.search.mol_inventory import SmilesListInventory 7 | 8 | PURCHASABLE_SMILES = ["CC", "c1ccccc1", "CCO"] 9 | NON_PURCHASABLE_SMILES = ["C", "C1CCCCC1", "OCCO"] 10 | 11 | 12 | @pytest.fixture 13 | def example_inventory() -> SmilesListInventory: 14 | """Returns a SmilesListInventory with arbitrary molecules.""" 15 | return SmilesListInventory(PURCHASABLE_SMILES) 16 | 17 | 18 | def test_is_purchasable(example_inventory: SmilesListInventory) -> None: 19 | """ 20 | Does the 'is_purchasable' method return true only for purchasable SMILES? 21 | """ 22 | for sm in PURCHASABLE_SMILES: 23 | assert example_inventory.is_purchasable(Molecule(sm)) 24 | 25 | for sm in NON_PURCHASABLE_SMILES: 26 | assert not example_inventory.is_purchasable(Molecule(sm)) 27 | 28 | 29 | def test_fill_metadata(example_inventory: SmilesListInventory) -> None: 30 | """ 31 | Does the 'fill_metadata' method accurately fill the metadata? 32 | Currently it only checks that the `is_purchasable` key is filled correctly. 33 | At least it should add the 'is_purchasable' key. 34 | """ 35 | 36 | for sm in PURCHASABLE_SMILES + NON_PURCHASABLE_SMILES: 37 | # Make initial molecule without any metadata 38 | mol = Molecule(sm) 39 | assert "is_purchasable" not in mol.metadata 40 | 41 | # Fill metadata and check that it is filled accurately. 42 | # To also handle the case where the metadata is filled, we run the test twice. 43 | for _ in range(2): 44 | example_inventory.fill_metadata(mol) 45 | assert mol.metadata["is_purchasable"] == example_inventory.is_purchasable(mol) 46 | 47 | # corrupt metadata so that next iteration the metadata is filled 48 | # and should be overwritten. 49 | # Type ignore is because we fill in random invalid metadata 50 | mol.metadata["is_purchasable"] = "abc" # type: ignore[typeddict-item] 51 | 52 | 53 | def test_to_purchasable_mols(example_inventory: SmilesListInventory) -> None: 54 | """ 55 | Does the 'to_purchasable_mols' method work correctly? It should return a collection 56 | of all the purchasable molecules. 57 | """ 58 | expected_set = {Molecule(sm) for sm in PURCHASABLE_SMILES} 59 | observed_set = set(example_inventory.to_purchasable_mols()) 60 | assert expected_set == observed_set 61 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/models/retro_knn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class Adapter(nn.Module): 7 | def __init__(self, dim, k=32): 8 | from dgl.nn import GINEConv 9 | 10 | super().__init__() 11 | self.gnn = GINEConv(nn.Linear(dim, dim)) 12 | self.node_proj = nn.Linear(dim, 2) # [tmp, p] 13 | self.edge_proj = nn.Linear(dim, 2) # [tmp, p] 14 | 15 | self.node_dist_in_proj = nn.Linear(k, k) 16 | self.edge_dist_in_proj = nn.Linear(k, k) 17 | 18 | self.node_ffn = nn.Linear(dim + k, dim) 19 | self.edge_ffn = nn.Linear(dim * 2 + k, dim) 20 | 21 | nn.init.kaiming_uniform_(self.node_dist_in_proj.weight, a=100) 22 | nn.init.kaiming_uniform_(self.edge_dist_in_proj.weight, a=100) 23 | 24 | nn.init.zeros_(self.node_proj.weight) 25 | nn.init.zeros_(self.edge_proj.weight) 26 | 27 | nn.init.constant_(self.node_proj.bias[0], 10.0) 28 | nn.init.constant_(self.edge_proj.bias[0], 10.0) 29 | 30 | def forward(self, g, nfeat, efeat, ndist, edist): 31 | from local_retro.scripts.model_utils import pair_atom_feats 32 | 33 | x = self.gnn(g, nfeat, efeat) 34 | x = F.relu(x) 35 | 36 | ndist = F.relu(self.node_dist_in_proj(ndist)) 37 | edist = F.relu(self.edge_dist_in_proj(edist)) 38 | 39 | node_x = torch.cat((x, ndist), dim=-1) 40 | node_x = self.node_ffn(node_x) 41 | node_x = F.relu(node_x) 42 | node_x = self.node_proj(node_x) 43 | 44 | edge_x = pair_atom_feats(g, x) 45 | edge_x = F.relu(edge_x) 46 | edge_x = torch.cat((edge_x, edist), dim=-1) 47 | edge_x = self.edge_ffn(edge_x) 48 | edge_x = F.relu(edge_x) 49 | edge_x = self.edge_proj(edge_x) 50 | 51 | node_t = torch.clamp(node_x[:, 0], 1, 100) 52 | node_p = torch.sigmoid(node_x[:, 1]) 53 | 54 | edge_t = torch.clamp(edge_x[:, 0], 1, 100) 55 | edge_p = torch.sigmoid(edge_x[:, 1]) 56 | 57 | return (r.unsqueeze(-1) for r in (node_t, node_p, edge_t, edge_p)) 58 | 59 | 60 | def knn_prob(feats, store, lables, max_idx, k=32, temperature=5): 61 | from torch_scatter import scatter 62 | 63 | dis, idx = store.search(feats, k) # [B, K] 64 | pred = lables[idx].unsqueeze(-1) # [B, K, 1] 65 | 66 | re_compute_dists = -1 * dis 67 | knn_weight = torch.softmax(re_compute_dists / temperature, dim=-1).unsqueeze(-1) # [B, K, 1] 68 | 69 | bsz = feats.shape[0] 70 | output = torch.zeros(bsz, k, max_idx).to(feats) 71 | 72 | scatter(src=knn_weight, out=output, index=pred, dim=-1) 73 | 74 | return output.sum(dim=1) 75 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/utils/downloading.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import urllib.request 4 | import zipfile 5 | from pathlib import Path 6 | from typing import Optional 7 | 8 | import yaml 9 | 10 | logger = logging.getLogger(__file__) 11 | 12 | 13 | def get_cache_dir(key: str) -> Path: 14 | """Get the cache directory for a given key (e.g. model name).""" 15 | 16 | # First, check if the cache directory has been manually overriden. 17 | cache_dir_from_env = os.getenv("SYNTHESEUS_CACHE_DIR") 18 | if cache_dir_from_env is not None: 19 | # If yes, use the path provided. 20 | cache_dir = Path(cache_dir_from_env) 21 | else: 22 | # If not, construct a reasonable default. 23 | cache_dir = Path(os.getenv("HOME", ".")) / ".cache" / "torch" / "syntheseus" 24 | 25 | cache_dir = cache_dir / key 26 | cache_dir.mkdir(parents=True, exist_ok=True) 27 | 28 | return cache_dir 29 | 30 | 31 | def get_cache_dir_download_if_missing(key: str, link: str) -> Path: 32 | """Get the cache directory for a given key, but populate by downloading from link if empty.""" 33 | 34 | cache_dir = get_cache_dir(key) 35 | if not any(cache_dir.iterdir()): 36 | cache_zip_path = cache_dir / "model.zip" 37 | 38 | logger.info(f"Downloading data from {link} to {cache_zip_path}") 39 | urllib.request.urlretrieve(link, cache_zip_path) 40 | 41 | with zipfile.ZipFile(cache_zip_path, "r") as f_zip: 42 | f_zip.extractall(cache_dir) 43 | 44 | cache_zip_path.unlink() 45 | 46 | return cache_dir 47 | 48 | 49 | def get_default_model_dir_from_cache(model_name: str, is_forward: bool) -> Optional[Path]: 50 | default_model_links_file_path = ( 51 | Path(__file__).parent.parent / "inference" / "default_checkpoint_links.yml" 52 | ) 53 | 54 | if not default_model_links_file_path.exists(): 55 | logger.info( 56 | f"Could not obtain a default model link: {default_model_links_file_path} does not exist" 57 | ) 58 | return None 59 | 60 | with open(default_model_links_file_path, "rt") as f_defaults: 61 | default_model_links = yaml.safe_load(f_defaults) 62 | 63 | assert default_model_links.keys() == {"backward", "forward"} 64 | 65 | forward_backward_key = "forward" if is_forward else "backward" 66 | model_links = default_model_links[forward_backward_key] 67 | 68 | if model_name not in model_links: 69 | logger.info(f"Could not obtain a default model link: no entry for {model_name}") 70 | return None 71 | 72 | return get_cache_dir_download_if_missing( 73 | f"{model_name}_{forward_backward_key}", link=model_links[model_name] 74 | ) 75 | -------------------------------------------------------------------------------- /docs/cli/eval_single_step.md: -------------------------------------------------------------------------------- 1 | # Single-step Evaluation 2 | 3 | ## Usage 4 | 5 | ``` 6 | syntheseus eval-single-step \ 7 | data_dir=[DATA_DIR] \ 8 | fold=[TRAIN, VAL or TEST] \ 9 | model_class=[MODEL_CLASS] \ 10 | model_dir=[MODEL_DIR] 11 | ``` 12 | 13 | The `eval-single-step` command accepts further arguments to customize the evaluation; see `BaseEvalConfig` in `cli/eval_single_step.py` for the complete list. 14 | 15 | The code will scan `data_dir` looking for files matching `*{train, val, test}.{jsonl, csv, smi}` and select the right data format based on the file extension. An error will be raised in case of ambiguity. Only the fold that was selected for evaluation has to be present. 16 | 17 | ## Data format 18 | 19 | The single-step evaluation script supports reaction data in one of three formats. 20 | 21 | ### JSONL 22 | 23 | Our internal format based on `*.jsonl` files in which each line is a JSON representation of a single reaction, for example: 24 | ```json 25 | {"reactants": [{"smiles": "Cc1ccc(Br)cc1"}, {"smiles": "Cc1ccc(B(O)O)cc1"}], "products": [{"smiles": "Cc1ccc(-c2ccc(C)cc2)cc1"}]} 26 | ``` 27 | This format is designed to be flexible at the expense of taking more disk space. The JSON is parsed into a `ReactionSample` object, so it can include additional metadata such as template information, while the reactants and products can include other fields accepted by the `Molecule` object. The evaluation script will only use reactant and product SMILES to compute the metrics. 28 | 29 | Unlike the other formats below, reactants and products in this format are assumed to be already stripped of atom mapping, which leads to slightly faster data loading as that avoids extra calls to `rdkit`. 30 | 31 | ### CSV 32 | 33 | This format is based on `*.csv` files and is commonly used to store raw USPTO data, e.g. as released by [Dai et al.](https://github.com/Hanjun-Dai/GLN): 34 | 35 | ``` 36 | id,class,reactants>reagents>production 37 | ID,UNK,[cH:1]1[cH:2][c:3]([CH3:4])[cH:5][cH:6][c:7]1Br.B(O)(O)[c:8]1[cH:9][cH:10][c:11]([CH3:12])[cH:13][cH:14]1>>[cH:1]1[cH:2][c:3]([CH3:4])[cH:5][cH:6][c:7]1[c:8]2[cH:14][cH:13][c:11]([CH3:12])[cH:10][cH:9]2 38 | ``` 39 | 40 | The evaluation script will look for the `reactants>reagents>production` column to extract the reaction SMILES, which are stripped of atom mapping and canonicalized before being fed to the model. 41 | 42 | ### SMILES 43 | 44 | The most compact format is to list reaction SMILES line-by-line in a `*.smi` file: 45 | 46 | ``` 47 | [cH:1]1[cH:2][c:3]([CH3:4])[cH:5][cH:6][c:7]1Br.B(O)(O)[c:8]1[cH:9][cH:10][c:11]([CH3:12])[cH:13][cH:14]1>>[cH:1]1[cH:2][c:3]([CH3:4])[cH:5][cH:6][c:7]1[c:8]2[cH:14][cH:13][c:11]([CH3:12])[cH:10][cH:9]2 48 | ``` 49 | 50 | The data will be handled in the same way as for the CSV format, i.e. it will only be fed into model after removing atom mapping and canonicalization. 51 | -------------------------------------------------------------------------------- /syntheseus/search/analysis/solution_time.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from syntheseus.search.graph.and_or import AndNode, OrNode 4 | from syntheseus.search.graph.base_graph import RetrosynthesisSearchGraph 5 | from syntheseus.search.graph.message_passing import run_message_passing 6 | from syntheseus.search.graph.molset import MolSetNode 7 | from syntheseus.search.graph.node import BaseGraphNode 8 | 9 | 10 | def get_first_solution_time(graph: RetrosynthesisSearchGraph) -> float: 11 | """Get the time of the first solution. Also sets 'first_solution_time' node attribute.""" 12 | run_message_passing( 13 | graph=graph, 14 | nodes=list(graph._graph.nodes()), 15 | update_fns=[first_solution_time_update], 16 | update_successors=False, # only affects predecessor nodes 17 | ) 18 | return graph.root_node.data["first_solution_time"] 19 | 20 | 21 | def first_solution_time_update(node: BaseGraphNode, graph: RetrosynthesisSearchGraph) -> bool: 22 | NO_SOLUTION_TIME = math.inf # being unsolved = inf time until solution found 23 | 24 | # Calculate "intrinsic solution time" 25 | if node._has_intrinsic_solution(): 26 | intrinsic_solution_time = node.data["analysis_time"] 27 | else: 28 | intrinsic_solution_time = NO_SOLUTION_TIME 29 | 30 | # Calculate solution age from children 31 | children_soln_time_list = [ 32 | c.data.get("first_solution_time", NO_SOLUTION_TIME) for c in graph.successors(node) 33 | ] 34 | if len(children_soln_time_list) == 0: 35 | children_solution_time = NO_SOLUTION_TIME 36 | elif isinstance(node, (OrNode, MolSetNode)): 37 | # Or node is first solved when one child is solved, 38 | # so its solution time is the min of its children's 39 | children_solution_time = min(children_soln_time_list) 40 | elif isinstance(node, AndNode): 41 | # AndNode requires all children to be solved, 42 | # so it is first solved when the LAST child is solved 43 | children_solution_time = max(children_soln_time_list) 44 | else: 45 | raise TypeError(f"Node type {type(node)} not supported.") 46 | 47 | # Min solution time is time of first intrinsic solution or solution from children 48 | new_min_soln_time = min(intrinsic_solution_time, children_solution_time) 49 | 50 | # Correct one case that can arise with loops: the children could potentially 51 | # be solved before this node was created. 52 | # Ensure that new min soln time is at least this node's age! 53 | new_min_soln_time = max(new_min_soln_time, node.data["analysis_time"]) 54 | 55 | # Perform update 56 | old_min_solution_time = node.data.get("first_solution_time") 57 | node.data["first_solution_time"] = new_min_soln_time 58 | return old_min_solution_time is None or not math.isclose( 59 | old_min_solution_time, new_min_soln_time 60 | ) 61 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/data/reaction_sample.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import inspect 4 | from dataclasses import dataclass, field 5 | from typing import Any, Dict, Optional, Type, TypeVar 6 | 7 | from syntheseus.interface.bag import Bag 8 | from syntheseus.interface.molecule import SMILES_SEPARATOR, Molecule 9 | from syntheseus.interface.reaction import REACTION_SEPARATOR, Reaction 10 | from syntheseus.reaction_prediction.chem.utils import remove_atom_mapping 11 | from syntheseus.reaction_prediction.utils.misc import undictify_bag_of_molecules 12 | 13 | ReactionType = TypeVar("ReactionType", bound="ReactionSample") 14 | 15 | 16 | @dataclass(frozen=True, order=False) 17 | class ReactionSample(Reaction): 18 | """Extends a `Reaction` with fields and methods relevant to data loading and saving.""" 19 | 20 | reagents: str = field(default="", hash=True, compare=True) 21 | mapped_reaction_smiles: Optional[str] = field(default=None, hash=False, compare=False) 22 | 23 | @property 24 | def reaction_smiles_with_reagents(self) -> str: 25 | return ( 26 | f"{self.reactants_str}{REACTION_SEPARATOR}" 27 | f"{self.reagents}{REACTION_SEPARATOR}" 28 | f"{self.products_str}" 29 | ) 30 | 31 | @classmethod 32 | def from_dict(cls: Type[ReactionType], data: Dict[str, Any]) -> ReactionType: 33 | """Creates a sample from the given arguments ignoring superfluous ones.""" 34 | for key in ["reactants", "products"]: 35 | data[key] = undictify_bag_of_molecules(data[key]) 36 | 37 | return cls( 38 | **{ 39 | key: value 40 | for key, value in data.items() 41 | if key in inspect.signature(cls).parameters 42 | } 43 | ) 44 | 45 | @classmethod 46 | def from_reaction_smiles_strict( 47 | cls: Type[ReactionType], reaction_smiles: str, mapped: bool, **kwargs 48 | ) -> ReactionType: 49 | # Split the reaction SMILES and discard the reagents. 50 | [reactants_smiles, reagents_smiles, products_smiles] = [ 51 | smiles_part.split(SMILES_SEPARATOR) 52 | for smiles_part in reaction_smiles.split(REACTION_SEPARATOR) 53 | ] 54 | 55 | if mapped: 56 | assert "mapped_reaction_smiles" not in kwargs 57 | kwargs["mapped_reaction_smiles"] = reaction_smiles 58 | 59 | reactants_smiles = [remove_atom_mapping(smiles) for smiles in reactants_smiles] 60 | products_smiles = [remove_atom_mapping(smiles) for smiles in products_smiles] 61 | 62 | return cls( 63 | reactants=Bag(Molecule(smiles=smiles) for smiles in reactants_smiles), 64 | products=Bag(Molecule(smiles=smiles) for smiles in products_smiles), 65 | reagents=SMILES_SEPARATOR.join(sorted(reagents_smiles)), 66 | **kwargs, 67 | ) 68 | 69 | @classmethod 70 | def from_reaction_smiles(cls: Type[ReactionType], *args, **kwargs) -> Optional[ReactionType]: 71 | try: 72 | return cls.from_reaction_smiles_strict(*args, **kwargs) 73 | except Exception: 74 | return None 75 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "syntheseus" 7 | authors = [ # Note: PyPI does not support multiple authors with email addresses included. 8 | {name = "Austin Tripp"}, # ajt212@cam.ac.uk 9 | {name = "Krzysztof Maziarz"}, # krzysztof.maziarz@microsoft.com 10 | {name = "Guoqing Liu"}, # guoqingliu@microsoft.com 11 | {name = "Megan Stanley"}, # meganstanley@microsoft.com 12 | {name = "Marwin Segler"}, # marwinsegler@microsoft.com 13 | ] 14 | description = "A package for retrosynthetic planning." 15 | readme = "README.md" 16 | requires-python = ">=3.8" 17 | license = {file = "LICENSE"} 18 | dynamic = ["version"] 19 | dependencies = [ 20 | "more_itertools", # reaction_prediction 21 | "networkx", # search 22 | "numpy", # reaction_prediction, search 23 | "omegaconf", # reaction_prediction 24 | "rdkit", # reaction_prediction, search 25 | "tqdm", # reaction_prediction 26 | ] 27 | 28 | [project.optional-dependencies] 29 | viz = [ 30 | "pillow", 31 | "graphviz" 32 | ] 33 | dev = [ 34 | "pytest", 35 | "pytest-cov", 36 | "pytest-rerunfailures", 37 | "pre-commit" 38 | ] 39 | chemformer = ["syntheseus-chemformer==0.3.0"] 40 | graph2edits = ["syntheseus-graph2edits==0.2.0"] 41 | local-retro = ["syntheseus-local-retro==0.5.0"] 42 | megan = ["syntheseus-megan==0.2.0"] 43 | mhn-react = ["syntheseus-mhnreact==1.0.0"] 44 | retro-knn = ["syntheseus[local-retro]"] 45 | root-aligned = ["syntheseus-root-aligned==0.2.0"] 46 | all-single-step = [ 47 | "syntheseus[chemformer,graph2edits,local-retro,megan,mhn-react,retro-knn,root-aligned]" 48 | ] 49 | all = [ 50 | "syntheseus[viz,dev,all-single-step]" 51 | ] 52 | 53 | [project.urls] 54 | Documentation = "https://microsoft.github.io/syntheseus" 55 | Repository = "https://github.com/microsoft/syntheseus" 56 | 57 | [project.scripts] 58 | syntheseus = "syntheseus.cli.main:main" 59 | 60 | [tool.setuptools.packages.find] 61 | # Guidance from: https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html 62 | where = ["."] 63 | include = ["syntheseus*"] 64 | exclude = ["syntheseus.tests*"] 65 | namespaces = false 66 | 67 | [tool.setuptools.package-data] 68 | "syntheseus" = ["py.typed"] 69 | 70 | [tool.setuptools_scm] 71 | 72 | [tool.black] 73 | line-length = 100 74 | include = '\.pyi?$' 75 | 76 | [tool.mypy] 77 | python_version = 3.9 # pin modern python version 78 | ignore_missing_imports = true 79 | 80 | [tool.ruff] 81 | line-length = 100 82 | # Check https://beta.ruff.rs/docs/rules/ for full list of rules 83 | lint.select = [ 84 | "E", "W", # pycodestyle 85 | "F", # Pyflakes 86 | "I", # isort 87 | "NPY201", # check for functions/constants deprecated in numpy 2.* 88 | ] 89 | lint.ignore = [ 90 | # W605: invalid escape sequence -- triggered by pseudo-LaTeX in comments 91 | "W605", 92 | # E501: Line too long -- triggered by comments and such. black deals with shortening. 93 | "E501", 94 | # E741: Do not use variables named 'l', 'o', or 'i' -- disagree with PEP8 95 | "E741", 96 | ] 97 | -------------------------------------------------------------------------------- /syntheseus/search/graph/route.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections import Counter 4 | from typing import Sequence, Union 5 | 6 | from syntheseus.interface.molecule import Molecule 7 | from syntheseus.interface.reaction import SingleProductReaction 8 | from syntheseus.search.graph.base_graph import BaseReactionGraph 9 | 10 | MOL_AND_RXN = Union[Molecule, SingleProductReaction] 11 | 12 | 13 | class SynthesisGraph(BaseReactionGraph[SingleProductReaction]): 14 | """ 15 | Data structure used to hold a retrosynthesis graph containing only 16 | reaction objects. The purpose of this class is as a minimal container 17 | for route objects, instead of storing them as AndOrGraphs or MolSetGraphs. 18 | """ 19 | 20 | def __init__(self, root_node: SingleProductReaction, **kwargs) -> None: 21 | super().__init__(**kwargs) 22 | self._root_node = root_node 23 | self._graph.add_node(self._root_node) 24 | 25 | @property 26 | def root_node(self) -> SingleProductReaction: 27 | return self._root_node 28 | 29 | @property 30 | def root_mol(self) -> Molecule: 31 | return self.root_node.product 32 | 33 | def is_minimal(self) -> bool: 34 | # Check if any product appears more than once 35 | for rxn in self._graph.nodes: 36 | product_count = Counter([rxn.product for rxn in self.successors(rxn)]) 37 | if any(v > 1 for v in product_count.values()): 38 | return False 39 | return True 40 | 41 | def assert_validity(self) -> None: 42 | # Everything from superclass applies 43 | super().assert_validity() 44 | 45 | for node in self._graph.nodes: 46 | assert isinstance(node, SingleProductReaction) 47 | for parent in self.predecessors(node): 48 | assert isinstance(parent, SingleProductReaction) 49 | assert node.product in parent.reactants 50 | children = list(self.successors(node)) 51 | assert len(children) == len(set(children)) # all children should be unique 52 | assert set([child.product for child in children]) <= set( 53 | node.reactants 54 | ) # all children should be reactants 55 | 56 | def expand_with_reactions( 57 | self, 58 | reactions: list[SingleProductReaction], 59 | node: SingleProductReaction, 60 | ensure_tree: bool, 61 | ) -> Sequence[SingleProductReaction]: 62 | raise NotImplementedError 63 | 64 | def get_starting_molecules(self) -> set[Molecule]: 65 | """ 66 | Get the 'starting molecules' for this route, 67 | i.e. reactant molecules which are not a product of a child reaction. 68 | """ 69 | output: set[Molecule] = set() 70 | for rxn in self._graph.nodes: 71 | successor_products = {child_rxn.product for child_rxn in self.successors(rxn)} 72 | for reactant in rxn.unique_reactants: 73 | if reactant not in successor_products: 74 | output.add(reactant) 75 | return output 76 | 77 | def __str__(self) -> str: 78 | return str([rxn.reaction_smiles for rxn in self.nodes()]) 79 | -------------------------------------------------------------------------------- /syntheseus/interface/reaction.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass, field 4 | from typing import Optional, TypedDict 5 | 6 | from syntheseus.interface.bag import Bag 7 | from syntheseus.interface.molecule import Molecule, molecule_bag_to_smiles 8 | 9 | REACTION_SEPARATOR = ">" 10 | 11 | 12 | class ReactionMetaData(TypedDict, total=False): 13 | """Class to add typing to optional meta-data fields for reactions.""" 14 | 15 | cost: float 16 | template: str 17 | source: str # any explanation of the source of this reaction 18 | probability: float # probability for this reaction (e.g. from a model) 19 | log_probability: float # log probability for this reaction (should match log of above) 20 | score: float # any kind of score for this reaction (e.g. softmax value, probability) 21 | confidence: float # confidence (probability) that this reaction is possible 22 | reaction_id: int # template id or other kind of reaction id, if applicable 23 | reaction_smiles: str # reaction smiles for this reaction 24 | ground_truth_match: bool # whether this reaction matches ground truth 25 | 26 | 27 | def reaction_string(reactants_str: str, products_str: str) -> str: 28 | """Produces a consistent string representation of a reaction.""" 29 | return f"{reactants_str}{2 * REACTION_SEPARATOR}{products_str}" 30 | 31 | 32 | @dataclass(frozen=True, order=False) 33 | class Reaction: 34 | reactants: Bag[Molecule] = field(hash=True, compare=True) 35 | products: Bag[Molecule] = field(hash=True, compare=True) 36 | identifier: Optional[str] = field(default=None, hash=True, compare=True) 37 | 38 | # Dictionary to hold additional metadata. 39 | metadata: ReactionMetaData = field( 40 | default_factory=lambda: ReactionMetaData(), 41 | hash=False, 42 | compare=False, 43 | ) 44 | 45 | @property 46 | def unique_reactants(self) -> set[Molecule]: 47 | return set(self.reactants) 48 | 49 | @property 50 | def unique_products(self) -> set[Molecule]: 51 | return set(self.products) 52 | 53 | @property 54 | def reactants_str(self) -> str: 55 | return molecule_bag_to_smiles(self.reactants) 56 | 57 | @property 58 | def products_str(self) -> str: 59 | return molecule_bag_to_smiles(self.products) 60 | 61 | @property 62 | def reaction_smiles(self) -> str: 63 | return reaction_string(reactants_str=self.reactants_str, products_str=self.products_str) 64 | 65 | def __str__(self) -> str: 66 | output = self.reaction_smiles 67 | if self.identifier is not None: 68 | output += f" ({self.identifier})" 69 | return output 70 | 71 | 72 | @dataclass(frozen=True, order=False) 73 | class SingleProductReaction(Reaction): 74 | def __init__(self, *, reactants: Bag[Molecule], product: Molecule, **kwargs) -> None: 75 | super().__init__(reactants=reactants, products=Bag([product]), **kwargs) 76 | 77 | @property 78 | def product(self) -> Molecule: 79 | """Handle for the single product of this reaction.""" 80 | assert len(self.products) == 1 # Guaranteed in `__init__`. 81 | return next(iter(self.products)) 82 | -------------------------------------------------------------------------------- /syntheseus/tests/interface/test_reaction.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import copy 4 | from dataclasses import FrozenInstanceError 5 | 6 | import pytest 7 | 8 | from syntheseus.interface.bag import Bag 9 | from syntheseus.interface.molecule import Molecule 10 | from syntheseus.interface.reaction import Reaction, SingleProductReaction 11 | 12 | 13 | def test_reaction_objects_basic(): 14 | """Instantiate a `Reaction` object.""" 15 | C2 = Molecule(2 * "C") 16 | C3 = Molecule(3 * "C") 17 | C5 = Molecule(5 * "C") 18 | 19 | # Single product reaction 20 | rxn1 = SingleProductReaction( 21 | reactants=Bag([C2, C3]), 22 | product=C5, 23 | ) 24 | assert rxn1.reaction_smiles == "CC.CCC>>CCCCC" 25 | 26 | # Standadr (multi-product) reaction 27 | rxn2 = Reaction( 28 | reactants=Bag([C5]), 29 | products=Bag([C2, C3]), 30 | ) 31 | assert rxn2.reaction_smiles == "CCCCC>>CC.CCC" 32 | 33 | 34 | class TestSingleProductReactions: 35 | def test_positive_equality(self, rxn_cocs_from_co_cs: SingleProductReaction) -> None: 36 | """Various tests that 2 reactions with same products and reactants should be equal.""" 37 | 38 | rxn_copy = SingleProductReaction( 39 | product=copy.deepcopy(rxn_cocs_from_co_cs.product), 40 | reactants=copy.deepcopy(rxn_cocs_from_co_cs.reactants), 41 | ) 42 | 43 | # Test 1: original and copy should be equal 44 | assert rxn_copy == rxn_cocs_from_co_cs 45 | 46 | # Test 2: although equal, they should be distinct objects 47 | assert rxn_cocs_from_co_cs is not rxn_copy 48 | 49 | # Test 3: differences in metadata should not affect equality 50 | rxn_cocs_from_co_cs.metadata["test"] = "str1" # type: ignore[typeddict-unknown-key] 51 | rxn_copy.metadata["test"] = "str2" # type: ignore[typeddict-unknown-key] 52 | assert rxn_copy == rxn_cocs_from_co_cs 53 | assert rxn_cocs_from_co_cs.metadata != rxn_copy.metadata 54 | 55 | def test_negative_equality(self, rxn_cocs_from_co_cs: SingleProductReaction) -> None: 56 | """Various tests that reactions which should not be equal are not equal.""" 57 | 58 | # Test 1: changing identifier makes reactions not equal 59 | rxn_with_different_id = SingleProductReaction( 60 | product=copy.deepcopy(rxn_cocs_from_co_cs.product), 61 | reactants=copy.deepcopy(rxn_cocs_from_co_cs.reactants), 62 | identifier="different", 63 | ) 64 | assert rxn_cocs_from_co_cs != rxn_with_different_id 65 | 66 | # Test 2: different products and reactions are not equal 67 | diff_rxn = SingleProductReaction( 68 | product=Molecule("CC"), reactants=Bag([Molecule("CO"), Molecule("CS")]) 69 | ) 70 | assert rxn_cocs_from_co_cs != diff_rxn 71 | 72 | def test_frozen(self, rxn_cocs_from_co_cs: SingleProductReaction) -> None: 73 | """Test that the fields of the reaction are frozen.""" 74 | with pytest.raises(FrozenInstanceError): 75 | # type ignore is because mypy complains we are modifying a frozen field, which is the point of the test 76 | rxn_cocs_from_co_cs.identifier = "abc" # type: ignore[misc] 77 | 78 | def test_rxn_smiles(self, rxn_cocs_from_co_cs: SingleProductReaction) -> None: 79 | """Test that the reaction SMILES is as expected.""" 80 | assert rxn_cocs_from_co_cs.reaction_smiles == "CO.CS>>COCS" 81 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | We support two installation modes: 2 | 3 | - *core installation* allows you to build and benchmark your own models or search algorithms 4 | - *full installation* also allows you to perform end-to-end search using the supported models 5 | 6 | There are also two installation sources: 7 | 8 | - *pip*, which provides the most recent released version 9 | - *GitHub*, which provides the latest changes but may be less stable and may not be 10 | backward-compatible with the latest released version 11 | 12 | === "Core (pip)" 13 | 14 | ```bash 15 | conda env create -f environment.yml 16 | conda activate syntheseus 17 | 18 | pip install syntheseus 19 | ``` 20 | 21 | === "Full (pip)" 22 | 23 | ```bash 24 | conda env create -f environment_full.yml 25 | conda activate syntheseus-full 26 | 27 | pip install "syntheseus[all]" 28 | ``` 29 | 30 | === "Core (GitHub)" 31 | 32 | ```bash 33 | conda env create -f environment.yml 34 | conda activate syntheseus 35 | 36 | pip install -e . 37 | ``` 38 | 39 | === "Full (GitHub)" 40 | 41 | ```bash 42 | conda env create -f environment_full.yml 43 | conda activate syntheseus-full 44 | 45 | pip install -e ".[all]" 46 | ``` 47 | 48 | !!! note 49 | 50 | Make sure you are viewing the version of the docs matching your `syntheseus` installation. 51 | Select the `x.y.z` version you installed if you used `pip` (go [here](https://microsoft.github.io/syntheseus/stable/) for the latest one), 52 | or [dev](https://microsoft.github.io/syntheseus/dev/) if you installed `syntheseus` directly from GitHub. 53 | 54 | Core installation includes only minimal dependencies (no ML libraries), while full installation includes all supported models and also dependencies for visualization/development. 55 | 56 | Instructions above assume you already cloned the repository via 57 | 58 | ```bash 59 | git clone https://github.com/microsoft/syntheseus.git 60 | cd syntheseus 61 | ``` 62 | 63 | Note that `environment_full.yml` pins the CUDA version (to 11.3) for reproducibility. 64 | If you want to use a different one, make sure to edit the environment file accordingly. 65 | 66 | ??? info "Setting up GLN" 67 | 68 | We also support GLN, but it requires a specialized environment and is thus not installed via `pip`. 69 | See [here](https://github.com/microsoft/syntheseus/blob/main/syntheseus/reaction_prediction/environment_gln/Dockerfile) for a Docker environment necessary for running GLN. 70 | 71 | ## Reducing the number of dependencies 72 | 73 | To keep the environment smaller, you can replace the `all` option with a comma-separated subset of `{chemformer,local-retro,megan,mhn-react,retro-knn,root-aligned,viz,dev}` (`viz` and `dev` correspond to visualization and development dependencies, respectively). 74 | For example, `pip install -e ".[local-retro,root-aligned]"` installs only LocalRetro and RootAligned. 75 | If installing a subset of models, you can also delete the lines in `environment_full.yml` marked with names of models you do not wish to use. 76 | 77 | If you only want to use a very specific part of `syntheseus`, you could also install it without dependencies: 78 | 79 | ```bash 80 | pip install -e . --no-dependencies 81 | ``` 82 | 83 | You then would need to manually install a subset of dependencies that are required for a particular functionality you want to access. 84 | See `pyproject.toml` for a list of dependencies tied to the `search` and `reaction_prediction` subpackages. 85 | -------------------------------------------------------------------------------- /syntheseus/interface/molecule.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classes to hold molecules, without reference to the reactions they may take part in. 3 | """ 4 | 5 | from dataclasses import InitVar, dataclass, field 6 | from typing import Optional, TypedDict, Union 7 | 8 | from rdkit import Chem 9 | 10 | from syntheseus.interface.bag import Bag 11 | 12 | SMILES_SEPARATOR = "." 13 | 14 | 15 | class MoleculeMetaData(TypedDict, total=False): 16 | """Class to add typing to optional meta-data fields for molecules.""" 17 | 18 | rdkit_mol: Chem.Mol 19 | 20 | # Things related to multi-step retrosynthesis 21 | is_purchasable: bool 22 | cost: float 23 | supplier: str 24 | 25 | # Other potentially relevant data 26 | purity: float 27 | 28 | 29 | @dataclass(frozen=True, order=True) 30 | class Molecule: 31 | """ 32 | Object representing a molecule with its SMILES string and an optional 33 | identifier to distinguish molecules with identical SMILES strings (usually not used). 34 | Everything else is considered metadata and is stored in a dictionary which is not 35 | compared or hashed. 36 | 37 | The class is frozen since it should not need to be edited, 38 | and this will auto-implement __eq__ and __hash__ methods. 39 | 40 | On initialization it is possible to automatically convert to canonical 41 | smiles (default True), and to store the rdkit molecule (default True). 42 | If set to false, there is no guarantee of canonicalization or storage of 43 | an rdkit mol. 44 | """ 45 | 46 | smiles: str = field(hash=True, compare=True) 47 | identifier: Optional[Union[str, int]] = field(default=None, hash=True, compare=True) 48 | 49 | canonicalize: InitVar[bool] = True 50 | make_rdkit_mol: InitVar[bool] = True 51 | 52 | metadata: MoleculeMetaData = field( 53 | default_factory=lambda: MoleculeMetaData(), 54 | hash=False, 55 | compare=False, 56 | ) 57 | 58 | def __post_init__(self, canonicalize: bool, make_rdkit_mol: bool) -> None: 59 | if canonicalize or make_rdkit_mol: 60 | try: 61 | rdkit_mol = Chem.MolFromSmiles(self.smiles) 62 | except Exception as e: 63 | raise ValueError(f"Cannot create a molecule with SMILES '{self.smiles}'") from e 64 | 65 | if make_rdkit_mol: 66 | self.metadata["rdkit_mol"] = rdkit_mol 67 | 68 | if canonicalize: 69 | try: 70 | smiles_canonical = Chem.MolToSmiles(rdkit_mol) 71 | except Exception as e: 72 | raise ValueError( 73 | f"Cannot canonicalize a molecule with SMILES '{self.smiles}'" 74 | ) from e 75 | 76 | object.__setattr__(self, "smiles", smiles_canonical) 77 | 78 | @property 79 | def rdkit_mol(self) -> Chem.Mol: 80 | """Makes an rdkit mol if one does yet exist""" 81 | if "rdkit_mol" not in self.metadata: 82 | self.metadata["rdkit_mol"] = Chem.MolFromSmiles(self.smiles) 83 | return self.metadata["rdkit_mol"] 84 | 85 | 86 | def molecule_bag_to_smiles(mols: Bag[Molecule]) -> str: 87 | """Combine SMILES strings of molecules in a `Bag` into a single string. 88 | 89 | For two bags that represent the same multiset of molecules this function will return the same 90 | result, because iteration order over a `Bag` is deterministic (sorted using default comparator). 91 | """ 92 | return SMILES_SEPARATOR.join(mol.smiles for mol in mols) 93 | -------------------------------------------------------------------------------- /syntheseus/search/mol_inventory.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import abc 4 | import warnings 5 | from collections.abc import Collection 6 | from pathlib import Path 7 | from typing import Union 8 | 9 | from rdkit import Chem 10 | 11 | from syntheseus.interface.molecule import Molecule 12 | 13 | 14 | class BaseMolInventory(abc.ABC): 15 | @abc.abstractmethod 16 | def is_purchasable(self, mol: Molecule) -> bool: 17 | """Whether or not a molecule is purchasable.""" 18 | raise NotImplementedError 19 | 20 | def fill_metadata(self, mol: Molecule) -> None: 21 | """ 22 | Fills any/all metadata of a molecule. This method should be fast to call, 23 | and many algorithms will assume that it sets `is_purchasable`. 24 | """ 25 | 26 | # Default just adds whether the molecule is purchasable 27 | mol.metadata["is_purchasable"] = self.is_purchasable(mol) 28 | 29 | 30 | class ExplicitMolInventory(BaseMolInventory): 31 | """ 32 | Base class for MolInventories which store an explicit list of purchasable molecules. 33 | It exposes and additional method to explore this list. 34 | 35 | If it is unclear how a mol inventory might *not* have an explicit list of purchasable 36 | molecules, imagine a toy problem where every molecule with <= 10 atoms is purchasable. 37 | It is easy to check if a molecule has <= 10 atoms, but it is difficult to enumerate 38 | all molecules with <= 10 atoms. 39 | """ 40 | 41 | @abc.abstractmethod 42 | def to_purchasable_mols(self) -> Collection[Molecule]: 43 | """Returns an explicit collection of all purchasable molecules. 44 | 45 | Likely expensive for large inventories, should be used mostly for testing or debugging. 46 | """ 47 | 48 | def purchasable_mols(self) -> Collection[Molecule]: 49 | warnings.warn( 50 | "purchasable_mols is deprecated, use to_purchasable_mols instead", DeprecationWarning 51 | ) 52 | return self.to_purchasable_mols() 53 | 54 | @abc.abstractmethod 55 | def __len__(self) -> int: 56 | """Return the number of purchasable molecules in the inventory.""" 57 | 58 | 59 | class SmilesListInventory(ExplicitMolInventory): 60 | """Most common type of inventory: a list of purchasable SMILES.""" 61 | 62 | def __init__(self, smiles_list: list[str], canonicalize: bool = True): 63 | if canonicalize: 64 | # For canonicalization we sequence `MolFromSmiles` and `MolToSmiles` to exactly match 65 | # the process employed in the `Molecule` class. 66 | smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(s)) for s in smiles_list] 67 | 68 | self._smiles_set = set(smiles_list) 69 | 70 | def is_purchasable(self, mol: Molecule) -> bool: 71 | if mol.identifier is not None: 72 | warnings.warn( 73 | f"Molecule identifier {mol.identifier} will be ignored during inventory lookup" 74 | ) 75 | 76 | return mol.smiles in self._smiles_set 77 | 78 | def to_purchasable_mols(self) -> Collection[Molecule]: 79 | return {Molecule(s, make_rdkit_mol=False, canonicalize=False) for s in self._smiles_set} 80 | 81 | def __len__(self) -> int: 82 | return len(self._smiles_set) 83 | 84 | @classmethod 85 | def load_from_file(cls, path: Union[str, Path], **kwargs) -> SmilesListInventory: 86 | """Load the inventory SMILES from a file.""" 87 | with open(path, "rt") as f_inventory: 88 | return cls([line.strip() for line in f_inventory], **kwargs) 89 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/utils/inference.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, List, Sequence, Union 3 | 4 | from syntheseus.interface.bag import Bag 5 | from syntheseus.interface.molecule import Molecule 6 | from syntheseus.interface.reaction import ( 7 | Reaction, 8 | ReactionMetaData, 9 | SingleProductReaction, 10 | ) 11 | from syntheseus.reaction_prediction.chem.utils import molecule_bag_from_smiles 12 | 13 | 14 | def process_raw_smiles_outputs_backwards( 15 | input: Molecule, output_list: List[str], metadata_list: List[ReactionMetaData] 16 | ) -> Sequence[SingleProductReaction]: 17 | """Convert raw SMILES outputs into a list of `SingleProductReaction` objects. 18 | 19 | Args: 20 | input: Model input. 21 | output_list: Raw SMILES outputs (including potentially invalid ones). 22 | metadata_list: Additional metadata to attach to the predictions (e.g. probability). 23 | 24 | Returns: 25 | A list of `SingleProductReaction`s; may be shorter than `outputs` if some of the raw 26 | SMILES could not be parsed into valid reactant bags. 27 | """ 28 | predictions: List[SingleProductReaction] = [] 29 | 30 | for raw_output, metadata in zip(output_list, metadata_list): 31 | reactants = molecule_bag_from_smiles(raw_output) 32 | 33 | # Only consider the prediction if the SMILES can be parsed. 34 | if reactants is not None: 35 | predictions.append( 36 | SingleProductReaction(product=input, reactants=reactants, metadata=metadata) 37 | ) 38 | 39 | return predictions 40 | 41 | 42 | def process_raw_smiles_outputs_forwards( 43 | input: Bag[Molecule], output_list: List[str], metadata_list: List[ReactionMetaData] 44 | ) -> Sequence[Reaction]: 45 | """Convert raw SMILES outputs into a list of `Reaction` objects. 46 | Like method `process_raw_smiles_outputs_backwards`, but for forward models. 47 | 48 | Args: 49 | input: Model input. 50 | output_list: Raw SMILES outputs (including potentially invalid ones). 51 | metadata_list: Additional metadata to attach to the predictions (e.g. probability). 52 | 53 | Returns: 54 | A list of `Reaction`s; may be shorter than `outputs` if some of the raw 55 | SMILES could not be parsed into valid reactant bags. 56 | """ 57 | predictions: List[Reaction] = [] 58 | 59 | for raw_output, metadata in zip(output_list, metadata_list): 60 | products = molecule_bag_from_smiles(raw_output) 61 | 62 | # Only consider the prediction if the SMILES can be parsed. 63 | if products is not None: 64 | predictions.append(Reaction(products=products, reactants=input, metadata=metadata)) 65 | 66 | return predictions 67 | 68 | 69 | def get_unique_file_in_dir(dir: Union[str, Path], pattern: str) -> Path: 70 | candidates = list(Path(dir).glob(pattern)) 71 | if len(candidates) != 1: 72 | raise ValueError( 73 | f"Expected a unique match for {pattern} in {dir}, found {len(candidates)}: {candidates}" 74 | ) 75 | 76 | return candidates[0] 77 | 78 | 79 | def get_module_path(module: Any) -> str: 80 | """Heuristically extract the local path to an imported module.""" 81 | 82 | # In some cases, `module.__path__` is already a `List`, while in other cases it may be a 83 | # `_NamespacePath` object. Either way the conversion below leaves us with `List[str]`. 84 | path_list: List[str] = list(module.__path__) 85 | 86 | if len(path_list) != 1: 87 | raise ValueError(f"Cannot extract path to module {module} from {path_list}") 88 | 89 | return path_list[0] 90 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/inference/gln.py: -------------------------------------------------------------------------------- 1 | """Inference wrapper for the Graph Logic Network (GLN) model. 2 | 3 | Paper: https://arxiv.org/abs/2001.01408 4 | Code: https://github.com/Hanjun-Dai/GLN 5 | 6 | The original GLN code is released under the MIT license. 7 | """ 8 | 9 | import sys 10 | from pathlib import Path 11 | from typing import List, Sequence 12 | 13 | from syntheseus.interface.molecule import Molecule 14 | from syntheseus.interface.reaction import SingleProductReaction 15 | from syntheseus.reaction_prediction.inference.base import ExternalBackwardReactionModel 16 | from syntheseus.reaction_prediction.utils.inference import process_raw_smiles_outputs_backwards 17 | from syntheseus.reaction_prediction.utils.misc import suppress_outputs 18 | 19 | 20 | class GLNModel(ExternalBackwardReactionModel): 21 | def __init__(self, *args, dataset_name: str = "schneider50k", **kwargs) -> None: 22 | """Initializes the GLN model wrapper. 23 | 24 | Assumed format of the model directory: 25 | - `model_dir` contains files necessary to build `RetroGLN` 26 | - `model_dir/{dataset_name}.ckpt` is the model checkpoint 27 | - `model_dir/cooked_{dataset_name}/atom_list.txt` is the atom type list 28 | """ 29 | super().__init__(*args, **kwargs) 30 | 31 | import torch 32 | 33 | chkpt_path = Path(self.model_dir) / f"{dataset_name}.ckpt" 34 | gln_args = { 35 | "dropbox": self.model_dir, 36 | "data_name": dataset_name, 37 | "model_for_test": chkpt_path, 38 | "tpl_name": "default", 39 | "f_atoms": Path(self.model_dir) / f"cooked_{dataset_name}" / "atom_list.txt", 40 | "gpu": torch.device(self.device).index, 41 | } 42 | 43 | # Suppress most of the prints from GLN's internals. This only works on messages that 44 | # originate from Python, so the C++-based ones slip through. 45 | with suppress_outputs(): 46 | # GLN makes heavy use of global state (saved either in `gln.common.cmd_args` or `sys.argv`), 47 | # so we have to hack both of these sources below. 48 | from gln.common.cmd_args import cmd_args 49 | 50 | sys.argv = [] 51 | for name, value in gln_args.items(): 52 | setattr(cmd_args, name, value) 53 | sys.argv += [f"-{name}", str(value)] 54 | 55 | # The global state hackery has to happen before this. 56 | from gln.test.model_inference import RetroGLN 57 | 58 | self.model = RetroGLN(self.model_dir, chkpt_path) 59 | 60 | @property 61 | def name(self) -> str: 62 | return "GLN" 63 | 64 | def get_parameters(self): 65 | return self.model.gln.parameters() 66 | 67 | def _get_model_predictions( 68 | self, input: Molecule, num_results: int 69 | ) -> Sequence[SingleProductReaction]: 70 | with suppress_outputs(): 71 | result = self.model.run(input.smiles, num_results, num_results) 72 | 73 | if result is None: 74 | return [] 75 | else: 76 | # `scores` are actually probabilities (produced by running `softmax`). 77 | return process_raw_smiles_outputs_backwards( 78 | input=input, 79 | output_list=result["reactants"], 80 | metadata_list=[ 81 | {"probability": probability.item()} for probability in result["scores"] 82 | ], 83 | ) 84 | 85 | def _get_reactions( 86 | self, inputs: List[Molecule], num_results: int 87 | ) -> List[Sequence[SingleProductReaction]]: 88 | return [self._get_model_predictions(input, num_results=num_results) for input in inputs] 89 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | workflow_dispatch: 9 | 10 | jobs: 11 | test-core: 12 | runs-on: ubuntu-latest 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] 17 | defaults: 18 | run: 19 | shell: bash -l {0} 20 | name: test-core (Python ${{ matrix.python-version }}) 21 | steps: 22 | - uses: actions/checkout@v3 23 | - uses: conda-incubator/setup-miniconda@v3 24 | with: 25 | mamba-version: "*" 26 | channels: conda-forge,defaults 27 | channel-priority: true 28 | python-version: ${{ matrix.python-version }} 29 | environment-file: environment.yml 30 | - name: Install syntheseus 31 | run: | 32 | pip install .[dev] 33 | - name: Run pre-commit 34 | run: | 35 | pre-commit run --verbose --all-files 36 | - name: Run unit tests 37 | run: | 38 | coverage run -p -m pytest ./syntheseus/tests/ 39 | coverage report --data-file .coverage.* 40 | - name: Upload coverage report 41 | uses: actions/upload-artifact@v4 42 | with: 43 | name: .coverage.core-${{ matrix.python-version }} 44 | path: .coverage.* 45 | include-hidden-files: true 46 | test-models: 47 | strategy: 48 | fail-fast: false 49 | matrix: 50 | runner: ["ubuntu-latest", "Syntheseus-GPU"] 51 | runs-on: ${{ matrix.runner }} 52 | defaults: 53 | run: 54 | shell: bash -l {0} 55 | name: test-models (${{ matrix.runner == 'ubuntu-latest' && 'CPU' || 'GPU' }}) 56 | steps: 57 | - name: Free extra disk space 58 | uses: jlumbroso/free-disk-space@main 59 | - uses: actions/checkout@v3 60 | - uses: conda-incubator/setup-miniconda@v3 61 | with: 62 | mamba-version: "*" 63 | channels: conda-forge,defaults 64 | channel-priority: true 65 | environment-file: environment_full.yml 66 | - name: Install syntheseus with all single-step models 67 | run: | 68 | sudo apt install -y graphviz 69 | pip install .[all] 70 | - name: Verify GPU is available 71 | if: ${{ matrix.runner != 'ubuntu-latest' }} 72 | run: | 73 | nvidia-smi || (echo "❌ No GPU detected" && exit 1) 74 | python -c "import torch; assert torch.cuda.is_available(), '❌ GPU not found'; print('✅ Found GPU:', torch.cuda.get_device_name(0))" 75 | - name: Run single-step model tests 76 | run: | 77 | coverage run -p -m pytest \ 78 | ./syntheseus/tests/cli/test_cli.py \ 79 | ./syntheseus/tests/reaction_prediction/inference/test_models.py \ 80 | ./syntheseus/tests/reaction_prediction/utils/test_parallel.py 81 | coverage report --data-file .coverage.* 82 | - name: Upload coverage report 83 | uses: actions/upload-artifact@v4 84 | with: 85 | name: .coverage.models-${{ matrix.runner }} 86 | path: .coverage.* 87 | include-hidden-files: true 88 | coverage: 89 | needs: [test-core, test-models] 90 | runs-on: ubuntu-latest 91 | steps: 92 | - uses: actions/checkout@v2 93 | - name: Set up Python 94 | uses: actions/setup-python@v2 95 | with: 96 | python-version: 3.9 97 | - name: Install dependencies 98 | run: | 99 | python -m pip install --upgrade pip 100 | pip install coverage 101 | - uses: actions/download-artifact@v4 102 | with: 103 | merge-multiple: true 104 | - name: Generate a combined coverage report 105 | run: | 106 | coverage combine 107 | coverage report 108 | -------------------------------------------------------------------------------- /syntheseus/tests/cli/test_search.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | from pathlib import Path 5 | from types import SimpleNamespace 6 | from typing import Sequence 7 | 8 | import pytest 9 | from omegaconf import OmegaConf 10 | 11 | from syntheseus import BackwardReactionModel, Bag, Molecule, SingleProductReaction 12 | from syntheseus.cli.search import SearchConfig, run_from_config 13 | from syntheseus.reaction_prediction.inference.config import BackwardModelClass 14 | 15 | 16 | class FlakyReactionModel(BackwardReactionModel): 17 | """Dummy reaction model that only works when called for the first time.""" 18 | 19 | def __init__(self, *args, **kwargs) -> None: 20 | super().__init__() 21 | self._used = False 22 | 23 | def _get_reactions( 24 | self, inputs: list[Molecule], num_results: int 25 | ) -> list[Sequence[SingleProductReaction]]: 26 | if self._used: 27 | raise RuntimeError() 28 | 29 | self._used = True 30 | return [ 31 | [ 32 | SingleProductReaction( 33 | reactants=Bag([Molecule("C")]), product=product, metadata={"probability": 1.0} 34 | ) 35 | ] 36 | for product in inputs 37 | ] 38 | 39 | 40 | def test_resume_search(tmpdir: Path) -> None: 41 | search_targets_file_path = tmpdir / "search_targets.smiles" 42 | with open(search_targets_file_path, "wt") as f_search_targets: 43 | f_search_targets.write("CC\nCC\nCC\nCC\n") 44 | 45 | inventory_file_path = tmpdir / "inventory.smiles" 46 | with open(inventory_file_path, "wt") as f_inventory: 47 | f_inventory.write("C\n") 48 | 49 | # Inject our flaky reaction model into the set of supported model classes. 50 | BackwardModelClass._member_map_["FlakyReactionModel"] = SimpleNamespace( # type: ignore 51 | name="FlakyReactionModel", value=FlakyReactionModel 52 | ) 53 | 54 | config = OmegaConf.create( # type: ignore 55 | SearchConfig( 56 | model_class="FlakyReactionModel", # type: ignore[arg-type] 57 | search_algorithm="retro_star", # type: ignore[arg-type] 58 | search_targets_file=str(search_targets_file_path), 59 | inventory_smiles_file=str(inventory_file_path), 60 | results_dir=str(tmpdir), 61 | append_timestamp_to_dir=False, 62 | limit_iterations=1, 63 | num_routes_to_plot=0, 64 | ) 65 | ) 66 | 67 | results_dir = tmpdir / "FlakyReactionModel" 68 | 69 | def file_exist(idx: int, name: str) -> bool: 70 | return (results_dir / str(idx) / name).exists() 71 | 72 | # Try to run search three times; each time we will succeed solving one target (which requires one 73 | # call) and then fail on the next one. 74 | for trial_idx in range(3): 75 | with pytest.raises(RuntimeError): 76 | run_from_config(config) 77 | 78 | for idx in range(trial_idx + 1): 79 | assert file_exist(idx, "stats.json") 80 | assert not file_exist(idx, ".lock") 81 | 82 | assert not file_exist(trial_idx + 1, "stats.json") 83 | assert file_exist(trial_idx + 1, ".lock") 84 | 85 | run_from_config(config) 86 | 87 | # The last search needs to solve one final target so it will succeed. 88 | for idx in range(4): 89 | assert file_exist(idx, "stats.json") 90 | assert not file_exist(idx, ".lock") 91 | 92 | with open(results_dir / "stats.json", "rt") as f_stats: 93 | stats = json.load(f_stats) 94 | 95 | # Even though each search only solved a single target, final stats should include everything. 96 | assert stats["num_targets"] == stats["num_solved_targets"] == 4 97 | 98 | # Finally change the targets and verify that the discrepancy will be detected. 99 | with open(search_targets_file_path, "wt") as f_search_targets: 100 | f_search_targets.write("CC\nCCCC\nCC\nCC\n") 101 | 102 | with pytest.raises(RuntimeError): 103 | run_from_config(config) 104 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/inference/toy_models.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Sequence 4 | 5 | from syntheseus.interface.bag import Bag 6 | from syntheseus.interface.models import BackwardReactionModel 7 | from syntheseus.interface.molecule import Molecule 8 | from syntheseus.interface.reaction import SingleProductReaction 9 | 10 | 11 | class ListOfReactionsToyModel(BackwardReactionModel): 12 | """A model which returns reactions from a pre-defined list.""" 13 | 14 | def __init__(self, reaction_list: Sequence[SingleProductReaction], **kwargs) -> None: 15 | super().__init__(**kwargs) 16 | self.reaction_list = list(reaction_list) 17 | 18 | def _get_reactions( 19 | self, inputs: list[Molecule], num_results: int 20 | ) -> list[Sequence[SingleProductReaction]]: 21 | return [[r for r in self.reaction_list if r.product == mol] for mol in inputs] 22 | 23 | 24 | class LinearMoleculesToyModel(BackwardReactionModel): 25 | """ 26 | A simple toy model of "reactions" on linear "ball-and-stick" molecules, 27 | where the possible reactions involve string cuts and substitutions. 28 | 29 | Molecules in this model must be formed entirely from C,S, and O atoms with single bonds. 30 | The reactions allowed are: 31 | - string cuts, e.g. "CCOC" -> "CC" + "OC" (*see note 2 below) 32 | - substitution of the atom on either end of the molecule: e.g. "CCOC" -> "CCOO" 33 | 34 | NOTE 1: molecules formed by this model are mostly unphysical and the reactions 35 | are not actual chemical reactions. This model is intended for testing and debugging. 36 | 37 | NOTE 2: all molecules are returned with canonical SMILES, so the outputs may not look 38 | the same as a string cut. For example, "CCOC" -> "CC" + "OC" will get canonicalized to 39 | "CC" + "CO" (i.e. the "O" and "C" will swap places). Fundamentally this doesn't change anything 40 | since "CO" and "OC" are the same molecule, but it may be confusing when debugging. 41 | """ 42 | 43 | def __init__(self, allow_substitution: bool = True, **kwargs) -> None: 44 | super().__init__(**kwargs) 45 | self._allow_substitution = allow_substitution # should not be modified after init 46 | 47 | def _get_single_backward_reactions(self, mol: Molecule) -> list[SingleProductReaction]: 48 | assert set(mol.smiles) <= set("COS"), "Molecules must be formed out of C,O, and S atoms." 49 | assert len(mol.smiles) > 0, "Molecules must have at least 1 atom." 50 | output: list[SingleProductReaction] = [] 51 | 52 | # String cuts 53 | for cut_idx in range(1, len(mol.smiles)): 54 | mol1 = Molecule(mol.smiles[:cut_idx], make_rdkit_mol=False) 55 | mol2 = Molecule(mol.smiles[cut_idx:], make_rdkit_mol=False) 56 | output.append( 57 | SingleProductReaction( 58 | product=mol, 59 | reactants=Bag({mol1, mol2}), # don't include duplicates 60 | metadata={"source": f"string cut at idx {cut_idx}"}, 61 | ) 62 | ) 63 | 64 | # Substitutions 65 | if self._allow_substitution: 66 | for sub_idx in {0, len(mol.smiles) - 1}: # use set in case len(mol.smiles) == 1 67 | for sub_atom in "COS": 68 | if mol.smiles[sub_idx] == sub_atom: 69 | continue 70 | else: 71 | new_mol = Molecule( 72 | mol.smiles[:sub_idx] + sub_atom + mol.smiles[sub_idx + 1 :], 73 | make_rdkit_mol=False, 74 | ) 75 | output.append( 76 | SingleProductReaction( 77 | product=mol, 78 | reactants=Bag([new_mol]), 79 | metadata={"source": f"substitution idx {sub_idx} with {sub_atom}"}, 80 | ), 81 | ) 82 | 83 | return output 84 | 85 | def _get_reactions( 86 | self, inputs: list[Molecule], num_results: int 87 | ) -> list[Sequence[SingleProductReaction]]: 88 | return [self._get_single_backward_reactions(mol) for mol in inputs] 89 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/inference/graph2edits.py: -------------------------------------------------------------------------------- 1 | """Inference wrapper for the Graph2Edits model. 2 | 3 | Paper: https://www.nature.com/articles/s41467-023-38851-5 4 | Code: https://github.com/Jamson-Zhong/Graph2Edits 5 | 6 | The original Graph2Edits code is released under the MIT license. 7 | Parts of this file are based on code from the GitHub repository above. 8 | """ 9 | 10 | from __future__ import annotations 11 | 12 | import sys 13 | from typing import Sequence 14 | 15 | from rdkit import Chem 16 | 17 | from syntheseus.interface.molecule import Molecule 18 | from syntheseus.interface.reaction import SingleProductReaction 19 | from syntheseus.reaction_prediction.inference.base import ExternalBackwardReactionModel 20 | from syntheseus.reaction_prediction.utils.inference import ( 21 | get_module_path, 22 | get_unique_file_in_dir, 23 | process_raw_smiles_outputs_backwards, 24 | ) 25 | from syntheseus.reaction_prediction.utils.misc import suppress_outputs, suppress_rdkit_outputs 26 | 27 | 28 | class Graph2EditsModel(ExternalBackwardReactionModel): 29 | def __init__(self, *args, max_edit_steps: int = 9, **kwargs) -> None: 30 | """Initializes the Graph2Edits model wrapper. 31 | 32 | Assumed format of the model directory: 33 | - `model_dir` contains the model checkpoint as the only `*.pt` file 34 | """ 35 | super().__init__(*args, **kwargs) 36 | 37 | import graph2edits 38 | import torch 39 | 40 | sys.path.insert(0, str(get_module_path(graph2edits))) 41 | 42 | from graph2edits.models import BeamSearch, Graph2Edits 43 | 44 | checkpoint = torch.load( 45 | get_unique_file_in_dir(self.model_dir, pattern="*.pt"), map_location=self.device 46 | ) 47 | 48 | model = Graph2Edits(**checkpoint["saveables"], device=self.device) 49 | model.load_state_dict(checkpoint["state"]) 50 | model.to(self.device) 51 | model.eval() 52 | 53 | # We set the beam size to a placeholder value for now and override it in `_get_reactions`. 54 | self.model = BeamSearch(model=model, step_beam_size=10, beam_size=None, use_rxn_class=False) 55 | self._max_edit_steps = max_edit_steps 56 | 57 | def get_parameters(self): 58 | return self.model.model.parameters() 59 | 60 | def _get_reactions( 61 | self, inputs: list[Molecule], num_results: int 62 | ) -> list[Sequence[SingleProductReaction]]: 63 | import torch 64 | 65 | self.model.beam_size = num_results 66 | 67 | batch_predictions = [] 68 | for input in inputs: 69 | # Copy the `rdkit` molecule as below we modify it in-place. 70 | mol = Chem.Mol(input.rdkit_mol) 71 | 72 | # Assign a dummy atom mapping as Graph2Edits depends on it. This has no connection to 73 | # the ground-truth atom mapping, which we do not have access to. 74 | for idx, atom in enumerate(mol.GetAtoms()): 75 | atom.SetAtomMapNum(idx + 1) 76 | 77 | with torch.no_grad(), suppress_outputs(), suppress_rdkit_outputs(): 78 | try: 79 | raw_results = self.model.run_search( 80 | prod_smi=Chem.MolToSmiles(mol), 81 | max_steps=self._max_edit_steps, 82 | rxn_class=None, 83 | ) 84 | except IndexError: 85 | # This can happen in some rare edge cases (e.g. "OBr"). 86 | raw_results = [] 87 | 88 | # Errors are returned as a string "final_smi_unmapped"; we get rid of those here. 89 | raw_results = [ 90 | raw_result 91 | for raw_result in raw_results 92 | if raw_result["final_smi"] != "final_smi_unmapped" 93 | ] 94 | 95 | batch_predictions.append( 96 | process_raw_smiles_outputs_backwards( 97 | input=input, 98 | output_list=[raw_result["final_smi"] for raw_result in raw_results], 99 | metadata_list=[ 100 | {"probability": raw_result["prob"]} for raw_result in raw_results 101 | ], 102 | ) 103 | ) 104 | 105 | return batch_predictions 106 | -------------------------------------------------------------------------------- /syntheseus/search/graph/node.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import abc 4 | import datetime 5 | import math 6 | from collections.abc import Collection 7 | from dataclasses import dataclass, field 8 | from typing import TypedDict 9 | 10 | 11 | class _NodeData_Time(TypedDict, total=False): 12 | """Holds optional data about the time a node was created.""" 13 | 14 | # How many times has the rxn model been called when this node was created? 15 | num_calls_rxn_model: int 16 | 17 | # How many times has the value function been called when this node was created? 18 | num_calls_value_function: int 19 | 20 | 21 | class _NodeData_Algorithms(TypedDict, total=False): 22 | """Holds optional data used by specific algorithms.""" 23 | 24 | # ================================================== 25 | # General 26 | # ================================================== 27 | policy_score: float 28 | 29 | # ================================================== 30 | # Retro* 31 | # ================================================== 32 | retro_star_min_cost: float # minimum cost found so far 33 | retro_star_reaction_number: float 34 | reaction_number_estimate: float 35 | retro_star_value: float 36 | retro_star_rxn_cost: float 37 | retro_star_mol_cost: float 38 | 39 | # ================================================== 40 | # MCTS 41 | # ================================================== 42 | mcts_value: float 43 | mcts_prev_reward: float # the most recent reward received 44 | 45 | # ================================================== 46 | # PDVN MCTS 47 | # ================================================== 48 | pdvn_reaction_cost: float 49 | pdvn_mcts_v_syn: float 50 | pdvn_mcts_v_cost: float 51 | pdvn_mcts_prev_reward_syn: float 52 | pdvn_mcts_prev_reward_cost: float 53 | pdvn_min_syn_cost: float 54 | 55 | 56 | class _NodeData_Analysis(TypedDict, total=False): 57 | """Holds optional data used during analysis of search results.""" 58 | 59 | analysis_time: float # Used to hold a node's creation time (measured any way) for analysis purposes 60 | first_solution_time: float # time of first solution (according to analysis_time) 61 | route_cost: float # non-negative cost that this node contributes to the entire route 62 | 63 | 64 | class NodeData(_NodeData_Time, _NodeData_Algorithms, _NodeData_Analysis): 65 | """Holds all kinds of node data.""" 66 | 67 | pass 68 | 69 | 70 | @dataclass 71 | class BaseGraphNode(abc.ABC): 72 | # Whether the node is "solved" (has a synthesis route leading to it) 73 | has_solution: bool = False 74 | 75 | # How many times has the node been "visited". 76 | # The meaning of a "visit" will be different for different algorithms. 77 | num_visit: int = 0 78 | 79 | # How "deep" is this node, i.e. the length of the path from the root node to this node. 80 | # It is initialized to inf to indicate "not set" (and this is the only value which will be 81 | # stable with graphs with no root node where depth is ill-defined) 82 | depth: int = math.inf # type: ignore 83 | 84 | # Whether the node has been expanded 85 | is_expanded: bool = False 86 | 87 | # Time when this node was created (used for analysis of search results). 88 | creation_time: datetime.datetime = field( 89 | default_factory=lambda: datetime.datetime.now(datetime.timezone.utc) 90 | ) 91 | 92 | # Any other node data, stored as a TypedDict to allow arbitrary values to be tracked 93 | # while also allowing type-checking. 94 | data: NodeData = field(default_factory=lambda: NodeData()) 95 | 96 | def __eq__(self, other): 97 | # No comparison of node values, only identity. 98 | return self is other 99 | 100 | def __hash__(self): 101 | # Hash nodes based on id: 102 | # this ensures distinct nodes always have a distinct hash. 103 | return id(self) 104 | 105 | @abc.abstractmethod 106 | def _has_intrinsic_solution(self) -> bool: 107 | """Whether this node has a solution without considering its children.""" 108 | raise NotImplementedError 109 | 110 | @abc.abstractmethod 111 | def _has_solution_from_children(self, children: Collection[BaseGraphNode]) -> bool: 112 | """Whether this node has a solution, exclusively considering its children.""" 113 | raise NotImplementedError 114 | -------------------------------------------------------------------------------- /syntheseus/tests/search/analysis/test_diversity.py: -------------------------------------------------------------------------------- 1 | """Test diversity analysis.""" 2 | from __future__ import annotations 3 | 4 | import random 5 | 6 | import pytest 7 | 8 | from syntheseus.search.analysis.diversity import ( 9 | estimate_packing_number, 10 | molecule_jaccard_distance, 11 | molecule_symmetric_difference_distance, 12 | reaction_jaccard_distance, 13 | reaction_symmetric_difference_distance, 14 | ) 15 | from syntheseus.search.graph.route import SynthesisGraph 16 | 17 | 18 | def test_empty_input() -> None: 19 | """Test there is no crash when an empty list of routes is input.""" 20 | distinct_routes = estimate_packing_number( 21 | routes=[], radius=0.5, distance_metric=reaction_jaccard_distance, num_tries=100 22 | ) 23 | assert len(distinct_routes) == 0 24 | 25 | 26 | @pytest.mark.parametrize( 27 | "metric, threshold, expected_packing_number", 28 | [ 29 | # Easy case 1: distance threshold is very large so only 1 route can be returned. 30 | # For Jaccard distance choose a value of 1 (the max value dist distance can take) 31 | # For symmetric difference distance choose a value of 20 (more than the max number of mols/rxns in a route) 32 | (molecule_jaccard_distance, 1.0, 1), 33 | (reaction_jaccard_distance, 1.0, 1), 34 | (molecule_symmetric_difference_distance, 20, 1), 35 | (reaction_symmetric_difference_distance, 20, 1), 36 | # 37 | # Easy case 2: distance threshold = 0 so it should just count number of distinct routes (11) 38 | (molecule_jaccard_distance, 0, 11), 39 | (reaction_jaccard_distance, 0, 11), 40 | (molecule_symmetric_difference_distance, 0, 11), 41 | (reaction_symmetric_difference_distance, 0, 11), 42 | # 43 | # Individual harder cases 44 | (molecule_jaccard_distance, 0.8, 2), 45 | (molecule_jaccard_distance, 0.5, 7), 46 | (molecule_jaccard_distance, 0.2, 9), 47 | ( 48 | reaction_jaccard_distance, 49 | 0.99, 50 | 7, 51 | ), # routes with completely non-overlapping reaction sets 52 | (reaction_symmetric_difference_distance, 4, 6), 53 | (molecule_symmetric_difference_distance, 2, 7), # differ in at least 2 molecules 54 | ], 55 | ) 56 | def test_estimate_packing_number( 57 | sample_synthesis_routes: list[SynthesisGraph], 58 | metric, 59 | threshold: float, 60 | expected_packing_number: int, 61 | ) -> None: 62 | """ 63 | Check that after a large number of trials, the correct packing number is found 64 | for a set of routes. 65 | """ 66 | 67 | # Run the packing number estimation 68 | distinct_routes = estimate_packing_number( 69 | routes=sample_synthesis_routes, 70 | radius=threshold, 71 | distance_metric=metric, 72 | num_tries=1000, 73 | random_state=random.Random(100), 74 | ) 75 | 76 | # Check that routes returned are all a distance > threshold from each other 77 | for i, route1 in enumerate(distinct_routes): 78 | for route2 in distinct_routes[i + 1 :]: 79 | assert metric(route1, route2) > threshold 80 | 81 | # Check that the correct packing number is found 82 | assert len(distinct_routes) == expected_packing_number 83 | 84 | 85 | @pytest.mark.parametrize( 86 | "metric, threshold, max_packing_number, expected_packing_number", 87 | [ 88 | # Easy case 1: distance threshold is very large so only 1 route can be returned 89 | (reaction_jaccard_distance, 1.0, 0, 0), # limit is 0 so no routes can be returned 90 | (reaction_jaccard_distance, 1.0, 10, 1), # limit is higher than actual number 91 | (reaction_jaccard_distance, 1e-3, 5, 5), # are 11 such routes but limit is 5 92 | ], 93 | ) 94 | def test_max_packing_number( 95 | sample_synthesis_routes: list[SynthesisGraph], 96 | metric, 97 | threshold: float, 98 | max_packing_number: int, 99 | expected_packing_number: int, 100 | ) -> None: 101 | """Test that max packing number actually limits the packing number.""" 102 | 103 | # Run the packing number estimation 104 | distinct_routes = estimate_packing_number( 105 | routes=sample_synthesis_routes, 106 | radius=threshold, 107 | distance_metric=metric, 108 | max_packing_number=max_packing_number, 109 | num_tries=100, 110 | random_state=random.Random(100), 111 | ) 112 | assert len(distinct_routes) == expected_packing_number 113 | -------------------------------------------------------------------------------- /syntheseus/tests/interface/test_molecule.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import FrozenInstanceError, dataclass 4 | from typing import Any, Optional 5 | 6 | import pytest 7 | 8 | from syntheseus.interface.molecule import Molecule 9 | 10 | 11 | def test_positive_equality(cocs_mol: Molecule) -> None: 12 | """Various tests that 2 molecules which should be equal are actually equal.""" 13 | 14 | mol_copy = Molecule(smiles=str(cocs_mol.smiles)) 15 | 16 | # Test 1: original and copy should be equal 17 | assert cocs_mol == mol_copy 18 | 19 | # Test 2: although equal, they should be distinct objects 20 | assert cocs_mol is not mol_copy 21 | 22 | # Test 3: differences in metadata should not affect equality 23 | # (type ignores are because we are adding arbitrary unrealistic metadata) 24 | cocs_mol.metadata["test"] = "str1" # type: ignore[typeddict-unknown-key] 25 | mol_copy.metadata["test"] = "str2" # type: ignore[typeddict-unknown-key] 26 | mol_copy.metadata["other_field"] = "not in other mol" # type: ignore[typeddict-unknown-key] 27 | assert cocs_mol == mol_copy 28 | assert cocs_mol.metadata != mol_copy.metadata 29 | 30 | 31 | def test_negative_equality(cocs_mol: Molecule) -> None: 32 | """Various tests that molecules which should not be equal are not equal.""" 33 | 34 | # Test 1: SMILES strings are not mol objects 35 | assert cocs_mol != cocs_mol.smiles 36 | 37 | # Test 2: changing the identifer makes mols not equal 38 | mol_with_different_id = Molecule(smiles=cocs_mol.smiles, identifier="different") 39 | assert cocs_mol != mol_with_different_id 40 | 41 | # Test 3: equality should only be true if class is the same, 42 | # so another object with the same fields should still be not equal 43 | 44 | @dataclass 45 | class FakeMoleculeClass: 46 | smiles: str 47 | identifier: Optional[str] 48 | metadata: dict[str, Any] 49 | 50 | fake_mol = FakeMoleculeClass(smiles=cocs_mol.smiles, identifier=None, metadata=dict()) 51 | assert fake_mol != cocs_mol 52 | 53 | # Test 4: same molecule but with non-canonical SMILES will still compare to False 54 | non_canonical_mol = Molecule(smiles="SCOC", canonicalize=False) 55 | assert non_canonical_mol != cocs_mol 56 | 57 | 58 | def test_frozen(cocs_mol: Molecule) -> None: 59 | """Test that the fields of the Molecule cannot be modified (i.e. is actually frozen).""" 60 | with pytest.raises(FrozenInstanceError): 61 | # type ignore is because mypy complains we are modifying a frozen field, which is the point of the test 62 | cocs_mol.smiles = "xyz" # type: ignore[misc] 63 | 64 | 65 | def test_canonicalization() -> None: 66 | """ 67 | Test that the `canonicalize` argument works as expected, 68 | canonicalizing the SMILES if True and leaving it unchanged if False. 69 | """ 70 | non_canonical_smiles = "OCC" 71 | canonical_smiles = "CCO" 72 | 73 | # Test 1: canonicalize=True should canonicalize the SMILES 74 | mol1 = Molecule(smiles=non_canonical_smiles, canonicalize=True) 75 | assert mol1.smiles == canonical_smiles 76 | 77 | # Test 2: canonicalize=False should leave the SMILES unchanged 78 | mol2 = Molecule(smiles=non_canonical_smiles, canonicalize=False) 79 | assert mol2.smiles == non_canonical_smiles 80 | 81 | 82 | def test_make_rdkit_mol() -> None: 83 | """Test that the argument `make_rdkit_mol` works as expected.""" 84 | 85 | # Test 1: make_rdkit_mol=True 86 | smiles = "CCO" 87 | mol_with_rdkit_mol = Molecule(smiles=smiles, make_rdkit_mol=True) 88 | assert "rdkit_mol" in mol_with_rdkit_mol.metadata 89 | 90 | # Test 2: make_rdkit_mol=False 91 | mol_without_rdkit_mol = Molecule(smiles=smiles, make_rdkit_mol=False) 92 | assert "rdkit_mol" not in mol_without_rdkit_mol.metadata 93 | 94 | # Test 3: accessing the rdkit mol should create it 95 | mol_without_rdkit_mol.rdkit_mol 96 | assert "rdkit_mol" in mol_without_rdkit_mol.metadata 97 | 98 | 99 | def test_sorting() -> None: 100 | """Test that sorting molecules works as expected: by SMILES, then by identifier.""" 101 | 102 | # Make individual molecules 103 | mol1 = Molecule("CC") 104 | mol2 = Molecule("CCC", identifier="") 105 | mol3 = Molecule("CCC", identifier="abc") 106 | mol4 = Molecule("CCC", identifier="def") 107 | 108 | # Test sorting 109 | mol_list = [mol4, mol3, mol2, mol1] 110 | mol_list.sort() 111 | assert mol_list == [mol1, mol2, mol3, mol4] 112 | 113 | 114 | def test_create_and_compare() -> None: 115 | # NOTE: this test has some redundancy with tests from above 116 | mol_1 = Molecule("C") 117 | mol_2 = Molecule("C1=CC(N)=CC=C1") 118 | mol_3 = Molecule("c1cccc(N)c1") 119 | 120 | assert mol_1 < mol_2 # Lexicographical comparison on SMILES. 121 | assert mol_2 == mol_3 # Should be equal after canonicalization. 122 | 123 | 124 | def test_order_of_components() -> None: 125 | assert Molecule("C.CC") == Molecule("CC.C") 126 | 127 | 128 | def test_create_invalid() -> None: 129 | with pytest.raises(ValueError): 130 | Molecule("not-a-real-SMILES") 131 | -------------------------------------------------------------------------------- /syntheseus/tests/search/graph/test_message_passing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Message passing is used in other algorithms and is tested implicitly in the algorithm tests. 3 | 4 | Therefore this file just contains minimal tests of correctness and edge cases. 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import pytest 10 | 11 | from syntheseus.search.graph.and_or import AndOrGraph 12 | from syntheseus.search.graph.message_passing import ( 13 | depth_update, 14 | has_solution_update, 15 | run_message_passing, 16 | ) 17 | 18 | 19 | def test_no_update_functions(andor_tree_non_minimal: AndOrGraph) -> None: 20 | """ 21 | Test that if no update functions are provided, the message passing algorithm 22 | doesn't actually run. 23 | """ 24 | g = andor_tree_non_minimal # rename for brevity 25 | output = run_message_passing(g, g.nodes(), update_fns=[], max_iterations=1) 26 | assert len(output) == 0 27 | 28 | 29 | def test_no_input_nodes(andor_tree_non_minimal: AndOrGraph) -> None: 30 | """ 31 | If no input nodes are provided, the message passing algorithm should 32 | terminate without running. 33 | """ 34 | g = andor_tree_non_minimal # rename for brevity 35 | output = run_message_passing(g, [], update_fns=[has_solution_update], max_iterations=1) 36 | assert len(output) == 0 37 | 38 | 39 | @pytest.mark.parametrize("update_successors", [True, False]) 40 | def test_update_successors(andor_tree_non_minimal: AndOrGraph, update_successors: bool) -> None: 41 | """ 42 | Test that the "update successors" function works as expected by 43 | setting the "has_solution" attribute of the root node to False. 44 | 45 | If update_successors=False then message passing should terminate after 1 iteration. 46 | However, if update_successors=True then message passing should terminate visiting 47 | the root node and its children (3 iterations). In both cases only 1 node should be updated. 48 | """ 49 | g = andor_tree_non_minimal # rename for brevity 50 | if update_successors: 51 | enough_iterations = 3 52 | else: 53 | enough_iterations = 1 54 | too_few_iterations = enough_iterations - 1 55 | 56 | # Test 1: in both cases, root node should be updated to has_solution=True 57 | # and that should be the only node updated 58 | g.root_node.has_solution = False 59 | output = run_message_passing( 60 | g, 61 | [g.root_node], 62 | update_fns=[has_solution_update], 63 | update_successors=update_successors, 64 | max_iterations=enough_iterations, 65 | ) # should run without error 66 | assert g.root_node.has_solution 67 | assert len(output) == 1 68 | 69 | # Test 2: should raise error if too few iterations 70 | g.root_node.has_solution = False 71 | with pytest.raises(RuntimeError): 72 | run_message_passing( 73 | g, 74 | [g.root_node], 75 | update_fns=[has_solution_update], 76 | update_successors=update_successors, 77 | max_iterations=too_few_iterations, 78 | ) 79 | 80 | 81 | @pytest.mark.parametrize("update_predecessors", [True, False]) 82 | def test_update_predecessors(andor_tree_non_minimal: AndOrGraph, update_predecessors: bool) -> None: 83 | """Similar to above but for predecessors, updating the depth of a node.""" 84 | 85 | g = andor_tree_non_minimal # rename for brevity 86 | node_to_perturb = list(g.successors(g.root_node))[0] 87 | if update_predecessors: # NOTE: only enough if not updating successors 88 | too_few_iterations = 1 89 | enough_iterations = 2 90 | else: 91 | too_few_iterations = 0 92 | enough_iterations = 1 93 | 94 | # Test 1: set depth of child node incorrectly 95 | # and check that it is updated correctly. 96 | node_to_perturb.depth = 500 97 | output = run_message_passing( 98 | g, 99 | [node_to_perturb], 100 | update_fns=[depth_update], 101 | update_predecessors=update_predecessors, 102 | max_iterations=enough_iterations, 103 | update_successors=False, 104 | ) # should run without error 105 | assert node_to_perturb.depth == 1 106 | assert len(output) == 1 107 | 108 | # Test 2: should raise error if too few iterations 109 | node_to_perturb.depth = 500 110 | with pytest.raises(RuntimeError): 111 | run_message_passing( 112 | g, 113 | [node_to_perturb], 114 | update_fns=[depth_update], 115 | update_predecessors=update_predecessors, 116 | max_iterations=too_few_iterations, 117 | update_successors=False, 118 | ) 119 | 120 | 121 | def update_fn_which_can_never_converge(*args, **kwargs) -> bool: 122 | return True 123 | 124 | 125 | def test_no_convergence(andor_tree_non_minimal: AndOrGraph) -> None: 126 | """Test that if no convergence is reached, an error is raised.""" 127 | g = andor_tree_non_minimal # rename for brevity 128 | with pytest.raises(RuntimeError): 129 | run_message_passing( 130 | g, g.nodes(), update_fns=[update_fn_which_can_never_converge], max_iterations=1_000 131 | ) 132 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/utils/misc.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import multiprocessing 3 | import os 4 | import random 5 | import warnings 6 | from contextlib import contextmanager, redirect_stderr, redirect_stdout 7 | from dataclasses import fields, is_dataclass 8 | from itertools import islice 9 | from os import devnull 10 | from typing import Any, Dict, Iterable, Iterator, List, Optional 11 | 12 | import numpy as np 13 | from rdkit import rdBase 14 | 15 | from syntheseus.interface.bag import Bag 16 | from syntheseus.interface.molecule import Molecule 17 | 18 | 19 | def set_random_seed(seed: int = 0) -> None: 20 | """Set random seed for `Python`, `torch` and `numpy`.""" 21 | random.seed(seed) 22 | np.random.seed(seed) 23 | 24 | # If `torch` is installed set its seed as well. 25 | try: 26 | import torch 27 | 28 | torch.manual_seed(seed) 29 | except ModuleNotFoundError: 30 | pass 31 | 32 | 33 | @contextmanager 34 | def suppress_outputs(): 35 | """Suppress messages written to both stdout and stderr.""" 36 | with open(devnull, "w") as fnull: 37 | with redirect_stderr(fnull), redirect_stdout(fnull): 38 | # Save the root logger handlers in order to restore them afterwards. 39 | root_handlers = list(logging.root.handlers) 40 | logging.disable(logging.CRITICAL) 41 | 42 | yield 43 | 44 | logging.root.handlers = root_handlers 45 | logging.disable(logging.NOTSET) 46 | 47 | 48 | @contextmanager 49 | def suppress_rdkit_outputs(): 50 | """Suppress warning messages produced by `rdkit`.""" 51 | try: 52 | previous_settings = dict(line.split(":") for line in rdBase.LogStatus().split("\n")) 53 | except Exception: 54 | # If `rdkit` internals change in the future we give up on restoring previous settings 55 | warnings.warn("Could not read rdkit log settings, warnings will be silenced permanently") 56 | previous_settings = {} 57 | 58 | rdBase.DisableLog("rdApp.*") 59 | 60 | yield 61 | 62 | for level, setting in previous_settings.items(): 63 | if setting == "enabled": 64 | rdBase.EnableLog(level) 65 | 66 | 67 | def dictify(data: Any) -> Any: 68 | # Need to ensure we make return objects fully serializable 69 | if isinstance(data, (int, float, str)) or data is None: 70 | return data 71 | elif isinstance(data, Molecule): 72 | return {"smiles": data.smiles} 73 | elif isinstance(data, (List, tuple, Bag)): 74 | # Captures lists of `Prediction`s 75 | return [dictify(x) for x in data] 76 | elif isinstance(data, dict): 77 | return {k: dictify(v) for k, v in data.items()} 78 | elif is_dataclass(data): 79 | result = {} 80 | for f in fields(data): 81 | value = getattr(data, f.name) 82 | result[f.name] = dictify(value) 83 | return result 84 | else: 85 | raise TypeError(f"Type {type(data)} cannot be handled by `dictify`") 86 | 87 | 88 | def asdict_extended(data) -> Dict[str, Any]: 89 | """Convert a dataclass containing various reaction-related objects into a dict.""" 90 | if not is_dataclass(data): 91 | raise TypeError(f"asdict_extended only for use on dataclasses, input is type {type(data)}") 92 | 93 | return dictify(data) 94 | 95 | 96 | def undictify_bag_of_molecules(data: List[Dict[str, str]]) -> Bag[Molecule]: 97 | """Recovers a bag of molecules serialized with `dictify`.""" 98 | return Bag(Molecule(d["smiles"]) for d in data) 99 | 100 | 101 | def parallelize( 102 | fn, 103 | inputs: Iterable, 104 | num_processes: int = 0, 105 | chunksize: int = 32, 106 | num_chunks_per_process_per_segment: Optional[int] = 64, 107 | ) -> Iterator: 108 | """Parallelize an appliation of an arbitrary function using a pool of processes.""" 109 | if num_processes == 0: 110 | yield from map(fn, inputs) 111 | else: 112 | # Needed for the chunking code to work on repeatable iterables e.g. lists. 113 | inputs = iter(inputs) 114 | 115 | with multiprocessing.Pool(num_processes) as pool: 116 | if num_chunks_per_process_per_segment is None: 117 | yield from pool.imap(fn, inputs, chunksize=chunksize) 118 | else: 119 | # A new segment will only be started if the previous one was consumed; this avoids doing 120 | # all the work upfront and storing it in memory if the consumer of the output is slow. 121 | segmentsize = num_chunks_per_process_per_segment * num_processes * chunksize 122 | 123 | non_empty = True 124 | while non_empty: 125 | non_empty = False 126 | 127 | # Call `imap` segment-by-segment to make sure the consumer of the output keeps up. 128 | for result in pool.imap(fn, islice(inputs, segmentsize), chunksize=chunksize): 129 | yield result 130 | non_empty = True 131 | 132 | 133 | def cpu_count(default: int = 8) -> int: 134 | """Return the number of CPUs, fallback to `default` if it cannot be determined.""" 135 | return os.cpu_count() or default 136 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/inference/retro_knn.py: -------------------------------------------------------------------------------- 1 | """Inference wrapper for the RetroKNN model. 2 | 3 | Paper: https://arxiv.org/abs/2306.04123 4 | """ 5 | 6 | from pathlib import Path 7 | from typing import List, Optional, Sequence, Union 8 | 9 | import numpy as np 10 | 11 | from syntheseus.interface.molecule import Molecule 12 | from syntheseus.interface.reaction import SingleProductReaction 13 | from syntheseus.reaction_prediction.inference.local_retro import LocalRetroModel 14 | from syntheseus.reaction_prediction.utils.inference import get_unique_file_in_dir 15 | 16 | 17 | class RetroKNNModel(LocalRetroModel): 18 | """Warpper for RetroKNN model.""" 19 | 20 | def __init__(self, model_dir: Optional[Union[str, Path]] = None, *args, **kwargs) -> None: 21 | """Initializes the RetroKNN model wrapper. 22 | 23 | Assumed format of the model directory: 24 | - `model_dir/local_retro` contains the files needed to load the LocalRetro wrapper 25 | - `model_dir/knn/` contains the adapter checkpoint as the only `*.pt` file 26 | - `model_dir/knn/datastore` contains the data store files 27 | """ 28 | model_dir = Path(model_dir or self.get_default_model_dir()) 29 | super().__init__(model_dir / "local_retro", *args, **kwargs) 30 | 31 | import torch 32 | 33 | from syntheseus.reaction_prediction.models.retro_knn import Adapter 34 | 35 | adapter_chkpt_path = get_unique_file_in_dir(Path(model_dir) / "knn", pattern="*.pt") 36 | datastore_path = Path(model_dir) / "knn" / "datastore" 37 | 38 | import faiss 39 | import faiss.contrib.torch_utils # make faiss available for torch tensors 40 | 41 | def load_data_store(path: Path, device: str): 42 | index = faiss.read_index(str(path), faiss.IO_FLAG_ONDISK_SAME_DIR) 43 | 44 | if device == "cpu": 45 | return index 46 | else: 47 | res = faiss.StandardGpuResources() 48 | co = faiss.GpuClonerOptions() 49 | co.useFloat16 = True 50 | return faiss.index_cpu_to_gpu(res, 0, index, co) 51 | 52 | self.atom_store = load_data_store(datastore_path / "data.atom_idx", device=self.device) 53 | self.bond_store = load_data_store(datastore_path / "data.bond_idx", device=self.device) 54 | self.raw_data = np.load(datastore_path / "data.npz") 55 | 56 | self.adapter = Adapter(self.model.linearB.weight.shape[0], k=32).to(self.device) 57 | self.adapter.load_state_dict(torch.load(adapter_chkpt_path, map_location=self.device)) 58 | self.adapter.eval() 59 | 60 | def _forward_localretro(self, bg): 61 | from local_retro.scripts.model_utils import pair_atom_feats, unbatch_feats, unbatch_mask 62 | 63 | bg = bg.to(self.device) 64 | node_feats = bg.ndata.pop("h").to(self.device) 65 | edge_feats = bg.edata.pop("e").to(self.device) 66 | 67 | node_feats = self.model.mpnn(bg, node_feats, edge_feats) 68 | atom_feats = node_feats 69 | bond_feats = self.model.linearB(pair_atom_feats(bg, node_feats)) 70 | edit_feats, mask = unbatch_mask(bg, atom_feats, bond_feats) 71 | _, edit_feats = self.model.att(edit_feats, mask) 72 | 73 | atom_feats, bond_feats = unbatch_feats(bg, edit_feats) 74 | atom_outs = self.model.atom_linear(atom_feats) 75 | bond_outs = self.model.bond_linear(bond_feats) 76 | 77 | return atom_outs, bond_outs, atom_feats, bond_feats 78 | 79 | def _get_reactions( 80 | self, inputs: List[Molecule], num_results: int 81 | ) -> List[Sequence[SingleProductReaction]]: 82 | import torch 83 | 84 | from syntheseus.reaction_prediction.models.retro_knn import knn_prob 85 | 86 | batch = self._mols_to_batch(inputs) 87 | ( 88 | batch_atom_logits, 89 | batch_bond_logits, 90 | atom_feats, 91 | bond_feats, 92 | ) = self._forward_localretro(batch) 93 | sg = batch.remove_self_loop().to(self.device) 94 | 95 | node_dis, _ = self.atom_store.search(atom_feats, k=32) 96 | edge_dis, _ = self.bond_store.search(bond_feats, k=32) 97 | 98 | node_t, node_p, edge_t, edge_p = self.adapter( 99 | sg, atom_feats, bond_feats, node_dis, edge_dis 100 | ) 101 | 102 | batch_atom_prob_nn = torch.nn.Softmax(dim=1)(batch_atom_logits) 103 | batch_bond_prob_nn = torch.nn.Softmax(dim=1)(batch_bond_logits) 104 | 105 | atom_output_label = torch.from_numpy(self.raw_data["atom_output_label"]).to(self.device) 106 | bond_output_label = torch.from_numpy(self.raw_data["bond_output_label"]).to(self.device) 107 | 108 | batch_atom_prob_knn = knn_prob( 109 | atom_feats, self.atom_store, atom_output_label, batch_atom_logits.shape[1], 32, node_t 110 | ) 111 | batch_bond_prob_knn = knn_prob( 112 | bond_feats, self.bond_store, bond_output_label, batch_bond_logits.shape[1], 32, edge_t 113 | ) 114 | 115 | batch_atom_logits = node_p * batch_atom_prob_nn + (1 - node_p) * batch_atom_prob_knn 116 | batch_bond_logits = edge_p * batch_bond_prob_nn + (1 - edge_p) * batch_bond_prob_knn 117 | 118 | return self._build_batch_predictions( 119 | batch, num_results, inputs, batch_atom_logits, batch_bond_logits 120 | ) 121 | -------------------------------------------------------------------------------- /syntheseus/search/graph/base_graph.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import abc 4 | from collections.abc import Container, Iterable, Sized 5 | from typing import Generic, Sequence, TypeVar 6 | 7 | import networkx as nx 8 | 9 | from syntheseus.interface.molecule import Molecule 10 | from syntheseus.interface.reaction import SingleProductReaction 11 | from syntheseus.search.graph.node import BaseGraphNode 12 | 13 | NodeType = TypeVar( 14 | "NodeType", 15 | ) 16 | SearchNodeType = TypeVar("SearchNodeType", bound=BaseGraphNode) 17 | 18 | 19 | class BaseReactionGraph(Container, Sized, Generic[NodeType], abc.ABC): 20 | """ 21 | Base class for holding a retrosynthesis graph where nodes represent molecules/reactions. 22 | Retrosynthesis graphs have the following properties: 23 | 24 | - Directed: A node usually represents a reaction/molecule used to synthesize its predecessors. 25 | Search generally goes from predecessors to successors, 26 | while the actual synthesis would go from successors to predecessors. 27 | - Root node: there is a root node representing the molecule to be synthesized. 28 | This node should never have a parent. 29 | - Implicit: the graph is a subgraph of a much larger reaction graph, 30 | so nodes can be "unexpanded", meaning that no children nodes are specified, 31 | even though the node should have children. Search typically "expands" 32 | a node by adding reactions. 33 | 34 | The actual implementation of the graph uses networkx's DiGraph 35 | (but it does not inherit from it because this causes many methods to break). 36 | """ 37 | 38 | def __init__(self, *args, **kwargs) -> None: 39 | self._graph = nx.DiGraph(*args, **kwargs) 40 | 41 | def __contains__(self, node: object) -> bool: 42 | return node in self._graph 43 | 44 | def __len__(self) -> int: 45 | return len(self._graph) 46 | 47 | @property 48 | @abc.abstractmethod 49 | def root_node(self) -> NodeType: 50 | """Root node of the graph, representing the molecule to be synthesized.""" 51 | pass 52 | 53 | @property 54 | @abc.abstractmethod 55 | def root_mol(self) -> Molecule: 56 | """The molecule to be synthesized.""" 57 | pass 58 | 59 | @abc.abstractmethod 60 | def is_minimal(self) -> bool: 61 | """Checks whether this is a *minimal* graph (i.e. contains a single synthesis route).""" 62 | pass 63 | 64 | def is_tree(self) -> bool: 65 | """Performs a [possibly expensive] check to see if the graph is a tree.""" 66 | return nx.is_arborescence(self._graph) 67 | 68 | def nodes(self) -> Iterable[NodeType]: 69 | return self._graph.nodes() 70 | 71 | def predecessors(self, node: NodeType) -> Iterable[NodeType]: 72 | """Returns the predecessors of a node.""" 73 | return self._graph.predecessors(node) 74 | 75 | def successors(self, node: NodeType) -> Iterable[NodeType]: 76 | """Returns the successors of a node.""" 77 | return self._graph.successors(node) 78 | 79 | def assert_validity(self) -> None: 80 | """ 81 | A (potentially expensive) function to check the graph's validity. 82 | """ 83 | 84 | # Check root node is in the graph 85 | assert self.root_node in self 86 | 87 | # Check root node has no parents 88 | assert len(list(self.predecessors(self.root_node))) == 0 89 | 90 | # Graph should be connected 91 | assert nx.is_weakly_connected(self._graph) 92 | 93 | def __eq__(self, __value: object) -> bool: 94 | """Equality is defined as having the same root node, nodes, and edges.""" 95 | if isinstance(__value, BaseReactionGraph): 96 | return (self.root_node == __value.root_node) and nx.utils.graphs_equal( 97 | self._graph, __value._graph 98 | ) 99 | else: 100 | return False 101 | 102 | @abc.abstractmethod 103 | def expand_with_reactions( 104 | self, 105 | reactions: list[SingleProductReaction], 106 | node: NodeType, 107 | ensure_tree: bool, 108 | ) -> Sequence[NodeType]: 109 | """ 110 | Expands a node with a series of reactions. 111 | For reproducibility, it ensures that the order of the nodes is consistent, 112 | which is why a sequence is returned. 113 | It is also encouraged (but not required) for all return nodes to be unique. 114 | 115 | Subclass implementations should ensure that the root node never has any predecessors 116 | as a result of this function, and if ensure_tree=True, should also ensure that the 117 | graph remains a tree. 118 | """ 119 | pass 120 | 121 | 122 | class RetrosynthesisSearchGraph(BaseReactionGraph[SearchNodeType], Generic[SearchNodeType]): 123 | """Subclass with more specific type requirements for the nodes.""" 124 | 125 | @abc.abstractmethod 126 | def _assert_valid_reactions(self) -> None: 127 | """Checks that all reactions are valid.""" 128 | pass 129 | 130 | def assert_validity(self) -> None: 131 | super().assert_validity() 132 | 133 | # Check that all nodes with children are marked as expanded 134 | for n in self._graph.nodes: 135 | if not n.is_expanded: 136 | assert len(list(self.successors(n))) == 0 137 | 138 | # Check valid reactions 139 | self._assert_valid_reactions() 140 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

Navigating the labyrinth of synthesis planning

4 | 5 | --- 6 | 7 |

8 | Docs • 9 | CLI • 10 | Tutorials • 11 | Paper 12 |

13 | 14 | [![CI](https://github.com/microsoft/syntheseus/actions/workflows/ci.yml/badge.svg?branch=main)](https://github.com/microsoft/syntheseus/actions/workflows/ci.yml) 15 | [![Python Version](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/) 16 | [![pypi](https://img.shields.io/pypi/v/syntheseus.svg)](https://pypi.org/project/syntheseus/) 17 | [![code style](https://img.shields.io/badge/code%20style-black-202020.svg)](https://github.com/ambv/black) 18 | [![License](https://img.shields.io/badge/license-MIT-green.svg)](https://github.com/microsoft/syntheseus/blob/main/LICENSE) 19 | 20 |
21 | 22 | ## Overview 23 | 24 | Syntheseus is a package for end-to-end retrosynthetic planning. 25 | - ⚒️ Combines search algorithms and reaction models in a standardized way 26 | - 🧭 Includes implementations of common search algorithms 27 | - 🧪 Includes wrappers for state-of-the-art reaction models 28 | - ⚙️ Exposes a simple API to plug in custom models and algorithms 29 | - 📈 Can be used to benchmark components of a retrosynthesis pipeline 30 | 31 | ## Quick Start 32 | 33 | To install `syntheseus` with all the extras, run 34 | 35 | ```bash 36 | conda env create -f environment_full.yml 37 | conda activate syntheseus-full 38 | 39 | pip install "syntheseus[all]" 40 | ``` 41 | 42 | This sample environment pins `torch` to version `2.2.2`. To run the models under `1.x` please downgrade to `syntheseus 0.6.0`. 43 | 44 | See [here](https://microsoft.github.io/syntheseus/stable/installation) if you prefer a more lightweight installation that only includes the parts you actually need. 45 | 46 | ## Citation and usage 47 | 48 | Since the release of our package, we've been thrilled to see syntheseus be used in the following projects: 49 | 50 | | **Project** | **Usage** | **Reference(s)** | 51 | |:--------------|:-----|:-----------| 52 | |Retro-fallback search|Multi-step search|ICLR [paper](https://arxiv.org/abs/2310.09270), [code](https://github.com/AustinT/retro-fallback-iclr24)| 53 | |RetroGFN|Pre-packaged single-step models|arXiv [paper](https://arxiv.org/abs/2406.18739), [code](https://github.com/gmum/RetroGFN)| 54 | |TANGO|Single-step and multi-step|arXiv [paper](https://arxiv.org/abs/2410.11527)| 55 | |SimpRetro|Multi-step search|JCIM [paper](https://pubs.acs.org/doi/10.1021/acs.jcim.4c00432), [code](https://github.com/catalystforyou/SimpRetro)| 56 | 57 | If you use syntheseus in an academic project, please consider citing our 58 | [associated paper from Faraday Discussions](https://pubs.rsc.org/en/content/articlelanding/2024/fd/d4fd00093e) 59 | (bibtex below). You can also message us or submit a PR to have your project added to the table above! 60 | 61 | ``` 62 | @article{maziarz2024re, 63 | title={Re-evaluating retrosynthesis algorithms with syntheseus}, 64 | author={Maziarz, Krzysztof and Tripp, Austin and Liu, Guoqing and Stanley, Megan and Xie, Shufang and Gainski, Piotr and Seidl, Philipp and Segler, Marwin}, 65 | journal={Faraday Discussions}, 66 | year={2024}, 67 | publisher={Royal Society of Chemistry} 68 | } 69 | ``` 70 | 71 | ## Development 72 | 73 | Syntheseus is currently under active development. 74 | If you want to help us develop syntheseus please install and run `pre-commit` 75 | checks before committing code. 76 | 77 | We use `pytest` for testing. Please make sure tests pass on your branch before 78 | submitting a PR (and try to maintain high test coverage). 79 | 80 | ```bash 81 | python -m pytest --cov syntheseus/tests 82 | ``` 83 | 84 | ## Contributing 85 | 86 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 87 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 88 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 89 | 90 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 91 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 92 | provided by the bot. You will only need to do this once across all repos using our CLA. 93 | 94 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 95 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 96 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 97 | 98 | ## Trademarks 99 | 100 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 101 | trademarks or logos is subject to and must follow 102 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 103 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 104 | Any use of third-party trademarks or logos are subject to those third-party's policies. 105 | -------------------------------------------------------------------------------- /syntheseus/tests/search/graph/test_molset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for MolSet graph/nodes. 3 | 4 | Because a lot of the behaviour is implicitly tested when the algorithms are tested, 5 | the tests here are sparse and mainly check edge cases which won't come up in algorithms. 6 | """ 7 | import pytest 8 | 9 | from syntheseus.interface.bag import Bag 10 | from syntheseus.interface.molecule import Molecule 11 | from syntheseus.interface.reaction import SingleProductReaction 12 | from syntheseus.search.graph.molset import MolSetGraph, MolSetNode 13 | from syntheseus.search.graph.route import SynthesisGraph 14 | from syntheseus.tests.search.graph.test_base import BaseNodeTest 15 | 16 | 17 | class TestMolSetNode(BaseNodeTest): 18 | def get_node(self): 19 | return MolSetNode(mols=Bag([Molecule("CC")])) 20 | 21 | def test_nodes_not_frozen(self): 22 | node = self.get_node() 23 | node.mols = None 24 | 25 | 26 | class TestMolSetGraph: 27 | """Tests for MolSetGraph: they mostly follow the same pattern as AND/OR graph tests.""" 28 | 29 | def test_basic_properties1( 30 | self, cocs_mol: Molecule, molset_tree_non_minimal: MolSetGraph 31 | ) -> None: 32 | """Test some basic properties (len, contains) on a graph.""" 33 | assert len(molset_tree_non_minimal) == 10 34 | assert molset_tree_non_minimal.root_node in molset_tree_non_minimal 35 | assert molset_tree_non_minimal.root_mol == cocs_mol 36 | 37 | def _check_not_minimal(self, graph: MolSetGraph) -> None: 38 | """Checks that a graph is not minimal (eliminate code duplication between tests below).""" 39 | assert not graph.is_minimal() 40 | assert graph.is_tree() # should always be a tree 41 | with pytest.raises(AssertionError): 42 | graph.to_synthesis_graph() 43 | 44 | def test_minimal_negative(self, molset_tree_non_minimal: MolSetGraph) -> None: 45 | """Test that a non-minimal molset graph is not identified as minimal.""" 46 | self._check_not_minimal(molset_tree_non_minimal) 47 | 48 | def test_minimal_negative_hard(self, molset_tree_almost_minimal: MolSetGraph) -> None: 49 | """Harder test for minimal negative: the 'almost minimal' graph.""" 50 | self._check_not_minimal(molset_tree_almost_minimal) 51 | 52 | def test_minimal_positive( 53 | self, molset_tree_minimal: MolSetGraph, minimal_synthesis_graph: SynthesisGraph 54 | ) -> None: 55 | graph = molset_tree_minimal 56 | assert graph.is_minimal() 57 | assert graph.is_tree() # should always be a tree 58 | route = graph.to_synthesis_graph() # should run without error 59 | assert route == minimal_synthesis_graph 60 | 61 | @pytest.mark.parametrize( 62 | "reason", ["root_has_parent", "unexpanded_expanded", "reactions_dont_match"] 63 | ) 64 | def test_assert_validity_negative( 65 | self, molset_tree_non_minimal: MolSetGraph, reason: str 66 | ) -> None: 67 | """ 68 | Test that an invalid MolSet graph is correctly identified as invalid. 69 | 70 | Different reasons for invalidity are tested. 71 | """ 72 | graph = molset_tree_non_minimal 73 | 74 | if reason == "root_has_parent": 75 | # Add a random connection to the root node 76 | random_node = [n for n in graph.nodes() if n is not graph.root_node][0] 77 | graph._graph.add_edge(random_node, graph.root_node) 78 | elif reason == "unexpanded_expanded": 79 | graph.root_node.is_expanded = False 80 | elif reason == "reactions_dont_match": 81 | child = list(graph.successors(graph.root_node))[0] 82 | 83 | # Set the reaction to a random incorrect reaction 84 | graph._graph.edges[graph.root_node, child]["reaction"] = SingleProductReaction( 85 | reactants=Bag([Molecule("OO")]), product=Molecule("CC") 86 | ) 87 | 88 | # Not only should the graph not be valid below, 89 | # but specifically the reaction should also be invalid 90 | with pytest.raises(AssertionError): 91 | graph._assert_valid_reactions() 92 | else: 93 | raise ValueError(f"Unsupported reason: {reason}") 94 | 95 | with pytest.raises(AssertionError): 96 | graph.assert_validity() 97 | 98 | @pytest.mark.parametrize("reason", ["already_expanded", "wrong_product"]) 99 | def test_invalid_expansions( 100 | self, 101 | molset_tree_non_minimal: MolSetGraph, 102 | rxn_cs_from_cc: SingleProductReaction, 103 | reason: str, 104 | ) -> None: 105 | """ 106 | Test that invalid expansions raise an error. 107 | Note that valid expansions are tested implicitly elsewhere. 108 | 109 | NOTE: because graphs with 1 node per molset are not properly supported now, 110 | we don't test the case where the product of a reaction is the root mol. 111 | """ 112 | graph = molset_tree_non_minimal 113 | cc_node = [n for n in graph.nodes() if n.mols == {Molecule("CC")}].pop() 114 | 115 | if reason == "already_expanded": 116 | with pytest.raises(AssertionError): 117 | graph.expand_with_reactions(reactions=[], node=graph.root_node, ensure_tree=True) 118 | elif reason == "wrong_product": 119 | with pytest.raises(AssertionError): 120 | graph.expand_with_reactions( 121 | reactions=[rxn_cs_from_cc], node=cc_node, ensure_tree=True 122 | ) 123 | else: 124 | raise ValueError(f"Unsupported reason: {reason}") 125 | -------------------------------------------------------------------------------- /syntheseus/tests/reaction_prediction/data/test_dataset.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | from pathlib import Path 4 | from typing import Generator 5 | 6 | import pytest 7 | 8 | from syntheseus.interface.bag import Bag 9 | from syntheseus.interface.molecule import Molecule 10 | from syntheseus.reaction_prediction.data.dataset import ( 11 | CSV_REACTION_SMILES_COLUMN_NAME, 12 | DataFold, 13 | DataFormat, 14 | DiskReactionDataset, 15 | ) 16 | from syntheseus.reaction_prediction.data.reaction_sample import ReactionSample 17 | 18 | 19 | @pytest.fixture(params=[False, True]) 20 | def temp_path(request, tmp_path: Path) -> Generator[Path, None, None]: 21 | """Fixture to provide a temporary path that is either absolute or relative.""" 22 | if request.param: 23 | # The built-in `tmp_path` fixture is absolute; if the fixture parameter is `True` we convert 24 | # that to a relative path. This is done by stripping away the root `/` that signifies an 25 | # absolute path and changing the working directory to `/` so that the path remains correct. 26 | 27 | old_working_dir = os.getcwd() 28 | os.chdir("/") 29 | 30 | yield Path(*list(tmp_path.parts)[1:]) 31 | 32 | os.chdir(old_working_dir) 33 | else: 34 | yield tmp_path 35 | 36 | 37 | @pytest.mark.parametrize("mapped", [False, True]) 38 | def test_save_and_load(temp_path: Path, mapped: bool) -> None: 39 | samples = [ 40 | ReactionSample.from_reaction_smiles_strict(reaction_smiles, mapped=mapped) 41 | for reaction_smiles in [ 42 | "O[c:1]1[cH:2][c:3](=[O:4])[nH:5][cH:6][cH:7]1>>[cH:1]1[cH:2][c:3](=[O:4])[nH:5][cH:6][cH:7]1", 43 | "CC(C)(C)OC(=O)[N:1]1[CH2:2][CH2:3][C@H:4]([F:5])[CH2:6]1>>[NH:1]1[CH2:2][CH2:3][C@H:4]([F:5])[CH2:6]1", 44 | ] 45 | ] 46 | 47 | for fold in DataFold: 48 | DiskReactionDataset.save_samples_to_file(data_dir=temp_path, fold=fold, samples=samples) 49 | 50 | for load_format in [None, DataFormat.JSONL]: 51 | # Now try to load the data we just saved. 52 | dataset = DiskReactionDataset(temp_path, sample_cls=ReactionSample, data_format=load_format) 53 | 54 | for fold in DataFold: 55 | assert list(dataset[fold]) == samples 56 | 57 | 58 | @pytest.mark.parametrize("format", [DataFormat.CSV, DataFormat.SMILES]) 59 | def test_load_external_format(temp_path: Path, format: DataFormat) -> None: 60 | # Example reaction SMILES, purposefully using non-canonical forms of reactants and product. 61 | reaction_smiles = ( 62 | "[cH:1]1[cH:2][c:3]([CH3:4])[cH:5][cH:6][c:7]1Br.B(O)(O)[c:8]1[cH:9][cH:10][c:11]([CH3:12])[cH:13][cH:14]1>>" 63 | "[cH:1]1[cH:2][c:3]([CH3:4])[cH:5][cH:6][c:7]1[c:8]2[cH:14][cH:13][c:11]([CH3:12])[cH:10][cH:9]2" 64 | ) 65 | 66 | filename = DiskReactionDataset.get_filename_suffix(format=format, fold=DataFold.TRAIN) 67 | with open(temp_path / filename, "wt") as f: 68 | if format == DataFormat.CSV: 69 | writer = csv.DictWriter(f, fieldnames=["id", "class", CSV_REACTION_SMILES_COLUMN_NAME]) 70 | writer.writeheader() 71 | writer.writerow( 72 | {"id": 0, "class": "UNK", CSV_REACTION_SMILES_COLUMN_NAME: reaction_smiles} 73 | ) 74 | else: 75 | f.write(f"{reaction_smiles}\n") 76 | 77 | for load_format in [None, format]: 78 | dataset = DiskReactionDataset(temp_path, sample_cls=ReactionSample, data_format=load_format) 79 | assert dataset.get_num_samples(DataFold.TRAIN) == 1 80 | 81 | samples = list(dataset[DataFold.TRAIN]) 82 | assert len(samples) == 1 83 | 84 | [sample] = samples 85 | 86 | # After loading, the reactants and products should be in canonical SMILES form, with the 87 | # atom mapping removed. 88 | assert sample == ReactionSample( 89 | reactants=Bag([Molecule("Cc1ccc(Br)cc1"), Molecule("Cc1ccc(B(O)O)cc1")]), 90 | products=Bag([Molecule("Cc1ccc(-c2ccc(C)cc2)cc1")]), 91 | ) 92 | 93 | # The original reaction SMILES should have been saved separately. 94 | assert sample.mapped_reaction_smiles == reaction_smiles 95 | 96 | 97 | @pytest.mark.parametrize("format", [DataFormat.JSONL, DataFormat.CSV, DataFormat.SMILES]) 98 | def test_format_detection(temp_path: Path, format: DataFormat) -> None: 99 | other_format = (set(DataFormat) - {format}).pop() 100 | 101 | # Create two files with different extensions, so that it is ambiguous which format we want. 102 | (temp_path / DiskReactionDataset.get_filename_suffix(format, DataFold.TRAIN)).touch() 103 | (temp_path / DiskReactionDataset.get_filename_suffix(other_format, DataFold.TEST)).touch() 104 | 105 | # Loading with automatic resolution should fail. 106 | with pytest.raises(ValueError): 107 | DiskReactionDataset(data_dir=temp_path, sample_cls=ReactionSample) 108 | 109 | # Loading with an explicit format should succeed. 110 | for f in [format, other_format]: 111 | DiskReactionDataset(data_dir=temp_path, sample_cls=ReactionSample, data_format=f) 112 | 113 | # Loading with an explicit format but no matching files should fail. 114 | another_format = (set(DataFormat) - {format, other_format}).pop() 115 | with pytest.raises(ValueError): 116 | DiskReactionDataset( 117 | data_dir=temp_path, sample_cls=ReactionSample, data_format=another_format 118 | ) 119 | 120 | # Create another file with the right suffix. 121 | (temp_path / f"raw_{DiskReactionDataset.get_filename_suffix(format, DataFold.TRAIN)}").touch() 122 | 123 | # Loading with an explicit format should now fail due to ambiguity. 124 | with pytest.raises(ValueError): 125 | DiskReactionDataset(data_dir=temp_path, sample_cls=ReactionSample, data_format=format) 126 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/data/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import csv 4 | import json 5 | import logging 6 | from abc import abstractmethod 7 | from enum import Enum 8 | from functools import partial 9 | from pathlib import Path 10 | from typing import Dict, Generic, Iterable, List, Optional, Type, TypeVar, Union 11 | 12 | from more_itertools import ilen 13 | 14 | from syntheseus.reaction_prediction.data.reaction_sample import ReactionSample 15 | from syntheseus.reaction_prediction.utils.misc import asdict_extended, parallelize 16 | 17 | logger = logging.getLogger(__file__) 18 | 19 | SampleType = TypeVar("SampleType", bound=ReactionSample) 20 | 21 | 22 | CSV_REACTION_SMILES_COLUMN_NAME = "reactants>reagents>production" 23 | 24 | 25 | class DataFold(Enum): 26 | TRAIN = "train" 27 | VALIDATION = "val" 28 | TEST = "test" 29 | 30 | 31 | class DataFormat(Enum): 32 | JSONL = "jsonl" 33 | CSV = "csv" 34 | SMILES = "smi" 35 | 36 | 37 | class ReactionDataset(Generic[SampleType]): 38 | """Dataset holding raw reactions split into folds.""" 39 | 40 | @abstractmethod 41 | def __getitem__(self, fold: DataFold) -> Iterable[SampleType]: 42 | pass 43 | 44 | @abstractmethod 45 | def get_num_samples(self, fold: DataFold) -> int: 46 | pass 47 | 48 | 49 | class DiskReactionDataset(ReactionDataset[SampleType]): 50 | def __init__( 51 | self, 52 | data_dir: Union[str, Path], 53 | sample_cls: Type[SampleType], 54 | num_processes: int = 0, 55 | data_format: Optional[DataFormat] = None, 56 | ): 57 | self._data_dir = Path(data_dir) 58 | self._sample_cls = sample_cls 59 | self._num_processes = num_processes 60 | 61 | paths = list(self._data_dir.iterdir()) 62 | 63 | if data_format is None: 64 | logger.info(f"Detecting data format from files: {[path.name for path in paths]}") 65 | formats_to_try = list(DataFormat) 66 | else: 67 | formats_to_try = [data_format] 68 | 69 | matches = { 70 | format: DiskReactionDataset.match_paths_to_folds(format=format, paths=paths) 71 | for format in formats_to_try 72 | } 73 | matches = {key: values for key, values in matches.items() if values} 74 | 75 | if data_format is None: 76 | if len(matches) != 1: 77 | raise ValueError( 78 | f"Format detection failed (formats matching the file list: {[f.name for f in matches]})" 79 | ) 80 | elif not matches: 81 | raise ValueError( 82 | f"No files matching *{{train, val, test}}.{data_format.value} were found" 83 | ) 84 | 85 | [(self._data_format, self._fold_to_path)] = matches.items() 86 | 87 | if data_format is None: 88 | logger.info(f"Detected format: {self._data_format.name}") 89 | 90 | logger.info(f"Loading data from files {self._fold_to_path}") 91 | self._num_samples: Dict[DataFold, int] = {} 92 | 93 | def _get_lines(self, fold: DataFold) -> Iterable[str]: 94 | if fold not in self._fold_to_path: 95 | return [] 96 | else: 97 | with open(self._fold_to_path[fold]) as f: 98 | if self._data_format == DataFormat.CSV: 99 | for row in csv.DictReader(f): 100 | if CSV_REACTION_SMILES_COLUMN_NAME not in row: 101 | raise ValueError( 102 | f"No {CSV_REACTION_SMILES_COLUMN_NAME} column found in the CSV data file" 103 | ) 104 | yield row[CSV_REACTION_SMILES_COLUMN_NAME] 105 | else: 106 | for line in f: 107 | yield line.rstrip() 108 | 109 | def __getitem__(self, fold: DataFold) -> Iterable[SampleType]: 110 | if self._data_format == DataFormat.JSONL: 111 | parse_fn = partial(DiskReactionDataset.sample_from_json, sample_cls=self._sample_cls) 112 | else: 113 | parse_fn = partial(self._sample_cls.from_reaction_smiles_strict, mapped=True) 114 | 115 | yield from parallelize(parse_fn, self._get_lines(fold), num_processes=self._num_processes) 116 | 117 | def get_num_samples(self, fold: DataFold) -> int: 118 | if fold not in self._num_samples: 119 | self._num_samples[fold] = ilen(self._get_lines(fold)) 120 | 121 | return self._num_samples[fold] 122 | 123 | @staticmethod 124 | def match_paths_to_folds(format: DataFormat, paths: List[Path]) -> Dict[DataFold, Path]: 125 | fold_to_path: Dict[DataFold, Path] = {} 126 | for fold in DataFold: 127 | suffix = DiskReactionDataset.get_filename_suffix(format, fold) 128 | matching_paths = [path for path in paths if path.name.endswith(suffix)] 129 | 130 | if len(matching_paths) > 1: 131 | raise ValueError( 132 | f"Found more than one {format.value} file for fold {fold.name}: {matching_paths}" 133 | ) 134 | 135 | if matching_paths: 136 | [path] = matching_paths 137 | fold_to_path[fold] = path 138 | 139 | return fold_to_path 140 | 141 | @staticmethod 142 | def get_filename_suffix(format: DataFormat, fold: DataFold) -> str: 143 | return f"{fold.value}.{format.value}" 144 | 145 | @staticmethod 146 | def sample_from_json(data: str, sample_cls: Type[SampleType]) -> SampleType: 147 | return sample_cls.from_dict(json.loads(data)) 148 | 149 | @staticmethod 150 | def save_samples_to_file( 151 | data_dir: Union[str, Path], fold: DataFold, samples: Iterable[SampleType] 152 | ) -> None: 153 | filename = DiskReactionDataset.get_filename_suffix(format=DataFormat.JSONL, fold=fold) 154 | 155 | with open(Path(data_dir) / filename, "wt") as f: 156 | for sample in samples: 157 | f.write(json.dumps(asdict_extended(sample)) + "\n") 158 | -------------------------------------------------------------------------------- /syntheseus/tests/cli/test_cli.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import math 4 | import sys 5 | import tempfile 6 | import urllib 7 | import zipfile 8 | from pathlib import Path 9 | from typing import Generator, List 10 | 11 | import omegaconf 12 | import pytest 13 | 14 | from syntheseus.reaction_prediction.inference.config import BackwardModelClass 15 | from syntheseus.reaction_prediction.utils.testing import are_single_step_models_installed 16 | 17 | pytestmark = pytest.mark.skipif( 18 | not are_single_step_models_installed(), 19 | reason="CLI tests require all single-step models to be installed", 20 | ) 21 | 22 | 23 | MODEL_CLASSES_TO_TEST = set(BackwardModelClass) - {BackwardModelClass.GLN} 24 | 25 | 26 | @pytest.fixture(scope="module") 27 | def data_dir() -> Generator[Path, None, None]: 28 | with tempfile.TemporaryDirectory() as raw_tempdir: 29 | tempdir = Path(raw_tempdir) 30 | 31 | # Download the raw USPTO-50K data released by the authors of GLN. 32 | uspto50k_zip_path = tempdir / "uspto50k.zip" 33 | urllib.request.urlretrieve( 34 | "https://figshare.com/ndownloader/files/45206101", uspto50k_zip_path 35 | ) 36 | 37 | with zipfile.ZipFile(uspto50k_zip_path, "r") as f_zip: 38 | f_zip.extractall(tempdir) 39 | 40 | # Create a simple search targets file with a single target and a matching inventory following 41 | # the same example class of reactions as those used in `test_call`. 42 | 43 | search_targets_file_path = tempdir / "search_targets.smiles" 44 | with open(search_targets_file_path, "wt") as f_search_targets: 45 | f_search_targets.write("Cc1ccc(-c2ccc(C)cc2)cc1\n") 46 | 47 | inventory_file_path = tempdir / "inventory.smiles" 48 | with open(inventory_file_path, "wt") as f_inventory: 49 | for leaving_group in ["Br", "B(O)O", "I", "[Mg+]"]: 50 | f_inventory.write(f"Cc1ccc({leaving_group})cc1\n") 51 | 52 | yield tempdir 53 | 54 | 55 | @pytest.fixture 56 | def eval_cli_argv() -> List[str]: 57 | import torch 58 | 59 | return ["eval-single-step", f"num_gpus={int(torch.cuda.is_available())}"] 60 | 61 | 62 | @pytest.fixture 63 | def search_cli_argv() -> List[str]: 64 | import torch 65 | 66 | return ["search", f"use_gpu={torch.cuda.is_available()}"] 67 | 68 | 69 | def run_cli_with_argv(argv: List[str]) -> None: 70 | # The import below pulls in some optional dependencies, so do it locally to avoid executing it 71 | # if the test suite is being skipped. 72 | from syntheseus.cli.main import main 73 | 74 | sys.argv = ["syntheseus"] + argv 75 | main() 76 | 77 | 78 | def test_cli_invalid( 79 | data_dir: Path, tmpdir: Path, eval_cli_argv: List[str], search_cli_argv: List[str] 80 | ) -> None: 81 | """Test various incomplete or invalid CLI calls that should all raise an error.""" 82 | argv_lists: List[List[str]] = [ 83 | [], 84 | ["not-a-real-command"], 85 | eval_cli_argv + ["model_class=LocalRetro"], # No data dir. 86 | eval_cli_argv + ["model_class=LocalRetro", f"data_dir={tmpdir}"], # No data. 87 | eval_cli_argv + ["model_class=FakeModel", f"data_dir={data_dir}"], # Not a real model. 88 | search_cli_argv 89 | + [ 90 | "model_class=LocalRetro", 91 | f"search_targets_file={data_dir}/search_targets.smiles", 92 | ], # No inventory. 93 | search_cli_argv 94 | + [ 95 | "model_class=LocalRetro", 96 | f"inventory_smiles_file={data_dir}/inventory.smiles", 97 | ], # No search targets. 98 | search_cli_argv 99 | + [ 100 | "model_class=FakeModel", 101 | f"search_targets_file={data_dir}/search_targets.smiles", 102 | f"inventory_smiles_file={data_dir}/inventory.smiles", 103 | ], # Not a real model. 104 | ] 105 | 106 | for argv in argv_lists: 107 | with pytest.raises((ValueError, omegaconf.errors.MissingMandatoryValue)): 108 | run_cli_with_argv(argv) 109 | 110 | 111 | @pytest.mark.parametrize("model_class", MODEL_CLASSES_TO_TEST) 112 | def test_cli_eval_single_step( 113 | model_class: BackwardModelClass, data_dir: Path, tmpdir: Path, eval_cli_argv: List[str] 114 | ) -> None: 115 | run_cli_with_argv( 116 | eval_cli_argv 117 | + [ 118 | f"model_class={model_class}", 119 | f"data_dir={data_dir}", 120 | f"results_dir={tmpdir}", 121 | "num_top_results=5", 122 | "print_idxs=[1,5]", 123 | "num_dataset_truncation=10", 124 | ] 125 | ) 126 | 127 | [results_path] = glob.glob(f"{tmpdir}/{model_class.name}_*.json") 128 | 129 | with open(results_path, "rt") as f: 130 | results = json.load(f) 131 | 132 | top_1_accuracy = results["top_k"][0] 133 | 134 | # We just evaluated a tiny sample of the data, so only make a rough check that the accuracy is 135 | # ballpark reasonable (full test set accuracy would be around ~50%). 136 | assert 0.2 <= top_1_accuracy <= 0.8 137 | 138 | 139 | @pytest.mark.parametrize("model_class", MODEL_CLASSES_TO_TEST) 140 | @pytest.mark.parametrize("search_algorithm", ["retro_star", "mcts", "pdvn"]) 141 | def test_cli_search( 142 | model_class: BackwardModelClass, 143 | search_algorithm: str, 144 | data_dir: Path, 145 | tmpdir: Path, 146 | search_cli_argv: List[str], 147 | ) -> None: 148 | run_cli_with_argv( 149 | search_cli_argv 150 | + [ 151 | f"model_class={model_class}", 152 | f"search_algorithm={search_algorithm}", 153 | f"results_dir={tmpdir}", 154 | f"search_targets_file={data_dir}/search_targets.smiles", 155 | f"inventory_smiles_file={data_dir}/inventory.smiles", 156 | "limit_iterations=3", 157 | "num_top_results=5", 158 | ] 159 | ) 160 | 161 | results_dir = f"{tmpdir}/{model_class.name}_*/" 162 | [results_path] = glob.glob(f"{results_dir}/stats.json") 163 | 164 | with open(results_path, "rt") as f: 165 | results = json.load(f) 166 | 167 | # Assert that a solution was found. 168 | assert results["soln_time_rxn_model_calls"] < math.inf 169 | assert len(glob.glob(f"{results_dir}/route_*.pdf")) >= 1 170 | -------------------------------------------------------------------------------- /syntheseus/reaction_prediction/inference/local_retro.py: -------------------------------------------------------------------------------- 1 | """Inference wrapper for the LocalRetro model. 2 | 3 | Paper: https://pubs.acs.org/doi/10.1021/jacsau.1c00246 4 | Code: https://github.com/kaist-amsg/LocalRetro 5 | 6 | The original LocalRetro code is released under the Apache 2.0 license. 7 | Parts of this file are based on code from the GitHub repository above. 8 | """ 9 | 10 | from pathlib import Path 11 | from typing import Any, List, Sequence 12 | 13 | from syntheseus.interface.molecule import Molecule 14 | from syntheseus.interface.reaction import SingleProductReaction 15 | from syntheseus.reaction_prediction.inference.base import ExternalBackwardReactionModel 16 | from syntheseus.reaction_prediction.utils.inference import ( 17 | get_unique_file_in_dir, 18 | process_raw_smiles_outputs_backwards, 19 | ) 20 | from syntheseus.reaction_prediction.utils.misc import suppress_outputs 21 | 22 | 23 | class LocalRetroModel(ExternalBackwardReactionModel): 24 | def __init__(self, *args, **kwargs) -> None: 25 | """Initializes the LocalRetro model wrapper. 26 | 27 | Assumed format of the model directory: 28 | - `model_dir` contains the model checkpoint as the only `*.pth` file 29 | - `model_dir` contains the config as the only `*.json` file 30 | - `model_dir/data` contains `*.csv` data files needed by LocalRetro 31 | """ 32 | super().__init__(*args, **kwargs) 33 | 34 | from local_retro.Retrosynthesis import load_templates 35 | from local_retro.scripts.utils import init_featurizer, load_model 36 | 37 | data_dir = Path(self.model_dir) / "data" 38 | self.args = init_featurizer( 39 | { 40 | "mode": "test", 41 | "device": self.device, 42 | "model_path": get_unique_file_in_dir(self.model_dir, pattern="*.pth"), 43 | "config_path": get_unique_file_in_dir(self.model_dir, pattern="*.json"), 44 | "data_dir": data_dir, 45 | "rxn_class_given": False, 46 | } 47 | ) 48 | 49 | with suppress_outputs(): 50 | self.model = load_model(self.args) 51 | 52 | [ 53 | self.args["atom_templates"], 54 | self.args["bond_templates"], 55 | self.args["template_infos"], 56 | ] = load_templates(self.args) 57 | 58 | def get_parameters(self): 59 | return self.model.parameters() 60 | 61 | def _mols_to_batch(self, mols: List[Molecule]) -> Any: 62 | from dgllife.utils import smiles_to_bigraph 63 | from local_retro.scripts.utils import collate_molgraphs_test 64 | 65 | graphs = [ 66 | smiles_to_bigraph( 67 | mol.smiles, 68 | node_featurizer=self.args["node_featurizer"], 69 | edge_featurizer=self.args["edge_featurizer"], 70 | add_self_loop=True, 71 | canonical_atom_order=False, 72 | ) 73 | for mol in mols 74 | ] 75 | 76 | return collate_molgraphs_test([(None, graph, None) for graph in graphs])[1] 77 | 78 | def _build_batch_predictions( 79 | self, batch, num_results: int, inputs: List[Molecule], batch_atom_logits, batch_bond_logits 80 | ) -> List[Sequence[SingleProductReaction]]: 81 | from local_retro.scripts.Decode_predictions import get_k_predictions 82 | from local_retro.scripts.get_edit import combined_edit, get_bg_partition 83 | 84 | graphs, nodes_sep, edges_sep = get_bg_partition(batch) 85 | start_node = 0 86 | start_edge = 0 87 | 88 | self.args["top_k"] = num_results 89 | self.args["raw_predictions"] = [] 90 | 91 | for input, graph, end_node, end_edge in zip(inputs, graphs, nodes_sep, edges_sep): 92 | pred_types, pred_sites, pred_scores = combined_edit( 93 | graph, 94 | batch_atom_logits[start_node:end_node], 95 | batch_bond_logits[start_edge:end_edge], 96 | num_results, 97 | ) 98 | start_node, start_edge = end_node, end_edge 99 | 100 | raw_predictions = [ 101 | f"({pred_types[i]}, {pred_sites[i][0]}, {pred_sites[i][1]}, {pred_scores[i]:.3f})" 102 | for i in range(num_results) 103 | ] 104 | 105 | self.args["raw_predictions"].append([input.smiles] + raw_predictions) 106 | 107 | batch_predictions = [] 108 | for idx, input in enumerate(inputs): 109 | try: 110 | raw_str_results = get_k_predictions(test_id=idx, args=self.args)[1][0] 111 | except RuntimeError: 112 | # In very rare cases we may get `rdkit` errors. 113 | raw_str_results = [] 114 | 115 | # We have to `eval` the predictions as they come rendered into strings. Second tuple 116 | # component is empirically (on USPTO-50K test set) in [0, 1], resembling a probability, 117 | # but does not sum up to 1.0 (usually to something in [0.5, 2.0]). 118 | raw_results = [eval(str_result) for str_result in raw_str_results] 119 | 120 | if raw_results: 121 | raw_outputs, probabilities = zip(*raw_results) 122 | else: 123 | raw_outputs = probabilities = [] 124 | 125 | batch_predictions.append( 126 | process_raw_smiles_outputs_backwards( 127 | input=input, 128 | output_list=raw_outputs, 129 | metadata_list=[{"probability": probability} for probability in probabilities], 130 | ) 131 | ) 132 | 133 | return batch_predictions 134 | 135 | def _get_reactions( 136 | self, inputs: List[Molecule], num_results: int 137 | ) -> List[Sequence[SingleProductReaction]]: 138 | import torch 139 | from local_retro.scripts.utils import predict 140 | 141 | batch = self._mols_to_batch(inputs) 142 | batch_atom_logits, batch_bond_logits, _ = predict(self.args, self.model, batch) 143 | 144 | batch_atom_logits = torch.nn.Softmax(dim=1)(batch_atom_logits) 145 | batch_bond_logits = torch.nn.Softmax(dim=1)(batch_bond_logits) 146 | 147 | return self._build_batch_predictions( 148 | batch, num_results, inputs, batch_atom_logits, batch_bond_logits 149 | ) 150 | -------------------------------------------------------------------------------- /syntheseus/tests/search/node_evaluation/test_common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import pytest 4 | 5 | from syntheseus.search.graph.and_or import AndNode, AndOrGraph 6 | from syntheseus.search.graph.molset import MolSetGraph 7 | from syntheseus.search.node_evaluation.common import ( 8 | ConstantNodeEvaluator, 9 | HasSolutionValueFunction, 10 | ReactionModelLogProbCost, 11 | ReactionModelProbPolicy, 12 | ) 13 | 14 | 15 | class TestConstantNodeEvaluator: 16 | @pytest.mark.parametrize("constant", [0.3, 0.8]) 17 | def test_values(self, constant: float, andor_graph_non_minimal: AndOrGraph) -> None: 18 | val_fn = ConstantNodeEvaluator(constant) 19 | vals = val_fn(list(andor_graph_non_minimal.nodes())) 20 | assert all([v == constant for v in vals]) # values should match 21 | assert val_fn.num_calls == len( 22 | andor_graph_non_minimal 23 | ) # should have been called once per node 24 | 25 | 26 | class TestHasSolutionValueFunction: 27 | def test_values(self, andor_graph_non_minimal: AndOrGraph) -> None: 28 | val_fn = HasSolutionValueFunction() 29 | nodes = list(andor_graph_non_minimal.nodes()) 30 | vals = val_fn(nodes) 31 | assert all([v == float(n.has_solution) for v, n in zip(vals, nodes)]) # values should match 32 | assert val_fn.num_calls == len( 33 | andor_graph_non_minimal 34 | ) # should have been called once per node 35 | 36 | 37 | class TestReactionModelLogProbCost: 38 | @pytest.mark.parametrize("normalize", [False, True]) 39 | @pytest.mark.parametrize("temperature", [1.0, 2.0]) 40 | @pytest.mark.parametrize("clip_probability_min", [0.1, 0.5]) 41 | @pytest.mark.parametrize("clip_probability_max", [0.5, 1.0]) 42 | def test_values( 43 | self, 44 | andor_graph_non_minimal: AndOrGraph, 45 | normalize: bool, 46 | temperature: float, 47 | clip_probability_min: float, 48 | clip_probability_max: float, 49 | ) -> None: 50 | val_fn = ReactionModelLogProbCost( 51 | normalize=normalize, 52 | temperature=temperature, 53 | clip_probability_min=clip_probability_min, 54 | clip_probability_max=clip_probability_max, 55 | ) 56 | nodes = [node for node in andor_graph_non_minimal.nodes() if isinstance(node, AndNode)] 57 | 58 | # The toy model does not set reaction probabilities, so set these manually. 59 | node_val_expected = {} 60 | for idx, node in enumerate(nodes): 61 | prob = idx / (len(nodes) - 1) 62 | node.reaction.metadata["probability"] = prob # type: ignore 63 | 64 | node_val_expected[node] = ( 65 | -math.log(min(clip_probability_max, max(clip_probability_min, prob))) / temperature 66 | ) 67 | 68 | if normalize: 69 | normalization_constant = math.log(sum(math.exp(-v) for v in node_val_expected.values())) 70 | node_val_expected = { 71 | key: value + normalization_constant for key, value in node_val_expected.items() 72 | } 73 | 74 | vals = val_fn(nodes) 75 | for val_computed, node in zip(vals, nodes): # values should match 76 | assert math.isclose(val_computed, node_val_expected[node]) 77 | 78 | assert val_fn.num_calls == len(nodes) # should have been called once per AND node 79 | 80 | def test_enforces_min_clipping(self) -> None: 81 | with pytest.raises(ValueError): 82 | ReactionModelLogProbCost(clip_probability_min=0.0) # should fail as `return_log = True` 83 | 84 | 85 | class TestReactionModelProbPolicy: 86 | @pytest.mark.parametrize("normalize", [False, True]) 87 | @pytest.mark.parametrize("temperature", [1.0, 2.0]) 88 | @pytest.mark.parametrize("clip_probability_min", [0.1, 0.5]) 89 | @pytest.mark.parametrize("clip_probability_max", [0.5, 1.0]) 90 | def test_values( 91 | self, 92 | molset_tree_non_minimal: MolSetGraph, 93 | normalize: bool, 94 | temperature: float, 95 | clip_probability_min: float, 96 | clip_probability_max: float, 97 | ) -> None: 98 | val_fn = ReactionModelProbPolicy( 99 | normalize=normalize, 100 | temperature=temperature, 101 | clip_probability_min=clip_probability_min, 102 | clip_probability_max=clip_probability_max, 103 | ) 104 | nodes = [ 105 | node 106 | for node in molset_tree_non_minimal.nodes() 107 | if node != molset_tree_non_minimal.root_node 108 | ] 109 | 110 | # The toy model does not set reaction probabilities, so set these manually. 111 | node_val_expected = {} 112 | for idx, node in enumerate(nodes): 113 | [parent] = molset_tree_non_minimal.predecessors(node) 114 | reaction = molset_tree_non_minimal._graph.edges[parent, node]["reaction"] 115 | 116 | # Be careful not to overwrite things as some reactions in the graph are repeated. 117 | if "probability" not in reaction.metadata: 118 | reaction.metadata["probability"] = prob = idx / (len(nodes) - 1) 119 | else: 120 | prob = reaction.metadata["probability"] 121 | 122 | node_val_expected[node] = min( 123 | clip_probability_max, max(clip_probability_min, prob) 124 | ) ** (1.0 / temperature) 125 | 126 | if normalize: 127 | normalization_factor = sum(node_val_expected.values()) 128 | node_val_expected = { 129 | key: value / normalization_factor for key, value in node_val_expected.items() 130 | } 131 | 132 | vals = val_fn(nodes, graph=molset_tree_non_minimal) 133 | for val_computed, node in zip(vals, nodes): # values should match 134 | assert math.isclose(val_computed, node_val_expected[node]) 135 | 136 | assert ( 137 | val_fn.num_calls == len(molset_tree_non_minimal) - 1 138 | ) # should have been called once per non-root node 139 | 140 | def test_enforces_min_clipping(self) -> None: 141 | with pytest.raises(ValueError): 142 | ReactionModelProbPolicy(clip_probability_min=0.0) # should fail as `normalize = True` 143 | 144 | ReactionModelProbPolicy( 145 | normalize=False, clip_probability_min=0.0 146 | ) # should succeed if we explicitly turn off normalization 147 | --------------------------------------------------------------------------------