├── .gitattributes ├── .github └── workflows │ └── docs.yml ├── .gitignore ├── .pre-commit-config.yaml ├── ADVANCED_USAGE.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── Makefile ├── _templates │ ├── custom-class-template.rst │ └── custom-module-template.rst ├── api.md ├── backbones │ ├── eqv2.md │ ├── jmp.md │ ├── m3gnet.md │ ├── mace.md │ ├── mattersim.md │ └── orb.md ├── conf.py ├── contributing.md ├── guides │ ├── datasets.md │ ├── fine-tuning.md │ ├── lightning.md │ ├── model_usage.md │ ├── normalization.md │ ├── recipes.md │ └── training_config.md ├── index.md ├── installation.md ├── introduction.md ├── license.md ├── motivation.md ├── requirements-torch.txt └── requirements.txt ├── examples ├── lora_decompose │ └── Li3PO4-checkpoints │ │ └── mattersim-1m-best-MPx3.ckpt ├── matbench-discovery │ ├── collect.py │ ├── results.md │ └── split_relax.py ├── matbench │ ├── matbenchmark-foldx.py │ ├── results.md │ └── screening.py ├── structure-optimization │ └── bfgs.py └── water-thermodynamics │ ├── README.md │ ├── data │ ├── H2O.xyz │ ├── train_30_water_1000_eVAng.xyz │ ├── train_water_1000_eVAng.xyz │ ├── val_water_1000_eVAng.xyz │ ├── water_1000_eVAng-energy_reference.json │ └── water_1000_eVAng.xyz │ ├── draw_rdf_plots.py │ ├── energy_reference.py │ ├── md.py │ ├── rdf_analysis.py │ ├── run.sh │ └── water-finetune.py ├── notebooks ├── eqv2-omat.ipynb ├── jmp-omat-autosplit.ipynb ├── jmp-omat.ipynb ├── m3gnet-waterthermo.ipynb ├── mattersim-waterthermo.ipynb └── orb-omat.ipynb ├── pyproject.toml ├── requirements.txt └── src └── mattertune ├── .nshconfig.generated.json ├── __init__.py ├── backbones ├── __init__.py ├── eqV2 │ ├── __init__.py │ └── model.py ├── jmp │ ├── __init__.py │ ├── model.py │ └── util.py ├── m3gnet │ ├── __init__.py │ └── model.py ├── mace_foundation │ ├── __init__.py │ └── model.py ├── mattersim │ ├── __init__.py │ └── model.py ├── orb │ ├── __init__.py │ └── model.py └── util.py ├── callbacks ├── early_stopping.py ├── ema.py └── model_checkpoint.py ├── configs ├── .gitattributes ├── __init__.py ├── backbones │ ├── __init__.py │ ├── eqV2 │ │ ├── __init__.py │ │ └── model │ │ │ └── __init__.py │ ├── jmp │ │ ├── __init__.py │ │ └── model │ │ │ └── __init__.py │ ├── m3gnet │ │ ├── __init__.py │ │ └── model │ │ │ └── __init__.py │ ├── mace_foundation │ │ ├── __init__.py │ │ └── model │ │ │ └── __init__.py │ ├── mattersim │ │ ├── __init__.py │ │ └── model │ │ │ └── __init__.py │ └── orb │ │ ├── __init__.py │ │ └── model │ │ └── __init__.py ├── callbacks │ ├── __init__.py │ ├── early_stopping │ │ └── __init__.py │ ├── ema │ │ └── __init__.py │ └── model_checkpoint │ │ └── __init__.py ├── data │ ├── __init__.py │ ├── atoms_list │ │ └── __init__.py │ ├── base │ │ └── __init__.py │ ├── datamodule │ │ └── __init__.py │ ├── db │ │ └── __init__.py │ ├── json_data │ │ └── __init__.py │ ├── matbench │ │ └── __init__.py │ ├── mp │ │ └── __init__.py │ ├── mptraj │ │ └── __init__.py │ ├── omat24 │ │ └── __init__.py │ └── xyz │ │ └── __init__.py ├── finetune │ ├── __init__.py │ ├── base │ │ └── __init__.py │ ├── loss │ │ └── __init__.py │ ├── lr_scheduler │ │ └── __init__.py │ ├── optimizer │ │ └── __init__.py │ └── properties │ │ └── __init__.py ├── loggers │ └── __init__.py ├── main │ └── __init__.py ├── normalization │ └── __init__.py ├── recipes │ ├── __init__.py │ ├── base │ │ └── __init__.py │ ├── ema │ │ └── __init__.py │ ├── lora │ │ └── __init__.py │ └── noop │ │ └── __init__.py ├── registry │ └── __init__.py └── wrappers │ ├── __init__.py │ └── property_predictor │ └── __init__.py ├── data ├── __init__.py ├── atoms_list.py ├── base.py ├── datamodule.py ├── db.py ├── json_data.py ├── matbench.py ├── mp.py ├── mptraj.py ├── omat24.py ├── util │ └── split_dataset.py └── xyz.py ├── finetune ├── __init__.py ├── base.py ├── callbacks │ └── freeze_backbone.py ├── data_util.py ├── loader.py ├── loss.py ├── lr_scheduler.py ├── metrics.py ├── optimizer.py └── properties.py ├── loggers.py ├── main.py ├── normalization.py ├── recipes ├── __init__.py ├── base.py ├── ema.py ├── lora.py └── noop.py ├── registry.py ├── util.py └── wrappers ├── ase_calculator.py └── property_predictor.py /.gitattributes: -------------------------------------------------------------------------------- 1 | .nshconfig.generated.json linguist-generated=true 2 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Deploy Documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main # or your default branch 7 | paths: 8 | - "docs/**" 9 | - ".github/workflows/docs.yml" 10 | 11 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages 12 | permissions: 13 | contents: read 14 | pages: write 15 | id-token: write 16 | 17 | # Allow only one concurrent deployment 18 | concurrency: 19 | group: "pages" 20 | cancel-in-progress: true 21 | 22 | jobs: 23 | build: 24 | runs-on: ubuntu-latest 25 | steps: 26 | - uses: actions/checkout@v4 27 | 28 | - name: Set up Python 29 | uses: actions/setup-python@v5 30 | with: 31 | python-version: "3.10" 32 | cache: "pip" 33 | 34 | - name: Install dependencies 35 | run: | 36 | python -m pip install --upgrade pip 37 | pip install -r docs/requirements-torch.txt 38 | pip install -r docs/requirements.txt 39 | pip install --no-deps -e . 40 | 41 | - name: Build documentation 42 | run: | 43 | cd docs 44 | # Create _static directory 45 | mkdir -p _static 46 | # Create _autosummary directory 47 | mkdir -p _autosummary 48 | make html 49 | # Create .nojekyll file to allow files and folders starting with an underscore 50 | touch _build/html/.nojekyll 51 | 52 | - name: Upload artifact 53 | uses: actions/upload-pages-artifact@v3 54 | with: 55 | path: docs/_build/html 56 | 57 | deploy: 58 | environment: 59 | name: github-pages 60 | url: ${{ steps.deployment.outputs.page_url }} 61 | runs-on: ubuntu-latest 62 | needs: build 63 | if: github.ref == 'refs/heads/main' # only deploy from main branch 64 | steps: 65 | - name: Deploy to GitHub Pages 66 | id: deployment 67 | uses: actions/deploy-pages@v4 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # File created using '.gitignore Generator' for Visual Studio Code: https://bit.ly/vscode-gig 2 | # Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,linux,python 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,linux,python 4 | 5 | ### Linux ### 6 | *~ 7 | 8 | # temporary files which can be created if a process still has a handle open of a deleted file 9 | .fuse_hidden* 10 | 11 | # KDE directory preferences 12 | .directory 13 | 14 | # Linux trash folder which might appear on any partition or disk 15 | .Trash-* 16 | 17 | # .nfs files are created when an open file is removed but is still being accessed 18 | .nfs* 19 | 20 | ### Python ### 21 | # Byte-compiled / optimized / DLL files 22 | __pycache__/ 23 | *.py[cod] 24 | *$py.class 25 | 26 | # C extensions 27 | *.so 28 | 29 | # Distribution / packaging 30 | .Python 31 | build/ 32 | develop-eggs/ 33 | dist/ 34 | downloads/ 35 | eggs/ 36 | .eggs/ 37 | lib/ 38 | lib64/ 39 | parts/ 40 | sdist/ 41 | var/ 42 | wheels/ 43 | share/python-wheels/ 44 | *.egg-info/ 45 | .installed.cfg 46 | *.egg 47 | MANIFEST 48 | 49 | # PyInstaller 50 | # Usually these files are written by a python script from a template 51 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 52 | *.manifest 53 | *.spec 54 | 55 | # Installer logs 56 | pip-log.txt 57 | pip-delete-this-directory.txt 58 | 59 | # Unit test / coverage reports 60 | htmlcov/ 61 | .tox/ 62 | .nox/ 63 | .coverage 64 | .coverage.* 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | *.cover 69 | *.py,cover 70 | .hypothesis/ 71 | .pytest_cache/ 72 | cover/ 73 | 74 | # Translations 75 | *.mo 76 | *.pot 77 | 78 | # Django stuff: 79 | *.log 80 | local_settings.py 81 | db.sqlite3 82 | db.sqlite3-journal 83 | 84 | # Flask stuff: 85 | instance/ 86 | .webassets-cache 87 | 88 | # Scrapy stuff: 89 | .scrapy 90 | 91 | # Sphinx documentation 92 | docs/_build/ 93 | 94 | # PyBuilder 95 | .pybuilder/ 96 | target/ 97 | 98 | # Jupyter Notebook 99 | .ipynb_checkpoints 100 | 101 | # IPython 102 | profile_default/ 103 | ipython_config.py 104 | 105 | # pyenv 106 | # For a library or package, you might want to ignore these files since the code is 107 | # intended to run in multiple environments; otherwise, check them in: 108 | # .python-version 109 | 110 | # pipenv 111 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 112 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 113 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 114 | # install all needed dependencies. 115 | #Pipfile.lock 116 | 117 | # poetry 118 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 119 | # This is especially recommended for binary packages to ensure reproducibility, and is more 120 | # commonly ignored for libraries. 121 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 122 | #poetry.lock 123 | 124 | # pdm 125 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 126 | #pdm.lock 127 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 128 | # in version control. 129 | # https://pdm.fming.dev/#use-with-ide 130 | .pdm.toml 131 | 132 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 133 | __pypackages__/ 134 | 135 | # Celery stuff 136 | celerybeat-schedule 137 | celerybeat.pid 138 | 139 | # SageMath parsed files 140 | *.sage.py 141 | 142 | # Environments 143 | .env 144 | .venv 145 | env/ 146 | venv/ 147 | ENV/ 148 | env.bak/ 149 | venv.bak/ 150 | 151 | # Spyder project settings 152 | .spyderproject 153 | .spyproject 154 | 155 | # Rope project settings 156 | .ropeproject 157 | 158 | # mkdocs documentation 159 | /site 160 | 161 | # mypy 162 | .mypy_cache/ 163 | .dmypy.json 164 | dmypy.json 165 | 166 | # Pyre type checker 167 | .pyre/ 168 | 169 | # pytype static type analyzer 170 | .pytype/ 171 | 172 | # Cython debug symbols 173 | cython_debug/ 174 | 175 | # PyCharm 176 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 177 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 178 | # and can be added to the global gitignore or merged into this file. For a more nuclear 179 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 180 | #.idea/ 181 | 182 | ### Python Patch ### 183 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 184 | poetry.toml 185 | 186 | # ruff 187 | .ruff_cache/ 188 | 189 | # LSP config files 190 | pyrightconfig.json 191 | 192 | ### VisualStudioCode ### 193 | .vscode/* 194 | !.vscode/settings.json 195 | !.vscode/tasks.json 196 | !.vscode/launch.json 197 | !.vscode/extensions.json 198 | !.vscode/*.code-snippets 199 | 200 | # Local History for Visual Studio Code 201 | .history/ 202 | 203 | # Built Visual Studio Code Extensions 204 | *.vsix 205 | 206 | ### VisualStudioCode Patch ### 207 | # Ignore all local history of files 208 | .history 209 | .ionide 210 | 211 | # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,linux,python 212 | 213 | # Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option) 214 | 215 | lightning_logs/ 216 | nshtrainer/ 217 | wandb/ 218 | abandoned/ 219 | .ruff_cache/ 220 | examples/hidden/ 221 | checkpoints*/ 222 | md_results/ 223 | results_backup/ 224 | ZnMn2O4_*/ 225 | examples/matbench/data/ 226 | bfg.jar 227 | .dir2textignore 228 | docs/_autosummary/ 229 | contents/ 230 | results/ -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: v0.8.1 4 | hooks: 5 | - id: ruff 6 | args: [--fix, --extend-select, I] # Fix import sorting 7 | - id: ruff-format # Format code 8 | - id: ruff 9 | name: ruff-check 10 | args: [--no-fix] # Just check for issues 11 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to MatterTune 2 | 3 | First off, thank you for considering contributing to MatterTune! We want to make contributing to MatterTune as easy and transparent as possible, whether it's: 4 | 5 | - Reporting a bug 6 | - Discussing the current state of the code 7 | - Submitting a fix 8 | - Proposing new features 9 | - Becoming a maintainer 10 | 11 | ## Development Process 12 | We use GitHub to host code, to track issues and feature requests, as well as accept pull requests. 13 | 14 | 1. Fork the repo and create your branch from `main` 15 | 2. If you've added code that should be tested, add tests 16 | 3. If you've changed APIs, update the documentation 17 | 4. Ensure the test suite passes 18 | 5. Make sure your code lints 19 | 6. Issue that pull request! 20 | 21 | ## Development Setup 22 | 23 | 1. Clone your fork of the repo: 24 | ```bash 25 | git clone https://github.com/Fung-Lab/mattertune.git 26 | ``` 27 | 28 | 2. Install development dependencies: 29 | ```bash 30 | cd mattertune 31 | pip install -e 32 | ``` 33 | 34 | 3. Install pre-commit hooks: 35 | ```bash 36 | pre-commit install 37 | ``` 38 | 39 | ## Code Style 40 | 41 | - We use [ruff](https://github.com/charliermarsh/ruff) for formatting, linting, and import sorting 42 | - We use [pyright](https://github.com/microsoft/pyright) for type checking 43 | 44 | Our pre-commit hooks will automatically format your code when you commit. To run formatting manually: 45 | 46 | ```bash 47 | # Format code + imports 48 | ruff check --select I --fix && ruff check --fix && ruff format 49 | 50 | # Run linting 51 | ruff check . 52 | 53 | # Run type checking 54 | pyright 55 | ``` 56 | 57 | ## Pull Request Process 58 | 59 | 1. Update the README.md with details of changes to the interface, if applicable 60 | 2. Update the documentation with details of any new functionality 61 | 3. Add or update tests as appropriate 62 | 4. Use clear, descriptive commit messages 63 | 5. The PR should be reviewed by at least one maintainer 64 | 6. Update the CHANGELOG.md with a note describing your changes 65 | 66 | ## Testing 67 | 68 | We use pytest for testing. To run tests: 69 | 70 | ```bash 71 | pytest 72 | ``` 73 | 74 | For coverage report: 75 | 76 | ```bash 77 | pytest --cov=mattertune 78 | ``` 79 | 80 | ## Reporting Bugs 81 | 82 | We use GitHub issues to track public bugs. Report a bug by [opening a new issue](). 83 | 84 | **Great Bug Reports** tend to have: 85 | 86 | - A quick summary and/or background 87 | - Steps to reproduce 88 | - Be specific! 89 | - Give sample code if you can 90 | - What you expected would happen 91 | - What actually happens 92 | - Notes (possibly including why you think this might be happening, or stuff you tried that didn't work) 93 | 94 | ## Feature Requests 95 | 96 | We love feature requests! To submit a feature request: 97 | 98 | 1. Check if the feature has already been requested in the issues 99 | 2. If not, create a new issue with the label "enhancement" 100 | 3. Include: 101 | - Clear description of the feature 102 | - Rationale for the feature 103 | - Example use cases 104 | - Potential implementation approach (optional) 105 | 106 | ## Documentation 107 | 108 | Documentation improvements are always welcome. Our docs are in the `docs/` folder and use Markdown format. 109 | 110 | ## License 111 | 112 | By contributing, you agree that your contributions will be licensed under the MIT License that covers the project. Feel free to contact the maintainers if that's a concern. 113 | 114 | ## Working with Model Backbones 115 | 116 | When working with model backbones, please note: 117 | 118 | - JMP (CC BY-NC 4.0 License): Non-commercial use only 119 | - EquiformerV2 (Meta Research License): Follow Meta's Acceptable Use Policy 120 | - M3GNet (BSD 3-Clause): Include required notices 121 | - ORB (Apache 2.0): Include required notices and attribution 122 | 123 | ## Questions? 124 | 125 | Don't hesitate to ask questions about how to contribute. You can: 126 | 127 | 1. Open an issue with your question 128 | 2. Tag your issue with "question" 129 | 3. We'll get back to you as soon as we can 130 | 131 | ## Attribution and References 132 | 133 | When adding new features or modifying existing ones, please add appropriate references to papers, repositories, or other sources that informed your implementation. 134 | 135 | Thank you for contributing to MatterTune! 🎉 136 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Fung Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MatterTune: A Unified Platform for Atomistic Foundation Model Fine-Tuning 2 | 3 | [![Documentation Status](https://github.com/Fung-Lab/MatterTune/actions/workflows/docs.yml/badge.svg)](https://fung-lab.github.io/MatterTune/) 4 | 5 | **[📚 Documentation](https://fung-lab.github.io/MatterTune/) | [🔧 Installation Guide](https://fung-lab.github.io/MatterTune/installation.html)** 6 | 7 | MatterTune is a flexible and powerful machine learning library designed specifically for fine-tuning state-of-the-art chemistry foundation models. It provides intuitive interfaces for computational chemists and materials scientists to fine-tune pre-trained models on their specific use cases. 8 | 9 | ## Features 10 | 11 | - Pre-trained model support: **JMP**, **EquiformerV2**, **MatterSim**, **ORB**, **MACE**, and more to be added. 12 | - Multiple property predictions: energy, forces, stress, and custom properties. 13 | - Various supported dataset formats: XYZ, ASE databases, Materials Project, Matbench, and more. 14 | - Comprehensive training features with automated data splitting and logging. 15 | 16 | ## Quick Start 17 | 18 | ```python 19 | import mattertune as mt 20 | import mattertune.configs as MC 21 | 22 | # Phase 1: Fine-tuning the model 23 | # ----------------------------- 24 | 25 | # Define the configuration for model, data, and training 26 | config = mt.configs.MatterTunerConfig( 27 | # Configure the model: using JMP backbone with energy prediction 28 | model=MC.JMPBackboneConfig( 29 | pretrained_model = "jmp-s", # Select pretrained model type 30 | graph_computer = MC.JMPGraphComputerConfig(pbc=True) 31 | properties = [ 32 | MC.EnergyPropertyConfig( # Configure energy prediction 33 | loss=MC.MAELossConfig(), # Using MAE loss 34 | loss_coefficient=1.0 # Weight for this property's loss 35 | ), 36 | MC.ForcesPropertyConfig( 37 | loss=MC.MSELossConfig(), 38 | conservative=False, 39 | loss_coefficient=1.0 40 | ) 41 | ], 42 | optimizer = MC.AdamWConfig(lr=8.0e-5) 43 | ), 44 | # Configure the data: loading from XYZ file with automatic train/val split 45 | data=MC.AutoSplitDataModuleConfig( 46 | dataset=MC.XYZDatasetConfig( 47 | src=Path("YOUR_XYZFILE_PATH") # Path to your XYZ data 48 | ), 49 | train_split=0.8, # Use 80% of data for training 50 | batch_size=32 # Process 32 structures per batch 51 | ), 52 | # Configure the training process 53 | trainer=MC.TrainerConfig( 54 | max_epochs=10, # Train for 10 epochs 55 | accelerator="gpu", # Use GPU for training 56 | devices=[0] # Use first GPU 57 | additional_trainer_kwargs={ 58 | "inference_mode": False, 59 | } 60 | ), 61 | ) 62 | 63 | # Create tuner and start training 64 | tuner = mt.MatterTune(config) 65 | model, trainer = tuner.tune() 66 | 67 | # Save the fine-tuned model 68 | trainer.save_checkpoint("finetuned_model.ckpt") 69 | 70 | # Phase 2: Using the fine-tuned model 71 | # ---------------------------------- 72 | 73 | from ase.optimize import BFGS 74 | from ase import Atoms 75 | 76 | # Load the fine-tuned model 77 | model = mt.backbones.JMPBackboneModule.load_from_checkpoint("finetuned_model.ckpt") 78 | 79 | # Create an ASE calculator from the model 80 | calculator = model.ase_calculator() 81 | 82 | # Set up an atomic structure 83 | atoms = Atoms('H2O', 84 | positions=[[0, 0, 0], [0, 0, 1], [0, 1, 0]], 85 | cell=[10, 10, 10], 86 | pbc=True) 87 | atoms.calc = calculator 88 | 89 | # Run geometry optimization 90 | opt = BFGS(atoms) 91 | opt.run(fmax=0.01) 92 | 93 | # Get results 94 | print("Final energy:", atoms.get_potential_energy()) 95 | print("Final forces:", atoms.get_forces()) 96 | ``` 97 | 98 | ## FAQ 99 | 100 | We welcome anyone with questions or suggestions about the MatterTune project to open an issue in this repository’s Issues section. Before creating a new issue, please check the existing ones to see whether a solution has already been posted. 101 | 102 | ## License 103 | 104 | MatterTune's core framework is licensed under the MIT License. Note that each supported model backbone is subject to its own licensing terms - see our [license information page of the documentation](https://fung-lab.github.io/MatterTune/license.html) for more details. 105 | 106 | ## Citation 107 | 108 | Coming soon. 109 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | 3 | # You can set these variables from the command line, and also 4 | # from the environment for the first two. 5 | SPHINXOPTS ?= 6 | SPHINXBUILD ?= sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Clean up the build directory and the autosummary directory 17 | clean: 18 | @$(SPHINXBUILD) -M clean "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 19 | rm -rf _autosummary/ 20 | 21 | 22 | # Catch-all target: route all unknown targets to Sphinx using the new 23 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 24 | %: Makefile 25 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 26 | -------------------------------------------------------------------------------- /docs/_templates/custom-class-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :members: 7 | :show-inheritance: 8 | :inherited-members: 9 | :special-members: __call__, __add__, __mul__ 10 | 11 | {% block methods %} 12 | {% if methods %} 13 | .. rubric:: {{ _('Methods') }} 14 | 15 | .. autosummary:: 16 | :nosignatures: 17 | {% for item in methods %} 18 | ~{{ name }}.{{ item }} 19 | {%- endfor %} 20 | {% endif %} 21 | {% endblock %} 22 | -------------------------------------------------------------------------------- /docs/_templates/custom-module-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. automodule:: {{ fullname }} 4 | :members: 5 | :undoc-members: 6 | :show-inheritance: 7 | :variables: 8 | :module-variables: 9 | 10 | {% block attributes %} 11 | {% if attributes %} 12 | .. rubric:: Module Variables 13 | 14 | .. autosummary:: 15 | {% for item in attributes %} 16 | {{ item }} 17 | {%- endfor %} 18 | 19 | {% for item in attributes %} 20 | .. autodata:: {{ item }} 21 | :annotation: 22 | {%- endfor %} 23 | {% endif %} 24 | {% endblock %} 25 | 26 | {% block modules %} 27 | {% if modules %} 28 | .. rubric:: Submodules 29 | 30 | .. autosummary:: 31 | :toctree: 32 | :recursive: 33 | 34 | {% for item in modules %} 35 | {{ item }} 36 | {%- endfor %} 37 | {% endif %} 38 | {% endblock %} 39 | -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | # API Reference 2 | 3 | ```{eval-rst} 4 | .. currentmodule:: mattertune 5 | 6 | .. autosummary:: 7 | :toctree: _autosummary 8 | :recursive: 9 | 10 | backbones 11 | callbacks 12 | configs 13 | data 14 | finetune 15 | loggers 16 | main 17 | normalization 18 | registry 19 | wrappers 20 | ``` 21 | -------------------------------------------------------------------------------- /docs/backbones/eqv2.md: -------------------------------------------------------------------------------- 1 | # EquiformerV2 Backbone 2 | 3 | The EquiformerV2 backbone implements Meta AI's EquiformerV2 model architecture in MatterTune. This is a state-of-the-art equivariant transformer model for molecular and materials property prediction, offering excellent performance across a wide range of chemical systems. 4 | 5 | ## Installation 6 | 7 | Before using the EquiformerV2 backbone, you need to set up the required dependencies in a fresh conda environment: 8 | 9 | ```bash 10 | conda create -n eqv2-tune python=3.10 11 | conda activate eqv2-tune 12 | 13 | # Install fairchem core 14 | pip install "git+https://github.com/FAIR-Chem/fairchem.git@omat24#subdirectory=packages/fairchem-core" --no-deps 15 | 16 | # Install dependencies 17 | pip install ase "e3nn>=0.5" hydra-core lmdb numba "numpy>=1.26,<2.0" orjson \ 18 | "pymatgen>=2023.10.3" submitit tensorboard "torch>=2.4" wandb torch_geometric \ 19 | h5py netcdf4 opt-einsum spglib 20 | ``` 21 | 22 | ## Key Features 23 | 24 | - E(3)-equivariant transformer architecture for robust geometric predictions 25 | - Support for both molecular and periodic systems 26 | - Highly optimized implementation for efficient training and inference 27 | - Pre-trained models available from Meta AI's OMAT24 release 28 | - Support for property predictions: 29 | - Energy (extensive/intensive) 30 | - Forces (non-conservative) 31 | - Stresses (non-conservative) 32 | - System-level graph properties (with sum/mean reduction) 33 | 34 | ## Configuration 35 | 36 | Here's a complete example showing how to configure the EquiformerV2 backbone: 37 | 38 | ```python 39 | from mattertune import configs as MC 40 | from pathlib import Path 41 | 42 | config = MC.MatterTunerConfig( 43 | model=MC.EqV2BackboneConfig( 44 | # Required: Path to pre-trained checkpoint 45 | checkpoint_path="path/to/eqv2_checkpoint.pt", 46 | 47 | # Configure graph construction 48 | atoms_to_graph=MC.FAIRChemAtomsToGraphSystemConfig( 49 | cutoff=12.0, # Angstroms 50 | max_neighbors=50, 51 | pbc=True # Set False for molecules 52 | ), 53 | 54 | # Properties to predict 55 | properties=[ 56 | # Energy prediction 57 | MC.EnergyPropertyConfig( 58 | loss=MC.MAELossConfig(), 59 | loss_coefficient=1.0 60 | ), 61 | 62 | # Force prediction (non-conservative) 63 | MC.ForcesPropertyConfig( 64 | loss=MC.MAELossConfig(), 65 | loss_coefficient=10.0, 66 | conservative=False 67 | ), 68 | 69 | # Stress prediction (non-conservative) 70 | MC.StressesPropertyConfig( 71 | loss=MC.MAELossConfig(), 72 | loss_coefficient=1.0, 73 | conservative=False 74 | ), 75 | 76 | # System-level property prediction 77 | MC.GraphPropertyConfig( 78 | name="bandgap", 79 | loss=MC.MAELossConfig(), 80 | loss_coefficient=1.0, 81 | reduction="mean" # or "sum" 82 | ) 83 | ], 84 | 85 | # Optimizer settings 86 | optimizer=MC.AdamWConfig(lr=1e-4), 87 | 88 | # Optional: Learning rate scheduler 89 | lr_scheduler=MC.CosineAnnealingLRConfig( 90 | T_max=100, 91 | eta_min=1e-6 92 | ) 93 | ), 94 | 95 | # ... data and trainer configs ... 96 | ) 97 | ``` 98 | 99 | ## Property Support 100 | 101 | The EquiformerV2 backbone supports the following property predictions: 102 | 103 | ### Energy Prediction 104 | - Uses `EquiformerV2EnergyHead` for extensive energy predictions 105 | - Always uses "sum" reduction over atomic contributions 106 | 107 | ### Force Prediction 108 | - Uses `EquiformerV2ForceHead` for direct force prediction 109 | - Currently only supports non-conservative forces (energy-derived forces coming soon) 110 | 111 | ### Stress Prediction 112 | - Uses `Rank2SymmetricTensorHead` for stress tensor prediction 113 | - Currently only supports non-conservative stresses 114 | - Returns full 3x3 stress tensor 115 | 116 | ### Graph Properties 117 | - Uses `EquiformerV2EnergyHead` with configurable reduction 118 | - Supports "sum" or "mean" reduction over atomic features 119 | - Suitable for intensive properties like bandgap 120 | 121 | ## Limitations 122 | 123 | - Conservative forces and stresses not yet supported (coming in future release) 124 | - Graph construction parameters must be manually specified (automatic loading from checkpoint coming soon) 125 | - Per-species reference energy normalization not yet implemented 126 | 127 | ## Using Pre-trained Models 128 | 129 | The EquiformerV2 backbone supports loading pre-trained models from Meta AI's OMAT24 release. Here's how to use them: 130 | 131 | ```python 132 | config = C.MatterTunerConfig( 133 | model=C.EqV2BackboneConfig( 134 | checkpoint_path=C.CachedPath( 135 | uri='hf://fairchem/OMAT24/eqV2_31M_mp.pt' 136 | ), 137 | # ... rest of config ... 138 | ) 139 | ) 140 | ``` 141 | 142 | Please visit the [Meta AI OMAT24 model release page on Hugging Face](https://huggingface.co/fairchem/OMAT24) for more details. 143 | 144 | ## Examples & Notebooks 145 | 146 | A notebook tutorial about how to fine-tune and use EquiformerV2 model can be found in ```notebooks/eqv2-omat.ipynb```([link](https://github.com/Fung-Lab/MatterTune/blob/main/notebooks/eqv2-omat.ipynb)). 147 | 148 | Under ```matbench```([link](https://github.com/Fung-Lab/MatterTune/tree/main/examples/matbench)), we gave an advanced usage example fine-tuning EquiformerV2 on property prediction data and applying to property screening task. 149 | 150 | ## License 151 | 152 | The EquiformerV2 backbone is subject to Meta's Research License. Please ensure compliance with the license terms when using this backbone, especially for commercial applications. 153 | -------------------------------------------------------------------------------- /docs/backbones/jmp.md: -------------------------------------------------------------------------------- 1 | # JMP Backbone 2 | 3 | The JMP backbone implements the Joint Multi-domain Pre-training (JMP) framework in MatterTune. This is a high-performance model architecture that combines message passing neural networks with transformers for accurate prediction of molecular and materials properties. 4 | 5 | ## Installation 6 | 7 | Before using the JMP backbone, follow the installation instructions in the [jmp-backbone repository](https://github.com/nimashoghi/jmp-backbone/blob/lingyu-grad/README.md). 8 | 9 | For development work with MatterTune, it's recommended to create a fresh conda environment: 10 | 11 | ```bash 12 | conda create -n jmp-tune python=3.10 13 | conda activate jmp-tune 14 | ``` 15 | 16 | ## Key Features 17 | 18 | - Hybrid architecture combining message passing networks with transformers 19 | - Supports both molecular and periodic systems with flexible boundary conditions 20 | - Highly optimized for both training and inference 21 | - Support for property predictions: 22 | - Energy (extensive/intensive) 23 | - Forces (both conservative and non-conservative) 24 | - Stresses (both conservative and non-conservative) 25 | - Graph-level properties with customizable reduction 26 | 27 | ## Configuration 28 | 29 | Here's a complete example showing how to configure the JMP backbone: 30 | 31 | ```python 32 | from mattertune import configs as MC 33 | from pathlib import Path 34 | 35 | config = MC.MatterTunerConfig( 36 | model=MC.JMPBackboneConfig( 37 | # Required: Path to pre-trained checkpoint 38 | ckpt_path="path/to/jmp_checkpoint.pt", 39 | 40 | # Graph construction settings 41 | graph_computer=MC.JMPGraphComputerConfig( 42 | pbc=True, # Set False for molecules 43 | ), 44 | 45 | # Properties to predict 46 | properties=[ 47 | # Energy prediction 48 | MC.EnergyPropertyConfig( 49 | loss=MC.MAELossConfig(), 50 | loss_coefficient=1.0 51 | ), 52 | 53 | # Force prediction (conservative) 54 | MC.ForcesPropertyConfig( 55 | loss=MC.MAELossConfig(), 56 | loss_coefficient=10.0, 57 | conservative=True 58 | ), 59 | 60 | # Stress prediction (conservative) 61 | MC.StressesPropertyConfig( 62 | loss=MC.MAELossConfig(), 63 | loss_coefficient=1.0, 64 | conservative=True 65 | ), 66 | 67 | # System-level property prediction 68 | MC.GraphPropertyConfig( 69 | name="bandgap", 70 | loss=MC.MAELossConfig(), 71 | loss_coefficient=1.0, 72 | reduction="mean" # or "sum" 73 | ) 74 | ], 75 | 76 | # Optimizer settings 77 | optimizer=MC.AdamWConfig(lr=1e-4), 78 | 79 | # Optional: Learning rate scheduler 80 | lr_scheduler=MC.CosineAnnealingLRConfig( 81 | T_max=100, 82 | eta_min=1e-6 83 | ) 84 | ) 85 | ) 86 | ``` 87 | 88 | ## Property Support 89 | 90 | The JMP backbone supports the following property predictions: 91 | 92 | ### Energy Prediction 93 | - Full support for extensive energy predictions 94 | - Automated per-atom energy normalization 95 | - Optional atomic reference energy subtraction 96 | 97 | ### Force Prediction 98 | - Supports both conservative (energy-derived) and direct force prediction 99 | - Configurable force scaling during training 100 | - Automatic handling of periodic boundary conditions 101 | 102 | ### Stress Prediction 103 | - Full support for stress tensor prediction 104 | - Conservative (energy-derived) or direct stress computation 105 | - Returns full 3x3 stress tensor with proper PBC handling 106 | 107 | ### Graph Properties 108 | - Support for system-level property prediction 109 | - Configurable "sum" or "mean" reduction over atomic features 110 | - Suitable for both extensive and intensive properties 111 | 112 | ## Graph Construction Parameters 113 | 114 | The JMP backbone uses a sophisticated multi-scale graph construction approach with several key parameters: 115 | 116 | - `cutoffs`: Distance cutoffs for different interaction types 117 | - `main`: Primary interaction cutoff (typically 12.0 Å) 118 | - `aeaint`: Atomic energy interaction cutoff 119 | - `qint`: Charge interaction cutoff 120 | - `aint`: Auxiliary interaction cutoff 121 | 122 | - `max_neighbors`: Maximum number of neighbors per interaction type 123 | - `main`: Primary interaction neighbors (typically 30) 124 | - `aeaint`: Atomic energy interaction neighbors 125 | - `qint`: Charge interaction neighbors 126 | - `aint`: Auxiliary interaction neighbors 127 | 128 | ## Examples & Notebooks 129 | 130 | A notebook tutorial about how to fine-tune and use JMP model can be found in ```notebooks/jmp-omat.ipynb```([link](https://github.com/Fung-Lab/MatterTune/blob/main/notebooks/jmp-omat.ipynb)) and ```notebooks/jmp-omat-autosplit.ipynb```([link](https://github.com/Fung-Lab/MatterTune/blob/main/notebooks/jmp-omat-autosplit.ipynb)). 131 | 132 | Under ```matbench```([link](https://github.com/Fung-Lab/MatterTune/tree/main/examples/matbench)), we gave an advanced usage example fine-tuning JMP on property prediction data and applying to property screening task. 133 | 134 | Under ```structure-optimization```([link](https://github.com/Fung-Lab/MatterTune/blob/main/examples/structure-optimization)), we fine-tuned JMP on a subset of MPTraj and applied it to perform structure relaxation. 135 | 136 | ## License 137 | 138 | The JMP backbone is available under the CC BY-NC 4.0 License, which means: 139 | - Free for academic and non-commercial use 140 | - Required attribution when using or modifying the code 141 | - Commercial use requires separate licensing 142 | 143 | Please ensure compliance with the license terms before using this backbone in your projects. 144 | -------------------------------------------------------------------------------- /docs/backbones/m3gnet.md: -------------------------------------------------------------------------------- 1 | # M3GNet Backbone 2 | 3 | The M3GNet backbone implements the M3GNet model architecture in MatterTune. It provides a powerful graph neural network designed specifically for materials science applications. In MatterTune, we chose the M3GNet model implemented by MatGL and pretrained on MPTraj dataset. 4 | 5 | ## Installation 6 | 7 | ```bash 8 | conda create -n matgl-tune python=3.10 -y 9 | pip install matgl 10 | pip install torch==2.2.1+cu121 -f https://download.pytorch.org/whl/torch_stable.html 11 | pip uninstall dgl 12 | pip install dgl -f https://data.dgl.ai/wheels/torch-2.2/cu121/repo.html 13 | pip install dglgo -f https://data.dgl.ai/wheels-test/repo.html 14 | ``` 15 | 16 | ## Key Features 17 | 18 | M3GNet supports predicting: 19 | - Total energy (with energy conservation) 20 | - Atomic forces (derived from energy) 21 | - Stress tensors (derived from energy) 22 | 23 | The model uses a three-body graph neural network architecture that captures both two-body and three-body interactions between atoms. 24 | 25 | ## Configuration 26 | 27 | Configure the M3GNet backbone using: 28 | 29 | ```python 30 | model = mt.configs.M3GNetBackboneConfig( 31 | # Path to pretrained checkpoint 32 | ckpt_path="path/to/checkpoint", 33 | 34 | # Graph computer settings 35 | graph_computer=mt.configs.M3GNetGraphComputerConfig( 36 | # Cutoff distance for neighbor list. If None, the cutoff is loaded from the checkpoint. 37 | cutoff=5.0, 38 | 39 | # Cutoff for three-body interactions. If None, the cutoff is loaded from the checkpoint. 40 | threebody_cutoff=4.0, 41 | 42 | # Whether to precompute line graphs 43 | pre_compute_line_graph=False 44 | ), 45 | 46 | # Properties to predict 47 | properties=[ 48 | mt.configs.EnergyPropertyConfig( 49 | loss=mt.configs.MAELossConfig(), 50 | loss_coefficient=1.0 51 | ), 52 | mt.configs.ForcesPropertyConfig( 53 | loss=mt.configs.MAELossConfig(), 54 | loss_coefficient=0.1, 55 | conservative=True # Forces derived from energy 56 | ) 57 | ], 58 | 59 | # Training settings 60 | optimizer=mt.configs.AdamConfig(lr=1e-4) 61 | ) 62 | ``` 63 | 64 | ### Key Parameters 65 | 66 | - `ckpt_path`: Path to pretrained model checkpoint 67 | - `graph_computer`: Controls graph construction: 68 | - `element_types`: Elements to include (defaults to all) 69 | - `cutoff`: Distance cutoff for neighbor list 70 | - `threebody_cutoff`: Cutoff for three-body interactions 71 | - `pre_compute_line_graph`: Whether to precompute line graphs 72 | - `properties`: List of properties to predict 73 | - `optimizer`: Optimizer configuration 74 | 75 | ## Implementation Details 76 | 77 | The backbone is implemented in `M3GNetBackboneModule` which: 78 | 79 | 1. Loads the pretrained model using MatGL 80 | 2. Constructs atomic graphs with both two-body and three-body interactions 81 | 3. Handles property prediction with energy conservation 82 | 4. Manages normalization of inputs/outputs 83 | 84 | Key features: 85 | - Energy-conserving force prediction 86 | - Three-body interactions for improved accuracy 87 | - Efficient graph construction 88 | - Support for periodic boundary conditions 89 | 90 | ## Examples & Notebooks 91 | 92 | A notebook tutorial about how to fine-tune M3GNet model can be found in ```notebooks/m3gnet-waterthermo.ipynb```([link](https://github.com/Fung-Lab/MatterTune/blob/main/notebooks/m3gnet-waterthermo.ipynb)). 93 | 94 | For advanced usage regarding fine-tuning models and applying them to downstream tasks (MD simulation for example), please refer to ```water-thermodynamics```([link](https://github.com/Fung-Lab/MatterTune/tree/main/examples/water-thermodynamics)) 95 | 96 | ## License 97 | 98 | We used the M3GNet model implemented in MatGL package, which is available under BSD 3-Clause License, which means redistribution and use in source and binary forms, with or without 99 | modification, are permitted provided that the following conditions are met: 100 | 101 | 1. Redistributions of source code must retain the above copyright notice, this 102 | list of conditions and the following disclaimer. 103 | 104 | 2. Redistributions in binary form must reproduce the above copyright notice, 105 | this list of conditions and the following disclaimer in the documentation 106 | and/or other materials provided with the distribution. 107 | 108 | 3. Neither the name of the copyright holder nor the names of its 109 | contributors may be used to endorse or promote products derived from 110 | this software without specific prior written permission. 111 | -------------------------------------------------------------------------------- /docs/backbones/mace.md: -------------------------------------------------------------------------------- 1 | # MACE Backbone 2 | 3 | MACE is a series of fast and accurate machine learning interatomic potentials with higher order equivariant message passing developed by Ilyes Batatia, Gregor Simm, David Kovacs, and the group of Gabor Csanyi in University of Cambridge. The MACE series released its first foundation model, [MACE-MP-0](https://arxiv.org/abs/2401.00096), in 2023, making it one of the earliest foundation models in the materials domain. To date, MACE has spawned several versions of its foundation models (see [MACE versions](https://github.com/ACEsuit/mace-foundations) for details) and has earned top marks on numerous leaderboards. 4 | 5 | ## Installation 6 | 7 | MACE can be directly installed with pip: 8 | 9 | ```bash 10 | pip install --upgrade pip 11 | pip install mace-torch 12 | ``` 13 | 14 | or it can be installed from source code: 15 | 16 | ```bash 17 | git clone https://github.com/ACEsuit/mace.git 18 | pip install ./mace 19 | ``` 20 | 21 | ## Key Features 22 | 23 | MACE adopts an equivariant neural-network paradigm and delivers energy-conserving predictions of forces and stresses. For details on the specific features of each MACE version, please consult the introduction here: [MACE versions](https://github.com/ACEsuit/mace-foundations) 24 | 25 | ## License 26 | 27 | The MatterSim backbone is available under MIT License -------------------------------------------------------------------------------- /docs/backbones/mattersim.md: -------------------------------------------------------------------------------- 1 | # MatterSim Backbone 2 | 3 | > Note: As of the latest MatterTune update, MatterSim has only released the M3GNet model. 4 | 5 | The MatterSim backbone integrates the MatterSim model architecture into MatterTune. MatterSim is a foundational atomistic model designed to simulate materials property across wide range of elements, temperatures and pressures. 6 | 7 | ## Installation 8 | 9 | We strongly recommand to install MatterSim from source code 10 | 11 | ```bash 12 | git clone git@github.com:microsoft/mattersim.git 13 | cd mattersim 14 | ``` 15 | 16 | Find the line 41 of the pyproject.toml in MatterSim, which is ```"pydantic==2.9.2",```. Change it to ```"pydantic>=2.9.2",```. After finishing this modification, install MatterSim by running: 17 | 18 | ```bash 19 | mamba env create -f environment.yaml 20 | mamba activate mattersim 21 | uv pip install -e . 22 | python setup.py build_ext --inplace 23 | ``` 24 | 25 | ## Key Features 26 | 27 | - Pretrained on materials data across wide range of elements, temperatures and pressures. 28 | - Flexible model architecture selection 29 | - MatterSim-v1.0.0-1M: A mini version of the M3GNet that is faster to run. 30 | - MatterSim-v1.0.0-5M: A larger version of the M3GNet that is more accurate. 31 | - TO BE RELEASED: Graphormer model with even larger parameter scale 32 | - Support for property predictions: 33 | - Energy (extensive/intensive) 34 | - Forces (conservative for M3GNet and non-conservative for Graphormer) 35 | - Stresses (conservative for M3GNet and non-conservative for Graphormer) 36 | - Graph-level properties (available on Graphormer) 37 | 38 | ## Configuration 39 | 40 | Here's a complete example showing how to configure the JMP backbone: 41 | 42 | ```python 43 | from mattertune import configs as MC 44 | from pathlib import Path 45 | 46 | config = MC.MatterTunerConfig( 47 | model=MC.MatterSimBackboneConfig( 48 | # Required: Path to pre-trained checkpoint 49 | pretrained_model="MatterSim-v1.0.0-5M", 50 | 51 | # Graph construction settings 52 | graph_convertor=MC.MatterSimGraphConvertorConfig( 53 | twobody_cutoff = 5.0 ## The cutoff distance for the two-body interactions. 54 | has_threebody = True ## Whether to include three-body interactions. 55 | threebody_cutoff = 4.0 ## The cutoff distance for the three-body interactions. 56 | ) 57 | 58 | # Properties to predict 59 | properties=[ 60 | # Energy prediction 61 | MC.EnergyPropertyConfig( 62 | loss=MC.MAELossConfig(), 63 | loss_coefficient=1.0 64 | ), 65 | 66 | # Force prediction (conservative) 67 | MC.ForcesPropertyConfig( 68 | loss=MC.MAELossConfig(), 69 | loss_coefficient=10.0, 70 | conservative=True 71 | ), 72 | 73 | # Stress prediction (conservative) 74 | MC.StressesPropertyConfig( 75 | loss=MC.MAELossConfig(), 76 | loss_coefficient=1.0, 77 | conservative=True 78 | ), 79 | ], 80 | 81 | # Optimizer settings 82 | optimizer=MC.AdamWConfig(lr=1e-4), 83 | 84 | # Optional: Learning rate scheduler 85 | lr_scheduler=MC.CosineAnnealingLRConfig( 86 | T_max=100, 87 | eta_min=1e-6 88 | ) 89 | ) 90 | ) 91 | ``` 92 | 93 | ## Examples & Notebooks 94 | 95 | A notebook tutorial about how to fine-tune and use MatterSim model can be found in ```notebooks/mattersim-waterthermo.ipynb```([link](https://github.com/Fung-Lab/MatterTune/blob/main/notebooks/mattersim-waterthermo.ipynb)). 96 | 97 | Under ```water-thermodynamics```([link](https://github.com/Fung-Lab/MatterTune/tree/main/examples/water-thermodynamics)), we gave an advanced usage example fine-tuning MatterSim on PES data and applying to MD simulation 98 | 99 | ## License 100 | 101 | The MatterSim backbone is available under MIT License -------------------------------------------------------------------------------- /docs/backbones/orb.md: -------------------------------------------------------------------------------- 1 | # ORB Backbone 2 | 3 | The ORB backbone implements the Orbital Neural Networks model architecture in MatterTune. This is a state-of-the-art graph neural network designed specifically for molecular and materials property prediction, with excellent performance across diverse chemical systems. 4 | 5 | ## Installation 6 | 7 | Before using the ORB backbone, you need to install the required dependencies: 8 | 9 | ```bash 10 | pip install "orb_models@git+https://github.com/nimashoghi/orb-models.git" 11 | ``` 12 | 13 | ## Key Features 14 | 15 | - Advanced graph neural network architecture optimized for materials 16 | - Support for both molecular and periodic systems 17 | - Highly efficient implementation for fast training and inference 18 | - Pre-trained models available from the orb-models package 19 | - Support for property predictions: 20 | - Energy (extensive/intensive) 21 | - Forces (non-conservative) 22 | - Stresses (non-conservative) 23 | - System-level graph properties (with configurable reduction) 24 | 25 | ## Configuration 26 | 27 | Here's a complete example showing how to configure the ORB backbone: 28 | 29 | ```python 30 | from mattertune import configs as MC 31 | from pathlib import Path 32 | 33 | config = MC.MatterTunerConfig( 34 | model=MC.ORBBackboneConfig( 35 | # Required: Name of pre-trained model 36 | pretrained_model="orb-v2", 37 | 38 | # Configure graph construction 39 | system=MC.ORBSystemConfig( 40 | radius=10.0, # Angstroms 41 | max_num_neighbors=20 42 | ), 43 | 44 | # Properties to predict 45 | properties=[ 46 | # Energy prediction 47 | MC.EnergyPropertyConfig( 48 | loss=MC.MAELossConfig(), 49 | loss_coefficient=1.0 50 | ), 51 | 52 | # Force prediction (non-conservative) 53 | MC.ForcesPropertyConfig( 54 | loss=MC.MAELossConfig(), 55 | loss_coefficient=10.0, 56 | conservative=False 57 | ), 58 | 59 | # Stress prediction (non-conservative) 60 | MC.StressesPropertyConfig( 61 | loss=MC.MAELossConfig(), 62 | loss_coefficient=1.0, 63 | conservative=False 64 | ), 65 | 66 | # System-level property prediction 67 | MC.GraphPropertyConfig( 68 | name="bandgap", 69 | loss=MC.MAELossConfig(), 70 | loss_coefficient=1.0, 71 | reduction="mean" # or "sum" 72 | ) 73 | ], 74 | 75 | # Optimizer settings 76 | optimizer=MC.AdamWConfig(lr=1e-4), 77 | 78 | # Optional: Learning rate scheduler 79 | lr_scheduler=MC.CosineAnnealingLRConfig( 80 | T_max=100, 81 | eta_min=1e-6 82 | ) 83 | ), 84 | 85 | # ... data and trainer configs ... 86 | ) 87 | ``` 88 | 89 | ## Property Support 90 | 91 | The ORB backbone supports the following property predictions: 92 | 93 | ### Energy Prediction 94 | - Uses `EnergyHead` for extensive energy predictions 95 | - Supports automated per-atom energy normalization 96 | - Optional atomic reference energy subtraction 97 | 98 | ### Force Prediction 99 | - Uses `NodeHead` for direct force prediction 100 | - Currently only supports non-conservative forces 101 | - Configurable force scaling during training 102 | 103 | ### Stress Prediction 104 | - Uses `GraphHead` for stress tensor prediction 105 | - Currently only supports non-conservative stresses 106 | - Returns full 3x3 stress tensor 107 | 108 | ### Graph Properties 109 | - Uses `GraphHead` with configurable reduction 110 | - Supports "sum" or "mean" reduction over atomic features 111 | - Suitable for both extensive and intensive properties 112 | 113 | ## Graph Construction Parameters 114 | 115 | The ORB backbone uses a sophisticated graph construction approach with two key parameters: 116 | 117 | - `radius`: The cutoff distance for including neighbors in the graph (typically 10.0 Å) 118 | - `max_num_neighbors`: Maximum number of neighbors per atom to include (typically 20) 119 | 120 | ## Limitations 121 | 122 | - Conservative forces and stresses not supported 123 | - Limited to fixed graph construction parameters 124 | - No direct support for charge predictions 125 | - Reference energy normalization requires manual configuration 126 | 127 | ## Using Pre-trained Models 128 | 129 | The ORB backbone supports loading pre-trained models from the orb-models package. Available models include: 130 | 131 | - `orb-v2`: General-purpose model trained on materials data 132 | - `orb-qm9`: Model specialized for molecular systems 133 | - `orb-mp`: Model specialized for crystalline materials 134 | 135 | ```python 136 | config = MC.MatterTunerConfig( 137 | model=MC.ORBBackboneConfig( 138 | pretrained_model="orb-v2", 139 | # ... rest of config ... 140 | ) 141 | ) 142 | ``` 143 | 144 | ## Examples & Notebooks 145 | 146 | A notebook tutorial about how to fine-tune and use ORB model can be found in ```notebooks/orb-omat.ipynb```([link](https://github.com/Fung-Lab/MatterTune/blob/main/notebooks/orb-omat.ipynb)). 147 | 148 | Under ```matbench```([link](https://github.com/Fung-Lab/MatterTune/tree/main/examples/matbench)), we gave an advanced usage example fine-tuning ORB on property prediction data and applying to property screening task. 149 | 150 | ## License 151 | 152 | The ORB backbone is available under the Apache 2.0 License, which allows both academic and commercial use with proper attribution. 153 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import sys 5 | 6 | sys.path.insert(0, os.path.abspath("..")) 7 | 8 | project = "MatterTune" 9 | copyright = "2024, MatterTune Team" 10 | author = "MatterTune Team" 11 | 12 | extensions = [ 13 | "sphinx.ext.autodoc", 14 | "sphinx.ext.napoleon", 15 | "sphinx.ext.viewcode", 16 | "sphinx.ext.githubpages", 17 | "myst_parser", 18 | "sphinx_copybutton", 19 | "sphinx.ext.autosummary", 20 | ] 21 | 22 | # MyST Markdown settings 23 | myst_enable_extensions = [ 24 | "colon_fence", 25 | "deflist", 26 | "dollarmath", 27 | "fieldlist", 28 | "html_admonition", 29 | "html_image", 30 | "replacements", 31 | "smartquotes", 32 | "tasklist", 33 | ] 34 | 35 | # Theme settings 36 | html_theme = "sphinx_rtd_theme" 37 | html_static_path = ["_static"] 38 | html_logo = None 39 | html_favicon = None 40 | 41 | # General settings 42 | source_suffix = { 43 | ".rst": "restructuredtext", 44 | ".md": "markdown", 45 | } 46 | master_doc = "index" 47 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 48 | 49 | # AutoDoc settings 50 | autodoc_default_options = { 51 | "members": True, 52 | "member-order": "bysource", 53 | "special-members": "__init__", 54 | "undoc-members": True, 55 | "exclude-members": "__weakref__, __pydantic_core_schema__, __pydantic_validator__, __pydantic_serializer__, \ 56 | __pydantic_fields_set__, __pydantic_extra__, __pydantic_private__, __pydantic_post_init__, __pydantic_decorators__, \ 57 | __pydantic_parent_namespace__, __pydantic_generic_metadata__, __pydantic_custom_init__, __pydantic_complete__, \ 58 | __fields__, __fields_set__, model_fields, model_config, model_computed_fields, __class_vars__, __private_attributes__, \ 59 | __signature__, __pydantic_root_model__, __slots__, __dict__, model_extra, model_fields_set, model_post_init", 60 | "show-module-summary": True, 61 | } 62 | 63 | 64 | # To exclude private members (those starting with _) 65 | def skip_private_members(app, what, name, obj, skip, options): 66 | if ( 67 | name.startswith("_") 68 | and not name.startswith("__") 69 | and not name.startswith("__pydantic") 70 | ): 71 | return True 72 | return skip 73 | 74 | 75 | def setup(app): 76 | app.connect("autodoc-skip-member", skip_private_members) 77 | 78 | 79 | autodoc_mock_imports = [] 80 | 81 | # Type hints settings 82 | autodoc_typehints = "description" 83 | autodoc_typehints_format = "short" 84 | typehints_use_rtype = False 85 | typehints_defaults = "comma" 86 | 87 | # Enable automatic doc generation 88 | autosummary_generate = True 89 | autodoc_member_order = "bysource" 90 | add_module_names = True 91 | 92 | # Custom templates 93 | autodoc_template_dir = ["_templates"] 94 | templates_path = ["_templates"] 95 | 96 | # always_use_bars_union for type hints 97 | always_use_bars_union = True 98 | always_document_param_types = True 99 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing to MatterTune 2 | 3 | We welcome contributions to MatterTune! Whether you're fixing bugs, adding new features, improving documentation, or reporting issues, your help is appreciated. 4 | 5 | ## How to Contribute 6 | 7 | 1. Fork the repository 8 | 2. Create a new branch for your feature 9 | 3. Make your changes 10 | 4. Submit a pull request 11 | 12 | For detailed contribution guidelines, please see our [Contributing Guidelines](https://github.com/Fung-Lab/MatterTune/blob/main/CONTRIBUTING.md) in the repository. 13 | 14 | ## Development Setup 15 | 16 | 1. Fork and clone the repository 17 | 2. Create a virtual environment 18 | 3. Install development dependencies 19 | 4. Run tests to ensure everything is working 20 | 21 | ## Code Style 22 | 23 | We follow standard Python code style guidelines: 24 | - Use Black for code formatting 25 | - Follow PEP 8 guidelines 26 | - Write descriptive docstrings 27 | - Add type hints where appropriate 28 | 29 | ## Testing 30 | 31 | Please ensure all tests pass before submitting a pull request: 32 | ```bash 33 | pytest tests/ 34 | ``` 35 | 36 | ## Documentation 37 | 38 | When contributing new features, please: 39 | 1. Add docstrings to new functions and classes 40 | 2. Update relevant documentation files 41 | 3. Add examples where appropriate 42 | -------------------------------------------------------------------------------- /docs/guides/lightning.md: -------------------------------------------------------------------------------- 1 | # Advanced: Lightning Integration 2 | 3 | MatterTune uses PyTorch Lightning as its core training framework. This document outlines how Lightning is integrated and what functionality it provides. 4 | 5 | ## Core Components 6 | 7 | ### LightningModule Integration 8 | 9 | The base model class `FinetuneModuleBase` inherits from `LightningModule` and provides: 10 | 11 | - Automatic device management (GPU/CPU handling) 12 | - Distributed training support 13 | - Built-in training/validation/test loops 14 | - Logging and metrics tracking 15 | - Checkpoint management 16 | 17 | ```python 18 | class FinetuneModuleBase(LightningModule): 19 | def training_step(self, batch, batch_idx): 20 | output = self(batch) 21 | loss = self._compute_loss(output["predicted_properties"], self.batch_to_labels(batch)) 22 | return loss 23 | 24 | def validation_step(self, batch, batch_idx): 25 | output = self(batch) 26 | self._compute_loss(output["predicted_properties"], self.batch_to_labels(batch)) 27 | 28 | def test_step(self, batch, batch_idx): 29 | output = self(batch) 30 | self._compute_loss(output["predicted_properties"], self.batch_to_labels(batch)) 31 | 32 | def configure_optimizers(self): 33 | return create_optimizer(self.hparams.optimizer, self.parameters()) 34 | ``` 35 | 36 | ### Data Handling 37 | 38 | MatterTune uses Lightning's DataModule system for standardized data loading: 39 | 40 | ```python 41 | class MatterTuneDataModule(LightningDataModule): 42 | def prepare_data(self): 43 | # Download data if needed 44 | pass 45 | 46 | def setup(self, stage): 47 | # Create train/val/test splits 48 | self.datasets = self.hparams.create_datasets() 49 | 50 | def train_dataloader(self): 51 | return self.lightning_module.create_dataloader( 52 | self.datasets["train"], 53 | has_labels=True 54 | ) 55 | 56 | def val_dataloader(self): 57 | return self.lightning_module.create_dataloader( 58 | self.datasets["validation"], 59 | has_labels=True 60 | ) 61 | ``` 62 | 63 | ## Key Features 64 | 65 | ### 1. Checkpoint Management 66 | 67 | Lightning automatically handles model checkpointing: 68 | 69 | ```python 70 | checkpoint_callback = ModelCheckpointConfig( 71 | monitor="val/forces_mae", 72 | dirpath="./checkpoints", 73 | filename="best-model", 74 | save_top_k=1, 75 | mode="min" 76 | ).create_callback() 77 | 78 | trainer = Trainer(callbacks=[checkpoint_callback]) 79 | ``` 80 | 81 | ### 2. Early Stopping 82 | 83 | Built-in early stopping support: 84 | 85 | ```python 86 | early_stopping = EarlyStoppingConfig( 87 | monitor="val/forces_mae", 88 | patience=20, 89 | mode="min" 90 | ).create_callback() 91 | 92 | trainer = Trainer(callbacks=[early_stopping]) 93 | ``` 94 | 95 | ### 3. Multi-GPU Training 96 | 97 | Lightning handles distributed training with minimal code changes: 98 | 99 | ```python 100 | # Single GPU 101 | trainer = Trainer(accelerator="gpu", devices=[0]) 102 | 103 | # Multiple GPUs with DDP 104 | trainer = Trainer(accelerator="gpu", devices=[0,1], strategy="ddp") 105 | ``` 106 | 107 | ### 4. Logging 108 | 109 | Lightning provides unified logging interfaces: 110 | 111 | ```python 112 | def training_step(self, batch, batch_idx): 113 | loss = ... 114 | self.log("train_loss", loss) 115 | self.log_dict({ 116 | "energy_mae": energy_mae, 117 | "forces_mae": forces_mae 118 | }) 119 | ``` 120 | 121 | ### 5. Precision Settings 122 | 123 | Easy configuration of precision: 124 | 125 | ```python 126 | # 32-bit training 127 | trainer = Trainer(precision="32-true") 128 | 129 | # Mixed precision training 130 | trainer = Trainer(precision="16-mixed") 131 | ``` 132 | 133 | ## Available Trainer Configurations 134 | 135 | The `TrainerConfig` class exposes common Lightning Trainer settings: 136 | 137 | ```python 138 | trainer_config = TrainerConfig( 139 | # Hardware 140 | accelerator="gpu", 141 | devices=[0,1], 142 | precision="16-mixed", 143 | 144 | # Training 145 | max_epochs=100, 146 | gradient_clip_val=1.0, 147 | 148 | # Validation 149 | val_check_interval=1.0, 150 | check_val_every_n_epoch=1, 151 | 152 | # Callbacks 153 | early_stopping=EarlyStoppingConfig(...), 154 | checkpoint=ModelCheckpointConfig(...), 155 | 156 | # Logging 157 | loggers=["tensorboard", "wandb"] 158 | ) 159 | ``` 160 | 161 | ## Best Practices 162 | 163 | 1. Use `self.log()` for tracking metrics during training 164 | 2. Enable checkpointing to save model states 165 | 3. Set appropriate early stopping criteria 166 | 4. Use appropriate precision settings for your hardware 167 | 5. Configure multi-GPU training based on available resources 168 | 169 | ## Advanced Usage 170 | 171 | For advanced use cases: 172 | 173 | ```python 174 | # Custom training loop 175 | @override 176 | def training_step(self, batch, batch_idx): 177 | if self.trainer.global_rank == 0: 178 | # Do something only on main process 179 | pass 180 | 181 | # Access trainer properties 182 | if self.trainer.is_last_batch: 183 | # Special handling for last batch 184 | pass 185 | 186 | # Custom validation 187 | @override 188 | def validation_epoch_end(self, outputs): 189 | # Compute epoch-level metrics 190 | pass 191 | ``` 192 | 193 | This integration provides a robust foundation for training atomistic models while handling common ML engineering concerns automatically. 194 | -------------------------------------------------------------------------------- /docs/guides/model_usage.md: -------------------------------------------------------------------------------- 1 | # Model Usage Guide 2 | 3 | After training, you can use your model for predictions in several ways. This guide covers loading models and making predictions. 4 | 5 | ## Loading a Model 6 | 7 | To load a saved model checkpoint (from a previous fine-tuning run), use the `load_from_checkpoint` method: 8 | 9 | ```python 10 | from mattertune.backbones import JMPBackboneModule 11 | 12 | model = JMPBackboneModule.load_from_checkpoint("path/to/checkpoint.ckpt") 13 | ``` 14 | 15 | ## Making Predictions 16 | 17 | The `MatterTunePropertyPredictor` interface provides a simple way to make predictions for a single or batch of atoms: 18 | 19 | ```python 20 | from ase import Atoms 21 | import torch 22 | 23 | # Create ASE Atoms objects 24 | atoms1 = Atoms('H2O', 25 | positions=[[0, 0, 0], [0, 0, 1], [0, 1, 0]], 26 | cell=[10, 10, 10], 27 | pbc=True) 28 | 29 | atoms2 = Atoms('H2O', 30 | positions=[[0, 0, 0], [0, 0, 1], [0, 1, 0]], 31 | cell=[10, 10, 10], 32 | pbc=True) 33 | 34 | # Get predictions using the model's property predictor interface 35 | property_predictor = model.property_predictor() 36 | predictions = property_predictor.predict([atoms1, atoms2], ["energy", "forces"]) 37 | 38 | print("Energy:", predictions[0]["energy"], predictions[1]["energy"]) 39 | print("Forces:", predictions[0]["forces"], predictions[1]["forces"]) 40 | ``` 41 | 42 | ## Using as ASE Calculator 43 | 44 | Our ASE calculator interface allows you to use the model for molecular dynamics or geometry optimization: 45 | 46 | ```python 47 | from ase.optimize import BFGS 48 | 49 | # Create calculator from model 50 | calculator = model.ase_calculator() 51 | 52 | # Set up atoms and calculator 53 | atoms = Atoms('H2O', 54 | positions=[[0, 0, 0], [0, 0, 1], [0, 1, 0]], 55 | cell=[10, 10, 10], 56 | pbc=True) 57 | atoms.calc = calculator 58 | 59 | # Run geometry optimization 60 | opt = BFGS(atoms) 61 | opt.run(fmax=0.01) 62 | 63 | # Get optimized results 64 | print("Final energy:", atoms.get_potential_energy()) 65 | print("Final forces:", atoms.get_forces()) 66 | ``` 67 | -------------------------------------------------------------------------------- /docs/guides/normalization.md: -------------------------------------------------------------------------------- 1 | # Normalization 2 | 3 | MatterTune provides flexible property normalization capabilities through the `mattertune.normalization` module. Normalization is crucial for improving training stability and convergence when fine-tuning models, especially for properties that can vary widely in scale. 4 | 5 | ## Overview 6 | 7 | The normalization system consists of: 8 | - A `NormalizationContext` that provides per-batch information needed for normalization 9 | - Multiple normalizer types that can be composed together 10 | - CLI tools for computing normalization parameters from datasets 11 | 12 | ## Supported Normalizers 13 | 14 | ### Mean-Standard Deviation Normalization 15 | 16 | Normalizes values using mean and standard deviation: `(x - mean) / std` 17 | 18 | ```python 19 | config = mt.configs.MatterTunerConfig( 20 | model=mt.configs.JMPBackboneConfig( 21 | # ... other configs ... 22 | normalizers={ 23 | "energy": [ 24 | mt.configs.MeanStdNormalizerConfig( 25 | mean=-13.6, # mean of your property 26 | std=2.4 # standard deviation 27 | ) 28 | ] 29 | } 30 | ), 31 | # ... other configs ... 32 | ) 33 | ``` 34 | 35 | ### RMS Normalization 36 | 37 | Normalizes values by dividing by the root mean square value: `x / rms` 38 | 39 | ```python 40 | config = mt.configs.MatterTunerConfig( 41 | model=mt.configs.JMPBackboneConfig( 42 | # ... other configs ... 43 | normalizers={ 44 | "forces": [ 45 | mt.configs.RMSNormalizerConfig( 46 | rms=2.5 # RMS value of your property 47 | ) 48 | ] 49 | } 50 | ), 51 | # ... other configs ... 52 | ) 53 | ``` 54 | 55 | ### Per-Atom Reference Normalization 56 | 57 | Subtracts composition-weighted atomic reference values. This is particularly useful for energy predictions where you want to remove the baseline atomic contributions. 58 | 59 | ```python 60 | config = mt.configs.MatterTunerConfig( 61 | model=mt.configs.JMPBackboneConfig( 62 | # ... other configs ... 63 | normalizers={ 64 | "energy": [ 65 | mt.configs.PerAtomReferencingNormalizerConfig( 66 | # Option 1: Direct dictionary mapping 67 | per_atom_references={ 68 | 1: -13.6, # H 69 | 8: -2000.0 # O 70 | } 71 | # Option 2: List indexed by atomic number 72 | # per_atom_references=[0.0, -13.6, 0.0, ..., -2000.0] 73 | # Option 3: Path to JSON file 74 | # per_atom_references="path/to/references.json" 75 | ) 76 | ] 77 | } 78 | ), 79 | # ... other configs ... 80 | ) 81 | ``` 82 | 83 | ## Computing Normalization Parameters 84 | 85 | ### Per-Atom References 86 | 87 | MatterTune provides a CLI tool to compute per-atom reference values using either linear regression or ridge regression: 88 | 89 | ```bash 90 | python -m mattertune.normalization \ 91 | config.json \ 92 | energy \ 93 | references.json \ 94 | --reference-model linear 95 | ``` 96 | 97 | Arguments: 98 | - `config.json`: Path to your MatterTune configuration file 99 | - `energy`: Name of the property to compute references for 100 | - `references.json`: Output path for the computed references 101 | - `--reference-model`: Model type (`linear` or `ridge`) 102 | - `--reference-model-kwargs`: Optional JSON string of kwargs for the regression model 103 | 104 | The tool will: 105 | 1. Load your dataset from the config 106 | 2. Fit a linear model to predict property values from atomic compositions 107 | 3. Save the computed per-atom references to the specified JSON file 108 | 109 | ## Composing Multiple Normalizers 110 | 111 | You can combine multiple normalizers for a single property. They will be applied in sequence: 112 | 113 | ```python 114 | config = mt.configs.MatterTunerConfig( 115 | model=mt.configs.JMPBackboneConfig( 116 | # ... other configs ... 117 | normalizers={ 118 | "energy": [ 119 | # First subtract atomic references 120 | mt.configs.PerAtomReferencingNormalizerConfig( 121 | per_atom_references="references.json" 122 | ), 123 | # Then apply mean-std normalization 124 | mt.configs.MeanStdNormalizerConfig( 125 | mean=0.0, 126 | std=1.0 127 | ) 128 | ] 129 | } 130 | ), 131 | # ... other configs ... 132 | ) 133 | ``` 134 | 135 | ## Technical Details 136 | 137 | All normalizers implement the `NormalizerModule` protocol which requires: 138 | - `normalize(value: torch.Tensor, ctx: NormalizationContext) -> torch.Tensor` 139 | - `denormalize(value: torch.Tensor, ctx: NormalizationContext) -> torch.Tensor` 140 | 141 | The `NormalizationContext` provides composition information needed for per-atom normalization: 142 | ```python 143 | @dataclass(frozen=True) 144 | class NormalizationContext: 145 | compositions: torch.Tensor # shape: (batch_size, num_elements) 146 | ``` 147 | 148 | Each row in `compositions` represents the element counts for one structure, where the index corresponds to the atomic number (e.g., index 1 for hydrogen). 149 | 150 | ## Implementation Notes 151 | 152 | - Normalization is applied automatically during training 153 | - Loss is computed on normalized values for numerical stability 154 | - Predictions are automatically denormalized before metric computation and output 155 | - The property predictor and ASE calculator interfaces return denormalized values 156 | 157 | {py:mod}`mattertune.normalization` 158 | - {py:class}`mattertune.normalization.NormalizerModule` 159 | - {py:class}`mattertune.normalization.MeanStdNormalizerConfig` 160 | - {py:class}`mattertune.normalization.RMSNormalizerConfig` 161 | - {py:class}`mattertune.normalization.PerAtomReferencingNormalizerConfig` 162 | -------------------------------------------------------------------------------- /docs/guides/recipes.md: -------------------------------------------------------------------------------- 1 | # Recipes 2 | 3 | Recipes are modular components that modify the fine-tuning process in MatterTune. They provide a standardized way to implement advanced training techniques, particularly parameter-efficient fine-tuning methods, through Lightning callbacks. 4 | 5 | ## Overview 6 | 7 | Recipes are configurable components that modify how models are trained in MatterTune. Each recipe provides a specific capability - like making training more memory-efficient, adding regularization, or enabling advanced optimization techniques. 8 | 9 | Using a recipe is as simple as adding its configuration to your training setup: 10 | 11 | ```python 12 | config = mt.configs.MatterTunerConfig( 13 | model=mt.configs.JMPBackboneConfig(...), 14 | data=mt.configs.AutoSplitDataModuleConfig(...), 15 | recipes=[ 16 | # List of recipe configurations 17 | mt.configs.MyRecipeConfig(...), 18 | mt.configs.AnotherRecipeConfig(...) 19 | ] 20 | ) 21 | ``` 22 | 23 | When training starts, each recipe is applied in order to modify the model, optimizer, or training loop. Recipes can be combined to create custom training pipelines that suit your specific needs. 24 | 25 | ## Available Recipes 26 | 27 | ### LoRA (Low-Rank Adaptation) 28 | 29 | LoRA is a parameter-efficient fine-tuning technique that adds trainable rank decomposition matrices to model weights while keeping the original weights frozen. 30 | 31 | API Reference: {py:class}`mattertune.configs.LoRARecipeConfig` 32 | 33 | ```python 34 | import mattertune as mt 35 | 36 | config = mt.configs.MatterTunerConfig( 37 | model=mt.configs.JMPBackboneConfig(...), 38 | data=mt.configs.AutoSplitDataModuleConfig(...), 39 | recipes=[ 40 | mt.configs.LoRARecipeConfig( 41 | lora=mt.configs.LoraConfig( 42 | r=8, # LoRA rank 43 | target_modules=["linear1", "linear2"], # Layers to apply LoRA to 44 | lora_alpha=8, # LoRA scaling factor 45 | lora_dropout=0.1 # Dropout probability 46 | ) 47 | ) 48 | ] 49 | ) 50 | ``` 51 | 52 | ## Creating Custom Recipes 53 | 54 | A recipe consists of two main components: 55 | 56 | 1. A configuration class that defines the parameters 57 | 2. A Lightning callback that integrates it into training 58 | 59 | Here's how to create your own recipe: 60 | 61 | 1. Define a configuration class: 62 | ```python 63 | class MyRecipeConfig(RecipeConfigBase): 64 | param1: int 65 | param2: float 66 | 67 | @classmethod 68 | @override 69 | def ensure_dependencies(cls): 70 | # Check for required packages 71 | if importlib.util.find_spec("some_package") is None: 72 | raise ImportError("Required package not found") 73 | ``` 74 | 75 | 2. Implement the callback: 76 | ```python 77 | class MyRecipeConfig(RecipeConfigBase): 78 | # ... Configuration class 79 | 80 | def create_lightning_callback(self): 81 | from lightning.pytorch.callbacks import LambdaCallback 82 | 83 | return LambdaCallback( 84 | on_train_start=lambda trainer, pl_module: print("Training started" + self.param1), 85 | on_train_end=lambda trainer, pl_module: print("Training ended" + self.param2) 86 | ) 87 | ``` 88 | 89 | ## Best Practices 90 | 91 | 1. **Configuration Validation**: Validate recipe parameters in `__post_init__` 92 | 2. **Dependency Management**: Use `ensure_dependencies` to check for required packages 93 | 3. **Error Handling**: Provide clear error messages for configuration issues 94 | 4. **Documentation**: Include docstrings explaining parameters and their effects 95 | 5. **Type Safety**: Use type hints for all parameters and return values 96 | 97 | ## Integration with Training 98 | 99 | Recipes are automatically applied when training starts: 100 | 101 | ```python 102 | tuner = mt.MatterTuner(config) 103 | model, trainer = tuner.tune() # Recipes are applied here 104 | ``` 105 | 106 | ## Advanced Usage 107 | 108 | Recipes can be combined and will be applied in order: 109 | 110 | ```python 111 | config = mt.configs.MatterTunerConfig( 112 | model=mt.configs.JMPBackboneConfig(...), 113 | data=mt.configs.AutoSplitDataModuleConfig(...), 114 | recipes=[ 115 | mt.configs.LoRARecipeConfig(...), 116 | mt.configs.MyRecipeConfig(...), 117 | ] 118 | ) 119 | ``` 120 | -------------------------------------------------------------------------------- /docs/guides/training_config.md: -------------------------------------------------------------------------------- 1 | # Training Configuration Guide 2 | 3 | MatterTune uses a comprehensive configuration system to control all aspects of training. This guide covers the key components and how to use them effectively. 4 | 5 | ## Model Configuration 6 | 7 | Control the model architecture and training parameters: 8 | 9 | ```python 10 | model = mt.configs.JMPBackboneConfig( 11 | # Specify pre-trained model checkpoint 12 | ckpt_path="path/to/pretrained/model.pt", 13 | 14 | # Define properties to predict 15 | properties=[ 16 | mt.configs.EnergyPropertyConfig( 17 | loss=mt.configs.MAELossConfig(), 18 | loss_coefficient=1.0 19 | ), 20 | mt.configs.ForcesPropertyConfig( 21 | loss=mt.configs.MAELossConfig(), 22 | loss_coefficient=10.0, 23 | conservative=True # Use energy-conserving force prediction 24 | ) 25 | ], 26 | 27 | # Configure optimizer 28 | optimizer=mt.configs.AdamWConfig(lr=1e-4), 29 | 30 | # Optional: Configure learning rate scheduler 31 | lr_scheduler=mt.configs.CosineAnnealingLRConfig( 32 | T_max=100, # Number of epochs 33 | eta_min=1e-6 # Minimum learning rate 34 | ) 35 | ) 36 | ``` 37 | 38 | ## Data Configuration 39 | 40 | Configure data loading and processing: 41 | 42 | ```python 43 | data = mt.configs.AutoSplitDataModuleConfig( 44 | # Specify dataset source 45 | dataset=mt.configs.XYZDatasetConfig( 46 | src="path/to/your/data.xyz" 47 | ), 48 | 49 | # Control data splitting 50 | train_split=0.8, # 80% for training 51 | 52 | # Configure batch size and loading 53 | batch_size=32, 54 | num_workers=4, # Number of data loading workers 55 | pin_memory=True # Optimize GPU transfer 56 | ) 57 | ``` 58 | 59 | ## Training Process Configuration 60 | 61 | Control the training loop behavior: 62 | 63 | ```python 64 | trainer = mt.configs.TrainerConfig( 65 | # Hardware configuration 66 | accelerator="gpu", 67 | devices=[0, 1], # Use GPUs 0 and 1 68 | 69 | # Training stopping criteria 70 | max_epochs=100, 71 | # OR: max_steps=1000, # Stop after 1000 steps 72 | # OR: max_time=datetime.timedelta(hours=1), # Stop after 1 hour 73 | 74 | # Validation frequency 75 | check_val_every_n_epoch=1, 76 | 77 | # Gradient clipping: Prevent exploding gradients 78 | gradient_clip_val=1.0, 79 | 80 | # Early stopping configuration 81 | early_stopping=mt.configs.EarlyStoppingConfig( 82 | monitor="val/energy_mae", 83 | patience=20, 84 | mode="min" 85 | ), 86 | 87 | # Model checkpointing 88 | checkpoint=mt.configs.ModelCheckpointConfig( 89 | monitor="val/energy_mae", 90 | save_top_k=1, 91 | mode="min" 92 | ), 93 | 94 | # Configure logging 95 | loggers=[ 96 | mt.configs.WandbLoggerConfig( 97 | project="my-project", 98 | name="experiment-1" 99 | ) 100 | ] 101 | ) 102 | 103 | # Combine all configurations 104 | config = mt.configs.MatterTunerConfig( 105 | model=model, 106 | data=data, 107 | trainer=trainer 108 | ) 109 | ``` 110 | 111 | ## Configuration Management 112 | 113 | MatterTune uses [`nshconfig`](https://github.com/nimashoghi/nshconfig) for configuration management, providing several ways to create and load configurations: 114 | 115 | ### 1. Direct Construction 116 | 117 | ```python 118 | config = mt.configs.MatterTunerConfig( 119 | model=mt.configs.JMPBackboneConfig(...), 120 | data=mt.configs.AutoSplitDataModuleConfig(...), 121 | trainer=mt.configs.TrainerConfig(...) 122 | ) 123 | ``` 124 | 125 | ### 2. Loading from Files/Dictionaries 126 | 127 | ```python 128 | # Load from YAML 129 | config = mt.configs.MatterTunerConfig.from_yaml('/path/to/config.yaml') 130 | 131 | # Load from JSON 132 | config = mt.configs.MatterTunerConfig.from_json('/path/to/config.json') 133 | 134 | # Load from dictionary 135 | config = mt.configs.MatterTunerConfig.from_dict({ 136 | 'model': {...}, 137 | 'data': {...}, 138 | 'trainer': {...} 139 | }) 140 | ``` 141 | 142 | ### 3. Using Draft Configs 143 | 144 | ```python 145 | # Create a draft config 146 | config = mt.configs.MatterTunerConfig.draft() 147 | 148 | # Set values progressively 149 | config.model = mt.configs.JMPBackboneConfig.draft() 150 | config.model.ckpt_path = "path/to/model.pt" 151 | # ... set other values ... 152 | 153 | # Finalize the config 154 | final_config = config.finalize() 155 | ``` 156 | 157 | For more advanced configuration management features, see the [nshconfig documentation](https://github.com/nimashoghi/nshconfig). 158 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # MatterTune Documentation 2 | 3 | MatterTune is a flexible and powerful machine learning library designed specifically for fine-tuning state-of-the-art chemistry models. It provides intuitive interfaces for computational chemists to fine-tune pre-trained models on their specific use cases. 4 | 5 | ```{toctree} 6 | :maxdepth: 1 7 | :caption: Getting Started 8 | 9 | introduction 10 | motivation 11 | installation 12 | ``` 13 | 14 | ```{toctree} 15 | :maxdepth: 2 16 | :caption: User Guide 17 | 18 | guides/datasets 19 | guides/fine-tuning 20 | guides/model_usage 21 | guides/training_config 22 | guides/normalization 23 | guides/lightning 24 | ``` 25 | 26 | ```{toctree} 27 | :maxdepth: 2 28 | :caption: Model Backbones 29 | 30 | backbones/jmp 31 | backbones/m3gnet 32 | backbones/orb 33 | backbones/eqv2 34 | backbones/mattersim 35 | ``` 36 | 37 | ```{toctree} 38 | :maxdepth: 1 39 | :caption: Development 40 | 41 | api 42 | contributing 43 | license 44 | ``` 45 | 46 | ## Indices and tables 47 | 48 | * {ref}`genindex` 49 | * {ref}`modindex` 50 | * {ref}`search` 51 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # Installation Guide 2 | 3 | The installation of MatterTune consists of three stages: 4 | 5 | 1. Configure environment dependencies for one specific backbone model 6 | 2. Install the MatterTune package 7 | 3. Set up additional dependencies for external datasets and data sources 8 | 9 | ```{warning} 10 | MatterTune must be installed on environments with python>=3.10 11 | ``` 12 | 13 | ```{warning} 14 | Since there are dependency conflicts between different backbone models, we strongly recommend creating separate virtual environments for each backbone model you plan to use. 15 | ``` 16 | 17 | ## Backbone Installation 18 | 19 | Below are the installation instructions for our currently supported backbone models using conda and pip. 20 | 21 | 35 | 36 | ### MatterSim 37 | 38 | We strongly recommand to install MatterSim from source code 39 | 40 | ```bash 41 | git clone git@github.com:microsoft/mattersim.git 42 | cd mattersim 43 | ``` 44 | 45 | Find the line 41 of the pyproject.toml in MatterSim, which is ```"pydantic==2.9.2",```. Change it to ```"pydantic>=2.9.2",``` and ```python=3.9``` in environment.yaml to ```python=3.10```. After finishing this modification, install MatterSim by running: 46 | 47 | ```bash 48 | mamba env create -f environment.yaml 49 | mamba activate mattersim 50 | uv pip install -e . 51 | python setup.py build_ext --inplace 52 | ``` 53 | 54 | ### MACE 55 | 56 | MACE can be directly installed with pip: 57 | 58 | ```bash 59 | pip install --upgrade pip 60 | pip install mace-torch 61 | ``` 62 | 63 | or it can be installed from source code: 64 | 65 | ```bash 66 | git clone https://github.com/ACEsuit/mace.git 67 | pip install ./mace 68 | ``` 69 | 70 | 71 | ### JMP 72 | 73 | Please follow the installation instructions in the [jmp-backbone repository](https://github.com/nimashoghi/jmp-backbone/blob/lingyu-grad/README.md). 74 | 75 | ### ORB 76 | 77 | Please follow the installation instructions in the [orb-models repository](https://github.com/orbital-materials/orb-models). 78 | 79 | ### EquiformerV2 80 | 81 | ```bash 82 | conda create -n eqv2-tune python=3.10 83 | conda activate eqv2-tune 84 | pip install "git+https://github.com/FAIR-Chem/fairchem.git@omat24#subdirectory=packages/fairchem-core" --no-deps 85 | pip install ase "e3nn>=0.5" hydra-core lmdb numba "numpy>=1.26,<2.0" orjson "pymatgen>=2023.10.3" submitit tensorboard "torch==2.5.0" wandb torch_geometric h5py netcdf4 opt-einsum spglib 86 | ``` 87 | 88 | ## MatterTune Package Installation 89 | 90 | ```{important} 91 | MatterTune should be installed after setting up the backbone model dependencies. 92 | ``` 93 | 94 | Clone the repository and install MatterTune by: 95 | 96 | ```bash 97 | pip install -e . 98 | ``` 99 | 100 | ## External Dataset Installation 101 | 102 | ### Matbench 103 | 104 | Clone the repo and install by: 105 | ```bash 106 | git clone https://github.com/hackingmaterials/matbench 107 | cd matbench 108 | pip install -e . -r requirements-dev.txt 109 | ``` 110 | 111 | ### Materials Project 112 | 113 | Install mp-api: 114 | ```bash 115 | pip install mp-api 116 | ``` 117 | 118 | ```{note} 119 | There are currently dependency conflicts between mp-api and matbench packages. You may not be able to use both simultaneously in a single virtual environment. 120 | ``` 121 | 122 | ### Materials Project Trajectories 123 | 124 | To access MPTraj data from our Hugging Face dataset: 125 | ```bash 126 | pip install datasets 127 | ``` 128 | -------------------------------------------------------------------------------- /docs/introduction.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | ## Motivation 4 | 5 | Atomistic Foundation Models have emerged as powerful tools in molecular and materials science. However, the diverse implementations of these open-source models, with their varying architectures and interfaces, create significant barriers for customized fine-tuning and downstream applications. 6 | 7 | MatterTune is a comprehensive platform that addresses these challenges through systematic yet general abstraction of Atomistic Foundation Model architectures. By adopting a modular design philosophy, MatterTune provides flexible and concise user interfaces that enable intuitive and efficient fine-tuning workflows. 8 | 9 | ## Key Features 10 | 11 | ### Pre-trained Model Support 12 | Seamlessly work with multiple state-of-the-art pre-trained models including: 13 | - JMP 14 | - EquiformerV2 15 | - M3GNet 16 | - ORB 17 | - MatterSim 18 | - MACE 19 | 20 | ### Flexible Property Predictions 21 | Support for various molecular and materials properties: 22 | - Energy prediction 23 | - Force prediction (both conservative and non-conservative) 24 | - Stress tensor prediction 25 | - Custom system-level property predictions 26 | 27 | ### Data Processing 28 | Built-in support for multiple data formats: 29 | - XYZ files 30 | - ASE databases 31 | - Materials Project database 32 | - Matbench datasets 33 | - Custom datasets 34 | 35 | ### Training Features 36 | - Automated train/validation splitting 37 | - Multiple loss functions (MAE, MSE, Huber, L2-MAE) 38 | - Property normalization and scaling 39 | - Early stopping and model checkpointing 40 | - Comprehensive logging with WandB, TensorBoard, and CSV support 41 | -------------------------------------------------------------------------------- /docs/license.md: -------------------------------------------------------------------------------- 1 | # License Information 2 | 3 | MatterTune's core framework is licensed under the MIT License - see the [LICENSE](https://github.com/Fung-Lab/MatterTune/blob/main/LICENSE) file for details. 4 | 5 | ## Backbone Licenses 6 | 7 | Each supported model backbone is subject to its own licensing terms: 8 | 9 | ### JMP Backbone 10 | Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0) 11 | [JMP License](https://github.com/facebookresearch/JMP/blob/main/LICENSE.md) 12 | 13 | ### EquiformerV2 Backbone 14 | Meta Research License 15 | [EquiformerV2 License](https://huggingface.co/fairchem/OMAT24/blob/main/LICENSE) 16 | 17 | ### M3GNet Backbone 18 | BSD 3-Clause License 19 | [M3GNet License](https://github.com/materialsvirtuallab/m3gnet/blob/main/LICENSE) 20 | 21 | ### ORB Backbone 22 | Apache License 2.0 23 | [ORB License](https://github.com/orbital-materials/orb-models/blob/main/LICENSE) 24 | 25 | ### MatterSim Backbone 26 | MIT License 27 | [MatterSim License](https://github.com/microsoft/mattersim/blob/main/LICENSE.txt) 28 | 29 | 30 | ```{important} 31 | Please ensure compliance with the respective licenses when using specific model backbones in your project. For commercial use cases, carefully review each backbone's license terms or contact the respective authors for licensing options. 32 | ``` 33 | -------------------------------------------------------------------------------- /docs/requirements-torch.txt: -------------------------------------------------------------------------------- 1 | --index-url https://download.pytorch.org/whl/cpu 2 | torch 3 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx-autoapi 3 | sphinx-autodoc-typehints 4 | sphinx-rtd-theme 5 | myst-parser 6 | sphinx-autodoc-typehints 7 | sphinx-copybutton 8 | nshconfig[extra] 9 | nshconfig-extra[extra] 10 | numpy 11 | ase 12 | scikit-learn 13 | lightning 14 | torchmetrics 15 | -------------------------------------------------------------------------------- /examples/lora_decompose/Li3PO4-checkpoints/mattersim-1m-best-MPx3.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fung-Lab/MatterTune/e82b3b8ed5f4cacd2d2df273caf039a79d380583/examples/lora_decompose/Li3PO4-checkpoints/mattersim-1m-best-MPx3.ckpt -------------------------------------------------------------------------------- /examples/matbench-discovery/collect.py: -------------------------------------------------------------------------------- 1 | import os 2 | import rich 3 | 4 | import numpy as np 5 | from ase.io import read 6 | from matbench_discovery.metrics.discovery import stable_metrics 7 | 8 | PATH = "/net/csefiles/coc-fung-cluster/lingyu/matbench-discovery" 9 | MODEL = "eqv2" 10 | 11 | dir_path = os.path.join(PATH, MODEL) 12 | 13 | xyz_files = os.listdir(dir_path) 14 | 15 | atoms_list = [] 16 | for file in xyz_files: 17 | atoms_list.extend(read(os.path.join(dir_path, file), ":")) 18 | 19 | print(len(atoms_list)) 20 | assert len(atoms_list) == 256963 21 | 22 | e_hull_preds = np.array([atoms.info["e_hull_pred"] for atoms in atoms_list]) 23 | e_hull_trues = np.array([atoms.info["e_hull_true"] for atoms in atoms_list]) 24 | 25 | rich.print(stable_metrics(e_hull_trues, e_hull_preds)) -------------------------------------------------------------------------------- /examples/matbench-discovery/results.md: -------------------------------------------------------------------------------- 1 | | | EqV2 S DeNS Baseline | EqV2 S DeNS | MatterSimV1 5M Baseline| MatterSimV1 5M | Orb-V2 Baseline | Orb-V2 | 2 | |-|-|-|-|-|-|-| 3 | | F1|0.815|0.792|0.862|0.842|0.880|0.866| 4 | | DAF|5.042|4.718|5.852|5.255|6.041|5.395| 5 | |Prec|0.771|0.756|0.895|0.876|0.924|0.899| 6 | | Acc|0.941|0.925|0.959|0.949|0.965|0.957| 7 | | MAE|0.036|0.035|0.024|0.024|0.028|0.027| 8 | | R2|0.788|0.780|0.863|0.848|0.824|0.817| -------------------------------------------------------------------------------- /examples/matbench/results.md: -------------------------------------------------------------------------------- 1 | # matbench 2 | 3 | | | jmp-mattertune | jmp-ori | orb-mattertune | eqv2-mattertune| 4 | |-|-|-|-|-| 5 | |matbench_dielectric|0.146 | 0.133 / 0.252 | 0.142 | 0.111 | 6 | |matbench_jdft2d| 19.42 | 20.72 / 30.16 | 21.44 | 23.45 | 7 | |matbench_log_gvrh| 0.059 | 0.06 / 0.062 | 0.053 | 0.056 | 8 | |matbench_log_kvrh| 0.033 | 0.044 / 0.046| 0.046 | 0.046 | 9 | |matbench_mp_e_form| 25.2 | 13.6 / 13.3 | 9.4 | 24.5 | 10 | |matbench_mp_gap| 0.119 | 0.119 / 0.121 | 0.093 | 0.098 | 11 | |matbench_perovskites| 0.029 | 0.029 / 0.028 | 0.033 | 0.027 | 12 | |matbench_phonons| 42.23 | 26.6 / 22.77 | 67.92 | 50.57 | -------------------------------------------------------------------------------- /examples/matbench/screening.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from ase.io import read 4 | from ase import Atoms 5 | import numpy as np 6 | import wandb 7 | 8 | from mattertune.backbones import ( 9 | EqV2BackboneModule, 10 | JMPBackboneModule, 11 | ORBBackboneModule, 12 | ) 13 | 14 | 15 | def main(args_dict: dict): 16 | ## Load Checkpoint 17 | if "jmp" in args_dict["ckpt_path"]: 18 | model_type = "jmp" 19 | model = JMPBackboneModule.load_from_checkpoint( 20 | checkpoint_path=args_dict["ckpt_path"], map_location="cpu" 21 | ) 22 | elif "orb" in args_dict["ckpt_path"]: 23 | model_type = "orb" 24 | model = ORBBackboneModule.load_from_checkpoint( 25 | checkpoint_path=args_dict["ckpt_path"], map_location="cpu" 26 | ) 27 | elif "eqv2" in args_dict["ckpt_path"]: 28 | model_type = "eqv2" 29 | model = EqV2BackboneModule.load_from_checkpoint( 30 | checkpoint_path=args_dict["ckpt_path"], map_location="cpu", strict=False 31 | ) 32 | else: 33 | raise ValueError( 34 | "Invalid fine-tuning model, must be one of 'jmp', 'orb', or 'eqv2'." 35 | ) 36 | 37 | ## Load Screening Data 38 | atoms_list: list[Atoms] = read("/net/csefiles/coc-fung-cluster/lingyu/gnome_Bandgap.xyz", index=":") # type: ignore 39 | true_properties = np.array([atoms.info["bandgap"] for atoms in atoms_list]) 40 | exclude_inf_indices = np.where(np.isinf(true_properties))[0] 41 | atoms_list = [atoms_list[i] for i in range(len(atoms_list)) if i not in exclude_inf_indices] 42 | true_properties = np.array([true_properties[i] for i in range(len(true_properties)) if i not in exclude_inf_indices]) 43 | 44 | 45 | ## Run Property Prediction 46 | 47 | 48 | wandb.init( 49 | project="MatterTune-Examples", 50 | name="GNoME-Bandgap-Screening-{}".format( 51 | args_dict["ckpt_path"].split("/")[-1].split(".")[0] 52 | ), 53 | config=args_dict, 54 | ) 55 | property_predictor = model.property_predictor( 56 | lightning_trainer_kwargs={ 57 | "accelerator": "gpu", 58 | "devices": args_dict["devices"], 59 | "precision": "32", 60 | "inference_mode": False, 61 | "enable_progress_bar": True, 62 | "enable_model_summary": False, 63 | "logger": False, 64 | "barebones": False, 65 | } 66 | ) 67 | model_outs = property_predictor.predict( 68 | atoms_list, batch_size=args_dict["batch_size"] 69 | ) 70 | pred_properties = [out["matbench_mp_gap"].item() for out in model_outs] 71 | print(pred_properties) 72 | 73 | ## Compare Predictions 74 | from sklearn.metrics import ( 75 | accuracy_score, 76 | confusion_matrix, 77 | f1_score, 78 | mean_absolute_error, 79 | mean_squared_error, 80 | recall_score, 81 | ) 82 | 83 | true_properties = np.array(true_properties) 84 | pred_properties = np.array(pred_properties) 85 | 86 | # Regression Metrics 87 | mae = mean_absolute_error(true_properties, pred_properties) 88 | mse = mean_squared_error(true_properties, pred_properties) 89 | rmse = np.sqrt(mse) 90 | 91 | print(f"MAE: {mae:.4f}") 92 | print(f"MSE: {mse:.4f}") 93 | print(f"RMSE: {rmse:.4f}") 94 | 95 | ## Screening Metrics 96 | thresholds = sorted(args_dict["thresholds"]) 97 | lower_bound, upper_bound = thresholds[0], thresholds[1] 98 | 99 | true_labels = ( 100 | (true_properties >= lower_bound) & (true_properties < upper_bound) 101 | ).astype(int) 102 | pred_labels = ( 103 | (pred_properties >= lower_bound) & (pred_properties < upper_bound) 104 | ).astype(int) 105 | 106 | tn, fp, fn, tp = confusion_matrix(true_labels, pred_labels).ravel() 107 | 108 | print(f"True Positives: {tp}") 109 | print(f"True Negatives: {tn}") 110 | print(f"False Positives: {fp}") 111 | print(f"False Negatives: {fn}") 112 | 113 | accuracy = accuracy_score(true_labels, pred_labels) 114 | recall = recall_score(true_labels, pred_labels) 115 | f1 = f1_score(true_labels, pred_labels) 116 | 117 | print(f"Accuracy: {accuracy*100:.2f}%") 118 | print(f"Recall: {recall*100:.2f}%") 119 | print(f"F1 Score: {f1:.4f}") 120 | 121 | wandb.log( 122 | { 123 | "MAE": mae, 124 | "MSE": mse, 125 | "RMSE": rmse, 126 | "Accuracy": accuracy, 127 | "Recall": recall, 128 | "F1 Score": f1, 129 | } 130 | ) 131 | 132 | ## Plot Bandgap Distribution 133 | sorted_indices = np.argsort(true_properties) 134 | true_properties = true_properties[sorted_indices] 135 | pred_properties = pred_properties[sorted_indices] 136 | import matplotlib.pyplot as plt 137 | 138 | plt.figure(figsize=(6, 3)) 139 | plt.plot(pred_properties, label="Predicted Bandgap", alpha=0.5) 140 | plt.plot(true_properties, label="True Bandgap", alpha=0.5) 141 | plt.xlabel("Index") 142 | plt.ylabel("Bandgap (eV)") 143 | plt.legend() 144 | # plt.yscale("log") 145 | 146 | wandb.log({"Bandgap Distribution": plt}) 147 | 148 | plt.savefig(f"./plots/{model_type}-gnome.png") 149 | 150 | 151 | if __name__ == "__main__": 152 | import argparse 153 | 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument( 156 | "--ckpt_path", type=str, default="./checkpoints-matbench_mp_gap/orb-best-fold0.ckpt" 157 | ) 158 | parser.add_argument("--devices", type=int, nargs="+", default=[2]) 159 | parser.add_argument("--batch_size", type=int, default=12) 160 | parser.add_argument("--thresholds", type=float, nargs="+", default=[1.0, 3.0]) 161 | args = parser.parse_args() 162 | 163 | main(vars(args)) 164 | -------------------------------------------------------------------------------- /examples/structure-optimization/bfgs.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import copy 4 | import logging 5 | import os 6 | from typing import cast 7 | 8 | from ase import Atoms 9 | from ase.constraints import UnitCellFilter 10 | from ase.io import read, write 11 | from ase.optimize import BFGS 12 | 13 | from mattertune.backbones import JMPBackboneModule 14 | 15 | logging.basicConfig(level=logging.ERROR) 16 | logging.getLogger("lightning.pytorch").setLevel(logging.ERROR) 17 | 18 | 19 | def main(args_dict: dict): 20 | ## Load Checkpoint and Create ASE Calculator 21 | model = JMPBackboneModule.load_from_checkpoint( 22 | checkpoint_path=args_dict["checkpoint_path"], map_location="cpu" 23 | ) 24 | calc = model.ase_calculator(device=f"cuda:{args_dict['devices']}") 25 | 26 | ## Perform Structure Optimization with BFGS 27 | os.makedirs(args_dict["save_dir"], exist_ok=True) 28 | files = os.listdir(args_dict["init_structs"]) 29 | for file in files: 30 | if os.path.exists(os.path.join(args_dict["save_dir"], file)): 31 | continue 32 | atoms = read(os.path.join(args_dict["init_structs"], file)) 33 | assert isinstance(atoms, Atoms), "Expected an Atoms object" 34 | relax_traj = [] 35 | atoms.pbc = True 36 | atoms.calc = calc 37 | ucf = UnitCellFilter(atoms, scalar_pressure=0.0) 38 | ucf_as_atoms = cast(Atoms, ucf) 39 | # UnitCellFilter is not a subclass of Atoms, but it can fill the role of Atoms in nearly all contexts 40 | opt = BFGS(ucf_as_atoms, logfile=None) 41 | 42 | def write_traj(): 43 | relax_traj.append(copy.deepcopy(atoms)) 44 | 45 | opt.attach(write_traj) 46 | opt.run(fmax=0.01) 47 | write(os.path.join(args_dict["save_dir"], file), relax_traj) 48 | 49 | 50 | if __name__ == "__main__": 51 | import argparse 52 | 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument( 55 | "--checkpoint_path", type=str, default="./checkpoints/jmp-best.ckpt" 56 | ) 57 | parser.add_argument("--init_structs", type=str, default="./ZnMn2O4_random") 58 | parser.add_argument("--save_dir", type=str, default="./ZnMn2O4_mlrelaxed") 59 | parser.add_argument("--max_steps", type=int, default=None) 60 | parser.add_argument("--device", type=int, default=0) 61 | args_dict = vars(parser.parse_args()) 62 | main(args_dict) 63 | -------------------------------------------------------------------------------- /examples/water-thermodynamics/README.md: -------------------------------------------------------------------------------- 1 | # Few-shot Ambient Water Experiment 2 | 3 | In this folder, we show an example to fine-tune Foundation Models using only 30 samples of ambient water structures. 4 | 5 | To run the experiment you have to setup environment following: [installation guidance](https://fung-lab.github.io/MatterTune/installation.html) 6 | 7 | ## Quick Start 8 | 9 | Here we give an example on how to fine-tune MatterSim model, the best-performing model so far on this experiment, with our scripts. 10 | 11 | Firstly set up the environment: 12 | 13 | ``` 14 | git clone https://github.com/microsoft/mattersim.git 15 | cd mattersim 16 | conda create -n mattersim python=3.10 -y 17 | conda activate mattersim 18 | pip install -e . 19 | pip install cython>=0.29.32 setuptools>=45 20 | python setup.py build_ext --inplace 21 | cd .. 22 | git clone https://github.com/Fung-Lab/MatterTune.git 23 | cd MatterTune 24 | pip install -e . 25 | ``` 26 | 27 | Then run training scripts: 28 | 29 | ``` 30 | python water-finetune.py \ 31 | --model_type "mattersim-1m" \ 32 | --batch_size 16 \ 33 | --lr 1e-4 \ 34 | --devices 0 1 2 3 \ 35 | --conservative 36 | ``` 37 | 38 | You can use the fine-tuned checkpoint to run MD simulation: 39 | 40 | ``` 41 | python md.pt --ckpt_path PATH_TO_CKPT 42 | ``` -------------------------------------------------------------------------------- /examples/water-thermodynamics/data/water_1000_eVAng-energy_reference.json: -------------------------------------------------------------------------------- 1 | { 2 | "1": -187.5707682110115, 3 | "2": 0.0, 4 | "3": 0.0, 5 | "4": 0.0, 6 | "5": 0.0, 7 | "6": 0.0, 8 | "7": 0.0, 9 | "8": -93.7853844982933 10 | } -------------------------------------------------------------------------------- /examples/water-thermodynamics/draw_rdf_plots.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | import matplotlib.pyplot as plt 5 | 6 | gOO_benchmark_data = pd.read_csv("./results/g_OO(r)-SkinnerBenmore2014.csv") 7 | gOO_benchmark_r_values = gOO_benchmark_data["r_values"] 8 | gOO_benchmark_g_values = gOO_benchmark_data["295.1K-g_oo"] 9 | r_max = 6.0 10 | indices = gOO_benchmark_r_values <= r_max 11 | gOO_benchmark_r_values = gOO_benchmark_r_values[indices] 12 | gOO_benchmark_g_values = gOO_benchmark_g_values[indices] 13 | 14 | gOO_mattersim_data = np.load("./results/mattersim-1m-30-refill-g_OO(r).npz") 15 | gOO_mattersim_r_values = gOO_mattersim_data["rdf_x"] 16 | gOO_mattersim_g_values = gOO_mattersim_data["rdf_y"] 17 | 18 | gOO_jmp_data = np.load("./results/jmp-s-30-refill-g_OO(r).npz") 19 | gOO_jmp_r_values = gOO_jmp_data["rdf_x"] 20 | gOO_jmp_g_values = gOO_jmp_data["rdf_y"] 21 | 22 | gOO_orb_data = np.load("./results/orb-v2-30-refill-g_OO(r).npz") 23 | gOO_orb_r_values = gOO_orb_data["rdf_x"] 24 | gOO_orb_g_values = gOO_orb_data["rdf_y"] 25 | 26 | gOO_eqv2_data = np.load("./results/eqv2-30-refill-g_OO(r).npz") 27 | gOO_eqv2_r_values = gOO_eqv2_data["rdf_x"] 28 | gOO_eqv2_g_values = gOO_eqv2_data["rdf_y"] 29 | 30 | gOO_mace_medium_data = np.load("./results/mace_medium-30-refill-g_OO(r).npz") 31 | gOO_mace_medium_r_values = gOO_mace_medium_data["rdf_x"] 32 | gOO_mace_medium_g_values = gOO_mace_medium_data["rdf_y"] 33 | 34 | gOO_mattersim_mpx2_data = np.load("./results/mattersim-1m-mpx2-g_OO(r).npz") 35 | gOO_mattersim_mpx2_r_values = gOO_mattersim_mpx2_data["rdf_x"] 36 | gOO_mattersim_mpx2_g_values = gOO_mattersim_mpx2_data["rdf_y"] 37 | 38 | plt.scatter(gOO_benchmark_r_values, gOO_benchmark_g_values, label="Experiment", color="black", marker="o", s=10) 39 | # plt.plot(gOO_mattersim_r_values, gOO_mattersim_g_values, label="MatterSim-V1-1M (30 samples)", color="#EA8379", linewidth=2) 40 | # plt.plot(gOO_jmp_r_values, gOO_jmp_g_values, label="JMP-S (30 samples)", color="#7DAEE0", linestyle="dashed", linewidth=2) 41 | # plt.plot(gOO_orb_r_values, gOO_orb_g_values, label="ORB-V2 (30 samples)", color="#B395BD", linestyle=":", linewidth=2) 42 | # plt.plot(gOO_eqv2_r_values, gOO_eqv2_g_values, label="EqV2-31M (30 samples)", color="#1B7C3D", linestyle="-.", linewidth=2) 43 | plt.plot(gOO_mace_medium_r_values, gOO_mace_medium_g_values, label="MACE-Medium (30 samples)", color="#EA8379", linestyle="--", linewidth=2) 44 | # plt.plot(gOO_mattersim_r_values, gOO_mattersim_mpx2_g_values, label="MatterSim-V1-1M-MPX2 (30 samples)", color="#7DAEE0", linestyle="dotted", linewidth=2) 45 | 46 | plt.xlabel(r"$r$ ($\AA$)") 47 | plt.ylabel(r"$g_{OO}(r)$") 48 | plt.xlim(0, r_max) 49 | plt.ylim(0, 6.0) 50 | plt.legend() 51 | plt.tight_layout() 52 | plt.savefig("./plots/g_OO(r)-comparison.png", dpi=300) -------------------------------------------------------------------------------- /examples/water-thermodynamics/energy_reference.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import logging 5 | 6 | import nshutils as nu 7 | import numpy as np 8 | from ase.calculators.singlepoint import SinglePointCalculator 9 | from ase.io import read, write 10 | 11 | import mattertune.configs as MC 12 | from mattertune.normalization import compute_per_atom_references 13 | 14 | nu.pretty() 15 | 16 | xyz_path = "./data/water_1000_eVAng.xyz" 17 | 18 | dataset_config = MC.XYZDatasetConfig(src=xyz_path) 19 | dataset = dataset_config.create_dataset() 20 | 21 | ref_dict = compute_per_atom_references( 22 | dataset=dataset, 23 | property=MC.EnergyPropertyConfig(loss=MC.MAELossConfig()), 24 | reference_model="ridge", 25 | ) 26 | 27 | filename = xyz_path.split("/")[-1].split(".")[0] 28 | json.dump(ref_dict, open(f"./data/{filename}-energy_reference.json", "w"), indent=4) 29 | logging.info(f"Saved energy reference to energy_reference.json") 30 | -------------------------------------------------------------------------------- /examples/water-thermodynamics/run.sh: -------------------------------------------------------------------------------- 1 | # #!/bin/bash 2 | # # Source the conda.sh script to enable 'conda' command 3 | # source /net/csefiles/coc-fung-cluster/lingyu/miniconda3/etc/profile.d/conda.sh 4 | 5 | # models=( 6 | # "mattersim-1m" 7 | # # "orb-v2" 8 | # # "jmp-s" 9 | # # "eqv2" 10 | # ) 11 | 12 | # batch_size=16 13 | 14 | # for model in "${models[@]}"; do 15 | # if [ $model == "mattersim-1m" ]; then 16 | # conda activate mattersim-tune 17 | # batch_size=16 18 | # elif [ $model == "orb-v2" ]; then 19 | # conda activate orb-tune 20 | # batch_size=8 21 | # elif [ $model == "jmp-s" ]; then 22 | # conda activate jmp-tune 23 | # batch_size=1 24 | # elif [ $model == "eqv2" ]; then 25 | # conda activate eqv2-tune 26 | # batch_size=4 27 | # fi 28 | # python water-finetune.py --down_sample_refill \ 29 | # --model_type $model \ 30 | # --train_down_sample 30 \ 31 | # --batch_size $batch_size \ 32 | # --conservative 33 | # python water-finetune.py \ 34 | # --model_type $model \ 35 | # --train_down_sample 900 \ 36 | # --batch_size $batch_size \ 37 | # --conservative 38 | # done 39 | 40 | #!/bin/bash 41 | # Source the conda.sh script to enable 'conda' command 42 | source /net/csefiles/coc-fung-cluster/lingyu/miniconda3/etc/profile.d/conda.sh 43 | 44 | # conda activate orbv3-tune 45 | # batch_size=4 46 | # python water-finetune.py \ 47 | # --model_type "orb-v3-conservative-inf-omat" \ 48 | # --batch_size $batch_size \ 49 | # --lr 1e-4 \ 50 | # --devices 0 1 2 3 \ 51 | # --conservative 52 | 53 | # conda activate orbv3-tune 54 | # batch_size=6 55 | # python water-finetune.py \ 56 | # --model_type "orb-v2" \ 57 | # --batch_size $batch_size \ 58 | # --lr 1e-4 \ 59 | # --devices 0 1 2 3 60 | 61 | 62 | conda activate mattersim-tune 63 | batch_size=16 64 | python water-finetune.py \ 65 | --model_type "mattersim-1m" \ 66 | --batch_size 16 \ 67 | --lr 1e-4 \ 68 | --devices 2 \ 69 | --conservative -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "mattertune" 3 | version = "0.1.0" 4 | description = "" 5 | authors = [{ name = "Nima Shoghi", email = "nimashoghi@gmail.com" }] 6 | readme = "README.md" 7 | requires-python = ">=3.10" 8 | dependencies = [ 9 | "torch", 10 | "ase", 11 | "scikit-learn", 12 | "lightning", 13 | "numpy", 14 | "torchmetrics", 15 | "nshconfig[extra]", 16 | "nshconfig-extra[extra]", 17 | "rich", 18 | "wandb", 19 | "nshutils", 20 | ] 21 | 22 | [project.optional-dependencies] 23 | dev = ["pytest", "pre-commit", "ruff"] 24 | 25 | [build-system] 26 | requires = ["setuptools >= 61.0"] 27 | build-backend = "setuptools.build_meta" 28 | 29 | [tool.pyright] 30 | typeCheckingMode = "standard" 31 | deprecateTypingAliases = true 32 | strictListInference = true 33 | strictDictionaryInference = true 34 | strictSetInference = true 35 | reportPrivateImportUsage = false 36 | reportMatchNotExhaustive = "error" 37 | reportImplicitOverride = "warning" 38 | reportShadowedImports = "warning" 39 | 40 | [tool.ruff.lint] 41 | select = ["FA102", "FA100", "F401"] 42 | ignore = ["F722", "F821", "E731", "E741"] 43 | 44 | [tool.ruff.lint.isort] 45 | required-imports = ["from __future__ import annotations"] 46 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nshconfig 2 | torch 3 | ase 4 | scikit-learn 5 | lightning 6 | numpy 7 | torchmetrics 8 | nshconfig[extra] 9 | nshconfig-extra[extra] 10 | -------------------------------------------------------------------------------- /src/mattertune/.nshconfig.generated.json: -------------------------------------------------------------------------------- 1 | { 2 | "module": "mattertune", 3 | "output": "configs", 4 | "typed_dicts": null, 5 | "json_schemas": null 6 | } -------------------------------------------------------------------------------- /src/mattertune/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .finetune.base import FinetuneModuleBase as FinetuneModuleBase 4 | from .main import MatterTuner as MatterTuner 5 | from .registry import backbone_registry as backbone_registry 6 | from .registry import data_registry as data_registry 7 | 8 | try: 9 | from . import configs as configs 10 | except ImportError: 11 | pass 12 | -------------------------------------------------------------------------------- /src/mattertune/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Annotated 4 | 5 | from typing_extensions import TypeAliasType 6 | 7 | from ..finetune.base import FinetuneModuleBaseConfig 8 | from ..registry import backbone_registry 9 | from .eqV2 import EqV2BackboneConfig as EqV2BackboneConfig 10 | from .eqV2 import EqV2BackboneModule as EqV2BackboneModule 11 | from .jmp import JMPBackboneConfig as JMPBackboneConfig 12 | from .jmp import JMPBackboneModule as JMPBackboneModule 13 | from .m3gnet import M3GNetBackboneConfig as M3GNetBackboneConfig 14 | from .m3gnet import M3GNetBackboneModule as M3GNetBackboneModule 15 | from .mattersim import MatterSimBackboneConfig as MatterSimBackboneConfig 16 | from .mattersim import MatterSimM3GNetBackboneModule as MatterSimM3GNetBackboneModule 17 | from .orb import ORBBackboneConfig as ORBBackboneConfig 18 | from .orb import ORBBackboneModule as ORBBackboneModule 19 | from .mace_foundation import MACEBackboneConfig as MACEBackboneConfig 20 | from .mace_foundation import MACEBackboneModule as MACEBackboneModule 21 | 22 | ModelConfig = TypeAliasType( 23 | "ModelConfig", 24 | Annotated[ 25 | FinetuneModuleBaseConfig, 26 | backbone_registry.DynamicResolution(), 27 | ], 28 | ) 29 | -------------------------------------------------------------------------------- /src/mattertune/backbones/eqV2/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .model import EqV2BackboneConfig as EqV2BackboneConfig 4 | from .model import EqV2BackboneModule as EqV2BackboneModule 5 | -------------------------------------------------------------------------------- /src/mattertune/backbones/jmp/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .model import JMPBackboneConfig as JMPBackboneConfig 4 | from .model import JMPBackboneModule as JMPBackboneModule 5 | -------------------------------------------------------------------------------- /src/mattertune/backbones/jmp/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch.nn as nn 4 | 5 | from ...util import optional_import_error_message 6 | 7 | 8 | def get_activation_cls(activation: str) -> type[nn.Module]: 9 | """ 10 | Get the activation class from the activation name 11 | """ 12 | match activation.lower(): 13 | case "relu": 14 | return nn.ReLU 15 | case "silu" | "swish": 16 | return nn.SiLU 17 | case "scaled_silu" | "scaled_swish": 18 | with optional_import_error_message("jmp"): 19 | from jmp.models.gemnet.layers.base_layers import ScaledSiLU # type: ignore[reportMissingImports] # noqa 20 | 21 | return ScaledSiLU 22 | case "tanh": 23 | return nn.Tanh 24 | case "sigmoid": 25 | return nn.Sigmoid 26 | case "identity": 27 | return nn.Identity 28 | case _: 29 | raise ValueError(f"Activation {activation} is not supported") 30 | -------------------------------------------------------------------------------- /src/mattertune/backbones/m3gnet/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .model import M3GNetBackboneConfig as M3GNetBackboneConfig 4 | from .model import M3GNetBackboneModule as M3GNetBackboneModule 5 | from .model import M3GNetGraphComputerConfig as M3GNetGraphComputerConfig 6 | -------------------------------------------------------------------------------- /src/mattertune/backbones/mace_foundation/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .model import MACEBackboneConfig, MACEBackboneModule 4 | -------------------------------------------------------------------------------- /src/mattertune/backbones/mattersim/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .model import MatterSimBackboneConfig as MatterSimBackboneConfig 4 | from .model import MatterSimGraphConvertorConfig as MatterSimGraphConvertorConfig 5 | from .model import MatterSimM3GNetBackboneModule as MatterSimM3GNetBackboneModule 6 | -------------------------------------------------------------------------------- /src/mattertune/backbones/orb/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .model import ORBBackboneConfig as ORBBackboneConfig 4 | from .model import ORBBackboneModule as ORBBackboneModule 5 | -------------------------------------------------------------------------------- /src/mattertune/backbones/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | 5 | 6 | def voigt_6_to_full_3x3_stress_torch(stress_vector: torch.Tensor) -> torch.Tensor: 7 | """ 8 | Form a 3x3 stress matrix from a 6 component vector in Voigt notation 9 | 10 | Args: 11 | stress_vector: Tensor of shape (B, 6) where B is the batch size 12 | 13 | Returns: 14 | Tensor of shape (B, 3, 3) 15 | """ 16 | # Unpack the components 17 | s1, s2, s3, s4, s5, s6 = stress_vector.unbind(dim=1) 18 | 19 | # Stack the components into a 3x3 matrix 20 | # Each s_i is of shape (B,) 21 | stress_matrix = torch.stack( 22 | [ 23 | torch.stack([s1, s6, s5], dim=1), 24 | torch.stack([s6, s2, s4], dim=1), 25 | torch.stack([s5, s4, s3], dim=1), 26 | ], 27 | dim=1, 28 | ) 29 | 30 | return stress_matrix 31 | -------------------------------------------------------------------------------- /src/mattertune/callbacks/early_stopping.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Literal 4 | 5 | import nshconfig as C 6 | 7 | 8 | class EarlyStoppingConfig(C.Config): 9 | monitor: str = "val/total_loss" 10 | """Quantity to be monitored.""" 11 | 12 | min_delta: float = 0.0 13 | """Minimum change in monitored quantity to qualify as an improvement. Changes of less than or equal to 14 | `min_delta` will count as no improvement. Default: ``0.0``.""" 15 | 16 | patience: int = 3 17 | """Number of validation checks with no improvement after which training will be stopped. Default: ``3``.""" 18 | 19 | verbose: bool = False 20 | """Whether to print messages when improvement is found or early stopping is triggered. Default: ``False``.""" 21 | 22 | mode: Literal["min", "max"] = "min" 23 | """One of 'min' or 'max'. In 'min' mode, training stops when monitored quantity stops decreasing; 24 | in 'max' mode it stops when the quantity stops increasing. Default: ``'min'``.""" 25 | 26 | strict: bool = True 27 | """Whether to raise an error if monitored metric is not found in validation metrics. Default: ``True``.""" 28 | 29 | check_finite: bool = True 30 | """Whether to stop training when the monitor becomes NaN or infinite. Default: ``True``.""" 31 | 32 | stopping_threshold: float | None = None 33 | """Stop training immediately once the monitored quantity reaches this threshold. Default: ``None``.""" 34 | 35 | divergence_threshold: float | None = None 36 | """Stop training as soon as the monitored quantity becomes worse than this threshold. Default: ``None``.""" 37 | 38 | check_on_train_epoch_end: bool | None = None 39 | """Whether to run early stopping at the end of training epoch. If False, check runs at validation end. 40 | Default: ``None``.""" 41 | 42 | log_rank_zero_only: bool = False 43 | """Whether to log the status of early stopping only for rank 0 process. Default: ``False``.""" 44 | 45 | def create_callback(self): 46 | from lightning.pytorch.callbacks.early_stopping import EarlyStopping 47 | 48 | """Creates an EarlyStopping callback instance from this config.""" 49 | return EarlyStopping( 50 | monitor=self.monitor, 51 | min_delta=self.min_delta, 52 | patience=self.patience, 53 | verbose=self.verbose, 54 | mode=self.mode, 55 | strict=self.strict, 56 | check_finite=self.check_finite, 57 | stopping_threshold=self.stopping_threshold, 58 | divergence_threshold=self.divergence_threshold, 59 | check_on_train_epoch_end=self.check_on_train_epoch_end, 60 | log_rank_zero_only=self.log_rank_zero_only, 61 | ) 62 | -------------------------------------------------------------------------------- /src/mattertune/callbacks/model_checkpoint.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import timedelta 4 | from typing import Literal 5 | 6 | import nshconfig as C 7 | 8 | 9 | class ModelCheckpointConfig(C.Config): 10 | dirpath: str | None = None 11 | """Directory to save the model file. Default: ``None``.""" 12 | 13 | filename: str | None = None 14 | """Checkpoint filename. Can contain named formatting options. Default: ``None``.""" 15 | 16 | monitor: str | None = None 17 | """Quantity to monitor. Default: ``None``.""" 18 | 19 | verbose: bool = False 20 | """Verbosity mode. Default: ``False``.""" 21 | 22 | save_last: Literal[True, False, "link"] | None = None 23 | """When True or "link", saves a 'last.ckpt' checkpoint when a checkpoint is saved. Default: ``None``.""" 24 | 25 | save_top_k: int = 1 26 | """If save_top_k=k, save k models with best monitored quantity. Default: ``1``.""" 27 | 28 | save_weights_only: bool = False 29 | """If True, only save model weights. Default: ``False``.""" 30 | 31 | mode: Literal["min", "max"] = "min" 32 | """One of {'min', 'max'}. For 'min' training stops when monitored quantity stops decreasing. Default: ``'min'``.""" 33 | 34 | auto_insert_metric_name: bool = True 35 | """Whether to automatically insert metric name in checkpoint filename. Default: ``True``.""" 36 | 37 | every_n_train_steps: int | None = None 38 | """Number of training steps between checkpoints. Default: ``None``.""" 39 | 40 | train_time_interval: timedelta | None = None 41 | """Checkpoints are monitored at the specified time interval. Default: ``None``.""" 42 | 43 | every_n_epochs: int | None = None 44 | """Number of epochs between checkpoints. Default: ``None``.""" 45 | 46 | save_on_train_epoch_end: bool | None = None 47 | """Whether to run checkpointing at end of training epoch. Default: ``None``.""" 48 | 49 | enable_version_counter: bool = True 50 | """Whether to append version to existing filenames. Default: ``True``.""" 51 | 52 | def create_callback(self): 53 | """Creates a ModelCheckpoint callback instance from this config.""" 54 | from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint 55 | 56 | return ModelCheckpoint( 57 | dirpath=self.dirpath, 58 | filename=self.filename, 59 | monitor=self.monitor, 60 | verbose=self.verbose, 61 | save_last=self.save_last, 62 | save_top_k=self.save_top_k, 63 | save_weights_only=self.save_weights_only, 64 | mode=self.mode, 65 | auto_insert_metric_name=self.auto_insert_metric_name, 66 | every_n_train_steps=self.every_n_train_steps, 67 | train_time_interval=self.train_time_interval, 68 | every_n_epochs=self.every_n_epochs, 69 | save_on_train_epoch_end=self.save_on_train_epoch_end, 70 | enable_version_counter=self.enable_version_counter, 71 | ) 72 | -------------------------------------------------------------------------------- /src/mattertune/configs/.gitattributes: -------------------------------------------------------------------------------- 1 | * linguist-generated=true 2 | -------------------------------------------------------------------------------- /src/mattertune/configs/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.backbones.jmp.model import CutoffsConfig as CutoffsConfig 4 | from mattertune.backbones import EqV2BackboneConfig as EqV2BackboneConfig 5 | from mattertune.backbones.eqV2.model import FAIRChemAtomsToGraphSystemConfig as FAIRChemAtomsToGraphSystemConfig 6 | from mattertune.backbones import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 7 | from mattertune.backbones import JMPBackboneConfig as JMPBackboneConfig 8 | from mattertune.backbones.jmp.model import JMPGraphComputerConfig as JMPGraphComputerConfig 9 | from mattertune.backbones import M3GNetBackboneConfig as M3GNetBackboneConfig 10 | from mattertune.backbones.m3gnet import M3GNetGraphComputerConfig as M3GNetGraphComputerConfig 11 | from mattertune.backbones.mace_foundation.model import MACEBackboneConfig as MACEBackboneConfig 12 | from mattertune.backbones import MatterSimBackboneConfig as MatterSimBackboneConfig 13 | from mattertune.backbones.mattersim import MatterSimGraphConvertorConfig as MatterSimGraphConvertorConfig 14 | from mattertune.backbones.jmp.model import MaxNeighborsConfig as MaxNeighborsConfig 15 | from mattertune.backbones import ORBBackboneConfig as ORBBackboneConfig 16 | from mattertune.backbones.orb.model import ORBSystemConfig as ORBSystemConfig 17 | 18 | from mattertune.backbones.jmp.model import CutoffsConfig as CutoffsConfig 19 | from mattertune.backbones import EqV2BackboneConfig as EqV2BackboneConfig 20 | from mattertune.backbones.eqV2.model import FAIRChemAtomsToGraphSystemConfig as FAIRChemAtomsToGraphSystemConfig 21 | from mattertune.backbones import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 22 | from mattertune.backbones import JMPBackboneConfig as JMPBackboneConfig 23 | from mattertune.backbones.jmp.model import JMPGraphComputerConfig as JMPGraphComputerConfig 24 | from mattertune.backbones import M3GNetBackboneConfig as M3GNetBackboneConfig 25 | from mattertune.backbones.m3gnet import M3GNetGraphComputerConfig as M3GNetGraphComputerConfig 26 | from mattertune.backbones.mace_foundation.model import MACEBackboneConfig as MACEBackboneConfig 27 | from mattertune.backbones import MatterSimBackboneConfig as MatterSimBackboneConfig 28 | from mattertune.backbones.mattersim import MatterSimGraphConvertorConfig as MatterSimGraphConvertorConfig 29 | from mattertune.backbones.jmp.model import MaxNeighborsConfig as MaxNeighborsConfig 30 | from mattertune.backbones import ModelConfig as ModelConfig 31 | from mattertune.backbones import ORBBackboneConfig as ORBBackboneConfig 32 | from mattertune.backbones.orb.model import ORBSystemConfig as ORBSystemConfig 33 | 34 | from mattertune.backbones import backbone_registry as backbone_registry 35 | 36 | from . import eqV2 as eqV2 37 | from . import jmp as jmp 38 | from . import m3gnet as m3gnet 39 | from . import mace_foundation as mace_foundation 40 | from . import mattersim as mattersim 41 | from . import orb as orb 42 | 43 | __all__ = [ 44 | "CutoffsConfig", 45 | "EqV2BackboneConfig", 46 | "FAIRChemAtomsToGraphSystemConfig", 47 | "FinetuneModuleBaseConfig", 48 | "JMPBackboneConfig", 49 | "JMPGraphComputerConfig", 50 | "M3GNetBackboneConfig", 51 | "M3GNetGraphComputerConfig", 52 | "MACEBackboneConfig", 53 | "MatterSimBackboneConfig", 54 | "MatterSimGraphConvertorConfig", 55 | "MaxNeighborsConfig", 56 | "ModelConfig", 57 | "ORBBackboneConfig", 58 | "ORBSystemConfig", 59 | "backbone_registry", 60 | "eqV2", 61 | "jmp", 62 | "m3gnet", 63 | "mace_foundation", 64 | "mattersim", 65 | "orb", 66 | ] 67 | -------------------------------------------------------------------------------- /src/mattertune/configs/backbones/eqV2/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.backbones.eqV2 import EqV2BackboneConfig as EqV2BackboneConfig 4 | from mattertune.backbones.eqV2.model import FAIRChemAtomsToGraphSystemConfig as FAIRChemAtomsToGraphSystemConfig 5 | from mattertune.backbones.eqV2.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 6 | 7 | from mattertune.backbones.eqV2 import EqV2BackboneConfig as EqV2BackboneConfig 8 | from mattertune.backbones.eqV2.model import FAIRChemAtomsToGraphSystemConfig as FAIRChemAtomsToGraphSystemConfig 9 | from mattertune.backbones.eqV2.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 10 | 11 | from mattertune.backbones.eqV2.model import backbone_registry as backbone_registry 12 | 13 | from . import model as model 14 | 15 | __all__ = [ 16 | "EqV2BackboneConfig", 17 | "FAIRChemAtomsToGraphSystemConfig", 18 | "FinetuneModuleBaseConfig", 19 | "backbone_registry", 20 | "model", 21 | ] 22 | -------------------------------------------------------------------------------- /src/mattertune/configs/backbones/eqV2/model/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.backbones.eqV2.model import EqV2BackboneConfig as EqV2BackboneConfig 4 | from mattertune.backbones.eqV2.model import FAIRChemAtomsToGraphSystemConfig as FAIRChemAtomsToGraphSystemConfig 5 | from mattertune.backbones.eqV2.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 6 | 7 | from mattertune.backbones.eqV2.model import EqV2BackboneConfig as EqV2BackboneConfig 8 | from mattertune.backbones.eqV2.model import FAIRChemAtomsToGraphSystemConfig as FAIRChemAtomsToGraphSystemConfig 9 | from mattertune.backbones.eqV2.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 10 | 11 | from mattertune.backbones.eqV2.model import backbone_registry as backbone_registry 12 | 13 | 14 | __all__ = [ 15 | "EqV2BackboneConfig", 16 | "FAIRChemAtomsToGraphSystemConfig", 17 | "FinetuneModuleBaseConfig", 18 | "backbone_registry", 19 | ] 20 | -------------------------------------------------------------------------------- /src/mattertune/configs/backbones/jmp/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.backbones.jmp.model import CutoffsConfig as CutoffsConfig 4 | from mattertune.backbones.jmp.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 5 | from mattertune.backbones.jmp import JMPBackboneConfig as JMPBackboneConfig 6 | from mattertune.backbones.jmp.model import JMPGraphComputerConfig as JMPGraphComputerConfig 7 | from mattertune.backbones.jmp.model import MaxNeighborsConfig as MaxNeighborsConfig 8 | 9 | from mattertune.backbones.jmp.model import CutoffsConfig as CutoffsConfig 10 | from mattertune.backbones.jmp.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 11 | from mattertune.backbones.jmp import JMPBackboneConfig as JMPBackboneConfig 12 | from mattertune.backbones.jmp.model import JMPGraphComputerConfig as JMPGraphComputerConfig 13 | from mattertune.backbones.jmp.model import MaxNeighborsConfig as MaxNeighborsConfig 14 | 15 | from mattertune.backbones.jmp.model import backbone_registry as backbone_registry 16 | 17 | from . import model as model 18 | 19 | __all__ = [ 20 | "CutoffsConfig", 21 | "FinetuneModuleBaseConfig", 22 | "JMPBackboneConfig", 23 | "JMPGraphComputerConfig", 24 | "MaxNeighborsConfig", 25 | "backbone_registry", 26 | "model", 27 | ] 28 | -------------------------------------------------------------------------------- /src/mattertune/configs/backbones/jmp/model/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.backbones.jmp.model import CutoffsConfig as CutoffsConfig 4 | from mattertune.backbones.jmp.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 5 | from mattertune.backbones.jmp.model import JMPBackboneConfig as JMPBackboneConfig 6 | from mattertune.backbones.jmp.model import JMPGraphComputerConfig as JMPGraphComputerConfig 7 | from mattertune.backbones.jmp.model import MaxNeighborsConfig as MaxNeighborsConfig 8 | 9 | from mattertune.backbones.jmp.model import CutoffsConfig as CutoffsConfig 10 | from mattertune.backbones.jmp.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 11 | from mattertune.backbones.jmp.model import JMPBackboneConfig as JMPBackboneConfig 12 | from mattertune.backbones.jmp.model import JMPGraphComputerConfig as JMPGraphComputerConfig 13 | from mattertune.backbones.jmp.model import MaxNeighborsConfig as MaxNeighborsConfig 14 | 15 | from mattertune.backbones.jmp.model import backbone_registry as backbone_registry 16 | 17 | 18 | __all__ = [ 19 | "CutoffsConfig", 20 | "FinetuneModuleBaseConfig", 21 | "JMPBackboneConfig", 22 | "JMPGraphComputerConfig", 23 | "MaxNeighborsConfig", 24 | "backbone_registry", 25 | ] 26 | -------------------------------------------------------------------------------- /src/mattertune/configs/backbones/m3gnet/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.backbones.m3gnet.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 4 | from mattertune.backbones.m3gnet import M3GNetBackboneConfig as M3GNetBackboneConfig 5 | from mattertune.backbones.m3gnet import M3GNetGraphComputerConfig as M3GNetGraphComputerConfig 6 | 7 | from mattertune.backbones.m3gnet.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 8 | from mattertune.backbones.m3gnet import M3GNetBackboneConfig as M3GNetBackboneConfig 9 | from mattertune.backbones.m3gnet import M3GNetGraphComputerConfig as M3GNetGraphComputerConfig 10 | 11 | from mattertune.backbones.m3gnet.model import backbone_registry as backbone_registry 12 | 13 | from . import model as model 14 | 15 | __all__ = [ 16 | "FinetuneModuleBaseConfig", 17 | "M3GNetBackboneConfig", 18 | "M3GNetGraphComputerConfig", 19 | "backbone_registry", 20 | "model", 21 | ] 22 | -------------------------------------------------------------------------------- /src/mattertune/configs/backbones/m3gnet/model/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.backbones.m3gnet.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 4 | from mattertune.backbones.m3gnet.model import M3GNetBackboneConfig as M3GNetBackboneConfig 5 | from mattertune.backbones.m3gnet.model import M3GNetGraphComputerConfig as M3GNetGraphComputerConfig 6 | 7 | from mattertune.backbones.m3gnet.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 8 | from mattertune.backbones.m3gnet.model import M3GNetBackboneConfig as M3GNetBackboneConfig 9 | from mattertune.backbones.m3gnet.model import M3GNetGraphComputerConfig as M3GNetGraphComputerConfig 10 | 11 | from mattertune.backbones.m3gnet.model import backbone_registry as backbone_registry 12 | 13 | 14 | __all__ = [ 15 | "FinetuneModuleBaseConfig", 16 | "M3GNetBackboneConfig", 17 | "M3GNetGraphComputerConfig", 18 | "backbone_registry", 19 | ] 20 | -------------------------------------------------------------------------------- /src/mattertune/configs/backbones/mace_foundation/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.backbones.mace_foundation.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 4 | from mattertune.backbones.mace_foundation.model import MACEBackboneConfig as MACEBackboneConfig 5 | 6 | from mattertune.backbones.mace_foundation.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 7 | from mattertune.backbones.mace_foundation.model import MACEBackboneConfig as MACEBackboneConfig 8 | 9 | from mattertune.backbones.mace_foundation.model import backbone_registry as backbone_registry 10 | 11 | from . import model as model 12 | 13 | __all__ = [ 14 | "FinetuneModuleBaseConfig", 15 | "MACEBackboneConfig", 16 | "backbone_registry", 17 | "model", 18 | ] 19 | -------------------------------------------------------------------------------- /src/mattertune/configs/backbones/mace_foundation/model/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.backbones.mace_foundation.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 4 | from mattertune.backbones.mace_foundation.model import MACEBackboneConfig as MACEBackboneConfig 5 | 6 | from mattertune.backbones.mace_foundation.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 7 | from mattertune.backbones.mace_foundation.model import MACEBackboneConfig as MACEBackboneConfig 8 | 9 | from mattertune.backbones.mace_foundation.model import backbone_registry as backbone_registry 10 | 11 | 12 | __all__ = [ 13 | "FinetuneModuleBaseConfig", 14 | "MACEBackboneConfig", 15 | "backbone_registry", 16 | ] 17 | -------------------------------------------------------------------------------- /src/mattertune/configs/backbones/mattersim/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.backbones.mattersim.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 4 | from mattertune.backbones.mattersim import MatterSimBackboneConfig as MatterSimBackboneConfig 5 | from mattertune.backbones.mattersim import MatterSimGraphConvertorConfig as MatterSimGraphConvertorConfig 6 | 7 | from mattertune.backbones.mattersim.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 8 | from mattertune.backbones.mattersim import MatterSimBackboneConfig as MatterSimBackboneConfig 9 | from mattertune.backbones.mattersim import MatterSimGraphConvertorConfig as MatterSimGraphConvertorConfig 10 | 11 | from mattertune.backbones.mattersim.model import backbone_registry as backbone_registry 12 | 13 | from . import model as model 14 | 15 | __all__ = [ 16 | "FinetuneModuleBaseConfig", 17 | "MatterSimBackboneConfig", 18 | "MatterSimGraphConvertorConfig", 19 | "backbone_registry", 20 | "model", 21 | ] 22 | -------------------------------------------------------------------------------- /src/mattertune/configs/backbones/mattersim/model/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.backbones.mattersim.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 4 | from mattertune.backbones.mattersim.model import MatterSimBackboneConfig as MatterSimBackboneConfig 5 | from mattertune.backbones.mattersim.model import MatterSimGraphConvertorConfig as MatterSimGraphConvertorConfig 6 | 7 | from mattertune.backbones.mattersim.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 8 | from mattertune.backbones.mattersim.model import MatterSimBackboneConfig as MatterSimBackboneConfig 9 | from mattertune.backbones.mattersim.model import MatterSimGraphConvertorConfig as MatterSimGraphConvertorConfig 10 | 11 | from mattertune.backbones.mattersim.model import backbone_registry as backbone_registry 12 | 13 | 14 | __all__ = [ 15 | "FinetuneModuleBaseConfig", 16 | "MatterSimBackboneConfig", 17 | "MatterSimGraphConvertorConfig", 18 | "backbone_registry", 19 | ] 20 | -------------------------------------------------------------------------------- /src/mattertune/configs/backbones/orb/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.backbones.orb.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 4 | from mattertune.backbones.orb import ORBBackboneConfig as ORBBackboneConfig 5 | from mattertune.backbones.orb.model import ORBSystemConfig as ORBSystemConfig 6 | 7 | from mattertune.backbones.orb.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 8 | from mattertune.backbones.orb import ORBBackboneConfig as ORBBackboneConfig 9 | from mattertune.backbones.orb.model import ORBSystemConfig as ORBSystemConfig 10 | 11 | from mattertune.backbones.orb.model import backbone_registry as backbone_registry 12 | 13 | from . import model as model 14 | 15 | __all__ = [ 16 | "FinetuneModuleBaseConfig", 17 | "ORBBackboneConfig", 18 | "ORBSystemConfig", 19 | "backbone_registry", 20 | "model", 21 | ] 22 | -------------------------------------------------------------------------------- /src/mattertune/configs/backbones/orb/model/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.backbones.orb.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 4 | from mattertune.backbones.orb.model import ORBBackboneConfig as ORBBackboneConfig 5 | from mattertune.backbones.orb.model import ORBSystemConfig as ORBSystemConfig 6 | 7 | from mattertune.backbones.orb.model import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 8 | from mattertune.backbones.orb.model import ORBBackboneConfig as ORBBackboneConfig 9 | from mattertune.backbones.orb.model import ORBSystemConfig as ORBSystemConfig 10 | 11 | from mattertune.backbones.orb.model import backbone_registry as backbone_registry 12 | 13 | 14 | __all__ = [ 15 | "FinetuneModuleBaseConfig", 16 | "ORBBackboneConfig", 17 | "ORBSystemConfig", 18 | "backbone_registry", 19 | ] 20 | -------------------------------------------------------------------------------- /src/mattertune/configs/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.callbacks.ema import EMAConfig as EMAConfig 4 | from mattertune.callbacks.early_stopping import EarlyStoppingConfig as EarlyStoppingConfig 5 | from mattertune.callbacks.model_checkpoint import ModelCheckpointConfig as ModelCheckpointConfig 6 | 7 | from mattertune.callbacks.ema import EMAConfig as EMAConfig 8 | from mattertune.callbacks.early_stopping import EarlyStoppingConfig as EarlyStoppingConfig 9 | from mattertune.callbacks.model_checkpoint import ModelCheckpointConfig as ModelCheckpointConfig 10 | 11 | 12 | from . import early_stopping as early_stopping 13 | from . import ema as ema 14 | from . import model_checkpoint as model_checkpoint 15 | 16 | __all__ = [ 17 | "EMAConfig", 18 | "EarlyStoppingConfig", 19 | "ModelCheckpointConfig", 20 | "early_stopping", 21 | "ema", 22 | "model_checkpoint", 23 | ] 24 | -------------------------------------------------------------------------------- /src/mattertune/configs/callbacks/early_stopping/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.callbacks.early_stopping import EarlyStoppingConfig as EarlyStoppingConfig 4 | 5 | from mattertune.callbacks.early_stopping import EarlyStoppingConfig as EarlyStoppingConfig 6 | 7 | 8 | 9 | __all__ = [ 10 | "EarlyStoppingConfig", 11 | ] 12 | -------------------------------------------------------------------------------- /src/mattertune/configs/callbacks/ema/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.callbacks.ema import EMAConfig as EMAConfig 4 | 5 | from mattertune.callbacks.ema import EMAConfig as EMAConfig 6 | 7 | 8 | 9 | __all__ = [ 10 | "EMAConfig", 11 | ] 12 | -------------------------------------------------------------------------------- /src/mattertune/configs/callbacks/model_checkpoint/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.callbacks.model_checkpoint import ModelCheckpointConfig as ModelCheckpointConfig 4 | 5 | from mattertune.callbacks.model_checkpoint import ModelCheckpointConfig as ModelCheckpointConfig 6 | 7 | 8 | 9 | __all__ = [ 10 | "ModelCheckpointConfig", 11 | ] 12 | -------------------------------------------------------------------------------- /src/mattertune/configs/data/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.data.atoms_list import AtomsListDatasetConfig as AtomsListDatasetConfig 4 | from mattertune.data.datamodule import AutoSplitDataModuleConfig as AutoSplitDataModuleConfig 5 | from mattertune.data.db import DBDatasetConfig as DBDatasetConfig 6 | from mattertune.data.datamodule import DataModuleBaseConfig as DataModuleBaseConfig 7 | from mattertune.data import DatasetConfigBase as DatasetConfigBase 8 | from mattertune.data import JSONDatasetConfig as JSONDatasetConfig 9 | from mattertune.data import MPDatasetConfig as MPDatasetConfig 10 | from mattertune.data.mptraj import MPTrajDatasetConfig as MPTrajDatasetConfig 11 | from mattertune.data.datamodule import ManualSplitDataModuleConfig as ManualSplitDataModuleConfig 12 | from mattertune.data import MatbenchDatasetConfig as MatbenchDatasetConfig 13 | from mattertune.data import OMAT24DatasetConfig as OMAT24DatasetConfig 14 | from mattertune.data import XYZDatasetConfig as XYZDatasetConfig 15 | 16 | from mattertune.data.atoms_list import AtomsListDatasetConfig as AtomsListDatasetConfig 17 | from mattertune.data.datamodule import AutoSplitDataModuleConfig as AutoSplitDataModuleConfig 18 | from mattertune.data.db import DBDatasetConfig as DBDatasetConfig 19 | from mattertune.data.datamodule import DataModuleBaseConfig as DataModuleBaseConfig 20 | from mattertune.data import DataModuleConfig as DataModuleConfig 21 | from mattertune.data import DatasetConfig as DatasetConfig 22 | from mattertune.data import DatasetConfigBase as DatasetConfigBase 23 | from mattertune.data import JSONDatasetConfig as JSONDatasetConfig 24 | from mattertune.data import MPDatasetConfig as MPDatasetConfig 25 | from mattertune.data.mptraj import MPTrajDatasetConfig as MPTrajDatasetConfig 26 | from mattertune.data.datamodule import ManualSplitDataModuleConfig as ManualSplitDataModuleConfig 27 | from mattertune.data import MatbenchDatasetConfig as MatbenchDatasetConfig 28 | from mattertune.data import OMAT24DatasetConfig as OMAT24DatasetConfig 29 | from mattertune.data import XYZDatasetConfig as XYZDatasetConfig 30 | 31 | from mattertune.data.db import data_registry as data_registry 32 | 33 | from . import atoms_list as atoms_list 34 | from . import base as base 35 | from . import datamodule as datamodule 36 | from . import db as db 37 | from . import json_data as json_data 38 | from . import matbench as matbench 39 | from . import mp as mp 40 | from . import mptraj as mptraj 41 | from . import omat24 as omat24 42 | from . import xyz as xyz 43 | 44 | __all__ = [ 45 | "AtomsListDatasetConfig", 46 | "AutoSplitDataModuleConfig", 47 | "DBDatasetConfig", 48 | "DataModuleBaseConfig", 49 | "DataModuleConfig", 50 | "DatasetConfig", 51 | "DatasetConfigBase", 52 | "JSONDatasetConfig", 53 | "MPDatasetConfig", 54 | "MPTrajDatasetConfig", 55 | "ManualSplitDataModuleConfig", 56 | "MatbenchDatasetConfig", 57 | "OMAT24DatasetConfig", 58 | "XYZDatasetConfig", 59 | "atoms_list", 60 | "base", 61 | "data_registry", 62 | "datamodule", 63 | "db", 64 | "json_data", 65 | "matbench", 66 | "mp", 67 | "mptraj", 68 | "omat24", 69 | "xyz", 70 | ] 71 | -------------------------------------------------------------------------------- /src/mattertune/configs/data/atoms_list/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.data.atoms_list import AtomsListDatasetConfig as AtomsListDatasetConfig 4 | from mattertune.data.atoms_list import DatasetConfigBase as DatasetConfigBase 5 | 6 | from mattertune.data.atoms_list import AtomsListDatasetConfig as AtomsListDatasetConfig 7 | from mattertune.data.atoms_list import DatasetConfigBase as DatasetConfigBase 8 | 9 | from mattertune.data.atoms_list import data_registry as data_registry 10 | 11 | 12 | __all__ = [ 13 | "AtomsListDatasetConfig", 14 | "DatasetConfigBase", 15 | "data_registry", 16 | ] 17 | -------------------------------------------------------------------------------- /src/mattertune/configs/data/base/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.data.base import DatasetConfigBase as DatasetConfigBase 4 | 5 | from mattertune.data.base import DatasetConfig as DatasetConfig 6 | from mattertune.data.base import DatasetConfigBase as DatasetConfigBase 7 | 8 | from mattertune.data.base import data_registry as data_registry 9 | 10 | 11 | __all__ = [ 12 | "DatasetConfig", 13 | "DatasetConfigBase", 14 | "data_registry", 15 | ] 16 | -------------------------------------------------------------------------------- /src/mattertune/configs/data/datamodule/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.data.datamodule import AutoSplitDataModuleConfig as AutoSplitDataModuleConfig 4 | from mattertune.data.datamodule import DataModuleBaseConfig as DataModuleBaseConfig 5 | from mattertune.data.datamodule import ManualSplitDataModuleConfig as ManualSplitDataModuleConfig 6 | 7 | from mattertune.data.datamodule import AutoSplitDataModuleConfig as AutoSplitDataModuleConfig 8 | from mattertune.data.datamodule import DataModuleBaseConfig as DataModuleBaseConfig 9 | from mattertune.data.datamodule import DataModuleConfig as DataModuleConfig 10 | from mattertune.data.datamodule import DatasetConfig as DatasetConfig 11 | from mattertune.data.datamodule import ManualSplitDataModuleConfig as ManualSplitDataModuleConfig 12 | 13 | from mattertune.data.datamodule import data_registry as data_registry 14 | 15 | 16 | __all__ = [ 17 | "AutoSplitDataModuleConfig", 18 | "DataModuleBaseConfig", 19 | "DataModuleConfig", 20 | "DatasetConfig", 21 | "ManualSplitDataModuleConfig", 22 | "data_registry", 23 | ] 24 | -------------------------------------------------------------------------------- /src/mattertune/configs/data/db/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.data.db import DBDatasetConfig as DBDatasetConfig 4 | from mattertune.data.db import DatasetConfigBase as DatasetConfigBase 5 | 6 | from mattertune.data.db import DBDatasetConfig as DBDatasetConfig 7 | from mattertune.data.db import DatasetConfigBase as DatasetConfigBase 8 | 9 | from mattertune.data.db import data_registry as data_registry 10 | 11 | 12 | __all__ = [ 13 | "DBDatasetConfig", 14 | "DatasetConfigBase", 15 | "data_registry", 16 | ] 17 | -------------------------------------------------------------------------------- /src/mattertune/configs/data/json_data/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.data.json_data import DatasetConfigBase as DatasetConfigBase 4 | from mattertune.data.json_data import JSONDatasetConfig as JSONDatasetConfig 5 | 6 | from mattertune.data.json_data import DatasetConfigBase as DatasetConfigBase 7 | from mattertune.data.json_data import JSONDatasetConfig as JSONDatasetConfig 8 | 9 | from mattertune.data.json_data import data_registry as data_registry 10 | 11 | 12 | __all__ = [ 13 | "DatasetConfigBase", 14 | "JSONDatasetConfig", 15 | "data_registry", 16 | ] 17 | -------------------------------------------------------------------------------- /src/mattertune/configs/data/matbench/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.data.matbench import DatasetConfigBase as DatasetConfigBase 4 | from mattertune.data.matbench import MatbenchDatasetConfig as MatbenchDatasetConfig 5 | 6 | from mattertune.data.matbench import DatasetConfigBase as DatasetConfigBase 7 | from mattertune.data.matbench import MatbenchDatasetConfig as MatbenchDatasetConfig 8 | 9 | from mattertune.data.matbench import data_registry as data_registry 10 | 11 | 12 | __all__ = [ 13 | "DatasetConfigBase", 14 | "MatbenchDatasetConfig", 15 | "data_registry", 16 | ] 17 | -------------------------------------------------------------------------------- /src/mattertune/configs/data/mp/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.data.mp import DatasetConfigBase as DatasetConfigBase 4 | from mattertune.data.mp import MPDatasetConfig as MPDatasetConfig 5 | 6 | from mattertune.data.mp import DatasetConfigBase as DatasetConfigBase 7 | from mattertune.data.mp import MPDatasetConfig as MPDatasetConfig 8 | 9 | from mattertune.data.mp import data_registry as data_registry 10 | 11 | 12 | __all__ = [ 13 | "DatasetConfigBase", 14 | "MPDatasetConfig", 15 | "data_registry", 16 | ] 17 | -------------------------------------------------------------------------------- /src/mattertune/configs/data/mptraj/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.data.mptraj import DatasetConfigBase as DatasetConfigBase 4 | from mattertune.data.mptraj import MPTrajDatasetConfig as MPTrajDatasetConfig 5 | 6 | from mattertune.data.mptraj import DatasetConfigBase as DatasetConfigBase 7 | from mattertune.data.mptraj import MPTrajDatasetConfig as MPTrajDatasetConfig 8 | 9 | from mattertune.data.mptraj import data_registry as data_registry 10 | 11 | 12 | __all__ = [ 13 | "DatasetConfigBase", 14 | "MPTrajDatasetConfig", 15 | "data_registry", 16 | ] 17 | -------------------------------------------------------------------------------- /src/mattertune/configs/data/omat24/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.data.omat24 import DatasetConfigBase as DatasetConfigBase 4 | from mattertune.data.omat24 import OMAT24DatasetConfig as OMAT24DatasetConfig 5 | 6 | from mattertune.data.omat24 import DatasetConfigBase as DatasetConfigBase 7 | from mattertune.data.omat24 import OMAT24DatasetConfig as OMAT24DatasetConfig 8 | 9 | from mattertune.data.omat24 import data_registry as data_registry 10 | 11 | 12 | __all__ = [ 13 | "DatasetConfigBase", 14 | "OMAT24DatasetConfig", 15 | "data_registry", 16 | ] 17 | -------------------------------------------------------------------------------- /src/mattertune/configs/data/xyz/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.data.xyz import DatasetConfigBase as DatasetConfigBase 4 | from mattertune.data.xyz import XYZDatasetConfig as XYZDatasetConfig 5 | 6 | from mattertune.data.xyz import DatasetConfigBase as DatasetConfigBase 7 | from mattertune.data.xyz import XYZDatasetConfig as XYZDatasetConfig 8 | 9 | from mattertune.data.xyz import data_registry as data_registry 10 | 11 | 12 | __all__ = [ 13 | "DatasetConfigBase", 14 | "XYZDatasetConfig", 15 | "data_registry", 16 | ] 17 | -------------------------------------------------------------------------------- /src/mattertune/configs/finetune/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.finetune.optimizer import AdamConfig as AdamConfig 4 | from mattertune.finetune.optimizer import AdamWConfig as AdamWConfig 5 | from mattertune.finetune.lr_scheduler import ConstantLRConfig as ConstantLRConfig 6 | from mattertune.finetune.lr_scheduler import CosineAnnealingLRConfig as CosineAnnealingLRConfig 7 | from mattertune.finetune.properties import EnergyPropertyConfig as EnergyPropertyConfig 8 | from mattertune.finetune.lr_scheduler import ExponentialConfig as ExponentialConfig 9 | from mattertune.finetune.base import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 10 | from mattertune.finetune.properties import ForcesPropertyConfig as ForcesPropertyConfig 11 | from mattertune.finetune.properties import GraphPropertyConfig as GraphPropertyConfig 12 | from mattertune.finetune.loss import HuberLossConfig as HuberLossConfig 13 | from mattertune.finetune.loss import L2MAELossConfig as L2MAELossConfig 14 | from mattertune.finetune.lr_scheduler import LinearLRConfig as LinearLRConfig 15 | from mattertune.finetune.loss import MAELossConfig as MAELossConfig 16 | from mattertune.finetune.loss import MSELossConfig as MSELossConfig 17 | from mattertune.finetune.lr_scheduler import MultiStepLRConfig as MultiStepLRConfig 18 | from mattertune.finetune.optimizer import OptimizerConfigBase as OptimizerConfigBase 19 | from mattertune.finetune.properties import PropertyConfigBase as PropertyConfigBase 20 | from mattertune.finetune.base import ReduceOnPlateauConfig as ReduceOnPlateauConfig 21 | from mattertune.finetune.optimizer import SGDConfig as SGDConfig 22 | from mattertune.finetune.lr_scheduler import StepLRConfig as StepLRConfig 23 | from mattertune.finetune.properties import StressesPropertyConfig as StressesPropertyConfig 24 | 25 | from mattertune.finetune.optimizer import AdamConfig as AdamConfig 26 | from mattertune.finetune.optimizer import AdamWConfig as AdamWConfig 27 | from mattertune.finetune.lr_scheduler import ConstantLRConfig as ConstantLRConfig 28 | from mattertune.finetune.lr_scheduler import CosineAnnealingLRConfig as CosineAnnealingLRConfig 29 | from mattertune.finetune.properties import EnergyPropertyConfig as EnergyPropertyConfig 30 | from mattertune.finetune.lr_scheduler import ExponentialConfig as ExponentialConfig 31 | from mattertune.finetune.base import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 32 | from mattertune.finetune.properties import ForcesPropertyConfig as ForcesPropertyConfig 33 | from mattertune.finetune.properties import GraphPropertyConfig as GraphPropertyConfig 34 | from mattertune.finetune.loss import HuberLossConfig as HuberLossConfig 35 | from mattertune.finetune.loss import L2MAELossConfig as L2MAELossConfig 36 | from mattertune.finetune.lr_scheduler import LinearLRConfig as LinearLRConfig 37 | from mattertune.finetune.loss import LossConfig as LossConfig 38 | from mattertune.finetune.loss import MAELossConfig as MAELossConfig 39 | from mattertune.finetune.loss import MSELossConfig as MSELossConfig 40 | from mattertune.finetune.lr_scheduler import MultiStepLRConfig as MultiStepLRConfig 41 | from mattertune.finetune.base import NormalizerConfig as NormalizerConfig 42 | from mattertune.finetune.base import OptimizerConfig as OptimizerConfig 43 | from mattertune.finetune.optimizer import OptimizerConfigBase as OptimizerConfigBase 44 | from mattertune.finetune.base import PropertyConfig as PropertyConfig 45 | from mattertune.finetune.properties import PropertyConfigBase as PropertyConfigBase 46 | from mattertune.finetune.base import ReduceOnPlateauConfig as ReduceOnPlateauConfig 47 | from mattertune.finetune.optimizer import SGDConfig as SGDConfig 48 | from mattertune.finetune.lr_scheduler import SingleLRSchedulerConfig as SingleLRSchedulerConfig 49 | from mattertune.finetune.lr_scheduler import StepLRConfig as StepLRConfig 50 | from mattertune.finetune.properties import StressesPropertyConfig as StressesPropertyConfig 51 | 52 | 53 | from . import base as base 54 | from . import loss as loss 55 | from . import lr_scheduler as lr_scheduler 56 | from . import optimizer as optimizer 57 | from . import properties as properties 58 | 59 | __all__ = [ 60 | "AdamConfig", 61 | "AdamWConfig", 62 | "ConstantLRConfig", 63 | "CosineAnnealingLRConfig", 64 | "EnergyPropertyConfig", 65 | "ExponentialConfig", 66 | "FinetuneModuleBaseConfig", 67 | "ForcesPropertyConfig", 68 | "GraphPropertyConfig", 69 | "HuberLossConfig", 70 | "L2MAELossConfig", 71 | "LinearLRConfig", 72 | "LossConfig", 73 | "MAELossConfig", 74 | "MSELossConfig", 75 | "MultiStepLRConfig", 76 | "NormalizerConfig", 77 | "OptimizerConfig", 78 | "OptimizerConfigBase", 79 | "PropertyConfig", 80 | "PropertyConfigBase", 81 | "ReduceOnPlateauConfig", 82 | "SGDConfig", 83 | "SingleLRSchedulerConfig", 84 | "StepLRConfig", 85 | "StressesPropertyConfig", 86 | "base", 87 | "loss", 88 | "lr_scheduler", 89 | "optimizer", 90 | "properties", 91 | ] 92 | -------------------------------------------------------------------------------- /src/mattertune/configs/finetune/base/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.finetune.base import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 4 | from mattertune.finetune.base import ReduceOnPlateauConfig as ReduceOnPlateauConfig 5 | 6 | from mattertune.finetune.base import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 7 | from mattertune.finetune.base import NormalizerConfig as NormalizerConfig 8 | from mattertune.finetune.base import OptimizerConfig as OptimizerConfig 9 | from mattertune.finetune.base import PropertyConfig as PropertyConfig 10 | from mattertune.finetune.base import ReduceOnPlateauConfig as ReduceOnPlateauConfig 11 | 12 | 13 | 14 | __all__ = [ 15 | "FinetuneModuleBaseConfig", 16 | "NormalizerConfig", 17 | "OptimizerConfig", 18 | "PropertyConfig", 19 | "ReduceOnPlateauConfig", 20 | ] 21 | -------------------------------------------------------------------------------- /src/mattertune/configs/finetune/loss/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.finetune.loss import HuberLossConfig as HuberLossConfig 4 | from mattertune.finetune.loss import L2MAELossConfig as L2MAELossConfig 5 | from mattertune.finetune.loss import MAELossConfig as MAELossConfig 6 | from mattertune.finetune.loss import MSELossConfig as MSELossConfig 7 | 8 | from mattertune.finetune.loss import HuberLossConfig as HuberLossConfig 9 | from mattertune.finetune.loss import L2MAELossConfig as L2MAELossConfig 10 | from mattertune.finetune.loss import LossConfig as LossConfig 11 | from mattertune.finetune.loss import MAELossConfig as MAELossConfig 12 | from mattertune.finetune.loss import MSELossConfig as MSELossConfig 13 | 14 | 15 | 16 | __all__ = [ 17 | "HuberLossConfig", 18 | "L2MAELossConfig", 19 | "LossConfig", 20 | "MAELossConfig", 21 | "MSELossConfig", 22 | ] 23 | -------------------------------------------------------------------------------- /src/mattertune/configs/finetune/lr_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.finetune.lr_scheduler import ConstantLRConfig as ConstantLRConfig 4 | from mattertune.finetune.lr_scheduler import CosineAnnealingLRConfig as CosineAnnealingLRConfig 5 | from mattertune.finetune.lr_scheduler import ExponentialConfig as ExponentialConfig 6 | from mattertune.finetune.lr_scheduler import LinearLRConfig as LinearLRConfig 7 | from mattertune.finetune.lr_scheduler import MultiStepLRConfig as MultiStepLRConfig 8 | from mattertune.finetune.lr_scheduler import ReduceOnPlateauConfig as ReduceOnPlateauConfig 9 | from mattertune.finetune.lr_scheduler import StepLRConfig as StepLRConfig 10 | 11 | from mattertune.finetune.lr_scheduler import ConstantLRConfig as ConstantLRConfig 12 | from mattertune.finetune.lr_scheduler import CosineAnnealingLRConfig as CosineAnnealingLRConfig 13 | from mattertune.finetune.lr_scheduler import ExponentialConfig as ExponentialConfig 14 | from mattertune.finetune.lr_scheduler import LinearLRConfig as LinearLRConfig 15 | from mattertune.finetune.lr_scheduler import MultiStepLRConfig as MultiStepLRConfig 16 | from mattertune.finetune.lr_scheduler import ReduceOnPlateauConfig as ReduceOnPlateauConfig 17 | from mattertune.finetune.lr_scheduler import SingleLRSchedulerConfig as SingleLRSchedulerConfig 18 | from mattertune.finetune.lr_scheduler import StepLRConfig as StepLRConfig 19 | 20 | 21 | 22 | __all__ = [ 23 | "ConstantLRConfig", 24 | "CosineAnnealingLRConfig", 25 | "ExponentialConfig", 26 | "LinearLRConfig", 27 | "MultiStepLRConfig", 28 | "ReduceOnPlateauConfig", 29 | "SingleLRSchedulerConfig", 30 | "StepLRConfig", 31 | ] 32 | -------------------------------------------------------------------------------- /src/mattertune/configs/finetune/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.finetune.optimizer import AdamConfig as AdamConfig 4 | from mattertune.finetune.optimizer import AdamWConfig as AdamWConfig 5 | from mattertune.finetune.optimizer import OptimizerConfigBase as OptimizerConfigBase 6 | from mattertune.finetune.optimizer import SGDConfig as SGDConfig 7 | 8 | from mattertune.finetune.optimizer import AdamConfig as AdamConfig 9 | from mattertune.finetune.optimizer import AdamWConfig as AdamWConfig 10 | from mattertune.finetune.optimizer import OptimizerConfig as OptimizerConfig 11 | from mattertune.finetune.optimizer import OptimizerConfigBase as OptimizerConfigBase 12 | from mattertune.finetune.optimizer import SGDConfig as SGDConfig 13 | 14 | 15 | 16 | __all__ = [ 17 | "AdamConfig", 18 | "AdamWConfig", 19 | "OptimizerConfig", 20 | "OptimizerConfigBase", 21 | "SGDConfig", 22 | ] 23 | -------------------------------------------------------------------------------- /src/mattertune/configs/finetune/properties/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.finetune.properties import EnergyPropertyConfig as EnergyPropertyConfig 4 | from mattertune.finetune.properties import ForcesPropertyConfig as ForcesPropertyConfig 5 | from mattertune.finetune.properties import GraphPropertyConfig as GraphPropertyConfig 6 | from mattertune.finetune.properties import PropertyConfigBase as PropertyConfigBase 7 | from mattertune.finetune.properties import StressesPropertyConfig as StressesPropertyConfig 8 | 9 | from mattertune.finetune.properties import EnergyPropertyConfig as EnergyPropertyConfig 10 | from mattertune.finetune.properties import ForcesPropertyConfig as ForcesPropertyConfig 11 | from mattertune.finetune.properties import GraphPropertyConfig as GraphPropertyConfig 12 | from mattertune.finetune.properties import LossConfig as LossConfig 13 | from mattertune.finetune.properties import PropertyConfig as PropertyConfig 14 | from mattertune.finetune.properties import PropertyConfigBase as PropertyConfigBase 15 | from mattertune.finetune.properties import StressesPropertyConfig as StressesPropertyConfig 16 | 17 | 18 | 19 | __all__ = [ 20 | "EnergyPropertyConfig", 21 | "ForcesPropertyConfig", 22 | "GraphPropertyConfig", 23 | "LossConfig", 24 | "PropertyConfig", 25 | "PropertyConfigBase", 26 | "StressesPropertyConfig", 27 | ] 28 | -------------------------------------------------------------------------------- /src/mattertune/configs/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.loggers import CSVLoggerConfig as CSVLoggerConfig 4 | from mattertune.loggers import TensorBoardLoggerConfig as TensorBoardLoggerConfig 5 | from mattertune.loggers import WandbLoggerConfig as WandbLoggerConfig 6 | 7 | from mattertune.loggers import CSVLoggerConfig as CSVLoggerConfig 8 | from mattertune.loggers import LoggerConfig as LoggerConfig 9 | from mattertune.loggers import TensorBoardLoggerConfig as TensorBoardLoggerConfig 10 | from mattertune.loggers import WandbLoggerConfig as WandbLoggerConfig 11 | 12 | 13 | 14 | __all__ = [ 15 | "CSVLoggerConfig", 16 | "LoggerConfig", 17 | "TensorBoardLoggerConfig", 18 | "WandbLoggerConfig", 19 | ] 20 | -------------------------------------------------------------------------------- /src/mattertune/configs/main/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.main import CSVLoggerConfig as CSVLoggerConfig 4 | from mattertune.main import EMAConfig as EMAConfig 5 | from mattertune.main import EarlyStoppingConfig as EarlyStoppingConfig 6 | from mattertune.main import MatterTunerConfig as MatterTunerConfig 7 | from mattertune.main import ModelCheckpointConfig as ModelCheckpointConfig 8 | from mattertune.main import TrainerConfig as TrainerConfig 9 | 10 | from mattertune.main import CSVLoggerConfig as CSVLoggerConfig 11 | from mattertune.main import DataModuleConfig as DataModuleConfig 12 | from mattertune.main import EMAConfig as EMAConfig 13 | from mattertune.main import EarlyStoppingConfig as EarlyStoppingConfig 14 | from mattertune.main import LoggerConfig as LoggerConfig 15 | from mattertune.main import MatterTunerConfig as MatterTunerConfig 16 | from mattertune.main import ModelCheckpointConfig as ModelCheckpointConfig 17 | from mattertune.main import ModelConfig as ModelConfig 18 | from mattertune.main import RecipeConfig as RecipeConfig 19 | from mattertune.main import TrainerConfig as TrainerConfig 20 | 21 | from mattertune.main import backbone_registry as backbone_registry 22 | from mattertune.main import data_registry as data_registry 23 | 24 | 25 | __all__ = [ 26 | "CSVLoggerConfig", 27 | "DataModuleConfig", 28 | "EMAConfig", 29 | "EarlyStoppingConfig", 30 | "LoggerConfig", 31 | "MatterTunerConfig", 32 | "ModelCheckpointConfig", 33 | "ModelConfig", 34 | "RecipeConfig", 35 | "TrainerConfig", 36 | "backbone_registry", 37 | "data_registry", 38 | ] 39 | -------------------------------------------------------------------------------- /src/mattertune/configs/normalization/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.normalization import MeanStdNormalizerConfig as MeanStdNormalizerConfig 4 | from mattertune.normalization import NormalizerConfigBase as NormalizerConfigBase 5 | from mattertune.normalization import PerAtomNormalizerConfig as PerAtomNormalizerConfig 6 | from mattertune.normalization import PerAtomReferencingNormalizerConfig as PerAtomReferencingNormalizerConfig 7 | from mattertune.normalization import RMSNormalizerConfig as RMSNormalizerConfig 8 | 9 | from mattertune.normalization import MeanStdNormalizerConfig as MeanStdNormalizerConfig 10 | from mattertune.normalization import NormalizerConfig as NormalizerConfig 11 | from mattertune.normalization import NormalizerConfigBase as NormalizerConfigBase 12 | from mattertune.normalization import PerAtomNormalizerConfig as PerAtomNormalizerConfig 13 | from mattertune.normalization import PerAtomReferencingNormalizerConfig as PerAtomReferencingNormalizerConfig 14 | from mattertune.normalization import PropertyConfig as PropertyConfig 15 | from mattertune.normalization import RMSNormalizerConfig as RMSNormalizerConfig 16 | 17 | 18 | 19 | __all__ = [ 20 | "MeanStdNormalizerConfig", 21 | "NormalizerConfig", 22 | "NormalizerConfigBase", 23 | "PerAtomNormalizerConfig", 24 | "PerAtomReferencingNormalizerConfig", 25 | "PropertyConfig", 26 | "RMSNormalizerConfig", 27 | ] 28 | -------------------------------------------------------------------------------- /src/mattertune/configs/recipes/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.recipes import EMARecipeConfig as EMARecipeConfig 4 | from mattertune.recipes import LoRARecipeConfig as LoRARecipeConfig 5 | from mattertune.recipes.lora import LoraConfig as LoraConfig 6 | from mattertune.recipes import NoOpRecipeConfig as NoOpRecipeConfig 7 | from mattertune.recipes.lora import PeftConfig as PeftConfig 8 | from mattertune.recipes import RecipeConfigBase as RecipeConfigBase 9 | 10 | from mattertune.recipes import EMARecipeConfig as EMARecipeConfig 11 | from mattertune.recipes import LoRARecipeConfig as LoRARecipeConfig 12 | from mattertune.recipes.lora import LoraConfig as LoraConfig 13 | from mattertune.recipes import NoOpRecipeConfig as NoOpRecipeConfig 14 | from mattertune.recipes.lora import PeftConfig as PeftConfig 15 | from mattertune.recipes import RecipeConfig as RecipeConfig 16 | from mattertune.recipes import RecipeConfigBase as RecipeConfigBase 17 | 18 | from mattertune.recipes import recipe_registry as recipe_registry 19 | 20 | from . import base as base 21 | from . import ema as ema 22 | from . import lora as lora 23 | from . import noop as noop 24 | 25 | __all__ = [ 26 | "EMARecipeConfig", 27 | "LoRARecipeConfig", 28 | "LoraConfig", 29 | "NoOpRecipeConfig", 30 | "PeftConfig", 31 | "RecipeConfig", 32 | "RecipeConfigBase", 33 | "base", 34 | "ema", 35 | "lora", 36 | "noop", 37 | "recipe_registry", 38 | ] 39 | -------------------------------------------------------------------------------- /src/mattertune/configs/recipes/base/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.recipes.base import RecipeConfigBase as RecipeConfigBase 4 | 5 | from mattertune.recipes.base import RecipeConfig as RecipeConfig 6 | from mattertune.recipes.base import RecipeConfigBase as RecipeConfigBase 7 | 8 | from mattertune.recipes.base import recipe_registry as recipe_registry 9 | 10 | 11 | __all__ = [ 12 | "RecipeConfig", 13 | "RecipeConfigBase", 14 | "recipe_registry", 15 | ] 16 | -------------------------------------------------------------------------------- /src/mattertune/configs/recipes/ema/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.recipes.ema import EMARecipeConfig as EMARecipeConfig 4 | from mattertune.recipes.ema import RecipeConfigBase as RecipeConfigBase 5 | 6 | from mattertune.recipes.ema import EMARecipeConfig as EMARecipeConfig 7 | from mattertune.recipes.ema import RecipeConfigBase as RecipeConfigBase 8 | 9 | from mattertune.recipes.ema import recipe_registry as recipe_registry 10 | 11 | 12 | __all__ = [ 13 | "EMARecipeConfig", 14 | "RecipeConfigBase", 15 | "recipe_registry", 16 | ] 17 | -------------------------------------------------------------------------------- /src/mattertune/configs/recipes/lora/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.recipes.lora import LoRARecipeConfig as LoRARecipeConfig 4 | from mattertune.recipes.lora import LoraConfig as LoraConfig 5 | from mattertune.recipes.lora import PeftConfig as PeftConfig 6 | from mattertune.recipes.lora import RecipeConfigBase as RecipeConfigBase 7 | 8 | from mattertune.recipes.lora import LoRARecipeConfig as LoRARecipeConfig 9 | from mattertune.recipes.lora import LoraConfig as LoraConfig 10 | from mattertune.recipes.lora import PeftConfig as PeftConfig 11 | from mattertune.recipes.lora import RecipeConfigBase as RecipeConfigBase 12 | 13 | from mattertune.recipes.lora import recipe_registry as recipe_registry 14 | 15 | 16 | __all__ = [ 17 | "LoRARecipeConfig", 18 | "LoraConfig", 19 | "PeftConfig", 20 | "RecipeConfigBase", 21 | "recipe_registry", 22 | ] 23 | -------------------------------------------------------------------------------- /src/mattertune/configs/recipes/noop/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.recipes.noop import NoOpRecipeConfig as NoOpRecipeConfig 4 | from mattertune.recipes.noop import RecipeConfigBase as RecipeConfigBase 5 | 6 | from mattertune.recipes.noop import NoOpRecipeConfig as NoOpRecipeConfig 7 | from mattertune.recipes.noop import RecipeConfigBase as RecipeConfigBase 8 | 9 | from mattertune.recipes.noop import recipe_registry as recipe_registry 10 | 11 | 12 | __all__ = [ 13 | "NoOpRecipeConfig", 14 | "RecipeConfigBase", 15 | "recipe_registry", 16 | ] 17 | -------------------------------------------------------------------------------- /src/mattertune/configs/registry/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | from mattertune.registry import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 4 | 5 | from mattertune.registry import FinetuneModuleBaseConfig as FinetuneModuleBaseConfig 6 | 7 | from mattertune.registry import backbone_registry as backbone_registry 8 | from mattertune.registry import data_registry as data_registry 9 | 10 | 11 | __all__ = [ 12 | "FinetuneModuleBaseConfig", 13 | "backbone_registry", 14 | "data_registry", 15 | ] 16 | -------------------------------------------------------------------------------- /src/mattertune/configs/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | 4 | from mattertune.wrappers.property_predictor import PropertyConfig as PropertyConfig 5 | 6 | 7 | from . import property_predictor as property_predictor 8 | 9 | __all__ = [ 10 | "PropertyConfig", 11 | "property_predictor", 12 | ] 13 | -------------------------------------------------------------------------------- /src/mattertune/configs/wrappers/property_predictor/__init__.py: -------------------------------------------------------------------------------- 1 | __codegen__ = True 2 | 3 | 4 | from mattertune.wrappers.property_predictor import PropertyConfig as PropertyConfig 5 | 6 | 7 | 8 | __all__ = [ 9 | "PropertyConfig", 10 | ] 11 | -------------------------------------------------------------------------------- /src/mattertune/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .base import DatasetConfig as DatasetConfig 4 | from .base import DatasetConfigBase as DatasetConfigBase 5 | from .json_data import JSONDataset as JSONDataset 6 | from .json_data import JSONDatasetConfig as JSONDatasetConfig 7 | from .matbench import MatbenchDataset as MatbenchDataset 8 | from .matbench import MatbenchDatasetConfig as MatbenchDatasetConfig 9 | from .mp import MPDataset as MPDataset 10 | from .mp import MPDatasetConfig as MPDatasetConfig 11 | from .omat24 import OMAT24Dataset as OMAT24Dataset 12 | from .omat24 import OMAT24DatasetConfig as OMAT24DatasetConfig 13 | from .xyz import XYZDataset as XYZDataset 14 | from .xyz import XYZDatasetConfig as XYZDatasetConfig 15 | 16 | if True: 17 | pass 18 | 19 | from .datamodule import DataModuleConfig as DataModuleConfig 20 | from .datamodule import MatterTuneDataModule as MatterTuneDataModule 21 | -------------------------------------------------------------------------------- /src/mattertune/data/atoms_list.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import copy 4 | import logging 5 | from pathlib import Path 6 | from typing import Literal 7 | 8 | import ase 9 | import numpy as np 10 | from ase import Atoms 11 | from torch.utils.data import Dataset 12 | from typing_extensions import override 13 | 14 | from ..registry import data_registry 15 | from .base import DatasetConfigBase 16 | 17 | log = logging.getLogger(__name__) 18 | 19 | 20 | @data_registry.register 21 | class AtomsListDatasetConfig(DatasetConfigBase): 22 | type: Literal["atoms_list"] = "atoms_list" 23 | """Discriminator for the atoms_list dataset.""" 24 | 25 | atoms_list: list[ase.Atoms] 26 | """The list of Atoms objects.""" 27 | 28 | @override 29 | def create_dataset(self): 30 | return AtomsListDataset(self) 31 | 32 | 33 | class AtomsListDataset(Dataset[ase.Atoms]): 34 | def __init__(self, config: AtomsListDatasetConfig): 35 | super().__init__() 36 | self.config = config 37 | 38 | atoms_list = self.config.atoms_list 39 | assert isinstance(atoms_list, list), "Expected a list of Atoms objects" 40 | shuffle_indices = np.random.permutation(len(atoms_list)) 41 | self.atoms_list = [atoms_list[i] for i in shuffle_indices] 42 | 43 | @override 44 | def __getitem__(self, idx: int) -> ase.Atoms: 45 | return self.atoms_list[idx] 46 | 47 | def __len__(self) -> int: 48 | return len(self.atoms_list) 49 | -------------------------------------------------------------------------------- /src/mattertune/data/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Annotated 5 | 6 | import nshconfig as C 7 | from ase import Atoms 8 | from torch.utils.data import Dataset 9 | from typing_extensions import TypeAliasType 10 | 11 | from ..registry import data_registry 12 | 13 | 14 | class DatasetConfigBase(C.Config, ABC): 15 | @abstractmethod 16 | def create_dataset(self) -> Dataset[Atoms]: ... 17 | 18 | def prepare_data(self): 19 | """ 20 | Prepare the dataset for training. 21 | 22 | Use this to download and prepare data. Downloading and saving data with multiple processes (distributed 23 | settings) will result in corrupted data. Lightning ensures this method is called only within a single process, 24 | so you can safely add your downloading logic within this method. 25 | """ 26 | pass 27 | 28 | @classmethod 29 | def ensure_dependencies(cls): 30 | """ 31 | Ensure that all dependencies are installed. 32 | 33 | This method should raise an exception if any dependencies are missing, 34 | with a message indicating which dependencies are missing and 35 | how to install them. 36 | """ 37 | pass 38 | 39 | 40 | DatasetConfig = TypeAliasType( 41 | "DatasetConfig", 42 | Annotated[DatasetConfigBase, data_registry.DynamicResolution()], 43 | ) 44 | -------------------------------------------------------------------------------- /src/mattertune/data/db.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from pathlib import Path 5 | from typing import Literal 6 | 7 | import ase 8 | import numpy as np 9 | from ase.calculators.calculator import all_properties 10 | from ase.calculators.singlepoint import SinglePointCalculator 11 | from ase.db import connect 12 | from ase.db.core import Database 13 | from ase.stress import full_3x3_to_voigt_6_stress 14 | from torch.utils.data import Dataset 15 | from typing_extensions import override 16 | 17 | from ..registry import data_registry 18 | from .base import DatasetConfigBase 19 | 20 | log = logging.getLogger(__name__) 21 | 22 | 23 | @data_registry.register 24 | class DBDatasetConfig(DatasetConfigBase): 25 | """Configuration for a dataset stored in an ASE database.""" 26 | 27 | type: Literal["db"] = "db" 28 | """Discriminator for the DB dataset.""" 29 | 30 | src: Database | str | Path 31 | """Path to the ASE database file or a database object.""" 32 | 33 | energy_key: str | None = None 34 | """Key for the energy label in the database.""" 35 | 36 | forces_key: str | None = None 37 | """Key for the force label in the database.""" 38 | 39 | stress_key: str | None = None 40 | """Key for the stress label in the database.""" 41 | 42 | preload: bool = True 43 | """Whether to load all the data at once or not.""" 44 | 45 | @override 46 | def create_dataset(self): 47 | return DBDataset(self) 48 | 49 | 50 | class DBDataset(Dataset[ase.Atoms]): 51 | def __init__(self, config: DBDatasetConfig): 52 | super().__init__() 53 | self.config = config 54 | if isinstance(config.src, Database): 55 | self.db = config.src 56 | else: 57 | self.db = connect(config.src) 58 | if self.config.preload: 59 | self.atoms_list = [] 60 | for row in self.db.select(): 61 | atoms = self._load_atoms_from_row(row) 62 | self.atoms_list.append(atoms) 63 | 64 | def _load_atoms_from_row(self, row): 65 | atoms = row.toatoms() 66 | labels = dict(row.data) 67 | unrecognized_labels = {} 68 | if self.config.energy_key: 69 | labels["energy"] = labels.pop(self.config.energy_key) 70 | if self.config.forces_key: 71 | labels["forces"] = np.array(labels.pop(self.config.forces_key)) 72 | if self.config.stress_key: 73 | labels["stress"] = np.array(labels.pop(self.config.stress_key)) 74 | if labels["stress"].shape == (3, 3): 75 | labels["stress"] = full_3x3_to_voigt_6_stress(labels["stress"]) 76 | elif labels["stress"].shape != (6,): 77 | raise ValueError( 78 | f"Stress has unexpected shape: {labels['stress'].shape}, expected (3, 3) or (6,)" 79 | ) 80 | for key in list(labels.keys()): 81 | if key not in all_properties: 82 | unrecognized_labels[key] = labels.pop(key) 83 | calc = SinglePointCalculator(atoms, **labels) 84 | atoms.calc = calc 85 | atoms.info = unrecognized_labels 86 | return atoms 87 | 88 | @override 89 | def __getitem__(self, idx): 90 | if self.config.preload: 91 | return self.atoms_list[idx] 92 | else: 93 | row = self.db.get(idx=idx) 94 | return self._load_atoms_from_row(row) 95 | 96 | def __len__(self): 97 | return len(self.db) 98 | -------------------------------------------------------------------------------- /src/mattertune/data/json_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import logging 5 | from pathlib import Path 6 | from typing import Literal 7 | 8 | import numpy as np 9 | import torch 10 | from ase import Atoms 11 | from ase.calculators.singlepoint import SinglePointCalculator 12 | from torch.utils.data import Dataset 13 | from typing_extensions import override 14 | 15 | from ..registry import data_registry 16 | from .base import DatasetConfigBase 17 | 18 | log = logging.getLogger(__name__) 19 | 20 | 21 | @data_registry.register 22 | class JSONDatasetConfig(DatasetConfigBase): 23 | type: Literal["json"] = "json" 24 | """Discriminator for the JSON dataset.""" 25 | 26 | src: str | Path 27 | """The path to the JSON dataset.""" 28 | 29 | tasks: dict[str, str] 30 | """Attributes in the JSON file that correspond to the tasks to be predicted.""" 31 | 32 | @override 33 | def create_dataset(self): 34 | return JSONDataset(self) 35 | 36 | 37 | class JSONDataset(Dataset[Atoms]): 38 | def __init__(self, config: JSONDatasetConfig): 39 | super().__init__() 40 | self.config = config 41 | 42 | with open(str(self.config.src), "r") as f: 43 | raw_data = json.load(f) 44 | 45 | self.atoms_list = [] 46 | for entry in raw_data: 47 | atoms = Atoms( 48 | numbers=np.array(entry["atomic_numbers"]), 49 | positions=np.array(entry["positions"]), 50 | cell=np.array(entry["cell"]), 51 | pbc=True, 52 | ) 53 | 54 | energy, forces, stress = None, None, None 55 | if "energy" in self.config.tasks: 56 | energy = torch.tensor(entry[self.config.tasks["energy"]]) 57 | if "forces" in self.config.tasks: 58 | forces = torch.tensor(entry[self.config.tasks["forces"]]) 59 | if "stress" in self.config.tasks: 60 | stress = torch.tensor(entry[self.config.tasks["stress"]]) 61 | # ASE requires stress to be of shape (3, 3) or (6,) 62 | # Some datasets store stress with shape (1, 3, 3) 63 | if stress.ndim == 3: 64 | stress = stress.squeeze(0) 65 | 66 | single_point_calc = SinglePointCalculator( 67 | atoms, energy=energy, forces=forces, stress=stress 68 | ) 69 | 70 | atoms.calc = single_point_calc 71 | self.atoms_list.append(atoms) 72 | 73 | log.info(f"Loaded {len(self.atoms_list)} structures from {self.config.src}") 74 | 75 | @override 76 | def __getitem__(self, idx: int) -> Atoms: 77 | return self.atoms_list[idx] 78 | 79 | def __len__(self) -> int: 80 | return len(self.atoms_list) 81 | -------------------------------------------------------------------------------- /src/mattertune/data/matbench.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from typing import TYPE_CHECKING, Literal 5 | 6 | import ase 7 | from torch.utils.data import Dataset 8 | from typing_extensions import override 9 | 10 | from ..registry import data_registry 11 | from ..util import optional_import_error_message 12 | from .base import DatasetConfigBase 13 | 14 | if TYPE_CHECKING: 15 | from pymatgen.core.structure import Structure 16 | 17 | log = logging.getLogger(__name__) 18 | 19 | 20 | @data_registry.register 21 | class MatbenchDatasetConfig(DatasetConfigBase): 22 | """Configuration for the Matbench dataset.""" 23 | 24 | type: Literal["matbench"] = "matbench" 25 | """Discriminator for the Matbench dataset.""" 26 | 27 | task: str | None = None 28 | """The name of the self.tasks to include in the dataset.""" 29 | 30 | property_name: str | None = None 31 | """Assign a property name for the self.task. Must match the property head in the model.""" 32 | 33 | fold_idx: int = 0 34 | """The index of the fold to be used in the dataset.""" 35 | 36 | @override 37 | def create_dataset(self): 38 | return MatbenchDataset(self) 39 | 40 | 41 | class MatbenchDataset(Dataset[ase.Atoms]): 42 | def __init__(self, config: MatbenchDatasetConfig): 43 | super().__init__() 44 | self.config = config 45 | self._initialize_benchmark() 46 | self._load_data() 47 | 48 | def _initialize_benchmark(self) -> None: 49 | """Initialize the Matbench benchmark and task.""" 50 | 51 | with optional_import_error_message("matbench"): 52 | from matbench.bench import MatbenchBenchmark # type: ignore[reportMissingImports] # noqa 53 | 54 | if self.config.task is None: 55 | mb = MatbenchBenchmark(autoload=False) 56 | all_tasks = list(mb.metadata.keys()) 57 | raise ValueError(f"Please specify a task from {all_tasks}") 58 | else: 59 | mb = MatbenchBenchmark(autoload=False, subset=[self.config.task]) 60 | self._task = list(mb.tasks)[0] 61 | self._task.load() 62 | 63 | def _load_data(self) -> None: 64 | """Load and process the dataset split.""" 65 | assert ( 66 | self.config.fold_idx >= 0 and self.config.fold_idx < 5 67 | ), "Invalid fold index, should be within [0, 1, 2, 3, 4]" 68 | fold = self._task.folds[self.config.fold_idx] 69 | inputs_data, outputs_data = self._task.get_train_and_val_data(fold) 70 | 71 | self._atoms_list = self._convert_structures_to_atoms(inputs_data, outputs_data) 72 | log.info( 73 | f"Loaded {len(self._atoms_list)} samples " f"(fold {self.config.fold_idx})" 74 | ) 75 | 76 | def _convert_structures_to_atoms( 77 | self, 78 | structures: list[Structure], 79 | property_values: list[float] | None = None, 80 | ) -> list[ase.Atoms]: 81 | """Convert pymatgen structures to ASE atoms. 82 | 83 | Args: 84 | structures: List of pymatgen Structure objects. 85 | property_values: Optional list of property values to add to atoms.info. 86 | 87 | Returns: 88 | List of ASE ase.Atoms objects. 89 | """ 90 | with optional_import_error_message("pymatgen"): 91 | from pymatgen.io.ase import AseAtomsAdaptor # type: ignore[reportMissingImports] # noqa 92 | 93 | adapter = AseAtomsAdaptor() 94 | atoms_list = [] 95 | prop_name = ( 96 | self.config.property_name 97 | if self.config.property_name is not None 98 | else self.config.task 99 | ) 100 | for i, structure in enumerate(structures): 101 | atoms = adapter.get_atoms(structure) 102 | assert isinstance(atoms, ase.Atoms), "Expected an Atoms object" 103 | if property_values is not None: 104 | atoms.info[prop_name] = property_values[i] 105 | atoms_list.append(atoms) 106 | 107 | return atoms_list 108 | 109 | @override 110 | def __getitem__(self, idx: int) -> ase.Atoms: 111 | """Get an item from the dataset by index.""" 112 | return self._atoms_list[idx] 113 | 114 | def __len__(self) -> int: 115 | """Get the total number of items in the dataset.""" 116 | return len(self._atoms_list) 117 | 118 | def get_test_data(self) -> list[ase.Atoms]: 119 | """Load the test data for the current task and fold. 120 | 121 | Returns: 122 | List of ASE ase.Atoms objects from the test set. 123 | """ 124 | test_inputs = self._task.get_test_data( 125 | self._task.folds[self.config.fold_idx], include_target=False 126 | ) 127 | return self._convert_structures_to_atoms(test_inputs) 128 | -------------------------------------------------------------------------------- /src/mattertune/data/mp.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from typing import Literal 5 | 6 | import ase 7 | from ase import Atoms 8 | from torch.utils.data import Dataset 9 | from typing_extensions import override 10 | 11 | from ..registry import data_registry 12 | from ..util import optional_import_error_message 13 | from .base import DatasetConfigBase 14 | 15 | log = logging.getLogger(__name__) 16 | 17 | 18 | @data_registry.register 19 | class MPDatasetConfig(DatasetConfigBase): 20 | """Configuration for a dataset stored in the Materials Project database.""" 21 | 22 | type: Literal["mp"] = "mp" 23 | """Discriminator for the MP dataset.""" 24 | 25 | api: str 26 | """Input API key for the Materials Project database.""" 27 | 28 | fields: list[str] 29 | """Fields to retrieve from the Materials Project database.""" 30 | 31 | query: dict 32 | """Query to filter the data from the Materials Project database.""" 33 | 34 | @override 35 | def create_dataset(self): 36 | return MPDataset(self) 37 | 38 | 39 | class MPDataset(Dataset[ase.Atoms]): 40 | def __init__(self, config: MPDatasetConfig): 41 | super().__init__() 42 | self.config = config 43 | 44 | with optional_import_error_message("mp_api"): 45 | from mp_api.client import MPRester # type: ignore[reportMissingImports] 46 | 47 | self.mpr = MPRester(config.api) 48 | if "material_id" not in config.fields: 49 | config.fields.append("material_id") 50 | self.docs = self.mpr.summary.search(fields=config.fields, **config.query) 51 | 52 | @override 53 | def __getitem__(self, idx: int) -> Atoms: 54 | from pymatgen.io.ase import AseAtomsAdaptor 55 | 56 | doc = self.docs[idx] 57 | mid = doc.material_id 58 | structure = self.mpr.get_structure_by_material_id(mid) 59 | adaptor = AseAtomsAdaptor() 60 | atoms = adaptor.get_atoms(structure) 61 | assert isinstance(atoms, Atoms), "Expected an Atoms object" 62 | doc_labels = dict(doc) 63 | atoms.info = { 64 | key: doc_labels[key] 65 | for key in self.config.fields 66 | if key in doc_labels and key != "material_id" 67 | } 68 | return atoms 69 | 70 | def __len__(self) -> int: 71 | return len(self.docs) 72 | -------------------------------------------------------------------------------- /src/mattertune/data/mptraj.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from typing import Literal 5 | 6 | import ase 7 | from ase import Atoms 8 | from ase.calculators.singlepoint import SinglePointCalculator 9 | from ase.stress import full_3x3_to_voigt_6_stress 10 | from torch.utils.data import Dataset 11 | from tqdm import tqdm 12 | from typing_extensions import override 13 | 14 | from ..registry import data_registry 15 | from ..util import optional_import_error_message 16 | from .base import DatasetConfigBase 17 | 18 | log = logging.getLogger(__name__) 19 | 20 | 21 | @data_registry.register 22 | class MPTrajDatasetConfig(DatasetConfigBase): 23 | """Configuration for a dataset stored in the Materials Project database.""" 24 | 25 | type: Literal["mptraj"] = "mptraj" 26 | """Discriminator for the MPTraj dataset.""" 27 | 28 | split: Literal["train", "val", "test"] = "train" 29 | """Split of the dataset to use.""" 30 | 31 | min_num_atoms: int | None = 5 32 | """Minimum number of atoms to be considered. Drops structures with fewer atoms.""" 33 | 34 | max_num_atoms: int | None = None 35 | """Maximum number of atoms to be considered. Drops structures with more atoms.""" 36 | 37 | elements: list[str] | None = None 38 | """ 39 | List of elements to be considered. Drops structures with elements not in the list. 40 | Subsets are also allowed. For example, ["Li", "Na"] will keep structures with either Li or Na. 41 | """ 42 | 43 | @override 44 | def create_dataset(self): 45 | return MPTrajDataset(self) 46 | 47 | 48 | class MPTrajDataset(Dataset[ase.Atoms]): 49 | def __init__(self, config: MPTrajDatasetConfig): 50 | super().__init__() 51 | 52 | with optional_import_error_message("datasets"): 53 | import datasets # type: ignore[reportMissingImports] # noqa 54 | 55 | self.config = config 56 | 57 | dataset = datasets.load_dataset("nimashoghi/mptrj", split=self.config.split) 58 | assert isinstance(dataset, datasets.Dataset) 59 | dataset.set_format("numpy") 60 | self.atoms_list = [] 61 | pbar = tqdm(dataset, desc="Loading dataset...") 62 | for entry in dataset: 63 | atoms = self._load_atoms_from_entry(dict(entry)) 64 | if self._filter_atoms(atoms): 65 | self.atoms_list.append(atoms) 66 | pbar.update(1) 67 | pbar.close() 68 | 69 | def _load_atoms_from_entry(self, entry: dict) -> Atoms: 70 | atoms = Atoms( 71 | positions=entry["positions"], 72 | numbers=entry["numbers"], 73 | cell=entry["cell"], 74 | pbc=True, 75 | ) 76 | labels = { 77 | "energy": entry["corrected_total_energy"].item(), 78 | "forces": entry["forces"], 79 | "stress": full_3x3_to_voigt_6_stress(entry["stress"]), 80 | } 81 | calc = SinglePointCalculator(atoms, **labels) 82 | atoms.calc = calc 83 | return atoms 84 | 85 | def _filter_atoms(self, atoms: Atoms) -> bool: 86 | if ( 87 | self.config.min_num_atoms is not None 88 | and len(atoms) < self.config.min_num_atoms 89 | ): 90 | return False 91 | if ( 92 | self.config.max_num_atoms is not None 93 | and len(atoms) > self.config.max_num_atoms 94 | ): 95 | return False 96 | if self.config.elements is not None: 97 | elements = set(atoms.get_chemical_symbols()) 98 | if not set(self.config.elements) >= elements: 99 | return False 100 | return True 101 | 102 | @override 103 | def __getitem__(self, idx: int) -> Atoms: 104 | return self.atoms_list[idx] 105 | 106 | def __len__(self): 107 | return len(self.atoms_list) 108 | -------------------------------------------------------------------------------- /src/mattertune/data/omat24.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from typing import Literal 5 | 6 | import ase 7 | from torch.utils.data import Dataset 8 | from typing_extensions import override 9 | 10 | from ..registry import data_registry 11 | from ..util import optional_import_error_message 12 | from .base import DatasetConfigBase 13 | 14 | 15 | @data_registry.register 16 | class OMAT24DatasetConfig(DatasetConfigBase): 17 | type: Literal["omat24"] = "omat24" 18 | """Discriminator for the OMAT24 dataset.""" 19 | 20 | src: Path 21 | """The path to the OMAT24 dataset.""" 22 | 23 | @override 24 | def create_dataset(self): 25 | return OMAT24Dataset(self) 26 | 27 | 28 | class OMAT24Dataset(Dataset[ase.Atoms]): 29 | def __init__(self, config: OMAT24DatasetConfig): 30 | super().__init__() 31 | self.config = config 32 | 33 | with optional_import_error_message("fairchem"): 34 | from fairchem.core.datasets import AseDBDataset # type: ignore[reportMissingImports] # noqa 35 | 36 | self.dataset = AseDBDataset(config={"src": str(self.config.src)}) 37 | 38 | @override 39 | def __getitem__(self, idx: int) -> ase.Atoms: 40 | atoms = self.dataset.get_atoms(idx) 41 | return atoms 42 | 43 | def __len__(self) -> int: 44 | return len(self.dataset) 45 | -------------------------------------------------------------------------------- /src/mattertune/data/util/split_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Sized 4 | from typing import Generic, TypeVar 5 | 6 | import ase 7 | import numpy as np 8 | from ase import Atoms 9 | from torch.utils.data import Dataset 10 | from typing_extensions import override 11 | 12 | TDataset = TypeVar("TDataset", bound=Dataset[ase.Atoms], covariant=True) 13 | 14 | 15 | class SplitDataset(Dataset[ase.Atoms], Generic[TDataset]): 16 | @override 17 | def __init__(self, dataset: TDataset, indices: np.ndarray): 18 | super().__init__() 19 | 20 | self.dataset = dataset 21 | self.indices = indices 22 | 23 | # Make sure the underlying dataset is a sized mappable dataset. 24 | if not isinstance(dataset, Sized): 25 | raise TypeError( 26 | f"The underlying dataset must be sized, but got {dataset!r}." 27 | ) 28 | 29 | # Make sure the indices are valid. 30 | if not np.issubdtype(indices.dtype, np.integer): 31 | raise TypeError(f"The indices must be integers, but got {indices.dtype!r}.") 32 | 33 | if not (0 <= indices).all() and (indices < len(dataset)).all(): 34 | raise ValueError( 35 | f"The indices must be in the range [0, {len(dataset)}), but got [{indices.min()}, {indices.max()}]." 36 | ) 37 | 38 | def __len__(self) -> int: 39 | return len(self.indices) 40 | 41 | @override 42 | def __getitem__(self, index: int) -> Atoms: 43 | index = int(self.indices[index]) 44 | return self.dataset[index] 45 | -------------------------------------------------------------------------------- /src/mattertune/data/xyz.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from pathlib import Path 5 | from typing import Literal 6 | 7 | import ase 8 | from ase import Atoms 9 | from ase.io import read 10 | import numpy as np 11 | from torch.utils.data import Dataset 12 | from typing_extensions import override 13 | import copy 14 | 15 | from ..registry import data_registry 16 | from .base import DatasetConfigBase 17 | 18 | log = logging.getLogger(__name__) 19 | 20 | 21 | @data_registry.register 22 | class XYZDatasetConfig(DatasetConfigBase): 23 | type: Literal["xyz"] = "xyz" 24 | """Discriminator for the XYZ dataset.""" 25 | 26 | src: str | Path 27 | """The path to the XYZ dataset.""" 28 | 29 | down_sample: int | None = None 30 | """Down sample the dataset""" 31 | 32 | down_sample_refill: bool = False 33 | """Refill the dataset after down sampling to achieve the same length as the original dataset""" 34 | 35 | @override 36 | def create_dataset(self): 37 | return XYZDataset(self) 38 | 39 | 40 | class XYZDataset(Dataset[ase.Atoms]): 41 | def __init__(self, config: XYZDatasetConfig): 42 | super().__init__() 43 | self.config = config 44 | 45 | atoms_list = read(str(self.config.src), index=":") 46 | assert isinstance(atoms_list, list), "Expected a list of Atoms objects" 47 | if self.config.down_sample is not None: 48 | ori_length = len(atoms_list) 49 | down_indices = np.random.choice(ori_length, self.config.down_sample, replace=False) 50 | if self.config.down_sample_refill: 51 | refilled_down_indices = [] 52 | for _ in range((ori_length // self.config.down_sample)): 53 | refilled_down_indices.extend(copy.deepcopy(down_indices)) 54 | if len(refilled_down_indices) != ori_length: 55 | res = np.random.choice(len(down_indices), ori_length - len(refilled_down_indices), replace=False) 56 | refilled_down_indices.extend([down_indices[i] for i in res]) 57 | new_atoms_list = [copy.deepcopy(atoms_list[i]) for i in refilled_down_indices] 58 | atoms_list = new_atoms_list 59 | else: 60 | new_atoms_list = [copy.deepcopy(atoms_list[i]) for i in down_indices] 61 | atoms_list = new_atoms_list 62 | self.atoms_list: list[Atoms] = atoms_list 63 | log.info(f"Loaded {len(self.atoms_list)} atoms from {self.config.src}") 64 | 65 | @override 66 | def __getitem__(self, idx: int) -> ase.Atoms: 67 | return self.atoms_list[idx] 68 | 69 | def __len__(self) -> int: 70 | return len(self.atoms_list) 71 | -------------------------------------------------------------------------------- /src/mattertune/finetune/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fung-Lab/MatterTune/e82b3b8ed5f4cacd2d2df273caf039a79d380583/src/mattertune/finetune/__init__.py -------------------------------------------------------------------------------- /src/mattertune/finetune/callbacks/freeze_backbone.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | 5 | from lightning.pytorch import LightningModule, Trainer 6 | from lightning.pytorch.callbacks import Callback 7 | from typing_extensions import override 8 | 9 | log = logging.getLogger(__name__) 10 | 11 | 12 | class FreezeBackboneCallback(Callback): 13 | @override 14 | def on_fit_start(self, trainer: Trainer, pl_module: LightningModule): 15 | super().on_fit_start(trainer, pl_module) 16 | 17 | # Make sure the pl is our MatterTune model 18 | from ..base import FinetuneModuleBase 19 | 20 | if not isinstance(pl_module, FinetuneModuleBase): 21 | log.warning( 22 | "The model is not a MatterTune model. The backbone will not be frozen." 23 | ) 24 | return 25 | 26 | # Freeze the backbone 27 | num_backbone_params = 0 28 | for backbone_param in pl_module.pretrained_backbone_parameters(): 29 | backbone_param.requires_grad = False 30 | num_backbone_params += len(backbone_param) 31 | 32 | log.info(f"Froze {num_backbone_params} backbone parameters.") 33 | -------------------------------------------------------------------------------- /src/mattertune/finetune/data_util.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Callable, Iterator, Sized 4 | from typing import Any, Generic 5 | 6 | from torch.utils.data import Dataset, IterableDataset 7 | from typing_extensions import TypeVar, override 8 | 9 | TDataIn = TypeVar("TDataIn", default=Any, infer_variance=True) 10 | TDataOut = TypeVar("TDataOut", default=Any, infer_variance=True) 11 | 12 | 13 | class MapDatasetWrapper(Dataset[TDataOut], Generic[TDataIn, TDataOut]): 14 | @override 15 | def __init__( 16 | self, 17 | dataset: Dataset[TDataIn], 18 | map_fn: Callable[[TDataIn], TDataOut], 19 | ): 20 | assert isinstance( 21 | dataset, Sized 22 | ), "The dataset must be sized. Otherwise, use _IterableDatasetWrapper." 23 | self.dataset = dataset 24 | self.map_fn = map_fn 25 | 26 | def __len__(self) -> int: 27 | return len(self.dataset) 28 | 29 | @override 30 | def __getitem__(self, idx: int) -> TDataOut: 31 | return self.map_fn(self.dataset[idx]) 32 | 33 | 34 | class IterableDatasetWrapper(IterableDataset[TDataOut], Generic[TDataIn, TDataOut]): 35 | @override 36 | def __init__( 37 | self, 38 | dataset: IterableDataset[TDataIn], 39 | map_fn: Callable[[TDataIn], TDataOut], 40 | ): 41 | self.dataset = dataset 42 | self.map_fn = map_fn 43 | 44 | @override 45 | def __iter__(self) -> Iterator[TDataOut]: 46 | for data in self.dataset: 47 | yield self.map_fn(data) 48 | -------------------------------------------------------------------------------- /src/mattertune/finetune/loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Iterable 4 | from typing import TYPE_CHECKING 5 | 6 | import ase 7 | from torch.utils.data import DataLoader, Dataset, IterableDataset, Sampler 8 | from torch.utils.data.dataloader import _worker_init_fn_t 9 | from typing_extensions import TypedDict, Unpack 10 | 11 | from .data_util import IterableDatasetWrapper, MapDatasetWrapper 12 | 13 | if TYPE_CHECKING: 14 | from .base import FinetuneModuleBase, TBatch, TData, TFinetuneModuleConfig 15 | 16 | 17 | class DataLoaderKwargs(TypedDict, total=False): 18 | """Keyword arguments for creating a DataLoader. 19 | 20 | Args: 21 | batch_size: How many samples per batch to load (default: 1). 22 | shuffle: Set to True to have the data reshuffled at every epoch (default: False). 23 | sampler: Defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ 24 | implemented. If specified, shuffle must not be specified. 25 | batch_sampler: Like sampler, but returns a batch of indices at a time. Mutually exclusive with 26 | batch_size, shuffle, sampler, and drop_last. 27 | num_workers: How many subprocesses to use for data loading. 0 means that the data will be loaded 28 | in the main process (default: 0). 29 | pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory before 30 | returning them. 31 | drop_last: Set to True to drop the last incomplete batch, if the dataset size is not divisible by 32 | the batch size (default: False). 33 | timeout: If positive, the timeout value for collecting a batch from workers. Should always be 34 | non-negative (default: 0). 35 | worker_init_fn: If not None, this will be called on each worker subprocess with the worker id 36 | as input, after seeding and before data loading. 37 | multiprocessing_context: If None, the default multiprocessing context of your operating system 38 | will be used. 39 | generator: If not None, this RNG will be used by RandomSampler to generate random indexes and 40 | multiprocessing to generate base_seed for workers. 41 | prefetch_factor: Number of batches loaded in advance by each worker. 42 | persistent_workers: If True, the data loader will not shut down the worker processes after a 43 | dataset has been consumed once. 44 | pin_memory_device: The device to pin_memory to if pin_memory is True. 45 | """ 46 | 47 | batch_size: int | None 48 | shuffle: bool | None 49 | sampler: Sampler | Iterable | None 50 | batch_sampler: Sampler[list[int]] | Iterable[list[int]] | None 51 | num_workers: int 52 | pin_memory: bool 53 | drop_last: bool 54 | timeout: float 55 | worker_init_fn: _worker_init_fn_t | None 56 | multiprocessing_context: Any # type: ignore 57 | generator: Any # type: ignore 58 | prefetch_factor: int | None 59 | persistent_workers: bool 60 | pin_memory_device: str 61 | 62 | 63 | def create_dataloader( 64 | dataset: Dataset[ase.Atoms], 65 | has_labels: bool, 66 | *, 67 | lightning_module: FinetuneModuleBase[TData, TBatch, TFinetuneModuleConfig], 68 | **kwargs: Unpack[DataLoaderKwargs], 69 | ): 70 | def map_fn(ase_data: ase.Atoms): 71 | data = lightning_module.atoms_to_data(ase_data, has_labels) 72 | data = lightning_module.cpu_data_transform(data) 73 | return data 74 | 75 | # Wrap the dataset with the CPU data transform 76 | dataset_mapped = ( 77 | IterableDatasetWrapper(dataset, map_fn) 78 | if isinstance(dataset, IterableDataset) 79 | else MapDatasetWrapper(dataset, map_fn) 80 | ) 81 | # Create the data loader with the model's collate function 82 | dl = DataLoader(dataset_mapped, collate_fn=lightning_module.collate_fn, **kwargs) 83 | return dl 84 | -------------------------------------------------------------------------------- /src/mattertune/finetune/loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Annotated, Literal 4 | 5 | import nshconfig as C 6 | import torch 7 | import torch.nn.functional as F 8 | from typing_extensions import TypeAliasType, assert_never 9 | 10 | 11 | class MAELossConfig(C.Config): 12 | name: Literal["mae"] = "mae" 13 | reduction: Literal["mean", "sum"] = "mean" 14 | """How to reduce the loss values across the batch. 15 | 16 | - ``"mean"``: The mean of the loss values. 17 | - ``"sum"``: The sum of the loss values. 18 | """ 19 | 20 | 21 | class MSELossConfig(C.Config): 22 | name: Literal["mse"] = "mse" 23 | reduction: Literal["mean", "sum"] = "mean" 24 | """How to reduce the loss values across the batch. 25 | 26 | - ``"mean"``: The mean of the loss values. 27 | - ``"sum"``: The sum of the loss values. 28 | """ 29 | 30 | 31 | class HuberLossConfig(C.Config): 32 | name: Literal["huber"] = "huber" 33 | delta: float = 1.0 34 | """The threshold value for the Huber loss function.""" 35 | reduction: Literal["mean", "sum"] = "mean" 36 | """How to reduce the loss values across the batch. 37 | 38 | - ``"mean"``: The mean of the loss values. 39 | - ``"sum"``: The sum of the loss values. 40 | """ 41 | 42 | 43 | class L2MAELossConfig(C.Config): 44 | name: Literal["l2_mae"] = "l2_mae" 45 | reduction: Literal["mean", "sum"] = "mean" 46 | """How to reduce the loss values across the batch. 47 | 48 | - ``"mean"``: The mean of the loss values. 49 | - ``"sum"``: The sum of the loss values. 50 | """ 51 | 52 | 53 | def l2_mae_loss( 54 | output: torch.Tensor, 55 | target: torch.Tensor, 56 | reduction: Literal["mean", "sum", "none"] = "mean", 57 | ) -> torch.Tensor: 58 | distances = F.pairwise_distance(output, target, p=2) 59 | match reduction: 60 | case "mean": 61 | return distances.mean() 62 | case "sum": 63 | return distances.sum() 64 | case "none": 65 | return distances 66 | case _: 67 | assert_never(reduction) 68 | 69 | 70 | LossConfig = TypeAliasType( 71 | "LossConfig", 72 | Annotated[ 73 | MAELossConfig | MSELossConfig | HuberLossConfig | L2MAELossConfig, 74 | C.Field(discriminator="name"), 75 | ], 76 | ) 77 | 78 | 79 | def compute_loss( 80 | config: LossConfig, 81 | prediction: torch.Tensor, 82 | label: torch.Tensor, 83 | ) -> torch.Tensor: 84 | """ 85 | Compute the loss value given the model output, ``prediction``, 86 | and the target label, ``label``. 87 | 88 | The loss value should be a scalar tensor. 89 | 90 | Args: 91 | config: The loss configuration. 92 | prediction: The model output. 93 | label: The target label. 94 | 95 | Returns: 96 | The computed loss value. 97 | """ 98 | try: 99 | prediction = prediction.reshape(label.shape) 100 | except RuntimeError: 101 | raise ValueError( 102 | f"Prediction shape {prediction.shape} does not match ground truth shape {label.shape}" 103 | ) 104 | 105 | match config: 106 | case MAELossConfig(): 107 | return F.l1_loss(prediction, label, reduction=config.reduction) 108 | 109 | case MSELossConfig(): 110 | return F.mse_loss(prediction, label, reduction=config.reduction) 111 | 112 | case HuberLossConfig(): 113 | return F.huber_loss( 114 | prediction, label, delta=config.delta, reduction=config.reduction 115 | ) 116 | 117 | case L2MAELossConfig(): 118 | return l2_mae_loss(prediction, label, reduction=config.reduction) 119 | 120 | case _: 121 | assert_never(config) 122 | -------------------------------------------------------------------------------- /src/mattertune/finetune/metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from collections.abc import Mapping, Sequence 5 | from typing import TYPE_CHECKING, Any 6 | 7 | import torch.nn as nn 8 | import torchmetrics 9 | from typing_extensions import override 10 | 11 | if TYPE_CHECKING: 12 | from .properties import PropertyConfig 13 | 14 | 15 | class MetricBase(nn.Module, ABC): 16 | @override 17 | def __init__(self, property_name: str): 18 | super().__init__() 19 | 20 | self.property_name = property_name 21 | 22 | @abstractmethod 23 | @override 24 | def forward( 25 | self, prediction: dict[str, Any], ground_truth: dict[str, Any] 26 | ) -> Mapping[str, torchmetrics.Metric]: ... 27 | 28 | 29 | class PropertyMetrics(MetricBase): 30 | @override 31 | def __init__(self, property_name: str): 32 | super().__init__(property_name) 33 | 34 | self.mae = torchmetrics.MeanAbsoluteError() 35 | self.mse = torchmetrics.MeanSquaredError(squared=True) 36 | self.rmse = torchmetrics.MeanSquaredError(squared=False) 37 | # self.r2 = torchmetrics.R2Score() 38 | 39 | @override 40 | def forward( 41 | self, 42 | prediction: dict[str, Any], 43 | ground_truth: dict[str, Any], 44 | ): 45 | y_hat, y = prediction[self.property_name], ground_truth[self.property_name] 46 | try: 47 | y_hat = y_hat.reshape(y.shape) 48 | except RuntimeError: 49 | raise ValueError( 50 | f"Prediction shape {y_hat.shape} does not match ground truth shape {y.shape}" 51 | ) 52 | self.mae(y_hat, y) 53 | self.mse(y_hat, y) 54 | self.rmse(y_hat, y) 55 | # self.r2(y_hat, y) 56 | 57 | return { 58 | f"{self.property_name}_mae": self.mae, 59 | f"{self.property_name}_mse": self.mse, 60 | f"{self.property_name}_rmse": self.rmse, 61 | # f"{self.property_name}_r2": self.r2, 62 | } 63 | 64 | 65 | class FinetuneMetrics(nn.Module): 66 | def __init__( 67 | self, 68 | properties: Sequence[PropertyConfig], 69 | metric_prefix: str = "", 70 | ): 71 | super().__init__() 72 | 73 | self.metric_modules = nn.ModuleList( 74 | [prop.metric_cls()(prop.name) for prop in properties] 75 | ) 76 | 77 | self.metric_prefix = metric_prefix 78 | 79 | @override 80 | def forward( 81 | self, predictions: dict[str, Any], labels: dict[str, Any] 82 | ) -> Mapping[str, torchmetrics.Metric]: 83 | metrics = {} 84 | 85 | for metric_module in self.metric_modules: 86 | metrics.update(metric_module(predictions, labels)) 87 | 88 | return { 89 | f"{self.metric_prefix}{metric_name}": metric 90 | for metric_name, metric in metrics.items() 91 | } 92 | -------------------------------------------------------------------------------- /src/mattertune/loggers.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Annotated, Any, Literal 4 | 5 | import nshconfig as C 6 | from typing_extensions import TypeAliasType 7 | 8 | 9 | class CSVLoggerConfig(C.Config): 10 | type: Literal["csv"] = "csv" 11 | 12 | save_dir: str 13 | """Save directory for logs.""" 14 | 15 | name: str = "lightning_logs" 16 | """Experiment name. Default: ``'lightning_logs'``.""" 17 | 18 | version: int | str | None = None 19 | """Experiment version. If not specified, automatically assigns the next available version. 20 | Default: ``None``.""" 21 | 22 | prefix: str = "" 23 | """String to put at the beginning of metric keys. Default: ``''``.""" 24 | 25 | flush_logs_every_n_steps: int = 100 26 | """How often to flush logs to disk. Default: ``100``.""" 27 | 28 | def create_logger(self): 29 | """Creates a CSVLogger instance from this config.""" 30 | from lightning.pytorch.loggers.csv_logs import CSVLogger 31 | 32 | return CSVLogger( 33 | save_dir=self.save_dir, 34 | name=self.name, 35 | version=self.version, 36 | prefix=self.prefix, 37 | flush_logs_every_n_steps=self.flush_logs_every_n_steps, 38 | ) 39 | 40 | 41 | class WandbLoggerConfig(C.Config): 42 | type: Literal["wandb"] = "wandb" 43 | 44 | name: str | None = None 45 | """Display name for the run. Default: ``None``.""" 46 | 47 | save_dir: str = "." 48 | """Path where data is saved. Default: ``.``.""" 49 | 50 | version: str | None = None 51 | """Sets the version, mainly used to resume a previous run. Default: ``None``.""" 52 | 53 | offline: bool = False 54 | """Run offline (data can be streamed later to wandb servers). Default: ``False``.""" 55 | 56 | dir: str | None = None 57 | """Same as save_dir. Default: ``None``.""" 58 | 59 | id: str | None = None 60 | """Same as version. Default: ``None``.""" 61 | 62 | anonymous: bool | None = None 63 | """Enables or explicitly disables anonymous logging. Default: ``None``.""" 64 | 65 | project: str | None = None 66 | """The name of the project to which this run will belong. Default: ``None``.""" 67 | 68 | log_model: Literal["all"] | bool = False 69 | """Whether/how to log model checkpoints as W&B artifacts. Default: ``False``. 70 | If 'all', checkpoints are logged during training. 71 | If True, checkpoints are logged at the end of training. 72 | If False, no checkpoints are logged.""" 73 | 74 | prefix: str = "" 75 | """A string to put at the beginning of metric keys. Default: ``''``.""" 76 | 77 | experiment: Any | None = None # Run | RunDisabled | None 78 | """WandB experiment object. Automatically set when creating a run. Default: ``None``.""" 79 | 80 | checkpoint_name: str | None = None 81 | """Name of the model checkpoint artifact being logged. Default: ``None``.""" 82 | 83 | additional_init_parameters: dict[str, Any] = {} 84 | """Additional parameters to pass to wandb.init(). Default: ``{}``.""" 85 | 86 | def create_logger(self): 87 | """Creates a WandbLogger instance from this config.""" 88 | from lightning.pytorch.loggers.wandb import WandbLogger 89 | 90 | # Pass all parameters except additional_init_parameters to constructor 91 | base_params = { 92 | k: v 93 | for k, v in self.model_dump().items() 94 | if k != "additional_init_parameters" and k != "type" 95 | } 96 | 97 | # Merge with additional init parameters 98 | return WandbLogger(**base_params, **self.additional_init_parameters) 99 | 100 | 101 | class TensorBoardLoggerConfig(C.Config): 102 | type: Literal["tensorboard"] = "tensorboard" 103 | 104 | save_dir: str 105 | """Save directory where TensorBoard logs will be saved.""" 106 | 107 | name: str | None = "lightning_logs" 108 | """Experiment name. Default: ``'lightning_logs'``. If empty string, no per-experiment subdirectory is used.""" 109 | 110 | version: int | str | None = None 111 | """Experiment version. If not specified, logger auto-assigns next available version. 112 | If string, used as run-specific subdirectory name. Default: ``None``.""" 113 | 114 | log_graph: bool = False 115 | """Whether to add computational graph to tensorboard. Requires model.example_input_array to be defined. 116 | Default: ``False``.""" 117 | 118 | default_hp_metric: bool = True 119 | """Enables placeholder metric with key `hp_metric` when logging hyperparameters without a metric. 120 | Default: ``True``.""" 121 | 122 | prefix: str = "" 123 | """String to put at beginning of metric keys. Default: ``''``.""" 124 | 125 | sub_dir: str | None = None 126 | """Sub-directory to group TensorBoard logs. If provided, logs are saved in 127 | ``/save_dir/name/version/sub_dir/``. Default: ``None``.""" 128 | 129 | additional_params: dict[str, Any] = {} 130 | """Additional parameters passed to tensorboardX.SummaryWriter. Default: ``{}``.""" 131 | 132 | def create_logger(self): 133 | """Creates a TensorBoardLogger instance from this config.""" 134 | from lightning.pytorch.loggers.tensorboard import TensorBoardLogger 135 | 136 | # Pass all parameters except additional_params to constructor 137 | base_params = { 138 | k: v 139 | for k, v in self.model_dump().items() 140 | if k != "additional_params" and k != "type" 141 | } 142 | 143 | # Merge with additional tensorboard parameters 144 | return TensorBoardLogger(**base_params, **self.additional_params) 145 | 146 | 147 | LoggerConfig = TypeAliasType( 148 | "LoggerConfig", 149 | Annotated[ 150 | CSVLoggerConfig | WandbLoggerConfig | TensorBoardLoggerConfig, 151 | C.Field(description="Logger configuration.", discriminator="type"), 152 | ], 153 | ) 154 | -------------------------------------------------------------------------------- /src/mattertune/recipes/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .base import RecipeConfig as RecipeConfig 4 | from .base import RecipeConfigBase as RecipeConfigBase 5 | from .base import recipe_registry as recipe_registry 6 | from .ema import EMARecipeConfig as EMARecipeConfig 7 | from .lora import LoRARecipeConfig as LoRARecipeConfig 8 | from .noop import NoOpRecipeConfig as NoOpRecipeConfig 9 | -------------------------------------------------------------------------------- /src/mattertune/recipes/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Annotated 5 | 6 | import nshconfig as C 7 | from lightning.pytorch.callbacks import Callback 8 | from typing_extensions import TypeAliasType 9 | 10 | 11 | class RecipeConfigBase(C.Config, ABC): 12 | """ 13 | Base configuration for recipes. 14 | """ 15 | 16 | @abstractmethod 17 | def create_lightning_callback(self) -> Callback | None: 18 | """ 19 | Creates the PyTorch Lightning callback for this recipe, or returns 20 | `None` if no callback is needed. 21 | """ 22 | ... 23 | 24 | @classmethod 25 | def ensure_dependencies(cls): 26 | """ 27 | Ensure that all dependencies are installed. 28 | 29 | This method should raise an exception if any dependencies are missing, 30 | with a message indicating which dependencies are missing and 31 | how to install them. 32 | """ 33 | return 34 | 35 | 36 | recipe_registry = C.Registry(RecipeConfigBase, discriminator="name") 37 | 38 | RecipeConfig = TypeAliasType( 39 | "RecipeConfig", 40 | Annotated[ 41 | RecipeConfigBase, 42 | recipe_registry.DynamicResolution(), 43 | ], 44 | ) 45 | -------------------------------------------------------------------------------- /src/mattertune/recipes/noop.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Literal 4 | 5 | from typing_extensions import override 6 | 7 | from .base import RecipeConfigBase, recipe_registry 8 | 9 | 10 | @recipe_registry.register 11 | class NoOpRecipeConfig(RecipeConfigBase): 12 | """ 13 | Example recipe that does nothing. 14 | """ 15 | 16 | name: Literal["no-op"] = "no-op" 17 | """Discriminator for the no-op recipe.""" 18 | 19 | @override 20 | def create_lightning_callback(self) -> None: 21 | return None 22 | -------------------------------------------------------------------------------- /src/mattertune/registry.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import nshconfig as C 4 | 5 | from .finetune.base import FinetuneModuleBaseConfig 6 | 7 | backbone_registry = C.Registry(FinetuneModuleBaseConfig, discriminator="name") 8 | """Registry for backbone modules.""" 9 | 10 | data_registry = C.Registry(C.Config, discriminator="type") 11 | """Registry for data modules.""" 12 | __all__ = [ 13 | "backbone_registry", 14 | "data_registry", 15 | ] 16 | -------------------------------------------------------------------------------- /src/mattertune/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import contextlib 4 | 5 | 6 | @contextlib.contextmanager 7 | def optional_import_error_message(pip_package_name: str, /): 8 | try: 9 | yield 10 | except ImportError as e: 11 | raise ImportError( 12 | f"The `{pip_package_name}` package is not installed. Please install it by running " 13 | f"`pip install {pip_package_name}`." 14 | ) from e 15 | -------------------------------------------------------------------------------- /src/mattertune/wrappers/ase_calculator.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import copy 4 | from typing import TYPE_CHECKING 5 | 6 | import numpy as np 7 | import torch 8 | from ase import Atoms 9 | from ase.calculators.calculator import Calculator 10 | from typing_extensions import override 11 | 12 | if TYPE_CHECKING: 13 | from ..finetune.properties import PropertyConfig 14 | from .property_predictor import MatterTunePropertyPredictor 15 | from ..finetune.base import FinetuneModuleBase 16 | 17 | 18 | class MatterTuneCalculator(Calculator): 19 | """ 20 | A fast version of the MatterTuneCalculator that uses the `predict_step` method directly without creating a trainer. 21 | """ 22 | 23 | @override 24 | def __init__(self, model: FinetuneModuleBase, device: torch.device): 25 | super().__init__() 26 | 27 | self.model = model.to(device) 28 | 29 | self.implemented_properties: list[str] = [] 30 | self._ase_prop_to_config: dict[str, PropertyConfig] = {} 31 | 32 | for prop in self.model.hparams.properties: 33 | # Ignore properties not marked as ASE calculator properties. 34 | if (ase_prop_name := prop.ase_calculator_property_name()) is None: 35 | continue 36 | self.implemented_properties.append(ase_prop_name) 37 | self._ase_prop_to_config[ase_prop_name] = prop 38 | 39 | @override 40 | def calculate( 41 | self, 42 | atoms: Atoms | None = None, 43 | properties: list[str] | None = None, 44 | system_changes: list[str] | None = None, 45 | ): 46 | if properties is None: 47 | properties = copy.deepcopy(self.implemented_properties) 48 | 49 | # Call the parent class to set `self.atoms`. 50 | Calculator.calculate(self, atoms) 51 | 52 | # Make sure `self.atoms` is set. 53 | assert self.atoms is not None, ( 54 | "`MatterTuneCalculator.atoms` is not set. " 55 | "This should have been set by the parent class. " 56 | "Please report this as a bug." 57 | ) 58 | assert isinstance(self.atoms, Atoms), ( 59 | "`MatterTuneCalculator.atoms` is not an `ase.Atoms` object. " 60 | "This should have been set by the parent class. " 61 | "Please report this as a bug." 62 | ) 63 | 64 | prop_configs = [self._ase_prop_to_config[prop] for prop in properties] 65 | 66 | normalized_atoms = copy.deepcopy(self.atoms) 67 | scaled_pos = normalized_atoms.get_scaled_positions() 68 | scaled_pos = np.mod(scaled_pos, 1.0) 69 | normalized_atoms.set_scaled_positions(scaled_pos) 70 | 71 | data = self.model.atoms_to_data(normalized_atoms, has_labels=False) 72 | batch = self.model.collate_fn([data]) 73 | batch = batch.to(self.model.device) 74 | 75 | pred = self.model.predict_step( 76 | batch = batch, 77 | batch_idx = 0, 78 | ) 79 | pred = pred[0] # type: ignore 80 | 81 | for prop in prop_configs: 82 | ase_prop_name = prop.ase_calculator_property_name() 83 | assert ase_prop_name is not None, ( 84 | f"Property '{prop.name}' does not have an ASE calculator property name. " 85 | "This should have been checked when creating the MatterTuneCalculator. " 86 | "Please report this as a bug." 87 | ) 88 | 89 | value = pred[prop.name].detach().to(torch.float32).cpu().numpy() # type: ignore 90 | value = value.astype(prop._numpy_dtype()) 91 | value = prop.prepare_value_for_ase_calculator(value) 92 | 93 | self.results[ase_prop_name] = value --------------------------------------------------------------------------------