├── .github ├── release-drafter.yml └── workflows │ ├── publish-release.yml │ ├── quality.yml │ ├── release-drafter.yml │ ├── update-major-minor-tags.yml │ └── version-bump.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .python-version ├── LICENSE ├── README.md ├── codecov.yml ├── data ├── README.md ├── __init__.py ├── configs │ ├── dms_dictionary.yaml │ └── logbook │ │ └── van_sm_6x3_6x3_256_noboth_seed=42 │ │ ├── model_config.yaml │ │ └── training_config.yaml ├── download.py ├── pharma_compounds.json ├── process.py ├── process_eMolecules.py └── route_separation.json ├── docs ├── CNAME ├── DirectMultiStep │ ├── components │ │ ├── attention.md │ │ ├── decoder.md │ │ ├── encoder.md │ │ └── moe.md │ ├── evaluation.md │ ├── model-init.md │ ├── training.md │ ├── utils │ │ ├── io.md │ │ ├── post-process.md │ │ ├── pre-process.md │ │ └── torch-dataset.md │ └── visualizations.md ├── analysis │ ├── monitoring-training.md │ ├── paper-figures.md │ └── style-settings.md ├── dev │ └── logging.md ├── index.md └── stylesheets │ └── extra.css ├── download_files.sh ├── mkdocs.yml ├── pyproject.toml ├── scripts └── solve_compounds.py ├── src └── directmultistep │ ├── __init__.py │ ├── analysis │ ├── __init__.py │ ├── paper │ │ ├── __init__.py │ │ ├── dataset_analysis.py │ │ └── linear_vs_convergent.py │ ├── style.py │ └── training.py │ ├── generate.py │ ├── generation │ ├── __init__.py │ ├── eval.py │ ├── generation.py │ └── tensor_gen.py │ ├── helpers.py │ ├── model │ ├── __init__.py │ ├── architecture.py │ ├── components │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ └── moe.py │ ├── config.py │ ├── default_configs │ │ ├── deep_40M.yaml │ │ ├── explorer_19M.yaml │ │ ├── explorer_xl_50M.yaml │ │ ├── flash_10M.yaml │ │ ├── flash_20M.yaml │ │ ├── flex_20M.yaml │ │ └── wide_40M.yaml │ └── factory.py │ ├── training │ ├── __init__.py │ ├── config.py │ ├── lightning.py │ └── trainer.py │ └── utils │ ├── __init__.py │ ├── dataset.py │ ├── io.py │ ├── logging_config.py │ ├── post_process.py │ ├── pre_process.py │ └── web_visualize.py ├── tests ├── __init__.py ├── test_data.py └── test_preprocess.py ├── use-examples ├── data │ └── configs │ │ └── logbook │ │ └── van_sm_6x3_6x3_256_noboth_seed=42 │ │ ├── model_config.yaml │ │ └── training_config.yaml ├── eval-subset.py ├── generate-route.py ├── paper-figures.py ├── train-model.py └── visualize-train-curves.py └── uv.lock /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name-template: "v$RESOLVED_VERSION ⛰️" 2 | tag-template: "v$RESOLVED_VERSION" 3 | categories: 4 | - title: "🚨 Breaking changes" 5 | labels: 6 | - "breaking-change" 7 | - title: "✨ New features" 8 | labels: 9 | - "new-feature" 10 | - title: "🐛 Bug fixes" 11 | labels: 12 | - "bugfix" 13 | - title: "🚀 Enhancements" 14 | labels: 15 | - "enhancement" 16 | - "refactor" 17 | - "performance" 18 | - title: "🧰 Maintenance" 19 | labels: 20 | - "maintenance" 21 | - "ci" 22 | - "update-known-checksums" 23 | - title: "📚 Documentation" 24 | labels: 25 | - "documentation" 26 | - title: "⬆️ Dependency updates" 27 | labels: 28 | - "dependencies" 29 | change-template: "- $TITLE @$AUTHOR (#$NUMBER)" 30 | change-title-escapes: '\<*_&' # You can add # and @ to disable mentions, and add ` to disable code blocks. 31 | version-resolver: 32 | major: 33 | labels: 34 | - "major" 35 | - "breaking-change" 36 | minor: 37 | labels: 38 | - "minor" 39 | - "new-feature" 40 | - "enhancement" 41 | patch: 42 | labels: 43 | - "patch" 44 | - "bugfix" 45 | - "default-version-update" 46 | default: patch 47 | template: | 48 | ## Changes 49 | 50 | $CHANGES -------------------------------------------------------------------------------- /.github/workflows/publish-release.yml: -------------------------------------------------------------------------------- 1 | name: Publish Release 2 | on: 3 | release: 4 | types: [published] # Runs when a draft is published 5 | 6 | permissions: 7 | contents: write 8 | 9 | jobs: 10 | create_python_package: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Install uv 15 | uses: astral-sh/setup-uv@v3 16 | with: 17 | enable-cache: true 18 | cache-dependency-glob: uv.lock 19 | 20 | - name: Set up Python 21 | run: uv python install 3.11.4 # Or whatever version I want to use. 22 | 23 | - name: Build 24 | run: uv build 25 | 26 | - name: Upload artifacts to release 27 | uses: softprops/action-gh-release@v1 28 | with: 29 | files: dist/* -------------------------------------------------------------------------------- /.github/workflows/quality.yml: -------------------------------------------------------------------------------- 1 | name: Code Quality 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | qualitycheck: 7 | runs-on: ubuntu-latest 8 | 9 | steps: 10 | - name: Checkout 11 | uses: actions/checkout@v4 12 | 13 | - name: Set up Python 3.11 14 | uses: actions/setup-python@v4 15 | with: 16 | python-version: '3.11' 17 | 18 | - name: Install uv 19 | run: | 20 | curl -LsSf https://astral.sh/uv/install.sh | sh 21 | echo "$HOME/.cargo/bin" >> $GITHUB_PATH 22 | 23 | - name: Cache dependencies 24 | uses: actions/cache@v4 25 | with: 26 | path: | 27 | ~/.cache/uv 28 | ~/.uv 29 | .venv 30 | key: ${{ runner.os }}-uv-${{ hashFiles('pyproject.toml') }} 31 | restore-keys: | 32 | ${{ runner.os }}-uv- 33 | 34 | - name: Create and activate virtual environment 35 | run: | 36 | uv venv 37 | echo "$PWD/.venv/bin" >> $GITHUB_PATH 38 | 39 | - name: Install dependencies 40 | run: uv pip install -e ".[dev]" 41 | 42 | - name: Run ruff (linter) 43 | run: ruff check 44 | 45 | - name: Run ruff (formatter) 46 | run: ruff format --check 47 | 48 | - name: Run isort 49 | run: isort --check --profile black . 50 | 51 | - name: Run mypy 52 | run: mypy . 53 | 54 | - name: Run tests 55 | run: pytest -v -------------------------------------------------------------------------------- /.github/workflows/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name: Release Drafter 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | types: [opened, reopened, synchronize] 9 | branches: 10 | - main 11 | workflow_dispatch: 12 | 13 | permissions: 14 | contents: read 15 | pull-requests: write 16 | 17 | jobs: 18 | process_pr: 19 | permissions: 20 | contents: write 21 | pull-requests: write 22 | runs-on: ubuntu-latest 23 | steps: 24 | - name: Assign labels from commits 25 | if: github.event_name == 'pull_request' 26 | id: assign-labels 27 | uses: mauroalderete/action-assign-labels@v1 28 | with: 29 | pull-request-number: ${{ github.event.pull_request.number }} 30 | github-token: ${{ secrets.GITHUB_TOKEN }} 31 | conventional-commits: | 32 | conventional-commits: 33 | - type: 'breaking_change' 34 | nouns: ['BREAKING CHANGE', 'BREAKING', 'MAJOR'] 35 | labels: ['breaking-change'] 36 | - type: 'feat' 37 | nouns: ['FEATURE', 'Feature', 'feature', 'FEAT', 'Feat', 'feat'] 38 | labels: ['new-feature'] 39 | - type: 'fix' 40 | nouns: ['FIX', 'Fix', 'fix', 'FIXED', 'Fixed', 'fixed'] 41 | labels: ['bugfix'] 42 | - type: 'enhance' 43 | nouns: ['ENHANCE', 'Enhance', 'enhance', 'IMPROVEMENT', 'improvement'] 44 | labels: ['enhancement'] 45 | - type: 'refactor' 46 | nouns: ['REFACTOR', 'Refactor', 'refactor'] 47 | labels: ['refactor'] 48 | - type: 'perf' 49 | nouns: ['PERF', 'Perf', 'perf', 'PERFORMANCE', 'Performance', 'performance'] 50 | labels: ['performance'] 51 | - type: 'docs' 52 | nouns: ['DOCS', 'Docs', 'docs', 'DOC', 'Doc', 'doc'] 53 | labels: ['documentation'] 54 | maintain-labels-not-matched: false 55 | apply-changes: true 56 | 57 | - uses: release-drafter/release-drafter@v6.0.0 58 | env: 59 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} -------------------------------------------------------------------------------- /.github/workflows/update-major-minor-tags.yml: -------------------------------------------------------------------------------- 1 | name: Release Management 2 | 3 | on: 4 | release: 5 | types: [published] # Runs when a release is published 6 | 7 | 8 | permissions: 9 | contents: write 10 | pull-requests: write 11 | 12 | jobs: 13 | update_version: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | with: 18 | fetch-depth: 0 19 | ref: ${{ github.event.repository.default_branch }} 20 | token: ${{ secrets.GITHUB_TOKEN }} 21 | 22 | - name: Get version from published release 23 | id: get_version 24 | uses: actions/github-script@v7 25 | with: 26 | script: | 27 | const release = context.payload.release; 28 | if (!release) { 29 | throw new Error('No release found in event payload'); 30 | } 31 | // Extract version without 'v' prefix 32 | const version = release.tag_name.replace(/^v/, ''); 33 | core.setOutput('version', version); 34 | 35 | - name: Update pyproject.toml version 36 | run: | 37 | VERSION="${{ steps.get_version.outputs.version }}" 38 | # Use sed to update version only under [project] section 39 | sed -i '/^\[project\]/,/^\[.*\]/ s/^version = .*/version = "'$VERSION'"/' pyproject.toml 40 | 41 | - name: Commit and push version update 42 | run: | 43 | git config --global user.name 'github-actions' 44 | git config --global user.email 'github-actions@github.com' 45 | git add pyproject.toml 46 | git commit -m "chore: update version to ${{ steps.get_version.outputs.version }}" 47 | git push origin ${{ github.event.repository.default_branch }} -------------------------------------------------------------------------------- /.github/workflows/version-bump.yml: -------------------------------------------------------------------------------- 1 | name: Release Management 2 | 3 | on: 4 | release: 5 | types: [published] # Runs when a draft is published 6 | 7 | permissions: 8 | contents: write 9 | pull-requests: write 10 | 11 | jobs: 12 | update_version: 13 | if: github.event.workflow_run.conclusion == 'success' 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | with: 18 | fetch-depth: 0 19 | token: ${{ secrets.GITHUB_TOKEN }} 20 | 21 | - name: Get next version from draft release 22 | id: get_version 23 | uses: actions/github-script@v7 24 | with: 25 | script: | 26 | const releases = await github.rest.repos.listReleases({ 27 | owner: context.repo.owner, 28 | repo: context.repo.repo 29 | }); 30 | const draftRelease = releases.data.find(release => release.draft); 31 | if (!draftRelease) { 32 | throw new Error('No draft release found'); 33 | } 34 | // Extract version without 'v' prefix 35 | const version = draftRelease.tag_name.replace(/^v/, ''); 36 | core.setOutput('version', version); 37 | 38 | - name: Update pyproject.toml version 39 | run: | 40 | VERSION="${{ steps.get_version.outputs.version }}" 41 | # Use sed to update version only under [project] section 42 | sed -i '/^\[project\]/,/^\[.*\]/ s/^version = .*/version = "'$VERSION'"/' pyproject.toml 43 | 44 | - name: Commit and push version update 45 | run: | 46 | git config --global user.name 'github-actions' 47 | git config --global user.email 'github-actions@github.com' 48 | git add pyproject.toml 49 | git commit -m "chore: update version to ${{ steps.get_version.outputs.version }}" 50 | git push -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *.so 5 | *.egg 6 | *.egg-info/ 7 | dist/ 8 | build/ 9 | eggs/ 10 | .coverage 11 | htmlcov/ 12 | 13 | # Virtual Environments 14 | venv/ 15 | .venv/ 16 | 17 | # IDE and System Files 18 | .DS_Store 19 | .env 20 | *_cache/ 21 | 22 | # Data and Model Files 23 | site/ 24 | data/checkpoints/ 25 | data/paroutes/ 26 | data/paroutes/*.json 27 | data/processed/pre_perms/ 28 | data/processed/*.pkl 29 | data/training/ 30 | data/figures/ 31 | data/evaluation/ 32 | data/datasets/ 33 | *.pkl 34 | *.ckpt 35 | *.pt 36 | *.pickle 37 | 38 | # Logs and Documentation 39 | lightning_logs/ 40 | *.log 41 | *.pdf 42 | *.tar.gz 43 | *.zip 44 | 45 | # Jupyter Notebooks 46 | *.ipynb 47 | 48 | # Temporary and Debug Files 49 | debug* 50 | slurm*.out 51 | submit*.sh 52 | qual.sh 53 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.11 3 | 4 | repos: 5 | - repo: local 6 | hooks: 7 | - id: ruff 8 | name: ruff (linter) 9 | entry: ruff check --fix 10 | language: python 11 | types: [python] 12 | 13 | - id: ruff-format 14 | name: ruff (formatter) 15 | entry: ruff format 16 | language: python 17 | types: [python] 18 | 19 | - id: isort 20 | name: isort 21 | entry: isort 22 | language: python 23 | types: [python] 24 | args: [--profile=black] 25 | 26 | - id: mypy 27 | name: mypy 28 | entry: mypy 29 | language: system 30 | types: [python] -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.11.4 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Batista Lab (Yale University) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DirectMultiStep: Direct Route Generation for Multi-Step Retrosynthesis 2 | 3 | [![Python Version](https://img.shields.io/badge/python-3.12+-blue.svg)](https://www.python.org/downloads/) 4 | [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) 5 | [![Checked with mypy](https://www.mypy-lang.org/static/mypy_badge.svg)](https://mypy-lang.org/) 6 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 7 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/batistagroup/DirectMultiStep/blob/main/LICENSE) 8 | [![arXiv](https://img.shields.io/badge/arXiv-2405.13983-b31b1b.svg)](https://arxiv.org/abs/2405.13983) 9 | [![image](https://img.shields.io/pypi/v/DirectMultiStep.svg)](https://pypi.org/project/DirectMultiStep/) 10 | [![PyPI - Downloads](https://img.shields.io/pypi/dm/directmultistep)](https://pypi.org/project/DirectMultiStep/) 11 | 12 | ## Overview 13 | 14 | This work has been published in [*J. Chem. Inf. Model*](https://pubs.acs.org/doi/10.1021/acs.jcim.4c01982). The preprint for this work was posted on [arXiv](https://arxiv.org/abs/2405.13983). 15 | 16 | You can use DMS models without installation through our web interface at [models.batistalab.com](https://models.batistalab.com). Or, if you want, you can install the package from pypi `pip install directmultistep`. Check out [dms.batistalab.com](https://dms.batistalab.com) for full documentation. 17 | 18 | ## How to use 19 | 20 | Here's a quick example to generate a retrosynthesis route (you can get relevant checkpoints by running `bash download_files.sh`). 21 | 22 | ```python 23 | from directmultistep.generate import generate_routes 24 | from pathlib import Path 25 | 26 | data_path = Path(__file__).resolve().parents[1] / "data" 27 | ckpt_path = data_path / "checkpoints" 28 | fig_path = data_path / "figures" 29 | config_path = data_path / "configs" / "dms_dictionary.yaml" 30 | 31 | # Generate a route for a target molecule 32 | target = "CNCc1cc(-c2ccccc2F)n(S(=O)(=O)c2cccnc2)c1" 33 | starting_material = "CN" 34 | 35 | # Find routes with different models: 36 | # Using flash model with starting material 37 | paths = generate_routes( 38 | target, 39 | n_steps=2, 40 | starting_material=starting_material, 41 | model="flash", beam_size=5, 42 | config_path=config_path, ckpt_dir=ckpt_path 43 | ) 44 | 45 | # Or use explorer model to automatically determine steps 46 | paths = generate_routes( 47 | target, 48 | starting_material=starting_material, 49 | model="explorer", 50 | beam_size=5, 51 | config_path=config_path, ckpt_dir=ckpt_path 52 | ) 53 | ``` 54 | 55 | See `use-examples/generate-route.py` to see more examples with other models. Other example scripts include: 56 | 57 | - `train-model.py`: Train a new model with customizable configuration for local or cluster environments 58 | - `eval-subset.py`: Evaluate a trained model on a subset of data 59 | - `paper-figures.py`: Reproduce figures from the paper 60 | - `visualize-train-curves.py`: Plot training curves and metrics 61 | 62 | ## Citing 63 | 64 | If you use DirectMultiStep in an academic project, please consider citing our publication in [*J. Chem. Inf. Model*](https://pubs.acs.org/doi/10.1021/acs.jcim.4c01982): 65 | 66 | ```tex 67 | @article{directmultistep, 68 | author = {Shee, Yu and Morgunov, Anton and Li, Haote and Batista, Victor S.}, 69 | title = {DirectMultiStep: Direct Route Generation for Multistep Retrosynthesis}, 70 | journal = {Journal of Chemical Information and Modeling}, 71 | volume = {65}, 72 | number = {8}, 73 | pages = {3903-3914}, 74 | year = {2025}, 75 | doi = {10.1021/acs.jcim.4c01982}, 76 | note ={PMID: 40197023}, 77 | URL = {https://doi.org/10.1021/acs.jcim.4c01982}, 78 | eprint = {https://doi.org/10.1021/acs.jcim.4c01982} 79 | } 80 | ``` 81 | 82 | ## Extra Materials 83 | 84 | Through [download_files.sh](./download_files.sh) you can download canonicalized versions of eMols (23M SMILES), Buyables (329k SMILES), ChEMBL-5000 (5k SMILES), and USPTO-190 (190 SMILES). Using pre-canonicalized version saves you roughly a day of cpu time. If you happen to use these canonicalized versions, consider citing the repo from figshare: 85 | 86 | ```tex 87 | @misc{shee2025figshare, 88 | author = {Yu Shee and Anton Morgunov}, 89 | title = {Data for ``DirectMultiStep: Direct Route Generation for Multistep Retrosynthesis''}, 90 | year = {2025}, 91 | month = {3}, 92 | howpublished = {\url{https://figshare.com/articles/dataset/Data_for_DirectMultiStep_Direct_Route_Generation_for_Multistep_Retrosynthesis_/28629470}}, 93 | doi = {"10.6084/m9.figshare.28629470.v1"}, 94 | note = {Accessed: 20xx-xx-xx} 95 | } 96 | ``` 97 | 98 | Also check out the [HigherLev Retro](https://github.com/jihye-roh/higherlev_retro) repo which is the source of the Buyables stock set. [route-distances](https://github.com/MolecularAI/route-distances?tab=readme-ov-file) is the source of ChEMBL-5000. [Retro*](https://github.com/binghong-ml/retro_star) is the source of the eMols stock set and USPTO-190. 99 | 100 | ## Licenses 101 | 102 | All code is licensed under MIT License. The content of the [pre-print on arXiv](https://arxiv.org/abs/2405.13983) is licensed under CC-BY 4.0. 103 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | require_ci_to_pass: yes 3 | 4 | coverage: 5 | precision: 2 6 | round: down 7 | range: "70...100" 8 | 9 | status: 10 | project: yes 11 | patch: yes 12 | changes: no 13 | 14 | parsers: 15 | gcov: 16 | branch_detection: 17 | conditional: yes 18 | loop: yes 19 | method: no 20 | macro: no 21 | 22 | comment: 23 | layout: "header, diff, tree" 24 | behavior: default 25 | require_changes: no 26 | 27 | ignore: 28 | - "DirectMultiStep/tests/*" 29 | - "Tutorials/*" 30 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Preparing PaRoutes Dataset (Multi-Step) 2 | 3 | This model is trained on the [PaRoutes](https://github.com/MolecularAI/PaRoutes) dataset. To get started, please: 4 | 5 | 1. Run `python download.py`. This creates a `PaRoutes` folder and downloads `n1-routes.json` (64.2 MB), `n1-stock.txt` (0.4 MB), `n5-routes.json` (82.1 MB), `n5-stock.txt` (0.4 MB), and `all_routes.json` (1.44 GB). 6 | 2. Run `python process.py`. This takes roughly 7 minutes and it creates pickle files in `processed` folder. 7 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/batistagroup/DirectMultiStep/bb445196ce743317c179ccf8a7f5ee3966051cda/data/__init__.py -------------------------------------------------------------------------------- /data/configs/dms_dictionary.yaml: -------------------------------------------------------------------------------- 1 | invdict: 2 | 0: 3 | 1: '{' 4 | 2: '''smiles'':' 5 | 3: '''' 6 | 4: C 7 | 5: O 8 | 6: ( 9 | 7: '=' 10 | 8: ) 11 | 9: c 12 | 10: '1' 13 | 11: '2' 14 | 12: '[' 15 | 13: N 16 | 14: + 17 | 15: ']' 18 | 16: '-' 19 | 17: ',' 20 | 18: '''children'':' 21 | 19: B 22 | 20: r 23 | 21: '}' 24 | 22: '?' 25 | 23: l 26 | 24: '#' 27 | 25: n 28 | 26: o 29 | 27: '3' 30 | 28: s 31 | 29: S 32 | 30: F 33 | 31: I 34 | 32: H 35 | 33: P 36 | 34: '@' 37 | 35: '4' 38 | 36: L 39 | 37: i 40 | 38: '5' 41 | 39: / 42 | 40: \ 43 | 41: M 44 | 42: g 45 | 43: '6' 46 | 44: '7' 47 | 45: u 48 | 46: K 49 | 47: e 50 | 48: Z 51 | 49: a 52 | 50: p 53 | 51: J 54 | 52: ' ' 55 | product_max_length: 145 56 | seq_out_maxlength: 1074 57 | sm_max_length: 135 58 | smiledict: 59 | '#': 24 60 | '''': 3 61 | '''children'':': 18 62 | '''smiles'':': 2 63 | (: 6 64 | ): 8 65 | +: 14 66 | ',': 17 67 | '-': 16 68 | /: 39 69 | '1': 10 70 | '2': 11 71 | '3': 27 72 | '4': 35 73 | '5': 38 74 | '6': 43 75 | '7': 44 76 | : 0 77 | '=': 7 78 | '?': 22 79 | '@': 34 80 | B: 19 81 | C: 4 82 | F: 30 83 | H: 32 84 | I: 31 85 | J: 51 86 | K: 46 87 | L: 36 88 | M: 41 89 | N: 13 90 | O: 5 91 | P: 33 92 | S: 29 93 | Z: 48 94 | '[': 12 95 | \: 40 96 | ']': 15 97 | a: 49 98 | c: 9 99 | e: 47 100 | g: 42 101 | i: 37 102 | l: 23 103 | n: 25 104 | o: 26 105 | p: 50 106 | r: 20 107 | s: 28 108 | u: 45 109 | '{': 1 110 | '}': 21 111 | ' ': 52 -------------------------------------------------------------------------------- /data/configs/logbook/van_sm_6x3_6x3_256_noboth_seed=42/model_config.yaml: -------------------------------------------------------------------------------- 1 | encoder: 2 | vocab_dim: 53 3 | hid_dim: 256 4 | n_layers: 6 5 | n_heads: 8 6 | ff_mult: 3 7 | ff_activation: gelu 8 | dropout: 0.1 9 | attn_bias: false 10 | context_window: 280 11 | start_idx: 0 12 | mask_idx: 51 13 | pad_idx: 52 14 | initiate_steps: true 15 | include_steps: true 16 | model_type: EncoderAConfig 17 | decoder: 18 | vocab_dim: 53 19 | hid_dim: 256 20 | n_layers: 6 21 | n_heads: 8 22 | ff_mult: 3 23 | ff_activation: gelu 24 | dropout: 0.1 25 | attn_bias: false 26 | context_window: 1075 27 | start_idx: 0 28 | mask_idx: 51 29 | pad_idx: 52 30 | model_type: TransformerConfig 31 | -------------------------------------------------------------------------------- /data/configs/logbook/van_sm_6x3_6x3_256_noboth_seed=42/training_config.yaml: -------------------------------------------------------------------------------- 1 | data_path: /Users/morgunov/batista/DirectMultiStep/data 2 | run_name: van_sm_6x3_6x3_256_noboth_seed=42 3 | train_fname: unique_dataset_nperms=3_nsms=all_noboth_train=0.95.pkl 4 | val_fname: unique_dataset_nperms=3_nsms=all_noboth_val=0.05.pkl 5 | metadata_fname: dms_dictionary.yaml 6 | batch_size: 8 7 | learning_rate: 0.0002 8 | max_epochs: 40 9 | warmup_steps: 3000 10 | decay_steps: 80000 11 | decay_factor: 0.1 12 | pad_idx: 52 13 | mask_idx: 51 14 | save_top_k: -1 15 | checkpoint_every_n_epochs: 2 16 | num_workers: 1 17 | n_devices: 1 18 | seed: 42 19 | accelerator: cpu 20 | matmul_precision: high 21 | summary_depth: 2 22 | dist_strategy: ddp_find_unused_parameters_true 23 | gradient_clip_val: 1.0 24 | gradient_clip_algorithm: value 25 | -------------------------------------------------------------------------------- /data/download.py: -------------------------------------------------------------------------------- 1 | """Module with script to download public data 2 | 3 | Adapted from https://github.com/MolecularAI/PaRoutes/blob/main/data/download_data.py 4 | """ 5 | 6 | import os 7 | import sys 8 | from pathlib import Path 9 | 10 | import requests 11 | import tqdm 12 | 13 | FILES_TO_DOWNLOAD = [ 14 | { 15 | "filename": "n1-routes.json", 16 | "url": "https://zenodo.org/record/7341155/files/ref_routes_n1.json?download=1", 17 | }, 18 | { 19 | "filename": "n1-stock.txt", 20 | "url": "https://zenodo.org/record/7341155/files/stock_n1.txt?download=1", 21 | }, 22 | { 23 | "filename": "n5-routes.json", 24 | "url": "https://zenodo.org/record/7341155/files/ref_routes_n5.json?download=1", 25 | }, 26 | { 27 | "filename": "n5-stock.txt", 28 | "url": "https://zenodo.org/record/7341155/files/stock_n5.txt?download=1", 29 | }, 30 | { 31 | "filename": "all_routes.json.gz", 32 | "url": "https://zenodo.org/record/7341155/files/all_loaded_routes.json.gz?download=1", 33 | }, 34 | ] 35 | 36 | 37 | def _download_file(url: str | Path, filename: str | Path) -> None: 38 | with requests.get(str(url), stream=True) as response: 39 | response.raise_for_status() 40 | total_size = int(response.headers.get("content-length", 0)) 41 | pbar = tqdm.tqdm(total=total_size, desc=os.path.basename(filename), unit="B", unit_scale=True) 42 | with open(filename, "wb") as fileobj: 43 | for chunk in response.iter_content(chunk_size=1024): 44 | fileobj.write(chunk) 45 | pbar.update(len(chunk)) 46 | pbar.close() 47 | 48 | 49 | def main() -> None: 50 | """Entry-point for CLI""" 51 | path = Path(__file__).parent / "paroutes" 52 | path.mkdir(parents=True, exist_ok=True) 53 | for filespec in FILES_TO_DOWNLOAD: 54 | try: 55 | _download_file(filespec["url"], path / filespec["filename"]) 56 | except requests.HTTPError as err: 57 | print(f"Download failed with message {str(err)}") 58 | sys.exit(1) 59 | 60 | 61 | if __name__ == "__main__": 62 | main() 63 | -------------------------------------------------------------------------------- /data/pharma_compounds.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "name": "Vonoprazan-1", 4 | "path": "{'smiles':'CNCc1cc(-c2ccccc2F)n(S(=O)(=O)c2cccnc2)c1','children':[{'smiles':'O=Cc1cc(-c2ccccc2F)n(S(=O)(=O)c2cccnc2)c1','children':[{'smiles':'O=Cc1c[nH]c(-c2ccccc2F)c1'},{'smiles':'O=S(=O)(Cl)c1cccnc1'}]},{'smiles':'CN'}]}" 5 | }, 6 | { 7 | "name": "Vonoprazan-2", 8 | "path": "{'smiles':'CNCc1cc(-c2ccccc2F)n(S(=O)(=O)c2cccnc2)c1','children':[{'smiles':'CNC(=O)c1cc(-c2ccccc2F)n(S(=O)(=O)c2cccnc2)c1','children':[{'smiles':'CNC(=O)c1c[nH]c(-c2ccccc2F)c1','children':[{'smiles':'O=C(O)c1c[nH]c(-c2ccccc2F)c1','children':[{'smiles':'CCOC(=O)c1c[nH]c(-c2ccccc2F)c1'}]},{'smiles':'CN'}]},{'smiles':'O=S(=O)(Cl)c1cccnc1'}]}]}" 9 | }, 10 | { 11 | "name": "Vonoprazan-2 partial", 12 | "path": "{'smiles':'CNC(=O)c1cc(-c2ccccc2F)n(S(=O)(=O)c2cccnc2)c1','children':[{'smiles':'CNC(=O)c1c[nH]c(-c2ccccc2F)c1','children':[{'smiles':'O=C(O)c1c[nH]c(-c2ccccc2F)c1','children':[{'smiles':'CCOC(=O)c1c[nH]c(-c2ccccc2F)c1'}]},{'smiles':'CN'}]},{'smiles':'O=S(=O)(Cl)c1cccnc1'}]}" 13 | }, 14 | { 15 | "name": "Mitapivat-1", 16 | "path": "{'smiles':'O=C(c1ccc(NS(=O)(=O)c2cccc3cccnc23)cc1)N1CCN(CC2CC2)CC1','children':[{'smiles':'O=C(c1ccc(NS(=O)(=O)c2cccc3cccnc23)cc1)N1CCNCC1','children':[{'smiles':'CC(C)(C)OC(=O)N1CCN(C(=O)c2ccc(NS(=O)(=O)c3cccc4cccnc34)cc2)CC1','children':[{'smiles':'O=C(O)c1ccc(NS(=O)(=O)c2cccc3cccnc23)cc1','children':[{'smiles':'CCOC(=O)c1ccc(NS(=O)(=O)c2cccc3cccnc23)cc1','children':[{'smiles':'CCOC(=O)c1ccc(N)cc1'},{'smiles':'O=S(=O)(Cl)c1cccc2cccnc12'}]}]},{'smiles':'CC(C)(C)OC(=O)N1CCNCC1'}]}]},{'smiles':'O=CC1CC1'}]}" 17 | }, 18 | { 19 | "name": "Mitapivat-2", 20 | "path": "{'smiles':'O=C(c1ccc(NS(=O)(=O)c2cccc3cccnc23)cc1)N1CCN(CC2CC2)CC1','children':[{'smiles':'O=C(O)c1ccc(NS(=O)(=O)c2cccc3cccnc23)cc1','children':[{'smiles':'CCOC(=O)c1ccc(NS(=O)(=O)c2cccc3cccnc23)cc1','children':[{'smiles':'CCOC(=O)c1ccc(N)cc1'},{'smiles':'O=S(=O)(Cl)c1cccc2cccnc12'}]}]},{'smiles':'C1CN(CC2CC2)CCN1'}]}" 21 | }, 22 | { 23 | "name": "Daridorexant", 24 | "path": "{'smiles':'COc1ccc(-n2nccn2)c(C(=O)N2CCC[C@@]2(C)c2nc3c(C)c(Cl)ccc3[nH]2)c1','children':[{'smiles':'Cc1c(Cl)ccc2[nH]c([C@]3(C)CCCN3)nc12','children':[{'smiles':'Cc1c(Cl)ccc2[nH]c([C@]3(C)CCCN3C(=O)OC(C)(C)C)nc12','children':[{'smiles':'CC(C)(C)OC(=O)N1CCC[C@@]1(C)C(=O)O','children':[{'smiles':'C[C@@]1(C(=O)O)CCCN1'},{'smiles':'CC(C)(C)OC(=O)OC(=O)OC(C)(C)C'}]},{'smiles':'Cc1c(Cl)ccc(N)c1N'}]}]},{'smiles':'COc1ccc(-n2nccn2)c(C(=O)O)c1','children':[{'smiles':'COc1ccc(I)c(C(=O)O)c1'},{'smiles':'c1cn[nH]n1'}]}]}" 25 | } 26 | ] -------------------------------------------------------------------------------- /data/process.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import random 4 | from pathlib import Path 5 | 6 | from tqdm import tqdm 7 | 8 | from directmultistep.utils.io import ( 9 | convert_dict_of_lists_to_list_of_dicts, 10 | convert_list_of_dicts_to_dict_of_lists, 11 | load_dataset_sm, 12 | save_dataset_sm, 13 | ) 14 | from directmultistep.utils.pre_process import ( 15 | FilteredDict, 16 | canonicalize_smiles, 17 | filter_mol_nodes, 18 | find_leaves, 19 | generate_permutations, 20 | max_tree_depth, 21 | ) 22 | 23 | data_path = Path(__file__).parent / "paroutes" 24 | save_path = Path(__file__).parent / "processed" 25 | 26 | ProductsType = list[str] 27 | FilteredType = list[FilteredDict] 28 | DatasetEntry = dict[str, str | int | list[str]] 29 | Dataset = list[DatasetEntry] 30 | 31 | 32 | class PaRoutesDataset: 33 | def __init__(self, data_path: Path, filename: str, verbose: bool = True) -> None: 34 | self.data_path = data_path 35 | self.filename = filename 36 | self.dataset = json.load(open(data_path.joinpath(filename), "r")) 37 | 38 | self.verbose = verbose 39 | 40 | self.products: list[str] = [] 41 | self.filtered_data: FilteredType = [] 42 | # self.path_strings: List[str] = [] 43 | # self.max_steps: List[int] = [] 44 | # self.SMs: List[List[str]] = [] 45 | 46 | # self.non_permuted_path_strings: List[str] = [] 47 | 48 | def filter_dataset(self) -> None: 49 | if self.verbose: 50 | print("- Filtering all_routes to remove meta data") 51 | for route in tqdm(self.dataset): 52 | filtered_node = filter_mol_nodes(route) 53 | self.filtered_data.append(filtered_node) 54 | self.products.append(filtered_node["smiles"]) 55 | 56 | def prepare_final_dataset_v2( 57 | self, 58 | save_path: Path, 59 | n_perms: int | None = None, 60 | exclude_path_strings: set[str] | None = None, 61 | n_sms: int | None = None, 62 | ) -> set[str]: 63 | self.filter_dataset() 64 | products: list[str] = [] 65 | starting_materials: list[str] = [] 66 | path_strings: list[str] = [] 67 | n_steps_list: list[int] = [] 68 | non_permuted_paths: set[str] = set() 69 | 70 | if exclude_path_strings is None: 71 | exclude_path_strings = set() 72 | 73 | for filtered_route in tqdm(self.filtered_data): 74 | non_permuted_string = str(filtered_route).replace(" ", "") 75 | non_permuted_paths.add(non_permuted_string) 76 | permuted_path_strings = generate_permutations(filtered_route, max_perm=None) 77 | for permuted_path_string in permuted_path_strings: 78 | if permuted_path_string in exclude_path_strings: 79 | break 80 | else: 81 | n_steps = max_tree_depth(filtered_route) 82 | all_SMs = find_leaves(filtered_route) 83 | if n_perms == 1: 84 | permuted_path_strings = [non_permuted_string] 85 | else: 86 | permuted_path_strings = generate_permutations(filtered_route, max_perm=n_perms) 87 | 88 | for path_string in permuted_path_strings: 89 | for sm_count, starting_material in enumerate(all_SMs): 90 | products.append(filtered_route["smiles"]) 91 | starting_materials.append(starting_material) 92 | path_strings.append(path_string) 93 | n_steps_list.append(n_steps) 94 | if n_sms is not None and sm_count + 1 >= n_sms: 95 | break 96 | print(f"Created dataset with {len(products)} entries") 97 | pickle.dump( 98 | (products, starting_materials, path_strings, n_steps_list), 99 | open(save_path, "wb"), 100 | ) 101 | return non_permuted_paths 102 | 103 | 104 | # ------- Dataset Processing ------- 105 | print("--- Processing of the PaRoutes dataset begins!") 106 | print("-- starting to canonicalize n1 and n5 stocks") 107 | n1_stock = open(data_path / "n1-stock.txt").read().splitlines() 108 | n5_stock = open(data_path / "n5-stock.txt").read().splitlines() 109 | 110 | n1_stock_canon = [canonicalize_smiles(smi) for smi in n1_stock] 111 | n5_stock_canon = [canonicalize_smiles(smi) for smi in n5_stock] 112 | 113 | with open(data_path / "n1-stock.txt", "w") as f: 114 | f.write("\n".join(n1_stock_canon)) 115 | 116 | with open(data_path / "n5-stock.txt", "w") as f: 117 | f.write("\n".join(n5_stock_canon)) 118 | 119 | 120 | print("-- starting to process n1 Routes") 121 | n_perms: int | None = None # None for all 122 | n_sms: int | None = 1 # None for all 123 | perm_suffix = "all" if n_perms is None else str(n_perms) 124 | sm_suffix = "all" if n_sms is None else str(n_sms) 125 | n1_routes_obj = PaRoutesDataset(data_path, "n1-routes.json") 126 | n1_path_set = n1_routes_obj.prepare_final_dataset_v2( 127 | save_path / f"n1_dataset_nperms={perm_suffix}_nsms={sm_suffix}.pkl", 128 | n_perms=n_perms, 129 | n_sms=n_sms, 130 | ) 131 | pickle.dump(n1_path_set, open(save_path / f"n1_nperms={perm_suffix}_nsms={sm_suffix}_path_set.pkl", "wb")) 132 | 133 | print("-- starting to process n5 Routes") 134 | n5_routes_obj = PaRoutesDataset(data_path, "n5-routes.json") 135 | n5_path_set = n5_routes_obj.prepare_final_dataset_v2( 136 | save_path / f"n5_dataset_nperms={perm_suffix}_nsms={sm_suffix}.pkl", n_perms=n_perms, n_sms=n_sms 137 | ) 138 | pickle.dump(n5_path_set, open(save_path / f"n5_nperms={perm_suffix}_nsms={sm_suffix}_path_set.pkl", "wb")) 139 | 140 | n1_path_set = pickle.load(open(save_path / "n1_nperms=all_nsms=1_path_set.pkl", "rb")) 141 | n5_path_set = pickle.load(open(save_path / "n5_nperms=all_nsms=1_path_set.pkl", "rb")) 142 | 143 | print("-- starting to process All Routes") 144 | all_routes_obj = PaRoutesDataset(data_path, "all_routes.json") 145 | all_routes_obj.prepare_final_dataset_v2( 146 | save_path / "all_dataset_nperms=3_nsms=1.pkl", 147 | n_perms=1, 148 | n_sms=1, 149 | exclude_path_strings=n1_path_set | n5_path_set, 150 | ) 151 | 152 | 153 | # ------- Prepare Evaluation Subsets ------- 154 | # testing_dataset = "n5" 155 | 156 | # (_products, _sms, _path_strings, _steps_list) = pickle.load( 157 | # open(save_path / f"{testing_dataset}_dataset_nperms=1_nsms=1.pkl", "rb") 158 | # ) 159 | # combined = [{"product": p, "SM": s, "path_string": ps, "steps": st} for p, s, ps, st in zip(_products, _sms, _path_strings, _steps_list)] 160 | 161 | # # shuffle the list 162 | # import random 163 | # random.seed(42) 164 | # random.shuffle(combined) 165 | # _sh_prods = [x["product"] for x in combined] 166 | # _sh_sms = [x["SM"] for x in combined] 167 | # _sh_paths = [x["path_string"] for x in combined] 168 | # _sh_steps = [x["steps"] for x in combined] 169 | # for n_elts in [10, 50,]: 170 | # pickle.dump((_sh_prods[:n_elts], _sh_sms[:n_elts], _sh_paths[:n_elts], _sh_steps[:n_elts]), open(save_path / f"{testing_dataset}_shuffled_seed42_n{n_elts}.pkl", "wb")) 171 | 172 | 173 | # ------- Prepare Evaluation Subsets ------- 174 | # testing_dataset = "n1" 175 | 176 | # (_products, _sms, _path_strings, _steps_list) = pickle.load( 177 | # open(save_path / f"{testing_dataset}_dataset_nperms=1_nsms=1.pkl", "rb") 178 | # ) 179 | 180 | # first, second, third, fourth = 2500, 5000, 7500, 10000 181 | # pickle.dump((_products[:first], _sms[:first], _path_strings[:first], _steps_list[:first]), open(save_path / f"{testing_dataset}_dataset_nperms=1_nsms=1_n{first}.pkl", "wb")) 182 | # pickle.dump((_products[first:second], _sms[first:second], _path_strings[first:second], _steps_list[first:second]), open(save_path / f"{testing_dataset}_dataset_nperms=1_nsms=1_n{second}.pkl", "wb")) 183 | # pickle.dump((_products[second:third], _sms[second:third], _path_strings[second:third], _steps_list[second:third]), open(save_path / f"{testing_dataset}_dataset_nperms=1_nsms=1_n{third}.pkl", "wb")) 184 | # pickle.dump((_products[third:fourth], _sms[third:fourth], _path_strings[third:fourth], _steps_list[third:fourth]), open(save_path / f"{testing_dataset}_dataset_nperms=1_nsms=1_n{fourth}.pkl", "wb")) 185 | 186 | 187 | # ------- Remove SM info from datasets ------- 188 | 189 | 190 | def remove_sm_from_ds(load_path: Path, save_path: Path) -> None: 191 | products, _, path_strings, n_steps_lists = pickle.load(open(load_path, "rb")) 192 | pickle.dump((products, path_strings, n_steps_lists), open(save_path, "wb")) 193 | 194 | 195 | remove_sm_from_ds( 196 | load_path=save_path / "all_dataset_nperms=1_nsms=1.pkl", save_path=save_path / "all_dataset_nperms=1_nosm.pkl" 197 | ) 198 | remove_sm_from_ds( 199 | load_path=save_path / "n1_dataset_nperms=1_nsms=1.pkl", save_path=save_path / "n1_dataset_nperms=1_nosm.pkl" 200 | ) 201 | remove_sm_from_ds( 202 | load_path=save_path / "n5_dataset_nperms=1_nsms=1.pkl", save_path=save_path / "n5_dataset_nperms=1_nosm.pkl" 203 | ) 204 | 205 | # ------- Create train/val partitions ------- 206 | train_fname = "unique_dataset_nperms=3_nsms=all_noboth.pkl" 207 | ds_dict = load_dataset_sm(save_path / train_fname) 208 | ds_list = convert_dict_of_lists_to_list_of_dicts(ds_dict) 209 | 210 | random.seed(42) 211 | random.shuffle(ds_list) 212 | 213 | val_frac = 0.05 214 | train_ds = ds_list[: int(len(ds_list) * (1 - val_frac))] 215 | val_ds = ds_list[int(len(ds_list) * (1 - val_frac)) :] 216 | print(f"Train dataset size: {len(train_ds)}") 217 | print(f"Validation dataset size: {len(val_ds)}") 218 | 219 | train_ds_dict = convert_list_of_dicts_to_dict_of_lists(train_ds) 220 | val_ds_dict = convert_list_of_dicts_to_dict_of_lists(val_ds) 221 | 222 | save_dataset_sm(train_ds_dict, save_path / f"{train_fname.split('.')[0]}_train={1-val_frac}.pkl") 223 | save_dataset_sm(val_ds_dict, save_path / f"{train_fname.split('.')[0]}_val={val_frac}.pkl") 224 | -------------------------------------------------------------------------------- /data/process_eMolecules.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pandas as pd 4 | from tqdm import tqdm 5 | 6 | from directmultistep.utils.logging_config import logger 7 | from directmultistep.utils.pre_process import canonicalize_smiles 8 | 9 | DATA_PATH = Path(__file__).parent 10 | COMPOUND_PATH = DATA_PATH / "compounds" 11 | 12 | if __name__ == "__main__": 13 | logger.info("Loading eMolecules csv...") 14 | # `origin_dict.csv` from `github.com/binghong-ml/retro_star` 15 | emol_df = pd.read_csv(COMPOUND_PATH / "origin_dict.csv", index_col=0) 16 | emol_smiles = emol_df["mol"].tolist() 17 | logger.info(f"Number of eMolecules smiles: {len(emol_smiles)}") 18 | del emol_df 19 | 20 | logger.info("Canonicalizing SMILES strings...") 21 | canonicalized_smiles = [] 22 | for smiles in tqdm(emol_smiles): 23 | try: 24 | canonicalized_smiles.append(canonicalize_smiles(smiles)) 25 | except ValueError as e: 26 | logger.error(f"Error canonicalizing SMILES '{smiles}': {e} during canonicalizing buyables") 27 | 28 | logger.info("Saving unique canonicalized SMILES strings...") 29 | unique_smiles = list(set(canonicalized_smiles)) 30 | logger.info(f"Number of unique eMolecules smiles: {len(unique_smiles)}") 31 | with open(COMPOUND_PATH / "eMolecules.txt", "w") as f: 32 | for smiles in unique_smiles: 33 | f.write(smiles + "\n") 34 | -------------------------------------------------------------------------------- /data/route_separation.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "name": "First Half", 4 | "path": "{'smiles':'NC[C@H]1CC[C@@H](C(O)CCc2ccccc2)CC1','children':[{'smiles':'CC(C)(C)OC(=O)NC[C@H]1CC[C@@H](C(O)CCc2ccccc2)CC1','children':[{'smiles':'CC(C)(C)OC(=O)NC[C@H]1CC[C@@H](C(=O)CCc2ccccc2)CC1','children':[{'smiles':'Br[Mg]CCc1ccccc1'},{'smiles':'CON(C)C(=O)[C@H]1CC[C@@H](CNC(=O)OC(C)(C)C)CC1','children':[{'smiles':'CC(C)(C)OC(=O)NC[C@H]1CC[C@@H](C(=O)O)CC1','children':[{'smiles':'CCCCOC(=O)[C@H]1CC[C@@H](CNC(=O)OC(C)(C)C)CC1'}]},{'smiles':'CNOC'}]}]}]}]}" 5 | }, 6 | { 7 | "name": "Second Half", 8 | "path": "{'smiles':'CCCCOC(=O)[C@H]1CC[C@@H](CNC(=O)OC(C)(C)C)CC1','children':[{'smiles':'CC(C)(C)OC(=O)OC(=O)OC(C)(C)C'},{'smiles':'CCCCOC(=O)[C@H]1CC[C@@H](CN)CC1','children':[{'smiles':'CCCCOC(=O)[C@H]1CC[C@@H](CN=[N+]=[N-])CC1','children':[{'smiles':'CCCCOC(=O)[C@H]1CC[C@@H](COS(C)(=O)=O)CC1','children':[{'smiles':'CCCCOC(=O)[C@H]1CC[C@@H](CO)CC1','children':[{'smiles':'CCCCOC(=O)[C@H]1CC[C@@H](C(=O)O)CC1'}]},{'smiles':'CS(=O)(=O)Cl'}]},{'smiles':'[N-]=[N+]=[N-]'}]}]}]}" 9 | } 10 | ] -------------------------------------------------------------------------------- /docs/CNAME: -------------------------------------------------------------------------------- 1 | dms.batistalab.com -------------------------------------------------------------------------------- /docs/DirectMultiStep/components/attention.md: -------------------------------------------------------------------------------- 1 | # Attention 2 | 3 | This document describes the attention mechanisms used in the DMS model. 4 | 5 | ## Summary 6 | 7 | The core mechanism of attention emerges from needing to selectively focus on relevant information while processing sequences. When encoding tokens, each position must consider its relationship with all others to capture context. Attention computes similarity scores between each query position and all possible key positions, essentially asking "how relevant is each key to my current query?" These raw similarity scores are normalized through softmax to produce attention weights that sum to 1, creating a probability distribution over the keys for each query. The weighted sum of values according to these attention weights produces the final attention output, allowing the model to synthesize information from multiple positions with varying degrees of influence. 8 | 9 | ### Flash Attention 10 | 11 | Flash Attention reformulates attention computation to maximize use of fast SRAM cache while minimizing slower DRAM memory access. Rather than computing and storing the full attention matrix at once, it splits the computation into smaller blocks that fit in SRAM, computing partial attention scores and incrementally aggregating them. This tiling approach, combined with local softmax normalization within blocks, achieves mathematically equivalent results while drastically reducing memory bandwidth requirements. The key insight is maintaining rolling statistics of softmax normalization terms across blocks, allowing processing of long sequences without materializing the full attention matrix in memory – trading increased computation for reduced memory usage, which is favorable on modern hardware where memory bandwidth often constrains performance more than computational capacity. 12 | 13 | ### Shape Convention 14 | 15 | The shape suffixes follow a consistent convention: 16 | 17 | - `B`: Batch size 18 | - `L`: Target sequence length 19 | - `M`: Memory/source sequence length 20 | - `D`: Model hidden dimension 21 | - `H`: Number of attention heads 22 | 23 | ## Source Code 24 | 25 | ::: directmultistep.model.components.attention 26 | handler: python 27 | members: MultiHeadAttentionLayer 28 | options: 29 | show_root_heading: true 30 | show_source: true 31 | -------------------------------------------------------------------------------- /docs/DirectMultiStep/components/decoder.md: -------------------------------------------------------------------------------- 1 | # Decoder 2 | 3 | This document describes the decoder components used in the DMS model. 4 | 5 | ## Base Decoder Layer 6 | 7 | The basic building block of the decoder that processes target sequences. 8 | 9 | ### Components 10 | 11 | #### **Self-Attention Block** 12 | 13 | - Multi-head self-attention mechanism 14 | - Causal masking to prevent looking ahead 15 | - Layer normalization 16 | - Residual connection 17 | 18 | #### **Cross-Attention Block** 19 | 20 | - Multi-head attention over encoder outputs 21 | - Allows decoder to focus on relevant input parts 22 | - Layer normalization 23 | - Residual connection 24 | 25 | #### **Feed-Forward Block** 26 | 27 | - Two-layer feed-forward network 28 | - Configurable activation function (ReLU or GELU) 29 | - Layer normalization 30 | - Residual connection 31 | 32 | ## Source Code 33 | 34 | ::: directmultistep.model.components.decoder 35 | handler: python 36 | options: 37 | show_root_heading: true 38 | show_source: true 39 | -------------------------------------------------------------------------------- /docs/DirectMultiStep/components/encoder.md: -------------------------------------------------------------------------------- 1 | # Encoder 2 | 3 | This document describes the encoder components used in the DMS model. 4 | 5 | ## Base Encoder Layer 6 | 7 | The basic building block of the encoder that processes input sequences. 8 | 9 | ### Components 10 | 11 | #### **Self-Attention Block** 12 | 13 | - Multi-head self-attention mechanism 14 | - Layer normalization 15 | - Residual connection 16 | 17 | #### **Feed-Forward Block** 18 | 19 | - Two-layer feed-forward network 20 | - Configurable activation function (ReLU or GELU) 21 | - Layer normalization 22 | - Residual connection 23 | 24 | ## Source Code 25 | 26 | ::: directmultistep.model.components.encoder 27 | handler: python 28 | options: 29 | show_root_heading: true 30 | show_source: true 31 | -------------------------------------------------------------------------------- /docs/DirectMultiStep/components/moe.md: -------------------------------------------------------------------------------- 1 | # Mixture of Experts 2 | 3 | This document describes the Mixture of Experts (MoE) components used in the DMS model. MoE is a technique that improves model capacity and efficiency by routing different inputs to specialized sub-networks (experts). 4 | 5 | ## Position-wise Feed-forward Layer 6 | 7 | The standard feed-forward network serves as our baseline for comparison with MoE layers. It processes each position in the sequence independently through a simple two-layer network with expansion and projection. This is the traditional architecture used in transformer models. 8 | 9 | ## Noisy Top-k Router 10 | 11 | The router is the brain of the MoE system - it decides which experts should process each token. Key features: 12 | 13 | - Uses learned routing weights to match tokens with relevant experts 14 | - Adds learned noise to encourage exploration and prevent expert collapse 15 | - Selects top-k experts per token to enable specialization while maintaining redundancy 16 | - Produces sparse routing probabilities to enable efficient computation 17 | 18 | The noise mechanism is particularly important as it: 19 | 20 | 1. Prevents tokens from always taking the same path 21 | 2. Helps balance load across experts 22 | 3. Improves training stability 23 | 24 | ## Expert Network 25 | 26 | Each expert is a specialized feed-forward network that becomes tuned to handle specific types of tokens or patterns. The expert architecture mirrors the standard feed-forward layer, but each expert can learn different specializations. For example: 27 | 28 | - Some experts might focus on syntax 29 | - Others on specific vocabulary domains 30 | - Others on particular transformation patterns 31 | 32 | ## Sparse MoE Layer 33 | 34 | This is where everything comes together into an efficient, scalable system: 35 | 36 | 1. **Token Routing**: The router examines each token and decides which experts should process it 37 | 2. **Load Balancing**: 38 | - Uses capacity factors to prevent expert overload 39 | - Ensures even utilization of experts 40 | - Handles cases where too many tokens want the same expert 41 | 3. **Parallel Processing**: 42 | - Tokens are grouped by assigned expert 43 | - Each expert processes its assigned group 44 | - Results are combined based on routing weights 45 | 46 | The sparse computation pattern makes MoE layers much more efficient than simply running multiple full-size feed-forward layers. 47 | 48 | ### Intuition Behind MoE 49 | 50 | Think of MoE like a team of specialists: 51 | 52 | - Instead of every token going through the same general-purpose network 53 | - Tokens are routed to experts that are best suited to process them 54 | - Each expert becomes specialized in handling certain types of patterns 55 | - The router learns to match tokens with the right experts 56 | 57 | This specialization allows the model to: 58 | 59 | - Handle a wider range of patterns effectively 60 | - Scale capacity without scaling computation for every token 61 | - Develop focused expertise in different aspects of the task 62 | 63 | ## Source Code 64 | 65 | ::: directmultistep.model.components.moe 66 | handler: python 67 | options: 68 | show_root_heading: true 69 | show_source: true 70 | -------------------------------------------------------------------------------- /docs/DirectMultiStep/evaluation.md: -------------------------------------------------------------------------------- 1 | # Subset Evaluation 2 | 3 | This documentation covers how to evaluate model performance on specific subsets of data using beam search. 4 | 5 | ## Example Use 6 | 7 | Evaluating a model on a subset involves several steps: 8 | 9 | 1. Configure the evaluation parameters using `EvalConfig` 10 | 2. Load the model using `ModelFactory` 11 | 3. Initialize `ModelEvaluator` and run evaluation 12 | 13 | See `use-examples/eval-subset.py` for a full example. 14 | 15 | ## Source Code 16 | 17 | ::: directmultistep.generation.eval 18 | handler: python 19 | options: 20 | show_root_heading: true 21 | show_source: true 22 | members: 23 | - EvalConfig 24 | - ModelEvaluator 25 | -------------------------------------------------------------------------------- /docs/DirectMultiStep/model-init.md: -------------------------------------------------------------------------------- 1 | # Creating a model instance 2 | 3 | There are several ways to create a DMS model instance, ranging from using preset configurations to custom configurations. 4 | 5 | ## Using Preset Configurations 6 | 7 | The simplest way to create a model is using one of the preset configurations: 8 | 9 | ```py 10 | from directmultistep.model import ModelFactory 11 | 12 | factory = ModelFactory.from_preset("flash_10M", compile_model=True) 13 | model = factory.create_model() 14 | ``` 15 | 16 | Available presets include: `deep_40M`, `explorer_xl_50M`, `flash_10M`, `flash_20M`, `flex_20M`, and `wide_40M`. 17 | 18 | ## Custom Configuration 19 | 20 | For more control, you can create a custom configuration: 21 | 22 | ```python 23 | from directmultistep.model.config import Seq2SeqConfig, EncoderAConfig, MoEDecoderConfig 24 | 25 | config = Seq2SeqConfig( 26 | encoder=EncoderAConfig( 27 | vocab_dim=53, 28 | hid_dim=256, 29 | n_layers=6, 30 | n_heads=8, 31 | ff_mult=3, 32 | ff_activation="gelu", 33 | dropout=0.1, 34 | attn_bias=False, 35 | context_window=280, 36 | start_idx=0, 37 | mask_idx=51, 38 | pad_idx=52, 39 | initiate_steps=True, 40 | include_steps=True 41 | ), 42 | decoder=MoEDecoderConfig( 43 | vocab_dim=53, 44 | hid_dim=256, 45 | n_layers=6, 46 | n_heads=8, 47 | ff_mult=3, 48 | ff_activation="gelu", 49 | dropout=0.1, 50 | attn_bias=False, 51 | context_window=1075, 52 | start_idx=0, 53 | mask_idx=51, 54 | pad_idx=52, 55 | n_experts=3, 56 | top_k=2, 57 | capacity_factor=1.0, 58 | ), 59 | ) 60 | 61 | factory = ModelFactory(config, device=None, compile_model=True) 62 | model = factory.create_model() 63 | ``` 64 | 65 | ## Configuration Types 66 | 67 | The model supports different types of encoders and decoders: 68 | 69 | - Encoders: 70 | - `EncoderAConfig`: EncoderA Type (the one we've been using so far) 71 | - `MoEEncoderConfig`: Mixture of Experts encoder 72 | 73 | - Decoders: 74 | - `TransformerConfig`: Standard transformer decoder 75 | - `MoEDecoderConfig`: Mixture of Experts decoder 76 | 77 | ## Saving and Loading Configurations 78 | 79 | Configurations can be saved to and loaded from YAML files: 80 | 81 | ```python 82 | # Save configuration 83 | config.save("model_config.yaml") 84 | 85 | # Load configuration and create model 86 | factory = ModelFactory.from_config_file("model_config.yaml") 87 | model = factory.create_model() 88 | ``` 89 | 90 | ## Source Code 91 | 92 | ::: directmultistep.model.config 93 | handler: python 94 | options: 95 | show_root_heading: true 96 | show_source: true 97 | members: 98 | - TransformerConfig 99 | - MoEDecoderConfig 100 | - EncoderAConfig 101 | - MoEEncoderConfig 102 | - Seq2SeqConfig 103 | 104 | ::: directmultistep.model.factory 105 | handler: python 106 | options: 107 | show_root_heading: true 108 | show_source: true 109 | members: 110 | - ModelFactory 111 | -------------------------------------------------------------------------------- /docs/DirectMultiStep/training.md: -------------------------------------------------------------------------------- 1 | # Training 2 | 3 | ## Example Use 4 | 5 | Training a model involves three main steps: 6 | 7 | 1. Create a model configuration and instance using `ModelFactory` 8 | 2. Configure the training parameters using `TrainingConfig` 9 | 3. Initialize the `ModelTrainer` and start training 10 | 11 | See `use-examples/train_model.py` for a full example. 12 | 13 | ## Source Code 14 | 15 | ::: directmultistep.training.config 16 | handler: python 17 | options: 18 | show_root_heading: true 19 | show_source: true 20 | members: 21 | - TrainingConfig 22 | 23 | ::: directmultistep.training.trainer 24 | handler: python 25 | options: 26 | show_root_heading: true 27 | show_source: true 28 | members: 29 | - ModelTrainer 30 | 31 | ::: directmultistep.training.lightning 32 | handler: python 33 | options: 34 | show_root_heading: true 35 | show_source: true 36 | members: 37 | - warmup_and_cosine_decay 38 | - LTraining 39 | -------------------------------------------------------------------------------- /docs/DirectMultiStep/utils/io.md: -------------------------------------------------------------------------------- 1 | # Input/Output Utilities 2 | 3 | This module provides functions for loading and saving datasets, as well as converting between different data formats. It is useful for preparing data for training and testing DirectMultiStep models. 4 | 5 | ## Example Use 6 | 7 | The most useful functions are `load_dataset_sm`, `load_dataset_nosm`, `save_dataset_sm`, and `load_pharma_compounds`. These functions allow you to load and save datasets in a variety of formats. 8 | 9 | ```python 10 | from pathlib import Path 11 | from directmultistep.utils.io import load_pharma_compounds 12 | 13 | data_path = Path.cwd() / "data" 14 | 15 | _products, _sms, _path_strings, _steps_list, nameToIdx = load_pharma_compounds(data_path / "pharma_compounds.json") 16 | ``` 17 | 18 | ## Source Code 19 | 20 | ::: directmultistep.utils.io 21 | handler: python 22 | options: 23 | show_root_heading: true 24 | show_source: true 25 | members: 26 | - DatasetDict 27 | - load_dataset_sm 28 | - load_dataset_nosm 29 | - save_dataset_sm 30 | - convert_dict_of_lists_to_list_of_dicts 31 | - convert_list_of_dicts_to_dict_of_lists 32 | - load_pharma_compounds 33 | - load_commercial_stock -------------------------------------------------------------------------------- /docs/DirectMultiStep/utils/post-process.md: -------------------------------------------------------------------------------- 1 | # Multistep Route Post-processing 2 | 3 | This module provides useful data structure classes and helper functions for postprocessing beam search results and multistep routes generated by DirectMultiStep models. 4 | 5 | ## Example Use 6 | 7 | The most useful functions are `canonicalize_path_dict`, `canonicalize_path_string`, and functions that start with `find_` 8 | 9 | ```python 10 | from directmultistep.utils.pre_process import stringify_dict 11 | from directmultistep.utils.post_process import canonicalize_path_dict, canonicalize_path_string 12 | 13 | path_string = "{'smiles':'CNCc1cc(-c2ccccc2F)n(S(=O)(=O)c2cccnc2)c1','children':[{'smiles':'O=Cc1cc(-c2ccccc2F)n(S(=O)(=O)c2cccnc2)c1','children':[{'smiles':'O=Cc1c[nH]c(-c2ccccc2F)c1'},{'smiles':'O=S(=O)(Cl)c1cccnc1'}]},{'smiles':'CN'}]}" 14 | 15 | cano_path_dict = canonicalize_path_dict(eval(path_string)) 16 | cano_path_string = stringify_dict(cano_path_dict) 17 | 18 | print(cano_path_string == canonicalize_path_string(path_string)) 19 | ``` 20 | 21 | ## Source Code 22 | 23 | ::: directmultistep.utils.post_process 24 | handler: python 25 | options: 26 | show_root_heading: true 27 | show_source: true 28 | members: 29 | - count_unsolved_targets 30 | - find_valid_paths 31 | - find_matching_paths 32 | - find_top_n_accuracy 33 | - remove_repetitions_within_beam_result 34 | - find_paths_with_commercial_sm 35 | - find_paths_with_correct_product_and_reactants 36 | - canonicalize_path_dict 37 | - canonicalize_path_string 38 | - process_paths 39 | - process_path_single 40 | - process_paths_post 41 | - calculate_top_k_counts_by_step_length -------------------------------------------------------------------------------- /docs/DirectMultiStep/utils/pre-process.md: -------------------------------------------------------------------------------- 1 | # Multistep Route Pre-processing 2 | 3 | This module provides useful data structure classes and helper functions for preprocessing multistep routes for training and testing DirectMultiStep models. 4 | 5 | ## Example Use 6 | 7 | The most frequently used data structure is `FilteredDict`, a dictionary format for multistep routes used in DirectMultiStep models. Several useful functions are available, such as `canonicalize_smiles`, `max_tree_depth`, `find_leaves`, `stringify_dict`, and `generate_permutations`, among others. For example: 8 | 9 | ```python 10 | from directmultistep.utils.pre_process import stringify_dict 11 | 12 | path_string = "{'smiles':'CNCc1cc(-c2ccccc2F)n(S(=O)(=O)c2cccnc2)c1','children':[{'smiles':'O=Cc1cc(-c2ccccc2F)n(S(=O)(=O)c2cccnc2)c1','children':[{'smiles':'O=Cc1c[nH]c(-c2ccccc2F)c1'},{'smiles':'O=S(=O)(Cl)c1cccnc1'}]},{'smiles':'CN'}]}" 13 | 14 | # This should evaluate to True, as it compares the stringified version of your FilteredDict 15 | print(stringify_dict(eval(path_string)) == path_string) 16 | ``` 17 | 18 | ## Source Code 19 | 20 | ::: directmultistep.utils.pre_process 21 | handler: python 22 | options: 23 | show_root_heading: true 24 | show_source: true 25 | members: 26 | - PaRoutesDict 27 | - FilteredDict 28 | - filter_mol_nodes 29 | - max_tree_depth 30 | - find_leaves 31 | - canonicalize_smiles 32 | - stringify_dict 33 | - generate_permutations 34 | - is_convergent 35 | -------------------------------------------------------------------------------- /docs/DirectMultiStep/utils/torch-dataset.md: -------------------------------------------------------------------------------- 1 | # Torch Dataset for Routes 2 | 3 | This module provides a custom PyTorch Dataset class for handling reaction routes. It includes functionalities for tokenizing SMILES strings, reaction paths, and context information, as well as preparing data for training and generation. 4 | 5 | ## Example Use 6 | 7 | `tokenize_path_string` is the most important function. It tokenizes a reaction path string. It uses a regular expression to split the string into tokens, and it can optionally add start-of-sequence (``) and end-of-sequence (`?`) tokens. 8 | 9 | ```python 10 | from directmultistep.utils.dataset import tokenize_path_string 11 | 12 | path_string = "{'smiles':'CC','children':[{'smiles':'CC(=O)O'}]}" 13 | tokens = tokenize_path_string(path_string) 14 | print(tokens) 15 | ``` 16 | 17 | ## Notes on Path Start 18 | 19 | In the `RoutesDataset` class, the `get_generation_with_sm` and `get_generation_no_sm` methods return an initial path tensor. This tensor is created from a `path_start` string, which is a partial path string that the model will start generating from. The `path_start` is `"{'smiles': 'product_smiles', 'children': [{'smiles':"`. The model will generate the rest of the path string from this starting point. 20 | 21 | This design is important because a trained model always generates this `path_start` at the beginning of the sequence. By providing this as the initial input, we avoid wasting time generating this part and can focus on generating the rest of the reaction path. 22 | 23 | The `prepare_input_tensors` function in `directmultistep.generate` allows for the provision of a custom `path_start` string. This is useful when you want to initiate the generation process from a specific point in the reaction path, instead of the default starting point. By modifying the `path_start` argument, you can control the initial state of the generation and explore different reaction pathways with user-defined intermediates. 24 | 25 | ## Source Code 26 | 27 | ::: directmultistep.utils.dataset 28 | handler: python 29 | options: 30 | show_root_heading: true 31 | show_source: true 32 | members: 33 | - tokenize_smile 34 | - tokenize_smile_atom 35 | - tokenize_context 36 | - tokenize_path_string 37 | - RoutesDataset 38 | 39 | ::: directmultistep.generate 40 | handler: python 41 | options: 42 | show_root_heading: true 43 | show_source: true 44 | members: 45 | - prepare_input_tensors -------------------------------------------------------------------------------- /docs/DirectMultiStep/visualizations.md: -------------------------------------------------------------------------------- 1 | # Visualizing Routes 2 | 3 | ## Example use 4 | 5 | To visualize a path string, you can use the following snippet: 6 | 7 | ```python 8 | from directmultistep.utils.web_visualize import draw_tree_from_path_string 9 | 10 | path = "{'smiles':'O=C(c1ccc(NS(=O)(=O)c2cccc3cccnc23)cc1)N1CCN(CC2CC2)CC1','children':[{'smiles':'O=C(O)c1ccc(NS(=O)(=O)c2cccc3cccnc23)cc1','children':[{'smiles':'CCOC(=O)c1ccc(NS(=O)(=O)c2cccc3cccnc23)cc1','children':[{'smiles':'CCOC(=O)c1ccc(N)cc1'},{'smiles':'O=S(=O)(Cl)c1cccc2cccnc12'}]}]},{'smiles':'C1CN(CC2CC2)CCN1'}]}" 11 | 12 | svg_str = draw_tree_from_path_string( 13 | path_string=path, 14 | save_path=Path("data/figures/desired_file_name"), 15 | width=400, 16 | height=400, 17 | x_margin=50, 18 | y_margin=100, 19 | theme="light", 20 | ) 21 | ``` 22 | 23 | ## Source Code 24 | 25 | ::: directmultistep.utils.web_visualize 26 | handler: python 27 | options: 28 | show_root_heading: true 29 | show_source: true 30 | members: 31 | - FilteredDict 32 | - ThemeType 33 | - ColorPalette 34 | - RetroSynthesisTree 35 | - TreeDimensions 36 | - compute_subtree_dimensions 37 | - compute_canvas_dimensions 38 | - check_overlap 39 | - draw_molecule 40 | - draw_tree_svg 41 | - create_tree_from_path_string 42 | - draw_tree_from_path_string 43 | -------------------------------------------------------------------------------- /docs/analysis/monitoring-training.md: -------------------------------------------------------------------------------- 1 | # Monitoring Training 2 | 3 | This guide explains how to monitor and visualize training progress for DMS models. 4 | 5 | ## Basic Usage 6 | 7 | The simplest way to visualize training progress is using the provided plotting utilities in `use-examples/visualize_train_curves.py` 8 | 9 | ## Run Configuration 10 | 11 | Use `RunConfig` to specify which training runs to visualize: 12 | 13 | ```python 14 | from directmultistep.analysis.training import RunConfig 15 | 16 | run = RunConfig( 17 | run_name="flash_10M", # Folder name of the run 18 | trace_name="Flash Model", # Display name for the traces 19 | include_val=True # Whether to include validation curve 20 | ) 21 | ``` 22 | 23 | ## Training Curves 24 | 25 | The `plot_training_curves` function creates a figure showing: 26 | 27 | - Training loss curves (solid lines) 28 | - Validation loss curves (dotted lines with markers) 29 | - X-axis shows number of processed tokens 30 | - Hovering over validation points shows epoch information 31 | 32 | ## Learning Rate Curves 33 | 34 | The `plot_learning_rates` function visualizes the learning rate schedule: 35 | 36 | - Shows learning rate vs. training step 37 | - Useful for verifying learning rate schedules 38 | - Multiple runs can be compared on the same plot 39 | 40 | ## Advanced Usage 41 | 42 | For more control over visualization, you can load the training data directly: 43 | 44 | ```python 45 | from directmultistep.analysis.training import load_training_df 46 | 47 | # Load training data 48 | df = load_training_df(train_path, "flash_10M") 49 | 50 | # Ignore specific training runs by ID 51 | df = load_training_df(train_path, "flash_10M", ignore_ids=[0, 1]) 52 | ``` 53 | 54 | The returned DataFrame contains columns: 55 | 56 | - `processed_tokens`: Number of tokens processed 57 | - `train_loss`: Training loss 58 | - `val_loss`: Validation loss (if available) 59 | - `train_lr`: Learning rate 60 | - `epoch`: Current epoch 61 | - Additional metrics depending on the training configuration 62 | 63 | ## Source Code 64 | 65 | ::: directmultistep.analysis.training 66 | handler: python 67 | options: 68 | show_root_heading: true 69 | show_source: true 70 | members: 71 | - RunConfig 72 | - plot_training_curves 73 | - plot_learning_rates 74 | - load_training_df 75 | -------------------------------------------------------------------------------- /docs/analysis/paper-figures.md: -------------------------------------------------------------------------------- 1 | # Paper Figures 2 | 3 | This document describes the figures that can be generated using the `paper-figures.py` script. 4 | 5 | ## Available Figures 6 | 7 | ### 1. Route Length Distribution 8 | 9 | - **File**: `route_length_distribution.{pdf,html}` 10 | - **Description**: Visualizes the distribution of route lengths across different datasets (training, n1, and n5 datasets). 11 | - **Generated by**: `plot_route_length_distribution()` 12 | 13 | ### 2. Leaf Distribution 14 | 15 | - **File**: `leaf_distribution.{pdf,html}` 16 | - **Description**: Shows the distribution of leaf nodes (end states) across different datasets. 17 | - **Generated by**: `plot_leaf_distribution()` 18 | 19 | ### 3. Convergent Route Analysis 20 | 21 | Two figures are generated for convergent route analysis: 22 | 23 | - **Files**: 24 | - `convergent_fraction_by_length.{pdf,html}` 25 | - `convergent_fraction_overall.{pdf,html}` 26 | - **Description**: Analyzes the fraction of convergent routes by length and overall convergent fraction across datasets. 27 | - **Generated by**: `plot_convergent_fraction_by_length()` and `plot_convergent_fraction_overall()` 28 | 29 | ### 4. Top-K Accuracy Analysis 30 | 31 | - **File**: `{dataset_name}_topk_accuracy_subplots.{pdf,html}` 32 | - **Description**: Comparative bar plots showing top-k accuracy metrics for different models and configurations. 33 | - **Features**: Shows accuracy for k values [1, 2, 3, 4, 5, 10] 34 | - **Generated separately** for n1 and n5 datasets 35 | 36 | ### 5. Route Processing Stages 37 | 38 | - **File**: `{dataset_name}_route_processing_stages_{config}.{pdf,html}` 39 | - **Description**: Visualizes different stages of route processing, comparing: 40 | - Valid routes 41 | - Processed routes without stock 42 | - Processed routes with stock 43 | - True routes 44 | 45 | ### 6. Accuracy by Route Length 46 | 47 | - **File**: `accuracy_by_length_subplots_{config}.{pdf,html}` 48 | - **Description**: Shows top-k accuracy metrics broken down by route length 49 | - **Features**: 50 | - Compares performance across different datasets (n1, n5) 51 | - Shows accuracy for k=1 and k=10 52 | 53 | ## Usage 54 | 55 | To generate these figures, modify the `rerun` dictionary in `paper-figures.py` to specify which figures you want to generate: 56 | 57 | ```python 58 | rerun = { 59 | "route-distribution": False, 60 | "leaf-distribution": False, 61 | "convergent-fraction": False, 62 | "topk-accuracy": False, 63 | "extraction-distribution": True, 64 | "accuracy-by-length": False, 65 | } 66 | ``` 67 | 68 | Set the corresponding flag to `True` for the figures you want to generate. All figures will be saved in both PDF and HTML formats in the `data/figures/paper` directory. 69 | 70 | ## Source Code 71 | 72 | ::: directmultistep.analysis.paper.dataset_analysis 73 | handler: python 74 | options: 75 | show_root_heading: true 76 | show_source: true 77 | 78 | ::: directmultistep.analysis.paper.linear_vs_convergent 79 | handler: python 80 | options: 81 | show_root_heading: true 82 | show_source: true 83 | -------------------------------------------------------------------------------- /docs/analysis/style-settings.md: -------------------------------------------------------------------------------- 1 | # Visualization Style Settings 2 | 3 | This guide explains the available style settings for visualizations in the analysis tools. 4 | 5 | ## Color Palettes 6 | 7 | The analysis tools provide several predefined color palettes for consistent visualization: 8 | 9 | - `style.colors_names` 10 | - `style.colors_light` 11 | - `style.colors_dark` 12 | 13 | ## Plot Settings 14 | 15 | The default plot settings use a dark theme: 16 | 17 | ```python 18 | template = "plotly_dark" # Plotly dark theme 19 | plot_bgcolor = "#000000" # Black plot background 20 | paper_bgcolor = "#000000" # Black paper background 21 | ``` 22 | 23 | ## Usage in Visualizations 24 | 25 | The style settings are automatically applied in visualization functions like `plot_training_curves` and `plot_learning_rates`. The color palettes are used cyclically when plotting multiple runs: 26 | 27 | - Training curves use `colors_light` 28 | - Validation curves use `colors_dark` 29 | - Special visualizations can use specific colors from `colors_names` 30 | -------------------------------------------------------------------------------- /docs/dev/logging.md: -------------------------------------------------------------------------------- 1 | # Logging Best Practices 2 | 3 | This guide explains how to effectively use Python's logging module in our codebase, whether you're writing modules, running scripts from CLI, or working in Jupyter notebooks. 4 | 5 | ## Environment Variables 6 | 7 | The application's log level can be controlled using the `DIRECTMULTISTEP_LOG_LEVEL` environment variable: 8 | 9 | ```bash 10 | # Set log level for the current session 11 | export DIRECTMULTISTEP_LOG_LEVEL=DEBUG 12 | python your_script.py 13 | 14 | # Or set it for a single command 15 | DIRECTMULTISTEP_LOG_LEVEL=DEBUG python your_script.py 16 | ``` 17 | 18 | Valid log levels are: 19 | 20 | - `DEBUG`: Most verbose, detailed debugging information 21 | - `INFO`: General operational information (default) 22 | - `WARNING`: Unexpected situations that aren't errors 23 | - `ERROR`: Serious problems that need attention 24 | - `CRITICAL`: Critical issues that may cause program failure 25 | 26 | ## Module Development 27 | 28 | When writing a module, follow these guidelines: 29 | 30 | ```python 31 | from directmultistep.utils.logging_config import logger 32 | 33 | def my_function(): 34 | # Use appropriate log levels 35 | logger.debug("Detailed information for debugging") 36 | logger.info("General information about progress") 37 | logger.warning("Something unexpected but not error") 38 | logger.error("A more serious problem") 39 | logger.critical("Program may not be able to continue") 40 | ``` 41 | 42 | Key points: 43 | 44 | - Don't configure the logger in your modules 45 | - Always use `from directmultistep.utils.logging_config import logger` 46 | - Choose appropriate log levels 47 | - Don't use print statements for debugging 48 | - Don't add parameters like `verbose` to your functions 49 | 50 | ## Jupyter Notebook Usage 51 | 52 | For Jupyter notebooks, put this in your first cell: 53 | 54 | ```python 55 | from directmultistep.utils.logging_config import logger 56 | 57 | logger.setLevel(logging.DEBUG) # To see debug messages 58 | logger.setLevel(logging.INFO) # Back to info only 59 | 60 | ``` 61 | 62 | ## Log Levels Guide 63 | 64 | Choose the appropriate level based on the message importance: 65 | 66 | - **DEBUG**: Detailed information for diagnosing problems 67 | 68 | ```python 69 | logger.debug(f"Processing data frame with shape {df.shape}") 70 | ``` 71 | 72 | - **INFO**: Confirmation that things are working as expected 73 | 74 | ```python 75 | logger.info("Model training started") 76 | ``` 77 | 78 | - **WARNING**: Indication that something unexpected happened 79 | 80 | ```python 81 | logger.warning("Using fallback parameter value") 82 | ``` 83 | 84 | - **ERROR**: More serious problem that prevented function from working 85 | 86 | ```python 87 | logger.error("Failed to load model weights") 88 | ``` 89 | 90 | - **CRITICAL**: Program may not be able to continue 91 | 92 | ```python 93 | logger.critical("Out of memory - cannot continue processing") 94 | ``` 95 | 96 | ## Common Pitfalls 97 | 98 | 1. **Configuring Loggers in Modules**: Only configure logging in your entry points (main scripts, notebooks) 99 | 100 | 2. **Using Print Statements**: Avoid print statements for debugging; use logger.debug instead 101 | 102 | 3. **Hard-coding Log Levels**: Don't set log levels in your modules; let the application control them 103 | 104 | 4. **Creating Multiple Handlers**: Clear existing handlers in notebooks to avoid duplicate logs 105 | 106 | 5. **Using f-strings for Debug Messages**: For expensive operations, check level first: 107 | 108 | ```python 109 | # Bad (string formatting happens regardless of level) 110 | logger.debug(f"Expensive operation result: {expensive_operation()}") 111 | 112 | # Good (string formatting only happens if needed) 113 | if logger.isEnabledFor(logging.DEBUG): 114 | logger.debug(f"Expensive operation result: {expensive_operation()}") 115 | ``` 116 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # DirectMultiStep: Direct Route Generation for Multi-Step Retrosynthesis 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2405.13983-b31b1b.svg)](https://arxiv.org/abs/2405.13983) 4 | 5 | DirectMultiStep is a novel multi-step first approach for generating retrosynthesis routes in chemistry. The project provides multiple models for different retrosynthesis generation approaches. 6 | 7 | ## Quick Start 8 | 9 | ### Online Demo 10 | 11 | Try out our deployed models without any installation at [models.batistalab.com](https://models.batistalab.com). 12 | 13 | ### Installation 14 | 15 | You can install the package directly from PyPI: 16 | 17 | ```bash 18 | pip install directmultistep 19 | ``` 20 | 21 | ### Development 22 | 23 | We welcome any contributions, feel free to clone the repo and create a PR. We recommend using [uv](https://docs.astral.sh/uv/getting-started/installation/): 24 | 25 | ```bash 26 | uv venv --python 3.11 27 | source .venv/bin/activate 28 | uv pip install -e ".[dev]" 29 | ``` 30 | 31 | ### Usage Example 32 | 33 | Here's a quick example to generate a retrosynthesis route: 34 | 35 | ```python 36 | from directmultistep.generate import generate_routes 37 | from pathlib import Path 38 | 39 | # Generate a route for a target molecule 40 | target = "CNCc1cc(-c2ccccc2F)n(S(=O)(=O)c2cccnc2)c1" 41 | starting_material = "CN" 42 | 43 | # Using flash model with starting material 44 | paths = generate_routes( 45 | target, 46 | n_steps=2, 47 | starting_material=starting_material, 48 | model="flash", 49 | beam_size=5, 50 | config_path="path/to/config.yaml", 51 | ckpt_dir="path/to/checkpoints" 52 | ) 53 | 54 | # Or use explorer model to automatically determine steps 55 | paths = generate_routes( 56 | target, 57 | starting_material=starting_material, 58 | model="explorer", 59 | beam_size=5, 60 | config_path="path/to/config.yaml", 61 | ckpt_dir="path/to/checkpoints" 62 | ) 63 | ``` 64 | 65 | ## License 66 | 67 | - Code: MIT License 68 | - Paper content ([arXiv preprint](https://arxiv.org/abs/2405.13983)): CC-BY 4.0 69 | -------------------------------------------------------------------------------- /docs/stylesheets/extra.css: -------------------------------------------------------------------------------- 1 | html { 2 | font-size: 18px; 3 | } 4 | 5 | 6 | h1 { 7 | font-weight: 600 !important; 8 | color: #333 !important; 9 | } 10 | 11 | [data-md-color-scheme="slate"] h1 { 12 | font-weight: 600 !important; 13 | color: #d7d7d7 !important; 14 | } 15 | 16 | h2 { 17 | font-weight: 600 !important; 18 | color: #333 !important; 19 | } 20 | 21 | [data-md-color-scheme="slate"] h2 { 22 | font-weight: 600 !important; 23 | color: #d7d7d7 !important; 24 | } 25 | -------------------------------------------------------------------------------- /download_files.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Create necessary directories 4 | echo "Creating necessary directories..." 5 | mkdir -p data/checkpoints \ 6 | data/processed \ 7 | data/datasets/compounds \ 8 | 9 | # Define URLs 10 | CKPT_URL="https://files.batistalab.com/DirectMultiStep/ckpts" 11 | DATASET_URL="https://files.batistalab.com/DirectMultiStep/datasets" 12 | 13 | # Model checkpoint configurations 14 | declare -A models 15 | models=( 16 | ["Flash"]="flash.ckpt|38" 17 | ["Flex"]="flex.ckpt|74" 18 | ["Deep"]="deep.ckpt|159" 19 | ["Wide"]="wide.ckpt|147" 20 | ["Explorer"]="explorer.ckpt|74" 21 | ["Explorer-XL"]="explorer_xl.ckpt|192" 22 | ["Flash-20"]="flash_20.ckpt|74" 23 | ) 24 | 25 | # Download model checkpoints 26 | read -p "Do you want to download all model checkpoints? [y/N]: " all_choice 27 | case "$all_choice" in 28 | y|Y ) 29 | for model in "${!models[@]}"; do 30 | IFS="|" read -r filename size <<< "${models[$model]}" 31 | echo "Downloading ${model} model ckpt (${size} MB)..." 32 | curl -o "data/checkpoints/${filename}" "${CKPT_URL}/${filename}" 33 | done 34 | ;; 35 | * ) 36 | for model in "${!models[@]}"; do 37 | IFS="|" read -r filename size <<< "${models[$model]}" 38 | read -p "Do you want to download ${model} model ckpt? (${size} MB) [y/N]: " choice 39 | case "$choice" in 40 | y|Y ) 41 | curl -o "data/checkpoints/${filename}" "${CKPT_URL}/${filename}" 42 | ;; 43 | * ) 44 | echo "Skipping ${model} ckpt." 45 | ;; 46 | esac 47 | done 48 | ;; 49 | esac 50 | 51 | # Download preprocessed datasets 52 | read -p "Do you want to download preprocessed datasets? (19 MB) [y/N]: " choice 53 | case "$choice" in 54 | y|Y ) 55 | curl -o data/processed/proc_ds.tar.gz ${DATASET_URL}/proc_ds.tar.gz 56 | (cd data/processed && tar -xvf proc_ds.tar.gz) 57 | ;; 58 | * ) 59 | echo "Skipping preprocessed datasets." 60 | ;; 61 | esac 62 | 63 | # Download canonicalized eMols, buyables, ChEMBL-5000, and USPTO-190 64 | read -p "Do you want to download canonicalized eMols, buyables, and target datasets? (244 MB) [y/N]: " choice 65 | case "$choice" in 66 | y|Y ) 67 | echo "Downloading canonicalized eMols, buyables, ChEMBL-5000, and USPTO-190 ..." 68 | wget -O "data/compounds.zip" "https://figshare.com/ndownloader/files/53117957" 69 | (cd data && unzip -o compounds.zip && rm compounds.zip) 70 | ;; 71 | * ) 72 | echo "Skipping canonicalized eMols and buyables." 73 | ;; 74 | esac -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: DirectMultiStep Docs 2 | repo_url: https://github.com/batistagroup/DirectMultiStep 3 | repo_name: batistagroup/DirectMultiStep 4 | copyright: CC-BY 4.0 © 2025 Batista Group 5 | theme: 6 | name: material 7 | features: 8 | - content.code.copy 9 | - navigation.footer 10 | palette: 11 | # Palette toggle for light mode 12 | - media: "(prefers-color-scheme: light)" 13 | scheme: default 14 | toggle: 15 | icon: material/brightness-7 16 | name: Switch to dark mode 17 | 18 | # Palette toggle for dark mode 19 | - media: "(prefers-color-scheme: dark)" 20 | scheme: slate 21 | toggle: 22 | icon: material/brightness-4 23 | name: Switch to light mode 24 | 25 | extra_css: 26 | - stylesheets/extra.css 27 | 28 | plugins: 29 | - search 30 | - mkdocstrings 31 | - material-plausible 32 | 33 | markdown_extensions: 34 | - pymdownx.highlight: 35 | anchor_linenums: true 36 | line_spans: __span 37 | pygments_lang_class: true 38 | - pymdownx.inlinehilite 39 | - pymdownx.snippets 40 | - pymdownx.superfences 41 | - def_list 42 | - pymdownx.tasklist: 43 | custom_checkbox: true 44 | 45 | - admonition 46 | - pymdownx.details 47 | 48 | extra: 49 | analytics: 50 | provider: plausible 51 | domain: dms.batistalab.com 52 | 53 | # : If using custom domain proxy or self-hosting Plausible, 54 | # : uncomment and specify script path here: 55 | src: "https://analytics.batistalab.com/js/script.js" 56 | 57 | feedback: 58 | title: Was this page helpful? 59 | ratings: 60 | - icon: material/emoticon-happy-outline 61 | name: This page was helpful 62 | data: good 63 | note: >- 64 | Thanks for your feedback! 65 | 66 | - icon: material/emoticon-sad-outline 67 | name: This page could be improved 68 | data: bad 69 | note: >- 70 | Thanks for your feedback! -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "directmultistep" 3 | version = "1.1.2" 4 | requires-python = ">=3.11" 5 | dependencies = [ 6 | "numpy==1.26.4", 7 | "pandas==1.5.3", 8 | "pdoc3==0.11.5", 9 | "plotly==5.24.1", 10 | "lightning==2.2.5", 11 | "pyyaml==6.0.1", 12 | "rdkit==2023.9.3", 13 | "torch==2.3.0", 14 | "torchmetrics==1.6.0", 15 | "tqdm==4.67.1", 16 | "svgwrite==1.4.3", 17 | "svglib==1.5.1", 18 | "tomli>=2.2.1", 19 | ] 20 | authors = [ 21 | { name = "Anton Morgunov", email = "anton@ischemist.com" }, 22 | { name = "Yu Shee", email = "yu.shee@yale.edu" }, 23 | ] 24 | license = { text = "MIT" } 25 | readme = "README.md" 26 | 27 | [project.urls] 28 | Homepage = "https://github.com/batistagroup/DirectMultiStep" 29 | Issues = "https://github.com/batistagroup/DirectMultiStep/issues" 30 | 31 | [tool.setuptools] 32 | package-dir = { "" = "src" } 33 | 34 | 35 | [project.optional-dependencies] 36 | dev = [ 37 | "ipykernel>=6.29.5", 38 | "nbformat>=5.10.4", 39 | "rich>=13.9.4", 40 | "kaleido==0.2.1", 41 | "pre-commit==4.0.1", 42 | "mkdocs==1.6.1", 43 | "mkdocstrings-python==1.12.2", 44 | "mkdocs-material==9.5.49", 45 | "material-plausible-plugin>=0.3.0", 46 | "pytest==8.3.4", 47 | "ruff==0.4.7", 48 | "mypy==1.13.0", 49 | "isort==5.13.2", 50 | "typing-extensions==4.12.2", 51 | "mypy-extensions==1.0.0", 52 | "types-pyyaml", 53 | "types-tqdm", 54 | "types-requests", 55 | ] 56 | 57 | [tool.mypy] 58 | strict = true 59 | ignore_missing_imports = true 60 | exclude = ["tests"] 61 | disable_error_code = ["unused-ignore"] 62 | 63 | [[tool.mypy.overrides]] 64 | module = ["rdkit-stubs.*", "rdkit.*"] 65 | ignore_errors = true 66 | 67 | [tool.ruff] 68 | line-length = 120 69 | 70 | -------------------------------------------------------------------------------- /scripts/solve_compounds.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import pickle 4 | import time 5 | from pathlib import Path 6 | 7 | from tqdm import tqdm 8 | 9 | from directmultistep.generate import generate_routes 10 | from directmultistep.utils.logging_config import logger 11 | from directmultistep.utils.post_process import find_path_strings_with_commercial_sm 12 | from directmultistep.utils.pre_process import canonicalize_smiles 13 | 14 | DATA_PATH = Path(__file__).parent.parent / "data" 15 | CKPT_PATH = DATA_PATH / "checkpoints" 16 | FIG_PATH = DATA_PATH / "figures" 17 | CONFIG_PATH = DATA_PATH / "configs" / "dms_dictionary.yaml" 18 | COMPOUND_PATH = DATA_PATH / "compounds" 19 | EVAL_PATH = DATA_PATH / "evaluations" 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--part", type=int, required=True, help="Part of the targets to process") 26 | parser.add_argument("--model_name", type=str, required=True, help="Name of the model") 27 | parser.add_argument("--use_fp16", action="store_true", help="Whether to use FP16") 28 | parser.add_argument("--num_part", type=int, required=True, help="Number of parts to split the targets into") 29 | parser.add_argument("--target_name", type=str, required=True, help="Name of the target dataset") 30 | args = parser.parse_args() 31 | part = args.part 32 | model_name = args.model_name 33 | use_fp16 = args.use_fp16 34 | num_part = args.num_part 35 | target_name = args.target_name 36 | 37 | logger.info(f"part: {part}") 38 | logger.info(f"model_name: {model_name}") 39 | logger.info(f"use_fp16: {use_fp16}") 40 | logger.info(f"num_part: {num_part}") 41 | logger.info(f"target_name: {target_name}") 42 | 43 | logger.info("Loading targets and stock compounds") 44 | if target_name == "uspto_190": 45 | with open(COMPOUND_PATH / "uspto_190.txt", "r") as f: 46 | targets = f.read().splitlines() 47 | elif target_name == "chembl": 48 | with open(COMPOUND_PATH / "chembl_targets.json", "r") as f: 49 | targets = json.load(f) 50 | else: 51 | logger.error(f"{target_name} is not a valid target name") 52 | raise Exception("Not valid target_name") 53 | 54 | # eMols is available at https://github.com/binghong-ml/retro_star 55 | # make sure to canonicalize the SMILES strings before using them 56 | with open(COMPOUND_PATH / "eMolecules.txt", "r") as f: 57 | emol_stock_set = set(f.read().splitlines()) 58 | # buyables-stock is available at https://github.com/jihye-roh/higherlev_retro 59 | # make sure to canonicalize the SMILES strings before using them 60 | with open(COMPOUND_PATH / "buyables-stock.txt", "r") as f: 61 | buyables_stock_set = set(f.read().splitlines()) 62 | 63 | chunk_size = len(targets) // num_part 64 | start_index = (part - 1) * chunk_size 65 | end_index = part * chunk_size if part < num_part else len(targets) 66 | targets = targets[start_index:end_index] 67 | 68 | folder_name = f"{target_name}_{model_name}_fp16" if use_fp16 else f"{target_name}_{model_name}" 69 | save_dir = EVAL_PATH / folder_name 70 | save_dir.mkdir(parents=True, exist_ok=True) 71 | SAVED_PATH = save_dir / f"paths_part_{part}.pkl" 72 | SAVED_COUNT_PATH = save_dir / f"count_part_{part}.json" 73 | 74 | logger.info("Retrosythesis starting") 75 | start = time.time() 76 | 77 | all_paths = [] 78 | raw_solved_count = 0 79 | buyable_solved_count = 0 80 | emol_solved_count = 0 81 | 82 | for target in tqdm(targets): 83 | target = canonicalize_smiles(target) 84 | raw_paths = [] 85 | if model_name == "explorer XL" or model_name == "explorer": 86 | raw_paths += generate_routes( 87 | target, 88 | n_steps=None, 89 | starting_material=None, 90 | beam_size=50, 91 | model=model_name, 92 | config_path=CONFIG_PATH, 93 | ckpt_dir=CKPT_PATH, 94 | commercial_stock=None, 95 | use_fp16=use_fp16, 96 | ) 97 | else: 98 | for step in range(2, 9): 99 | raw_paths += generate_routes( 100 | target, 101 | n_steps=step, 102 | starting_material=None, 103 | beam_size=50, 104 | model=model_name, 105 | config_path=CONFIG_PATH, 106 | ckpt_dir=CKPT_PATH, 107 | commercial_stock=None, 108 | use_fp16=use_fp16, 109 | ) 110 | buyables_paths = find_path_strings_with_commercial_sm(raw_paths, commercial_stock=buyables_stock_set) 111 | emol_paths = find_path_strings_with_commercial_sm(raw_paths, commercial_stock=emol_stock_set) 112 | if len(raw_paths) > 0: 113 | raw_solved_count += 1 114 | if len(buyables_paths) > 0: 115 | buyable_solved_count += 1 116 | if len(emol_paths) > 0: 117 | emol_solved_count += 1 118 | logger.info(f"Current raw solved count: {raw_solved_count}") 119 | logger.info(f"Current buyable solved count: {buyable_solved_count}") 120 | logger.info(f"Current emol solved count: {emol_solved_count}") 121 | all_paths.append([raw_paths, buyables_paths, emol_paths]) 122 | 123 | end = time.time() 124 | 125 | results = { 126 | "raw_solved_count": raw_solved_count, 127 | "buyable_solved_count": buyable_solved_count, 128 | "emol_solved_count": emol_solved_count, 129 | "time_elapsed": end - start, 130 | } 131 | logger.info(f"Results: {results}") 132 | with open(SAVED_COUNT_PATH, "w") as f: 133 | json.dump(results, f) 134 | with open(SAVED_PATH, "wb") as f: 135 | pickle.dump(all_paths, f) 136 | -------------------------------------------------------------------------------- /src/directmultistep/__init__.py: -------------------------------------------------------------------------------- 1 | """DirectMultiStep - Direct Route Generation for Multi-Step Retrosynthesis.""" 2 | 3 | from directmultistep.utils.logging_config import setup_logging 4 | 5 | setup_logging() 6 | -------------------------------------------------------------------------------- /src/directmultistep/analysis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/batistagroup/DirectMultiStep/bb445196ce743317c179ccf8a7f5ee3966051cda/src/directmultistep/analysis/__init__.py -------------------------------------------------------------------------------- /src/directmultistep/analysis/paper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/batistagroup/DirectMultiStep/bb445196ce743317c179ccf8a7f5ee3966051cda/src/directmultistep/analysis/paper/__init__.py -------------------------------------------------------------------------------- /src/directmultistep/analysis/style.py: -------------------------------------------------------------------------------- 1 | """Color palettes and styling utilities for plotly figures.""" 2 | 3 | from typing import Any 4 | 5 | import plotly.graph_objects as go 6 | 7 | # fmt:off 8 | # Color palettes 9 | colors_light = [ 10 | '#ff4d4d', '#ff7f50', '#ffff00', '#00ff7f', '#00ffff', 11 | '#1e90ff', '#9370db', '#ff69b4', '#cd5c5c', '#8fbc8f', 12 | '#ffd700', '#32cd32', '#00bfff', '#ff00ff', '#ff8c00' 13 | ] 14 | 15 | colors_dark = [ 16 | '#cc0000', '#cc5500', '#cccc00', '#00cc66', '#00cccc', 17 | '#0066cc', '#6a5acd', '#ff1493', '#8b0000', '#2e8b57', 18 | '#daa520', '#228b22', '#0099cc', '#cc00cc', '#d2691e' 19 | ] 20 | 21 | colors_names: dict[str, str] = { 22 | "yellow": "#ffdd00", "red": "#ff006d", "cyan": "#00ffff", "purple": "#8f00ff", "orange": "#ff7d00", 23 | "lime": "#adff02", "green": "#04e762", "pink": "#ff00cc", "white": "#ffffff", "blue": "#0d41e1", 24 | "sky": "#0080ff", "spring": "#00F59B" 25 | } 26 | 27 | publication_colors: dict[str, str] = { 28 | "primary_blue": "#6A7BC8", 29 | "dark_blue": "#4C61BD", 30 | "light_blue": "#8AA1E9", 31 | "purple": "#A064B9", 32 | "dark_purple": "#763F8D" 33 | } 34 | 35 | colors_gray : list[str] = ["#333333", "#666666", "#999999", "#CCCCCC"] 36 | colors_blue: list[str] = ["#3a0ca3", "#3f37c9", "#4361ee", "#4895ef"] 37 | colors_purple: list[str] = ["#6411ad", "#822faf", "#973aa8", "#c05299"] 38 | colors_red: list[str] = ["#800f2f", "#a4133c", "#c9184a"] 39 | # fmt:on 40 | 41 | # Universal font settings 42 | FONT_FAMILY = "Helvetica" 43 | FONT_COLOR = "#333333" 44 | 45 | # Font sizes for different elements 46 | FONT_SIZES: dict[str, int] = { 47 | "title": 20, 48 | "axis_title": 16, 49 | "tick_label": 16, 50 | "subtitle": 11, 51 | "legend": 12, 52 | "subplot_title": 16, 53 | } 54 | 55 | # Universal axis style settings 56 | AXIS_STYLE: dict[str, Any] = { 57 | "showgrid": True, 58 | "gridwidth": 1, 59 | "gridcolor": "#E7E7E7", 60 | "zeroline": False, 61 | "linewidth": 2, 62 | "linecolor": "#333333", 63 | } 64 | 65 | # Universal layout settings 66 | LAYOUT_STYLE: dict[str, Any] = { 67 | "plot_bgcolor": "#FBFCFF", 68 | "paper_bgcolor": "#FBFCFF", 69 | "margin": dict(t=40, b=40, r=40), 70 | } 71 | 72 | # Development style settings 73 | DEVELOPMENT_STYLE: dict[str, Any] = { 74 | "template": "plotly_dark", 75 | "plot_bgcolor": "black", 76 | "paper_bgcolor": "black", 77 | "font": dict(color="white"), 78 | } 79 | 80 | 81 | def get_font_dict(size: int) -> dict[str, Any]: 82 | """Helper function to create consistent font dictionaries. 83 | 84 | Args: 85 | size: Font size to use 86 | 87 | Returns: 88 | Dictionary with font settings 89 | """ 90 | return dict(family=FONT_FAMILY, size=size, color=FONT_COLOR) 91 | 92 | 93 | def apply_publication_fonts(fig: go.Figure) -> None: 94 | """Apply publication-quality font settings to a figure. 95 | 96 | Args: 97 | fig: A plotly figure 98 | """ 99 | # Update global font 100 | fig.update_layout(font=get_font_dict(FONT_SIZES["tick_label"])) 101 | 102 | # Update title font if title exists 103 | if fig.layout.title is not None: 104 | fig.layout.title.update(font=get_font_dict(FONT_SIZES["title"])) 105 | 106 | 107 | def update_axis(axis: go.layout.XAxis | go.layout.YAxis, axis_style: dict[str, Any]) -> None: 108 | """Helper function to update a single axis with publication styling. 109 | 110 | Args: 111 | axis: Axis to update 112 | axis_style: Style parameters to apply 113 | """ 114 | axis.update( 115 | axis_style, title_font=get_font_dict(FONT_SIZES["axis_title"]), tickfont=get_font_dict(FONT_SIZES["tick_label"]) 116 | ) 117 | 118 | 119 | def apply_axis_style(fig: go.Figure, row: int | None = None, col: int | None = None, **kwargs: Any) -> None: 120 | """Apply publication-quality axis styling to a figure. 121 | 122 | Args: 123 | fig: A plotly figure 124 | row: Optional row index for subplots 125 | col: Optional column index for subplots 126 | **kwargs: Additional axis style parameters to override defaults 127 | """ 128 | axis_style = AXIS_STYLE.copy() 129 | axis_style.update(kwargs) 130 | 131 | if row is not None and col is not None: 132 | update_axis(fig.get_xaxes()[row - 1], axis_style) 133 | update_axis(fig.get_yaxes()[col - 1], axis_style) 134 | else: 135 | update_axis(fig.layout.xaxis, axis_style) 136 | update_axis(fig.layout.yaxis, axis_style) 137 | 138 | 139 | def apply_publication_style(fig: go.Figure, **kwargs: Any) -> None: 140 | """Apply all publication-quality styling to a figure. 141 | 142 | Args: 143 | fig: A plotly figure 144 | show_legend: Whether to show and style the legend 145 | **kwargs: Additional layout parameters to override defaults 146 | """ 147 | # Apply fonts 148 | apply_publication_fonts(fig) 149 | 150 | # Apply axis style to all axes 151 | # Handle both single plot and subplot cases by looking for axis objects in layout 152 | for key in fig.layout: 153 | if key.startswith("xaxis") or key.startswith("yaxis"): 154 | update_axis(getattr(fig.layout, key), AXIS_STYLE) 155 | 156 | layout_style: dict[str, Any] = LAYOUT_STYLE.copy() 157 | layout_style.update(kwargs) 158 | fig.update_layout(layout_style) 159 | 160 | 161 | def apply_development_style(fig: go.Figure) -> None: 162 | """Apply dark theme development styling to a figure. 163 | 164 | This applies a dark theme with black background, suitable for development 165 | and debugging visualizations. 166 | 167 | Args: 168 | fig: A plotly figure 169 | """ 170 | fig.update_layout(**DEVELOPMENT_STYLE) 171 | -------------------------------------------------------------------------------- /src/directmultistep/analysis/training.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | 4 | import pandas as pd 5 | import plotly.graph_objects as go 6 | 7 | from directmultistep.analysis import style 8 | from directmultistep.utils.logging_config import logger 9 | 10 | 11 | def load_training_df(train_path: Path, run_name: str, ignore_ids: list[int] | None = None) -> pd.DataFrame: 12 | logger.debug(f"Loading {run_name=}") 13 | log_path = train_path / run_name / "lightning_logs" 14 | dfs = [] 15 | versions = [log.name for log in log_path.glob("version_*")] 16 | logger.debug(f"Found versions: {versions} for {run_name}") 17 | if ignore_ids is not None: 18 | ignored_folders = {f"version_{i}" for i in ignore_ids} 19 | else: 20 | ignored_folders = set() 21 | for version in sorted(versions, key=lambda x: int(x.split("_")[1])): 22 | if version in ignored_folders: 23 | continue 24 | temp_df = pd.read_csv(log_path / version / "metrics.csv") 25 | logger.debug(f"Loaded df with shape {temp_df.shape}") 26 | dfs.append(temp_df) 27 | df = pd.concat(dfs) 28 | df = df.reset_index(drop=True) 29 | return df 30 | 31 | 32 | def create_train_trace(df: pd.DataFrame, run_name: str, color: str, x_axis: str) -> go.Scatter: 33 | return go.Scatter( 34 | x=df[x_axis], 35 | y=df["train_loss"], 36 | mode="lines", 37 | name=f"train_loss {run_name}", 38 | line_color=color, 39 | showlegend=True, 40 | legendgroup=run_name, 41 | ) 42 | 43 | 44 | def create_val_trace(df: pd.DataFrame, run_name: str, color: str, x_axis: str) -> go.Scatter: 45 | val_df = df.dropna(subset=["val_loss"]) 46 | return go.Scatter( 47 | x=val_df[x_axis], 48 | y=val_df["val_loss"], 49 | mode="lines+markers", 50 | name=f"val_loss {run_name}", 51 | line_color=color, 52 | showlegend=True, 53 | hovertemplate="%{fullData.name}
" 54 | + "epoch=%{customdata}
" 55 | + f"{x_axis}=%{{x}}
" 56 | + "val_loss=%{y}", 57 | customdata=val_df["epoch"], 58 | ) 59 | 60 | 61 | @dataclass 62 | class RunConfig: 63 | """Configuration for a training run visualization.""" 64 | 65 | run_name: str # Folder name of the run 66 | trace_name: str # Display name for the traces 67 | include_val: bool = True # Whether to include validation curve 68 | ignore_ids: list[int] | None = None # Version IDs to ignore when loading data 69 | 70 | 71 | def plot_training_curves( 72 | train_path: Path, 73 | runs: list[RunConfig], 74 | x_axis: str = "processed_tokens", 75 | ) -> go.Figure: 76 | """Create a figure showing training and validation curves for multiple runs. 77 | 78 | Args: 79 | train_path: Path to training data directory 80 | runs: List of run configurations specifying what and how to plot 81 | x_axis: Column to use for x-axis values ("processed_tokens", "epoch", or "step") 82 | 83 | Returns: 84 | Plotly figure with training and validation curves 85 | """ 86 | traces = [] 87 | for i, run in enumerate(runs): 88 | df = load_training_df(train_path, run.run_name, run.ignore_ids) 89 | color_idx = i % len(style.colors_light) 90 | traces.append( 91 | create_train_trace(df, run.trace_name, style.colors_light[color_idx % len(style.colors_light)], x_axis) 92 | ) 93 | if run.include_val: 94 | traces.append( 95 | create_val_trace(df, run.trace_name, style.colors_dark[color_idx % len(style.colors_dark)], x_axis) 96 | ) 97 | 98 | fig = go.Figure(data=traces) 99 | 100 | fig.update_layout( 101 | title="Training Loss", 102 | xaxis_title=x_axis, 103 | yaxis_title="Loss", 104 | width=1000, 105 | ) 106 | style.apply_development_style(fig) 107 | 108 | return fig 109 | 110 | 111 | def get_lr_trace(df: pd.DataFrame, run_name: str) -> go.Scatter: 112 | return go.Scatter( 113 | x=df["step"], 114 | y=df["train_lr"], 115 | mode="lines", 116 | name=f"learning rate {run_name}", 117 | showlegend=True, 118 | legendgroup=run_name, 119 | ) 120 | 121 | 122 | def plot_learning_rates( 123 | train_path: Path, 124 | runs: list[RunConfig], 125 | ) -> go.Figure: 126 | """Create a figure showing learning rate curves for multiple runs. 127 | 128 | Args: 129 | train_path: Path to training data directory 130 | runs: List of run configurations specifying what and how to plot 131 | 132 | Returns: 133 | Plotly figure with learning rate curves 134 | """ 135 | traces = [] 136 | for run in runs: 137 | df = load_training_df(train_path, run.run_name, run.ignore_ids) 138 | traces.append(get_lr_trace(df, run.trace_name)) 139 | 140 | fig = go.Figure(data=traces) 141 | 142 | fig.update_layout( 143 | title="Learning Rate", 144 | xaxis_title="Step", 145 | yaxis_title="Learning Rate", 146 | width=800, 147 | ) 148 | style.apply_development_style(fig) 149 | 150 | return fig 151 | 152 | 153 | if __name__ == "__main__": 154 | train_path = Path("data/training") 155 | 156 | runs = [ 157 | RunConfig(run_name="baseline_run", trace_name="Baseline Model"), 158 | RunConfig(run_name="improved_run", trace_name="Improved Model", include_val=True), 159 | RunConfig( 160 | run_name="experimental_run", 161 | trace_name="Experimental Model", 162 | include_val=False, # Only show training curve 163 | ), 164 | ] 165 | 166 | fig = plot_training_curves(train_path, runs) 167 | fig.show() 168 | -------------------------------------------------------------------------------- /src/directmultistep/generate.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Literal, cast 3 | 4 | import torch 5 | import torch.nn as nn 6 | import yaml 7 | 8 | from directmultistep.generation.tensor_gen import BeamSearchOptimized as BeamSearch 9 | from directmultistep.model import ModelFactory 10 | from directmultistep.utils.dataset import RoutesProcessing 11 | from directmultistep.utils.post_process import find_valid_paths, process_path_single 12 | 13 | ModelName = Literal["flash", "flash-20M", "flex-20M", "deep", "wide", "explorer", "explorer XL"] 14 | 15 | MODEL_CHECKPOINTS = { 16 | "flash": ("flash_10M", "flash.ckpt"), 17 | "flash-20M": ("flash_20M", "flash_20.ckpt"), 18 | "flex-20M": ("flex_20M", "flex.ckpt"), 19 | "deep": ("deep_40M", "deep.ckpt"), 20 | "wide": ("wide_40M", "wide.ckpt"), 21 | "explorer": ("explorer_19M", "explorer.ckpt"), 22 | "explorer XL": ("explorer_xl_50M", "explorer_xl.ckpt"), 23 | } 24 | 25 | 26 | def validate_model_constraints(model_name: ModelName, n_steps: int | None, starting_material: str | None) -> None: 27 | """Validate model-specific constraints for route generation.""" 28 | if model_name in ["deep", "wide"] and starting_material is not None: 29 | raise ValueError(f"{model_name} model does not support starting material specification") 30 | if model_name == "explorer" and n_steps is not None: 31 | raise ValueError("explorer model does not support step count specification") 32 | if model_name == "explorer XL" and (n_steps is not None or starting_material is not None): 33 | raise ValueError("explorer XL model does not support step count or starting material specification") 34 | 35 | 36 | def load_model(model_name: ModelName, ckpt_dir: Path, use_fp16: bool = False) -> torch.nn.Module: 37 | """Load a model by name from the available checkpoints. 38 | 39 | Args: 40 | model_name: Name of the model to load 41 | ckpt_dir: Directory containing model checkpoints 42 | use_fp16: Whether to use half precision (FP16) for model weights 43 | """ 44 | if model_name not in MODEL_CHECKPOINTS: 45 | raise ValueError(f"Unknown model name: {model_name}. Available models: {list(MODEL_CHECKPOINTS.keys())}") 46 | 47 | preset_name, ckpt_file = MODEL_CHECKPOINTS[model_name] 48 | device = ModelFactory.determine_device() 49 | model = ModelFactory.from_preset(preset_name, compile_model=False).create_model() 50 | model = ModelFactory.load_checkpoint(model, ckpt_dir / ckpt_file, device) 51 | 52 | if use_fp16: 53 | model = model.half() # Convert to FP16 54 | 55 | return cast(nn.Module, model) 56 | 57 | 58 | def create_beam_search(model: torch.nn.Module, beam_size: int, config_path: Path) -> tuple[int, int, BeamSearch]: 59 | """Create a beam search object and return product/sm max lengths and the beam search object.""" 60 | device = next(model.parameters()).device 61 | with open(config_path, "rb") as file: 62 | data = yaml.safe_load(file) 63 | idx_to_token = data["invdict"] 64 | product_max_length = data["product_max_length"] 65 | sm_max_length = data["sm_max_length"] 66 | 67 | beam = BeamSearch( 68 | model=model, 69 | beam_size=beam_size, 70 | start_idx=0, 71 | pad_idx=52, 72 | end_idx=22, 73 | max_length=1074, 74 | idx_to_token=idx_to_token, 75 | device=device, 76 | ) 77 | return product_max_length, sm_max_length, beam 78 | 79 | 80 | def prepare_input_tensors( 81 | target: str, 82 | n_steps: int | None, 83 | starting_material: str | None, 84 | rds: RoutesProcessing, 85 | product_max_length: int, 86 | sm_max_length: int, 87 | use_fp16: bool = False, 88 | ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: 89 | """Prepare input tensors for the model. 90 | Args: 91 | target: SMILES string of the target molecule. 92 | n_steps: Number of synthesis steps. 93 | starting_material: SMILES string of the starting material, if any. 94 | rds: RoutesProcessing object for tokenization. 95 | product_max_length: Maximum length of the product SMILES sequence. 96 | sm_max_length: Maximum length of the starting material SMILES sequence. 97 | use_fp16: Whether to use half precision (FP16) for tensors. 98 | path_start: Initial path string to start generation from. 99 | Returns: 100 | A tuple containing: 101 | - encoder_inp: Input tensor for the encoder. 102 | - steps_tens: Tensor of the number of steps, or None if not provided. 103 | - path_tens: Initial path tensor for the decoder. 104 | """ 105 | prod_tens = rds.smile_to_tokens(target, product_max_length) 106 | if starting_material: 107 | sm_tens = rds.smile_to_tokens(starting_material, sm_max_length) 108 | encoder_inp = torch.cat([prod_tens, sm_tens], dim=0).unsqueeze(0) 109 | else: 110 | encoder_inp = torch.cat([prod_tens], dim=0).unsqueeze(0) 111 | 112 | steps_tens = torch.tensor([n_steps]).unsqueeze(0) if n_steps is not None else None 113 | path_start = "{'smiles':'" + target + "','children':[{'smiles':'" 114 | path_tens = rds.path_string_to_tokens(path_start, max_length=None, add_eos=False).unsqueeze(0) 115 | 116 | if use_fp16: 117 | encoder_inp = encoder_inp.half() 118 | if steps_tens is not None: 119 | steps_tens = steps_tens.half() 120 | path_tens = path_tens.half() 121 | 122 | return encoder_inp, steps_tens, path_tens 123 | 124 | 125 | def generate_routes( 126 | target: str, 127 | n_steps: int | None, 128 | starting_material: str | None, 129 | beam_size: int, 130 | model: ModelName | torch.nn.Module, 131 | config_path: Path, 132 | ckpt_dir: Path | None = None, 133 | commercial_stock: set[str] | None = None, 134 | use_fp16: bool = False, 135 | ) -> list[str]: 136 | """Generate synthesis routes using the model. 137 | 138 | Args: 139 | target: SMILES string of the target molecule 140 | n_steps: Number of synthesis steps. If None, will try multiple steps 141 | starting_material: Optional SMILES string of the starting material 142 | beam_size: Beam size for the beam search 143 | model: Either a model name or a torch.nn.Module 144 | config_path: Path to the model configuration file 145 | ckpt_dir: Directory containing model checkpoints (required if model is a string) 146 | stock_set: Set of commercially available starting materials (SMILES). 147 | use_fp16: Whether to use half precision (FP16) for model weights and computations 148 | """ 149 | # Handle model loading and validation 150 | if isinstance(model, str): 151 | if ckpt_dir is None: 152 | raise ValueError("ckpt_dir must be provided when model is specified by name") 153 | validate_model_constraints(model, n_steps, starting_material) 154 | model = load_model(model, ckpt_dir, use_fp16) 155 | 156 | rds = RoutesProcessing(metadata_path=config_path) 157 | product_max_length, sm_max_length, beam_obj = create_beam_search(model, beam_size, config_path) 158 | 159 | # Prepare input tensors 160 | encoder_inp, steps_tens, path_tens = prepare_input_tensors( 161 | target, n_steps, starting_material, rds, product_max_length, sm_max_length, use_fp16 162 | ) 163 | 164 | # Run beam search 165 | device = ModelFactory.determine_device() 166 | all_beam_results_NS2: list[list[tuple[str, float]]] = [] 167 | beam_result_BS2 = beam_obj.decode( 168 | src_BC=encoder_inp.to(device), 169 | steps_B1=steps_tens.to(device) if steps_tens is not None else None, 170 | path_start_BL=path_tens.to(device), 171 | ) 172 | for beam_result_S2 in beam_result_BS2: 173 | all_beam_results_NS2.append(beam_result_S2) 174 | 175 | # Process results 176 | valid_paths_NS2n = find_valid_paths(all_beam_results_NS2) 177 | correct_paths_NS2n = process_path_single( 178 | paths_NS2n=valid_paths_NS2n, 179 | true_products=[target], 180 | true_reacs=[starting_material] if starting_material else None, 181 | commercial_stock=commercial_stock, 182 | ) 183 | return [beam_result[0] for beam_result in correct_paths_NS2n[0]] 184 | -------------------------------------------------------------------------------- /src/directmultistep/generation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/batistagroup/DirectMultiStep/bb445196ce743317c179ccf8a7f5ee3966051cda/src/directmultistep/generation/__init__.py -------------------------------------------------------------------------------- /src/directmultistep/generation/generation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Shape suffixes convention inspired by 3 | https://medium.com/@NoamShazeer/shape-suffixes-good-coding-style-f836e72e24fd 4 | 5 | B: batch size 6 | C: the length of the input on which conditioning is done 7 | in our case input_max_length 8 | L: sequence length for decoder, in our case output_max_length 9 | M: memory length (length of sequence being attended to) 10 | D: model dimension (sometimes called d_model or embedding_dim) 11 | V: vocabulary size 12 | F: feed-forward subnetwork hidden size 13 | H: number of attention heads in a layer 14 | K: size of each attention key or value (sometimes called d_kv) 15 | """ 16 | 17 | import numpy as np 18 | import numpy.typing as npt 19 | import torch 20 | import torch.nn as nn 21 | from tqdm import tqdm 22 | 23 | # Define types 24 | Tensor = torch.Tensor 25 | BeamSearchOutput = list[list[tuple[str, float]]] 26 | 27 | 28 | class BeamSearch: 29 | def __init__( 30 | self, 31 | model: nn.Module, 32 | beam_size: int, 33 | start_idx: int, 34 | pad_idx: int, 35 | end_idx: int, 36 | max_length: int, 37 | idx_to_token: dict[int, str], 38 | device: torch.device, 39 | ): 40 | self.model = model 41 | self.beam_size = beam_size 42 | self.start_idx = start_idx 43 | self.pad_idx = pad_idx 44 | self.end_idx = end_idx 45 | self.device = device 46 | self.max_length = max_length 47 | self.idx_to_token = idx_to_token 48 | 49 | def _prepare_beam_tensors( 50 | self, src_BC: Tensor, enc_src_BCD: Tensor 51 | ) -> tuple[Tensor, Tensor, list[list[list[int]]], npt.NDArray[np.float64], Tensor]: 52 | B = enc_src_BCD.shape[0] 53 | S = self.beam_size 54 | 55 | beam_enc_WCD = enc_src_BCD.repeat_interleave(S, dim=0) # W = B * S 56 | beam_src_WC = src_BC.repeat_interleave(S, dim=0) 57 | beam_src_mask_W11C = (beam_src_WC != self.pad_idx).unsqueeze(1).unsqueeze(2) 58 | beam_idxs_BS1_nt = [[[self.start_idx] for _ in range(S)] for _ in range(B)] 59 | beam_log_probs_BS_nt = np.zeros((B, S)) 60 | cur_targets_B1 = torch.LongTensor([[self.start_idx] for _ in range(B)]).to(self.device) 61 | 62 | return ( 63 | beam_enc_WCD, 64 | beam_src_mask_W11C, 65 | beam_idxs_BS1_nt, 66 | beam_log_probs_BS_nt, 67 | cur_targets_B1, 68 | ) 69 | 70 | def _expand_and_normalize_candidates( 71 | self, 72 | output_BSLS: Tensor, 73 | beam_idxs_BSL_nt: list[list[list[int]]], 74 | beam_log_probs_BS_nt: npt.NDArray[np.float64], 75 | ) -> tuple[list[list[float]], list[list[list[int]]]]: 76 | """Generate expanded candidate sequences and their probabilities.""" 77 | B = len(beam_idxs_BSL_nt) 78 | S = self.beam_size 79 | 80 | candidate_probs_BS_nt: list[list[float]] = [[] for _ in range(B)] 81 | candidate_seqs_BSL_nt: list[list[list[int]]] = [[] for _ in range(B)] # S is actually 2S 82 | # candidate_probs_B_nt, candidate_seqs_B_nt = 83 | # _candidates(output_BSLS, beam_idxs_BS1_nt, beam_log_probs_BS_nt) 84 | 85 | for B_idx in range(B): # former n 86 | for S_idx in range(S): # former k 87 | if self.end_idx not in beam_idxs_BSL_nt[B_idx][S_idx]: 88 | k_prob = beam_log_probs_BS_nt[B_idx][S_idx] 89 | normalized_probs_LS = torch.log_softmax(output_BSLS[B_idx, S_idx, -1, :], dim=-1) 90 | k_prob_vec_S = k_prob + normalized_probs_LS 91 | top_k_S = torch.topk(k_prob_vec_S, S).indices 92 | 93 | # find if any of the top_k_S is greater than pad_idx 94 | assert torch.any(top_k_S > self.pad_idx).item() is False 95 | # top_k_S[top_k_S > self.pad_idx] = self.pad_idx 96 | 97 | for idx in top_k_S: 98 | candidate_seqs_BSL_nt[B_idx].append(beam_idxs_BSL_nt[B_idx][S_idx] + [idx.item()]) 99 | candidate_probs_BS_nt[B_idx].append(k_prob_vec_S[idx].item()) 100 | else: 101 | candidate_seqs_BSL_nt[B_idx].append(beam_idxs_BSL_nt[B_idx][S_idx] + [self.pad_idx]) 102 | candidate_probs_BS_nt[B_idx].append(beam_log_probs_BS_nt[B_idx][S_idx]) 103 | 104 | return candidate_probs_BS_nt, candidate_seqs_BSL_nt 105 | 106 | def _select_top_k_candidates( 107 | self, 108 | candidate_probs_BS_nt: list[list[float]], 109 | candidate_seqs_BSL_nt: list[list[list[int]]], # S is actually 2S 110 | debug: bool = False, 111 | ) -> list[npt.NDArray[np.int_]]: 112 | """Normalize probabilities and select top-k candidates.""" 113 | B = len(candidate_probs_BS_nt) 114 | best_k_B_nt = [] 115 | 116 | for B_idx in range(B): 117 | seq_lengths = [ 118 | len([token for token in seq if token != self.pad_idx]) for seq in candidate_seqs_BSL_nt[B_idx] 119 | ] 120 | normalized_probs_S = np.array(candidate_probs_BS_nt[B_idx]) / (np.sqrt(seq_lengths) + 1e-6) 121 | if debug: 122 | breakpoint() 123 | 124 | best_k = np.argsort(normalized_probs_S)[-self.beam_size :][::-1] 125 | best_k_B_nt.append(best_k) 126 | return best_k_B_nt 127 | 128 | def _generate_final_outputs( 129 | self, 130 | beam_idxs_BSL_nt: list[list[list[int]]], 131 | beam_log_probs_BS_nt: npt.NDArray[np.float64], 132 | ) -> BeamSearchOutput: 133 | """Convert index sequences to final outputs.""" 134 | B = len(beam_idxs_BSL_nt) 135 | outputs_B2_nt: list[list[tuple[str, float]]] = [[] for _ in range(B)] 136 | 137 | for B_idx in range(B): 138 | for S_idx in range(self.beam_size): 139 | output_str = "" 140 | for L_idx in beam_idxs_BSL_nt[B_idx][S_idx]: 141 | if L_idx == self.end_idx: 142 | break 143 | output_str += self.idx_to_token[L_idx] 144 | outputs_B2_nt[B_idx].append((output_str[5:], beam_log_probs_BS_nt[B_idx][S_idx])) 145 | 146 | return outputs_B2_nt 147 | 148 | def decode(self, src_BC: Tensor, steps_B1: Tensor) -> BeamSearchOutput: 149 | """ 150 | src_BC: product + one_sm 151 | steps_B1: number of steps 152 | 153 | define S as beam_size 154 | define W as B*S (W for Window, a window for output) 155 | _nt stands for not a tensor, a regular list 156 | """ 157 | self.model.eval() 158 | src_mask_B11C = (src_BC != self.pad_idx).unsqueeze(1).unsqueeze(2) 159 | enc_src_BCD = self.model.encoder(src_BC.long(), src_mask_B11C, steps_B1) 160 | # prepare tensors for beam search 161 | ( 162 | beam_enc_WCD, 163 | beam_src_mask_W11C, 164 | beam_idxs_BS1_nt, 165 | beam_log_probs_BS_nt, 166 | cur_targets_B1, 167 | ) = self._prepare_beam_tensors(src_BC, enc_src_BCD) 168 | with torch.no_grad(): 169 | output_BLV = self.model.decoder( 170 | trg_BL=cur_targets_B1, 171 | enc_src_BCD=enc_src_BCD, 172 | src_mask_B11C=src_mask_B11C, 173 | trg_mask_B1LL=None, 174 | ) 175 | 176 | B, L, V = output_BLV.shape 177 | S = self.beam_size 178 | if self.beam_size > output_BLV.shape[-1]: 179 | # beam_size is greater than vocabulary, add padding 180 | pad_tensor = torch.full((B, L, S - V), float("-inf"), device=output_BLV.device) 181 | output_BLS = torch.cat([output_BLV, pad_tensor], dim=-1) 182 | else: 183 | output_BLS = output_BLV 184 | 185 | normalized_probs_BLS = torch.softmax(output_BLS, dim=-1) 186 | # Update initial target sequences with the first step probabilities 187 | top_k_BS_nt = [] 188 | for B_idx in range(B): 189 | sorted_idx_top_S = torch.argsort(normalized_probs_BLS[B_idx, -1, :])[-1 * torch.arange(1, 1 + S, 1)] 190 | sorted_idx_top_S[sorted_idx_top_S > self.pad_idx] = self.pad_idx 191 | top_k_BS_nt.append(sorted_idx_top_S.tolist()) 192 | for B_idx in range(B): 193 | for S_idx in range(S): 194 | beam_idxs_BS1_nt[B_idx][S_idx].append((chosen_idx := top_k_BS_nt[B_idx][S_idx])) 195 | beam_log_probs_BS_nt[B_idx][S_idx] += np.log(normalized_probs_BLS[B_idx, -1, chosen_idx].item()) 196 | 197 | # Expand beam search over multiple decoding steps 198 | beam_idxs_BSL_nt = beam_idxs_BS1_nt 199 | for step in tqdm(range(self.max_length - 2)): 200 | trg_idxs_WL = torch.LongTensor(beam_idxs_BSL_nt).view(B * S, -1).to(self.device) 201 | with torch.no_grad(): 202 | output_WLV = self.model.decoder( 203 | trg_BL=trg_idxs_WL, 204 | enc_src_BCD=beam_enc_WCD, 205 | src_mask_B11C=beam_src_mask_W11C, 206 | trg_mask_B1LL=None, 207 | ) 208 | W, L, V = output_WLV.shape 209 | output_BSLV = output_WLV.view(B, S, L, V) 210 | if self.beam_size > output_WLV.shape[-1]: 211 | # beam_size is greater than vocabulary, add padding 212 | pad_tensor = torch.full((B, S, L, S - V), float("-inf"), device=output_WLV.device) 213 | output_BSLS = torch.cat([output_BSLV, pad_tensor], dim=-1) 214 | else: 215 | output_BSLS = output_BSLV 216 | ( 217 | candidate_probs_BS_nt, 218 | candidate_seqs_BSL_nt, 219 | ) = self._expand_and_normalize_candidates(output_BSLS, beam_idxs_BS1_nt, beam_log_probs_BS_nt) 220 | 221 | best_k_B_nt = self._select_top_k_candidates( 222 | candidate_probs_BS_nt, 223 | candidate_seqs_BSL_nt, 224 | debug=False, # step>113-2 225 | ) 226 | 227 | for B_idx in range(B): 228 | for S_idx in range(S): 229 | beam_idxs_BSL_nt[B_idx][S_idx] = candidate_seqs_BSL_nt[B_idx][best_k_B_nt[B_idx][S_idx]] 230 | beam_log_probs_BS_nt[B_idx][S_idx] = candidate_probs_BS_nt[B_idx][best_k_B_nt[B_idx][S_idx]] 231 | if step > 150: 232 | break 233 | # if step > 113-2: 234 | # breakpoint() 235 | return self._generate_final_outputs(beam_idxs_BSL_nt, beam_log_probs_BS_nt) 236 | -------------------------------------------------------------------------------- /src/directmultistep/generation/tensor_gen.py: -------------------------------------------------------------------------------- 1 | """ 2 | Shape suffixes convention inspired by 3 | https://medium.com/@NoamShazeer/shape-suffixes-good-coding-style-f836e72e24fd 4 | 5 | B: batch size 6 | C: the length of the input on which conditioning is done 7 | in our case input_max_length 8 | L: sequence length for decoder, in our case output_max_length 9 | M: memory length (length of sequence being attended to) 10 | D: model dimension (sometimes called d_model or embedding_dim) 11 | V: vocabulary size 12 | F: feed-forward subnetwork hidden size 13 | H: number of attention heads in a layer 14 | K: size of each attention key or value (sometimes called d_kv) 15 | """ 16 | 17 | import torch 18 | import torch.nn as nn 19 | from tqdm import tqdm 20 | 21 | from directmultistep.utils.logging_config import logger 22 | 23 | # Define types 24 | Tensor = torch.Tensor 25 | BeamSearchOutput = list[list[tuple[str, float]]] 26 | 27 | 28 | class BeamSearchOptimized: 29 | def __init__( 30 | self, 31 | model: nn.Module, 32 | beam_size: int, 33 | start_idx: int, 34 | pad_idx: int, 35 | end_idx: int, 36 | max_length: int, 37 | idx_to_token: dict[int, str], 38 | device: torch.device, 39 | ): 40 | self.model = model 41 | self.beam_size = beam_size 42 | self.start_idx = start_idx 43 | self.pad_idx = pad_idx 44 | self.end_idx = end_idx 45 | self.device = device 46 | self.max_length = max_length 47 | self.idx_to_token = idx_to_token 48 | 49 | def __repr__(self) -> str: 50 | return f"BeamSearchOptimized(beam_width={self.beam_size}, max_length={self.max_length})" 51 | 52 | def decode(self, src_BC: Tensor, steps_B1: Tensor | None, path_start_BL: Tensor | None = None) -> BeamSearchOutput: 53 | """ 54 | src_BC: product + one_sm (B, C) 55 | steps_B1: number of steps (B, 1) 56 | 57 | Define S as beam_size. 58 | Define W as B*S (W for Window, a window for output). 59 | _nt stands for not a tensor, a regular list. 60 | """ 61 | B, C = src_BC.shape 62 | S = self.beam_size 63 | L = self.max_length 64 | 65 | # Prepare mask and encoder outputs 66 | src_mask_B11C = (src_BC != self.pad_idx).unsqueeze(1).unsqueeze(2) 67 | enc_src_BCD = self.model.encoder(src_BC.long(), src_mask_B11C, steps_B1) 68 | beam_enc_WCD = enc_src_BCD.repeat_interleave(S, dim=0) # W = B * S 69 | 70 | beam_src_WC = src_BC.repeat_interleave(S, dim=0) 71 | beam_src_mask_W11C = (beam_src_WC != self.pad_idx).unsqueeze(1).unsqueeze(2) 72 | 73 | beam_idxs_WL = torch.full((B * S, L), self.pad_idx, dtype=torch.long, device=self.device) 74 | if path_start_BL is None: 75 | beam_idxs_WL[:, 0] = self.start_idx 76 | first_step = 1 77 | beam_log_probs_W = torch.zeros(B * S, device=self.device) 78 | else: 79 | beam_idxs_WL[:, : path_start_BL.size(1)] = path_start_BL 80 | first_step = path_start_BL.size(1) 81 | beam_log_probs_W = torch.zeros(B * S, device=self.device) 82 | 83 | finished_sequences_W = torch.zeros(B * S, dtype=torch.bool, device=self.device) 84 | logger.info( 85 | f"Generating routes with beam size {S}. The progress bar may end early if all beams find end token." 86 | ) 87 | for step in tqdm(range(first_step, L - 1)): 88 | with torch.no_grad(): 89 | output_WLV = self.model.decoder( 90 | trg_BL=beam_idxs_WL[:, :step], 91 | enc_src_BCD=beam_enc_WCD, 92 | src_mask_B11C=beam_src_mask_W11C, 93 | trg_mask_B1LL=None, # trg_mask_W1LL[:, :, :step, :step] 94 | ) 95 | W, _, V = output_WLV.shape 96 | output_WV = output_WLV[:, -1, :] # Get the last token's logits 97 | log_probs_WV = torch.log_softmax(output_WV, dim=-1) 98 | 99 | finished_sequences_W = torch.any(beam_idxs_WL == self.end_idx, dim=-1) 100 | active_mask_W = ~finished_sequences_W 101 | if finished_sequences_W.all(): 102 | break 103 | # finished_mask_WV = finished_sequences_W.unsqueeze(-1).expand(-1, V) 104 | # log_probs_WV = log_probs_WV.masked_fill(finished_mask_WV, float('-inf')) 105 | 106 | if step == first_step: 107 | log_probs_BSV = log_probs_WV.view(B, S, -1) 108 | log_probs_WS, top_k_idxs_WS = torch.topk(log_probs_BSV[:, 0, :], S, dim=-1) 109 | beam_log_probs_W = log_probs_WS.view(B * S) 110 | beam_idxs_WL[:, step] = top_k_idxs_WS.view(B * S) 111 | else: 112 | active_WV = active_mask_W.unsqueeze(1).expand(-1, V) 113 | cur_log_probs_WV = beam_log_probs_W.unsqueeze(1) + log_probs_WV 114 | 115 | _, act_top_k_idxs_WS = torch.topk(cur_log_probs_WV[active_WV].view(-1, V), S, dim=-1) 116 | act_top_k_idxs_BSS = act_top_k_idxs_WS.view(B, -1, S) 117 | 118 | active_WL = active_mask_W.unsqueeze(-1).repeat(1, L) 119 | active_beams_WL = beam_idxs_WL[active_WL].view(-1, L) 120 | active_beams_BSL = active_beams_WL.view(B, -1, L) 121 | _S = active_beams_BSL.size(1) 122 | active_beams_BSSL = active_beams_BSL.unsqueeze(2).repeat(1, 1, S, 1) 123 | active_beams_BSSL[..., step] = act_top_k_idxs_BSS 124 | active_beams_BSsqL = active_beams_BSSL.view(B, -1, L) # my candidate_seqs_BSL_nt 125 | cur_log_probs_WS = cur_log_probs_WV[active_mask_W].view(-1, V).gather(1, act_top_k_idxs_WS) 126 | 127 | # cur_log_probs_BSsq = cur_log_probs_WS.view(B, -1) # my candidate_probs_BS_nt 128 | sequence_lengths_WL = (active_beams_WL.ne(self.pad_idx).sum(dim=1).float()).unsqueeze(1) 129 | 130 | normalized_act_log_probs_WS = cur_log_probs_WS / (sequence_lengths_WL.sqrt() + 1e-6) 131 | normalized_act_log_probs_BSsq = normalized_act_log_probs_WS.view(B, -1) 132 | _, best_idxs_BS = normalized_act_log_probs_BSsq.topk(S, dim=-1) 133 | 134 | active_beams_WL = active_beams_BSsqL.view(-1, L).gather( 135 | 0, best_idxs_BS.view(-1).unsqueeze(-1).expand(-1, L) 136 | ) 137 | active_log_probs_W = cur_log_probs_WS.view(-1).gather(0, best_idxs_BS.view(-1)) 138 | 139 | active_beams_BSL = active_beams_WL.view(B, -1, L) 140 | active_log_probs_BS = active_log_probs_W.view(B, -1) 141 | 142 | inactive_beams_WL = beam_idxs_WL[~active_WL] 143 | inactive_log_probs_W = beam_log_probs_W[~active_mask_W] 144 | inactive_beams_BSL = inactive_beams_WL.view(B, -1, L) 145 | inactive_log_probs_BS = inactive_log_probs_W.view(B, -1) 146 | 147 | both_beams_BSL = torch.cat([active_beams_BSL, inactive_beams_BSL], dim=1) 148 | both_log_probs_BS = torch.cat([active_log_probs_BS, inactive_log_probs_BS], dim=1) 149 | 150 | both_beams_WL = both_beams_BSL.view(-1, L) 151 | both_log_probs_W = both_log_probs_BS.view(-1) 152 | 153 | both_seq_lengths_W = both_beams_WL.ne(self.pad_idx).sum(dim=1).float() 154 | both_normalized_log_probs_WS = both_log_probs_W / (both_seq_lengths_W.sqrt() + 1e-6) 155 | both_normalized_log_probs_BSsq = both_normalized_log_probs_WS.view(B, -1) 156 | _, best_idxs_BS = both_normalized_log_probs_BSsq.topk(S, dim=-1) 157 | 158 | beam_idxs_WL = both_beams_BSL.gather(1, best_idxs_BS.unsqueeze(-1).expand(-1, -1, L)).view(-1, L) 159 | beam_log_probs_W = both_log_probs_BS.gather(1, best_idxs_BS).view(-1) 160 | 161 | beam_idxs_BSL = beam_idxs_WL.view(B, S, L) 162 | beam_log_probs_BS = beam_log_probs_W.view(B, S) 163 | 164 | outputs_BS2_nt: list[list[tuple[str, float]]] = [[] for _ in range(B)] 165 | 166 | for b in range(B): 167 | for s in range(S): 168 | output_str = "" 169 | for L_idx in beam_idxs_BSL[b, s]: 170 | if L_idx == self.start_idx: 171 | continue 172 | if L_idx == self.end_idx: 173 | break 174 | output_str += self.idx_to_token[L_idx.item()] 175 | log_prob = beam_log_probs_BS[b, s].item() 176 | outputs_BS2_nt[b].append((output_str, log_prob)) 177 | 178 | return outputs_BS2_nt 179 | -------------------------------------------------------------------------------- /src/directmultistep/helpers.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import re 3 | from pathlib import Path 4 | 5 | from directmultistep.utils.dataset import RoutesDataset 6 | 7 | 8 | def prepare_datasets( 9 | train_data_path: Path, 10 | val_data_path: Path, 11 | metadata_path: Path, 12 | load_sm: bool = True, 13 | mode: str = "training", 14 | ) -> tuple[RoutesDataset, ...]: 15 | with open(train_data_path, "rb") as f: 16 | (products, starting_materials, path_strings, n_steps_list) = pickle.load(f) 17 | if not load_sm: 18 | starting_materials = None 19 | ds_train = RoutesDataset( 20 | metadata_path=metadata_path, 21 | products=products, 22 | starting_materials=starting_materials, 23 | path_strings=path_strings, 24 | n_steps_list=n_steps_list, 25 | mode=mode, 26 | ) 27 | with open(val_data_path, "rb") as f: 28 | (val_products, val_sms, val_path_strings, val_steps_list) = pickle.load(f) 29 | if not load_sm: 30 | val_sms = None 31 | ds_val = RoutesDataset( 32 | metadata_path=metadata_path, 33 | products=val_products, 34 | starting_materials=val_sms, 35 | path_strings=val_path_strings, 36 | n_steps_list=val_steps_list, 37 | mode=mode, 38 | ) 39 | return ds_train, ds_val 40 | 41 | 42 | def find_checkpoint(train_path: Path, run_name: str) -> Path | None: 43 | ckpt_path = train_path / run_name 44 | checkpoints = list(ckpt_path.glob("*.ckpt")) 45 | 46 | # First, check if there's a file with "last" in its name 47 | last_checkpoints = [ckpt for ckpt in checkpoints if "last" in ckpt.stem] 48 | if last_checkpoints: 49 | # Extract version number if present, else default to 0 (e.g., last.ckpt is treated as v0) 50 | def parse_version(ckpt: Path) -> int: 51 | match = re.search(r"last-v(\d+)", ckpt.stem) 52 | return int(match.group(1)) if match else 0 53 | 54 | # Sort by version number in descending order and return the latest 55 | return sorted(last_checkpoints, key=parse_version, reverse=True)[0] 56 | 57 | # If no "last" file, find the checkpoint with the largest epoch and step 58 | def parse_epoch_step(filename: str) -> tuple[int, int]: 59 | # This pattern will match 'epoch=X-step=Y.ckpt' and extract X and Y 60 | match = re.search(r"epoch=(\d+)-step=(\d+)\.ckpt", filename) 61 | if match: 62 | return int(match.group(1)), int(match.group(2)) 63 | return -1, -1 # Default to -1 if no match found 64 | 65 | checkpoints.sort(key=lambda ckpt: parse_epoch_step(ckpt.name), reverse=True) 66 | return checkpoints[0] if checkpoints else None 67 | 68 | 69 | if __name__ == "__main__": 70 | train_path = Path(__file__).resolve().parent / "Data" / "Training" 71 | run_name = "moe_3x2_3x3_002_local" 72 | o = find_checkpoint(train_path, run_name) 73 | print(f"{o=}") 74 | -------------------------------------------------------------------------------- /src/directmultistep/model/__init__.py: -------------------------------------------------------------------------------- 1 | from directmultistep.model.factory import ModelFactory 2 | 3 | __all__ = ["ModelFactory"] 4 | -------------------------------------------------------------------------------- /src/directmultistep/model/architecture.py: -------------------------------------------------------------------------------- 1 | # Shape suffixes convention inspired by 2 | # https://medium.com/@NoamShazeer/shape-suffixes-good-coding-style-f836e72e24fd 3 | 4 | # B: batch size 5 | # C: the length of the input on which conditioning is done 6 | # in our case input_max_length 7 | # L: sequence length for decoder, in our case output_max_length 8 | # M: memory length (length of sequence being attended to) 9 | # D: model dimension (sometimes called d_model or embedding_dim) 10 | # V: vocabulary size 11 | # F: feed-forward subnetwork hidden size 12 | # H: number of attention heads in a layer 13 | # K: size of each attention key or value (sometimes called d_kv) 14 | 15 | 16 | from typing import cast 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | Tensor = torch.Tensor 22 | 23 | 24 | class Seq2Seq(nn.Module): 25 | def __init__( 26 | self, 27 | encoder: nn.Module, 28 | decoder: nn.Module, 29 | src_pad_idx: int, 30 | trg_pad_idx: int, 31 | ): 32 | super().__init__() 33 | 34 | self.decoder = decoder 35 | self.encoder = encoder 36 | self.src_pad_idx = src_pad_idx 37 | self.trg_pad_idx = trg_pad_idx 38 | 39 | def make_src_mask(self, src_BC: Tensor) -> Tensor: 40 | src_mask_B11C = (src_BC != self.src_pad_idx).unsqueeze(1).unsqueeze(2) 41 | return src_mask_B11C 42 | 43 | def forward(self, src_BC: Tensor, trg_BL: Tensor, steps_B1: Tensor) -> Tensor: 44 | """ 45 | src_BC is the product_item + one_sm_item combined 46 | trg_BL is the path_string of the corresponding route 47 | """ 48 | src_mask_B11C = self.make_src_mask(src_BC.long()) 49 | 50 | enc_src_BCD = self.encoder(src_BC.long(), src_mask_B11C, steps_B1) 51 | trg_mask = None # this will trigger is_causal=True 52 | output_BLV = self.decoder(trg_BL, enc_src_BCD, src_mask_B11C, trg_mask_B1LL=trg_mask) 53 | return cast(Tensor, output_BLV) 54 | -------------------------------------------------------------------------------- /src/directmultistep/model/components/__init__.py: -------------------------------------------------------------------------------- 1 | from directmultistep.model.components.attention import MultiHeadAttentionLayer 2 | from directmultistep.model.components.decoder import DecoderLayer, MoEDecoderLayer 3 | from directmultistep.model.components.encoder import EncoderLayer, MoEEncoderLayer 4 | from directmultistep.model.components.moe import PositionwiseFeedforwardLayer, SparseMoE 5 | 6 | __all__ = [ 7 | "MultiHeadAttentionLayer", 8 | "SparseMoE", 9 | "PositionwiseFeedforwardLayer", 10 | "EncoderLayer", 11 | "MoEEncoderLayer", 12 | "DecoderLayer", 13 | "MoEDecoderLayer", 14 | ] 15 | -------------------------------------------------------------------------------- /src/directmultistep/model/components/attention.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | Tensor = torch.Tensor 7 | 8 | 9 | class MultiHeadAttentionLayer(nn.Module): 10 | """ 11 | Multi-head attention layer. 12 | 13 | This layer applies multi-head attention to the input tensors. 14 | 15 | Shape suffixes convention: 16 | B: batch size 17 | L: sequence length for decoder 18 | M: memory length (length of sequence being attended to) 19 | D: model dimension (sometimes called d_model or embedding_dim) 20 | H: number of attention heads in a layer 21 | 22 | Args: 23 | hid_dim: The hidden dimension size. 24 | n_heads: The number of attention heads. 25 | dropout: The dropout rate. 26 | attn_bias: Whether to use bias in the linear layers. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | hid_dim: int, 32 | n_heads: int, 33 | dropout: float, 34 | attn_bias: bool, 35 | # device: torch.device, 36 | ): 37 | super().__init__() 38 | 39 | self.hid_dim = hid_dim 40 | self.n_heads = n_heads 41 | self.head_dim = hid_dim // n_heads 42 | 43 | self.query = nn.Linear(hid_dim, hid_dim, bias=attn_bias) 44 | self.key = nn.Linear(hid_dim, hid_dim, bias=attn_bias) 45 | self.value = nn.Linear(hid_dim, hid_dim, bias=attn_bias) 46 | 47 | self.projection = nn.Linear(hid_dim, hid_dim) 48 | 49 | self.dropout = nn.Dropout(dropout) 50 | # self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device) 51 | 52 | def forward( 53 | self, 54 | query_BLD: Tensor, 55 | key_BMD: Tensor, 56 | value_BMD: Tensor, 57 | mask_B11M: Tensor | None = None, 58 | ) -> Tensor: 59 | """ 60 | Forward pass of the multi-head attention layer. 61 | 62 | Shape suffixes convention: 63 | B: batch size 64 | L: sequence length for decoder 65 | M: memory length (length of sequence being attended to) 66 | D: model dimension (sometimes called d_model or embedding_dim) 67 | H: number of attention heads in a layer 68 | 69 | Args: 70 | query_BLD: The query tensor of shape (B, L, D). 71 | key_BMD: The key tensor of shape (B, M, D). 72 | value_BMD: The value tensor of shape (B, M, D). 73 | mask_B11M: The attention mask of shape (B, 1, 1, M). 74 | 75 | Returns: 76 | The output tensor of shape (B, L, D). 77 | """ 78 | B, L, _ = query_BLD.shape 79 | Q_BLD = self.query(query_BLD) 80 | K_BMD = self.key(key_BMD) 81 | V_BMD = self.value(value_BMD) 82 | # Reshape into multiple heads 83 | Q_BHLD = Q_BLD.view(B, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 84 | K_BHMD = K_BMD.view(B, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 85 | V_BHMD = V_BMD.view(B, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 86 | 87 | if mask_B11M is not None: 88 | # Expand mask for all heads 89 | mask_BHLM = mask_B11M.expand(B, self.n_heads, L, -1) 90 | is_causal = False 91 | else: 92 | mask_BHLM = None 93 | is_causal = True 94 | 95 | attn_output_BHLD = nn.functional.scaled_dot_product_attention( 96 | query=Q_BHLD, 97 | key=K_BHMD, 98 | value=V_BHMD, 99 | attn_mask=mask_BHLM, 100 | dropout_p=self.dropout.p if self.training else 0.0, 101 | is_causal=is_causal, 102 | # scale=self.scale.item(), 103 | ) 104 | attn_output_BLD = attn_output_BHLD.permute(0, 2, 1, 3).contiguous().view(B, L, self.hid_dim) 105 | output_BLD = cast(Tensor, self.projection(attn_output_BLD)) 106 | return output_BLD 107 | -------------------------------------------------------------------------------- /src/directmultistep/model/components/decoder.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from directmultistep.model.components.attention import MultiHeadAttentionLayer 7 | from directmultistep.model.components.moe import PositionwiseFeedforwardLayer, SparseMoE 8 | 9 | Tensor = torch.Tensor 10 | activation_dict = { 11 | "relu": nn.ReLU(), 12 | "gelu": nn.GELU(), 13 | } 14 | 15 | 16 | class DecoderLayer(nn.Module): 17 | """A single layer of the decoder. 18 | 19 | Shape suffixes convention: 20 | B: batch size 21 | C: the length of the input on which conditioning is done 22 | (in our case input_max_length) 23 | L: sequence length for decoder, in our case output_max_length 24 | D: model dimension (sometimes called d_model or embedding_dim) 25 | """ 26 | 27 | def __init__( 28 | self, 29 | hid_dim: int, 30 | n_heads: int, 31 | dropout: float, 32 | attn_bias: bool, 33 | ff_mult: int, 34 | ff_activation: str, 35 | ) -> None: 36 | """Initializes the DecoderLayer. 37 | 38 | Args: 39 | hid_dim: The hidden dimension size. 40 | n_heads: The number of attention heads. 41 | dropout: The dropout rate. 42 | attn_bias: Whether to use bias in the attention layers. 43 | ff_mult: The feed-forward expansion factor. 44 | ff_activation: The activation function type. 45 | """ 46 | super().__init__() 47 | self.self_attn_ln = nn.LayerNorm(hid_dim) 48 | self.enc_attn_ln = nn.LayerNorm(hid_dim) 49 | self.ff_ln = nn.LayerNorm(hid_dim) 50 | self.self_attn = MultiHeadAttentionLayer( 51 | hid_dim=hid_dim, 52 | n_heads=n_heads, 53 | dropout=dropout, 54 | attn_bias=attn_bias, 55 | ) 56 | self.encoder_attn = MultiHeadAttentionLayer( 57 | hid_dim=hid_dim, 58 | n_heads=n_heads, 59 | dropout=dropout, 60 | attn_bias=attn_bias, 61 | ) 62 | self.mlp: nn.Module = PositionwiseFeedforwardLayer( 63 | hid_dim=hid_dim, 64 | ff_mult=ff_mult, 65 | ff_activation=activation_dict[ff_activation], 66 | dropout=dropout, 67 | ) 68 | self.dropout = nn.Dropout(dropout) 69 | 70 | def forward( 71 | self, 72 | trg_BLD: Tensor, 73 | enc_src_BCD: Tensor, 74 | src_mask_B11C: Tensor, 75 | trg_mask_B1LL: Tensor, 76 | ) -> Tensor: 77 | """Forward pass of the DecoderLayer. 78 | 79 | Args: 80 | trg_BLD: The target sequence tensor of shape (B, L, D). 81 | enc_src_BCD: The encoder output tensor of shape (B, C, D). 82 | src_mask_B11C: The source mask tensor of shape (B, 1, 1, C). 83 | trg_mask_B1LL: The target mask tensor of shape (B, 1, L, L). 84 | 85 | Returns: 86 | The output tensor of shape (B, L, D). 87 | """ 88 | self_attn_BLD = self.self_attn(trg_BLD, trg_BLD, trg_BLD, trg_mask_B1LL) 89 | trg_BLD = self.self_attn_ln(trg_BLD + self.dropout(self_attn_BLD)) 90 | # Encoder-Decoder Attetion 91 | enc_attn_BLD = self.encoder_attn(trg_BLD, enc_src_BCD, enc_src_BCD, src_mask_B11C) 92 | trg_BLD = self.enc_attn_ln(trg_BLD + self.dropout(enc_attn_BLD)) 93 | ff_out_BLD = self.mlp(trg_BLD) 94 | trg_BLD = self.ff_ln(trg_BLD + self.dropout(ff_out_BLD)) 95 | return trg_BLD 96 | 97 | 98 | class MoEDecoderLayer(DecoderLayer): 99 | """A single layer of the decoder with Mixture of Experts in the feedforward layer.""" 100 | 101 | def __init__( 102 | self, 103 | hid_dim: int, 104 | n_heads: int, 105 | dropout: float, 106 | attn_bias: bool, 107 | ff_mult: int, 108 | ff_activation: str, 109 | n_experts: int, 110 | top_k: int, 111 | capacity_factor: float, 112 | ) -> None: 113 | """Initializes the MoEDecoderLayer. 114 | 115 | Args: 116 | hid_dim: The hidden dimension size. 117 | n_heads: The number of attention heads. 118 | dropout: The dropout rate. 119 | attn_bias: Whether to use bias in the attention layers. 120 | ff_mult: The feed-forward expansion factor. 121 | ff_activation: The activation function type. 122 | n_experts: The number of experts in the MoE layer. 123 | top_k: The number of experts to use in the MoE layer. 124 | capacity_factor: The capacity factor for the MoE layer. 125 | """ 126 | super().__init__( 127 | hid_dim=hid_dim, 128 | n_heads=n_heads, 129 | dropout=dropout, 130 | attn_bias=attn_bias, 131 | ff_mult=ff_mult, 132 | ff_activation=ff_activation, 133 | ) 134 | # Override the MLP with MoE 135 | self.mlp = SparseMoE( 136 | hid_dim=hid_dim, 137 | n_experts=n_experts, 138 | top_k=top_k, 139 | ff_mult=ff_mult, 140 | ff_activation=ff_activation, 141 | dropout=dropout, 142 | capacity_factor=capacity_factor, 143 | ) 144 | 145 | 146 | class Decoder(nn.Module): 147 | """The decoder module. 148 | 149 | Shape suffixes convention: 150 | B: batch size 151 | C: the length of the input on which conditioning is done 152 | (in our case input_max_length) 153 | L: sequence length for decoder, in our case output_max_length 154 | D: model dimension (sometimes called d_model or embedding_dim) 155 | V: vocabulary size 156 | """ 157 | 158 | def __init__( 159 | self, 160 | vocab_dim: int, 161 | hid_dim: int, 162 | context_window: int, 163 | n_layers: int, 164 | n_heads: int, 165 | dropout: float, 166 | attn_bias: bool, 167 | ff_mult: int, 168 | ff_activation: str, 169 | ) -> None: 170 | """Initializes the Decoder. 171 | 172 | Args: 173 | vocab_dim: The vocabulary size. 174 | hid_dim: The hidden dimension size. 175 | context_window: The context window size. 176 | n_layers: The number of decoder layers. 177 | n_heads: The number of attention heads. 178 | dropout: The dropout rate. 179 | attn_bias: Whether to use bias in the attention layers. 180 | ff_mult: The feed-forward expansion factor. 181 | ff_activation: The activation function type. 182 | """ 183 | super().__init__() 184 | self.hid_dim = hid_dim 185 | self.tok_embedding = nn.Embedding(vocab_dim, hid_dim) 186 | self.pos_embedding = nn.Embedding(context_window, hid_dim) 187 | 188 | self.layers = nn.ModuleList( 189 | [ 190 | DecoderLayer( 191 | hid_dim=hid_dim, 192 | n_heads=n_heads, 193 | dropout=dropout, 194 | attn_bias=attn_bias, 195 | ff_mult=ff_mult, 196 | ff_activation=ff_activation, 197 | ) 198 | for _ in range(n_layers) 199 | ] 200 | ) 201 | 202 | self.fc_out = nn.Linear(hid_dim, vocab_dim) 203 | self.dropout = nn.Dropout(dropout) 204 | self.scale = torch.sqrt(torch.FloatTensor([hid_dim])) 205 | 206 | def forward( 207 | self, 208 | trg_BL: Tensor, 209 | enc_src_BCD: Tensor, 210 | src_mask_B11C: Tensor, 211 | trg_mask_B1LL: Tensor | None = None, 212 | ) -> Tensor: 213 | """Forward pass of the Decoder. 214 | 215 | Args: 216 | trg_BL: The target sequence tensor of shape (B, L). 217 | enc_src_BCD: The encoder output tensor of shape (B, C, D). 218 | src_mask_B11C: The source mask tensor of shape (B, 1, 1, C). 219 | trg_mask_B1LL: The target mask tensor of shape (B, 1, L, L). 220 | 221 | Returns: 222 | The output tensor of shape (B, L, V). 223 | """ 224 | B, L = trg_BL.shape 225 | # below: [L] -> [1, L] -> [B, L] 226 | pos_BL = torch.arange(0, L).unsqueeze(0).repeat(B, 1).to(trg_BL) 227 | tok_emb_BLD = self.tok_embedding(trg_BL) * self.scale.to(trg_BL) 228 | pos_emb_BLD = self.pos_embedding(pos_BL) 229 | trg_BLD = self.dropout(tok_emb_BLD + pos_emb_BLD) 230 | for layer in self.layers: 231 | trg_BLD = layer(trg_BLD, enc_src_BCD, src_mask_B11C, trg_mask_B1LL) 232 | output_BLV = self.fc_out(trg_BLD) 233 | return cast(Tensor, output_BLV) 234 | 235 | 236 | class MoEDecoder(Decoder): 237 | """The decoder module with Mixture of Experts in the feedforward layers.""" 238 | 239 | def __init__( 240 | self, 241 | vocab_dim: int, 242 | hid_dim: int, 243 | context_window: int, 244 | n_layers: int, 245 | n_heads: int, 246 | dropout: float, 247 | attn_bias: bool, 248 | ff_mult: int, 249 | ff_activation: str, 250 | n_experts: int, 251 | top_k: int, 252 | capacity_factor: float, 253 | ): 254 | """Initializes the MoEDecoder. 255 | 256 | Args: 257 | vocab_dim: The vocabulary size. 258 | hid_dim: The hidden dimension size. 259 | context_window: The context window size. 260 | n_layers: The number of decoder layers. 261 | n_heads: The number of attention heads. 262 | dropout: The dropout rate. 263 | attn_bias: Whether to use bias in the attention layers. 264 | ff_mult: The feed-forward expansion factor. 265 | ff_activation: The activation function type. 266 | n_experts: The number of experts in the MoE layer. 267 | top_k: The number of experts to use in the MoE layer. 268 | capacity_factor: The capacity factor for the MoE layer. 269 | """ 270 | super().__init__( 271 | vocab_dim=vocab_dim, 272 | hid_dim=hid_dim, 273 | context_window=context_window, 274 | n_layers=n_layers, 275 | n_heads=n_heads, 276 | dropout=dropout, 277 | attn_bias=attn_bias, 278 | ff_mult=ff_mult, 279 | ff_activation=ff_activation, 280 | ) 281 | # Override layers with MoE layers 282 | self.layers = nn.ModuleList( 283 | [ 284 | MoEDecoderLayer( 285 | hid_dim=hid_dim, 286 | n_heads=n_heads, 287 | dropout=dropout, 288 | attn_bias=attn_bias, 289 | ff_mult=ff_mult, 290 | ff_activation=ff_activation, 291 | n_experts=n_experts, 292 | top_k=top_k, 293 | capacity_factor=capacity_factor, 294 | ) 295 | for _ in range(n_layers) 296 | ] 297 | ) 298 | -------------------------------------------------------------------------------- /src/directmultistep/model/components/moe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | Tensor = torch.Tensor 6 | activation_dict = { 7 | "relu": nn.ReLU(), 8 | "gelu": nn.GELU(), 9 | } 10 | 11 | 12 | class PositionwiseFeedforwardLayer(nn.Module): 13 | """Positionwise feedforward layer. 14 | 15 | Applies a two-layer feedforward network to the input. 16 | 17 | Shape suffixes: 18 | B: batch size 19 | L: sequence length 20 | D: model dimension 21 | F: feed-forward subnetwork hidden size 22 | """ 23 | 24 | def __init__( 25 | self, 26 | hid_dim: int, 27 | ff_mult: int, 28 | ff_activation: nn.Module, 29 | dropout: float, 30 | ): 31 | """Initializes the PositionwiseFeedforwardLayer. 32 | 33 | Args: 34 | hid_dim: The hidden dimension size (D). 35 | ff_mult: The feed-forward expansion factor. 36 | ff_activation: The activation function. 37 | dropout: The dropout rate. 38 | """ 39 | super().__init__() 40 | 41 | self.fc_1 = nn.Linear(hid_dim, ff_mult * hid_dim) 42 | self.activ = ff_activation 43 | self.fc_2 = nn.Linear(hid_dim * ff_mult, hid_dim) 44 | 45 | self.dropout = nn.Dropout(dropout) 46 | 47 | def forward(self, x_BLD: Tensor) -> Tensor: 48 | """Forward pass of the PositionwiseFeedforwardLayer. 49 | 50 | Args: 51 | x_BLD: The input tensor of shape (B, L, D). 52 | 53 | Returns: 54 | The output tensor of shape (B, L, D). 55 | """ 56 | x_BLF = self.dropout(self.activ(self.fc_1(x_BLD))) 57 | x_BLD = self.fc_2(x_BLF) 58 | return x_BLD 59 | 60 | 61 | class NoisyTopkRouter(nn.Module): 62 | """Noisy top-k router for MoE. 63 | 64 | Routes inputs to the top-k experts based on noisy logits. 65 | 66 | Shape suffixes: 67 | B: batch size 68 | L: sequence length 69 | D: model dimension 70 | E: number of experts 71 | K: top_k 72 | """ 73 | 74 | def __init__(self, hid_dim: int, n_experts: int, top_k: int): 75 | """Initializes the NoisyTopkRouter. 76 | 77 | Args: 78 | hid_dim: The hidden dimension size (D). 79 | n_experts: The number of experts (E). 80 | top_k: The number of top experts to route to (K). 81 | """ 82 | super().__init__() 83 | self.top_k = top_k 84 | self.topkroute_linear = nn.Linear(hid_dim, n_experts) 85 | self.noise_linear = nn.Linear(hid_dim, n_experts) 86 | 87 | def forward(self, x_BLD: Tensor) -> tuple[Tensor, Tensor]: 88 | """Forward pass of the NoisyTopkRouter. 89 | 90 | Args: 91 | x_BLD: The input tensor of shape (B, L, D). 92 | 93 | Returns: 94 | A tuple containing: 95 | - The router output tensor of shape (B, L, E). 96 | - The indices of the top-k experts of shape (B, L, K). 97 | """ 98 | logits_BLE = self.topkroute_linear(x_BLD) 99 | noise_logits_BLE = self.noise_linear(x_BLD) 100 | # Adding scaled unit gaussian noise to the logits 101 | noise_BLE = torch.randn_like(logits_BLE) * F.softplus(noise_logits_BLE) 102 | noisy_logits_BLE = logits_BLE + noise_BLE 103 | 104 | top_k_logits_BLE, indices_BLK = noisy_logits_BLE.topk(self.top_k, dim=-1) 105 | zeros_BLE = torch.full_like(noisy_logits_BLE, float("-inf")) 106 | # creating a sparse tensor with top-k logits 107 | sparse_logits_BLE = zeros_BLE.scatter(-1, indices_BLK, top_k_logits_BLE) 108 | router_output_BLE = F.softmax(sparse_logits_BLE, dim=-1) 109 | return router_output_BLE, indices_BLK 110 | 111 | 112 | class Expert(nn.Module): 113 | """A single expert in the MoE layer. 114 | 115 | Applies a two-layer feedforward network to the input. 116 | 117 | Shape suffixes: 118 | B: batch size 119 | L: sequence length 120 | D: model dimension 121 | F: feed-forward subnetwork hidden size 122 | """ 123 | 124 | def __init__( 125 | self, 126 | hid_dim: int, 127 | ff_mult: int, 128 | ff_activation: str, 129 | dropout: float, 130 | ): 131 | """Initializes the Expert. 132 | 133 | Args: 134 | hid_dim: The hidden dimension size (D). 135 | ff_mult: The feed-forward expansion factor. 136 | ff_activation: The activation function type. 137 | dropout: The dropout rate. 138 | """ 139 | super().__init__() 140 | self.net = nn.Sequential( 141 | nn.Linear(hid_dim, ff_mult * hid_dim), 142 | activation_dict[ff_activation], 143 | nn.Linear(ff_mult * hid_dim, hid_dim), 144 | nn.Dropout(dropout), 145 | ) 146 | 147 | def forward(self, x_BLD: Tensor) -> Tensor: 148 | """Forward pass of the Expert. 149 | 150 | Args: 151 | x_BLD: The input tensor of shape (B, L, D). 152 | 153 | Returns: 154 | The output tensor of shape (B, L, D). 155 | """ 156 | return self.net(x_BLD) # type: ignore 157 | 158 | 159 | class SparseMoE(nn.Module): 160 | """Sparse Mixture of Experts layer. 161 | 162 | Routes inputs to a subset of experts and combines their outputs. 163 | 164 | Shape suffixes: 165 | B: batch size 166 | L: sequence length 167 | D: model dimension 168 | E: number of experts 169 | K: top_k 170 | S: number of selected tokens for an expert 171 | """ 172 | 173 | def __init__( 174 | self, 175 | hid_dim: int, 176 | n_experts: int, 177 | top_k: int, 178 | ff_mult: int, 179 | ff_activation: str, 180 | dropout: float, 181 | capacity_factor: float, 182 | ): 183 | """Initializes the SparseMoE layer. 184 | 185 | Args: 186 | hid_dim: The hidden dimension size (D). 187 | n_experts: The number of experts (E). 188 | top_k: The number of top experts to route to (K). 189 | ff_mult: The feed-forward expansion factor. 190 | ff_activation: The activation function type. 191 | dropout: The dropout rate. 192 | capacity_factor: The capacity factor for each expert. 193 | """ 194 | super(SparseMoE, self).__init__() 195 | self.router = NoisyTopkRouter(hid_dim, n_experts, top_k) 196 | self.experts = nn.ModuleList([Expert(hid_dim, ff_mult, ff_activation, dropout) for _ in range(n_experts)]) 197 | self.n_experts = n_experts 198 | self.top_k = top_k 199 | self.capacity_factor = capacity_factor 200 | 201 | def forward(self, x_BLD: Tensor) -> Tensor: 202 | """Forward pass of the SparseMoE layer. 203 | 204 | Args: 205 | x_BLD: The input tensor of shape (B, L, D). 206 | 207 | Returns: 208 | The output tensor of shape (B, L, D). 209 | """ 210 | B, L, _ = x_BLD.shape 211 | gating_output_BLE, indices_BLK = self.router(x_BLD) 212 | final_output_BLD = torch.zeros_like(x_BLD) 213 | 214 | flat_x_FD = x_BLD.view(-1, x_BLD.size(-1)) # [B*L, D], define B*L=F 215 | flat_gating_output_FE = gating_output_BLE.view(-1, gating_output_BLE.size(-1)) 216 | n_tkns = B * L * self.top_k 217 | capacity = int((n_tkns / self.n_experts) * self.capacity_factor) 218 | 219 | updates_FD = torch.zeros_like(flat_x_FD) 220 | for i, expert in enumerate(self.experts): 221 | # Create a mask for the inputs where the current expert is in top-k 222 | expert_mask_BL = (indices_BLK == i).any(dim=-1) 223 | flat_mask_F = expert_mask_BL.view(-1) 224 | selected_idxs_F = torch.nonzero(flat_mask_F).squeeze(-1) 225 | 226 | if selected_idxs_F.numel() > capacity: 227 | limited_idxs_F = selected_idxs_F[:capacity] 228 | else: 229 | limited_idxs_F = selected_idxs_F 230 | 231 | if limited_idxs_F.numel() > 0: 232 | expert_input_SD = flat_x_FD[limited_idxs_F] # S = sum(flat_mask_F) 233 | expert_output_SD = expert(expert_input_SD) 234 | 235 | # Extract and apply gating scores, [S] -> [S, 1] 236 | gating_scores_S1 = flat_gating_output_FE[limited_idxs_F, i].unsqueeze(1) 237 | weighted_output_SD = expert_output_SD * gating_scores_S1 238 | 239 | updates_FD.index_add_(0, limited_idxs_F, weighted_output_SD) 240 | 241 | final_output_BLD += updates_FD.view(B, L, -1) 242 | 243 | return final_output_BLD 244 | -------------------------------------------------------------------------------- /src/directmultistep/model/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, dataclass 2 | from pathlib import Path 3 | from typing import Literal, Type, TypeVar 4 | 5 | import yaml 6 | 7 | T = TypeVar("T") 8 | 9 | 10 | @dataclass 11 | class TransformerConfig: 12 | """Configuration for transformer components. 13 | 14 | Attributes: 15 | vocab_dim: Vocabulary dimension. 16 | hid_dim: Hidden dimension. 17 | n_layers: Number of layers. 18 | n_heads: Number of attention heads. 19 | ff_mult: Feedforward multiplier. 20 | ff_activation: Feedforward activation function ('gelu' or 'relu'). 21 | dropout: Dropout probability. 22 | attn_bias: Whether to use attention bias. 23 | context_window: Context window size. 24 | start_idx: Start token index. 25 | mask_idx: Mask token index. 26 | pad_idx: Padding token index. 27 | """ 28 | 29 | vocab_dim: int 30 | hid_dim: int 31 | n_layers: int 32 | n_heads: int 33 | ff_mult: int 34 | ff_activation: Literal["gelu", "relu"] 35 | dropout: float 36 | attn_bias: bool 37 | context_window: int 38 | start_idx: int 39 | mask_idx: int 40 | pad_idx: int 41 | 42 | def __post_init__(self) -> None: 43 | if self.hid_dim % self.n_heads != 0: 44 | raise ValueError(f"{self.hid_dim=} must be divisible by {self.n_heads=}") 45 | if self.ff_activation not in ["gelu", "relu"]: 46 | raise ValueError(f"{self.ff_activation=} must be either 'gelu' or 'relu'") 47 | 48 | def save(self, path: Path) -> None: 49 | """Save config to yaml file. 50 | 51 | Args: 52 | path: Path to save the config to. 53 | """ 54 | data = asdict(self) 55 | data["model_type"] = self.__class__.__name__ 56 | with open(path, "w") as f: 57 | yaml.dump(data, f, sort_keys=False, default_flow_style=False) 58 | 59 | @classmethod 60 | def load(cls: Type[T], path: Path) -> T: 61 | """Load config from yaml file. 62 | 63 | Args: 64 | path: Path to load the config from. 65 | 66 | Returns: 67 | Loaded config. 68 | """ 69 | with open(path) as f: 70 | data = yaml.safe_load(f) 71 | return cls(**data) 72 | 73 | 74 | @dataclass 75 | class MoEDecoderConfig(TransformerConfig): 76 | """Configuration for Mixture of Experts decoder components. 77 | 78 | Attributes: 79 | n_experts: Number of experts. 80 | top_k: Number of experts to use in forward pass. 81 | capacity_factor: Capacity factor for experts. 82 | """ 83 | 84 | n_experts: int 85 | top_k: int 86 | capacity_factor: float 87 | 88 | 89 | @dataclass 90 | class EncoderAConfig(TransformerConfig): 91 | """Configuration for EncoderA components. 92 | 93 | Attributes: 94 | initiate_steps: Whether to initiate steps. 95 | include_steps: Whether to include steps. 96 | """ 97 | 98 | initiate_steps: bool 99 | include_steps: bool 100 | 101 | 102 | @dataclass 103 | class MoEEncoderConfig(EncoderAConfig): 104 | """Configuration for Mixture of Experts encoder components. 105 | 106 | Attributes: 107 | n_experts: Number of experts. 108 | top_k: Number of experts to use in forward pass. 109 | capacity_factor: Capacity factor for experts. 110 | """ 111 | 112 | n_experts: int 113 | top_k: int 114 | capacity_factor: float 115 | 116 | 117 | @dataclass 118 | class Seq2SeqConfig: 119 | """Complete model configuration. 120 | 121 | Attributes: 122 | encoder: Encoder configuration. 123 | decoder: Decoder configuration. 124 | """ 125 | 126 | encoder: TransformerConfig 127 | decoder: TransformerConfig 128 | 129 | def save(self, path: Path) -> None: 130 | """Save config to yaml file. 131 | 132 | Args: 133 | path: Path to save the config to. 134 | """ 135 | config_dict = { 136 | "encoder": asdict(self.encoder) | {"model_type": self.encoder.__class__.__name__}, 137 | "decoder": asdict(self.decoder) | {"model_type": self.decoder.__class__.__name__}, 138 | } 139 | with open(path, "w") as f: 140 | yaml.dump(config_dict, f, sort_keys=False) 141 | 142 | @classmethod 143 | def load(cls, path: Path) -> "Seq2SeqConfig": 144 | """Load config from yaml file. 145 | 146 | Args: 147 | path: Path to load the config from. 148 | 149 | Returns: 150 | Loaded Seq2SeqConfig. 151 | """ 152 | with open(path) as f: 153 | data = yaml.safe_load(f) 154 | 155 | # Determine correct encoder/decoder types based on model_type 156 | encoder_data = data.pop("encoder") 157 | decoder_data = data.pop("decoder") 158 | 159 | model_type_to_config = { 160 | "TransformerConfig": TransformerConfig, 161 | "MoEDecoderConfig": MoEDecoderConfig, 162 | "EncoderAConfig": EncoderAConfig, 163 | "MoEEncoderConfig": MoEEncoderConfig, 164 | } 165 | 166 | encoder_model_type = encoder_data.pop("model_type") 167 | decoder_model_type = decoder_data.pop("model_type") 168 | 169 | encoder_type = model_type_to_config[encoder_model_type] 170 | decoder_type = model_type_to_config[decoder_model_type] 171 | 172 | encoder = encoder_type(**encoder_data) 173 | decoder = decoder_type(**decoder_data) 174 | 175 | return cls(encoder=encoder, decoder=decoder, **data) 176 | 177 | 178 | if __name__ == "__main__": 179 | config = Seq2SeqConfig( 180 | encoder=TransformerConfig( 181 | vocab_dim=53, 182 | hid_dim=256, 183 | n_layers=6, 184 | n_heads=8, 185 | ff_mult=3, 186 | ff_activation="gelu", 187 | dropout=0.1, 188 | attn_bias=False, 189 | context_window=280, 190 | start_idx=0, 191 | mask_idx=51, 192 | pad_idx=52, 193 | ), 194 | decoder=MoEDecoderConfig( 195 | vocab_dim=53, 196 | hid_dim=256, 197 | n_layers=6, 198 | n_heads=8, 199 | ff_mult=3, 200 | ff_activation="gelu", 201 | dropout=0.1, 202 | attn_bias=False, 203 | context_window=1075, 204 | start_idx=0, 205 | mask_idx=51, 206 | pad_idx=52, 207 | n_experts=3, 208 | top_k=2, 209 | capacity_factor=1.0, 210 | ), 211 | ) 212 | -------------------------------------------------------------------------------- /src/directmultistep/model/default_configs/deep_40M.yaml: -------------------------------------------------------------------------------- 1 | encoder: 2 | vocab_dim: 53 3 | hid_dim: 256 4 | n_layers: 12 5 | n_heads: 8 6 | ff_mult: 3 7 | ff_activation: gelu 8 | dropout: 0.1 9 | attn_bias: false 10 | context_window: 145 11 | start_idx: 0 12 | mask_idx: 51 13 | pad_idx: 52 14 | initiate_steps: true 15 | include_steps: true 16 | model_type: EncoderAConfig 17 | decoder: 18 | vocab_dim: 53 19 | hid_dim: 256 20 | n_layers: 36 21 | n_heads: 8 22 | ff_mult: 3 23 | ff_activation: gelu 24 | dropout: 0.1 25 | attn_bias: false 26 | context_window: 1075 27 | start_idx: 0 28 | mask_idx: 51 29 | pad_idx: 52 30 | model_type: TransformerConfig 31 | -------------------------------------------------------------------------------- /src/directmultistep/model/default_configs/explorer_19M.yaml: -------------------------------------------------------------------------------- 1 | encoder: 2 | vocab_dim: 53 3 | hid_dim: 256 4 | n_layers: 6 5 | n_heads: 8 6 | ff_mult: 3 7 | ff_activation: gelu 8 | dropout: 0.1 9 | attn_bias: false 10 | context_window: 280 11 | start_idx: 0 12 | mask_idx: 51 13 | pad_idx: 52 14 | initiate_steps: true 15 | include_steps: false 16 | n_experts: 3 17 | top_k: 2 18 | capacity_factor: 1 19 | model_type: MoEEncoderConfig 20 | decoder: 21 | vocab_dim: 53 22 | hid_dim: 256 23 | n_layers: 6 24 | n_heads: 8 25 | ff_mult: 3 26 | ff_activation: gelu 27 | dropout: 0.1 28 | attn_bias: false 29 | context_window: 1075 30 | start_idx: 0 31 | mask_idx: 51 32 | pad_idx: 52 33 | n_experts: 3 34 | top_k: 2 35 | capacity_factor: 1 36 | model_type: MoEDecoderConfig 37 | -------------------------------------------------------------------------------- /src/directmultistep/model/default_configs/explorer_xl_50M.yaml: -------------------------------------------------------------------------------- 1 | encoder: 2 | vocab_dim: 53 3 | hid_dim: 256 4 | n_layers: 6 5 | n_heads: 8 6 | ff_mult: 3 7 | ff_activation: gelu 8 | dropout: 0.1 9 | attn_bias: false 10 | context_window: 145 11 | start_idx: 0 12 | mask_idx: 51 13 | pad_idx: 52 14 | initiate_steps: true 15 | include_steps: false 16 | n_experts: 3 17 | top_k: 2 18 | capacity_factor: 1 19 | model_type: MoEEncoderConfig 20 | decoder: 21 | vocab_dim: 53 22 | hid_dim: 256 23 | n_layers: 24 24 | n_heads: 8 25 | ff_mult: 3 26 | ff_activation: gelu 27 | dropout: 0.1 28 | attn_bias: false 29 | context_window: 1075 30 | start_idx: 0 31 | mask_idx: 51 32 | pad_idx: 52 33 | n_experts: 3 34 | top_k: 2 35 | capacity_factor: 1 36 | model_type: MoEDecoderConfig 37 | -------------------------------------------------------------------------------- /src/directmultistep/model/default_configs/flash_10M.yaml: -------------------------------------------------------------------------------- 1 | encoder: 2 | vocab_dim: 53 3 | hid_dim: 256 4 | n_layers: 6 5 | n_heads: 8 6 | ff_mult: 3 7 | ff_activation: gelu 8 | dropout: 0.1 9 | attn_bias: false 10 | context_window: 280 11 | start_idx: 0 12 | mask_idx: 51 13 | pad_idx: 52 14 | initiate_steps: true 15 | include_steps: true 16 | model_type: EncoderAConfig 17 | decoder: 18 | vocab_dim: 53 19 | hid_dim: 256 20 | n_layers: 6 21 | n_heads: 8 22 | ff_mult: 3 23 | ff_activation: gelu 24 | dropout: 0.1 25 | attn_bias: false 26 | context_window: 1075 27 | start_idx: 0 28 | mask_idx: 51 29 | pad_idx: 52 30 | model_type: TransformerConfig 31 | -------------------------------------------------------------------------------- /src/directmultistep/model/default_configs/flash_20M.yaml: -------------------------------------------------------------------------------- 1 | encoder: 2 | vocab_dim: 53 3 | hid_dim: 256 4 | n_layers: 12 5 | n_heads: 8 6 | ff_mult: 3 7 | ff_activation: gelu 8 | dropout: 0.1 9 | attn_bias: false 10 | context_window: 280 11 | start_idx: 0 12 | mask_idx: 51 13 | pad_idx: 52 14 | initiate_steps: true 15 | include_steps: true 16 | model_type: EncoderAConfig 17 | decoder: 18 | vocab_dim: 53 19 | hid_dim: 256 20 | n_layers: 12 21 | n_heads: 8 22 | ff_mult: 3 23 | ff_activation: gelu 24 | dropout: 0.1 25 | attn_bias: false 26 | context_window: 1075 27 | start_idx: 0 28 | mask_idx: 51 29 | pad_idx: 52 30 | model_type: TransformerConfig 31 | -------------------------------------------------------------------------------- /src/directmultistep/model/default_configs/flex_20M.yaml: -------------------------------------------------------------------------------- 1 | encoder: 2 | vocab_dim: 53 3 | hid_dim: 256 4 | n_layers: 6 5 | n_heads: 8 6 | ff_mult: 3 7 | ff_activation: gelu 8 | dropout: 0.1 9 | attn_bias: false 10 | context_window: 280 11 | start_idx: 0 12 | mask_idx: 51 13 | pad_idx: 52 14 | initiate_steps: true 15 | include_steps: true 16 | n_experts: 3 17 | top_k: 2 18 | capacity_factor: 1 19 | model_type: MoEEncoderConfig 20 | decoder: 21 | vocab_dim: 53 22 | hid_dim: 256 23 | n_layers: 6 24 | n_heads: 8 25 | ff_mult: 3 26 | ff_activation: gelu 27 | dropout: 0.1 28 | attn_bias: false 29 | context_window: 1075 30 | start_idx: 0 31 | mask_idx: 51 32 | pad_idx: 52 33 | n_experts: 3 34 | top_k: 2 35 | capacity_factor: 1 36 | model_type: MoEDecoderConfig 37 | -------------------------------------------------------------------------------- /src/directmultistep/model/default_configs/wide_40M.yaml: -------------------------------------------------------------------------------- 1 | encoder: 2 | vocab_dim: 53 3 | hid_dim: 256 4 | n_layers: 12 5 | n_heads: 8 6 | ff_mult: 3 7 | ff_activation: gelu 8 | dropout: 0.1 9 | attn_bias: false 10 | context_window: 145 11 | start_idx: 0 12 | mask_idx: 51 13 | pad_idx: 52 14 | initiate_steps: true 15 | include_steps: true 16 | n_experts: 3 17 | top_k: 2 18 | capacity_factor: 1 19 | model_type: MoEEncoderConfig 20 | decoder: 21 | vocab_dim: 53 22 | hid_dim: 256 23 | n_layers: 12 24 | n_heads: 8 25 | ff_mult: 3 26 | ff_activation: gelu 27 | dropout: 0.1 28 | attn_bias: false 29 | context_window: 1075 30 | start_idx: 0 31 | mask_idx: 51 32 | pad_idx: 52 33 | n_experts: 3 34 | top_k: 2 35 | capacity_factor: 1 36 | model_type: MoEDecoderConfig 37 | -------------------------------------------------------------------------------- /src/directmultistep/training/__init__.py: -------------------------------------------------------------------------------- 1 | from directmultistep.training.config import TrainingConfig 2 | from directmultistep.training.trainer import ModelTrainer 3 | 4 | __all__ = ["ModelTrainer", "TrainingConfig"] 5 | -------------------------------------------------------------------------------- /src/directmultistep/training/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, dataclass 2 | from pathlib import Path 3 | 4 | import yaml 5 | 6 | 7 | @dataclass 8 | class TrainingConfig: 9 | # Data configs 10 | data_path: Path 11 | 12 | # Training setup 13 | run_name: str 14 | train_fname: str 15 | val_fname: str 16 | metadata_fname: str 17 | 18 | # Training hyperparameters 19 | batch_size: int 20 | learning_rate: float 21 | max_epochs: int 22 | 23 | # Scheduler configs 24 | warmup_steps: int 25 | decay_steps: int 26 | decay_factor: float 27 | 28 | pad_idx: int 29 | mask_idx: int 30 | 31 | # Checkpointing 32 | save_top_k: int = -1 33 | checkpoint_every_n_epochs: int = 2 34 | 35 | num_workers: int = 1 36 | n_devices: int = 1 37 | seed: int = 42 38 | 39 | accelerator: str = "auto" 40 | matmul_precision: str = "high" 41 | summary_depth: int = 2 42 | dist_strategy: str = "ddp_find_unused_parameters_true" 43 | 44 | gradient_clip_val: float = 1.0 45 | gradient_clip_algorithm: str = "value" 46 | 47 | def __post_init__(self) -> None: 48 | self.data_path.mkdir(parents=True, exist_ok=True) 49 | self.run_name = f"{self.run_name}_seed={self.seed}" 50 | 51 | if self.matmul_precision not in ["high", "medium", "low"]: 52 | raise ValueError(f"{self.matmul_precision=} must be one of 'high', 'medium', or 'low'") 53 | 54 | if self.dist_strategy not in ["auto", "fsdp", "ddp", "ddp_spawn", "ddp_find_unused_parameters_true"]: 55 | raise ValueError( 56 | f"{self.dist_strategy=} must be one of 'fsdp', 'ddp', 'ddp_spawn', or 'ddp_find_unused_parameters_true'" 57 | ) 58 | 59 | if self.gradient_clip_algorithm not in ["norm", "value"]: 60 | raise ValueError(f"{self.gradient_clip_algorithm=} must be one of 'norm' or 'value'") 61 | 62 | def save(self, path: Path) -> None: 63 | """Save config to YAML file. 64 | 65 | Args: 66 | path: Path to save config file 67 | """ 68 | config_dict = asdict(self) 69 | config_dict["data_path"] = str(config_dict["data_path"]) 70 | 71 | with open(path, "w") as f: 72 | yaml.safe_dump(config_dict, f, default_flow_style=False, sort_keys=False) 73 | 74 | @classmethod 75 | def load(cls, path: Path) -> "TrainingConfig": 76 | """Load config from YAML file. 77 | 78 | Args: 79 | path: Path to config file 80 | 81 | Returns: 82 | Loaded config object 83 | """ 84 | with open(path) as f: 85 | config_dict = yaml.safe_load(f) 86 | 87 | config_dict["data_path"] = Path(config_dict["data_path"]) 88 | instance = cls.__new__(cls) 89 | for key, value in config_dict.items(): 90 | setattr(instance, key, value) 91 | return instance 92 | -------------------------------------------------------------------------------- /src/directmultistep/training/lightning.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, cast 2 | 3 | import lightning as pl 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | Tensor = torch.Tensor 9 | 10 | 11 | def warmup_and_cosine_decay(warmup_steps: int, decay_steps: int, decay_factor: float) -> Callable[[int], float]: 12 | """Creates a learning rate schedule with warmup and cosine decay. 13 | 14 | The learning rate increases linearly during the warmup phase, then 15 | decreases following a cosine function during the decay phase, and 16 | finally remains constant at the decay factor. 17 | 18 | Args: 19 | warmup_steps: The number of steps for the warmup phase. 20 | decay_steps: The number of steps for the decay phase. 21 | decay_factor: The final learning rate factor after decay. 22 | 23 | Returns: 24 | A function that takes the current step as input and returns the 25 | corresponding learning rate factor. 26 | """ 27 | 28 | def _get_new_lr(step: int) -> float: 29 | if step < warmup_steps: 30 | return step / warmup_steps 31 | elif step >= warmup_steps and step < warmup_steps + decay_steps: 32 | factor = 0.5 * (1 + np.cos(np.pi * (step - warmup_steps) / decay_steps)) 33 | return cast(float, max(factor, decay_factor)) 34 | else: 35 | return decay_factor 36 | 37 | return _get_new_lr 38 | 39 | 40 | class LTraining(pl.LightningModule): 41 | """A PyTorch Lightning module for training sequence-to-sequence models.""" 42 | 43 | def __init__( 44 | self, 45 | pad_idx: int, 46 | mask_idx: int, 47 | lr: float, 48 | batch_size: int, 49 | warmup_steps: int = 4000, 50 | decay_steps: int = 24000, 51 | decay_factor: float = 0.1, 52 | model: nn.Module | None = None, 53 | criterion: nn.Module | None = None, 54 | processed_tokens: int = 0, 55 | start_idx: int = 0, 56 | ): 57 | """Initializes the PLTraining module. 58 | 59 | Args: 60 | pad_idx: The index of the padding token. 61 | mask_idx: The index of the mask token. 62 | lr: The initial learning rate. 63 | batch_size: The batch size. 64 | warmup_steps: The number of warmup steps for the learning rate scheduler. 65 | decay_steps: The number of decay steps for the learning rate scheduler. 66 | decay_factor: The decay factor for the learning rate scheduler. 67 | model: The sequence-to-sequence model. 68 | criterion: The loss function. 69 | processed_tokens: The number of tokens processed so far. 70 | start_idx: The index of the start token. 71 | """ 72 | super().__init__() 73 | if model is not None: 74 | self.model = model 75 | if criterion is not None: 76 | self.criterion = criterion 77 | self.start_idx = start_idx 78 | self.pad_idx = pad_idx 79 | self.mask_idx = mask_idx 80 | self.learning_rate = lr 81 | self.batch_size = batch_size 82 | self.warmup_steps = warmup_steps 83 | self.decay_steps = decay_steps 84 | self.decay_factor = decay_factor 85 | self.processed_tokens = processed_tokens 86 | self.save_hyperparameters(ignore=["criterion", "model"]) 87 | self.compute_loss = self.compute_loss_full 88 | 89 | def mask_src(self, src_BC: Tensor, masking_prob: float) -> Tensor: 90 | """Masks the source sequence with a given probability. 91 | 92 | Args: 93 | src_BC: The source sequence tensor of shape [B, C]. 94 | masking_prob: The probability of masking a token. 95 | 96 | Returns: 97 | The masked source sequence tensor of shape [B, C]. 98 | """ 99 | mask_idx_BC = torch.rand(src_BC.shape).to(src_BC.device) < masking_prob 100 | not_pad_BC = src_BC != self.pad_idx 101 | final_mask_BC = mask_idx_BC & not_pad_BC 102 | masked_src_BC = src_BC.clone() 103 | masked_src_BC[final_mask_BC] = self.mask_idx 104 | return masked_src_BC 105 | 106 | def compute_loss_full(self, batch: Tensor, batch_idx: int) -> Tensor: 107 | """Computes the loss for the full sequence training. 108 | 109 | This method calculates the loss for all tokens in the sequence. 110 | 111 | Args: 112 | batch: The input batch tensor. 113 | batch_idx: The index of the batch. 114 | 115 | Returns: 116 | The computed loss tensor. 117 | """ 118 | src_item_BC = batch[0] 119 | tgt_item_BL = batch[1].long() 120 | steps_B1 = batch[2].view(-1, 1) 121 | masked_src_BC = self.mask_src(src_item_BC, masking_prob=0.05) 122 | # the output actually is [B, L-1, V] given slicing of tgt_item_BL 123 | output_BLV = self.model(masked_src_BC, tgt_item_BL[:, :-1], steps_B1) 124 | output_blV = output_BLV.view(-1, output_BLV.shape[-1]) # [B*(L-1), V] 125 | tgt_bl = tgt_item_BL[:, 1:].reshape(-1) # [B*(L-1)] 126 | loss = self.criterion(output_blV, tgt_bl) 127 | self.processed_tokens += tgt_item_BL.shape[0] * tgt_item_BL.shape[1] 128 | return cast(Tensor, loss) 129 | 130 | def log_step_info(self, loss: Tensor, mode: str, prog_bar: bool) -> None: 131 | """Logs the loss and other training information. 132 | 133 | Args: 134 | loss: The loss tensor. 135 | mode: The mode of training ('train' or 'val'). 136 | prog_bar: Whether to display the loss in the progress bar. 137 | """ 138 | self.log( 139 | f"{mode}_loss", 140 | loss, 141 | batch_size=self.batch_size, 142 | prog_bar=prog_bar, 143 | sync_dist=True, 144 | ) 145 | self.log("processed_tokens", self.processed_tokens, sync_dist=True) 146 | if mode == "train": 147 | current_lr = self.trainer.optimizers[0].param_groups[0]["lr"] 148 | self.log(f"{mode}_lr", current_lr, batch_size=self.batch_size, sync_dist=True) 149 | 150 | def training_step(self, batch: Tensor, batch_idx: int) -> Tensor: 151 | """Performs a single training step. 152 | 153 | Args: 154 | batch: The input batch tensor. 155 | batch_idx: The index of the batch. 156 | 157 | Returns: 158 | The computed loss tensor. 159 | """ 160 | loss = self.compute_loss(batch, batch_idx) 161 | self.log_step_info(loss, "train", prog_bar=True) 162 | return loss 163 | 164 | def validation_step(self, batch: Tensor, batch_idx: int) -> Tensor: 165 | """Performs a single validation step. 166 | 167 | Args: 168 | batch: The input batch tensor. 169 | batch_idx: The index of the batch. 170 | 171 | Returns: 172 | The computed loss tensor. 173 | """ 174 | loss = self.compute_loss(batch, batch_idx) 175 | self.log_step_info(loss, "val", prog_bar=True) 176 | return loss 177 | 178 | def configure_optimizers( 179 | self, 180 | ) -> tuple[list[torch.optim.Optimizer], list[dict[str, Any]]]: 181 | """Configures the optimizer and learning rate scheduler. 182 | 183 | Returns: 184 | A tuple containing the list of optimizers and the list of 185 | learning rate schedulers. 186 | """ 187 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate) 188 | # return optimizer 189 | scheduler = torch.optim.lr_scheduler.LambdaLR( 190 | optimizer, 191 | lr_lambda=warmup_and_cosine_decay( 192 | warmup_steps=self.warmup_steps, 193 | decay_steps=self.decay_steps, 194 | decay_factor=self.decay_factor, 195 | ), 196 | verbose=False, 197 | ) 198 | lr_scheduler = { 199 | "scheduler": scheduler, # The LR scheduler instance (required) 200 | "interval": "step", # The unit of the scheduler's step size 201 | "frequency": 1, # The frequency of the scheduler 202 | } 203 | return [optimizer], [lr_scheduler] 204 | 205 | def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: 206 | """Adds the processed tokens to the checkpoint. 207 | 208 | Args: 209 | checkpoint: The checkpoint dictionary. 210 | """ 211 | # Add processed_tokens to the checkpoint dictionary 212 | checkpoint["processed_tokens"] = self.processed_tokens 213 | 214 | def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: 215 | """Loads the processed tokens from the checkpoint. 216 | 217 | Args: 218 | checkpoint: The checkpoint dictionary. 219 | """ 220 | # Load processed_tokens from the checkpoint dictionary 221 | self.processed_tokens = checkpoint.get("processed_tokens", 0) 222 | -------------------------------------------------------------------------------- /src/directmultistep/training/trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import lightning as L 4 | import torch 5 | from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary 6 | from torch.utils.data import DataLoader 7 | 8 | from directmultistep import helpers 9 | from directmultistep.training.config import TrainingConfig 10 | from directmultistep.training.lightning import LTraining 11 | from directmultistep.utils.dataset import RoutesDataset 12 | 13 | Tensor = torch.Tensor 14 | 15 | 16 | class ModelTrainer: 17 | """High-level trainer class that orchestrates the training process.""" 18 | 19 | def __init__(self, config: TrainingConfig): 20 | """Initialize trainer with configuration. 21 | 22 | Args: 23 | config: Training configuration 24 | """ 25 | self.config = config 26 | self._setup_environment() 27 | 28 | def _setup_environment(self) -> None: 29 | """Configure training environment.""" 30 | L.seed_everything(self.config.seed) 31 | torch.set_float32_matmul_precision(self.config.matmul_precision) 32 | 33 | def _create_lightning_module(self, model: torch.nn.Module) -> LTraining: 34 | """Create the Lightning training module. 35 | 36 | Args: 37 | model: The model to train 38 | 39 | Returns: 40 | Configured PLTraining module 41 | """ 42 | criterion = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_idx, reduction="mean") 43 | 44 | return LTraining( 45 | model=model, 46 | pad_idx=self.config.pad_idx, 47 | mask_idx=self.config.mask_idx, 48 | criterion=criterion, 49 | lr=self.config.learning_rate, 50 | batch_size=self.config.batch_size, 51 | warmup_steps=self.config.warmup_steps, 52 | decay_steps=self.config.decay_steps, 53 | decay_factor=self.config.decay_factor, 54 | ) 55 | 56 | def _setup_callbacks(self) -> list[Any]: 57 | """Configure training callbacks. 58 | 59 | Returns: 60 | List of Lightning callbacks 61 | """ 62 | checkpoint_callback = ModelCheckpoint( 63 | monitor="val_loss", 64 | dirpath=self.config.data_path / "training" / self.config.run_name, 65 | save_last=True, 66 | save_top_k=self.config.save_top_k, 67 | every_n_epochs=self.config.checkpoint_every_n_epochs, 68 | ) 69 | 70 | return [checkpoint_callback, RichModelSummary(max_depth=self.config.summary_depth)] 71 | 72 | def _create_trainer(self) -> L.Trainer: 73 | """Create Lightning trainer. 74 | 75 | Returns: 76 | Configured Lightning trainer 77 | """ 78 | return L.Trainer( 79 | default_root_dir=self.config.data_path / "training" / self.config.run_name, 80 | max_epochs=self.config.max_epochs, 81 | accelerator=self.config.accelerator, 82 | devices=self.config.n_devices, 83 | num_nodes=1, 84 | strategy=self.config.dist_strategy, 85 | callbacks=self._setup_callbacks(), 86 | gradient_clip_val=self.config.gradient_clip_val, 87 | gradient_clip_algorithm=self.config.gradient_clip_algorithm, 88 | ) 89 | 90 | def _create_dataloaders( 91 | self, 92 | train_dataset: RoutesDataset, 93 | val_dataset: RoutesDataset, 94 | ) -> tuple[DataLoader[tuple[Tensor, ...]], DataLoader[tuple[Tensor, ...]]]: 95 | """Create training and validation dataloaders. 96 | 97 | Args: 98 | train_dataset: Training dataset 99 | val_dataset: Validation dataset 100 | 101 | Returns: 102 | Tuple of (train_dataloader, val_dataloader) 103 | """ 104 | train_loader = torch.utils.data.DataLoader( 105 | dataset=train_dataset, 106 | batch_size=self.config.batch_size, 107 | shuffle=True, 108 | num_workers=self.config.num_workers, 109 | persistent_workers=True, 110 | pin_memory=True, 111 | ) 112 | 113 | val_loader = torch.utils.data.DataLoader( 114 | dataset=val_dataset, 115 | batch_size=self.config.batch_size, 116 | shuffle=False, 117 | num_workers=self.config.num_workers, 118 | persistent_workers=True, 119 | pin_memory=True, 120 | ) 121 | 122 | return train_loader, val_loader 123 | 124 | def train( 125 | self, 126 | model: torch.nn.Module, 127 | train_dataset: RoutesDataset, 128 | val_dataset: RoutesDataset, 129 | ) -> None: 130 | """Train the model. 131 | 132 | Args: 133 | model: Model to train 134 | train_dataset: Training dataset 135 | val_dataset: Validation dataset 136 | """ 137 | lightning_model = self._create_lightning_module(model) 138 | trainer = self._create_trainer() 139 | dl_train, dl_val = self._create_dataloaders(train_dataset, val_dataset) 140 | latest_ckpt = helpers.find_checkpoint(self.config.data_path / "training", self.config.run_name) 141 | 142 | if latest_ckpt is not None: 143 | print(f"Loading model from {latest_ckpt}") 144 | trainer.fit(lightning_model, dl_train, dl_val, ckpt_path=latest_ckpt) 145 | else: 146 | trainer.fit(lightning_model, dl_train, dl_val) 147 | -------------------------------------------------------------------------------- /src/directmultistep/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/batistagroup/DirectMultiStep/bb445196ce743317c179ccf8a7f5ee3966051cda/src/directmultistep/utils/__init__.py -------------------------------------------------------------------------------- /src/directmultistep/utils/io.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | from pathlib import Path 4 | from typing import Any, TypedDict 5 | 6 | from directmultistep.utils.pre_process import ( 7 | canonicalize_smiles, 8 | find_leaves, 9 | max_tree_depth, 10 | ) 11 | 12 | 13 | class DatasetDict(TypedDict, total=False): 14 | """ 15 | A dictionary type for storing dataset information. 16 | 17 | Attributes: 18 | products: List of product SMILES strings. 19 | starting_materials: List of starting material SMILES strings. 20 | path_strings: List of string representations of reaction paths. 21 | n_steps_list: List of integers representing the number of steps in each path. 22 | ds_name: Name of the dataset. 23 | nameToIdx: A dictionary mapping names to lists of indices. 24 | """ 25 | 26 | products: list[str] 27 | starting_materials: list[str] 28 | path_strings: list[str] 29 | n_steps_list: list[int] 30 | ds_name: str 31 | nameToIdx: dict[str, list[int]] | None 32 | 33 | 34 | def load_dataset_sm(path: Path) -> DatasetDict: 35 | """Loads a dataset from a pickle file containing starting materials. 36 | 37 | Args: 38 | path: The path to the pickle file. 39 | 40 | Returns: 41 | A dictionary containing the loaded dataset. 42 | """ 43 | with open(path, "rb") as file: 44 | products, starting_materials, path_strings, n_steps_list = pickle.load(file) 45 | ds_name = path.stem.split("_")[0] 46 | return { 47 | "products": products, 48 | "starting_materials": starting_materials, 49 | "path_strings": path_strings, 50 | "n_steps_list": n_steps_list, 51 | "ds_name": ds_name, 52 | } 53 | 54 | 55 | def load_dataset_nosm(path: Path) -> DatasetDict: 56 | """Loads a dataset from a pickle file without starting materials. 57 | 58 | Args: 59 | path: The path to the pickle file. 60 | 61 | Returns: 62 | A dictionary containing the loaded dataset. 63 | """ 64 | with open(path, "rb") as file: 65 | products, _, path_strings, n_steps_list = pickle.load(file) 66 | ds_name = path.stem.split("_")[0] 67 | return { 68 | "products": products, 69 | "path_strings": path_strings, 70 | "n_steps_list": n_steps_list, 71 | "ds_name": ds_name, 72 | } 73 | 74 | 75 | def save_dataset_sm(data: dict[str, Any], path: Path) -> None: 76 | """Saves a dataset to a pickle file, including starting materials. 77 | 78 | Args: 79 | data: The dataset dictionary to save. 80 | path: The path to save the pickle file. 81 | """ 82 | with open(path, "wb") as file: 83 | p, sm, ps, ns = data["products"], data.get("starting_materials", []), data["path_strings"], data["n_steps_list"] 84 | pickle.dump((p, sm, ps, ns), file) 85 | 86 | 87 | def convert_dict_of_lists_to_list_of_dicts(dict_of_lists: DatasetDict) -> list[dict[str, str]]: 88 | """Converts a dictionary of lists to a list of dictionaries. 89 | 90 | Args: 91 | dict_of_lists: The dictionary of lists to convert. 92 | 93 | Returns: 94 | A list of dictionaries. 95 | """ 96 | return [dict(zip(dict_of_lists.keys(), values)) for values in zip(*dict_of_lists.values())] 97 | 98 | 99 | def convert_list_of_dicts_to_dict_of_lists(list_of_dicts: list[dict[str, str]]) -> dict[str, list[str]]: 100 | """Converts a list of dictionaries to a dictionary of lists. 101 | 102 | Args: 103 | list_of_dicts: The list of dictionaries to convert. 104 | 105 | Returns: 106 | A dictionary of lists. 107 | """ 108 | return {key: [item[key] for item in list_of_dicts] for key in list_of_dicts[0].keys()} 109 | 110 | 111 | def load_pharma_compounds( 112 | path_to_json: Path, 113 | load_sm: bool = True, 114 | ) -> DatasetDict: 115 | """Loads pharmaceutical compounds from a JSON file. 116 | 117 | Args: 118 | path_to_json: The path to the JSON file. 119 | load_sm: Whether to load starting materials. 120 | 121 | Returns: 122 | A dictionary containing the loaded dataset. 123 | """ 124 | with open(path_to_json, "r") as file: 125 | data = json.load(file) 126 | _products, _sms, _path_strings, _steps_list = [], [], [], [] 127 | name_idx: dict[str, list[int]] = {} 128 | idx = 0 129 | for item in data: 130 | path_dict = eval(item["path"]) 131 | all_sm = find_leaves(path_dict) 132 | if load_sm: 133 | for sm in all_sm: 134 | name_idx.setdefault(item["name"], []).append(idx) 135 | _path_strings.append(item["path"]) 136 | _products.append(eval(item["path"])["smiles"]) 137 | _sms.append(sm) 138 | _steps_list.append(max_tree_depth(path_dict)) 139 | idx += 1 140 | else: 141 | name_idx.setdefault(item["name"], []).append(idx) 142 | _path_strings.append(item["path"]) 143 | _products.append(eval(item["path"])["smiles"]) 144 | _steps_list.append(max_tree_depth(path_dict)) 145 | idx += 1 146 | 147 | if load_sm: 148 | return { 149 | "products": _products, 150 | "starting_materials": _sms, 151 | "path_strings": _path_strings, 152 | "n_steps_list": _steps_list, 153 | "nameToIdx": name_idx, 154 | } 155 | else: 156 | return { 157 | "products": _products, 158 | "path_strings": _path_strings, 159 | "n_steps_list": _steps_list, 160 | "nameToIdx": name_idx, 161 | } 162 | 163 | 164 | def load_commercial_stock(path: Path) -> set[str]: 165 | """Loads a set of molecules from a file, canonicalizes them, and returns a set. 166 | 167 | Args: 168 | path: The path to the file containing molecules. 169 | 170 | Returns: 171 | A set of canonicalized SMILES strings. 172 | """ 173 | with open(path, "r") as file: 174 | stock = file.readlines() 175 | canonical_stock = set() 176 | for molecule in stock: 177 | canonical_stock.add(canonicalize_smiles(molecule.strip())) 178 | return canonical_stock 179 | -------------------------------------------------------------------------------- /src/directmultistep/utils/logging_config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | import os 4 | from typing import Any 5 | 6 | # --- Hardcoded Configuration --- 7 | LOGGING_CONFIG: dict[str, Any] = { 8 | "version": 1, 9 | "disable_existing_loggers": False, 10 | "formatters": { 11 | "standard": { 12 | "format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s", 13 | "datefmt": "%Y-%m-%d %H:%M:%S", 14 | } 15 | }, 16 | "handlers": { 17 | "console": { 18 | "class": "logging.StreamHandler", 19 | "formatter": "standard", 20 | "stream": "ext://sys.stdout", 21 | } 22 | }, 23 | "loggers": { 24 | "directmultistep": { 25 | "handlers": ["console"], 26 | "propagate": False, 27 | "level": "INFO", # Default level 28 | } 29 | }, 30 | } 31 | # --- End Hardcoded Configuration --- 32 | 33 | 34 | def setup_logging() -> None: 35 | """Setup logging configuration from hardcoded dict with environment variable override""" 36 | 37 | # Get log level from environment variable, default to INFO if not set 38 | log_level = os.getenv("DIRECTMULTISTEP_LOG_LEVEL", "INFO").upper() 39 | 40 | # Validate the log level 41 | valid_levels = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"} 42 | if log_level not in valid_levels: 43 | print(f"Invalid log level {log_level}, defaulting to INFO") 44 | log_level = "INFO" # Make sure to reset if invalid 45 | 46 | # Override the log level in the copied config 47 | LOGGING_CONFIG["loggers"]["directmultistep"]["level"] = log_level 48 | 49 | logging.config.dictConfig(LOGGING_CONFIG) 50 | 51 | 52 | logger = logging.getLogger("directmultistep") 53 | -------------------------------------------------------------------------------- /src/directmultistep/utils/pre_process.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from itertools import islice, permutations 3 | from typing import TypedDict, cast 4 | 5 | from rdkit import Chem, RDLogger 6 | 7 | RDLogger.DisableLog("rdApp.*") 8 | 9 | PaRoutesDict = dict[str, str | bool | list["PaRoutesDict"]] 10 | 11 | 12 | class FilteredDict(TypedDict, total=False): 13 | """A dictionary format for multistep routes, used in DirectMultiStep models. 14 | 15 | This dictionary is designed to represent a node in a synthetic route tree. 16 | It contains the SMILES string of a molecule and a list of its child nodes. 17 | To get its string format, use `stringify_dict`. 18 | 19 | Attributes: 20 | smiles: SMILES string of the molecule. 21 | children: List of child nodes, each a FilteredDict. 22 | """ 23 | 24 | smiles: str 25 | children: list["FilteredDict"] 26 | 27 | 28 | def filter_mol_nodes(node: PaRoutesDict) -> FilteredDict: 29 | """Filters a PaRoutes dictionary to keep only 'smiles' and 'children' keys. 30 | 31 | This function removes extra information like 'metadata', 'rsmi', and 32 | 'reaction_hash', keeping only the 'smiles' and 'children' keys. It also 33 | canonicalizes the SMILES string using RDKit. 34 | 35 | Args: 36 | node: A dictionary representing a node in a PaRoutes data structure. 37 | 38 | Returns: 39 | A FilteredDict containing the canonicalized SMILES and filtered children. 40 | 41 | Raises: 42 | ValueError: If the 'type' of the node is not 'mol' or if 'children' is not a list. 43 | """ 44 | # canonicalize smiles by passing through RDKit 45 | canonical_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(node["smiles"])) 46 | if "children" not in node: 47 | return {"smiles": canonical_smiles} 48 | if node.get("type") != "mol": 49 | raise ValueError(f"Expected 'type' to be 'mol', got {node.get('type', 'empty')}") 50 | 51 | filtered_node: FilteredDict = {"smiles": canonical_smiles, "children": []} 52 | # we skip one level of the PaRoutes dictionary as it contains the reaction meta data 53 | # assert isinstance(node["children"], list), f"Expected 'children' to be a list, got {type(node['children'])}" 54 | if not isinstance(node["children"], list): 55 | raise ValueError(f"Expected 'children' to be a list, got {type(node['children'])}") 56 | reaction_meta: list[PaRoutesDict] = node["children"] 57 | first_child = reaction_meta[0] 58 | for child in cast(list[PaRoutesDict], first_child["children"]): 59 | filtered_node["children"].append(filter_mol_nodes(child)) 60 | return filtered_node 61 | 62 | 63 | def max_tree_depth(node: FilteredDict) -> int: 64 | """Calculates the maximum depth of a synthetic route tree. 65 | 66 | Args: 67 | node: A FilteredDict representing a node in the route tree. 68 | 69 | Returns: 70 | The maximum depth of the tree. Returns 0 for a leaf node. 71 | """ 72 | if "children" not in node: 73 | return 0 # Leaf node, depth is 0 74 | else: 75 | child_depths = [ 76 | max_tree_depth(child) 77 | for child in node["children"] 78 | # if isinstance(child, dict) 79 | ] 80 | return 1 + max(child_depths) 81 | 82 | 83 | def find_leaves(node: FilteredDict) -> list[str]: 84 | """Finds the SMILES strings of all leaf nodes (starting materials) in a route tree. 85 | 86 | Args: 87 | node: A FilteredDict representing a node in the route tree. 88 | 89 | Returns: 90 | A list of SMILES strings representing the starting materials. 91 | """ 92 | leaves = [] 93 | if "children" in node: 94 | for child in node["children"]: 95 | leaves.extend(find_leaves(child)) 96 | else: 97 | leaves.append(node["smiles"]) 98 | return leaves 99 | 100 | 101 | def canonicalize_smiles(smiles: str) -> str: 102 | """Canonicalizes a SMILES string using RDKit. 103 | 104 | Args: 105 | smiles: The SMILES string to canonicalize. 106 | 107 | Returns: 108 | The canonicalized SMILES string. 109 | 110 | Raises: 111 | ValueError: If the SMILES string cannot be parsed by RDKit. 112 | """ 113 | mol = Chem.MolFromSmiles(smiles) 114 | if mol is None: 115 | raise ValueError(f"Failed to parse SMILES: {smiles}") 116 | return cast(str, Chem.MolToSmiles(mol)) 117 | 118 | 119 | def stringify_dict(data: FilteredDict) -> str: 120 | """Converts a FilteredDict to a string, removing spaces. 121 | 122 | Args: 123 | data: The FilteredDict to convert. 124 | 125 | Returns: 126 | A string representation of the FilteredDict with no spaces. 127 | """ 128 | return str(data).replace(" ", "") 129 | 130 | 131 | def generate_permutations(data: FilteredDict, max_perm: int | None = None) -> list[str]: 132 | """Generates permutations of a synthetic route by permuting the order of children. 133 | 134 | This function generates all possible permutations of a synthetic route by 135 | rearranging the order of child nodes at each level of the tree. It can 136 | optionally limit the number of permutations generated. 137 | 138 | Args: 139 | data: A FilteredDict representing the synthetic route. 140 | max_perm: An optional integer to limit the number of permutations generated. 141 | 142 | Returns: 143 | A list of stringified FilteredDicts representing the permuted routes. 144 | """ 145 | if "children" not in data or not data["children"]: 146 | return [stringify_dict(data)] 147 | 148 | child_permutations = [] 149 | for child in data["children"]: 150 | child_permutations.append(generate_permutations(child, max_perm)) 151 | 152 | all_combos = [] 153 | # Conditionally apply permutation limit 154 | permutation_generator = permutations(range(len(child_permutations))) 155 | if max_perm is not None: 156 | permutation_generator = islice(permutation_generator, max_perm) # type:ignore 157 | 158 | for combo in permutation_generator: 159 | for product in itertools.product(*(child_permutations[i] for i in combo)): 160 | new_data = data.copy() 161 | new_data["children"] = [eval(child_str) for child_str in product] 162 | all_combos.append(stringify_dict(new_data)) 163 | if max_perm is not None and len(all_combos) >= max_perm: 164 | return all_combos # Return early if maximum number of permutations is reached 165 | return all_combos 166 | 167 | 168 | def is_convergent(route: FilteredDict) -> bool: 169 | """Determines if a synthesis route is convergent (non-linear). 170 | 171 | A route is linear if for every transformation, at most one reactant has children 172 | (i.e., all other reactants are leaf nodes). A route is convergent if there exists 173 | at least one transformation where two or more reactants have children. 174 | 175 | Args: 176 | route: The synthesis route to analyze. 177 | 178 | Returns: 179 | True if the route is convergent (non-linear), False if it's linear. 180 | """ 181 | if "children" not in route: 182 | return False 183 | 184 | # Check if current node's transformation has 2 or more children with their own children 185 | children = route["children"] 186 | if len(children) >= 2: # Need at least 2 children for a transformation 187 | children_with_children = sum(1 for child in children if "children" in child) 188 | if children_with_children >= 2: 189 | return True 190 | 191 | # Recursively check children 192 | return any(is_convergent(child) for child in children) 193 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/batistagroup/DirectMultiStep/bb445196ce743317c179ccf8a7f5ee3966051cda/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_preprocess.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from directmultistep.utils.dataset import tokenize_path_string, tokenize_smile 4 | from directmultistep.utils.pre_process import ( 5 | filter_mol_nodes, 6 | find_leaves, 7 | generate_permutations, 8 | max_tree_depth, 9 | ) 10 | 11 | from .test_data import ( 12 | test1_leaves, 13 | test2_depth1, 14 | test3_depth2, 15 | test4_n1route0, 16 | test5_depth0_leaves, 17 | test6_depth1_leaves, 18 | test7_depth2_leaves, 19 | test8_n1route_leaves, 20 | test9_tknz_smiles, 21 | test10_tknz_path, 22 | ) 23 | 24 | test_filtering_and_depth = [ 25 | pytest.param(test1_leaves, 0, id="leaves"), 26 | pytest.param(test2_depth1, 1, id="depth1"), 27 | pytest.param(test3_depth2, 2, id="depth2"), 28 | pytest.param(test4_n1route0, 8, id="n1_routes_idx0"), 29 | ] 30 | 31 | 32 | @pytest.mark.parametrize("data, _", test_filtering_and_depth) 33 | def test_filter_mol_nodes(data, _): 34 | for item in data: 35 | assert filter_mol_nodes(item["paRoute"]) == item["filtered"] 36 | 37 | 38 | def test_filter_mol_nodes_invalid_type(): 39 | node = { 40 | "smiles": "COC(=O)c1ccc2c(c1)OCCO2", 41 | "children": [{"smiles": "BrCCBr"}, {"smiles": "COC(=O)c1ccc(O)c(O)c1"}], 42 | "type": "invalid", 43 | } 44 | with pytest.raises(ValueError) as exc_info: 45 | filter_mol_nodes(node) 46 | assert str(exc_info.value) == "Expected 'type' to be 'mol', got invalid" 47 | 48 | 49 | @pytest.mark.parametrize("data, expected_depth", test_filtering_and_depth) 50 | def test_max_tree_depth(data, expected_depth): 51 | for item in data: 52 | assert max_tree_depth(item["filtered"]) == expected_depth 53 | 54 | 55 | test_leaves = [ 56 | pytest.param(test5_depth0_leaves, id="depth0"), 57 | pytest.param(test6_depth1_leaves, id="depth1"), 58 | pytest.param(test7_depth2_leaves, id="depth2"), 59 | pytest.param(test8_n1route_leaves, id="n1route_idx0"), 60 | ] 61 | 62 | 63 | @pytest.mark.parametrize("data", test_leaves) 64 | def test_find_leaves(data): 65 | for item in data: 66 | assert find_leaves(item["filtered"]) == item["leaves"] 67 | 68 | 69 | @pytest.mark.parametrize("data", test9_tknz_smiles) 70 | def test_tokenize_smile(data): 71 | assert tokenize_smile(data[0]) == data[1] 72 | 73 | 74 | @pytest.mark.parametrize("data", test10_tknz_path) 75 | def test_tokenize_path(data): 76 | assert tokenize_path_string(data[0]) == data[1] 77 | 78 | 79 | def test_generate_permutations_no_children(): 80 | # Test data with no children 81 | data = {"smiles": "A"} 82 | assert generate_permutations(data) == [str(data).replace(" ", "")] 83 | 84 | 85 | def test_generate_permutations_single_child(): 86 | # Test data with one child 87 | data = {"smiles": "A", "children": [{"smiles": "B"}]} 88 | expected_output = [str(data).replace(" ", "")] 89 | assert generate_permutations(data) == expected_output 90 | 91 | 92 | def test_generate_permutations_multiple_children(): 93 | # Test data with multiple children 94 | data = {"smiles": "A", "children": [{"smiles": "B"}, {"smiles": "C"}]} 95 | expected_output = [ 96 | str({"smiles": "A", "children": [{"smiles": "B"}, {"smiles": "C"}]}).replace(" ", ""), 97 | str({"smiles": "A", "children": [{"smiles": "C"}, {"smiles": "B"}]}).replace(" ", ""), 98 | ] 99 | assert sorted(generate_permutations(data)) == sorted(expected_output) 100 | 101 | 102 | def test_generate_permutations_nested_children(): 103 | # Test data with nested children 104 | data = {"smiles": "A", "children": [{"smiles": "B", "children": [{"smiles": "C"}]}]} 105 | expected_output = [ 106 | str( 107 | { 108 | "smiles": "A", 109 | "children": [{"smiles": "B", "children": [{"smiles": "C"}]}], 110 | } 111 | ).replace(" ", "") 112 | ] 113 | assert generate_permutations(data) == expected_output 114 | 115 | 116 | def test_generate_permutations_with_limit(): 117 | # Test data with a permutation limit 118 | data = { 119 | "smiles": "A", 120 | "children": [{"smiles": "B"}, {"smiles": "C"}, {"smiles": "D"}], 121 | } 122 | # Limit to 2 permutations 123 | results = generate_permutations(data, max_perm=2) 124 | assert len(results) == 2 125 | 126 | 127 | def test_generate_permutations_complex_case(): 128 | # More complex structure with depth and multiple children at different levels 129 | data = { 130 | "smiles": "A", 131 | "children": [ 132 | {"smiles": "B", "children": [{"smiles": "C"}, {"smiles": "D"}]}, 133 | {"smiles": "E"}, 134 | ], 135 | } 136 | results = generate_permutations(data) 137 | # Test that the correct number of permutations is generated (factorial of children count at each level) 138 | assert len(results) == 4 # Since there are 2 at top level (B with its children, and E) 139 | 140 | 141 | @pytest.mark.parametrize( 142 | "data, expected", 143 | [ 144 | ({"smiles": "X"}, [str({"smiles": "X"}).replace(" ", "")]), 145 | ( 146 | {"smiles": "Y", "children": [{"smiles": "Z"}]}, 147 | [str({"smiles": "Y", "children": [{"smiles": "Z"}]}).replace(" ", "")], 148 | ), 149 | ], 150 | ) 151 | def test_generate_permutations_parametrized(data, expected): 152 | assert generate_permutations(data) == expected 153 | 154 | 155 | if __name__ == "__main__": 156 | pytest.main(["-v", "-s", __file__]) 157 | -------------------------------------------------------------------------------- /use-examples/data/configs/logbook/van_sm_6x3_6x3_256_noboth_seed=42/model_config.yaml: -------------------------------------------------------------------------------- 1 | encoder: 2 | vocab_dim: 53 3 | hid_dim: 256 4 | n_layers: 6 5 | n_heads: 8 6 | ff_mult: 3 7 | ff_activation: gelu 8 | dropout: 0.1 9 | attn_bias: false 10 | context_window: 280 11 | start_idx: 0 12 | mask_idx: 51 13 | pad_idx: 52 14 | initiate_steps: true 15 | include_steps: true 16 | model_type: EncoderAConfig 17 | decoder: 18 | vocab_dim: 53 19 | hid_dim: 256 20 | n_layers: 6 21 | n_heads: 8 22 | ff_mult: 3 23 | ff_activation: gelu 24 | dropout: 0.1 25 | attn_bias: false 26 | context_window: 1075 27 | start_idx: 0 28 | mask_idx: 51 29 | pad_idx: 52 30 | model_type: TransformerConfig 31 | -------------------------------------------------------------------------------- /use-examples/data/configs/logbook/van_sm_6x3_6x3_256_noboth_seed=42/training_config.yaml: -------------------------------------------------------------------------------- 1 | data_path: /Users/morgunov/batista/DirectMultiStep/use-examples/data 2 | run_name: van_sm_6x3_6x3_256_noboth_seed=42 3 | train_fname: unique_dataset_nperms=3_nsms=all_noboth_train=0.95.pkl 4 | val_fname: unique_dataset_nperms=3_nsms=all_noboth_val=0.05.pkl 5 | metadata_fname: dms_dictionary.yaml 6 | batch_size: 8 7 | learning_rate: 0.0002 8 | max_epochs: 40 9 | warmup_steps: 3000 10 | decay_steps: 80000 11 | decay_factor: 0.1 12 | pad_idx: 52 13 | mask_idx: 51 14 | save_top_k: -1 15 | checkpoint_every_n_epochs: 2 16 | num_workers: 1 17 | n_devices: 1 18 | seed: 42 19 | accelerator: cpu 20 | matmul_precision: high 21 | summary_depth: 2 22 | dist_strategy: fsdp 23 | gradient_clip_val: 1.0 24 | gradient_clip_algorithm: value 25 | -------------------------------------------------------------------------------- /use-examples/eval-subset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from directmultistep.generation.eval import EvalConfig, ModelEvaluator 4 | from directmultistep.model import ModelFactory 5 | from directmultistep.training import TrainingConfig 6 | 7 | __mode__ = "local" 8 | assert __mode__ in ["local", "cluster"] 9 | 10 | if __mode__ == "local": 11 | base_path = Path(__file__).resolve().parent.parent 12 | elif __mode__ in ["cluster"]: 13 | base_path = Path(__file__).resolve().parent.parent 14 | 15 | data_path = base_path / "data" 16 | 17 | run_name = "van_sm_6x3_6x3_256_noboth_seed=42" 18 | logbook_path = data_path / "configs" / "logbook" / run_name 19 | train_conf = TrainingConfig.load(logbook_path / "training_config.yaml") 20 | factory = ModelFactory.from_config_file(logbook_path / "model_config.yaml", compile_model=False) 21 | 22 | ec = EvalConfig( 23 | data_path=data_path, 24 | run_name=run_name, 25 | eval_dataset="n1_50", 26 | epoch=46, 27 | use_sm=True, 28 | use_steps=True, 29 | beam_width=50, 30 | enc_active_experts=None, 31 | dec_active_experts=None, 32 | ) 33 | ec.save(logbook_path / f"{ec.eval_name}_config.yaml") 34 | 35 | 36 | if __name__ == "__main__": 37 | factory.check_for_eval_config_updates(ec) 38 | model = factory.create_model() 39 | device = ModelFactory.determine_device() 40 | # model = factory.load_lightning_checkpoint(model, ec.checkpoint_path, device=device) 41 | pblshd = data_path / "checkpoints" / "flash_ep=46.ckpt" 42 | model = factory.load_checkpoint(model, pblshd, device=device) 43 | 44 | evalObj = ModelEvaluator(model, ec, train_conf, device=device) 45 | 46 | evalObj.load_eval_dataset() 47 | evalObj.prepare_beam_search() 48 | 49 | all_beam_results_NS2 = evalObj.run_beam_search() 50 | 51 | top_ks = evalObj.calculate_top_k_accuracy() 52 | print(top_ks) 53 | 54 | # for pharma 55 | name_to_rank = evalObj.prepare_name_to_rank() 56 | print(name_to_rank) 57 | -------------------------------------------------------------------------------- /use-examples/generate-route.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from directmultistep.generate import generate_routes 4 | from directmultistep.utils.web_visualize import draw_tree_from_path_string 5 | 6 | data_path = Path(__file__).resolve().parents[1] / "data" 7 | ckpt_path = data_path / "checkpoints" 8 | fig_path = data_path / "figures" 9 | config_path = data_path / "configs" / "dms_dictionary.yaml" 10 | 11 | 12 | def visualize_routes(path_strings: list[str], theme: str = "light") -> list[str]: 13 | """Visualize synthesis routes and return SVG strings.""" 14 | import random 15 | 16 | request_id = "rnd_" + str(random.randint(0, 1000000)) 17 | save_folder = fig_path / request_id 18 | save_folder.mkdir(parents=True, exist_ok=True) 19 | 20 | svg_results = [] 21 | for i, path_string in enumerate(path_strings): 22 | svg_tree = draw_tree_from_path_string( 23 | path_string=path_string, 24 | save_path=save_folder / f"result_{i}", 25 | width=600, 26 | height=600, 27 | x_margin=40, 28 | y_margin=120, 29 | theme=theme, 30 | ) 31 | svg_results.append(svg_tree) 32 | 33 | return svg_results 34 | 35 | 36 | if __name__ == "__main__": 37 | # Example usage 38 | target = "CNCc1cc(-c2ccccc2F)n(S(=O)(=O)c2cccnc2)c1" 39 | sm = "CN" 40 | 41 | # Find routes with starting material using flash model 42 | paths = generate_routes( 43 | target, n_steps=2, starting_material=sm, model="flash", beam_size=5, config_path=config_path, ckpt_dir=ckpt_path 44 | ) 45 | # paths = generate_routes(target, n_steps=2, starting_material=sm, model="flash-20M", beam_size=5) 46 | # paths = generate_routes(target, n_steps=2, starting_material=sm, model="flex-20M", beam_size=5) 47 | 48 | # # Find routes without starting material using deep model 49 | # paths = generate_routes(target, n_steps=2, model="deep") 50 | # paths = generate_routes(target, n_steps=2, model="wide", beam_size=20) 51 | 52 | # # Find routes using explorer model (automatically determines steps) 53 | # paths = generate_routes(target, starting_material=sm, model="explorer", beam_size=5) 54 | # paths = generate_routes(target, model="explorer XL", beam_size=5) 55 | 56 | svg_contents = visualize_routes(paths) 57 | -------------------------------------------------------------------------------- /use-examples/paper-figures.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from pathlib import Path 3 | 4 | import plotly.io as pio 5 | 6 | from directmultistep.analysis.paper.dataset_analysis import ( 7 | plot_convergent_fraction_by_length, 8 | plot_convergent_fraction_overall, 9 | plot_leaf_distribution, 10 | plot_route_length_distribution, 11 | ) 12 | from directmultistep.analysis.paper.linear_vs_convergent import ( 13 | ModelPlotConfig, 14 | RouteAnalyzer, 15 | process_model_configs, 16 | ) 17 | from directmultistep.utils.io import load_dataset_sm 18 | from directmultistep.utils.logging_config import logger 19 | 20 | pio.kaleido.scope.mathjax = None 21 | 22 | base_path = Path(__name__).resolve().parent 23 | save_path = base_path / "data" / "figures" / "paper" 24 | base_path = Path("/Users/morgunov/batista/RetroChallenge") 25 | prcsd_path = base_path / "data" / "processed" 26 | eval_path = base_path / "data" / "evaluation" 27 | save_path.mkdir(parents=True, exist_ok=True) 28 | folders = [f.name for f in eval_path.glob("*/")] 29 | 30 | 31 | if __name__ == "__main__": 32 | # Load datasets 33 | train_dataset = load_dataset_sm(prcsd_path / "unique_dataset_nperms=3_nsms=all_noboth.pkl") 34 | n1_dataset = load_dataset_sm(prcsd_path / "n1_dataset_nperms=1_nsms=1.pkl") 35 | n5_dataset = load_dataset_sm(prcsd_path / "n5_dataset_nperms=1_nsms=1.pkl") 36 | 37 | rerun = { 38 | "route-distribution": False, 39 | "leaf-distribution": False, 40 | "convergent-fraction": False, 41 | "topk-accuracy": False, 42 | "extraction-distribution": False, 43 | "accuracy-by-length": True, 44 | } 45 | 46 | # ------------ Route Length Distribution in Datasets ------------ 47 | if rerun["route-distribution"]: 48 | fig = plot_route_length_distribution( 49 | train_dataset["n_steps_list"], 50 | n1_dataset["n_steps_list"], 51 | n5_dataset["n_steps_list"], 52 | ) 53 | fig.write_image(save_path / "route_length_distribution.pdf") 54 | # fig.write_html(save_path / "route_length_distribution.html", include_plotlyjs="cdn") 55 | 56 | # ------------ Leaf Distribution in Datasets ------------ 57 | if rerun["leaf-distribution"]: 58 | fig = plot_leaf_distribution( 59 | train_dataset["path_strings"], 60 | n1_dataset["path_strings"], 61 | n5_dataset["path_strings"], 62 | ) 63 | fig.write_image(save_path / "leaf_distribution.pdf") 64 | # fig.write_html(save_path / "leaf_distribution.html", include_plotlyjs="cdn") 65 | 66 | # ------------ Convergent Route Fraction by Length ------------ 67 | if rerun["convergent-fraction"]: 68 | fig = plot_convergent_fraction_by_length( 69 | train_dataset["path_strings"], 70 | train_dataset["n_steps_list"], 71 | n1_dataset["path_strings"], 72 | n1_dataset["n_steps_list"], 73 | n5_dataset["path_strings"], 74 | n5_dataset["n_steps_list"], 75 | ) 76 | # fig.show() 77 | fig.write_image(save_path / "convergent_fraction_by_length.pdf") 78 | # fig.write_html(save_path / "convergent_fraction_by_length.html", include_plotlyjs="cdn") 79 | 80 | fig = plot_convergent_fraction_overall( 81 | train_dataset["path_strings"], 82 | n1_dataset["path_strings"], 83 | n5_dataset["path_strings"], 84 | ) 85 | fig.write_image(save_path / "convergent_fraction_overall.pdf") 86 | # fig.write_html(save_path / "convergent_fraction_overall.html", include_plotlyjs="cdn") 87 | 88 | # ---------------------------------------------------------------- 89 | # fmt:off 90 | model_configs = [ 91 | ModelPlotConfig(model_name="flex_20M", epoch="epoch=20", variant_base="b50_sm_st_ea=1_da=1"), 92 | ModelPlotConfig(model_name="flash_10M", epoch="epoch=46", variant_base="b50_sm_st"), 93 | ModelPlotConfig(model_name="flash_20M", epoch="epoch=31", variant_base="b50_sm_st"), 94 | ModelPlotConfig(model_name="flex_20M", epoch="epoch=20", variant_base="b50_sm_st_ea=2_da=2"), 95 | ModelPlotConfig(model_name="flash_10M", epoch="epoch=46", variant_base="b50_nosm_st"), 96 | ModelPlotConfig(model_name="flex_20M", epoch="epoch=20", variant_base="b50_nosm_st_ea=2_da=2"), 97 | ModelPlotConfig(model_name="deep_40M", epoch="epoch=47", variant_base="b50_nosm_st"), 98 | ModelPlotConfig(model_name="wide_40M", epoch="epoch=31", variant_base="b50_nosm_st_ea=2_da=2"), 99 | ModelPlotConfig(model_name="explorer_19M", epoch="epoch=18", variant_base="b50_sm_nost_ea=2_da=2"), 100 | ModelPlotConfig(model_name="explorer_19M", epoch="epoch=18", variant_base="b50_nosm_nost_ea=2_da=2"), 101 | ModelPlotConfig(model_name="explorer_50M", epoch="epoch=16", variant_base="b50_nosm_nost_ea=2_da=2"), 102 | ] 103 | # fmt:on 104 | 105 | # ------------ Top-K Accuracy for Convergent/Non-Convergent Routes ------------ 106 | if rerun["topk-accuracy"]: 107 | for dataset in [n1_dataset, n5_dataset]: 108 | configs = [config.with_dataset(dataset["ds_name"]) for config in model_configs] 109 | result_paths, trace_names = process_model_configs(eval_path, configs, dataset) 110 | fig = RouteAnalyzer.create_comparative_bar_plots( 111 | result_paths, 112 | trace_names, 113 | k_vals=[1, 2, 3, 4, 5, 10], # Only show these k values 114 | # title="Top-k Accuracy Comparison - n5 Dataset", 115 | ) 116 | fig.write_image( 117 | save_path / f"{dataset['ds_name']}_topk_accuracy_subplots.pdf", 118 | ) 119 | # fig.write_html(save_path / f"{dataset['ds_name']}_topk_accuracy_subplots.html", include_plotlyjs="cdn") 120 | 121 | # ------------ Route Distribution Plots ------------ 122 | if rerun["extraction-distribution"]: 123 | for dataset in [n1_dataset, n5_dataset]: 124 | for config in model_configs: 125 | config = config.with_dataset(dataset["ds_name"]) 126 | logger.info(f"Processing {config.model_name} evaluation {config.variant}") 127 | res_path = config.get_result_path(eval_path) 128 | 129 | with open(res_path / "valid_paths_NS2n.pkl", "rb") as f: 130 | valid_routes = pickle.load(f) 131 | with open(res_path / "processed_paths_NS2n_true_reacs=False_stock=False.pkl", "rb") as f: 132 | processed_no_stock = pickle.load(f) 133 | with open(res_path / "processed_paths_NS2n_true_reacs=False_stock=True.pkl", "rb") as f: 134 | processed_with_stock = pickle.load(f) 135 | 136 | fig = RouteAnalyzer.visualize_route_processing_stages( 137 | valid_routes=valid_routes, 138 | processed_routes_no_stock=processed_no_stock, 139 | processed_routes_with_stock=processed_with_stock, 140 | true_routes=dataset["path_strings"], 141 | dataset_name=f"{dataset['ds_name']} Dataset ({config.model_name}, {config.display_name})", 142 | show_filtered_stats=False, 143 | ) 144 | fig.write_image(save_path / f"{dataset['ds_name']}_route_processing_stages_{config.save_suffix}.pdf") 145 | # fig.write_html( 146 | # save_path / f"{dataset['ds_name']}_route_processing_stages_{config.save_suffix}.html", 147 | # include_plotlyjs="cdn", 148 | # ) 149 | 150 | # ------------ Top-K Accuracy by Route Length ------------ 151 | if rerun["accuracy-by-length"]: 152 | for base_config in model_configs: 153 | datasets = [n1_dataset, n5_dataset] 154 | configs = [base_config.with_dataset(ds["ds_name"]) for ds in datasets] 155 | 156 | result_paths = [] 157 | trace_names = [] 158 | for config, dataset in zip(configs, datasets): 159 | paths, names = process_model_configs(eval_path, [config], dataset) 160 | result_paths.extend(paths) 161 | trace_names.extend(names) 162 | 163 | # Create single plot 164 | fig = RouteAnalyzer.create_accuracy_by_length_plot( 165 | result_paths=result_paths, 166 | datasets=datasets, 167 | configs=configs, 168 | k_vals=[1, 10], 169 | title="Top-k Accuracy by Route Length", 170 | ) 171 | fig.write_image(save_path / f"accuracy_by_length_{configs[0].save_suffix}.pdf") 172 | # fig.write_html(save_path / f"accuracy_by_length_{configs[0].save_suffix}.html", include_plotlyjs="cdn") 173 | 174 | # Create subplot figure 175 | fig = RouteAnalyzer.create_accuracy_by_length_subplots( 176 | result_paths=result_paths, 177 | datasets=datasets, 178 | configs=configs, 179 | k_vals=[1, 10], 180 | title="Top-k Accuracy by Route Length - Route Type Comparison", 181 | ) 182 | fig.write_image(save_path / f"accuracy_by_length_subplots_{configs[0].save_suffix}.pdf") 183 | # fig.write_html( 184 | # save_path / f"accuracy_by_length_subplots_{configs[0].save_suffix}.html", include_plotlyjs="cdn" 185 | # ) 186 | -------------------------------------------------------------------------------- /use-examples/train-model.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from directmultistep import helpers 4 | from directmultistep.model import ModelFactory 5 | from directmultistep.training import ModelTrainer, TrainingConfig 6 | 7 | __mode__ = "local" 8 | assert __mode__ in ["local", "cluster"] 9 | 10 | if __mode__ == "local": 11 | # replace with .parent if you place train-model.py in root folder 12 | # .parents[1] if you keep it in use-examples 13 | base_path = Path(__file__).resolve().parents[1] 14 | n_workers = 1 15 | n_devices = 1 16 | accelerator = "cpu" 17 | batch_size = 8 18 | elif __mode__ == "shared": 19 | base_path = Path(__file__).resolve().parent # change it to your path 20 | n_workers = 64 21 | n_devices = 1 22 | accelerator = "auto" 23 | batch_size = 32 * 4 24 | data_path = base_path / "data" 25 | 26 | factory = ModelFactory.from_preset("flash_10M", compile_model=False) 27 | # or any other preset name from src/directmultistep/dms/model/default_configs 28 | # or create your own config, see src/directmultistep/model/factory.py for examples 29 | config = TrainingConfig( 30 | data_path=data_path, 31 | run_name="van_sm_6x3_6x3_256_noboth", 32 | train_fname="unique_dataset_nperms=3_nsms=all_noboth_train=0.95.pkl", 33 | val_fname="unique_dataset_nperms=3_nsms=all_noboth_val=0.05.pkl", 34 | metadata_fname="dms_dictionary.yaml", 35 | batch_size=batch_size, 36 | learning_rate=2e-4, 37 | max_epochs=40, 38 | warmup_steps=3000, 39 | decay_steps=80_000, 40 | decay_factor=0.1, 41 | pad_idx=factory.config.decoder.pad_idx, 42 | mask_idx=factory.config.decoder.mask_idx, 43 | save_top_k=-1, # -1 will save all 44 | checkpoint_every_n_epochs=2, # every 2 epochs 45 | summary_depth=2, 46 | accelerator=accelerator, 47 | num_workers=n_workers, 48 | n_devices=n_devices, 49 | ) 50 | 51 | # Save configs to logbook 52 | logbook_path = data_path / "configs" / "logbook" / config.run_name 53 | logbook_path.mkdir(parents=True, exist_ok=True) 54 | 55 | config.save(logbook_path / "training_config.yaml") 56 | factory.config.save(logbook_path / "model_config.yaml") 57 | 58 | 59 | train_dataset, val_dataset = helpers.prepare_datasets( 60 | train_data_path=data_path / "processed" / config.train_fname, 61 | val_data_path=data_path / "processed" / config.val_fname, 62 | metadata_path=data_path / "configs" / config.metadata_fname, 63 | load_sm=True, 64 | mode="training", 65 | ) 66 | 67 | model = factory.create_model() 68 | trainer = ModelTrainer(config) 69 | 70 | if __name__ == "__main__": 71 | # this has to be in main block to avoid issues with multiprocessing 72 | trainer.train(model, train_dataset, val_dataset) 73 | -------------------------------------------------------------------------------- /use-examples/visualize-train-curves.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from directmultistep.analysis.training import ( 4 | RunConfig, 5 | plot_learning_rates, 6 | plot_training_curves, 7 | ) 8 | 9 | data_path = Path(__name__).resolve().parent / "data" 10 | train_path = data_path / "training" 11 | eval_path = data_path / "evaluation" 12 | 13 | 14 | runs = [ 15 | RunConfig("sm_6x3_6x3_256_noboth_unique", "Flash"), 16 | RunConfig("nosm_6x3_6x3_256_noboth_unique", "Flash (no SM)"), 17 | RunConfig("nosm_12x3_36x3_256_noboth", "Deep"), 18 | RunConfig("moe_sm_2x3_6x3_6x3_256_cap_3.5e-4", "Flex"), 19 | RunConfig("moe_nosm_2x3_12x3_12x3_256_cap_3.5e-4", "Wide"), 20 | RunConfig("moe_nosm_2x3_6x3_24x3_256_cap_3e-4_nosteps_v2", "Explorer XL"), 21 | ] 22 | if __name__ == "__main__": 23 | train_fig_tokens = plot_training_curves(train_path, runs, x_axis="processed_tokens") 24 | train_fig_tokens.show() 25 | 26 | train_fig_epoch = plot_training_curves(train_path, runs, x_axis="epoch") 27 | train_fig_epoch.show() 28 | 29 | train_fig_step = plot_training_curves(train_path, runs, x_axis="step") 30 | train_fig_step.show() 31 | 32 | lr_fig = plot_learning_rates(train_path, runs) 33 | lr_fig.show() 34 | --------------------------------------------------------------------------------