├── syntheseus
├── py.typed
├── cli
│ ├── __init__.py
│ ├── main.py
│ └── search_config.yml
├── tests
│ ├── __init__.py
│ ├── cli
│ │ ├── __init__.py
│ │ ├── test_search.py
│ │ └── test_cli.py
│ ├── interface
│ │ ├── __init__.py
│ │ ├── test_bag.py
│ │ ├── test_toy_models.py
│ │ ├── test_reaction.py
│ │ └── test_molecule.py
│ ├── search
│ │ ├── __init__.py
│ │ ├── graph
│ │ │ ├── __init__.py
│ │ │ ├── test_route.py
│ │ │ ├── test_base.py
│ │ │ ├── test_standardization.py
│ │ │ ├── test_message_passing.py
│ │ │ └── test_molset.py
│ │ ├── algorithms
│ │ │ ├── __init__.py
│ │ │ └── test_random.py
│ │ ├── analysis
│ │ │ ├── __init__.py
│ │ │ ├── test_solution_time.py
│ │ │ ├── conftest.py
│ │ │ └── test_diversity.py
│ │ ├── node_evaluation
│ │ │ ├── __init__.py
│ │ │ └── test_common.py
│ │ ├── test_visualization.py
│ │ └── test_mol_inventory.py
│ ├── reaction_prediction
│ │ ├── __init__.py
│ │ ├── chem
│ │ │ ├── __init__.py
│ │ │ └── test_utils.py
│ │ ├── data
│ │ │ ├── __init__.py
│ │ │ └── test_dataset.py
│ │ ├── inference
│ │ │ ├── __init__.py
│ │ │ └── test_models.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── test_misc.py
│ │ │ └── test_parallel.py
│ └── conftest.py
├── interface
│ ├── __init__.py
│ ├── bag.py
│ ├── reaction.py
│ └── molecule.py
├── search
│ ├── analysis
│ │ ├── __init__.py
│ │ └── solution_time.py
│ ├── graph
│ │ ├── __init__.py
│ │ ├── message_passing
│ │ │ ├── __init__.py
│ │ │ └── update_functions.py
│ │ ├── route.py
│ │ ├── node.py
│ │ └── base_graph.py
│ ├── utils
│ │ ├── __init__.py
│ │ └── misc.py
│ ├── algorithms
│ │ ├── __init__.py
│ │ ├── mcts
│ │ │ ├── __init__.py
│ │ │ └── molset.py
│ │ ├── best_first
│ │ │ └── __init__.py
│ │ ├── mixins.py
│ │ ├── breadth_first.py
│ │ └── random.py
│ ├── __init__.py
│ ├── node_evaluation
│ │ ├── __init__.py
│ │ └── common.py
│ └── mol_inventory.py
├── reaction_prediction
│ ├── __init__.py
│ ├── chem
│ │ ├── __init__.py
│ │ └── utils.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── reaction_sample.py
│ │ └── dataset.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── testing.py
│ │ ├── metrics.py
│ │ ├── model_loading.py
│ │ ├── parallel.py
│ │ ├── config.py
│ │ ├── downloading.py
│ │ ├── inference.py
│ │ └── misc.py
│ ├── models
│ │ ├── __init__.py
│ │ └── retro_knn.py
│ ├── environment_gln
│ │ ├── environment.yml
│ │ └── Dockerfile
│ └── inference
│ │ ├── default_checkpoint_links.yml
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── base.py
│ │ ├── gln.py
│ │ ├── toy_models.py
│ │ ├── graph2edits.py
│ │ ├── retro_knn.py
│ │ └── local_retro.py
└── __init__.py
├── docs
├── images
│ ├── logo.png
│ └── logo.svg
├── tutorials
│ └── .gitignore
├── index.md
├── cli
│ ├── search.md
│ └── eval_single_step.md
├── single_step.md
└── installation.md
├── environment.yml
├── CODE_OF_CONDUCT.md
├── .gitignore
├── SUPPORT.md
├── .coveragerc
├── .github
├── azure_pipelines
│ └── code-security-analysis.yml
└── workflows
│ ├── release.yml
│ ├── docs.yml
│ └── ci.yml
├── environment_full.yml
├── LICENSE
├── DEVELOPMENT.md
├── mkdocs.yml
├── .pre-commit-config.yaml
├── SECURITY.md
├── pyproject.toml
└── README.md
/syntheseus/py.typed:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/cli/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/interface/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/tests/cli/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/search/analysis/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/search/graph/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/search/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/tests/interface/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/tests/search/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/search/algorithms/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/tests/search/graph/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/chem/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/search/algorithms/mcts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/tests/reaction_prediction/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/tests/search/algorithms/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/tests/search/analysis/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/search/algorithms/best_first/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/tests/reaction_prediction/chem/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/tests/reaction_prediction/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/tests/search/node_evaluation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/tests/reaction_prediction/inference/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/syntheseus/tests/reaction_prediction/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/syntheseus/HEAD/docs/images/logo.png
--------------------------------------------------------------------------------
/docs/tutorials/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore files created as outputs from the tutorials
2 | *.pdf
3 | *.pkl
4 |
--------------------------------------------------------------------------------
/syntheseus/search/utils/misc.py:
--------------------------------------------------------------------------------
1 | def lookup_by_name(module, name):
2 | return module.__dict__[name]
3 |
--------------------------------------------------------------------------------
/syntheseus/search/__init__.py:
--------------------------------------------------------------------------------
1 | INT_INF = int(1e17) # integer large enough to practically be infinity, but less than 2^63 - 1
2 |
--------------------------------------------------------------------------------
/syntheseus/search/node_evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from syntheseus.search.node_evaluation.base import BaseNodeEvaluator
2 |
3 | __all__ = ["BaseNodeEvaluator"]
4 |
--------------------------------------------------------------------------------
/syntheseus/tests/search/analysis/test_solution_time.py:
--------------------------------------------------------------------------------
1 | """This function is tested minimally in the algorithm tests, so there are no specific tests at this time."""
2 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: syntheseus
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - numpy
7 | - pip # ideally >=25.2 unless using python 3.8
8 | - python>=3.8 # ideally 3.9+
9 | - rdkit
10 |
--------------------------------------------------------------------------------
/syntheseus/search/graph/message_passing/__init__.py:
--------------------------------------------------------------------------------
1 | from syntheseus.search.graph.message_passing.run import run_message_passing
2 | from syntheseus.search.graph.message_passing.update_functions import (
3 | depth_update,
4 | has_solution_update,
5 | )
6 |
7 | __all__ = ["run_message_passing", "has_solution_update", "depth_update"]
8 |
--------------------------------------------------------------------------------
/syntheseus/tests/search/graph/test_route.py:
--------------------------------------------------------------------------------
1 | """Route objects are tested implicity in other tests, so there are only minimal tests for now."""
2 |
3 | from syntheseus.interface.molecule import Molecule
4 |
5 |
6 | def test_route_starting_molecules(minimal_synthesis_graph):
7 | assert minimal_synthesis_graph.get_starting_molecules() == {Molecule("CC")}
8 |
--------------------------------------------------------------------------------
/syntheseus/search/algorithms/mcts/molset.py:
--------------------------------------------------------------------------------
1 | from syntheseus.search.algorithms.base import MolSetSearchAlgorithm
2 | from syntheseus.search.algorithms.mcts.base import BaseMCTS
3 | from syntheseus.search.graph.molset import MolSetGraph, MolSetNode
4 |
5 |
6 | class MolSetMCTS(BaseMCTS[MolSetGraph, MolSetNode, MolSetNode], MolSetSearchAlgorithm[int]):
7 | pass
8 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/environment_gln/environment.yml:
--------------------------------------------------------------------------------
1 | name: gln-env
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | - defaults
7 | - pyg
8 | dependencies:
9 | - python=3.9
10 | - pytorch=1.13.0=py3.9_cuda11.6_cudnn8.3.2_0
11 | - scipy
12 | - tqdm
13 | - boost
14 | - boost-cpp
15 | - cairo
16 | - cmake
17 | - eigen
18 | - gxx_linux-64
19 | - mkl<2024.1.0
20 | - pillow
21 | - pkg-config
22 | - py-boost
23 | - pyg==2.1.0
24 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/syntheseus/__init__.py:
--------------------------------------------------------------------------------
1 | from syntheseus.interface.bag import Bag
2 | from syntheseus.interface.models import BackwardReactionModel, ForwardReactionModel, ReactionModel
3 | from syntheseus.interface.molecule import Molecule
4 | from syntheseus.interface.reaction import Reaction, SingleProductReaction
5 |
6 | __all__ = [
7 | "Molecule",
8 | "Reaction",
9 | "SingleProductReaction",
10 | "Bag",
11 | "ReactionModel",
12 | "BackwardReactionModel",
13 | "ForwardReactionModel",
14 | ]
15 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # Notebook checkpoints:
30 | .ipynb_checkpoints
31 |
32 | # Unit test / coverage reports
33 | .coverage
34 | .coverage.*
35 |
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # Support
2 |
3 | ## How to file issues and get help
4 |
5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
6 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
7 | feature request as a new Issue.
8 |
9 | For help and questions about using this project, please file a GitHub Issue or email us.
10 | We will try to respond to all issues.
11 |
12 | ## Microsoft Support Policy
13 |
14 | Support for this project is limited to the resources listed above.
15 |
--------------------------------------------------------------------------------
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | branch = True
3 | source = syntheseus
4 | omit = syntheseus/tests/*
5 |
6 | [report]
7 | # Regexes for lines to exclude from consideration
8 | exclude_lines =
9 | # Have to re-enable the standard pragma
10 | pragma: no cover
11 |
12 | # Don't complain about missing debug-only code:
13 | def __repr__
14 | if self\.debug
15 |
16 | # Don't complain if tests don't hit defensive assertion code:
17 | raise AssertionError
18 | raise NotImplementedError
19 |
20 | # Don't complain about abstract methods, they aren't run:
21 | @(abc\.)?abstractmethod
22 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/inference/default_checkpoint_links.yml:
--------------------------------------------------------------------------------
1 | backward:
2 | Chemformer: https://figshare.com/ndownloader/files/42009888
3 | GLN: https://figshare.com/ndownloader/files/45882867
4 | Graph2Edits: https://figshare.com/ndownloader/files/44194301
5 | LocalRetro: https://figshare.com/ndownloader/files/42287319
6 | MEGAN: https://figshare.com/ndownloader/files/42012732
7 | MHNreact: https://figshare.com/ndownloader/files/42012777
8 | RetroKNN: https://figshare.com/ndownloader/files/45662430
9 | RootAligned: https://figshare.com/ndownloader/files/42012792
10 | forward:
11 | Chemformer: https://figshare.com/ndownloader/files/42012708
12 |
--------------------------------------------------------------------------------
/syntheseus/tests/search/graph/test_base.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class BaseNodeTest(abc.ABC):
5 | """Base class which defines common tests for graph nodes."""
6 |
7 | @abc.abstractmethod
8 | def get_node(self):
9 | pass
10 |
11 | def test_node_comparison(self):
12 | """Test that nodes are equal if and only if they are the same object."""
13 | node1 = self.get_node()
14 | node2 = self.get_node()
15 | assert node1 is not node2
16 | assert node1 != node2
17 |
18 | @abc.abstractmethod
19 | def test_nodes_not_frozen(self):
20 | """Test that the fields of the node can be modified."""
21 |
--------------------------------------------------------------------------------
/.github/azure_pipelines/code-security-analysis.yml:
--------------------------------------------------------------------------------
1 | schedules:
2 | - cron: '0 4 * * 0'
3 | displayName: Weekly build
4 | branches:
5 | include:
6 | - main
7 |
8 | pr: none
9 | trigger: none
10 |
11 | pool:
12 | vmImage: 'windows-latest'
13 |
14 | steps:
15 | - task: CredScan@2
16 | inputs:
17 | toolMajorVersion: 'V2'
18 | - task: ComponentGovernanceComponentDetection@0
19 | inputs:
20 | scanType: 'Register'
21 | verbosity: 'Verbose'
22 | alertWarningLevel: 'High'
23 | - task: PublishSecurityAnalysisLogs@2
24 | inputs:
25 | ArtifactName: 'CodeAnalysisLogs'
26 | ArtifactType: 'Container'
27 | AllTools: true
28 | ToolLogsNotFoundAction: 'Standard'
29 |
--------------------------------------------------------------------------------
/syntheseus/tests/interface/test_bag.py:
--------------------------------------------------------------------------------
1 | from syntheseus.interface.bag import Bag
2 |
3 |
4 | def test_basic_operations() -> None:
5 | bag_1 = Bag(["a", "b", "a"])
6 | bag_2 = Bag(["b", "a", "a"])
7 | bag_3 = Bag(["a", "b"])
8 |
9 | # Test `__contains__`.
10 | assert "a" in bag_1
11 | assert "b" in bag_1
12 | assert "c" not in bag_1
13 |
14 | # Test `__eq__`.
15 | assert bag_1 == bag_2
16 | assert bag_1 != bag_3
17 |
18 | # Test `__iter__`.
19 | assert Bag(bag_1) == bag_1
20 |
21 | # Test `__len__`.
22 | assert len(bag_1) == 3
23 | assert len(bag_3) == 2
24 |
25 | # Test `__hash__`.
26 | assert len(set([bag_1, bag_2, bag_3])) == 2
27 |
--------------------------------------------------------------------------------
/syntheseus/tests/reaction_prediction/utils/test_misc.py:
--------------------------------------------------------------------------------
1 | from typing import Iterable
2 |
3 | import pytest
4 |
5 | from syntheseus.reaction_prediction.utils.misc import parallelize
6 |
7 |
8 | def square(x: int) -> int:
9 | return x * x
10 |
11 |
12 | @pytest.mark.parametrize("use_iterator", [False, True])
13 | @pytest.mark.parametrize("num_processes", [0, 2])
14 | def test_parallelize(use_iterator: bool, num_processes: int) -> None:
15 | inputs: Iterable[int] = [1, 2, 3, 4]
16 | expected_outputs = [square(x) for x in inputs]
17 |
18 | if use_iterator:
19 | inputs = iter(inputs)
20 |
21 | assert list(parallelize(square, inputs, num_processes=num_processes)) == expected_outputs
22 |
--------------------------------------------------------------------------------
/environment_full.yml:
--------------------------------------------------------------------------------
1 | name: syntheseus-full
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - pyg # RetroKNN
6 | - conda-forge
7 | - defaults
8 | dependencies:
9 | - faiss-gpu # RetroKNN
10 | - numpy
11 | - pip>=25.2
12 | - python==3.9.7
13 | - pytorch=2.2.2=py3.9_cuda12.1_cudnn8.9.2_0
14 | - pytorch-lightning==2.2.2 # Chemformer
15 | - pytorch-scatter==2.1.2 # RetroKNN
16 | - rdchiral_cpp # MHNreact
17 | - rdkit=2021.09.4
18 | - pip:
19 | - --find-links https://data.dgl.ai/wheels/torch-2.2/cu121/repo.html
20 | - dgl==2.4.0 # LocalRetro, RetroKNN
21 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/utils/testing.py:
--------------------------------------------------------------------------------
1 | def are_single_step_models_installed():
2 | """Check whether the single-step models can be imported successfully."""
3 | try:
4 | # Try to import the single-step model repositories to check if they are installed. It could
5 | # be the case that these are installed but their dependencies are not, in which case trying
6 | # to *use* the models would fail; nevertheless, the below is good enough for our usecases.
7 |
8 | import chemformer # noqa: F401
9 | import graph2edits # noqa: F401
10 | import local_retro # noqa: F401
11 | import megan # noqa: F401
12 | import mhnreact # noqa: F401
13 | import root_aligned # noqa: F401
14 |
15 | return True
16 | except ModuleNotFoundError:
17 | return False
18 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Release Stable Version
2 |
3 | on:
4 | workflow_dispatch:
5 | inputs:
6 | version:
7 | required: true
8 | type: string
9 |
10 | permissions:
11 | contents: write
12 |
13 | jobs:
14 | push-tag:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: actions/checkout@v4
18 | - name: Configure Git
19 | run: |
20 | git config user.name github-actions[bot]
21 | git config user.email 41898282+github-actions[bot]@users.noreply.github.com
22 | - run: |
23 | git tag -a v${{ inputs.version }} -m "Release v${{ inputs.version }}"
24 | git push origin v${{ inputs.version }}
25 | build-docs:
26 | needs: push-tag
27 | uses: ./.github/workflows/docs.yml
28 | with:
29 | versions: ${{ inputs.version }} stable
30 |
--------------------------------------------------------------------------------
/syntheseus/tests/search/graph/test_standardization.py:
--------------------------------------------------------------------------------
1 | """
2 | At the moment, only a smoke-test is done for graph standardization,
3 | but the exact behaviour is not well-tested.
4 | These tests should be added in the future.
5 | """
6 |
7 | import pytest
8 |
9 | from syntheseus.search.graph.and_or import AndOrGraph
10 | from syntheseus.search.graph.molset import MolSetGraph
11 | from syntheseus.search.graph.standardization import get_unique_node_andor_graph
12 |
13 |
14 | def test_smoke_andor(andor_graph_non_minimal: AndOrGraph):
15 | with pytest.warns(UserWarning):
16 | output = get_unique_node_andor_graph(andor_graph_non_minimal)
17 |
18 | assert len(output) == len(andor_graph_non_minimal) # no nodes deleted here
19 |
20 |
21 | def test_smoke_molset(molset_tree_non_minimal: MolSetGraph):
22 | with pytest.warns(UserWarning):
23 | get_unique_node_andor_graph(molset_tree_non_minimal)
24 |
--------------------------------------------------------------------------------
/syntheseus/cli/main.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from typing import Callable, Dict
3 |
4 | from syntheseus.cli import eval_single_step, search
5 |
6 |
7 | def main() -> None:
8 | supported_commands: Dict[str, Callable] = {
9 | "search": search.main,
10 | "eval-single-step": eval_single_step.main,
11 | }
12 | supported_command_names = ", ".join(supported_commands.keys())
13 |
14 | if len(sys.argv) == 1:
15 | raise ValueError(f"Please choose a command from: {supported_command_names}")
16 |
17 | command = sys.argv[1]
18 | if command not in supported_commands:
19 | raise ValueError(f"Command {command} not supported; choose from: {supported_command_names}")
20 |
21 | # Drop the subcommand name and let the chosen command parse the rest of the arguments.
22 | del sys.argv[1]
23 | supported_commands[command]()
24 |
25 |
26 | if __name__ == "__main__":
27 | main()
28 |
--------------------------------------------------------------------------------
/syntheseus/search/graph/message_passing/update_functions.py:
--------------------------------------------------------------------------------
1 | from syntheseus.search.graph.base_graph import RetrosynthesisSearchGraph
2 | from syntheseus.search.graph.node import BaseGraphNode
3 |
4 |
5 | def depth_update(node: BaseGraphNode, graph: RetrosynthesisSearchGraph) -> bool:
6 | parent_depths = [n.depth for n in graph.predecessors(node)]
7 | if len(parent_depths) == 0:
8 | new_depth = 0
9 | else:
10 | new_depth = min(parent_depths) + 1
11 |
12 | depth_changed = node.depth != new_depth
13 | node.depth = new_depth
14 | return depth_changed
15 |
16 |
17 | def has_solution_update(node: BaseGraphNode, graph: RetrosynthesisSearchGraph) -> bool:
18 | new_has_solution = node._has_intrinsic_solution() or node._has_solution_from_children(
19 | list(graph.successors(node))
20 | )
21 |
22 | old_has_solution = node.has_solution
23 | node.has_solution = new_has_solution
24 | return new_has_solution != old_has_solution
25 |
--------------------------------------------------------------------------------
/syntheseus/search/algorithms/mixins.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Generic, TypeVar
4 |
5 | from syntheseus.search.algorithms.base import SearchAlgorithm
6 | from syntheseus.search.graph.node import BaseGraphNode
7 | from syntheseus.search.node_evaluation import BaseNodeEvaluator
8 |
9 | NodeType = TypeVar("NodeType", bound=BaseGraphNode)
10 |
11 |
12 | class ValueFunctionMixin(SearchAlgorithm, Generic[NodeType]):
13 | def __init__(self, *args, value_function: BaseNodeEvaluator[NodeType], **kwargs):
14 | super().__init__(*args, **kwargs)
15 | self.value_function = value_function
16 |
17 | def set_node_values(self, nodes, graph):
18 | output_nodes = super().set_node_values(nodes, graph)
19 | for node in output_nodes:
20 | node.data.setdefault("num_calls_value_function", self.value_function.num_calls)
21 | return output_nodes
22 |
23 | def reset(self) -> None:
24 | super().reset()
25 | self.value_function.reset()
26 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/inference/__init__.py:
--------------------------------------------------------------------------------
1 | from syntheseus.reaction_prediction.inference.chemformer import ChemformerModel
2 | from syntheseus.reaction_prediction.inference.gln import GLNModel
3 | from syntheseus.reaction_prediction.inference.graph2edits import Graph2EditsModel
4 | from syntheseus.reaction_prediction.inference.local_retro import LocalRetroModel
5 | from syntheseus.reaction_prediction.inference.megan import MEGANModel
6 | from syntheseus.reaction_prediction.inference.mhnreact import MHNreactModel
7 | from syntheseus.reaction_prediction.inference.retro_knn import RetroKNNModel
8 | from syntheseus.reaction_prediction.inference.root_aligned import RootAlignedModel
9 | from syntheseus.reaction_prediction.inference.toy_models import (
10 | LinearMoleculesToyModel,
11 | ListOfReactionsToyModel,
12 | )
13 |
14 | __all__ = [
15 | "ChemformerModel",
16 | "GLNModel",
17 | "Graph2EditsModel",
18 | "LinearMoleculesToyModel",
19 | "ListOfReactionsToyModel",
20 | "LocalRetroModel",
21 | "MEGANModel",
22 | "MHNreactModel",
23 | "RetroKNNModel",
24 | "RootAlignedModel",
25 | ]
26 |
--------------------------------------------------------------------------------
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | name: Docs
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | workflow_call:
7 | inputs:
8 | versions:
9 | required: true
10 | type: string
11 | workflow_dispatch:
12 |
13 | permissions:
14 | contents: write
15 |
16 | jobs:
17 | deploy:
18 | runs-on: ubuntu-latest
19 | steps:
20 | - uses: actions/checkout@v4
21 | - name: Configure Git
22 | run: |
23 | git config user.name github-actions[bot]
24 | git config user.email 41898282+github-actions[bot]@users.noreply.github.com
25 | git fetch origin gh-pages --depth=1
26 | - uses: actions/setup-python@v4
27 | with:
28 | python-version: 3.x
29 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
30 | - uses: actions/cache@v3
31 | with:
32 | key: mkdocs-material-${{ env.cache_id }}
33 | path: .cache
34 | restore-keys: |
35 | mkdocs-material-
36 | - run: pip install mkdocs-material mkdocs-jupyter mike
37 | - run: mike deploy --push --update-aliases ${{ inputs.versions != '' && inputs.versions || 'dev' }}
38 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 |
2 | {width="450"}
3 | Navigating the labyrinth of synthesis planning
4 |
5 |
6 | ---
7 |
8 | [](https://github.com/microsoft/syntheseus/actions/workflows/ci.yml)
9 | [](https://www.python.org/downloads/)
10 | [](https://pypi.org/project/syntheseus/)
11 | [](https://github.com/ambv/black)
12 | [](https://github.com/microsoft/syntheseus/blob/main/LICENSE)
13 |
14 | Syntheseus is a package for end-to-end retrosynthetic planning.
15 |
16 | - ⚒️ Combines search algorithms and reaction models in a standardized way
17 | - 🧭 Includes implementations of common search algorithms
18 | - 🧪 Includes wrappers for state-of-the-art reaction models
19 | - ⚙️ Exposes a simple API to plug in custom models and algorithms
20 | - 📈 Can be used to benchmark components of a retrosynthesis pipeline
21 |
--------------------------------------------------------------------------------
/syntheseus/interface/bag.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Collection
2 | from typing import Generic, Iterable, Iterator, Protocol, TypeVar
3 |
4 |
5 | class Comparable(Protocol):
6 | def __lt__(self, __other) -> bool:
7 | ...
8 |
9 |
10 | ElementT = TypeVar("ElementT", bound=Comparable)
11 |
12 |
13 | class Bag(Collection, Generic[ElementT]):
14 | """Class representing a frozen multi-set (i.e. set where elements are allowed to repeat).
15 |
16 | The bag elements are internally stored as a sorted tuple for simplicity, thus lookup time is
17 | linear in terms of the size of the bag.
18 | """
19 |
20 | def __init__(self, values: Iterable[ElementT]) -> None:
21 | self._values = tuple(sorted(values))
22 |
23 | def __iter__(self) -> Iterator:
24 | return iter(self._values)
25 |
26 | def __contains__(self, element) -> bool:
27 | return element in self._values
28 |
29 | def __eq__(self, other) -> bool:
30 | if isinstance(other, Bag):
31 | return self._values == other._values
32 | else:
33 | return False
34 |
35 | def __len__(self) -> int:
36 | return len(self._values)
37 |
38 | def __repr__(self) -> str:
39 | return repr(self._values)
40 |
41 | def __hash__(self) -> int:
42 | return hash(self._values)
43 |
--------------------------------------------------------------------------------
/DEVELOPMENT.md:
--------------------------------------------------------------------------------
1 | To release a new stable version of `syntheseus` one needs to complete the following steps:
2 |
3 | 1. Create a PR editing the `CHANGELOG.md` analogous to [#88](https://github.com/microsoft/syntheseus/pull/88) (note how the changelog has to be modified in three places).
4 | 2. Run the ["Release Stable Version" workflow](https://github.com/microsoft/syntheseus/actions/workflows/release.yml), providing the version number as argument (use the `x.y.z` format _without_ the leading "v"; the workflow will prepend it wherever necessary). Make sure the branch is set to `main`. If this step does not work as intended the tag can always be deleted manually, while the changes to the docs can be reverted by removing a commit from the `gh-pages` branch (requires a force push).
5 | 3. Create a GitHub Release from the [newly created tag](https://github.com/microsoft/syntheseus/tags). Set the name to `syntheseus x.y.z`. The description should be the list of changes copied from the changelog (example [here](https://github.com/microsoft/syntheseus/releases/tag/v0.4.0)). Consider including a short description before the list of changes to describe the main gist of the release.
6 | 4. Release a new version to PyPI (e.g. by following [these instructions](https://realpython.com/pypi-publish-python-package/)). Consider publishing to [Test PyPI](https://test.pypi.org/) first to verify that the README renders correctly.
7 |
--------------------------------------------------------------------------------
/syntheseus/tests/conftest.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import pytest
4 |
5 | from syntheseus.interface.bag import Bag
6 | from syntheseus.interface.molecule import Molecule
7 | from syntheseus.interface.reaction import SingleProductReaction
8 |
9 |
10 | @pytest.fixture
11 | def cocs_mol() -> Molecule:
12 | """Returns the molecule with smiles 'COCS'."""
13 | return Molecule("COCS", make_rdkit_mol=False)
14 |
15 |
16 | @pytest.fixture
17 | def ossc_mol() -> Molecule:
18 | """Returns the molecule with smiles 'OSSC'."""
19 | return Molecule("OSSC", make_rdkit_mol=False)
20 |
21 |
22 | @pytest.fixture
23 | def soos_mol() -> Molecule:
24 | """Returns the molecule with smiles 'SOOS'."""
25 | return Molecule("SOOS", make_rdkit_mol=False)
26 |
27 |
28 | @pytest.fixture
29 | def rxn_cocs_from_co_cs(cocs_mol: Molecule) -> SingleProductReaction:
30 | """Returns a reaction with COCS as the product."""
31 | return SingleProductReaction(product=cocs_mol, reactants=Bag([Molecule("CO"), Molecule("CS")]))
32 |
33 |
34 | @pytest.fixture
35 | def rxn_cs_from_cc() -> SingleProductReaction:
36 | return SingleProductReaction(product=Molecule("CS"), reactants=Bag([Molecule("CC")]))
37 |
38 |
39 | @pytest.fixture
40 | def rxn_cocs_from_cocc(cocs_mol: Molecule) -> SingleProductReaction:
41 | return SingleProductReaction(product=cocs_mol, reactants=Bag([Molecule("COCC")]))
42 |
--------------------------------------------------------------------------------
/syntheseus/tests/reaction_prediction/utils/test_parallel.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from syntheseus import Molecule
4 | from syntheseus.tests.cli.test_eval_single_step import DummyModel
5 |
6 | try:
7 | import torch
8 |
9 | from syntheseus.reaction_prediction.utils.parallel import ParallelReactionModel
10 |
11 | torch_available = True
12 | cuda_available = torch.cuda.is_available()
13 | except ModuleNotFoundError:
14 | torch_available = False
15 | cuda_available = False
16 |
17 |
18 | @pytest.mark.skipif(
19 | not torch_available, reason="Simple testing of parallel inference requires torch"
20 | )
21 | def test_parallel_reaction_model_cpu() -> None:
22 | # We cannot really run this on CPU, so just check if the model creation works as normal.
23 | parallel_model: ParallelReactionModel = ParallelReactionModel(
24 | model_fn=DummyModel, devices=["cpu"] * 4
25 | )
26 | assert parallel_model([]) == []
27 |
28 |
29 | @pytest.mark.skipif(
30 | not cuda_available, reason="Full testing of parallel inference requires GPU to be available"
31 | )
32 | def test_parallel_reaction_model_gpu() -> None:
33 | model = DummyModel()
34 | parallel_model: ParallelReactionModel = ParallelReactionModel(
35 | model_fn=DummyModel, devices=["cuda:0"] * 4
36 | )
37 |
38 | inputs = [Molecule("C" * length) for length in range(1, 6)]
39 | assert parallel_model(inputs) == model(inputs)
40 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/inference/config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from enum import Enum
3 | from typing import Any, Dict
4 |
5 | from omegaconf import MISSING
6 |
7 | from syntheseus.reaction_prediction.inference import (
8 | ChemformerModel,
9 | GLNModel,
10 | Graph2EditsModel,
11 | LocalRetroModel,
12 | MEGANModel,
13 | MHNreactModel,
14 | RetroKNNModel,
15 | RootAlignedModel,
16 | )
17 |
18 |
19 | class ForwardModelClass(Enum):
20 | Chemformer = ChemformerModel
21 |
22 |
23 | class BackwardModelClass(Enum):
24 | Chemformer = ChemformerModel
25 | GLN = GLNModel
26 | Graph2Edits = Graph2EditsModel
27 | LocalRetro = LocalRetroModel
28 | MEGAN = MEGANModel
29 | MHNreact = MHNreactModel
30 | RetroKNN = RetroKNNModel
31 | RootAligned = RootAlignedModel
32 |
33 |
34 | @dataclass
35 | class ModelConfig:
36 | """Config for loading any reaction models, forward or backward."""
37 |
38 | model_dir: str = MISSING
39 | model_kwargs: Dict[str, Any] = field(default_factory=dict)
40 |
41 |
42 | @dataclass
43 | class ForwardModelConfig(ModelConfig):
44 | """Config for loading one of the supported forward models."""
45 |
46 | model_class: ForwardModelClass = MISSING
47 |
48 |
49 | @dataclass
50 | class BackwardModelConfig(ModelConfig):
51 | """Config for loading one of the supported backward models."""
52 |
53 | model_class: BackwardModelClass = MISSING
54 |
--------------------------------------------------------------------------------
/syntheseus/tests/interface/test_toy_models.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from syntheseus.interface.molecule import Molecule
4 | from syntheseus.interface.reaction import SingleProductReaction
5 | from syntheseus.reaction_prediction.inference import (
6 | LinearMoleculesToyModel,
7 | ListOfReactionsToyModel,
8 | )
9 |
10 |
11 | def test_linear_molecules_invalid_molecule() -> None:
12 | """
13 | The LinearMolecules model is only defined on the space of linear molecules.
14 | If called on a non-linear molecule (e.g. with branching) it is currently set up to throw an error.
15 | This test ensures that this happens.
16 |
17 | NOTE: in the future the behaviour could be changed to just return an empty list,
18 | but for a toy example I thought it would be best to alert the user with a warning.
19 | """
20 | rxn_model = LinearMoleculesToyModel()
21 | with pytest.raises(AssertionError):
22 | rxn_model([Molecule("CC(C)C")])
23 |
24 |
25 | def test_list_of_reactions_model(
26 | rxn_cocs_from_co_cs: SingleProductReaction,
27 | rxn_cocs_from_cocc: SingleProductReaction,
28 | rxn_cs_from_cc: SingleProductReaction,
29 | ) -> None:
30 | """Simple test of the ListOfReactionsModel class."""
31 | model = ListOfReactionsToyModel([rxn_cocs_from_co_cs, rxn_cocs_from_cocc, rxn_cs_from_cc])
32 | output = model([Molecule("COCS"), Molecule("CS"), Molecule("CO")])
33 | assert output == [[rxn_cocs_from_co_cs, rxn_cocs_from_cocc], [rxn_cs_from_cc], []]
34 |
--------------------------------------------------------------------------------
/syntheseus/cli/search_config.yml:
--------------------------------------------------------------------------------
1 | mcts:
2 | Chemformer:
3 | bound_constant: 1
4 | policy_kwargs:
5 | temperature: 8.0
6 | value_function_kwargs:
7 | constant: 0.75
8 | GLN:
9 | bound_constant: 100
10 | policy_kwargs:
11 | temperature: 4.0
12 | value_function_kwargs:
13 | constant: 0.5
14 | LocalRetro:
15 | bound_constant: 1
16 | policy_kwargs:
17 | temperature: 4.0
18 | value_function_kwargs:
19 | constant: 0.5
20 | MEGAN:
21 | bound_constant: 1
22 | policy_kwargs:
23 | clip_probability_max: 0.9999
24 | clip_probability_min: 1.0e-05
25 | temperature: 2.0
26 | value_function_kwargs:
27 | constant: 0.75
28 | MHNreact:
29 | bound_constant: 1
30 | policy_kwargs:
31 | temperature: 8.0
32 | value_function_kwargs:
33 | constant: 0.5
34 | RetroKNN:
35 | bound_constant: 1
36 | policy_kwargs:
37 | temperature: 8.0
38 | value_function_kwargs:
39 | constant: 0.75
40 | RootAligned:
41 | bound_constant: 10
42 | policy_kwargs:
43 | clip_probability_max: 0.999
44 | clip_probability_min: 1.0e-05
45 | temperature: 8.0
46 | value_function_kwargs:
47 | constant: 0.5
48 | retro_star:
49 | MEGAN:
50 | and_node_cost_fn_kwargs:
51 | clip_probability_max: 0.99
52 | clip_probability_min: 1.0e-05
53 | RootAligned:
54 | and_node_cost_fn_kwargs:
55 | clip_probability_max: 0.9999
56 | clip_probability_min: 1.0e-08
57 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: Syntheseus
2 | site_url: https://microsoft.github.io/syntheseus/
3 |
4 | repo_name: microsoft/syntheseus
5 | repo_url: https://github.com/microsoft/syntheseus
6 | edit_uri: edit/main/docs/
7 |
8 | theme:
9 | name: material
10 | palette:
11 | - media: "(prefers-color-scheme: light)"
12 | scheme: default
13 | primary: black
14 | accent: red
15 | toggle:
16 | icon: material/toggle-switch
17 | name: "Switch to dark mode"
18 | - media: "(prefers-color-scheme: dark)"
19 | scheme: slate
20 | primary: black
21 | accent: red
22 | toggle:
23 | icon: material/toggle-switch-off-outline
24 | name: "Switch to light mode"
25 | features:
26 | - content.code.copy
27 | - navigation.tabs
28 |
29 | nav:
30 | - Get Started:
31 | - Overview: index.md
32 | - Installation: installation.md
33 | - Single-Step Models: single_step.md
34 | - CLI:
35 | - Single-Step Evaluation: cli/eval_single_step.md
36 | - Running Search: cli/search.md
37 | - Tutorials:
38 | - Quick Start: tutorials/quick_start.ipynb
39 | - Integrating a Custom Model: tutorials/custom_model.ipynb
40 | - Multi-step Search on PaRoutes: tutorials/paroutes_benchmark.ipynb
41 |
42 | plugins:
43 | - mkdocs-jupyter
44 |
45 | markdown_extensions:
46 | - admonition
47 | - attr_list
48 | - md_in_html
49 | - pymdownx.details
50 | - pymdownx.superfences
51 | - pymdownx.tabbed:
52 | alternate_style: true
53 |
54 | extra:
55 | version:
56 | default: stable
57 | provider: mike
58 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/environment_gln/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.8-cudnn8-ubuntu22.04
2 | MAINTAINER krmaziar@microsoft.com
3 |
4 | # Set bash, as conda doesn't like dash
5 | SHELL [ "/bin/bash", "--login", "-c" ]
6 |
7 | # Make bash aware of conda
8 | RUN echo ". /opt/miniconda/etc/profile.d/conda.sh" >> ~/.profile
9 |
10 | # Turn off caching in pip
11 | ENV PIP_NO_CACHE_DIR=1
12 |
13 | # Install the dependencies into conda's default environment
14 | COPY ./environment.yml /tmp/
15 | RUN conda config --remove channels defaults && \
16 | conda config --add channels conda-forge && \
17 | conda update --all && \
18 | conda install mamba -n base -c conda-forge && \
19 | conda clean --all --yes
20 | RUN mamba env update -p /opt/miniconda -f /tmp/environment.yml && conda clean --all --yes
21 |
22 | # Install RDKit from source
23 | RUN git clone https://github.com/rdkit/rdkit.git
24 | WORKDIR /rdkit
25 | RUN git checkout 8bd74d91118f3fdb370081ef0a18d71715e7c6cf
26 | RUN mkdir build && cd build && cmake -DPy_ENABLE_SHARED=1 \
27 | -DRDK_INSTALL_INTREE=ON \
28 | -DRDK_INSTALL_STATIC_LIBS=OFF \
29 | -DRDK_BUILD_CPP_TESTS=ON \
30 | -DPYTHON_NUMPY_INCLUDE_PATH="$(python -c 'import numpy ; print(numpy.get_include())')" \
31 | -DBOOST_ROOT="$CONDA_PREFIX" \
32 | .. && make && make install
33 | WORKDIR /
34 |
35 | # Install GLN (this relies on `CUDA_HOME` being set correctly)
36 | RUN git clone https://github.com/kmaziarz/GLN.git
37 | WORKDIR /GLN
38 | RUN pip install -e .
39 |
40 | ENV PYTHONPATH=$PYTHONPATH:/rdkit:/GLN
41 | ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/rdkit/lib
42 |
--------------------------------------------------------------------------------
/docs/cli/search.md:
--------------------------------------------------------------------------------
1 | # Running Search
2 |
3 | ## Usage
4 |
5 | ```
6 | syntheseus search \
7 | search_targets_file=[SMILES_FILE_WITH_SEARCH_TARGETS] \
8 | inventory_smiles_file=[SMILES_FILE_WITH_PURCHASABLE_MOLECULES] \
9 | model_class=[MODEL_CLASS] \
10 | model_dir=[MODEL_DIR] \
11 | time_limit_s=[NUMBER_OF_SECONDS_PER_TARGET]
12 | ```
13 |
14 | Both the search targets and the purchasable molecules inventory are expected to be plain SMILES files, with one molecule per line.
15 |
16 | The `search` command accepts further arguments to configure the search algorithm; see `SearchConfig` in `cli/search.py` for the complete list.
17 |
18 | !!! info
19 | When using one of the natively supported single-step models you can omit `model_dir`, which will cause `syntheseus` to use a default checkpoint trained on USPTO-50K (see [here](../single_step.md) for details).
20 |
21 | ## Configuring the search algorithm
22 |
23 | You can set the search algorithm explicitly using the `search_algorithm` argument to `retro_star` (default), `mcts` or `pdvn`.
24 | For all of those algorithms you can vary hyperparameters such as the policy/value functions or MCTS bound type/constant.
25 |
26 | In practice however there may be no need to override any hyperparameters, especially if combining Retro\* or MCTS with one of the natively supported models, as for those `syntheseus` will automatically choose sensible hyperparameter defaults (listed in `cli/search_config.yml`).
27 | [In our experience](https://arxiv.org/abs/2310.19796) both Retro* and MCTS show similar performance when tuned properly, but you may want to try both for your particular usecase and see which one works best empirically.
28 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/inference/base.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Optional, Union
3 |
4 | from syntheseus.interface.models import (
5 | BackwardReactionModel,
6 | ForwardReactionModel,
7 | InputType,
8 | ReactionModel,
9 | ReactionType,
10 | )
11 | from syntheseus.reaction_prediction.utils.downloading import get_default_model_dir_from_cache
12 |
13 |
14 | class ExternalReactionModel(ReactionModel[InputType, ReactionType]):
15 | """Base class for the external reaction models, abstracting out common functinality."""
16 |
17 | def __init__(
18 | self, model_dir: Optional[Union[str, Path]] = None, device: Optional[str] = None, **kwargs
19 | ) -> None:
20 | super().__init__(**kwargs)
21 | import torch
22 |
23 | self.model_dir = Path(model_dir or self.get_default_model_dir())
24 | self.device = device or ("cuda:0" if torch.cuda.is_available() else "cpu")
25 |
26 | def get_default_model_dir(self) -> Path:
27 | model_dir = get_default_model_dir_from_cache(self.name, is_forward=self.is_forward())
28 |
29 | if model_dir is None:
30 | raise ValueError(
31 | f"Could not obtain a default checkpoint for model {self.name}, "
32 | "please provide an explicit value for `model_dir`"
33 | )
34 |
35 | return model_dir
36 |
37 | @property
38 | def name(self) -> str:
39 | return self.__class__.__name__.removesuffix("Model")
40 |
41 |
42 | class ExternalBackwardReactionModel(ExternalReactionModel, BackwardReactionModel):
43 | pass
44 |
45 |
46 | class ExternalForwardReactionModel(ExternalReactionModel, ForwardReactionModel):
47 | pass
48 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # This should not be necessary, except that `conda<4.11` has a bug dealing with `python>=3.10`
2 | # (see https://github.com/conda/conda/issues/10969), and the below makes that go away.
3 | default_language_version:
4 | python: python3
5 |
6 | repos:
7 | # Generally useful pre-commit hooks
8 | - repo: https://github.com/pre-commit/pre-commit-hooks
9 | rev: v4.3.0 # Use the ref you want to point at
10 | hooks:
11 | - id: check-added-large-files
12 | - id: check-case-conflict
13 | - id: check-docstring-first
14 | - id: check-executables-have-shebangs
15 | - id: check-json
16 | - id: check-merge-conflict
17 | - id: check-symlinks
18 | - id: check-toml
19 | - id: check-yaml
20 | - id: debug-statements
21 | - id: destroyed-symlinks
22 | - id: detect-private-key
23 | - id: end-of-file-fixer
24 | - id: name-tests-test
25 | args: ["--pytest-test-first"]
26 | - id: trailing-whitespace
27 | args: [--markdown-linebreak-ext=md]
28 |
29 | # latest version of black when this pre-commit config is being set up
30 | - repo: https://github.com/psf/black
31 | rev: 23.3.0
32 | hooks:
33 | - id: black
34 | name: "black"
35 | args: ["--config=pyproject.toml"]
36 |
37 | # latest version of mypy at time pre-commit config is being set up
38 | # NOTE: only checks code in "syntheseus" directory.
39 | - repo: https://github.com/pre-commit/mirrors-mypy
40 | rev: v1.2.0
41 | hooks:
42 | - id: mypy
43 | name: "mypy"
44 | files: "syntheseus/"
45 | args: ["--install-types", "--non-interactive"]
46 |
47 | # Latest ruff (does linting + more)
48 | - repo: https://github.com/charliermarsh/ruff-pre-commit
49 | rev: 'v0.2.1'
50 | hooks:
51 | - id: ruff
52 | args: [--fix]
53 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/utils/metrics.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, TypeVar
3 |
4 | import numpy as np
5 |
6 | OutputType = TypeVar("OutputType")
7 |
8 |
9 | class TopKMetricsAccumulator:
10 | """Class to accumulate prediction top-k accuracy and MRR under a given notion of correctness."""
11 |
12 | def __init__(self, max_num_results: int):
13 | self._max_num_results = max_num_results
14 |
15 | # Initialize things we will need to compute the metrics.
16 | self._top_k_correct_cnt = np.zeros(max_num_results)
17 | self._sum_reciprocal_rank = 0.0
18 | self._num_samples = 0
19 |
20 | def add(self, is_output_correct: List[bool]) -> None:
21 | assert len(is_output_correct) <= self._max_num_results
22 |
23 | self._num_samples += 1
24 | for idx, output in enumerate(is_output_correct):
25 | if output:
26 | self._top_k_correct_cnt[idx:] += 1.0
27 | self._sum_reciprocal_rank += 1.0 / (idx + 1)
28 | break
29 |
30 | @property
31 | def num_samples(self) -> int:
32 | return self._num_samples
33 |
34 | @property
35 | def top_k(self) -> List[float]:
36 | return list(self._top_k_correct_cnt / self._num_samples)
37 |
38 | @property
39 | def mrr(self) -> float:
40 | return self._sum_reciprocal_rank / self._num_samples
41 |
42 |
43 | @dataclass(frozen=True)
44 | class ModelTimingResults:
45 | time_model_call: float
46 | time_post_processing: float
47 |
48 |
49 | def compute_total_time(timing_results: List[ModelTimingResults]) -> ModelTimingResults:
50 | return ModelTimingResults(
51 | **{
52 | key: sum(getattr(result, key) for result in timing_results)
53 | for key in ["time_model_call", "time_post_processing"]
54 | }
55 | )
56 |
--------------------------------------------------------------------------------
/syntheseus/tests/search/analysis/conftest.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import pytest
4 |
5 | from syntheseus.search.algorithms.breadth_first import AndOr_BreadthFirstSearch
6 | from syntheseus.search.analysis.route_extraction import iter_routes_cost_order
7 | from syntheseus.search.graph.and_or import AndNode, AndOrGraph
8 | from syntheseus.search.graph.molset import MolSetNode
9 | from syntheseus.search.graph.route import SynthesisGraph
10 | from syntheseus.tests.search.conftest import RetrosynthesisTask
11 |
12 |
13 | def set_uniform_costs(graph) -> None:
14 | """Set a unit cost of 1 for all nodes with reactions."""
15 | for node in graph.nodes():
16 | if isinstance(node, (MolSetNode, AndNode)):
17 | node.data["route_cost"] = 1.0
18 | else:
19 | node.data["route_cost"] = 0.0
20 |
21 |
22 | @pytest.fixture
23 | def andor_graph_with_many_routes(retrosynthesis_task6: RetrosynthesisTask) -> AndOrGraph:
24 | task = retrosynthesis_task6
25 | alg = AndOr_BreadthFirstSearch(
26 | reaction_model=task.reaction_model, mol_inventory=task.inventory, unique_nodes=True
27 | )
28 | output_graph, _ = alg.run_from_mol(task.target_mol)
29 | assert len(output_graph) == 278 # make sure number of nodes is always the same
30 | set_uniform_costs(output_graph)
31 | return output_graph
32 |
33 |
34 | @pytest.fixture
35 | def sample_synthesis_routes(andor_graph_with_many_routes: AndOrGraph) -> list[SynthesisGraph]:
36 | """Return 11 synthesis routes extracted from the graph of length <= 3."""
37 | output = list(
38 | iter_routes_cost_order(andor_graph_with_many_routes, max_routes=10_000, stop_cost=4.0)
39 | )
40 | assert len(output) == 11
41 | return [
42 | andor_graph_with_many_routes.to_synthesis_graph(route) # type: ignore # node type unclear
43 | for route in output
44 | ]
45 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/utils/model_loading.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Union
3 |
4 | from omegaconf import OmegaConf
5 |
6 | from syntheseus.interface.models import ReactionModel
7 | from syntheseus.reaction_prediction.inference.config import BackwardModelConfig, ForwardModelConfig
8 |
9 | logger = logging.getLogger(__file__)
10 |
11 |
12 | def get_model(
13 | config: Union[BackwardModelConfig, ForwardModelConfig],
14 | batch_size: int,
15 | num_gpus: int,
16 | **model_kwargs,
17 | ) -> ReactionModel:
18 | # Check that model kwargs don't overlap
19 | overlapping_kwargs = set(config.model_kwargs.keys()) & set(model_kwargs.keys())
20 | if overlapping_kwargs:
21 | raise ValueError(f"Model kwargs overlap: {overlapping_kwargs}")
22 |
23 | def model_fn(device):
24 | return config.model_class.value(
25 | model_dir=OmegaConf.select(config, "model_dir"),
26 | device=device,
27 | **config.model_kwargs,
28 | **model_kwargs,
29 | )
30 |
31 | if num_gpus == 0:
32 | return model_fn("cpu")
33 | elif num_gpus == 1:
34 | return model_fn("cuda:0")
35 | else:
36 | if batch_size < num_gpus:
37 | raise ValueError(f"Cannot split batch of size {batch_size} across {num_gpus} GPUs")
38 |
39 | batch_size_per_gpu = batch_size // num_gpus
40 |
41 | if batch_size_per_gpu < 16:
42 | logger.warning(f"Batch size per GPU is very small: ~{batch_size_per_gpu}")
43 |
44 | try:
45 | from syntheseus.reaction_prediction.utils.parallel import ParallelReactionModel
46 | except ModuleNotFoundError:
47 | raise ValueError("Multi-GPU evaluation is only supported for torch-based models")
48 |
49 | return ParallelReactionModel(
50 | model_fn=model_fn, devices=[f"cuda:{idx}" for idx in range(num_gpus)]
51 | )
52 |
--------------------------------------------------------------------------------
/syntheseus/tests/search/algorithms/test_random.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import pytest
4 |
5 | from syntheseus.search.algorithms.random import (
6 | AndOr_RandomSearch,
7 | MolSet_RandomSearch,
8 | )
9 | from syntheseus.tests.search.algorithms.test_base import BaseAlgorithmTest
10 | from syntheseus.tests.search.conftest import RetrosynthesisTask
11 |
12 |
13 | class BaseRandomSearchTest(BaseAlgorithmTest):
14 | """
15 | Base test for random search.
16 |
17 | We skip `test_found_routesX` because random search is very inefficient.
18 | """
19 |
20 | @pytest.mark.skip(reason="Random search is very inefficient")
21 | def test_found_routes1(self, retrosynthesis_task1: RetrosynthesisTask) -> None:
22 | pass
23 |
24 | @pytest.mark.skip(reason="Random search is very inefficient")
25 | def test_found_routes2(self, retrosynthesis_task2: RetrosynthesisTask) -> None:
26 | pass
27 |
28 | @pytest.mark.flaky(reruns=3)
29 | @pytest.mark.parametrize("limit", [0, 1, 2, 1000])
30 | def test_limit_iterations(
31 | self,
32 | retrosynthesis_task1: RetrosynthesisTask,
33 | retrosynthesis_task2: RetrosynthesisTask,
34 | retrosynthesis_task3: RetrosynthesisTask,
35 | limit: int,
36 | ) -> None:
37 | # Here we are just overriding the limits which are tested.
38 | # Random search is inefficient, so sometimes after 100 iterations not all tasks are solved.
39 | super().test_limit_iterations(
40 | retrosynthesis_task1, retrosynthesis_task2, retrosynthesis_task3, limit
41 | )
42 |
43 |
44 | class TestAndOrRandomSearch(BaseRandomSearchTest):
45 | def setup_algorithm(self, **kwargs) -> AndOr_RandomSearch:
46 | return AndOr_RandomSearch(**kwargs)
47 |
48 |
49 | class TestMolSetRandomSearch(BaseRandomSearchTest):
50 | def setup_algorithm(self, **kwargs) -> MolSet_RandomSearch:
51 | return MolSet_RandomSearch(**kwargs)
52 |
--------------------------------------------------------------------------------
/syntheseus/tests/reaction_prediction/inference/test_models.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from syntheseus.interface.bag import Bag
4 | from syntheseus.interface.molecule import Molecule
5 | from syntheseus.reaction_prediction.inference.base import ExternalBackwardReactionModel
6 | from syntheseus.reaction_prediction.inference.config import BackwardModelClass
7 | from syntheseus.reaction_prediction.utils.testing import are_single_step_models_installed
8 |
9 | pytestmark = pytest.mark.skipif(
10 | not are_single_step_models_installed(),
11 | reason="Model tests require all single-step models to be installed",
12 | )
13 |
14 |
15 | MODEL_CLASSES_TO_TEST = set(BackwardModelClass) - {BackwardModelClass.GLN}
16 |
17 |
18 | @pytest.fixture(scope="module", params=list(MODEL_CLASSES_TO_TEST) * 2)
19 | def model(request) -> ExternalBackwardReactionModel:
20 | model_cls = request.param.value
21 | return model_cls()
22 |
23 |
24 | def test_call(model: ExternalBackwardReactionModel) -> None:
25 | [result] = model([Molecule("Cc1ccc(-c2ccc(C)cc2)cc1")], num_results=20)
26 | model_predictions = [prediction.reactants for prediction in result]
27 |
28 | # Prepare some coupling reactions that are reasonable predictions for the product above.
29 | expected_predictions = [
30 | Bag([Molecule(f"Cc1ccc({leaving_group_1})cc1"), Molecule(f"Cc1ccc({leaving_group_2})cc1")])
31 | for leaving_group_1 in ["Br", "I"]
32 | for leaving_group_2 in ["B(O)O", "I", "[Mg+]"]
33 | ]
34 |
35 | # The model should recover at least two (out of six) in its top-20.
36 | assert len(set(expected_predictions) & set(model_predictions)) >= 2
37 |
38 |
39 | def test_misc(model: ExternalBackwardReactionModel) -> None:
40 | import torch
41 |
42 | assert isinstance(model.name, str)
43 | assert isinstance(model.get_model_info(), dict)
44 | assert model.is_backward() is not model.is_forward()
45 |
46 | for p in model.get_parameters():
47 | assert isinstance(p, torch.Tensor)
48 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/utils/parallel.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Callable, List, Sequence
3 |
4 | import torch
5 | from more_itertools import chunked
6 |
7 | from syntheseus.interface.models import InputType, ReactionModel, ReactionType
8 |
9 |
10 | class ParallelReactionModel(ReactionModel[InputType, ReactionType]):
11 | """Wraps an arbitrary `ReactionModel` to enable multi-GPU inference.
12 |
13 | Unlike most off-the-shelf multi-GPU approaches (e.g. strategies in `pytorch_lightning`,
14 | `nn.DataParallel`, `nn.DistributedDataParallel`), this class only handles inference (not
15 | training), and because of that it can be much looser in terms of the constraints the
16 | parallelized model has to satisfy. It also works with lists of inputs (chunking them up
17 | appropriately), whereas other approaches usually only work with tensors.
18 | """
19 |
20 | def __init__(self, *args, model_fn: Callable, devices: List, **kwargs) -> None:
21 | super().__init__(*args, **kwargs)
22 |
23 | self._devices = devices
24 | self._model_replicas = [model_fn(device=device) for device in devices]
25 |
26 | def _get_reactions(
27 | self, inputs: List[InputType], num_results: int
28 | ) -> List[Sequence[ReactionType]]:
29 | # Chunk up the inputs into (roughly) equal-sized chunks.
30 | chunk_size = math.ceil(len(inputs) / len(self._devices))
31 | input_chunks = list((input,) for input in chunked(inputs, chunk_size))
32 |
33 | # If `len(inputs)` is not divisible by `len(self._devices)` the last chunk may end up empty.
34 | num_chunks = len(input_chunks)
35 |
36 | outputs = torch.nn.parallel.parallel_apply(
37 | self._model_replicas[:num_chunks],
38 | input_chunks,
39 | tuple({"num_results": num_results} for _ in range(num_chunks)),
40 | self._devices[:num_chunks],
41 | )
42 |
43 | # Contatenate all outputs from the replicas.
44 | return sum(outputs, [])
45 |
46 | def is_forward(self) -> bool:
47 | return self._model_replicas[0].is_forward()
48 |
--------------------------------------------------------------------------------
/syntheseus/tests/reaction_prediction/chem/test_utils.py:
--------------------------------------------------------------------------------
1 | from rdkit import Chem
2 |
3 | from syntheseus import Bag, Molecule, Reaction, SingleProductReaction
4 | from syntheseus.reaction_prediction.chem.utils import (
5 | remove_atom_mapping,
6 | remove_atom_mapping_from_mol,
7 | remove_stereo_information,
8 | remove_stereo_information_from_reaction,
9 | )
10 |
11 |
12 | def test_remove_mapping() -> None:
13 | smiles_mapped = "[OH:1][CH2:2][c:3]1[cH:4][n:5][cH:6][cH:7][c:8]1[Br:9]"
14 | smiles_unmapped = "OCc1cnccc1Br"
15 |
16 | assert remove_atom_mapping(smiles_mapped) == smiles_unmapped
17 | assert remove_atom_mapping(smiles_unmapped) == smiles_unmapped
18 |
19 | mol = Chem.MolFromSmiles(smiles_mapped)
20 | remove_atom_mapping_from_mol(mol)
21 |
22 | assert Chem.MolToSmiles(mol) == smiles_unmapped
23 |
24 |
25 | def test_remove_stereo_information() -> None:
26 | mol = Molecule("CC(N)C#N")
27 | mols_chiral = [Molecule("C[C@H](N)C#N"), Molecule("C[C@@H](N)C#N")]
28 |
29 | assert len(set([mol] + mols_chiral)) == 3
30 | assert len(set([mol] + [remove_stereo_information(m) for m in mols_chiral])) == 1
31 |
32 |
33 | def test_remove_stereo_information_from_reaction() -> None:
34 | reactants = Bag([Molecule("CCC"), Molecule("CC(N)C#N")])
35 | reactants_chiral = Bag([Molecule("CCC"), Molecule("C[C@H](N)C#N")])
36 |
37 | product = Molecule("CC(N)C#N")
38 | product_chiral = Molecule("C[C@H](N)C#N")
39 |
40 | rxn = Reaction(reactants=reactants, products=Bag([product]))
41 | rxn_chiral = Reaction(reactants=reactants_chiral, products=Bag([product_chiral]))
42 | rxn_stereo_removed = remove_stereo_information_from_reaction(rxn_chiral)
43 |
44 | assert type(rxn_stereo_removed) is Reaction
45 | assert rxn_stereo_removed == rxn
46 |
47 | sp_rxn = SingleProductReaction(reactants=reactants, product=product)
48 | sp_rxn_chiral = SingleProductReaction(reactants=reactants_chiral, product=product_chiral)
49 | sp_rxn_stero_removed = remove_stereo_information_from_reaction(sp_rxn_chiral)
50 |
51 | assert type(sp_rxn_stero_removed) is SingleProductReaction
52 | assert sp_rxn_stero_removed == sp_rxn
53 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/chem/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, Union
2 |
3 | from rdkit import Chem
4 |
5 | from syntheseus import Bag, Molecule, SingleProductReaction
6 | from syntheseus.interface.models import ReactionType
7 | from syntheseus.interface.molecule import SMILES_SEPARATOR
8 |
9 | ATOM_MAPPING_PROP_NAME = "molAtomMapNumber"
10 |
11 |
12 | def remove_atom_mapping_from_mol(mol: Chem.Mol) -> None:
13 | """Removed the atom mapping from an rdkit molecule modifying it in place."""
14 | for atom in mol.GetAtoms():
15 | atom.ClearProp(ATOM_MAPPING_PROP_NAME)
16 |
17 |
18 | def remove_atom_mapping(smiles: str) -> str:
19 | """Removes the atom mapping from a SMILES string.
20 |
21 | Args:
22 | smiles: Molecule SMILES to be modified.
23 |
24 | Returns:
25 | str: Input SMILES with atom map numbers stripped away.
26 | """
27 | mol = Chem.MolFromSmiles(smiles)
28 | remove_atom_mapping_from_mol(mol)
29 |
30 | return Chem.MolToSmiles(mol)
31 |
32 |
33 | def remove_stereo_information(mol: Molecule) -> Molecule:
34 | return Molecule(Chem.MolToSmiles(mol.rdkit_mol, isomericSmiles=False))
35 |
36 |
37 | def remove_stereo_information_from_reaction(reaction: ReactionType) -> ReactionType:
38 | mol_kwargs: Dict[str, Union[Molecule, Bag[Molecule]]] = {
39 | "reactants": Bag([remove_stereo_information(mol) for mol in reaction.reactants])
40 | }
41 |
42 | if isinstance(reaction, SingleProductReaction):
43 | mol_kwargs["product"] = remove_stereo_information(reaction.product)
44 | else:
45 | mol_kwargs["products"] = Bag([remove_stereo_information(mol) for mol in reaction.products])
46 |
47 | return reaction.__class__(
48 | **mol_kwargs, identifier=reaction.identifier, metadata=reaction.metadata # type: ignore[arg-type]
49 | )
50 |
51 |
52 | def molecule_bag_from_smiles_strict(smiles: str) -> Bag[Molecule]:
53 | return Bag([Molecule(component) for component in smiles.split(SMILES_SEPARATOR)])
54 |
55 |
56 | def molecule_bag_from_smiles(smiles: str) -> Optional[Bag[Molecule]]:
57 | try:
58 | return molecule_bag_from_smiles_strict(smiles)
59 | except ValueError:
60 | # If any of the components ends up invalid we return `None` instead.
61 | return None
62 |
--------------------------------------------------------------------------------
/syntheseus/search/algorithms/breadth_first.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import collections
4 | import logging
5 | from typing import Generic
6 |
7 | from syntheseus.search.algorithms.base import (
8 | AndOrSearchAlgorithm,
9 | GraphType,
10 | MolSetSearchAlgorithm,
11 | SearchAlgorithm,
12 | )
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class GeneralBreadthFirstSearch(SearchAlgorithm[GraphType, int], Generic[GraphType]):
18 | """Base class for breadth first search algorithms (pseudo-code is the same for all data structures)."""
19 |
20 | @property
21 | def requires_tree(self) -> bool:
22 | return False # can work on any graph
23 |
24 | def _run_from_graph_after_setup(self, graph: GraphType) -> int:
25 | log_level = logging.DEBUG - 1
26 | logger_active = logger.isEnabledFor(log_level)
27 |
28 | queue = collections.deque([node for node in graph._graph.nodes() if not node.is_expanded])
29 | step = 0 # initialize this variable in case loop is not entered
30 | for step in range(self.limit_iterations):
31 | if self.should_stop_search(graph) or len(queue) == 0:
32 | break
33 |
34 | # Pop node and potentially expand it
35 | node = queue.popleft()
36 | if node.is_expanded:
37 | outcome = "already expanded, do nothing"
38 | else:
39 | new_nodes = self.expand_node(node, graph)
40 | self.set_node_values(set(new_nodes) | {node}, graph)
41 | queue.extend([n for n in new_nodes if self.can_expand_node(n, graph)])
42 | outcome = f"expanded, created {len(new_nodes)} new nodes"
43 |
44 | if logger_active:
45 | logger.log(
46 | log_level, f"Step {step}: node {node} {outcome}. Queue size: {len(queue)}"
47 | )
48 |
49 | return step
50 |
51 |
52 | class MolSet_BreadthFirstSearch(GeneralBreadthFirstSearch, MolSetSearchAlgorithm):
53 | @property
54 | def requires_tree(self) -> bool:
55 | # Even though it "could" work, molset graphs with unique nodes are not
56 | # well-supported so we don't allow it at this time.
57 | return True
58 |
59 |
60 | class AndOr_BreadthFirstSearch(GeneralBreadthFirstSearch, AndOrSearchAlgorithm):
61 | pass
62 |
--------------------------------------------------------------------------------
/syntheseus/tests/search/test_visualization.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests that vizualization runs. Tests are skipped if graphviz is not installed.
3 | """
4 |
5 | import tempfile
6 |
7 | import pytest
8 |
9 | from syntheseus.search.graph.and_or import AndOrGraph
10 | from syntheseus.search.graph.molset import MolSetGraph
11 |
12 | visualization = pytest.importorskip("syntheseus.search.visualization")
13 |
14 |
15 | @pytest.mark.parametrize("draw_mols", [False, True])
16 | @pytest.mark.parametrize("partial_graph", [False, True])
17 | def test_andor_visualization(
18 | andor_graph_non_minimal: AndOrGraph, draw_mols: bool, partial_graph: bool
19 | ) -> None:
20 | if partial_graph:
21 | nodes = [n for n in andor_graph_non_minimal.nodes() if n.depth <= 2]
22 | else:
23 | nodes = None
24 |
25 | # Visualize, both with and without drawing mols
26 | with tempfile.TemporaryDirectory() as tmp_dir:
27 | visualization.visualize_andor(
28 | graph=andor_graph_non_minimal,
29 | filename=f"{tmp_dir}/tmp.pdf",
30 | draw_mols=draw_mols,
31 | nodes=nodes,
32 | )
33 |
34 |
35 | @pytest.mark.parametrize("draw_mols", [False, True])
36 | @pytest.mark.parametrize("partial_graph", [False, True])
37 | def test_molset_visualization(
38 | molset_tree_non_minimal: MolSetGraph, draw_mols: bool, partial_graph: bool
39 | ) -> None:
40 | if partial_graph:
41 | nodes = [n for n in molset_tree_non_minimal.nodes() if n.depth <= 2]
42 | else:
43 | nodes = None
44 | with tempfile.TemporaryDirectory() as tmp_dir:
45 | visualization.visualize_molset(
46 | graph=molset_tree_non_minimal,
47 | filename=f"{tmp_dir}/tmp.pdf",
48 | draw_mols=draw_mols,
49 | nodes=nodes,
50 | )
51 |
52 |
53 | def test_filename_ends_with_pdf(
54 | molset_tree_non_minimal: MolSetGraph,
55 | andor_graph_non_minimal: AndOrGraph,
56 | ) -> None:
57 | """Test that an error is raised if the file name doesn't end in .pdf"""
58 |
59 | with pytest.raises(ValueError):
60 | visualization.visualize_andor(
61 | graph=andor_graph_non_minimal,
62 | filename="tmp.xyz",
63 | )
64 | with pytest.raises(ValueError):
65 | visualization.visualize_molset(
66 | graph=molset_tree_non_minimal,
67 | filename="tmp.xyz",
68 | )
69 |
--------------------------------------------------------------------------------
/syntheseus/search/algorithms/random.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 | from typing import Generic
5 |
6 | from syntheseus.search.algorithms.base import (
7 | AndOrSearchAlgorithm,
8 | GraphType,
9 | MolSetSearchAlgorithm,
10 | SearchAlgorithm,
11 | )
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | class BaseRandomSearch(SearchAlgorithm[GraphType, int], Generic[GraphType]):
17 | """Base class for both AND/OR and MolSet random search algorithms."""
18 |
19 | @property
20 | def requires_tree(self) -> bool:
21 | return False # can work on any graph
22 |
23 | def _run_from_graph_after_setup(self, graph: GraphType) -> int:
24 | log_level = logging.DEBUG - 1
25 | logger_active = logger.isEnabledFor(log_level)
26 |
27 | # Initialize a set of nodes which can be expanded
28 | expandable_nodes = {node for node in graph.nodes() if self.can_expand_node(node, graph)}
29 |
30 | step = 0 # initialize this variable in case loop is not entered
31 | for step in range(self.limit_iterations):
32 | if self.should_stop_search(graph) or len(expandable_nodes) == 0:
33 | break
34 |
35 | # Choose a random node to expand
36 | node = self.random_state.choice(list(expandable_nodes))
37 | expandable_nodes.remove(node)
38 |
39 | # Expand the node
40 | new_nodes = self.expand_node(node, graph)
41 | self.set_node_values(set(new_nodes) | {node}, graph)
42 | for n in new_nodes:
43 | if self.can_expand_node(n, graph):
44 | expandable_nodes.add(n)
45 |
46 | if logger_active:
47 | logger.log(
48 | log_level,
49 | f"Step {step}: node {node} expanded, created {len(new_nodes)} new nodes. "
50 | f"Num expandable nodes: {len(expandable_nodes)}.",
51 | )
52 |
53 | return step
54 |
55 |
56 | class MolSet_RandomSearch(BaseRandomSearch, MolSetSearchAlgorithm):
57 | @property
58 | def requires_tree(self) -> bool:
59 | # Even though it "could" work, molset graphs with unique nodes are not
60 | # well-supported so we don't allow it at this time.
61 | return True
62 |
63 |
64 | class AndOr_RandomSearch(BaseRandomSearch, AndOrSearchAlgorithm):
65 | pass
66 |
--------------------------------------------------------------------------------
/syntheseus/search/node_evaluation/common.py:
--------------------------------------------------------------------------------
1 | """Common node evaluation functions."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import Sequence, Union
6 |
7 | from syntheseus.interface.reaction import SingleProductReaction
8 | from syntheseus.search.graph.and_or import AndNode
9 | from syntheseus.search.graph.molset import MolSetNode
10 | from syntheseus.search.node_evaluation.base import NoCacheNodeEvaluator, ReactionModelBasedEvaluator
11 |
12 |
13 | class ConstantNodeEvaluator(NoCacheNodeEvaluator):
14 | def __init__(self, constant: float, **kwargs):
15 | super().__init__(**kwargs)
16 | self.constant = constant
17 |
18 | def _evaluate_nodes(self, nodes, graph=None):
19 | return [self.constant] * len(nodes)
20 |
21 |
22 | class HasSolutionValueFunction(NoCacheNodeEvaluator):
23 | def _evaluate_nodes(self, nodes, graph=None):
24 | return [float(n.has_solution) for n in nodes]
25 |
26 |
27 | class ReactionModelLogProbCost(ReactionModelBasedEvaluator[AndNode]):
28 | """Evaluator that uses the reactions' negative logprob to form a cost (useful for Retro*)."""
29 |
30 | def __init__(self, **kwargs) -> None:
31 | super().__init__(return_log=True, **kwargs)
32 |
33 | def _get_reaction(self, node: AndNode, graph) -> SingleProductReaction:
34 | return node.reaction
35 |
36 | def _evaluate_nodes(self, nodes, graph=None) -> Sequence[float]:
37 | return [-v for v in super()._evaluate_nodes(nodes, graph)]
38 |
39 |
40 | class ReactionModelProbPolicy(ReactionModelBasedEvaluator[Union[MolSetNode, AndNode]]):
41 | """Evaluator that uses the reactions' probability to form a policy (useful for MCTS)."""
42 |
43 | def __init__(self, **kwargs) -> None:
44 | kwargs["normalize"] = kwargs.get("normalize", True) # set `normalize = True` by default
45 | super().__init__(return_log=False, **kwargs)
46 |
47 | def _get_reaction(self, node: Union[MolSetNode, AndNode], graph) -> SingleProductReaction:
48 | if isinstance(node, MolSetNode):
49 | parents = list(graph.predecessors(node))
50 | assert len(parents) == 1, "Graph must be a tree"
51 | return graph._graph.edges[parents[0], node]["reaction"]
52 | elif isinstance(node, AndNode):
53 | return node.reaction
54 | else:
55 | raise ValueError(f"ReactionModelProbPolicy does not support nodes of type {type(node)}")
56 |
--------------------------------------------------------------------------------
/docs/images/logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
63 |
--------------------------------------------------------------------------------
/docs/single_step.md:
--------------------------------------------------------------------------------
1 | Syntheseus currently supports 8 established single-step models.
2 |
3 | For convenience, for each model we include a default checkpoint trained on USPTO-50K.
4 | If no checkpoint directory is provided during model loading, `syntheseus` will automatically download a default checkpoint and cache it on disk for future use.
5 | The default path for the cache is `$HOME/.cache/torch/syntheseus`, but it can be overriden by setting the `SYNTHESEUS_CACHE_DIR` environment variable.
6 | See table below for the links to the default checkpoints.
7 |
8 | | Model checkpoint link | Source |
9 | |----------------------------------------------------------------|--------|
10 | | [Chemformer](https://figshare.com/ndownloader/files/42009888) | finetuned by us starting from checkpoint released by authors |
11 | | [GLN](https://figshare.com/ndownloader/files/45882867) | released by authors |
12 | | [Graph2Edits](https://figshare.com/ndownloader/files/44194301) | released by authors |
13 | | [LocalRetro](https://figshare.com/ndownloader/files/42287319) | trained by us |
14 | | [MEGAN](https://figshare.com/ndownloader/files/42012732) | trained by us |
15 | | [MHNreact](https://figshare.com/ndownloader/files/42012777) | trained by us |
16 | | [RetroKNN](https://figshare.com/ndownloader/files/45662430) | trained by us |
17 | | [RootAligned](https://figshare.com/ndownloader/files/42012792) | released by authors |
18 |
19 | ??? note "More advanced datasets"
20 |
21 | The USPTO-50K dataset is well-established but relatively small. Advanced users may prefer to retrain their models of interest on a larger dataset, such as USPTO-FULL or Pistachio. To do that, please follow the instructions in the original model repositories.
22 |
23 | In `reaction_prediction/cli/eval.py` a forward model can be used for computing back-translation (round-trip) accuracy.
24 | See [here](https://figshare.com/ndownloader/files/42012708) for a Chemformer checkpoint finetuned for forward prediction on USPTO-50K. As for the backward direction, pretrained weights released by original authors were used as a starting point.
25 |
26 | ??? info "Licenses"
27 | All checkpoints were produced in a way that involved external model repositories, hence may be affected by the exact license each model was released with.
28 | For more details about a particular model see the top of the corresponding model wrapper file in `reaction_prediction/inference/`.
29 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/utils/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | from typing import Any, Callable, Dict, List, Optional, TypeVar, Union, cast
4 |
5 | from omegaconf import DictConfig, ListConfig, OmegaConf
6 |
7 | R = TypeVar("R")
8 |
9 |
10 | def get_config(
11 | argv: Optional[List[str]],
12 | config_cls: Callable[..., R],
13 | defaults: Optional[Dict[str, Any]] = None,
14 | ) -> R:
15 | """
16 | Utility function to get `OmegaConf` config options.
17 |
18 | Args:
19 | argv: Either a list of command line arguments to parse, or `None`. If `None`, this argument
20 | is set from `sys.argv`.
21 | config_cls: Dataclass object specifying config structure (i.e. which fields to expect in the
22 | config). It should be the class itself, not an instance of the class.
23 |
24 | Returns:
25 | Config object, which will pass as an instance of `config_cls` among other things. Note: the
26 | type for this could be specified more carefully, but `OmegaConf`'s typing system is a bit
27 | complex. Search `OmegaConf`'s docs for "structured" for more info.
28 | """
29 |
30 | if argv is None:
31 | argv = sys.argv[1:]
32 | # Parse command line arguments
33 | parser = argparse.ArgumentParser(allow_abbrev=False) # prevent prefix matching issues
34 | parser.add_argument(
35 | "--config",
36 | type=str,
37 | action="append",
38 | default=list(),
39 | help="Path to a yaml config file. "
40 | "Argument can be repeated multiple times, with later configs overwriting previous ones.",
41 | )
42 | args, config_changes = parser.parse_known_args(argv)
43 |
44 | # Read configs from defaults, file and command line
45 | conf_yamls: List[Union[DictConfig, ListConfig]] = []
46 | if defaults:
47 | conf_yamls = [OmegaConf.create(defaults)]
48 |
49 | conf_yamls += [OmegaConf.load(c) for c in args.config]
50 | conf_cli = OmegaConf.from_cli(config_changes)
51 |
52 | # Make merged config options
53 | # CLI options take priority over YAML file options
54 | schema = OmegaConf.structured(config_cls)
55 | config = OmegaConf.merge(schema, *conf_yamls, conf_cli)
56 | OmegaConf.set_readonly(config, True) # should not be written to
57 | return cast(R, config)
58 |
59 |
60 | def get_error_message_for_missing_value(name: str, possible_values: List[str]) -> str:
61 | return f"{name} should be set to one of [{', '.join(possible_values)}]"
62 |
--------------------------------------------------------------------------------
/syntheseus/tests/search/test_mol_inventory.py:
--------------------------------------------------------------------------------
1 | """Tests for MolInventory objects, focusing on the provided SmilesListInventory."""
2 |
3 | import pytest
4 |
5 | from syntheseus.interface.molecule import Molecule
6 | from syntheseus.search.mol_inventory import SmilesListInventory
7 |
8 | PURCHASABLE_SMILES = ["CC", "c1ccccc1", "CCO"]
9 | NON_PURCHASABLE_SMILES = ["C", "C1CCCCC1", "OCCO"]
10 |
11 |
12 | @pytest.fixture
13 | def example_inventory() -> SmilesListInventory:
14 | """Returns a SmilesListInventory with arbitrary molecules."""
15 | return SmilesListInventory(PURCHASABLE_SMILES)
16 |
17 |
18 | def test_is_purchasable(example_inventory: SmilesListInventory) -> None:
19 | """
20 | Does the 'is_purchasable' method return true only for purchasable SMILES?
21 | """
22 | for sm in PURCHASABLE_SMILES:
23 | assert example_inventory.is_purchasable(Molecule(sm))
24 |
25 | for sm in NON_PURCHASABLE_SMILES:
26 | assert not example_inventory.is_purchasable(Molecule(sm))
27 |
28 |
29 | def test_fill_metadata(example_inventory: SmilesListInventory) -> None:
30 | """
31 | Does the 'fill_metadata' method accurately fill the metadata?
32 | Currently it only checks that the `is_purchasable` key is filled correctly.
33 | At least it should add the 'is_purchasable' key.
34 | """
35 |
36 | for sm in PURCHASABLE_SMILES + NON_PURCHASABLE_SMILES:
37 | # Make initial molecule without any metadata
38 | mol = Molecule(sm)
39 | assert "is_purchasable" not in mol.metadata
40 |
41 | # Fill metadata and check that it is filled accurately.
42 | # To also handle the case where the metadata is filled, we run the test twice.
43 | for _ in range(2):
44 | example_inventory.fill_metadata(mol)
45 | assert mol.metadata["is_purchasable"] == example_inventory.is_purchasable(mol)
46 |
47 | # corrupt metadata so that next iteration the metadata is filled
48 | # and should be overwritten.
49 | # Type ignore is because we fill in random invalid metadata
50 | mol.metadata["is_purchasable"] = "abc" # type: ignore[typeddict-item]
51 |
52 |
53 | def test_to_purchasable_mols(example_inventory: SmilesListInventory) -> None:
54 | """
55 | Does the 'to_purchasable_mols' method work correctly? It should return a collection
56 | of all the purchasable molecules.
57 | """
58 | expected_set = {Molecule(sm) for sm in PURCHASABLE_SMILES}
59 | observed_set = set(example_inventory.to_purchasable_mols())
60 | assert expected_set == observed_set
61 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/models/retro_knn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 |
6 | class Adapter(nn.Module):
7 | def __init__(self, dim, k=32):
8 | from dgl.nn import GINEConv
9 |
10 | super().__init__()
11 | self.gnn = GINEConv(nn.Linear(dim, dim))
12 | self.node_proj = nn.Linear(dim, 2) # [tmp, p]
13 | self.edge_proj = nn.Linear(dim, 2) # [tmp, p]
14 |
15 | self.node_dist_in_proj = nn.Linear(k, k)
16 | self.edge_dist_in_proj = nn.Linear(k, k)
17 |
18 | self.node_ffn = nn.Linear(dim + k, dim)
19 | self.edge_ffn = nn.Linear(dim * 2 + k, dim)
20 |
21 | nn.init.kaiming_uniform_(self.node_dist_in_proj.weight, a=100)
22 | nn.init.kaiming_uniform_(self.edge_dist_in_proj.weight, a=100)
23 |
24 | nn.init.zeros_(self.node_proj.weight)
25 | nn.init.zeros_(self.edge_proj.weight)
26 |
27 | nn.init.constant_(self.node_proj.bias[0], 10.0)
28 | nn.init.constant_(self.edge_proj.bias[0], 10.0)
29 |
30 | def forward(self, g, nfeat, efeat, ndist, edist):
31 | from local_retro.scripts.model_utils import pair_atom_feats
32 |
33 | x = self.gnn(g, nfeat, efeat)
34 | x = F.relu(x)
35 |
36 | ndist = F.relu(self.node_dist_in_proj(ndist))
37 | edist = F.relu(self.edge_dist_in_proj(edist))
38 |
39 | node_x = torch.cat((x, ndist), dim=-1)
40 | node_x = self.node_ffn(node_x)
41 | node_x = F.relu(node_x)
42 | node_x = self.node_proj(node_x)
43 |
44 | edge_x = pair_atom_feats(g, x)
45 | edge_x = F.relu(edge_x)
46 | edge_x = torch.cat((edge_x, edist), dim=-1)
47 | edge_x = self.edge_ffn(edge_x)
48 | edge_x = F.relu(edge_x)
49 | edge_x = self.edge_proj(edge_x)
50 |
51 | node_t = torch.clamp(node_x[:, 0], 1, 100)
52 | node_p = torch.sigmoid(node_x[:, 1])
53 |
54 | edge_t = torch.clamp(edge_x[:, 0], 1, 100)
55 | edge_p = torch.sigmoid(edge_x[:, 1])
56 |
57 | return (r.unsqueeze(-1) for r in (node_t, node_p, edge_t, edge_p))
58 |
59 |
60 | def knn_prob(feats, store, lables, max_idx, k=32, temperature=5):
61 | from torch_scatter import scatter
62 |
63 | dis, idx = store.search(feats, k) # [B, K]
64 | pred = lables[idx].unsqueeze(-1) # [B, K, 1]
65 |
66 | re_compute_dists = -1 * dis
67 | knn_weight = torch.softmax(re_compute_dists / temperature, dim=-1).unsqueeze(-1) # [B, K, 1]
68 |
69 | bsz = feats.shape[0]
70 | output = torch.zeros(bsz, k, max_idx).to(feats)
71 |
72 | scatter(src=knn_weight, out=output, index=pred, dim=-1)
73 |
74 | return output.sum(dim=1)
75 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/utils/downloading.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import urllib.request
4 | import zipfile
5 | from pathlib import Path
6 | from typing import Optional
7 |
8 | import yaml
9 |
10 | logger = logging.getLogger(__file__)
11 |
12 |
13 | def get_cache_dir(key: str) -> Path:
14 | """Get the cache directory for a given key (e.g. model name)."""
15 |
16 | # First, check if the cache directory has been manually overriden.
17 | cache_dir_from_env = os.getenv("SYNTHESEUS_CACHE_DIR")
18 | if cache_dir_from_env is not None:
19 | # If yes, use the path provided.
20 | cache_dir = Path(cache_dir_from_env)
21 | else:
22 | # If not, construct a reasonable default.
23 | cache_dir = Path(os.getenv("HOME", ".")) / ".cache" / "torch" / "syntheseus"
24 |
25 | cache_dir = cache_dir / key
26 | cache_dir.mkdir(parents=True, exist_ok=True)
27 |
28 | return cache_dir
29 |
30 |
31 | def get_cache_dir_download_if_missing(key: str, link: str) -> Path:
32 | """Get the cache directory for a given key, but populate by downloading from link if empty."""
33 |
34 | cache_dir = get_cache_dir(key)
35 | if not any(cache_dir.iterdir()):
36 | cache_zip_path = cache_dir / "model.zip"
37 |
38 | logger.info(f"Downloading data from {link} to {cache_zip_path}")
39 | urllib.request.urlretrieve(link, cache_zip_path)
40 |
41 | with zipfile.ZipFile(cache_zip_path, "r") as f_zip:
42 | f_zip.extractall(cache_dir)
43 |
44 | cache_zip_path.unlink()
45 |
46 | return cache_dir
47 |
48 |
49 | def get_default_model_dir_from_cache(model_name: str, is_forward: bool) -> Optional[Path]:
50 | default_model_links_file_path = (
51 | Path(__file__).parent.parent / "inference" / "default_checkpoint_links.yml"
52 | )
53 |
54 | if not default_model_links_file_path.exists():
55 | logger.info(
56 | f"Could not obtain a default model link: {default_model_links_file_path} does not exist"
57 | )
58 | return None
59 |
60 | with open(default_model_links_file_path, "rt") as f_defaults:
61 | default_model_links = yaml.safe_load(f_defaults)
62 |
63 | assert default_model_links.keys() == {"backward", "forward"}
64 |
65 | forward_backward_key = "forward" if is_forward else "backward"
66 | model_links = default_model_links[forward_backward_key]
67 |
68 | if model_name not in model_links:
69 | logger.info(f"Could not obtain a default model link: no entry for {model_name}")
70 | return None
71 |
72 | return get_cache_dir_download_if_missing(
73 | f"{model_name}_{forward_backward_key}", link=model_links[model_name]
74 | )
75 |
--------------------------------------------------------------------------------
/docs/cli/eval_single_step.md:
--------------------------------------------------------------------------------
1 | # Single-step Evaluation
2 |
3 | ## Usage
4 |
5 | ```
6 | syntheseus eval-single-step \
7 | data_dir=[DATA_DIR] \
8 | fold=[TRAIN, VAL or TEST] \
9 | model_class=[MODEL_CLASS] \
10 | model_dir=[MODEL_DIR]
11 | ```
12 |
13 | The `eval-single-step` command accepts further arguments to customize the evaluation; see `BaseEvalConfig` in `cli/eval_single_step.py` for the complete list.
14 |
15 | The code will scan `data_dir` looking for files matching `*{train, val, test}.{jsonl, csv, smi}` and select the right data format based on the file extension. An error will be raised in case of ambiguity. Only the fold that was selected for evaluation has to be present.
16 |
17 | ## Data format
18 |
19 | The single-step evaluation script supports reaction data in one of three formats.
20 |
21 | ### JSONL
22 |
23 | Our internal format based on `*.jsonl` files in which each line is a JSON representation of a single reaction, for example:
24 | ```json
25 | {"reactants": [{"smiles": "Cc1ccc(Br)cc1"}, {"smiles": "Cc1ccc(B(O)O)cc1"}], "products": [{"smiles": "Cc1ccc(-c2ccc(C)cc2)cc1"}]}
26 | ```
27 | This format is designed to be flexible at the expense of taking more disk space. The JSON is parsed into a `ReactionSample` object, so it can include additional metadata such as template information, while the reactants and products can include other fields accepted by the `Molecule` object. The evaluation script will only use reactant and product SMILES to compute the metrics.
28 |
29 | Unlike the other formats below, reactants and products in this format are assumed to be already stripped of atom mapping, which leads to slightly faster data loading as that avoids extra calls to `rdkit`.
30 |
31 | ### CSV
32 |
33 | This format is based on `*.csv` files and is commonly used to store raw USPTO data, e.g. as released by [Dai et al.](https://github.com/Hanjun-Dai/GLN):
34 |
35 | ```
36 | id,class,reactants>reagents>production
37 | ID,UNK,[cH:1]1[cH:2][c:3]([CH3:4])[cH:5][cH:6][c:7]1Br.B(O)(O)[c:8]1[cH:9][cH:10][c:11]([CH3:12])[cH:13][cH:14]1>>[cH:1]1[cH:2][c:3]([CH3:4])[cH:5][cH:6][c:7]1[c:8]2[cH:14][cH:13][c:11]([CH3:12])[cH:10][cH:9]2
38 | ```
39 |
40 | The evaluation script will look for the `reactants>reagents>production` column to extract the reaction SMILES, which are stripped of atom mapping and canonicalized before being fed to the model.
41 |
42 | ### SMILES
43 |
44 | The most compact format is to list reaction SMILES line-by-line in a `*.smi` file:
45 |
46 | ```
47 | [cH:1]1[cH:2][c:3]([CH3:4])[cH:5][cH:6][c:7]1Br.B(O)(O)[c:8]1[cH:9][cH:10][c:11]([CH3:12])[cH:13][cH:14]1>>[cH:1]1[cH:2][c:3]([CH3:4])[cH:5][cH:6][c:7]1[c:8]2[cH:14][cH:13][c:11]([CH3:12])[cH:10][cH:9]2
48 | ```
49 |
50 | The data will be handled in the same way as for the CSV format, i.e. it will only be fed into model after removing atom mapping and canonicalization.
51 |
--------------------------------------------------------------------------------
/syntheseus/search/analysis/solution_time.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from syntheseus.search.graph.and_or import AndNode, OrNode
4 | from syntheseus.search.graph.base_graph import RetrosynthesisSearchGraph
5 | from syntheseus.search.graph.message_passing import run_message_passing
6 | from syntheseus.search.graph.molset import MolSetNode
7 | from syntheseus.search.graph.node import BaseGraphNode
8 |
9 |
10 | def get_first_solution_time(graph: RetrosynthesisSearchGraph) -> float:
11 | """Get the time of the first solution. Also sets 'first_solution_time' node attribute."""
12 | run_message_passing(
13 | graph=graph,
14 | nodes=list(graph._graph.nodes()),
15 | update_fns=[first_solution_time_update],
16 | update_successors=False, # only affects predecessor nodes
17 | )
18 | return graph.root_node.data["first_solution_time"]
19 |
20 |
21 | def first_solution_time_update(node: BaseGraphNode, graph: RetrosynthesisSearchGraph) -> bool:
22 | NO_SOLUTION_TIME = math.inf # being unsolved = inf time until solution found
23 |
24 | # Calculate "intrinsic solution time"
25 | if node._has_intrinsic_solution():
26 | intrinsic_solution_time = node.data["analysis_time"]
27 | else:
28 | intrinsic_solution_time = NO_SOLUTION_TIME
29 |
30 | # Calculate solution age from children
31 | children_soln_time_list = [
32 | c.data.get("first_solution_time", NO_SOLUTION_TIME) for c in graph.successors(node)
33 | ]
34 | if len(children_soln_time_list) == 0:
35 | children_solution_time = NO_SOLUTION_TIME
36 | elif isinstance(node, (OrNode, MolSetNode)):
37 | # Or node is first solved when one child is solved,
38 | # so its solution time is the min of its children's
39 | children_solution_time = min(children_soln_time_list)
40 | elif isinstance(node, AndNode):
41 | # AndNode requires all children to be solved,
42 | # so it is first solved when the LAST child is solved
43 | children_solution_time = max(children_soln_time_list)
44 | else:
45 | raise TypeError(f"Node type {type(node)} not supported.")
46 |
47 | # Min solution time is time of first intrinsic solution or solution from children
48 | new_min_soln_time = min(intrinsic_solution_time, children_solution_time)
49 |
50 | # Correct one case that can arise with loops: the children could potentially
51 | # be solved before this node was created.
52 | # Ensure that new min soln time is at least this node's age!
53 | new_min_soln_time = max(new_min_soln_time, node.data["analysis_time"])
54 |
55 | # Perform update
56 | old_min_solution_time = node.data.get("first_solution_time")
57 | node.data["first_solution_time"] = new_min_soln_time
58 | return old_min_solution_time is None or not math.isclose(
59 | old_min_solution_time, new_min_soln_time
60 | )
61 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/data/reaction_sample.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import inspect
4 | from dataclasses import dataclass, field
5 | from typing import Any, Dict, Optional, Type, TypeVar
6 |
7 | from syntheseus.interface.bag import Bag
8 | from syntheseus.interface.molecule import SMILES_SEPARATOR, Molecule
9 | from syntheseus.interface.reaction import REACTION_SEPARATOR, Reaction
10 | from syntheseus.reaction_prediction.chem.utils import remove_atom_mapping
11 | from syntheseus.reaction_prediction.utils.misc import undictify_bag_of_molecules
12 |
13 | ReactionType = TypeVar("ReactionType", bound="ReactionSample")
14 |
15 |
16 | @dataclass(frozen=True, order=False)
17 | class ReactionSample(Reaction):
18 | """Extends a `Reaction` with fields and methods relevant to data loading and saving."""
19 |
20 | reagents: str = field(default="", hash=True, compare=True)
21 | mapped_reaction_smiles: Optional[str] = field(default=None, hash=False, compare=False)
22 |
23 | @property
24 | def reaction_smiles_with_reagents(self) -> str:
25 | return (
26 | f"{self.reactants_str}{REACTION_SEPARATOR}"
27 | f"{self.reagents}{REACTION_SEPARATOR}"
28 | f"{self.products_str}"
29 | )
30 |
31 | @classmethod
32 | def from_dict(cls: Type[ReactionType], data: Dict[str, Any]) -> ReactionType:
33 | """Creates a sample from the given arguments ignoring superfluous ones."""
34 | for key in ["reactants", "products"]:
35 | data[key] = undictify_bag_of_molecules(data[key])
36 |
37 | return cls(
38 | **{
39 | key: value
40 | for key, value in data.items()
41 | if key in inspect.signature(cls).parameters
42 | }
43 | )
44 |
45 | @classmethod
46 | def from_reaction_smiles_strict(
47 | cls: Type[ReactionType], reaction_smiles: str, mapped: bool, **kwargs
48 | ) -> ReactionType:
49 | # Split the reaction SMILES and discard the reagents.
50 | [reactants_smiles, reagents_smiles, products_smiles] = [
51 | smiles_part.split(SMILES_SEPARATOR)
52 | for smiles_part in reaction_smiles.split(REACTION_SEPARATOR)
53 | ]
54 |
55 | if mapped:
56 | assert "mapped_reaction_smiles" not in kwargs
57 | kwargs["mapped_reaction_smiles"] = reaction_smiles
58 |
59 | reactants_smiles = [remove_atom_mapping(smiles) for smiles in reactants_smiles]
60 | products_smiles = [remove_atom_mapping(smiles) for smiles in products_smiles]
61 |
62 | return cls(
63 | reactants=Bag(Molecule(smiles=smiles) for smiles in reactants_smiles),
64 | products=Bag(Molecule(smiles=smiles) for smiles in products_smiles),
65 | reagents=SMILES_SEPARATOR.join(sorted(reagents_smiles)),
66 | **kwargs,
67 | )
68 |
69 | @classmethod
70 | def from_reaction_smiles(cls: Type[ReactionType], *args, **kwargs) -> Optional[ReactionType]:
71 | try:
72 | return cls.from_reaction_smiles_strict(*args, **kwargs)
73 | except Exception:
74 | return None
75 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "syntheseus"
7 | authors = [ # Note: PyPI does not support multiple authors with email addresses included.
8 | {name = "Austin Tripp"}, # ajt212@cam.ac.uk
9 | {name = "Krzysztof Maziarz"}, # krzysztof.maziarz@microsoft.com
10 | {name = "Guoqing Liu"}, # guoqingliu@microsoft.com
11 | {name = "Megan Stanley"}, # meganstanley@microsoft.com
12 | {name = "Marwin Segler"}, # marwinsegler@microsoft.com
13 | ]
14 | description = "A package for retrosynthetic planning."
15 | readme = "README.md"
16 | requires-python = ">=3.8"
17 | license = {file = "LICENSE"}
18 | dynamic = ["version"]
19 | dependencies = [
20 | "more_itertools", # reaction_prediction
21 | "networkx", # search
22 | "numpy", # reaction_prediction, search
23 | "omegaconf", # reaction_prediction
24 | "rdkit", # reaction_prediction, search
25 | "tqdm", # reaction_prediction
26 | ]
27 |
28 | [project.optional-dependencies]
29 | viz = [
30 | "pillow",
31 | "graphviz"
32 | ]
33 | dev = [
34 | "pytest",
35 | "pytest-cov",
36 | "pytest-rerunfailures",
37 | "pre-commit"
38 | ]
39 | chemformer = ["syntheseus-chemformer==0.3.0"]
40 | graph2edits = ["syntheseus-graph2edits==0.2.0"]
41 | local-retro = ["syntheseus-local-retro==0.5.0"]
42 | megan = ["syntheseus-megan==0.2.0"]
43 | mhn-react = ["syntheseus-mhnreact==1.0.0"]
44 | retro-knn = ["syntheseus[local-retro]"]
45 | root-aligned = ["syntheseus-root-aligned==0.2.0"]
46 | all-single-step = [
47 | "syntheseus[chemformer,graph2edits,local-retro,megan,mhn-react,retro-knn,root-aligned]"
48 | ]
49 | all = [
50 | "syntheseus[viz,dev,all-single-step]"
51 | ]
52 |
53 | [project.urls]
54 | Documentation = "https://microsoft.github.io/syntheseus"
55 | Repository = "https://github.com/microsoft/syntheseus"
56 |
57 | [project.scripts]
58 | syntheseus = "syntheseus.cli.main:main"
59 |
60 | [tool.setuptools.packages.find]
61 | # Guidance from: https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html
62 | where = ["."]
63 | include = ["syntheseus*"]
64 | exclude = ["syntheseus.tests*"]
65 | namespaces = false
66 |
67 | [tool.setuptools.package-data]
68 | "syntheseus" = ["py.typed"]
69 |
70 | [tool.setuptools_scm]
71 |
72 | [tool.black]
73 | line-length = 100
74 | include = '\.pyi?$'
75 |
76 | [tool.mypy]
77 | python_version = 3.9 # pin modern python version
78 | ignore_missing_imports = true
79 |
80 | [tool.ruff]
81 | line-length = 100
82 | # Check https://beta.ruff.rs/docs/rules/ for full list of rules
83 | lint.select = [
84 | "E", "W", # pycodestyle
85 | "F", # Pyflakes
86 | "I", # isort
87 | "NPY201", # check for functions/constants deprecated in numpy 2.*
88 | ]
89 | lint.ignore = [
90 | # W605: invalid escape sequence -- triggered by pseudo-LaTeX in comments
91 | "W605",
92 | # E501: Line too long -- triggered by comments and such. black deals with shortening.
93 | "E501",
94 | # E741: Do not use variables named 'l', 'o', or 'i' -- disagree with PEP8
95 | "E741",
96 | ]
97 |
--------------------------------------------------------------------------------
/syntheseus/search/graph/route.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections import Counter
4 | from typing import Sequence, Union
5 |
6 | from syntheseus.interface.molecule import Molecule
7 | from syntheseus.interface.reaction import SingleProductReaction
8 | from syntheseus.search.graph.base_graph import BaseReactionGraph
9 |
10 | MOL_AND_RXN = Union[Molecule, SingleProductReaction]
11 |
12 |
13 | class SynthesisGraph(BaseReactionGraph[SingleProductReaction]):
14 | """
15 | Data structure used to hold a retrosynthesis graph containing only
16 | reaction objects. The purpose of this class is as a minimal container
17 | for route objects, instead of storing them as AndOrGraphs or MolSetGraphs.
18 | """
19 |
20 | def __init__(self, root_node: SingleProductReaction, **kwargs) -> None:
21 | super().__init__(**kwargs)
22 | self._root_node = root_node
23 | self._graph.add_node(self._root_node)
24 |
25 | @property
26 | def root_node(self) -> SingleProductReaction:
27 | return self._root_node
28 |
29 | @property
30 | def root_mol(self) -> Molecule:
31 | return self.root_node.product
32 |
33 | def is_minimal(self) -> bool:
34 | # Check if any product appears more than once
35 | for rxn in self._graph.nodes:
36 | product_count = Counter([rxn.product for rxn in self.successors(rxn)])
37 | if any(v > 1 for v in product_count.values()):
38 | return False
39 | return True
40 |
41 | def assert_validity(self) -> None:
42 | # Everything from superclass applies
43 | super().assert_validity()
44 |
45 | for node in self._graph.nodes:
46 | assert isinstance(node, SingleProductReaction)
47 | for parent in self.predecessors(node):
48 | assert isinstance(parent, SingleProductReaction)
49 | assert node.product in parent.reactants
50 | children = list(self.successors(node))
51 | assert len(children) == len(set(children)) # all children should be unique
52 | assert set([child.product for child in children]) <= set(
53 | node.reactants
54 | ) # all children should be reactants
55 |
56 | def expand_with_reactions(
57 | self,
58 | reactions: list[SingleProductReaction],
59 | node: SingleProductReaction,
60 | ensure_tree: bool,
61 | ) -> Sequence[SingleProductReaction]:
62 | raise NotImplementedError
63 |
64 | def get_starting_molecules(self) -> set[Molecule]:
65 | """
66 | Get the 'starting molecules' for this route,
67 | i.e. reactant molecules which are not a product of a child reaction.
68 | """
69 | output: set[Molecule] = set()
70 | for rxn in self._graph.nodes:
71 | successor_products = {child_rxn.product for child_rxn in self.successors(rxn)}
72 | for reactant in rxn.unique_reactants:
73 | if reactant not in successor_products:
74 | output.add(reactant)
75 | return output
76 |
77 | def __str__(self) -> str:
78 | return str([rxn.reaction_smiles for rxn in self.nodes()])
79 |
--------------------------------------------------------------------------------
/syntheseus/interface/reaction.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import dataclass, field
4 | from typing import Optional, TypedDict
5 |
6 | from syntheseus.interface.bag import Bag
7 | from syntheseus.interface.molecule import Molecule, molecule_bag_to_smiles
8 |
9 | REACTION_SEPARATOR = ">"
10 |
11 |
12 | class ReactionMetaData(TypedDict, total=False):
13 | """Class to add typing to optional meta-data fields for reactions."""
14 |
15 | cost: float
16 | template: str
17 | source: str # any explanation of the source of this reaction
18 | probability: float # probability for this reaction (e.g. from a model)
19 | log_probability: float # log probability for this reaction (should match log of above)
20 | score: float # any kind of score for this reaction (e.g. softmax value, probability)
21 | confidence: float # confidence (probability) that this reaction is possible
22 | reaction_id: int # template id or other kind of reaction id, if applicable
23 | reaction_smiles: str # reaction smiles for this reaction
24 | ground_truth_match: bool # whether this reaction matches ground truth
25 |
26 |
27 | def reaction_string(reactants_str: str, products_str: str) -> str:
28 | """Produces a consistent string representation of a reaction."""
29 | return f"{reactants_str}{2 * REACTION_SEPARATOR}{products_str}"
30 |
31 |
32 | @dataclass(frozen=True, order=False)
33 | class Reaction:
34 | reactants: Bag[Molecule] = field(hash=True, compare=True)
35 | products: Bag[Molecule] = field(hash=True, compare=True)
36 | identifier: Optional[str] = field(default=None, hash=True, compare=True)
37 |
38 | # Dictionary to hold additional metadata.
39 | metadata: ReactionMetaData = field(
40 | default_factory=lambda: ReactionMetaData(),
41 | hash=False,
42 | compare=False,
43 | )
44 |
45 | @property
46 | def unique_reactants(self) -> set[Molecule]:
47 | return set(self.reactants)
48 |
49 | @property
50 | def unique_products(self) -> set[Molecule]:
51 | return set(self.products)
52 |
53 | @property
54 | def reactants_str(self) -> str:
55 | return molecule_bag_to_smiles(self.reactants)
56 |
57 | @property
58 | def products_str(self) -> str:
59 | return molecule_bag_to_smiles(self.products)
60 |
61 | @property
62 | def reaction_smiles(self) -> str:
63 | return reaction_string(reactants_str=self.reactants_str, products_str=self.products_str)
64 |
65 | def __str__(self) -> str:
66 | output = self.reaction_smiles
67 | if self.identifier is not None:
68 | output += f" ({self.identifier})"
69 | return output
70 |
71 |
72 | @dataclass(frozen=True, order=False)
73 | class SingleProductReaction(Reaction):
74 | def __init__(self, *, reactants: Bag[Molecule], product: Molecule, **kwargs) -> None:
75 | super().__init__(reactants=reactants, products=Bag([product]), **kwargs)
76 |
77 | @property
78 | def product(self) -> Molecule:
79 | """Handle for the single product of this reaction."""
80 | assert len(self.products) == 1 # Guaranteed in `__init__`.
81 | return next(iter(self.products))
82 |
--------------------------------------------------------------------------------
/syntheseus/tests/interface/test_reaction.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import copy
4 | from dataclasses import FrozenInstanceError
5 |
6 | import pytest
7 |
8 | from syntheseus.interface.bag import Bag
9 | from syntheseus.interface.molecule import Molecule
10 | from syntheseus.interface.reaction import Reaction, SingleProductReaction
11 |
12 |
13 | def test_reaction_objects_basic():
14 | """Instantiate a `Reaction` object."""
15 | C2 = Molecule(2 * "C")
16 | C3 = Molecule(3 * "C")
17 | C5 = Molecule(5 * "C")
18 |
19 | # Single product reaction
20 | rxn1 = SingleProductReaction(
21 | reactants=Bag([C2, C3]),
22 | product=C5,
23 | )
24 | assert rxn1.reaction_smiles == "CC.CCC>>CCCCC"
25 |
26 | # Standadr (multi-product) reaction
27 | rxn2 = Reaction(
28 | reactants=Bag([C5]),
29 | products=Bag([C2, C3]),
30 | )
31 | assert rxn2.reaction_smiles == "CCCCC>>CC.CCC"
32 |
33 |
34 | class TestSingleProductReactions:
35 | def test_positive_equality(self, rxn_cocs_from_co_cs: SingleProductReaction) -> None:
36 | """Various tests that 2 reactions with same products and reactants should be equal."""
37 |
38 | rxn_copy = SingleProductReaction(
39 | product=copy.deepcopy(rxn_cocs_from_co_cs.product),
40 | reactants=copy.deepcopy(rxn_cocs_from_co_cs.reactants),
41 | )
42 |
43 | # Test 1: original and copy should be equal
44 | assert rxn_copy == rxn_cocs_from_co_cs
45 |
46 | # Test 2: although equal, they should be distinct objects
47 | assert rxn_cocs_from_co_cs is not rxn_copy
48 |
49 | # Test 3: differences in metadata should not affect equality
50 | rxn_cocs_from_co_cs.metadata["test"] = "str1" # type: ignore[typeddict-unknown-key]
51 | rxn_copy.metadata["test"] = "str2" # type: ignore[typeddict-unknown-key]
52 | assert rxn_copy == rxn_cocs_from_co_cs
53 | assert rxn_cocs_from_co_cs.metadata != rxn_copy.metadata
54 |
55 | def test_negative_equality(self, rxn_cocs_from_co_cs: SingleProductReaction) -> None:
56 | """Various tests that reactions which should not be equal are not equal."""
57 |
58 | # Test 1: changing identifier makes reactions not equal
59 | rxn_with_different_id = SingleProductReaction(
60 | product=copy.deepcopy(rxn_cocs_from_co_cs.product),
61 | reactants=copy.deepcopy(rxn_cocs_from_co_cs.reactants),
62 | identifier="different",
63 | )
64 | assert rxn_cocs_from_co_cs != rxn_with_different_id
65 |
66 | # Test 2: different products and reactions are not equal
67 | diff_rxn = SingleProductReaction(
68 | product=Molecule("CC"), reactants=Bag([Molecule("CO"), Molecule("CS")])
69 | )
70 | assert rxn_cocs_from_co_cs != diff_rxn
71 |
72 | def test_frozen(self, rxn_cocs_from_co_cs: SingleProductReaction) -> None:
73 | """Test that the fields of the reaction are frozen."""
74 | with pytest.raises(FrozenInstanceError):
75 | # type ignore is because mypy complains we are modifying a frozen field, which is the point of the test
76 | rxn_cocs_from_co_cs.identifier = "abc" # type: ignore[misc]
77 |
78 | def test_rxn_smiles(self, rxn_cocs_from_co_cs: SingleProductReaction) -> None:
79 | """Test that the reaction SMILES is as expected."""
80 | assert rxn_cocs_from_co_cs.reaction_smiles == "CO.CS>>COCS"
81 |
--------------------------------------------------------------------------------
/docs/installation.md:
--------------------------------------------------------------------------------
1 | We support two installation modes:
2 |
3 | - *core installation* allows you to build and benchmark your own models or search algorithms
4 | - *full installation* also allows you to perform end-to-end search using the supported models
5 |
6 | There are also two installation sources:
7 |
8 | - *pip*, which provides the most recent released version
9 | - *GitHub*, which provides the latest changes but may be less stable and may not be
10 | backward-compatible with the latest released version
11 |
12 | === "Core (pip)"
13 |
14 | ```bash
15 | conda env create -f environment.yml
16 | conda activate syntheseus
17 |
18 | pip install syntheseus
19 | ```
20 |
21 | === "Full (pip)"
22 |
23 | ```bash
24 | conda env create -f environment_full.yml
25 | conda activate syntheseus-full
26 |
27 | pip install "syntheseus[all]"
28 | ```
29 |
30 | === "Core (GitHub)"
31 |
32 | ```bash
33 | conda env create -f environment.yml
34 | conda activate syntheseus
35 |
36 | pip install -e .
37 | ```
38 |
39 | === "Full (GitHub)"
40 |
41 | ```bash
42 | conda env create -f environment_full.yml
43 | conda activate syntheseus-full
44 |
45 | pip install -e ".[all]"
46 | ```
47 |
48 | !!! note
49 |
50 | Make sure you are viewing the version of the docs matching your `syntheseus` installation.
51 | Select the `x.y.z` version you installed if you used `pip` (go [here](https://microsoft.github.io/syntheseus/stable/) for the latest one),
52 | or [dev](https://microsoft.github.io/syntheseus/dev/) if you installed `syntheseus` directly from GitHub.
53 |
54 | Core installation includes only minimal dependencies (no ML libraries), while full installation includes all supported models and also dependencies for visualization/development.
55 |
56 | Instructions above assume you already cloned the repository via
57 |
58 | ```bash
59 | git clone https://github.com/microsoft/syntheseus.git
60 | cd syntheseus
61 | ```
62 |
63 | Note that `environment_full.yml` pins the CUDA version (to 11.3) for reproducibility.
64 | If you want to use a different one, make sure to edit the environment file accordingly.
65 |
66 | ??? info "Setting up GLN"
67 |
68 | We also support GLN, but it requires a specialized environment and is thus not installed via `pip`.
69 | See [here](https://github.com/microsoft/syntheseus/blob/main/syntheseus/reaction_prediction/environment_gln/Dockerfile) for a Docker environment necessary for running GLN.
70 |
71 | ## Reducing the number of dependencies
72 |
73 | To keep the environment smaller, you can replace the `all` option with a comma-separated subset of `{chemformer,local-retro,megan,mhn-react,retro-knn,root-aligned,viz,dev}` (`viz` and `dev` correspond to visualization and development dependencies, respectively).
74 | For example, `pip install -e ".[local-retro,root-aligned]"` installs only LocalRetro and RootAligned.
75 | If installing a subset of models, you can also delete the lines in `environment_full.yml` marked with names of models you do not wish to use.
76 |
77 | If you only want to use a very specific part of `syntheseus`, you could also install it without dependencies:
78 |
79 | ```bash
80 | pip install -e . --no-dependencies
81 | ```
82 |
83 | You then would need to manually install a subset of dependencies that are required for a particular functionality you want to access.
84 | See `pyproject.toml` for a list of dependencies tied to the `search` and `reaction_prediction` subpackages.
85 |
--------------------------------------------------------------------------------
/syntheseus/interface/molecule.py:
--------------------------------------------------------------------------------
1 | """
2 | Classes to hold molecules, without reference to the reactions they may take part in.
3 | """
4 |
5 | from dataclasses import InitVar, dataclass, field
6 | from typing import Optional, TypedDict, Union
7 |
8 | from rdkit import Chem
9 |
10 | from syntheseus.interface.bag import Bag
11 |
12 | SMILES_SEPARATOR = "."
13 |
14 |
15 | class MoleculeMetaData(TypedDict, total=False):
16 | """Class to add typing to optional meta-data fields for molecules."""
17 |
18 | rdkit_mol: Chem.Mol
19 |
20 | # Things related to multi-step retrosynthesis
21 | is_purchasable: bool
22 | cost: float
23 | supplier: str
24 |
25 | # Other potentially relevant data
26 | purity: float
27 |
28 |
29 | @dataclass(frozen=True, order=True)
30 | class Molecule:
31 | """
32 | Object representing a molecule with its SMILES string and an optional
33 | identifier to distinguish molecules with identical SMILES strings (usually not used).
34 | Everything else is considered metadata and is stored in a dictionary which is not
35 | compared or hashed.
36 |
37 | The class is frozen since it should not need to be edited,
38 | and this will auto-implement __eq__ and __hash__ methods.
39 |
40 | On initialization it is possible to automatically convert to canonical
41 | smiles (default True), and to store the rdkit molecule (default True).
42 | If set to false, there is no guarantee of canonicalization or storage of
43 | an rdkit mol.
44 | """
45 |
46 | smiles: str = field(hash=True, compare=True)
47 | identifier: Optional[Union[str, int]] = field(default=None, hash=True, compare=True)
48 |
49 | canonicalize: InitVar[bool] = True
50 | make_rdkit_mol: InitVar[bool] = True
51 |
52 | metadata: MoleculeMetaData = field(
53 | default_factory=lambda: MoleculeMetaData(),
54 | hash=False,
55 | compare=False,
56 | )
57 |
58 | def __post_init__(self, canonicalize: bool, make_rdkit_mol: bool) -> None:
59 | if canonicalize or make_rdkit_mol:
60 | try:
61 | rdkit_mol = Chem.MolFromSmiles(self.smiles)
62 | except Exception as e:
63 | raise ValueError(f"Cannot create a molecule with SMILES '{self.smiles}'") from e
64 |
65 | if make_rdkit_mol:
66 | self.metadata["rdkit_mol"] = rdkit_mol
67 |
68 | if canonicalize:
69 | try:
70 | smiles_canonical = Chem.MolToSmiles(rdkit_mol)
71 | except Exception as e:
72 | raise ValueError(
73 | f"Cannot canonicalize a molecule with SMILES '{self.smiles}'"
74 | ) from e
75 |
76 | object.__setattr__(self, "smiles", smiles_canonical)
77 |
78 | @property
79 | def rdkit_mol(self) -> Chem.Mol:
80 | """Makes an rdkit mol if one does yet exist"""
81 | if "rdkit_mol" not in self.metadata:
82 | self.metadata["rdkit_mol"] = Chem.MolFromSmiles(self.smiles)
83 | return self.metadata["rdkit_mol"]
84 |
85 |
86 | def molecule_bag_to_smiles(mols: Bag[Molecule]) -> str:
87 | """Combine SMILES strings of molecules in a `Bag` into a single string.
88 |
89 | For two bags that represent the same multiset of molecules this function will return the same
90 | result, because iteration order over a `Bag` is deterministic (sorted using default comparator).
91 | """
92 | return SMILES_SEPARATOR.join(mol.smiles for mol in mols)
93 |
--------------------------------------------------------------------------------
/syntheseus/search/mol_inventory.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import abc
4 | import warnings
5 | from collections.abc import Collection
6 | from pathlib import Path
7 | from typing import Union
8 |
9 | from rdkit import Chem
10 |
11 | from syntheseus.interface.molecule import Molecule
12 |
13 |
14 | class BaseMolInventory(abc.ABC):
15 | @abc.abstractmethod
16 | def is_purchasable(self, mol: Molecule) -> bool:
17 | """Whether or not a molecule is purchasable."""
18 | raise NotImplementedError
19 |
20 | def fill_metadata(self, mol: Molecule) -> None:
21 | """
22 | Fills any/all metadata of a molecule. This method should be fast to call,
23 | and many algorithms will assume that it sets `is_purchasable`.
24 | """
25 |
26 | # Default just adds whether the molecule is purchasable
27 | mol.metadata["is_purchasable"] = self.is_purchasable(mol)
28 |
29 |
30 | class ExplicitMolInventory(BaseMolInventory):
31 | """
32 | Base class for MolInventories which store an explicit list of purchasable molecules.
33 | It exposes and additional method to explore this list.
34 |
35 | If it is unclear how a mol inventory might *not* have an explicit list of purchasable
36 | molecules, imagine a toy problem where every molecule with <= 10 atoms is purchasable.
37 | It is easy to check if a molecule has <= 10 atoms, but it is difficult to enumerate
38 | all molecules with <= 10 atoms.
39 | """
40 |
41 | @abc.abstractmethod
42 | def to_purchasable_mols(self) -> Collection[Molecule]:
43 | """Returns an explicit collection of all purchasable molecules.
44 |
45 | Likely expensive for large inventories, should be used mostly for testing or debugging.
46 | """
47 |
48 | def purchasable_mols(self) -> Collection[Molecule]:
49 | warnings.warn(
50 | "purchasable_mols is deprecated, use to_purchasable_mols instead", DeprecationWarning
51 | )
52 | return self.to_purchasable_mols()
53 |
54 | @abc.abstractmethod
55 | def __len__(self) -> int:
56 | """Return the number of purchasable molecules in the inventory."""
57 |
58 |
59 | class SmilesListInventory(ExplicitMolInventory):
60 | """Most common type of inventory: a list of purchasable SMILES."""
61 |
62 | def __init__(self, smiles_list: list[str], canonicalize: bool = True):
63 | if canonicalize:
64 | # For canonicalization we sequence `MolFromSmiles` and `MolToSmiles` to exactly match
65 | # the process employed in the `Molecule` class.
66 | smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(s)) for s in smiles_list]
67 |
68 | self._smiles_set = set(smiles_list)
69 |
70 | def is_purchasable(self, mol: Molecule) -> bool:
71 | if mol.identifier is not None:
72 | warnings.warn(
73 | f"Molecule identifier {mol.identifier} will be ignored during inventory lookup"
74 | )
75 |
76 | return mol.smiles in self._smiles_set
77 |
78 | def to_purchasable_mols(self) -> Collection[Molecule]:
79 | return {Molecule(s, make_rdkit_mol=False, canonicalize=False) for s in self._smiles_set}
80 |
81 | def __len__(self) -> int:
82 | return len(self._smiles_set)
83 |
84 | @classmethod
85 | def load_from_file(cls, path: Union[str, Path], **kwargs) -> SmilesListInventory:
86 | """Load the inventory SMILES from a file."""
87 | with open(path, "rt") as f_inventory:
88 | return cls([line.strip() for line in f_inventory], **kwargs)
89 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/utils/inference.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Any, List, Sequence, Union
3 |
4 | from syntheseus.interface.bag import Bag
5 | from syntheseus.interface.molecule import Molecule
6 | from syntheseus.interface.reaction import (
7 | Reaction,
8 | ReactionMetaData,
9 | SingleProductReaction,
10 | )
11 | from syntheseus.reaction_prediction.chem.utils import molecule_bag_from_smiles
12 |
13 |
14 | def process_raw_smiles_outputs_backwards(
15 | input: Molecule, output_list: List[str], metadata_list: List[ReactionMetaData]
16 | ) -> Sequence[SingleProductReaction]:
17 | """Convert raw SMILES outputs into a list of `SingleProductReaction` objects.
18 |
19 | Args:
20 | input: Model input.
21 | output_list: Raw SMILES outputs (including potentially invalid ones).
22 | metadata_list: Additional metadata to attach to the predictions (e.g. probability).
23 |
24 | Returns:
25 | A list of `SingleProductReaction`s; may be shorter than `outputs` if some of the raw
26 | SMILES could not be parsed into valid reactant bags.
27 | """
28 | predictions: List[SingleProductReaction] = []
29 |
30 | for raw_output, metadata in zip(output_list, metadata_list):
31 | reactants = molecule_bag_from_smiles(raw_output)
32 |
33 | # Only consider the prediction if the SMILES can be parsed.
34 | if reactants is not None:
35 | predictions.append(
36 | SingleProductReaction(product=input, reactants=reactants, metadata=metadata)
37 | )
38 |
39 | return predictions
40 |
41 |
42 | def process_raw_smiles_outputs_forwards(
43 | input: Bag[Molecule], output_list: List[str], metadata_list: List[ReactionMetaData]
44 | ) -> Sequence[Reaction]:
45 | """Convert raw SMILES outputs into a list of `Reaction` objects.
46 | Like method `process_raw_smiles_outputs_backwards`, but for forward models.
47 |
48 | Args:
49 | input: Model input.
50 | output_list: Raw SMILES outputs (including potentially invalid ones).
51 | metadata_list: Additional metadata to attach to the predictions (e.g. probability).
52 |
53 | Returns:
54 | A list of `Reaction`s; may be shorter than `outputs` if some of the raw
55 | SMILES could not be parsed into valid reactant bags.
56 | """
57 | predictions: List[Reaction] = []
58 |
59 | for raw_output, metadata in zip(output_list, metadata_list):
60 | products = molecule_bag_from_smiles(raw_output)
61 |
62 | # Only consider the prediction if the SMILES can be parsed.
63 | if products is not None:
64 | predictions.append(Reaction(products=products, reactants=input, metadata=metadata))
65 |
66 | return predictions
67 |
68 |
69 | def get_unique_file_in_dir(dir: Union[str, Path], pattern: str) -> Path:
70 | candidates = list(Path(dir).glob(pattern))
71 | if len(candidates) != 1:
72 | raise ValueError(
73 | f"Expected a unique match for {pattern} in {dir}, found {len(candidates)}: {candidates}"
74 | )
75 |
76 | return candidates[0]
77 |
78 |
79 | def get_module_path(module: Any) -> str:
80 | """Heuristically extract the local path to an imported module."""
81 |
82 | # In some cases, `module.__path__` is already a `List`, while in other cases it may be a
83 | # `_NamespacePath` object. Either way the conversion below leaves us with `List[str]`.
84 | path_list: List[str] = list(module.__path__)
85 |
86 | if len(path_list) != 1:
87 | raise ValueError(f"Cannot extract path to module {module} from {path_list}")
88 |
89 | return path_list[0]
90 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/inference/gln.py:
--------------------------------------------------------------------------------
1 | """Inference wrapper for the Graph Logic Network (GLN) model.
2 |
3 | Paper: https://arxiv.org/abs/2001.01408
4 | Code: https://github.com/Hanjun-Dai/GLN
5 |
6 | The original GLN code is released under the MIT license.
7 | """
8 |
9 | import sys
10 | from pathlib import Path
11 | from typing import List, Sequence
12 |
13 | from syntheseus.interface.molecule import Molecule
14 | from syntheseus.interface.reaction import SingleProductReaction
15 | from syntheseus.reaction_prediction.inference.base import ExternalBackwardReactionModel
16 | from syntheseus.reaction_prediction.utils.inference import process_raw_smiles_outputs_backwards
17 | from syntheseus.reaction_prediction.utils.misc import suppress_outputs
18 |
19 |
20 | class GLNModel(ExternalBackwardReactionModel):
21 | def __init__(self, *args, dataset_name: str = "schneider50k", **kwargs) -> None:
22 | """Initializes the GLN model wrapper.
23 |
24 | Assumed format of the model directory:
25 | - `model_dir` contains files necessary to build `RetroGLN`
26 | - `model_dir/{dataset_name}.ckpt` is the model checkpoint
27 | - `model_dir/cooked_{dataset_name}/atom_list.txt` is the atom type list
28 | """
29 | super().__init__(*args, **kwargs)
30 |
31 | import torch
32 |
33 | chkpt_path = Path(self.model_dir) / f"{dataset_name}.ckpt"
34 | gln_args = {
35 | "dropbox": self.model_dir,
36 | "data_name": dataset_name,
37 | "model_for_test": chkpt_path,
38 | "tpl_name": "default",
39 | "f_atoms": Path(self.model_dir) / f"cooked_{dataset_name}" / "atom_list.txt",
40 | "gpu": torch.device(self.device).index,
41 | }
42 |
43 | # Suppress most of the prints from GLN's internals. This only works on messages that
44 | # originate from Python, so the C++-based ones slip through.
45 | with suppress_outputs():
46 | # GLN makes heavy use of global state (saved either in `gln.common.cmd_args` or `sys.argv`),
47 | # so we have to hack both of these sources below.
48 | from gln.common.cmd_args import cmd_args
49 |
50 | sys.argv = []
51 | for name, value in gln_args.items():
52 | setattr(cmd_args, name, value)
53 | sys.argv += [f"-{name}", str(value)]
54 |
55 | # The global state hackery has to happen before this.
56 | from gln.test.model_inference import RetroGLN
57 |
58 | self.model = RetroGLN(self.model_dir, chkpt_path)
59 |
60 | @property
61 | def name(self) -> str:
62 | return "GLN"
63 |
64 | def get_parameters(self):
65 | return self.model.gln.parameters()
66 |
67 | def _get_model_predictions(
68 | self, input: Molecule, num_results: int
69 | ) -> Sequence[SingleProductReaction]:
70 | with suppress_outputs():
71 | result = self.model.run(input.smiles, num_results, num_results)
72 |
73 | if result is None:
74 | return []
75 | else:
76 | # `scores` are actually probabilities (produced by running `softmax`).
77 | return process_raw_smiles_outputs_backwards(
78 | input=input,
79 | output_list=result["reactants"],
80 | metadata_list=[
81 | {"probability": probability.item()} for probability in result["scores"]
82 | ],
83 | )
84 |
85 | def _get_reactions(
86 | self, inputs: List[Molecule], num_results: int
87 | ) -> List[Sequence[SingleProductReaction]]:
88 | return [self._get_model_predictions(input, num_results=num_results) for input in inputs]
89 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 | workflow_dispatch:
9 |
10 | jobs:
11 | test-core:
12 | runs-on: ubuntu-latest
13 | strategy:
14 | fail-fast: false
15 | matrix:
16 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
17 | defaults:
18 | run:
19 | shell: bash -l {0}
20 | name: test-core (Python ${{ matrix.python-version }})
21 | steps:
22 | - uses: actions/checkout@v3
23 | - uses: conda-incubator/setup-miniconda@v3
24 | with:
25 | mamba-version: "*"
26 | channels: conda-forge,defaults
27 | channel-priority: true
28 | python-version: ${{ matrix.python-version }}
29 | environment-file: environment.yml
30 | - name: Install syntheseus
31 | run: |
32 | pip install .[dev]
33 | - name: Run pre-commit
34 | run: |
35 | pre-commit run --verbose --all-files
36 | - name: Run unit tests
37 | run: |
38 | coverage run -p -m pytest ./syntheseus/tests/
39 | coverage report --data-file .coverage.*
40 | - name: Upload coverage report
41 | uses: actions/upload-artifact@v4
42 | with:
43 | name: .coverage.core-${{ matrix.python-version }}
44 | path: .coverage.*
45 | include-hidden-files: true
46 | test-models:
47 | strategy:
48 | fail-fast: false
49 | matrix:
50 | runner: ["ubuntu-latest", "Syntheseus-GPU"]
51 | runs-on: ${{ matrix.runner }}
52 | defaults:
53 | run:
54 | shell: bash -l {0}
55 | name: test-models (${{ matrix.runner == 'ubuntu-latest' && 'CPU' || 'GPU' }})
56 | steps:
57 | - name: Free extra disk space
58 | uses: jlumbroso/free-disk-space@main
59 | - uses: actions/checkout@v3
60 | - uses: conda-incubator/setup-miniconda@v3
61 | with:
62 | mamba-version: "*"
63 | channels: conda-forge,defaults
64 | channel-priority: true
65 | environment-file: environment_full.yml
66 | - name: Install syntheseus with all single-step models
67 | run: |
68 | sudo apt install -y graphviz
69 | pip install .[all]
70 | - name: Verify GPU is available
71 | if: ${{ matrix.runner != 'ubuntu-latest' }}
72 | run: |
73 | nvidia-smi || (echo "❌ No GPU detected" && exit 1)
74 | python -c "import torch; assert torch.cuda.is_available(), '❌ GPU not found'; print('✅ Found GPU:', torch.cuda.get_device_name(0))"
75 | - name: Run single-step model tests
76 | run: |
77 | coverage run -p -m pytest \
78 | ./syntheseus/tests/cli/test_cli.py \
79 | ./syntheseus/tests/reaction_prediction/inference/test_models.py \
80 | ./syntheseus/tests/reaction_prediction/utils/test_parallel.py
81 | coverage report --data-file .coverage.*
82 | - name: Upload coverage report
83 | uses: actions/upload-artifact@v4
84 | with:
85 | name: .coverage.models-${{ matrix.runner }}
86 | path: .coverage.*
87 | include-hidden-files: true
88 | coverage:
89 | needs: [test-core, test-models]
90 | runs-on: ubuntu-latest
91 | steps:
92 | - uses: actions/checkout@v2
93 | - name: Set up Python
94 | uses: actions/setup-python@v2
95 | with:
96 | python-version: 3.9
97 | - name: Install dependencies
98 | run: |
99 | python -m pip install --upgrade pip
100 | pip install coverage
101 | - uses: actions/download-artifact@v4
102 | with:
103 | merge-multiple: true
104 | - name: Generate a combined coverage report
105 | run: |
106 | coverage combine
107 | coverage report
108 |
--------------------------------------------------------------------------------
/syntheseus/tests/cli/test_search.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import json
4 | from pathlib import Path
5 | from types import SimpleNamespace
6 | from typing import Sequence
7 |
8 | import pytest
9 | from omegaconf import OmegaConf
10 |
11 | from syntheseus import BackwardReactionModel, Bag, Molecule, SingleProductReaction
12 | from syntheseus.cli.search import SearchConfig, run_from_config
13 | from syntheseus.reaction_prediction.inference.config import BackwardModelClass
14 |
15 |
16 | class FlakyReactionModel(BackwardReactionModel):
17 | """Dummy reaction model that only works when called for the first time."""
18 |
19 | def __init__(self, *args, **kwargs) -> None:
20 | super().__init__()
21 | self._used = False
22 |
23 | def _get_reactions(
24 | self, inputs: list[Molecule], num_results: int
25 | ) -> list[Sequence[SingleProductReaction]]:
26 | if self._used:
27 | raise RuntimeError()
28 |
29 | self._used = True
30 | return [
31 | [
32 | SingleProductReaction(
33 | reactants=Bag([Molecule("C")]), product=product, metadata={"probability": 1.0}
34 | )
35 | ]
36 | for product in inputs
37 | ]
38 |
39 |
40 | def test_resume_search(tmpdir: Path) -> None:
41 | search_targets_file_path = tmpdir / "search_targets.smiles"
42 | with open(search_targets_file_path, "wt") as f_search_targets:
43 | f_search_targets.write("CC\nCC\nCC\nCC\n")
44 |
45 | inventory_file_path = tmpdir / "inventory.smiles"
46 | with open(inventory_file_path, "wt") as f_inventory:
47 | f_inventory.write("C\n")
48 |
49 | # Inject our flaky reaction model into the set of supported model classes.
50 | BackwardModelClass._member_map_["FlakyReactionModel"] = SimpleNamespace( # type: ignore
51 | name="FlakyReactionModel", value=FlakyReactionModel
52 | )
53 |
54 | config = OmegaConf.create( # type: ignore
55 | SearchConfig(
56 | model_class="FlakyReactionModel", # type: ignore[arg-type]
57 | search_algorithm="retro_star", # type: ignore[arg-type]
58 | search_targets_file=str(search_targets_file_path),
59 | inventory_smiles_file=str(inventory_file_path),
60 | results_dir=str(tmpdir),
61 | append_timestamp_to_dir=False,
62 | limit_iterations=1,
63 | num_routes_to_plot=0,
64 | )
65 | )
66 |
67 | results_dir = tmpdir / "FlakyReactionModel"
68 |
69 | def file_exist(idx: int, name: str) -> bool:
70 | return (results_dir / str(idx) / name).exists()
71 |
72 | # Try to run search three times; each time we will succeed solving one target (which requires one
73 | # call) and then fail on the next one.
74 | for trial_idx in range(3):
75 | with pytest.raises(RuntimeError):
76 | run_from_config(config)
77 |
78 | for idx in range(trial_idx + 1):
79 | assert file_exist(idx, "stats.json")
80 | assert not file_exist(idx, ".lock")
81 |
82 | assert not file_exist(trial_idx + 1, "stats.json")
83 | assert file_exist(trial_idx + 1, ".lock")
84 |
85 | run_from_config(config)
86 |
87 | # The last search needs to solve one final target so it will succeed.
88 | for idx in range(4):
89 | assert file_exist(idx, "stats.json")
90 | assert not file_exist(idx, ".lock")
91 |
92 | with open(results_dir / "stats.json", "rt") as f_stats:
93 | stats = json.load(f_stats)
94 |
95 | # Even though each search only solved a single target, final stats should include everything.
96 | assert stats["num_targets"] == stats["num_solved_targets"] == 4
97 |
98 | # Finally change the targets and verify that the discrepancy will be detected.
99 | with open(search_targets_file_path, "wt") as f_search_targets:
100 | f_search_targets.write("CC\nCCCC\nCC\nCC\n")
101 |
102 | with pytest.raises(RuntimeError):
103 | run_from_config(config)
104 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/inference/toy_models.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Sequence
4 |
5 | from syntheseus.interface.bag import Bag
6 | from syntheseus.interface.models import BackwardReactionModel
7 | from syntheseus.interface.molecule import Molecule
8 | from syntheseus.interface.reaction import SingleProductReaction
9 |
10 |
11 | class ListOfReactionsToyModel(BackwardReactionModel):
12 | """A model which returns reactions from a pre-defined list."""
13 |
14 | def __init__(self, reaction_list: Sequence[SingleProductReaction], **kwargs) -> None:
15 | super().__init__(**kwargs)
16 | self.reaction_list = list(reaction_list)
17 |
18 | def _get_reactions(
19 | self, inputs: list[Molecule], num_results: int
20 | ) -> list[Sequence[SingleProductReaction]]:
21 | return [[r for r in self.reaction_list if r.product == mol] for mol in inputs]
22 |
23 |
24 | class LinearMoleculesToyModel(BackwardReactionModel):
25 | """
26 | A simple toy model of "reactions" on linear "ball-and-stick" molecules,
27 | where the possible reactions involve string cuts and substitutions.
28 |
29 | Molecules in this model must be formed entirely from C,S, and O atoms with single bonds.
30 | The reactions allowed are:
31 | - string cuts, e.g. "CCOC" -> "CC" + "OC" (*see note 2 below)
32 | - substitution of the atom on either end of the molecule: e.g. "CCOC" -> "CCOO"
33 |
34 | NOTE 1: molecules formed by this model are mostly unphysical and the reactions
35 | are not actual chemical reactions. This model is intended for testing and debugging.
36 |
37 | NOTE 2: all molecules are returned with canonical SMILES, so the outputs may not look
38 | the same as a string cut. For example, "CCOC" -> "CC" + "OC" will get canonicalized to
39 | "CC" + "CO" (i.e. the "O" and "C" will swap places). Fundamentally this doesn't change anything
40 | since "CO" and "OC" are the same molecule, but it may be confusing when debugging.
41 | """
42 |
43 | def __init__(self, allow_substitution: bool = True, **kwargs) -> None:
44 | super().__init__(**kwargs)
45 | self._allow_substitution = allow_substitution # should not be modified after init
46 |
47 | def _get_single_backward_reactions(self, mol: Molecule) -> list[SingleProductReaction]:
48 | assert set(mol.smiles) <= set("COS"), "Molecules must be formed out of C,O, and S atoms."
49 | assert len(mol.smiles) > 0, "Molecules must have at least 1 atom."
50 | output: list[SingleProductReaction] = []
51 |
52 | # String cuts
53 | for cut_idx in range(1, len(mol.smiles)):
54 | mol1 = Molecule(mol.smiles[:cut_idx], make_rdkit_mol=False)
55 | mol2 = Molecule(mol.smiles[cut_idx:], make_rdkit_mol=False)
56 | output.append(
57 | SingleProductReaction(
58 | product=mol,
59 | reactants=Bag({mol1, mol2}), # don't include duplicates
60 | metadata={"source": f"string cut at idx {cut_idx}"},
61 | )
62 | )
63 |
64 | # Substitutions
65 | if self._allow_substitution:
66 | for sub_idx in {0, len(mol.smiles) - 1}: # use set in case len(mol.smiles) == 1
67 | for sub_atom in "COS":
68 | if mol.smiles[sub_idx] == sub_atom:
69 | continue
70 | else:
71 | new_mol = Molecule(
72 | mol.smiles[:sub_idx] + sub_atom + mol.smiles[sub_idx + 1 :],
73 | make_rdkit_mol=False,
74 | )
75 | output.append(
76 | SingleProductReaction(
77 | product=mol,
78 | reactants=Bag([new_mol]),
79 | metadata={"source": f"substitution idx {sub_idx} with {sub_atom}"},
80 | ),
81 | )
82 |
83 | return output
84 |
85 | def _get_reactions(
86 | self, inputs: list[Molecule], num_results: int
87 | ) -> list[Sequence[SingleProductReaction]]:
88 | return [self._get_single_backward_reactions(mol) for mol in inputs]
89 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/inference/graph2edits.py:
--------------------------------------------------------------------------------
1 | """Inference wrapper for the Graph2Edits model.
2 |
3 | Paper: https://www.nature.com/articles/s41467-023-38851-5
4 | Code: https://github.com/Jamson-Zhong/Graph2Edits
5 |
6 | The original Graph2Edits code is released under the MIT license.
7 | Parts of this file are based on code from the GitHub repository above.
8 | """
9 |
10 | from __future__ import annotations
11 |
12 | import sys
13 | from typing import Sequence
14 |
15 | from rdkit import Chem
16 |
17 | from syntheseus.interface.molecule import Molecule
18 | from syntheseus.interface.reaction import SingleProductReaction
19 | from syntheseus.reaction_prediction.inference.base import ExternalBackwardReactionModel
20 | from syntheseus.reaction_prediction.utils.inference import (
21 | get_module_path,
22 | get_unique_file_in_dir,
23 | process_raw_smiles_outputs_backwards,
24 | )
25 | from syntheseus.reaction_prediction.utils.misc import suppress_outputs, suppress_rdkit_outputs
26 |
27 |
28 | class Graph2EditsModel(ExternalBackwardReactionModel):
29 | def __init__(self, *args, max_edit_steps: int = 9, **kwargs) -> None:
30 | """Initializes the Graph2Edits model wrapper.
31 |
32 | Assumed format of the model directory:
33 | - `model_dir` contains the model checkpoint as the only `*.pt` file
34 | """
35 | super().__init__(*args, **kwargs)
36 |
37 | import graph2edits
38 | import torch
39 |
40 | sys.path.insert(0, str(get_module_path(graph2edits)))
41 |
42 | from graph2edits.models import BeamSearch, Graph2Edits
43 |
44 | checkpoint = torch.load(
45 | get_unique_file_in_dir(self.model_dir, pattern="*.pt"), map_location=self.device
46 | )
47 |
48 | model = Graph2Edits(**checkpoint["saveables"], device=self.device)
49 | model.load_state_dict(checkpoint["state"])
50 | model.to(self.device)
51 | model.eval()
52 |
53 | # We set the beam size to a placeholder value for now and override it in `_get_reactions`.
54 | self.model = BeamSearch(model=model, step_beam_size=10, beam_size=None, use_rxn_class=False)
55 | self._max_edit_steps = max_edit_steps
56 |
57 | def get_parameters(self):
58 | return self.model.model.parameters()
59 |
60 | def _get_reactions(
61 | self, inputs: list[Molecule], num_results: int
62 | ) -> list[Sequence[SingleProductReaction]]:
63 | import torch
64 |
65 | self.model.beam_size = num_results
66 |
67 | batch_predictions = []
68 | for input in inputs:
69 | # Copy the `rdkit` molecule as below we modify it in-place.
70 | mol = Chem.Mol(input.rdkit_mol)
71 |
72 | # Assign a dummy atom mapping as Graph2Edits depends on it. This has no connection to
73 | # the ground-truth atom mapping, which we do not have access to.
74 | for idx, atom in enumerate(mol.GetAtoms()):
75 | atom.SetAtomMapNum(idx + 1)
76 |
77 | with torch.no_grad(), suppress_outputs(), suppress_rdkit_outputs():
78 | try:
79 | raw_results = self.model.run_search(
80 | prod_smi=Chem.MolToSmiles(mol),
81 | max_steps=self._max_edit_steps,
82 | rxn_class=None,
83 | )
84 | except IndexError:
85 | # This can happen in some rare edge cases (e.g. "OBr").
86 | raw_results = []
87 |
88 | # Errors are returned as a string "final_smi_unmapped"; we get rid of those here.
89 | raw_results = [
90 | raw_result
91 | for raw_result in raw_results
92 | if raw_result["final_smi"] != "final_smi_unmapped"
93 | ]
94 |
95 | batch_predictions.append(
96 | process_raw_smiles_outputs_backwards(
97 | input=input,
98 | output_list=[raw_result["final_smi"] for raw_result in raw_results],
99 | metadata_list=[
100 | {"probability": raw_result["prob"]} for raw_result in raw_results
101 | ],
102 | )
103 | )
104 |
105 | return batch_predictions
106 |
--------------------------------------------------------------------------------
/syntheseus/search/graph/node.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import abc
4 | import datetime
5 | import math
6 | from collections.abc import Collection
7 | from dataclasses import dataclass, field
8 | from typing import TypedDict
9 |
10 |
11 | class _NodeData_Time(TypedDict, total=False):
12 | """Holds optional data about the time a node was created."""
13 |
14 | # How many times has the rxn model been called when this node was created?
15 | num_calls_rxn_model: int
16 |
17 | # How many times has the value function been called when this node was created?
18 | num_calls_value_function: int
19 |
20 |
21 | class _NodeData_Algorithms(TypedDict, total=False):
22 | """Holds optional data used by specific algorithms."""
23 |
24 | # ==================================================
25 | # General
26 | # ==================================================
27 | policy_score: float
28 |
29 | # ==================================================
30 | # Retro*
31 | # ==================================================
32 | retro_star_min_cost: float # minimum cost found so far
33 | retro_star_reaction_number: float
34 | reaction_number_estimate: float
35 | retro_star_value: float
36 | retro_star_rxn_cost: float
37 | retro_star_mol_cost: float
38 |
39 | # ==================================================
40 | # MCTS
41 | # ==================================================
42 | mcts_value: float
43 | mcts_prev_reward: float # the most recent reward received
44 |
45 | # ==================================================
46 | # PDVN MCTS
47 | # ==================================================
48 | pdvn_reaction_cost: float
49 | pdvn_mcts_v_syn: float
50 | pdvn_mcts_v_cost: float
51 | pdvn_mcts_prev_reward_syn: float
52 | pdvn_mcts_prev_reward_cost: float
53 | pdvn_min_syn_cost: float
54 |
55 |
56 | class _NodeData_Analysis(TypedDict, total=False):
57 | """Holds optional data used during analysis of search results."""
58 |
59 | analysis_time: float # Used to hold a node's creation time (measured any way) for analysis purposes
60 | first_solution_time: float # time of first solution (according to analysis_time)
61 | route_cost: float # non-negative cost that this node contributes to the entire route
62 |
63 |
64 | class NodeData(_NodeData_Time, _NodeData_Algorithms, _NodeData_Analysis):
65 | """Holds all kinds of node data."""
66 |
67 | pass
68 |
69 |
70 | @dataclass
71 | class BaseGraphNode(abc.ABC):
72 | # Whether the node is "solved" (has a synthesis route leading to it)
73 | has_solution: bool = False
74 |
75 | # How many times has the node been "visited".
76 | # The meaning of a "visit" will be different for different algorithms.
77 | num_visit: int = 0
78 |
79 | # How "deep" is this node, i.e. the length of the path from the root node to this node.
80 | # It is initialized to inf to indicate "not set" (and this is the only value which will be
81 | # stable with graphs with no root node where depth is ill-defined)
82 | depth: int = math.inf # type: ignore
83 |
84 | # Whether the node has been expanded
85 | is_expanded: bool = False
86 |
87 | # Time when this node was created (used for analysis of search results).
88 | creation_time: datetime.datetime = field(
89 | default_factory=lambda: datetime.datetime.now(datetime.timezone.utc)
90 | )
91 |
92 | # Any other node data, stored as a TypedDict to allow arbitrary values to be tracked
93 | # while also allowing type-checking.
94 | data: NodeData = field(default_factory=lambda: NodeData())
95 |
96 | def __eq__(self, other):
97 | # No comparison of node values, only identity.
98 | return self is other
99 |
100 | def __hash__(self):
101 | # Hash nodes based on id:
102 | # this ensures distinct nodes always have a distinct hash.
103 | return id(self)
104 |
105 | @abc.abstractmethod
106 | def _has_intrinsic_solution(self) -> bool:
107 | """Whether this node has a solution without considering its children."""
108 | raise NotImplementedError
109 |
110 | @abc.abstractmethod
111 | def _has_solution_from_children(self, children: Collection[BaseGraphNode]) -> bool:
112 | """Whether this node has a solution, exclusively considering its children."""
113 | raise NotImplementedError
114 |
--------------------------------------------------------------------------------
/syntheseus/tests/search/analysis/test_diversity.py:
--------------------------------------------------------------------------------
1 | """Test diversity analysis."""
2 | from __future__ import annotations
3 |
4 | import random
5 |
6 | import pytest
7 |
8 | from syntheseus.search.analysis.diversity import (
9 | estimate_packing_number,
10 | molecule_jaccard_distance,
11 | molecule_symmetric_difference_distance,
12 | reaction_jaccard_distance,
13 | reaction_symmetric_difference_distance,
14 | )
15 | from syntheseus.search.graph.route import SynthesisGraph
16 |
17 |
18 | def test_empty_input() -> None:
19 | """Test there is no crash when an empty list of routes is input."""
20 | distinct_routes = estimate_packing_number(
21 | routes=[], radius=0.5, distance_metric=reaction_jaccard_distance, num_tries=100
22 | )
23 | assert len(distinct_routes) == 0
24 |
25 |
26 | @pytest.mark.parametrize(
27 | "metric, threshold, expected_packing_number",
28 | [
29 | # Easy case 1: distance threshold is very large so only 1 route can be returned.
30 | # For Jaccard distance choose a value of 1 (the max value dist distance can take)
31 | # For symmetric difference distance choose a value of 20 (more than the max number of mols/rxns in a route)
32 | (molecule_jaccard_distance, 1.0, 1),
33 | (reaction_jaccard_distance, 1.0, 1),
34 | (molecule_symmetric_difference_distance, 20, 1),
35 | (reaction_symmetric_difference_distance, 20, 1),
36 | #
37 | # Easy case 2: distance threshold = 0 so it should just count number of distinct routes (11)
38 | (molecule_jaccard_distance, 0, 11),
39 | (reaction_jaccard_distance, 0, 11),
40 | (molecule_symmetric_difference_distance, 0, 11),
41 | (reaction_symmetric_difference_distance, 0, 11),
42 | #
43 | # Individual harder cases
44 | (molecule_jaccard_distance, 0.8, 2),
45 | (molecule_jaccard_distance, 0.5, 7),
46 | (molecule_jaccard_distance, 0.2, 9),
47 | (
48 | reaction_jaccard_distance,
49 | 0.99,
50 | 7,
51 | ), # routes with completely non-overlapping reaction sets
52 | (reaction_symmetric_difference_distance, 4, 6),
53 | (molecule_symmetric_difference_distance, 2, 7), # differ in at least 2 molecules
54 | ],
55 | )
56 | def test_estimate_packing_number(
57 | sample_synthesis_routes: list[SynthesisGraph],
58 | metric,
59 | threshold: float,
60 | expected_packing_number: int,
61 | ) -> None:
62 | """
63 | Check that after a large number of trials, the correct packing number is found
64 | for a set of routes.
65 | """
66 |
67 | # Run the packing number estimation
68 | distinct_routes = estimate_packing_number(
69 | routes=sample_synthesis_routes,
70 | radius=threshold,
71 | distance_metric=metric,
72 | num_tries=1000,
73 | random_state=random.Random(100),
74 | )
75 |
76 | # Check that routes returned are all a distance > threshold from each other
77 | for i, route1 in enumerate(distinct_routes):
78 | for route2 in distinct_routes[i + 1 :]:
79 | assert metric(route1, route2) > threshold
80 |
81 | # Check that the correct packing number is found
82 | assert len(distinct_routes) == expected_packing_number
83 |
84 |
85 | @pytest.mark.parametrize(
86 | "metric, threshold, max_packing_number, expected_packing_number",
87 | [
88 | # Easy case 1: distance threshold is very large so only 1 route can be returned
89 | (reaction_jaccard_distance, 1.0, 0, 0), # limit is 0 so no routes can be returned
90 | (reaction_jaccard_distance, 1.0, 10, 1), # limit is higher than actual number
91 | (reaction_jaccard_distance, 1e-3, 5, 5), # are 11 such routes but limit is 5
92 | ],
93 | )
94 | def test_max_packing_number(
95 | sample_synthesis_routes: list[SynthesisGraph],
96 | metric,
97 | threshold: float,
98 | max_packing_number: int,
99 | expected_packing_number: int,
100 | ) -> None:
101 | """Test that max packing number actually limits the packing number."""
102 |
103 | # Run the packing number estimation
104 | distinct_routes = estimate_packing_number(
105 | routes=sample_synthesis_routes,
106 | radius=threshold,
107 | distance_metric=metric,
108 | max_packing_number=max_packing_number,
109 | num_tries=100,
110 | random_state=random.Random(100),
111 | )
112 | assert len(distinct_routes) == expected_packing_number
113 |
--------------------------------------------------------------------------------
/syntheseus/tests/interface/test_molecule.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import FrozenInstanceError, dataclass
4 | from typing import Any, Optional
5 |
6 | import pytest
7 |
8 | from syntheseus.interface.molecule import Molecule
9 |
10 |
11 | def test_positive_equality(cocs_mol: Molecule) -> None:
12 | """Various tests that 2 molecules which should be equal are actually equal."""
13 |
14 | mol_copy = Molecule(smiles=str(cocs_mol.smiles))
15 |
16 | # Test 1: original and copy should be equal
17 | assert cocs_mol == mol_copy
18 |
19 | # Test 2: although equal, they should be distinct objects
20 | assert cocs_mol is not mol_copy
21 |
22 | # Test 3: differences in metadata should not affect equality
23 | # (type ignores are because we are adding arbitrary unrealistic metadata)
24 | cocs_mol.metadata["test"] = "str1" # type: ignore[typeddict-unknown-key]
25 | mol_copy.metadata["test"] = "str2" # type: ignore[typeddict-unknown-key]
26 | mol_copy.metadata["other_field"] = "not in other mol" # type: ignore[typeddict-unknown-key]
27 | assert cocs_mol == mol_copy
28 | assert cocs_mol.metadata != mol_copy.metadata
29 |
30 |
31 | def test_negative_equality(cocs_mol: Molecule) -> None:
32 | """Various tests that molecules which should not be equal are not equal."""
33 |
34 | # Test 1: SMILES strings are not mol objects
35 | assert cocs_mol != cocs_mol.smiles
36 |
37 | # Test 2: changing the identifer makes mols not equal
38 | mol_with_different_id = Molecule(smiles=cocs_mol.smiles, identifier="different")
39 | assert cocs_mol != mol_with_different_id
40 |
41 | # Test 3: equality should only be true if class is the same,
42 | # so another object with the same fields should still be not equal
43 |
44 | @dataclass
45 | class FakeMoleculeClass:
46 | smiles: str
47 | identifier: Optional[str]
48 | metadata: dict[str, Any]
49 |
50 | fake_mol = FakeMoleculeClass(smiles=cocs_mol.smiles, identifier=None, metadata=dict())
51 | assert fake_mol != cocs_mol
52 |
53 | # Test 4: same molecule but with non-canonical SMILES will still compare to False
54 | non_canonical_mol = Molecule(smiles="SCOC", canonicalize=False)
55 | assert non_canonical_mol != cocs_mol
56 |
57 |
58 | def test_frozen(cocs_mol: Molecule) -> None:
59 | """Test that the fields of the Molecule cannot be modified (i.e. is actually frozen)."""
60 | with pytest.raises(FrozenInstanceError):
61 | # type ignore is because mypy complains we are modifying a frozen field, which is the point of the test
62 | cocs_mol.smiles = "xyz" # type: ignore[misc]
63 |
64 |
65 | def test_canonicalization() -> None:
66 | """
67 | Test that the `canonicalize` argument works as expected,
68 | canonicalizing the SMILES if True and leaving it unchanged if False.
69 | """
70 | non_canonical_smiles = "OCC"
71 | canonical_smiles = "CCO"
72 |
73 | # Test 1: canonicalize=True should canonicalize the SMILES
74 | mol1 = Molecule(smiles=non_canonical_smiles, canonicalize=True)
75 | assert mol1.smiles == canonical_smiles
76 |
77 | # Test 2: canonicalize=False should leave the SMILES unchanged
78 | mol2 = Molecule(smiles=non_canonical_smiles, canonicalize=False)
79 | assert mol2.smiles == non_canonical_smiles
80 |
81 |
82 | def test_make_rdkit_mol() -> None:
83 | """Test that the argument `make_rdkit_mol` works as expected."""
84 |
85 | # Test 1: make_rdkit_mol=True
86 | smiles = "CCO"
87 | mol_with_rdkit_mol = Molecule(smiles=smiles, make_rdkit_mol=True)
88 | assert "rdkit_mol" in mol_with_rdkit_mol.metadata
89 |
90 | # Test 2: make_rdkit_mol=False
91 | mol_without_rdkit_mol = Molecule(smiles=smiles, make_rdkit_mol=False)
92 | assert "rdkit_mol" not in mol_without_rdkit_mol.metadata
93 |
94 | # Test 3: accessing the rdkit mol should create it
95 | mol_without_rdkit_mol.rdkit_mol
96 | assert "rdkit_mol" in mol_without_rdkit_mol.metadata
97 |
98 |
99 | def test_sorting() -> None:
100 | """Test that sorting molecules works as expected: by SMILES, then by identifier."""
101 |
102 | # Make individual molecules
103 | mol1 = Molecule("CC")
104 | mol2 = Molecule("CCC", identifier="")
105 | mol3 = Molecule("CCC", identifier="abc")
106 | mol4 = Molecule("CCC", identifier="def")
107 |
108 | # Test sorting
109 | mol_list = [mol4, mol3, mol2, mol1]
110 | mol_list.sort()
111 | assert mol_list == [mol1, mol2, mol3, mol4]
112 |
113 |
114 | def test_create_and_compare() -> None:
115 | # NOTE: this test has some redundancy with tests from above
116 | mol_1 = Molecule("C")
117 | mol_2 = Molecule("C1=CC(N)=CC=C1")
118 | mol_3 = Molecule("c1cccc(N)c1")
119 |
120 | assert mol_1 < mol_2 # Lexicographical comparison on SMILES.
121 | assert mol_2 == mol_3 # Should be equal after canonicalization.
122 |
123 |
124 | def test_order_of_components() -> None:
125 | assert Molecule("C.CC") == Molecule("CC.C")
126 |
127 |
128 | def test_create_invalid() -> None:
129 | with pytest.raises(ValueError):
130 | Molecule("not-a-real-SMILES")
131 |
--------------------------------------------------------------------------------
/syntheseus/tests/search/graph/test_message_passing.py:
--------------------------------------------------------------------------------
1 | """
2 | Message passing is used in other algorithms and is tested implicitly in the algorithm tests.
3 |
4 | Therefore this file just contains minimal tests of correctness and edge cases.
5 | """
6 |
7 | from __future__ import annotations
8 |
9 | import pytest
10 |
11 | from syntheseus.search.graph.and_or import AndOrGraph
12 | from syntheseus.search.graph.message_passing import (
13 | depth_update,
14 | has_solution_update,
15 | run_message_passing,
16 | )
17 |
18 |
19 | def test_no_update_functions(andor_tree_non_minimal: AndOrGraph) -> None:
20 | """
21 | Test that if no update functions are provided, the message passing algorithm
22 | doesn't actually run.
23 | """
24 | g = andor_tree_non_minimal # rename for brevity
25 | output = run_message_passing(g, g.nodes(), update_fns=[], max_iterations=1)
26 | assert len(output) == 0
27 |
28 |
29 | def test_no_input_nodes(andor_tree_non_minimal: AndOrGraph) -> None:
30 | """
31 | If no input nodes are provided, the message passing algorithm should
32 | terminate without running.
33 | """
34 | g = andor_tree_non_minimal # rename for brevity
35 | output = run_message_passing(g, [], update_fns=[has_solution_update], max_iterations=1)
36 | assert len(output) == 0
37 |
38 |
39 | @pytest.mark.parametrize("update_successors", [True, False])
40 | def test_update_successors(andor_tree_non_minimal: AndOrGraph, update_successors: bool) -> None:
41 | """
42 | Test that the "update successors" function works as expected by
43 | setting the "has_solution" attribute of the root node to False.
44 |
45 | If update_successors=False then message passing should terminate after 1 iteration.
46 | However, if update_successors=True then message passing should terminate visiting
47 | the root node and its children (3 iterations). In both cases only 1 node should be updated.
48 | """
49 | g = andor_tree_non_minimal # rename for brevity
50 | if update_successors:
51 | enough_iterations = 3
52 | else:
53 | enough_iterations = 1
54 | too_few_iterations = enough_iterations - 1
55 |
56 | # Test 1: in both cases, root node should be updated to has_solution=True
57 | # and that should be the only node updated
58 | g.root_node.has_solution = False
59 | output = run_message_passing(
60 | g,
61 | [g.root_node],
62 | update_fns=[has_solution_update],
63 | update_successors=update_successors,
64 | max_iterations=enough_iterations,
65 | ) # should run without error
66 | assert g.root_node.has_solution
67 | assert len(output) == 1
68 |
69 | # Test 2: should raise error if too few iterations
70 | g.root_node.has_solution = False
71 | with pytest.raises(RuntimeError):
72 | run_message_passing(
73 | g,
74 | [g.root_node],
75 | update_fns=[has_solution_update],
76 | update_successors=update_successors,
77 | max_iterations=too_few_iterations,
78 | )
79 |
80 |
81 | @pytest.mark.parametrize("update_predecessors", [True, False])
82 | def test_update_predecessors(andor_tree_non_minimal: AndOrGraph, update_predecessors: bool) -> None:
83 | """Similar to above but for predecessors, updating the depth of a node."""
84 |
85 | g = andor_tree_non_minimal # rename for brevity
86 | node_to_perturb = list(g.successors(g.root_node))[0]
87 | if update_predecessors: # NOTE: only enough if not updating successors
88 | too_few_iterations = 1
89 | enough_iterations = 2
90 | else:
91 | too_few_iterations = 0
92 | enough_iterations = 1
93 |
94 | # Test 1: set depth of child node incorrectly
95 | # and check that it is updated correctly.
96 | node_to_perturb.depth = 500
97 | output = run_message_passing(
98 | g,
99 | [node_to_perturb],
100 | update_fns=[depth_update],
101 | update_predecessors=update_predecessors,
102 | max_iterations=enough_iterations,
103 | update_successors=False,
104 | ) # should run without error
105 | assert node_to_perturb.depth == 1
106 | assert len(output) == 1
107 |
108 | # Test 2: should raise error if too few iterations
109 | node_to_perturb.depth = 500
110 | with pytest.raises(RuntimeError):
111 | run_message_passing(
112 | g,
113 | [node_to_perturb],
114 | update_fns=[depth_update],
115 | update_predecessors=update_predecessors,
116 | max_iterations=too_few_iterations,
117 | update_successors=False,
118 | )
119 |
120 |
121 | def update_fn_which_can_never_converge(*args, **kwargs) -> bool:
122 | return True
123 |
124 |
125 | def test_no_convergence(andor_tree_non_minimal: AndOrGraph) -> None:
126 | """Test that if no convergence is reached, an error is raised."""
127 | g = andor_tree_non_minimal # rename for brevity
128 | with pytest.raises(RuntimeError):
129 | run_message_passing(
130 | g, g.nodes(), update_fns=[update_fn_which_can_never_converge], max_iterations=1_000
131 | )
132 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/utils/misc.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import multiprocessing
3 | import os
4 | import random
5 | import warnings
6 | from contextlib import contextmanager, redirect_stderr, redirect_stdout
7 | from dataclasses import fields, is_dataclass
8 | from itertools import islice
9 | from os import devnull
10 | from typing import Any, Dict, Iterable, Iterator, List, Optional
11 |
12 | import numpy as np
13 | from rdkit import rdBase
14 |
15 | from syntheseus.interface.bag import Bag
16 | from syntheseus.interface.molecule import Molecule
17 |
18 |
19 | def set_random_seed(seed: int = 0) -> None:
20 | """Set random seed for `Python`, `torch` and `numpy`."""
21 | random.seed(seed)
22 | np.random.seed(seed)
23 |
24 | # If `torch` is installed set its seed as well.
25 | try:
26 | import torch
27 |
28 | torch.manual_seed(seed)
29 | except ModuleNotFoundError:
30 | pass
31 |
32 |
33 | @contextmanager
34 | def suppress_outputs():
35 | """Suppress messages written to both stdout and stderr."""
36 | with open(devnull, "w") as fnull:
37 | with redirect_stderr(fnull), redirect_stdout(fnull):
38 | # Save the root logger handlers in order to restore them afterwards.
39 | root_handlers = list(logging.root.handlers)
40 | logging.disable(logging.CRITICAL)
41 |
42 | yield
43 |
44 | logging.root.handlers = root_handlers
45 | logging.disable(logging.NOTSET)
46 |
47 |
48 | @contextmanager
49 | def suppress_rdkit_outputs():
50 | """Suppress warning messages produced by `rdkit`."""
51 | try:
52 | previous_settings = dict(line.split(":") for line in rdBase.LogStatus().split("\n"))
53 | except Exception:
54 | # If `rdkit` internals change in the future we give up on restoring previous settings
55 | warnings.warn("Could not read rdkit log settings, warnings will be silenced permanently")
56 | previous_settings = {}
57 |
58 | rdBase.DisableLog("rdApp.*")
59 |
60 | yield
61 |
62 | for level, setting in previous_settings.items():
63 | if setting == "enabled":
64 | rdBase.EnableLog(level)
65 |
66 |
67 | def dictify(data: Any) -> Any:
68 | # Need to ensure we make return objects fully serializable
69 | if isinstance(data, (int, float, str)) or data is None:
70 | return data
71 | elif isinstance(data, Molecule):
72 | return {"smiles": data.smiles}
73 | elif isinstance(data, (List, tuple, Bag)):
74 | # Captures lists of `Prediction`s
75 | return [dictify(x) for x in data]
76 | elif isinstance(data, dict):
77 | return {k: dictify(v) for k, v in data.items()}
78 | elif is_dataclass(data):
79 | result = {}
80 | for f in fields(data):
81 | value = getattr(data, f.name)
82 | result[f.name] = dictify(value)
83 | return result
84 | else:
85 | raise TypeError(f"Type {type(data)} cannot be handled by `dictify`")
86 |
87 |
88 | def asdict_extended(data) -> Dict[str, Any]:
89 | """Convert a dataclass containing various reaction-related objects into a dict."""
90 | if not is_dataclass(data):
91 | raise TypeError(f"asdict_extended only for use on dataclasses, input is type {type(data)}")
92 |
93 | return dictify(data)
94 |
95 |
96 | def undictify_bag_of_molecules(data: List[Dict[str, str]]) -> Bag[Molecule]:
97 | """Recovers a bag of molecules serialized with `dictify`."""
98 | return Bag(Molecule(d["smiles"]) for d in data)
99 |
100 |
101 | def parallelize(
102 | fn,
103 | inputs: Iterable,
104 | num_processes: int = 0,
105 | chunksize: int = 32,
106 | num_chunks_per_process_per_segment: Optional[int] = 64,
107 | ) -> Iterator:
108 | """Parallelize an appliation of an arbitrary function using a pool of processes."""
109 | if num_processes == 0:
110 | yield from map(fn, inputs)
111 | else:
112 | # Needed for the chunking code to work on repeatable iterables e.g. lists.
113 | inputs = iter(inputs)
114 |
115 | with multiprocessing.Pool(num_processes) as pool:
116 | if num_chunks_per_process_per_segment is None:
117 | yield from pool.imap(fn, inputs, chunksize=chunksize)
118 | else:
119 | # A new segment will only be started if the previous one was consumed; this avoids doing
120 | # all the work upfront and storing it in memory if the consumer of the output is slow.
121 | segmentsize = num_chunks_per_process_per_segment * num_processes * chunksize
122 |
123 | non_empty = True
124 | while non_empty:
125 | non_empty = False
126 |
127 | # Call `imap` segment-by-segment to make sure the consumer of the output keeps up.
128 | for result in pool.imap(fn, islice(inputs, segmentsize), chunksize=chunksize):
129 | yield result
130 | non_empty = True
131 |
132 |
133 | def cpu_count(default: int = 8) -> int:
134 | """Return the number of CPUs, fallback to `default` if it cannot be determined."""
135 | return os.cpu_count() or default
136 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/inference/retro_knn.py:
--------------------------------------------------------------------------------
1 | """Inference wrapper for the RetroKNN model.
2 |
3 | Paper: https://arxiv.org/abs/2306.04123
4 | """
5 |
6 | from pathlib import Path
7 | from typing import List, Optional, Sequence, Union
8 |
9 | import numpy as np
10 |
11 | from syntheseus.interface.molecule import Molecule
12 | from syntheseus.interface.reaction import SingleProductReaction
13 | from syntheseus.reaction_prediction.inference.local_retro import LocalRetroModel
14 | from syntheseus.reaction_prediction.utils.inference import get_unique_file_in_dir
15 |
16 |
17 | class RetroKNNModel(LocalRetroModel):
18 | """Warpper for RetroKNN model."""
19 |
20 | def __init__(self, model_dir: Optional[Union[str, Path]] = None, *args, **kwargs) -> None:
21 | """Initializes the RetroKNN model wrapper.
22 |
23 | Assumed format of the model directory:
24 | - `model_dir/local_retro` contains the files needed to load the LocalRetro wrapper
25 | - `model_dir/knn/` contains the adapter checkpoint as the only `*.pt` file
26 | - `model_dir/knn/datastore` contains the data store files
27 | """
28 | model_dir = Path(model_dir or self.get_default_model_dir())
29 | super().__init__(model_dir / "local_retro", *args, **kwargs)
30 |
31 | import torch
32 |
33 | from syntheseus.reaction_prediction.models.retro_knn import Adapter
34 |
35 | adapter_chkpt_path = get_unique_file_in_dir(Path(model_dir) / "knn", pattern="*.pt")
36 | datastore_path = Path(model_dir) / "knn" / "datastore"
37 |
38 | import faiss
39 | import faiss.contrib.torch_utils # make faiss available for torch tensors
40 |
41 | def load_data_store(path: Path, device: str):
42 | index = faiss.read_index(str(path), faiss.IO_FLAG_ONDISK_SAME_DIR)
43 |
44 | if device == "cpu":
45 | return index
46 | else:
47 | res = faiss.StandardGpuResources()
48 | co = faiss.GpuClonerOptions()
49 | co.useFloat16 = True
50 | return faiss.index_cpu_to_gpu(res, 0, index, co)
51 |
52 | self.atom_store = load_data_store(datastore_path / "data.atom_idx", device=self.device)
53 | self.bond_store = load_data_store(datastore_path / "data.bond_idx", device=self.device)
54 | self.raw_data = np.load(datastore_path / "data.npz")
55 |
56 | self.adapter = Adapter(self.model.linearB.weight.shape[0], k=32).to(self.device)
57 | self.adapter.load_state_dict(torch.load(adapter_chkpt_path, map_location=self.device))
58 | self.adapter.eval()
59 |
60 | def _forward_localretro(self, bg):
61 | from local_retro.scripts.model_utils import pair_atom_feats, unbatch_feats, unbatch_mask
62 |
63 | bg = bg.to(self.device)
64 | node_feats = bg.ndata.pop("h").to(self.device)
65 | edge_feats = bg.edata.pop("e").to(self.device)
66 |
67 | node_feats = self.model.mpnn(bg, node_feats, edge_feats)
68 | atom_feats = node_feats
69 | bond_feats = self.model.linearB(pair_atom_feats(bg, node_feats))
70 | edit_feats, mask = unbatch_mask(bg, atom_feats, bond_feats)
71 | _, edit_feats = self.model.att(edit_feats, mask)
72 |
73 | atom_feats, bond_feats = unbatch_feats(bg, edit_feats)
74 | atom_outs = self.model.atom_linear(atom_feats)
75 | bond_outs = self.model.bond_linear(bond_feats)
76 |
77 | return atom_outs, bond_outs, atom_feats, bond_feats
78 |
79 | def _get_reactions(
80 | self, inputs: List[Molecule], num_results: int
81 | ) -> List[Sequence[SingleProductReaction]]:
82 | import torch
83 |
84 | from syntheseus.reaction_prediction.models.retro_knn import knn_prob
85 |
86 | batch = self._mols_to_batch(inputs)
87 | (
88 | batch_atom_logits,
89 | batch_bond_logits,
90 | atom_feats,
91 | bond_feats,
92 | ) = self._forward_localretro(batch)
93 | sg = batch.remove_self_loop().to(self.device)
94 |
95 | node_dis, _ = self.atom_store.search(atom_feats, k=32)
96 | edge_dis, _ = self.bond_store.search(bond_feats, k=32)
97 |
98 | node_t, node_p, edge_t, edge_p = self.adapter(
99 | sg, atom_feats, bond_feats, node_dis, edge_dis
100 | )
101 |
102 | batch_atom_prob_nn = torch.nn.Softmax(dim=1)(batch_atom_logits)
103 | batch_bond_prob_nn = torch.nn.Softmax(dim=1)(batch_bond_logits)
104 |
105 | atom_output_label = torch.from_numpy(self.raw_data["atom_output_label"]).to(self.device)
106 | bond_output_label = torch.from_numpy(self.raw_data["bond_output_label"]).to(self.device)
107 |
108 | batch_atom_prob_knn = knn_prob(
109 | atom_feats, self.atom_store, atom_output_label, batch_atom_logits.shape[1], 32, node_t
110 | )
111 | batch_bond_prob_knn = knn_prob(
112 | bond_feats, self.bond_store, bond_output_label, batch_bond_logits.shape[1], 32, edge_t
113 | )
114 |
115 | batch_atom_logits = node_p * batch_atom_prob_nn + (1 - node_p) * batch_atom_prob_knn
116 | batch_bond_logits = edge_p * batch_bond_prob_nn + (1 - edge_p) * batch_bond_prob_knn
117 |
118 | return self._build_batch_predictions(
119 | batch, num_results, inputs, batch_atom_logits, batch_bond_logits
120 | )
121 |
--------------------------------------------------------------------------------
/syntheseus/search/graph/base_graph.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import abc
4 | from collections.abc import Container, Iterable, Sized
5 | from typing import Generic, Sequence, TypeVar
6 |
7 | import networkx as nx
8 |
9 | from syntheseus.interface.molecule import Molecule
10 | from syntheseus.interface.reaction import SingleProductReaction
11 | from syntheseus.search.graph.node import BaseGraphNode
12 |
13 | NodeType = TypeVar(
14 | "NodeType",
15 | )
16 | SearchNodeType = TypeVar("SearchNodeType", bound=BaseGraphNode)
17 |
18 |
19 | class BaseReactionGraph(Container, Sized, Generic[NodeType], abc.ABC):
20 | """
21 | Base class for holding a retrosynthesis graph where nodes represent molecules/reactions.
22 | Retrosynthesis graphs have the following properties:
23 |
24 | - Directed: A node usually represents a reaction/molecule used to synthesize its predecessors.
25 | Search generally goes from predecessors to successors,
26 | while the actual synthesis would go from successors to predecessors.
27 | - Root node: there is a root node representing the molecule to be synthesized.
28 | This node should never have a parent.
29 | - Implicit: the graph is a subgraph of a much larger reaction graph,
30 | so nodes can be "unexpanded", meaning that no children nodes are specified,
31 | even though the node should have children. Search typically "expands"
32 | a node by adding reactions.
33 |
34 | The actual implementation of the graph uses networkx's DiGraph
35 | (but it does not inherit from it because this causes many methods to break).
36 | """
37 |
38 | def __init__(self, *args, **kwargs) -> None:
39 | self._graph = nx.DiGraph(*args, **kwargs)
40 |
41 | def __contains__(self, node: object) -> bool:
42 | return node in self._graph
43 |
44 | def __len__(self) -> int:
45 | return len(self._graph)
46 |
47 | @property
48 | @abc.abstractmethod
49 | def root_node(self) -> NodeType:
50 | """Root node of the graph, representing the molecule to be synthesized."""
51 | pass
52 |
53 | @property
54 | @abc.abstractmethod
55 | def root_mol(self) -> Molecule:
56 | """The molecule to be synthesized."""
57 | pass
58 |
59 | @abc.abstractmethod
60 | def is_minimal(self) -> bool:
61 | """Checks whether this is a *minimal* graph (i.e. contains a single synthesis route)."""
62 | pass
63 |
64 | def is_tree(self) -> bool:
65 | """Performs a [possibly expensive] check to see if the graph is a tree."""
66 | return nx.is_arborescence(self._graph)
67 |
68 | def nodes(self) -> Iterable[NodeType]:
69 | return self._graph.nodes()
70 |
71 | def predecessors(self, node: NodeType) -> Iterable[NodeType]:
72 | """Returns the predecessors of a node."""
73 | return self._graph.predecessors(node)
74 |
75 | def successors(self, node: NodeType) -> Iterable[NodeType]:
76 | """Returns the successors of a node."""
77 | return self._graph.successors(node)
78 |
79 | def assert_validity(self) -> None:
80 | """
81 | A (potentially expensive) function to check the graph's validity.
82 | """
83 |
84 | # Check root node is in the graph
85 | assert self.root_node in self
86 |
87 | # Check root node has no parents
88 | assert len(list(self.predecessors(self.root_node))) == 0
89 |
90 | # Graph should be connected
91 | assert nx.is_weakly_connected(self._graph)
92 |
93 | def __eq__(self, __value: object) -> bool:
94 | """Equality is defined as having the same root node, nodes, and edges."""
95 | if isinstance(__value, BaseReactionGraph):
96 | return (self.root_node == __value.root_node) and nx.utils.graphs_equal(
97 | self._graph, __value._graph
98 | )
99 | else:
100 | return False
101 |
102 | @abc.abstractmethod
103 | def expand_with_reactions(
104 | self,
105 | reactions: list[SingleProductReaction],
106 | node: NodeType,
107 | ensure_tree: bool,
108 | ) -> Sequence[NodeType]:
109 | """
110 | Expands a node with a series of reactions.
111 | For reproducibility, it ensures that the order of the nodes is consistent,
112 | which is why a sequence is returned.
113 | It is also encouraged (but not required) for all return nodes to be unique.
114 |
115 | Subclass implementations should ensure that the root node never has any predecessors
116 | as a result of this function, and if ensure_tree=True, should also ensure that the
117 | graph remains a tree.
118 | """
119 | pass
120 |
121 |
122 | class RetrosynthesisSearchGraph(BaseReactionGraph[SearchNodeType], Generic[SearchNodeType]):
123 | """Subclass with more specific type requirements for the nodes."""
124 |
125 | @abc.abstractmethod
126 | def _assert_valid_reactions(self) -> None:
127 | """Checks that all reactions are valid."""
128 | pass
129 |
130 | def assert_validity(self) -> None:
131 | super().assert_validity()
132 |
133 | # Check that all nodes with children are marked as expanded
134 | for n in self._graph.nodes:
135 | if not n.is_expanded:
136 | assert len(list(self.successors(n))) == 0
137 |
138 | # Check valid reactions
139 | self._assert_valid_reactions()
140 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
Navigating the labyrinth of synthesis planning
4 |
5 | ---
6 |
7 |
8 | Docs •
9 | CLI •
10 | Tutorials •
11 | Paper
12 |
13 |
14 | [](https://github.com/microsoft/syntheseus/actions/workflows/ci.yml)
15 | [](https://www.python.org/downloads/)
16 | [](https://pypi.org/project/syntheseus/)
17 | [](https://github.com/ambv/black)
18 | [](https://github.com/microsoft/syntheseus/blob/main/LICENSE)
19 |
20 |
21 |
22 | ## Overview
23 |
24 | Syntheseus is a package for end-to-end retrosynthetic planning.
25 | - ⚒️ Combines search algorithms and reaction models in a standardized way
26 | - 🧭 Includes implementations of common search algorithms
27 | - 🧪 Includes wrappers for state-of-the-art reaction models
28 | - ⚙️ Exposes a simple API to plug in custom models and algorithms
29 | - 📈 Can be used to benchmark components of a retrosynthesis pipeline
30 |
31 | ## Quick Start
32 |
33 | To install `syntheseus` with all the extras, run
34 |
35 | ```bash
36 | conda env create -f environment_full.yml
37 | conda activate syntheseus-full
38 |
39 | pip install "syntheseus[all]"
40 | ```
41 |
42 | This sample environment pins `torch` to version `2.2.2`. To run the models under `1.x` please downgrade to `syntheseus 0.6.0`.
43 |
44 | See [here](https://microsoft.github.io/syntheseus/stable/installation) if you prefer a more lightweight installation that only includes the parts you actually need.
45 |
46 | ## Citation and usage
47 |
48 | Since the release of our package, we've been thrilled to see syntheseus be used in the following projects:
49 |
50 | | **Project** | **Usage** | **Reference(s)** |
51 | |:--------------|:-----|:-----------|
52 | |Retro-fallback search|Multi-step search|ICLR [paper](https://arxiv.org/abs/2310.09270), [code](https://github.com/AustinT/retro-fallback-iclr24)|
53 | |RetroGFN|Pre-packaged single-step models|arXiv [paper](https://arxiv.org/abs/2406.18739), [code](https://github.com/gmum/RetroGFN)|
54 | |TANGO|Single-step and multi-step|arXiv [paper](https://arxiv.org/abs/2410.11527)|
55 | |SimpRetro|Multi-step search|JCIM [paper](https://pubs.acs.org/doi/10.1021/acs.jcim.4c00432), [code](https://github.com/catalystforyou/SimpRetro)|
56 |
57 | If you use syntheseus in an academic project, please consider citing our
58 | [associated paper from Faraday Discussions](https://pubs.rsc.org/en/content/articlelanding/2024/fd/d4fd00093e)
59 | (bibtex below). You can also message us or submit a PR to have your project added to the table above!
60 |
61 | ```
62 | @article{maziarz2024re,
63 | title={Re-evaluating retrosynthesis algorithms with syntheseus},
64 | author={Maziarz, Krzysztof and Tripp, Austin and Liu, Guoqing and Stanley, Megan and Xie, Shufang and Gainski, Piotr and Seidl, Philipp and Segler, Marwin},
65 | journal={Faraday Discussions},
66 | year={2024},
67 | publisher={Royal Society of Chemistry}
68 | }
69 | ```
70 |
71 | ## Development
72 |
73 | Syntheseus is currently under active development.
74 | If you want to help us develop syntheseus please install and run `pre-commit`
75 | checks before committing code.
76 |
77 | We use `pytest` for testing. Please make sure tests pass on your branch before
78 | submitting a PR (and try to maintain high test coverage).
79 |
80 | ```bash
81 | python -m pytest --cov syntheseus/tests
82 | ```
83 |
84 | ## Contributing
85 |
86 | This project welcomes contributions and suggestions. Most contributions require you to agree to a
87 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
88 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
89 |
90 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide
91 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
92 | provided by the bot. You will only need to do this once across all repos using our CLA.
93 |
94 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
95 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
96 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
97 |
98 | ## Trademarks
99 |
100 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
101 | trademarks or logos is subject to and must follow
102 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
103 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
104 | Any use of third-party trademarks or logos are subject to those third-party's policies.
105 |
--------------------------------------------------------------------------------
/syntheseus/tests/search/graph/test_molset.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for MolSet graph/nodes.
3 |
4 | Because a lot of the behaviour is implicitly tested when the algorithms are tested,
5 | the tests here are sparse and mainly check edge cases which won't come up in algorithms.
6 | """
7 | import pytest
8 |
9 | from syntheseus.interface.bag import Bag
10 | from syntheseus.interface.molecule import Molecule
11 | from syntheseus.interface.reaction import SingleProductReaction
12 | from syntheseus.search.graph.molset import MolSetGraph, MolSetNode
13 | from syntheseus.search.graph.route import SynthesisGraph
14 | from syntheseus.tests.search.graph.test_base import BaseNodeTest
15 |
16 |
17 | class TestMolSetNode(BaseNodeTest):
18 | def get_node(self):
19 | return MolSetNode(mols=Bag([Molecule("CC")]))
20 |
21 | def test_nodes_not_frozen(self):
22 | node = self.get_node()
23 | node.mols = None
24 |
25 |
26 | class TestMolSetGraph:
27 | """Tests for MolSetGraph: they mostly follow the same pattern as AND/OR graph tests."""
28 |
29 | def test_basic_properties1(
30 | self, cocs_mol: Molecule, molset_tree_non_minimal: MolSetGraph
31 | ) -> None:
32 | """Test some basic properties (len, contains) on a graph."""
33 | assert len(molset_tree_non_minimal) == 10
34 | assert molset_tree_non_minimal.root_node in molset_tree_non_minimal
35 | assert molset_tree_non_minimal.root_mol == cocs_mol
36 |
37 | def _check_not_minimal(self, graph: MolSetGraph) -> None:
38 | """Checks that a graph is not minimal (eliminate code duplication between tests below)."""
39 | assert not graph.is_minimal()
40 | assert graph.is_tree() # should always be a tree
41 | with pytest.raises(AssertionError):
42 | graph.to_synthesis_graph()
43 |
44 | def test_minimal_negative(self, molset_tree_non_minimal: MolSetGraph) -> None:
45 | """Test that a non-minimal molset graph is not identified as minimal."""
46 | self._check_not_minimal(molset_tree_non_minimal)
47 |
48 | def test_minimal_negative_hard(self, molset_tree_almost_minimal: MolSetGraph) -> None:
49 | """Harder test for minimal negative: the 'almost minimal' graph."""
50 | self._check_not_minimal(molset_tree_almost_minimal)
51 |
52 | def test_minimal_positive(
53 | self, molset_tree_minimal: MolSetGraph, minimal_synthesis_graph: SynthesisGraph
54 | ) -> None:
55 | graph = molset_tree_minimal
56 | assert graph.is_minimal()
57 | assert graph.is_tree() # should always be a tree
58 | route = graph.to_synthesis_graph() # should run without error
59 | assert route == minimal_synthesis_graph
60 |
61 | @pytest.mark.parametrize(
62 | "reason", ["root_has_parent", "unexpanded_expanded", "reactions_dont_match"]
63 | )
64 | def test_assert_validity_negative(
65 | self, molset_tree_non_minimal: MolSetGraph, reason: str
66 | ) -> None:
67 | """
68 | Test that an invalid MolSet graph is correctly identified as invalid.
69 |
70 | Different reasons for invalidity are tested.
71 | """
72 | graph = molset_tree_non_minimal
73 |
74 | if reason == "root_has_parent":
75 | # Add a random connection to the root node
76 | random_node = [n for n in graph.nodes() if n is not graph.root_node][0]
77 | graph._graph.add_edge(random_node, graph.root_node)
78 | elif reason == "unexpanded_expanded":
79 | graph.root_node.is_expanded = False
80 | elif reason == "reactions_dont_match":
81 | child = list(graph.successors(graph.root_node))[0]
82 |
83 | # Set the reaction to a random incorrect reaction
84 | graph._graph.edges[graph.root_node, child]["reaction"] = SingleProductReaction(
85 | reactants=Bag([Molecule("OO")]), product=Molecule("CC")
86 | )
87 |
88 | # Not only should the graph not be valid below,
89 | # but specifically the reaction should also be invalid
90 | with pytest.raises(AssertionError):
91 | graph._assert_valid_reactions()
92 | else:
93 | raise ValueError(f"Unsupported reason: {reason}")
94 |
95 | with pytest.raises(AssertionError):
96 | graph.assert_validity()
97 |
98 | @pytest.mark.parametrize("reason", ["already_expanded", "wrong_product"])
99 | def test_invalid_expansions(
100 | self,
101 | molset_tree_non_minimal: MolSetGraph,
102 | rxn_cs_from_cc: SingleProductReaction,
103 | reason: str,
104 | ) -> None:
105 | """
106 | Test that invalid expansions raise an error.
107 | Note that valid expansions are tested implicitly elsewhere.
108 |
109 | NOTE: because graphs with 1 node per molset are not properly supported now,
110 | we don't test the case where the product of a reaction is the root mol.
111 | """
112 | graph = molset_tree_non_minimal
113 | cc_node = [n for n in graph.nodes() if n.mols == {Molecule("CC")}].pop()
114 |
115 | if reason == "already_expanded":
116 | with pytest.raises(AssertionError):
117 | graph.expand_with_reactions(reactions=[], node=graph.root_node, ensure_tree=True)
118 | elif reason == "wrong_product":
119 | with pytest.raises(AssertionError):
120 | graph.expand_with_reactions(
121 | reactions=[rxn_cs_from_cc], node=cc_node, ensure_tree=True
122 | )
123 | else:
124 | raise ValueError(f"Unsupported reason: {reason}")
125 |
--------------------------------------------------------------------------------
/syntheseus/tests/reaction_prediction/data/test_dataset.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os
3 | from pathlib import Path
4 | from typing import Generator
5 |
6 | import pytest
7 |
8 | from syntheseus.interface.bag import Bag
9 | from syntheseus.interface.molecule import Molecule
10 | from syntheseus.reaction_prediction.data.dataset import (
11 | CSV_REACTION_SMILES_COLUMN_NAME,
12 | DataFold,
13 | DataFormat,
14 | DiskReactionDataset,
15 | )
16 | from syntheseus.reaction_prediction.data.reaction_sample import ReactionSample
17 |
18 |
19 | @pytest.fixture(params=[False, True])
20 | def temp_path(request, tmp_path: Path) -> Generator[Path, None, None]:
21 | """Fixture to provide a temporary path that is either absolute or relative."""
22 | if request.param:
23 | # The built-in `tmp_path` fixture is absolute; if the fixture parameter is `True` we convert
24 | # that to a relative path. This is done by stripping away the root `/` that signifies an
25 | # absolute path and changing the working directory to `/` so that the path remains correct.
26 |
27 | old_working_dir = os.getcwd()
28 | os.chdir("/")
29 |
30 | yield Path(*list(tmp_path.parts)[1:])
31 |
32 | os.chdir(old_working_dir)
33 | else:
34 | yield tmp_path
35 |
36 |
37 | @pytest.mark.parametrize("mapped", [False, True])
38 | def test_save_and_load(temp_path: Path, mapped: bool) -> None:
39 | samples = [
40 | ReactionSample.from_reaction_smiles_strict(reaction_smiles, mapped=mapped)
41 | for reaction_smiles in [
42 | "O[c:1]1[cH:2][c:3](=[O:4])[nH:5][cH:6][cH:7]1>>[cH:1]1[cH:2][c:3](=[O:4])[nH:5][cH:6][cH:7]1",
43 | "CC(C)(C)OC(=O)[N:1]1[CH2:2][CH2:3][C@H:4]([F:5])[CH2:6]1>>[NH:1]1[CH2:2][CH2:3][C@H:4]([F:5])[CH2:6]1",
44 | ]
45 | ]
46 |
47 | for fold in DataFold:
48 | DiskReactionDataset.save_samples_to_file(data_dir=temp_path, fold=fold, samples=samples)
49 |
50 | for load_format in [None, DataFormat.JSONL]:
51 | # Now try to load the data we just saved.
52 | dataset = DiskReactionDataset(temp_path, sample_cls=ReactionSample, data_format=load_format)
53 |
54 | for fold in DataFold:
55 | assert list(dataset[fold]) == samples
56 |
57 |
58 | @pytest.mark.parametrize("format", [DataFormat.CSV, DataFormat.SMILES])
59 | def test_load_external_format(temp_path: Path, format: DataFormat) -> None:
60 | # Example reaction SMILES, purposefully using non-canonical forms of reactants and product.
61 | reaction_smiles = (
62 | "[cH:1]1[cH:2][c:3]([CH3:4])[cH:5][cH:6][c:7]1Br.B(O)(O)[c:8]1[cH:9][cH:10][c:11]([CH3:12])[cH:13][cH:14]1>>"
63 | "[cH:1]1[cH:2][c:3]([CH3:4])[cH:5][cH:6][c:7]1[c:8]2[cH:14][cH:13][c:11]([CH3:12])[cH:10][cH:9]2"
64 | )
65 |
66 | filename = DiskReactionDataset.get_filename_suffix(format=format, fold=DataFold.TRAIN)
67 | with open(temp_path / filename, "wt") as f:
68 | if format == DataFormat.CSV:
69 | writer = csv.DictWriter(f, fieldnames=["id", "class", CSV_REACTION_SMILES_COLUMN_NAME])
70 | writer.writeheader()
71 | writer.writerow(
72 | {"id": 0, "class": "UNK", CSV_REACTION_SMILES_COLUMN_NAME: reaction_smiles}
73 | )
74 | else:
75 | f.write(f"{reaction_smiles}\n")
76 |
77 | for load_format in [None, format]:
78 | dataset = DiskReactionDataset(temp_path, sample_cls=ReactionSample, data_format=load_format)
79 | assert dataset.get_num_samples(DataFold.TRAIN) == 1
80 |
81 | samples = list(dataset[DataFold.TRAIN])
82 | assert len(samples) == 1
83 |
84 | [sample] = samples
85 |
86 | # After loading, the reactants and products should be in canonical SMILES form, with the
87 | # atom mapping removed.
88 | assert sample == ReactionSample(
89 | reactants=Bag([Molecule("Cc1ccc(Br)cc1"), Molecule("Cc1ccc(B(O)O)cc1")]),
90 | products=Bag([Molecule("Cc1ccc(-c2ccc(C)cc2)cc1")]),
91 | )
92 |
93 | # The original reaction SMILES should have been saved separately.
94 | assert sample.mapped_reaction_smiles == reaction_smiles
95 |
96 |
97 | @pytest.mark.parametrize("format", [DataFormat.JSONL, DataFormat.CSV, DataFormat.SMILES])
98 | def test_format_detection(temp_path: Path, format: DataFormat) -> None:
99 | other_format = (set(DataFormat) - {format}).pop()
100 |
101 | # Create two files with different extensions, so that it is ambiguous which format we want.
102 | (temp_path / DiskReactionDataset.get_filename_suffix(format, DataFold.TRAIN)).touch()
103 | (temp_path / DiskReactionDataset.get_filename_suffix(other_format, DataFold.TEST)).touch()
104 |
105 | # Loading with automatic resolution should fail.
106 | with pytest.raises(ValueError):
107 | DiskReactionDataset(data_dir=temp_path, sample_cls=ReactionSample)
108 |
109 | # Loading with an explicit format should succeed.
110 | for f in [format, other_format]:
111 | DiskReactionDataset(data_dir=temp_path, sample_cls=ReactionSample, data_format=f)
112 |
113 | # Loading with an explicit format but no matching files should fail.
114 | another_format = (set(DataFormat) - {format, other_format}).pop()
115 | with pytest.raises(ValueError):
116 | DiskReactionDataset(
117 | data_dir=temp_path, sample_cls=ReactionSample, data_format=another_format
118 | )
119 |
120 | # Create another file with the right suffix.
121 | (temp_path / f"raw_{DiskReactionDataset.get_filename_suffix(format, DataFold.TRAIN)}").touch()
122 |
123 | # Loading with an explicit format should now fail due to ambiguity.
124 | with pytest.raises(ValueError):
125 | DiskReactionDataset(data_dir=temp_path, sample_cls=ReactionSample, data_format=format)
126 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/data/dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import csv
4 | import json
5 | import logging
6 | from abc import abstractmethod
7 | from enum import Enum
8 | from functools import partial
9 | from pathlib import Path
10 | from typing import Dict, Generic, Iterable, List, Optional, Type, TypeVar, Union
11 |
12 | from more_itertools import ilen
13 |
14 | from syntheseus.reaction_prediction.data.reaction_sample import ReactionSample
15 | from syntheseus.reaction_prediction.utils.misc import asdict_extended, parallelize
16 |
17 | logger = logging.getLogger(__file__)
18 |
19 | SampleType = TypeVar("SampleType", bound=ReactionSample)
20 |
21 |
22 | CSV_REACTION_SMILES_COLUMN_NAME = "reactants>reagents>production"
23 |
24 |
25 | class DataFold(Enum):
26 | TRAIN = "train"
27 | VALIDATION = "val"
28 | TEST = "test"
29 |
30 |
31 | class DataFormat(Enum):
32 | JSONL = "jsonl"
33 | CSV = "csv"
34 | SMILES = "smi"
35 |
36 |
37 | class ReactionDataset(Generic[SampleType]):
38 | """Dataset holding raw reactions split into folds."""
39 |
40 | @abstractmethod
41 | def __getitem__(self, fold: DataFold) -> Iterable[SampleType]:
42 | pass
43 |
44 | @abstractmethod
45 | def get_num_samples(self, fold: DataFold) -> int:
46 | pass
47 |
48 |
49 | class DiskReactionDataset(ReactionDataset[SampleType]):
50 | def __init__(
51 | self,
52 | data_dir: Union[str, Path],
53 | sample_cls: Type[SampleType],
54 | num_processes: int = 0,
55 | data_format: Optional[DataFormat] = None,
56 | ):
57 | self._data_dir = Path(data_dir)
58 | self._sample_cls = sample_cls
59 | self._num_processes = num_processes
60 |
61 | paths = list(self._data_dir.iterdir())
62 |
63 | if data_format is None:
64 | logger.info(f"Detecting data format from files: {[path.name for path in paths]}")
65 | formats_to_try = list(DataFormat)
66 | else:
67 | formats_to_try = [data_format]
68 |
69 | matches = {
70 | format: DiskReactionDataset.match_paths_to_folds(format=format, paths=paths)
71 | for format in formats_to_try
72 | }
73 | matches = {key: values for key, values in matches.items() if values}
74 |
75 | if data_format is None:
76 | if len(matches) != 1:
77 | raise ValueError(
78 | f"Format detection failed (formats matching the file list: {[f.name for f in matches]})"
79 | )
80 | elif not matches:
81 | raise ValueError(
82 | f"No files matching *{{train, val, test}}.{data_format.value} were found"
83 | )
84 |
85 | [(self._data_format, self._fold_to_path)] = matches.items()
86 |
87 | if data_format is None:
88 | logger.info(f"Detected format: {self._data_format.name}")
89 |
90 | logger.info(f"Loading data from files {self._fold_to_path}")
91 | self._num_samples: Dict[DataFold, int] = {}
92 |
93 | def _get_lines(self, fold: DataFold) -> Iterable[str]:
94 | if fold not in self._fold_to_path:
95 | return []
96 | else:
97 | with open(self._fold_to_path[fold]) as f:
98 | if self._data_format == DataFormat.CSV:
99 | for row in csv.DictReader(f):
100 | if CSV_REACTION_SMILES_COLUMN_NAME not in row:
101 | raise ValueError(
102 | f"No {CSV_REACTION_SMILES_COLUMN_NAME} column found in the CSV data file"
103 | )
104 | yield row[CSV_REACTION_SMILES_COLUMN_NAME]
105 | else:
106 | for line in f:
107 | yield line.rstrip()
108 |
109 | def __getitem__(self, fold: DataFold) -> Iterable[SampleType]:
110 | if self._data_format == DataFormat.JSONL:
111 | parse_fn = partial(DiskReactionDataset.sample_from_json, sample_cls=self._sample_cls)
112 | else:
113 | parse_fn = partial(self._sample_cls.from_reaction_smiles_strict, mapped=True)
114 |
115 | yield from parallelize(parse_fn, self._get_lines(fold), num_processes=self._num_processes)
116 |
117 | def get_num_samples(self, fold: DataFold) -> int:
118 | if fold not in self._num_samples:
119 | self._num_samples[fold] = ilen(self._get_lines(fold))
120 |
121 | return self._num_samples[fold]
122 |
123 | @staticmethod
124 | def match_paths_to_folds(format: DataFormat, paths: List[Path]) -> Dict[DataFold, Path]:
125 | fold_to_path: Dict[DataFold, Path] = {}
126 | for fold in DataFold:
127 | suffix = DiskReactionDataset.get_filename_suffix(format, fold)
128 | matching_paths = [path for path in paths if path.name.endswith(suffix)]
129 |
130 | if len(matching_paths) > 1:
131 | raise ValueError(
132 | f"Found more than one {format.value} file for fold {fold.name}: {matching_paths}"
133 | )
134 |
135 | if matching_paths:
136 | [path] = matching_paths
137 | fold_to_path[fold] = path
138 |
139 | return fold_to_path
140 |
141 | @staticmethod
142 | def get_filename_suffix(format: DataFormat, fold: DataFold) -> str:
143 | return f"{fold.value}.{format.value}"
144 |
145 | @staticmethod
146 | def sample_from_json(data: str, sample_cls: Type[SampleType]) -> SampleType:
147 | return sample_cls.from_dict(json.loads(data))
148 |
149 | @staticmethod
150 | def save_samples_to_file(
151 | data_dir: Union[str, Path], fold: DataFold, samples: Iterable[SampleType]
152 | ) -> None:
153 | filename = DiskReactionDataset.get_filename_suffix(format=DataFormat.JSONL, fold=fold)
154 |
155 | with open(Path(data_dir) / filename, "wt") as f:
156 | for sample in samples:
157 | f.write(json.dumps(asdict_extended(sample)) + "\n")
158 |
--------------------------------------------------------------------------------
/syntheseus/tests/cli/test_cli.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 | import math
4 | import sys
5 | import tempfile
6 | import urllib
7 | import zipfile
8 | from pathlib import Path
9 | from typing import Generator, List
10 |
11 | import omegaconf
12 | import pytest
13 |
14 | from syntheseus.reaction_prediction.inference.config import BackwardModelClass
15 | from syntheseus.reaction_prediction.utils.testing import are_single_step_models_installed
16 |
17 | pytestmark = pytest.mark.skipif(
18 | not are_single_step_models_installed(),
19 | reason="CLI tests require all single-step models to be installed",
20 | )
21 |
22 |
23 | MODEL_CLASSES_TO_TEST = set(BackwardModelClass) - {BackwardModelClass.GLN}
24 |
25 |
26 | @pytest.fixture(scope="module")
27 | def data_dir() -> Generator[Path, None, None]:
28 | with tempfile.TemporaryDirectory() as raw_tempdir:
29 | tempdir = Path(raw_tempdir)
30 |
31 | # Download the raw USPTO-50K data released by the authors of GLN.
32 | uspto50k_zip_path = tempdir / "uspto50k.zip"
33 | urllib.request.urlretrieve(
34 | "https://figshare.com/ndownloader/files/45206101", uspto50k_zip_path
35 | )
36 |
37 | with zipfile.ZipFile(uspto50k_zip_path, "r") as f_zip:
38 | f_zip.extractall(tempdir)
39 |
40 | # Create a simple search targets file with a single target and a matching inventory following
41 | # the same example class of reactions as those used in `test_call`.
42 |
43 | search_targets_file_path = tempdir / "search_targets.smiles"
44 | with open(search_targets_file_path, "wt") as f_search_targets:
45 | f_search_targets.write("Cc1ccc(-c2ccc(C)cc2)cc1\n")
46 |
47 | inventory_file_path = tempdir / "inventory.smiles"
48 | with open(inventory_file_path, "wt") as f_inventory:
49 | for leaving_group in ["Br", "B(O)O", "I", "[Mg+]"]:
50 | f_inventory.write(f"Cc1ccc({leaving_group})cc1\n")
51 |
52 | yield tempdir
53 |
54 |
55 | @pytest.fixture
56 | def eval_cli_argv() -> List[str]:
57 | import torch
58 |
59 | return ["eval-single-step", f"num_gpus={int(torch.cuda.is_available())}"]
60 |
61 |
62 | @pytest.fixture
63 | def search_cli_argv() -> List[str]:
64 | import torch
65 |
66 | return ["search", f"use_gpu={torch.cuda.is_available()}"]
67 |
68 |
69 | def run_cli_with_argv(argv: List[str]) -> None:
70 | # The import below pulls in some optional dependencies, so do it locally to avoid executing it
71 | # if the test suite is being skipped.
72 | from syntheseus.cli.main import main
73 |
74 | sys.argv = ["syntheseus"] + argv
75 | main()
76 |
77 |
78 | def test_cli_invalid(
79 | data_dir: Path, tmpdir: Path, eval_cli_argv: List[str], search_cli_argv: List[str]
80 | ) -> None:
81 | """Test various incomplete or invalid CLI calls that should all raise an error."""
82 | argv_lists: List[List[str]] = [
83 | [],
84 | ["not-a-real-command"],
85 | eval_cli_argv + ["model_class=LocalRetro"], # No data dir.
86 | eval_cli_argv + ["model_class=LocalRetro", f"data_dir={tmpdir}"], # No data.
87 | eval_cli_argv + ["model_class=FakeModel", f"data_dir={data_dir}"], # Not a real model.
88 | search_cli_argv
89 | + [
90 | "model_class=LocalRetro",
91 | f"search_targets_file={data_dir}/search_targets.smiles",
92 | ], # No inventory.
93 | search_cli_argv
94 | + [
95 | "model_class=LocalRetro",
96 | f"inventory_smiles_file={data_dir}/inventory.smiles",
97 | ], # No search targets.
98 | search_cli_argv
99 | + [
100 | "model_class=FakeModel",
101 | f"search_targets_file={data_dir}/search_targets.smiles",
102 | f"inventory_smiles_file={data_dir}/inventory.smiles",
103 | ], # Not a real model.
104 | ]
105 |
106 | for argv in argv_lists:
107 | with pytest.raises((ValueError, omegaconf.errors.MissingMandatoryValue)):
108 | run_cli_with_argv(argv)
109 |
110 |
111 | @pytest.mark.parametrize("model_class", MODEL_CLASSES_TO_TEST)
112 | def test_cli_eval_single_step(
113 | model_class: BackwardModelClass, data_dir: Path, tmpdir: Path, eval_cli_argv: List[str]
114 | ) -> None:
115 | run_cli_with_argv(
116 | eval_cli_argv
117 | + [
118 | f"model_class={model_class}",
119 | f"data_dir={data_dir}",
120 | f"results_dir={tmpdir}",
121 | "num_top_results=5",
122 | "print_idxs=[1,5]",
123 | "num_dataset_truncation=10",
124 | ]
125 | )
126 |
127 | [results_path] = glob.glob(f"{tmpdir}/{model_class.name}_*.json")
128 |
129 | with open(results_path, "rt") as f:
130 | results = json.load(f)
131 |
132 | top_1_accuracy = results["top_k"][0]
133 |
134 | # We just evaluated a tiny sample of the data, so only make a rough check that the accuracy is
135 | # ballpark reasonable (full test set accuracy would be around ~50%).
136 | assert 0.2 <= top_1_accuracy <= 0.8
137 |
138 |
139 | @pytest.mark.parametrize("model_class", MODEL_CLASSES_TO_TEST)
140 | @pytest.mark.parametrize("search_algorithm", ["retro_star", "mcts", "pdvn"])
141 | def test_cli_search(
142 | model_class: BackwardModelClass,
143 | search_algorithm: str,
144 | data_dir: Path,
145 | tmpdir: Path,
146 | search_cli_argv: List[str],
147 | ) -> None:
148 | run_cli_with_argv(
149 | search_cli_argv
150 | + [
151 | f"model_class={model_class}",
152 | f"search_algorithm={search_algorithm}",
153 | f"results_dir={tmpdir}",
154 | f"search_targets_file={data_dir}/search_targets.smiles",
155 | f"inventory_smiles_file={data_dir}/inventory.smiles",
156 | "limit_iterations=3",
157 | "num_top_results=5",
158 | ]
159 | )
160 |
161 | results_dir = f"{tmpdir}/{model_class.name}_*/"
162 | [results_path] = glob.glob(f"{results_dir}/stats.json")
163 |
164 | with open(results_path, "rt") as f:
165 | results = json.load(f)
166 |
167 | # Assert that a solution was found.
168 | assert results["soln_time_rxn_model_calls"] < math.inf
169 | assert len(glob.glob(f"{results_dir}/route_*.pdf")) >= 1
170 |
--------------------------------------------------------------------------------
/syntheseus/reaction_prediction/inference/local_retro.py:
--------------------------------------------------------------------------------
1 | """Inference wrapper for the LocalRetro model.
2 |
3 | Paper: https://pubs.acs.org/doi/10.1021/jacsau.1c00246
4 | Code: https://github.com/kaist-amsg/LocalRetro
5 |
6 | The original LocalRetro code is released under the Apache 2.0 license.
7 | Parts of this file are based on code from the GitHub repository above.
8 | """
9 |
10 | from pathlib import Path
11 | from typing import Any, List, Sequence
12 |
13 | from syntheseus.interface.molecule import Molecule
14 | from syntheseus.interface.reaction import SingleProductReaction
15 | from syntheseus.reaction_prediction.inference.base import ExternalBackwardReactionModel
16 | from syntheseus.reaction_prediction.utils.inference import (
17 | get_unique_file_in_dir,
18 | process_raw_smiles_outputs_backwards,
19 | )
20 | from syntheseus.reaction_prediction.utils.misc import suppress_outputs
21 |
22 |
23 | class LocalRetroModel(ExternalBackwardReactionModel):
24 | def __init__(self, *args, **kwargs) -> None:
25 | """Initializes the LocalRetro model wrapper.
26 |
27 | Assumed format of the model directory:
28 | - `model_dir` contains the model checkpoint as the only `*.pth` file
29 | - `model_dir` contains the config as the only `*.json` file
30 | - `model_dir/data` contains `*.csv` data files needed by LocalRetro
31 | """
32 | super().__init__(*args, **kwargs)
33 |
34 | from local_retro.Retrosynthesis import load_templates
35 | from local_retro.scripts.utils import init_featurizer, load_model
36 |
37 | data_dir = Path(self.model_dir) / "data"
38 | self.args = init_featurizer(
39 | {
40 | "mode": "test",
41 | "device": self.device,
42 | "model_path": get_unique_file_in_dir(self.model_dir, pattern="*.pth"),
43 | "config_path": get_unique_file_in_dir(self.model_dir, pattern="*.json"),
44 | "data_dir": data_dir,
45 | "rxn_class_given": False,
46 | }
47 | )
48 |
49 | with suppress_outputs():
50 | self.model = load_model(self.args)
51 |
52 | [
53 | self.args["atom_templates"],
54 | self.args["bond_templates"],
55 | self.args["template_infos"],
56 | ] = load_templates(self.args)
57 |
58 | def get_parameters(self):
59 | return self.model.parameters()
60 |
61 | def _mols_to_batch(self, mols: List[Molecule]) -> Any:
62 | from dgllife.utils import smiles_to_bigraph
63 | from local_retro.scripts.utils import collate_molgraphs_test
64 |
65 | graphs = [
66 | smiles_to_bigraph(
67 | mol.smiles,
68 | node_featurizer=self.args["node_featurizer"],
69 | edge_featurizer=self.args["edge_featurizer"],
70 | add_self_loop=True,
71 | canonical_atom_order=False,
72 | )
73 | for mol in mols
74 | ]
75 |
76 | return collate_molgraphs_test([(None, graph, None) for graph in graphs])[1]
77 |
78 | def _build_batch_predictions(
79 | self, batch, num_results: int, inputs: List[Molecule], batch_atom_logits, batch_bond_logits
80 | ) -> List[Sequence[SingleProductReaction]]:
81 | from local_retro.scripts.Decode_predictions import get_k_predictions
82 | from local_retro.scripts.get_edit import combined_edit, get_bg_partition
83 |
84 | graphs, nodes_sep, edges_sep = get_bg_partition(batch)
85 | start_node = 0
86 | start_edge = 0
87 |
88 | self.args["top_k"] = num_results
89 | self.args["raw_predictions"] = []
90 |
91 | for input, graph, end_node, end_edge in zip(inputs, graphs, nodes_sep, edges_sep):
92 | pred_types, pred_sites, pred_scores = combined_edit(
93 | graph,
94 | batch_atom_logits[start_node:end_node],
95 | batch_bond_logits[start_edge:end_edge],
96 | num_results,
97 | )
98 | start_node, start_edge = end_node, end_edge
99 |
100 | raw_predictions = [
101 | f"({pred_types[i]}, {pred_sites[i][0]}, {pred_sites[i][1]}, {pred_scores[i]:.3f})"
102 | for i in range(num_results)
103 | ]
104 |
105 | self.args["raw_predictions"].append([input.smiles] + raw_predictions)
106 |
107 | batch_predictions = []
108 | for idx, input in enumerate(inputs):
109 | try:
110 | raw_str_results = get_k_predictions(test_id=idx, args=self.args)[1][0]
111 | except RuntimeError:
112 | # In very rare cases we may get `rdkit` errors.
113 | raw_str_results = []
114 |
115 | # We have to `eval` the predictions as they come rendered into strings. Second tuple
116 | # component is empirically (on USPTO-50K test set) in [0, 1], resembling a probability,
117 | # but does not sum up to 1.0 (usually to something in [0.5, 2.0]).
118 | raw_results = [eval(str_result) for str_result in raw_str_results]
119 |
120 | if raw_results:
121 | raw_outputs, probabilities = zip(*raw_results)
122 | else:
123 | raw_outputs = probabilities = []
124 |
125 | batch_predictions.append(
126 | process_raw_smiles_outputs_backwards(
127 | input=input,
128 | output_list=raw_outputs,
129 | metadata_list=[{"probability": probability} for probability in probabilities],
130 | )
131 | )
132 |
133 | return batch_predictions
134 |
135 | def _get_reactions(
136 | self, inputs: List[Molecule], num_results: int
137 | ) -> List[Sequence[SingleProductReaction]]:
138 | import torch
139 | from local_retro.scripts.utils import predict
140 |
141 | batch = self._mols_to_batch(inputs)
142 | batch_atom_logits, batch_bond_logits, _ = predict(self.args, self.model, batch)
143 |
144 | batch_atom_logits = torch.nn.Softmax(dim=1)(batch_atom_logits)
145 | batch_bond_logits = torch.nn.Softmax(dim=1)(batch_bond_logits)
146 |
147 | return self._build_batch_predictions(
148 | batch, num_results, inputs, batch_atom_logits, batch_bond_logits
149 | )
150 |
--------------------------------------------------------------------------------
/syntheseus/tests/search/node_evaluation/test_common.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import pytest
4 |
5 | from syntheseus.search.graph.and_or import AndNode, AndOrGraph
6 | from syntheseus.search.graph.molset import MolSetGraph
7 | from syntheseus.search.node_evaluation.common import (
8 | ConstantNodeEvaluator,
9 | HasSolutionValueFunction,
10 | ReactionModelLogProbCost,
11 | ReactionModelProbPolicy,
12 | )
13 |
14 |
15 | class TestConstantNodeEvaluator:
16 | @pytest.mark.parametrize("constant", [0.3, 0.8])
17 | def test_values(self, constant: float, andor_graph_non_minimal: AndOrGraph) -> None:
18 | val_fn = ConstantNodeEvaluator(constant)
19 | vals = val_fn(list(andor_graph_non_minimal.nodes()))
20 | assert all([v == constant for v in vals]) # values should match
21 | assert val_fn.num_calls == len(
22 | andor_graph_non_minimal
23 | ) # should have been called once per node
24 |
25 |
26 | class TestHasSolutionValueFunction:
27 | def test_values(self, andor_graph_non_minimal: AndOrGraph) -> None:
28 | val_fn = HasSolutionValueFunction()
29 | nodes = list(andor_graph_non_minimal.nodes())
30 | vals = val_fn(nodes)
31 | assert all([v == float(n.has_solution) for v, n in zip(vals, nodes)]) # values should match
32 | assert val_fn.num_calls == len(
33 | andor_graph_non_minimal
34 | ) # should have been called once per node
35 |
36 |
37 | class TestReactionModelLogProbCost:
38 | @pytest.mark.parametrize("normalize", [False, True])
39 | @pytest.mark.parametrize("temperature", [1.0, 2.0])
40 | @pytest.mark.parametrize("clip_probability_min", [0.1, 0.5])
41 | @pytest.mark.parametrize("clip_probability_max", [0.5, 1.0])
42 | def test_values(
43 | self,
44 | andor_graph_non_minimal: AndOrGraph,
45 | normalize: bool,
46 | temperature: float,
47 | clip_probability_min: float,
48 | clip_probability_max: float,
49 | ) -> None:
50 | val_fn = ReactionModelLogProbCost(
51 | normalize=normalize,
52 | temperature=temperature,
53 | clip_probability_min=clip_probability_min,
54 | clip_probability_max=clip_probability_max,
55 | )
56 | nodes = [node for node in andor_graph_non_minimal.nodes() if isinstance(node, AndNode)]
57 |
58 | # The toy model does not set reaction probabilities, so set these manually.
59 | node_val_expected = {}
60 | for idx, node in enumerate(nodes):
61 | prob = idx / (len(nodes) - 1)
62 | node.reaction.metadata["probability"] = prob # type: ignore
63 |
64 | node_val_expected[node] = (
65 | -math.log(min(clip_probability_max, max(clip_probability_min, prob))) / temperature
66 | )
67 |
68 | if normalize:
69 | normalization_constant = math.log(sum(math.exp(-v) for v in node_val_expected.values()))
70 | node_val_expected = {
71 | key: value + normalization_constant for key, value in node_val_expected.items()
72 | }
73 |
74 | vals = val_fn(nodes)
75 | for val_computed, node in zip(vals, nodes): # values should match
76 | assert math.isclose(val_computed, node_val_expected[node])
77 |
78 | assert val_fn.num_calls == len(nodes) # should have been called once per AND node
79 |
80 | def test_enforces_min_clipping(self) -> None:
81 | with pytest.raises(ValueError):
82 | ReactionModelLogProbCost(clip_probability_min=0.0) # should fail as `return_log = True`
83 |
84 |
85 | class TestReactionModelProbPolicy:
86 | @pytest.mark.parametrize("normalize", [False, True])
87 | @pytest.mark.parametrize("temperature", [1.0, 2.0])
88 | @pytest.mark.parametrize("clip_probability_min", [0.1, 0.5])
89 | @pytest.mark.parametrize("clip_probability_max", [0.5, 1.0])
90 | def test_values(
91 | self,
92 | molset_tree_non_minimal: MolSetGraph,
93 | normalize: bool,
94 | temperature: float,
95 | clip_probability_min: float,
96 | clip_probability_max: float,
97 | ) -> None:
98 | val_fn = ReactionModelProbPolicy(
99 | normalize=normalize,
100 | temperature=temperature,
101 | clip_probability_min=clip_probability_min,
102 | clip_probability_max=clip_probability_max,
103 | )
104 | nodes = [
105 | node
106 | for node in molset_tree_non_minimal.nodes()
107 | if node != molset_tree_non_minimal.root_node
108 | ]
109 |
110 | # The toy model does not set reaction probabilities, so set these manually.
111 | node_val_expected = {}
112 | for idx, node in enumerate(nodes):
113 | [parent] = molset_tree_non_minimal.predecessors(node)
114 | reaction = molset_tree_non_minimal._graph.edges[parent, node]["reaction"]
115 |
116 | # Be careful not to overwrite things as some reactions in the graph are repeated.
117 | if "probability" not in reaction.metadata:
118 | reaction.metadata["probability"] = prob = idx / (len(nodes) - 1)
119 | else:
120 | prob = reaction.metadata["probability"]
121 |
122 | node_val_expected[node] = min(
123 | clip_probability_max, max(clip_probability_min, prob)
124 | ) ** (1.0 / temperature)
125 |
126 | if normalize:
127 | normalization_factor = sum(node_val_expected.values())
128 | node_val_expected = {
129 | key: value / normalization_factor for key, value in node_val_expected.items()
130 | }
131 |
132 | vals = val_fn(nodes, graph=molset_tree_non_minimal)
133 | for val_computed, node in zip(vals, nodes): # values should match
134 | assert math.isclose(val_computed, node_val_expected[node])
135 |
136 | assert (
137 | val_fn.num_calls == len(molset_tree_non_minimal) - 1
138 | ) # should have been called once per non-root node
139 |
140 | def test_enforces_min_clipping(self) -> None:
141 | with pytest.raises(ValueError):
142 | ReactionModelProbPolicy(clip_probability_min=0.0) # should fail as `normalize = True`
143 |
144 | ReactionModelProbPolicy(
145 | normalize=False, clip_probability_min=0.0
146 | ) # should succeed if we explicitly turn off normalization
147 |
--------------------------------------------------------------------------------