├── CODE_OF_CONDUCT.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── docker └── Dockerfile ├── env.yaml ├── evaluation.py ├── pyproject.toml ├── reactot-pretrained.ckpt ├── reactot ├── Figures │ ├── F1_prime.pdf │ ├── F2_resub.pdf │ ├── F3_pprime.pdf │ ├── F4.pdf │ ├── F5_resub_v1.pdf │ ├── figure1.jpg │ ├── figure2.jpg │ └── reaction_stat_v2.svg ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── pre_process.cpython-310.pyc │ └── run_model.cpython-310.pyc ├── analyze │ ├── BDAGIHXWWSANSR_95_p.xyz │ ├── BDAGIHXWWSANSR_95_r.xyz │ ├── __pycache__ │ │ └── rmsd.cpython-310.pyc │ ├── geomopt.py │ └── rmsd.py ├── appmain.py ├── dataset │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── base_dataset.cpython-310.pyc │ │ ├── datasets_config.cpython-310.pyc │ │ ├── qm9.cpython-310.pyc │ │ ├── sampler.cpython-310.pyc │ │ ├── transition1x.cpython-310.pyc │ │ └── zeolite.cpython-310.pyc │ ├── base_dataset.py │ ├── datasets_config.py │ ├── ff_lmdb.py │ ├── qm9.py │ ├── sampler.py │ ├── transition1x.py │ └── zeolite.py ├── diffusion │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── _normalizer.cpython-310.pyc │ │ ├── _schedule.cpython-310.pyc │ │ ├── _utils.cpython-310.pyc │ │ ├── en_diffusion.cpython-310.pyc │ │ └── en_sb.cpython-310.pyc │ ├── _node_dist.py │ ├── _normalizer.py │ ├── _schedule.py │ ├── _utils.py │ ├── en_diffusion.py │ └── en_sb.py ├── dynamics │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── _base.cpython-310.pyc │ │ ├── confidence.cpython-310.pyc │ │ ├── egnn_dynamics.cpython-310.pyc │ │ └── potential.cpython-310.pyc │ ├── _base.py │ ├── confidence.py │ ├── egnn_dynamics.py │ └── potential.py ├── evaluate │ ├── evaluate_rmsd_vs_ediff.py │ ├── evaluate_ts_w_rp.py │ ├── generate_confidence_sample.py │ ├── generate_on_example_path.py │ ├── run_confidence_sample.sh │ ├── run_eva_ts_e_rp.sh │ └── utils.py ├── model │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── block.cpython-310.pyc │ │ ├── core.cpython-310.pyc │ │ ├── egnn.cpython-310.pyc │ │ ├── leftnet.cpython-310.pyc │ │ └── util_funcs.cpython-310.pyc │ ├── block.py │ ├── core.py │ ├── egnn.py │ ├── leftnet.py │ └── util_funcs.py ├── pre_process.py ├── run_model.py ├── sampling │ └── sample_datasets.py ├── trainer │ ├── __pycache__ │ │ ├── _metrics.cpython-310.pyc │ │ └── pl_trainer.cpython-310.pyc │ ├── _metrics.py │ ├── ema.py │ ├── pl_trainer.py │ ├── potential_module.py │ ├── test_integrator.ipynb │ └── train_rpsb_ts1x.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── _graph_tools.cpython-310.pyc │ ├── bond_analyze.cpython-310.pyc │ ├── sampling_tools.cpython-310.pyc │ └── training_tools.cpython-310.pyc │ ├── _graph_tools.py │ ├── bond_analyze.py │ ├── examples │ ├── H2O_dissociated.xyz │ ├── acetate.xyz │ ├── chiral_stereo_test.xyz │ ├── ethane.xyz │ ├── ethane_radical.xyz │ └── propylbenzene.xyz │ ├── sampling_tools.py │ ├── training_tools.py │ └── xyz2mol.py ├── requirements.txt └── setup.cfg /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, 8 | body size, disability, ethnicity, gender identity and expression, level of 9 | experience, nationality, personal appearance, race, religion, or sexual 10 | identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment include: 15 | 16 | * Using welcoming and inclusive language 17 | * Being respectful of differing viewpoints and experiences 18 | * Gracefully accepting constructive criticism 19 | * Focusing on what is best for the community 20 | * Showing empathy towards other community members 21 | 22 | Examples of unacceptable behavior by participants include: 23 | 24 | * The use of sexualized language or imagery and unwelcome sexual attention or advances 25 | * Trolling, insulting/derogatory comments, and personal or political attacks 26 | * Public or private harassment 27 | * Publishing others' private information, such as a physical or electronic address, without explicit permission 28 | * Other conduct which could reasonably be considered inappropriate in a professional setting 29 | 30 | ## Our Responsibilities 31 | 32 | Project maintainers are responsible for clarifying the standards of acceptable 33 | behavior and are expected to take appropriate and fair corrective action in 34 | response to any instances of unacceptable behavior. 35 | 36 | Project maintainers have the right and responsibility to remove, edit, or 37 | reject comments, commits, code, wiki edits, issues, and other contributions 38 | that are not aligned to this Code of Conduct, or to ban temporarily or 39 | permanently any contributor for other behaviors that they deem inappropriate, 40 | threatening, offensive, or harmful. 41 | 42 | Moreover, project maintainers will strive to offer feedback and advice to 43 | ensure quality and consistency of contributions to the code. Contributions 44 | from outside the group of project maintainers are strongly welcomed but the 45 | final decision as to whether commits are merged into the codebase rests with 46 | the team of project maintainers. 47 | 48 | ## Scope 49 | 50 | This Code of Conduct applies both within project spaces and in public spaces 51 | when an individual is representing the project or its community. Examples of 52 | representing a project or community include using an official project e-mail 53 | address, posting via an official social media account, or acting as an 54 | appointed representative at an online or offline event. Representation of a 55 | project may be further defined and clarified by project maintainers. 56 | 57 | ## Enforcement 58 | 59 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 60 | reported by contacting the project team at 'duanchenru@gmail.com, nandy@mit.edu, and yuanqidu@cs.cornell.edu'. The project team will 61 | review and investigate all complaints, and will respond in a way that it deems 62 | appropriate to the circumstances. The project team is obligated to maintain 63 | confidentiality with regard to the reporter of an incident. Further details of 64 | specific enforcement policies may be posted separately. 65 | 66 | Project maintainers who do not follow or enforce the Code of Conduct in good 67 | faith may face temporary or permanent repercussions as determined by other 68 | members of the project's leadership. 69 | 70 | ## Attribution 71 | 72 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 73 | version 1.4, available at 74 | [http://contributor-covenant.org/version/1/4][version] 75 | 76 | [homepage]: http://contributor-covenant.org 77 | [version]: http://contributor-covenant.org/version/1/4/ 78 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include CODE_OF_CONDUCT.md 2 | 3 | global-exclude *.py[cod] __pycache__ *.so 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # React-OT: Optimal Transport for Generating Transition State in Chemical Reactions 2 | 3 | In this work, we developed React-OT, an optimal transport approach to generate TSs of an elementary reaction in a fully deterministic manner. It is based on our previously developed diffusion-based generative model for generating 3D chemical reactions, [OA-ReactDiff](https://github.com/chenruduan/OAReactDiff). React-OT has been improved for generating transition state (TS) structures for a given reactants and products (double-ended search problem), enabling it to generate highly accurate transition state structures while maintaining an extremely high inference speed. 4 | 5 | ![image](https://github.com/deepprinciple/react-ot/blob/main/reactot/Figures/figure1.jpg) 6 | Fig. 1 | Overview of the diffusion model and optimal transport framework for generating TS. a. Learning the joint distribution of structures in elementary reactions (reactant in red, TS in yellow, and product in blue). b. Stochastic inference with inpainting in OA-ReactDiff. c. Deterministic inference with React-OT. 7 | 8 | We trained React-OT on Transition1x, a dataset that contains paired reactants, TSs, and products calculated from climbing-image NEB obtained with DFT (ωB97x/6-31G(d)). In React-OT, the object-aware version of LEFTNet is used as the scoring network to fit the transition kernel (see [LEFTNet](https://arxiv.org/abs/2304.04757)). React-OT achieves a mean RMSD of 0.103 Å between generated and true TS structures on the set-aside test reactions of Transition1x, significantly improved upon previous state-of-the-art TS prediction methods. 9 | 10 | ![image](https://github.com/deepprinciple/react-ot/blob/main/reactot/Figures/figure2.jpg) 11 | Fig. 2 | Structural and energetic performance of diffusion and optimal transport generated TS structures. a. Cumulative probability for structure root mean square deviation (RMSD) (left) and absolute energy error (|∆E TS|) (right) between the true and generated TS on 1,073 set-aside test reactions. b. Reference TS structure, OA-ReactDiff TS sample (red), and React-OT structure (orange) for select reactions. c. Histogram (gray, left y axis) and cumulative probability(blue, right y axis) showing the difference of RMSD (left) and |∆ETS|(right) between OA-ReactDiff recommended and React-OT structures compared to reference TS. d. Inference time in seconds for single-shot OA-ReactDiff, 40-shot OA-ReactDiff with recommender, and React-OT. 12 | 13 | We envision that the remarkable accuracy and rapid inference of React-OT will be highly useful when integrated with the current high-throughput TS search workflow. This integration will facilitate the exploration of chemical reactions with unknown mechanisms. 14 | 15 | ## Environment set-up 16 | ``` 17 | conda env create -f env.yaml 18 | conda activate reactot && pip install -e . 19 | ``` 20 | 21 | ## Download data 22 | The processed data is uploaded on zenodo, [download link](https://zenodo.org/records/13131875). You need to put both pickle files under the data directory. 23 | 24 | ``` 25 | mkdir reactot/data 26 | mkdir reactot/data/transition1x 27 | mv PATH_TO_PKL_FILES reactot/data/transition1x/ 28 | ``` 29 | 30 | ## Evaluation using a pre-trained model 31 | The pre-trained model can be downloaded through the [download link](https://zenodo.org/records/13131875). 32 | ``` 33 | python evaluation.py --checkpoint PATH_TO_CHECKPOINT --solver ode --nfe 10 34 | ``` 35 | 36 | ## Training 37 | ``` 38 | python -m reactot.trainer.train_rpsb_ts1x 39 | ``` 40 | Note that the default parameters and model types are used in the current command. More detailed instructions on model training will be updated soon. 41 | 42 | ## Data used in this work 43 | 1. [Transition1x](https://gitlab.com/matschreiner/Transition1x) 44 | 2. [RGD1](https://figshare.com/articles/dataset/model_reaction_database/21066901) 45 | 3. [Berkholz-15](https://onlinelibrary.wiley.com/doi/abs/10.1002/jcc.23910) 46 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM deepprinciple-cn-beijing.cr.volces.com/public-registry/python:3.10.16 2 | 3 | ENV http_proxy http://100.68.166.13:3128 4 | ENV https_proxy http://100.68.166.13:3128 5 | ENV HTTP_PROXY http://100.68.166.13:3128 6 | ENV HTTPS_PROXY http://100.68.166.13:3128 7 | 8 | ENV no_proxy mirrors.tuna.tsinghua.edu.cn 9 | 10 | COPY . /app 11 | 12 | WORKDIR /app 13 | 14 | RUN pip install --no-cache-dir -r requirements.txt -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple 15 | RUN pip install --no-cache-dir torch -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple 16 | RUN pip install --no-cache-dir torch_geometric -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple 17 | RUN pip install --no-cache-dir pytorch-lightning -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple 18 | RUN pip install --no-cache-dir torchdiffeq -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple 19 | RUN apt-get update && apt-get -y install build-essential python3-dev python3-pip libomp-dev 20 | RUN pip install torch_scatter -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple 21 | 22 | RUN unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY 23 | 24 | # RUN TMPDIR=/home/ubuntu/tmp/ pip install torch-sparse -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple 25 | 26 | # RUN TMPDIR=/home/ubuntu/tmp/ pip install torch-cluster -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: reactot 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | # Base depends 8 | - python=3.10.14 9 | - pip 10 | # - rdkit 11 | # - biopython 12 | # - imageio 13 | # - openbabel 14 | # Testing 15 | - pytest 16 | # - codecov 17 | # - pytest-cov 18 | # Pip-only installs 19 | - pip: 20 | - numpy==1.26.4 21 | - pytorch-lightning==2.4.0 22 | - torch==2.2.1 23 | - torch_geometric 24 | - torch-scatter 25 | - torch-sparse 26 | - torch-cluster 27 | - pymatgen 28 | - ase 29 | - wandb 30 | - networkx 31 | # - ipykernel 32 | # - timm 33 | # - e3nn 34 | # - plotly 35 | # - nbformat 36 | # - pyscf 37 | - torchdiffeq 38 | - colored-traceback 39 | - ipdb 40 | - lmdb 41 | - rich 42 | - timm 43 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | import sys 4 | 5 | import argparse 6 | import numpy as np 7 | from reactot.trainer.pl_trainer import SBModule 8 | 9 | from rich.console import Console 10 | from rich.logging import RichHandler 11 | 12 | device = "cuda" 13 | 14 | def setup_logger(log_dir: pathlib.Path) -> None: 15 | log_dir.mkdir(exist_ok=True, parents=True) 16 | 17 | log_file = open(log_dir / "log.txt", "w") 18 | file_console = Console(file=log_file, width=150) 19 | logging.basicConfig( 20 | level=logging.INFO, 21 | format="%(message)s", 22 | datefmt="[%X]", 23 | force=True, 24 | handlers=[RichHandler(), RichHandler(console=file_console)], 25 | ) 26 | 27 | def load_model( 28 | checkpoint_path="/mnt/beegfs/bulk/mirror/yuanqi/reactOT/reactot/checkpoint/RPSB-FT-Schedule/leftnet-ts_guess_NEBCI-xtb-ema-785cb2e522c3/sb-epoch=009-val_ep_scaled_err=0.0645.ckpt", 29 | ): 30 | print (checkpoint_path) 31 | model = SBModule.load_from_checkpoint( 32 | checkpoint_path=checkpoint_path, 33 | map_location=device, 34 | ) 35 | model = model.eval() 36 | model = model.to(device) 37 | 38 | model.training_config["use_sampler"] = False 39 | model.training_config["swapping_react_prod"] = False 40 | model.training_config["datadir"] = "./reactot/data/transition1x" 41 | 42 | model.setup(stage="fit", device=device, swapping_react_prod=False) 43 | return model 44 | 45 | def main(opt): 46 | setup_logger(pathlib.Path(".log")) 47 | log = logging.getLogger(__name__) 48 | 49 | log.info("===== Start =====") 50 | log.info("Command used:\n{}".format(" ".join(sys.argv))) 51 | 52 | model = load_model(opt.checkpoint) 53 | 54 | val_loader = model.val_dataloader(bz=opt.batch_size, shuffle=False) 55 | model.nfe = opt.nfe 56 | model.ddpm.opt = opt # hack :) 57 | 58 | if opt.dryrun: 59 | # dryrun: just first batch 60 | batch = next(iter(val_loader)) 61 | r_pos, ts_pos, p_pos, x0_size, x0_other, rmsds = model.eval_sample_batch( 62 | batch, 63 | return_all=True, 64 | ) # 30s for nfe=100 65 | else: 66 | # full val set 67 | res, rmsds = model.eval_rmsd( 68 | val_loader, 69 | write_xyz=False, 70 | bz=opt.batch_size, 71 | refpath="ref_ts", 72 | localpath=f"{opt.solver}-{opt.method}/nfe{opt.nfe}/", 73 | # max_num_batch=10, 74 | ) 75 | # np.savez(f"data/{opt.save}.npz", res=res, rmsds=rmsds) 76 | 77 | log.info(f"mean={np.mean(rmsds):.5f}, median={np.median(rmsds):.5f}, {len(rmsds)=}") 78 | log.info("===== End =====") 79 | 80 | if __name__ == '__main__': 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument("--batch-size", type=int, default=72) 83 | parser.add_argument("--nfe", type=int, default=100) 84 | parser.add_argument("--save", type=str, default="debug") 85 | parser.add_argument("--dryrun", action="store_true") 86 | 87 | parser.add_argument("--solver", type=str, choices=["ddpm", "ei", "ode"]) 88 | parser.add_argument("--checkpoint", type=str) 89 | 90 | # ei 91 | parser.add_argument("--order", type=int, default=1) 92 | parser.add_argument("--diz", type=str, default="linear", choices=["linear", "quad"]) 93 | parser.add_argument("--normalize", action="store_true") 94 | 95 | # ode 96 | parser.add_argument("--method", type=str, default="midpoint") 97 | parser.add_argument("--atol", type=float, default=1e-2) 98 | parser.add_argument("--rtol", type=float, default=1e-2) 99 | 100 | opt = parser.parse_args() 101 | main(opt) 102 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "versioningit~=2.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | # Self-descriptive entries which should always be present 6 | # https://packaging.python.org/en/latest/specifications/declaring-project-metadata/ 7 | [project] 8 | name = "oa_reactdiff" 9 | description = "A object-wise SE(3) diffusion model for elementary reaction generation." 10 | dynamic = ["version"] 11 | readme = "README.md" 12 | authors = [ 13 | { name = "Chenru Duan, and Yuanqi Du", email = "duanchenru@gmail.com" } 14 | ] 15 | license = { text = "MIT" } 16 | # See https://pypi.org/classifiers/ 17 | classifiers = [ 18 | "License :: OSI Approved :: MIT License", 19 | "Programming Language :: Python :: 3", 20 | ] 21 | requires-python = ">=3.10" 22 | # Declare any run-time dependencies that should be installed with the package. 23 | #dependencies = [ 24 | # "importlib-resources;python_version<'3.10'", 25 | #] 26 | 27 | # Update the urls once the hosting is set up. 28 | #[project.urls] 29 | #"Source" = "https://github.com//OA_ReactDiff/" 30 | #"Documentation" = "https://OA_ReactDiff.readthedocs.io/" 31 | 32 | [project.optional-dependencies] 33 | test = [ 34 | "pytest>=6.1.2", 35 | "pytest-runner" 36 | ] 37 | 38 | [tool.setuptools] 39 | # This subkey is a beta stage development and keys may change in the future, see https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html for more details 40 | # 41 | # As of version 0.971, mypy does not support type checking of installed zipped 42 | # packages (because it does not actually import the Python packages). 43 | # We declare the package not-zip-safe so that our type hints are also available 44 | # when checking client code that uses our (installed) package. 45 | # Ref: 46 | # https://mypy.readthedocs.io/en/stable/installed_packages.html?highlight=zip#using-installed-packages-with-mypy-pep-561 47 | zip-safe = false 48 | # Let setuptools discover the package in the current directory, 49 | # but be explicit about non-Python files. 50 | # See also: 51 | # https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html#setuptools-specific-configuration 52 | # Note that behavior is currently evolving with respect to how to interpret the 53 | # "data" and "tests" subdirectories. As of setuptools 63, both are automatically 54 | # included if namespaces is true (default), even if the package is named explicitly 55 | # (instead of using 'find'). With 'find', the 'tests' subpackage is discovered 56 | # recursively because of its __init__.py file, but the data subdirectory is excluded 57 | # with include-package-data = false and namespaces = false. 58 | include-package-data = false 59 | [tool.setuptools.packages.find] 60 | namespaces = false 61 | where = ["."] 62 | 63 | # Ref https://setuptools.pypa.io/en/latest/userguide/datafiles.html#package-data 64 | [tool.setuptools.package-data] 65 | oa_reactdiff = [ 66 | "py.typed" 67 | ] 68 | 69 | [tool.versioningit] 70 | default-version = "0.0.1" 71 | 72 | # [tool.versioningit.format] 73 | # distance = "{base_version}+{distance}.{vcs}{rev}" 74 | # dirty = "{base_version}+{distance}.{vcs}{rev}.dirty" 75 | # distance-dirty = "{base_version}+{distance}.{vcs}{rev}.dirty" 76 | 77 | [tool.versioningit.vcs] 78 | # The method key: 79 | method = "git" # <- The method name 80 | # Parameters to pass to the method: 81 | match = ["*"] 82 | default-tag = "0.0.1" 83 | 84 | [tool.versioningit.write] 85 | file = "oa_reactdiff/_version.py" 86 | -------------------------------------------------------------------------------- /reactot-pretrained.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot-pretrained.ckpt -------------------------------------------------------------------------------- /reactot/Figures/F1_prime.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/Figures/F1_prime.pdf -------------------------------------------------------------------------------- /reactot/Figures/F2_resub.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/Figures/F2_resub.pdf -------------------------------------------------------------------------------- /reactot/Figures/F3_pprime.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/Figures/F3_pprime.pdf -------------------------------------------------------------------------------- /reactot/Figures/F4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/Figures/F4.pdf -------------------------------------------------------------------------------- /reactot/Figures/F5_resub_v1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/Figures/F5_resub_v1.pdf -------------------------------------------------------------------------------- /reactot/Figures/figure1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/Figures/figure1.jpg -------------------------------------------------------------------------------- /reactot/Figures/figure2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/Figures/figure2.jpg -------------------------------------------------------------------------------- /reactot/__init__.py: -------------------------------------------------------------------------------- 1 | """A object-wise SE(3) optimal transport model for elementary reaction generation.""" 2 | -------------------------------------------------------------------------------- /reactot/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/__pycache__/pre_process.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/__pycache__/pre_process.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/__pycache__/run_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/__pycache__/run_model.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/analyze/BDAGIHXWWSANSR_95_p.xyz: -------------------------------------------------------------------------------- 1 | 5 2 | Properties=species:S:1:pos:R:3 Coordinates=T from=T ORCA-job=T BDAGIHXWWSANSR_95_p-OPT=T pbc="F F F" 3 | C -0.45048513 -1.08365292 0.48842500 4 | O 0.37423922 -1.59632488 1.13234369 5 | O -1.27611845 -0.57030607 -0.15573152 6 | H 1.01203679 1.76379061 -0.80985883 7 | H 0.34032756 1.48659325 -0.65517833 8 | -------------------------------------------------------------------------------- /reactot/analyze/BDAGIHXWWSANSR_95_r.xyz: -------------------------------------------------------------------------------- 1 | 5 2 | Properties=species:S:1:pos:R:3 Coordinates=T from=T ORCA-job=T BDAGIHXWWSANSR_95_r-OPT=T pbc="F F F" 3 | C -0.41197499 -0.08701124 -0.07312436 4 | O 0.81428584 -0.33945489 0.42022200 5 | O -1.33026086 -0.83923915 0.06512702 6 | H 1.41060650 0.38977084 0.20101890 7 | H -0.48265650 0.87603445 -0.61324356 8 | -------------------------------------------------------------------------------- /reactot/analyze/__pycache__/rmsd.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/analyze/__pycache__/rmsd.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/analyze/geomopt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pyscf import gto, dft 4 | from pyscf.geomopt.geometric_solver import optimize 5 | from pyscf.hessian import thermo 6 | # from pyscf import dftd3 7 | 8 | from molSimplify.Classes.mol3D import mol3D 9 | from pymatgen.core import Molecule 10 | from pymatgen.analysis.molecule_matcher import BruteForceOrderMatcher 11 | from pymatgen.io.xyz import XYZ 12 | 13 | from .rmsd import pymatgen_rmsd 14 | 15 | AU2KCALMOL = 627.509608 16 | AU2EV = 27.2114 17 | BOHR = 0.52917721092 18 | 19 | 20 | def count_negative_eig(x: list): 21 | count = 0 22 | for _x in x: 23 | if _x.imag > 0: 24 | count += 1 25 | return count 26 | 27 | 28 | def compute_efh(geomfile, f=True, hess=False, return_metrics=False, xc="wb97x", basis="631g*", d3=False): 29 | spin = 0 30 | mol = gto.M( 31 | atom=geomfile, 32 | unit="Ang", 33 | basis=basis, 34 | ) 35 | mol.build() 36 | 37 | mf = dft.RKS(mol) if not spin else dft.UKS(mol) 38 | mf.xc = xc 39 | # if d3: 40 | # mf = dftd3.dftd3(dft.RKS(mol, xc=xc)) 41 | mf.conv_tol = 1e-6 42 | # mf.damp = 0.2 43 | mf.max_cycle = 200 44 | mf.max_memory = 32000 45 | mf.run() 46 | 47 | force = None 48 | force_rms = np.nan 49 | if mf.converged and f: 50 | force = mf.nuc_grad_method().kernel() * -1. / BOHR 51 | force_rms = np.sqrt(np.mean(force ** 2)) * AU2EV 52 | print("force rms (ev/A): ", force_rms) 53 | 54 | hessian = None 55 | if hess: 56 | hessian = mf.Hessian().kernel() 57 | freq_info = thermo.harmonic_analysis(mf.mol, hessian) 58 | print("freq: ", freq_info['freq_wavenumber']) 59 | 60 | if return_metrics: 61 | return mf, force, hessian, force_rms, count_negative_eig(freq_info['freq_wavenumber']) 62 | return mf, force, hessian 63 | 64 | 65 | def compute_rmsd_with_optgeom(mf, transition=False, xyzfile=None): 66 | e_generated = mf.e_tot 67 | mol_eq = optimize(mf, maxsteps=300, transition=transition) 68 | 69 | if xyzfile is None: 70 | xyzfile = ".opt.xyz" if not transition else ".ts.xyz" 71 | mf.mol.tofile(".tmp.xyz", format="xyz") 72 | mol_eq.tofile(xyzfile, format="xyz") 73 | 74 | mf_eq, _, _ = compute_efh(xyzfile, f=False, hess=False) 75 | e_eq = mf_eq.e_tot 76 | e_diff = (e_eq - e_generated) * AU2KCALMOL 77 | 78 | rmsd = pymatgen_rmsd( 79 | mol1=".tmp.xyz", 80 | mol2=xyzfile, 81 | ignore_chirality=True, 82 | threshold=0.5, 83 | ) 84 | 85 | return rmsd, e_diff 86 | 87 | 88 | def compute_irc(mf, hessian, ts_xyz=".ts.xyz", dq=0.1): 89 | ms_mol_eq = mol3D() 90 | ms_mol_eq.readfromxyz(ts_xyz) 91 | freq_info = thermo.harmonic_analysis(mf.mol, hessian) 92 | 93 | # Left 94 | new_coords = ms_mol_eq.coordsvect() - freq_info["norm_mode"][0] * dq 95 | for ii, atom in enumerate(ms_mol_eq.atoms): 96 | atom.setcoords(new_coords[ii]) 97 | ms_mol_eq.writexyz("ts-.xyz") 98 | _mf, _, _ = compute_efh( 99 | "ts-.xyz", hess=False, return_metrics=False) 100 | _, _ = compute_rmsd_with_optgeom(_mf, transition=False, xyzfile="opt-.xyz") 101 | 102 | # Right 103 | new_coords = ms_mol_eq.coordsvect() + 2 * freq_info["norm_mode"][0] * dq 104 | for ii, atom in enumerate(ms_mol_eq.atoms): 105 | atom.setcoords(new_coords[ii]) 106 | ms_mol_eq.writexyz("ts+.xyz") 107 | _mf, _, _ = compute_efh( 108 | "ts+.xyz", hess=False, return_metrics=False) 109 | _, _ = compute_rmsd_with_optgeom(_mf, transition=False, xyzfile="opt+.xyz") 110 | 111 | 112 | def compute_barrier(opt1_xyz, ts_xyz, opt2_xyz): 113 | mf_1, _, _ = compute_efh(opt1_xyz, f=True, hess=False, return_metrics=False) 114 | mf_2, _, _ = compute_efh(opt2_xyz, f=True, hess=False, return_metrics=False) 115 | mf_ts, _, _ = compute_efh(ts_xyz, f=True, hess=False, return_metrics=False) 116 | barrier_left = (mf_ts.e_tot - mf_1.e_tot) * AU2EV 117 | barrier_right = (mf_ts.e_tot - mf_2.e_tot) * AU2EV 118 | return barrier_left, barrier_right 119 | 120 | 121 | def calc_deltaE(xyzfile1, xyzfile2, f=False, xc="wb97x"): 122 | mf_1, _, _ = compute_efh(xyzfile1, f=f, hess=False, return_metrics=False, xc=xc) 123 | mf_2, _, _ = compute_efh(xyzfile2, f=f, hess=False, return_metrics=False, xc=xc) 124 | return (mf_2.e_tot - mf_1.e_tot) * AU2EV 125 | -------------------------------------------------------------------------------- /reactot/analyze/rmsd.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import numpy as np 3 | 4 | from pymatgen.core import Molecule 5 | from pymatgen.analysis.molecule_matcher import BruteForceOrderMatcher, GeneticOrderMatcher, HungarianOrderMatcher, KabschMatcher 6 | from pymatgen.io.xyz import XYZ 7 | 8 | from torch import Tensor 9 | 10 | 11 | def xh2pmg(xh): 12 | mol = Molecule( 13 | species=xh[:, -1].long().cpu().numpy(), 14 | coords=xh[:, :3].cpu().numpy(), 15 | ) 16 | return mol 17 | 18 | 19 | def xyz2pmg(xyzfile): 20 | xyz_converter = XYZ(mol=None) 21 | mol = xyz_converter.from_file(xyzfile).molecule 22 | return mol 23 | 24 | 25 | def rmsd_core(mol1, mol2, threshold=0.5, same_order=False): 26 | _, count = np.unique(mol1.atomic_numbers, return_counts=True) 27 | if same_order: 28 | bfm = KabschMatcher(mol1) 29 | _, rmsd = bfm.fit(mol2) 30 | return rmsd 31 | total_permutations = 1 32 | for c in count: 33 | total_permutations *= np.math.factorial(c) # type: ignore 34 | if total_permutations < 1e4: 35 | bfm = BruteForceOrderMatcher(mol1) 36 | _, rmsd = bfm.fit(mol2) 37 | else: 38 | bfm = GeneticOrderMatcher(mol1, threshold=threshold) 39 | pairs = bfm.fit(mol2) 40 | rmsd = threshold 41 | for pair in pairs: 42 | rmsd = min(rmsd, pair[-1]) 43 | if not len(pairs): 44 | bfm = HungarianOrderMatcher(mol1) 45 | _, rmsd = bfm.fit(mol2) 46 | return rmsd 47 | 48 | 49 | def pymatgen_rmsd( 50 | mol1, 51 | mol2, 52 | ignore_chirality: bool = False, 53 | threshold: float = 0.5, 54 | same_order: bool = True, 55 | ): 56 | if isinstance(mol1, str): 57 | mol1 = xyz2pmg(mol1) 58 | if isinstance(mol2, str): 59 | mol2 = xyz2pmg(mol2) 60 | rmsd = rmsd_core(mol1, mol2, threshold, same_order=same_order) 61 | if ignore_chirality: 62 | coords = mol2.cart_coords 63 | coords[:, -1] = -coords[:, -1] 64 | mol2_reflect = Molecule( 65 | species=mol2.species, 66 | coords=coords, 67 | ) 68 | rmsd_reflect = rmsd_core( 69 | mol1, mol2_reflect, threshold, same_order=same_order) 70 | rmsd = min(rmsd, rmsd_reflect) 71 | return rmsd 72 | 73 | 74 | def batch_rmsd( 75 | fragments_nodes: List[Tensor], 76 | out_samples: List[Tensor], 77 | xh: List[Tensor], 78 | idx: int = 1, 79 | threshold: float = 0.5, 80 | same_order: bool = False, 81 | ) -> List[float]: 82 | rmsds = [] 83 | out_samples_use = out_samples[idx] 84 | xh_use = xh[idx] 85 | nodes = fragments_nodes[idx].long().cpu().numpy() 86 | start_ind, end_ind = 0, 0 87 | for jj, natoms in enumerate(nodes): 88 | end_ind += natoms 89 | mol1 = xh2pmg(out_samples_use[start_ind:end_ind]) 90 | mol2 = xh2pmg(xh_use[start_ind:end_ind]) 91 | try: 92 | rmsd = pymatgen_rmsd( 93 | mol1, 94 | mol2, 95 | ignore_chirality=True, 96 | threshold=threshold, 97 | same_order=same_order, 98 | ) 99 | except: 100 | rmsd = 1 101 | rmsds.append(min(rmsd, 1.0)) 102 | start_ind = end_ind 103 | return rmsds 104 | 105 | def batch_rmsd_sb( 106 | fragments_node: Tensor, 107 | pred_xh: Tensor, 108 | target_xh: Tensor, 109 | threshold: float = 0.5, 110 | same_order: bool = True, 111 | ) -> List[float]: 112 | 113 | rmsds = [] 114 | 115 | end_ind = np.cumsum(fragments_node.long().cpu().numpy()) 116 | start_ind = np.concatenate([np.int64(np.zeros(1)), end_ind[:-1]]) 117 | 118 | for start, end in zip(start_ind, end_ind): 119 | mol1 = xh2pmg(pred_xh[start : end]) 120 | mol2 = xh2pmg(target_xh[start : end]) 121 | rmsd = pymatgen_rmsd( 122 | mol1, 123 | mol2, 124 | ignore_chirality=True, 125 | threshold=threshold, 126 | same_order=same_order, 127 | ) 128 | rmsds.append(min(rmsd, 1.0)) 129 | return rmsds 130 | -------------------------------------------------------------------------------- /reactot/appmain.py: -------------------------------------------------------------------------------- 1 | from yarp.parsers import xyz_parse 2 | from yarp.utils import opt_geo 3 | from yarp.taffi_functions import table_generator 4 | from yarp.find_lewis import find_lewis 5 | from yarp.parsers import xyz_write 6 | from run_model import pred_ts 7 | import argparse 8 | 9 | 10 | def modify_radj_mat(Radj_mat, bond_break, bond_form): 11 | # Convert bond_break and bond_form from string to list of tuples 12 | bond_break = eval(bond_break) 13 | bond_form = eval(bond_form) 14 | 15 | for (i, j) in bond_break: 16 | Radj_mat[i-1][j-1] = 0 17 | Radj_mat[j-1][i-1] = 0 18 | 19 | 20 | for (i, j) in bond_form: 21 | Radj_mat[i-1][j-1] = 1 22 | Radj_mat[j-1][i-1] = 1 23 | 24 | return Radj_mat 25 | 26 | def main(opt): 27 | if opt.bond_break and opt.bond_form: 28 | [[E,G]] = xyz_parse(opt.rxyz, multiple=True) 29 | R_adj_mat = table_generator(E,G) 30 | P_adj_mat = modify_radj_mat(R_adj_mat, opt.bond_break, opt.bond_form) 31 | P_bond_mats, scores = find_lewis(E, P_adj_mat) 32 | P_bond_mat = P_bond_mats[0] 33 | P_G1 = opt_geo(E, G, P_bond_mat) 34 | xyz_write("product.xyz",E,P_G1) 35 | pred_ts(opt.rxyz, "product.xyz", opt, opt.output_path) 36 | elif opt.pxyz: 37 | pred_ts(opt.rxyz, opt.pxyz, opt, opt.output_path) 38 | else: 39 | raise ValueError("Either --bond_break and --bond_form or --pxyz must be provided") 40 | 41 | if __name__ == '__main__': 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('--rxyz', type=str, help='Specify the input file path') 44 | parser.add_argument('--bond_break', type=str, help='Specify the bonds to break as a list of tuples') 45 | parser.add_argument('--bond_form', type=str, help='Specify the bonds to form as a list of tuples') 46 | parser.add_argument("--pxyz", type=str, default='', help='Specify the product file path') 47 | parser.add_argument("--output_path", type=str) 48 | 49 | parser.add_argument("--batch-size", type=int, default=72) 50 | parser.add_argument("--nfe", type=int, default=100) 51 | 52 | parser.add_argument("--solver", type=str, default='ddpm', choices=["ddpm", "ei", "ode"]) 53 | parser.add_argument("--checkpoint_path", type=str, default='/root/react-ot/reactot-pretrained.ckpt') 54 | 55 | # ei 56 | parser.add_argument("--order", type=int, default=1) 57 | parser.add_argument("--diz", type=str, default="linear", choices=["linear", "quad"]) 58 | 59 | # ode 60 | parser.add_argument("--method", type=str, default="midpoint") 61 | parser.add_argument("--atol", type=float, default=1e-2) 62 | parser.add_argument("--rtol", type=float, default=1e-2) 63 | 64 | opt = parser.parse_args() 65 | 66 | # Ensure --pxyz cannot be used with --bond_break and --bond_form 67 | if opt.pxyz and (opt.bond_break or opt.bond_form): 68 | parser.error("--pxyz cannot be used with --bond_break or --bond_form") 69 | 70 | main(opt) 71 | 72 | -------------------------------------------------------------------------------- /reactot/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .qm9 import ProcessedQM9, ProcessedDoubleQM9, ProcessedTripleQM9 2 | from .base_dataset import BaseDataset 3 | from .transition1x import ProcessedTS1x 4 | from .zeolite import ProcessedZeolite 5 | from .sampler import DynamicBatchSampler 6 | -------------------------------------------------------------------------------- /reactot/dataset/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/dataset/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/dataset/__pycache__/base_dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/dataset/__pycache__/base_dataset.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/dataset/__pycache__/datasets_config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/dataset/__pycache__/datasets_config.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/dataset/__pycache__/qm9.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/dataset/__pycache__/qm9.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/dataset/__pycache__/sampler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/dataset/__pycache__/sampler.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/dataset/__pycache__/transition1x.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/dataset/__pycache__/transition1x.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/dataset/__pycache__/zeolite.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/dataset/__pycache__/zeolite.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | import torch.nn.functional as F 7 | 8 | from reactot.dataset.datasets_config import ATOM_MAPPING, SAM_CHARGED_ATOM_MAPPING 9 | 10 | 11 | class BaseDataset(Dataset): 12 | def __init__( 13 | self, 14 | npz_path, 15 | center=True, 16 | zero_charge=False, 17 | device="cpu", 18 | remove_h=False, 19 | n_fragment=3, 20 | atom_mapping=ATOM_MAPPING, 21 | ) -> None: 22 | super().__init__() 23 | 24 | if ".npz" in str(npz_path): 25 | with np.load(npz_path, allow_pickle=True) as f: 26 | data = {key: val for key, val in f.items()} 27 | elif ".pkl" in str(npz_path): 28 | data = pickle.load(open(npz_path, "rb")) 29 | else: 30 | raise ValueError("data file should be either .npz or .pkl") 31 | 32 | self.raw_dataset = data 33 | self.n_samples = -1 34 | self.data = {} 35 | self.n_fragment = n_fragment 36 | 37 | self.remove_h = remove_h 38 | self.zero_charge = zero_charge 39 | self.center = center 40 | self.device = device 41 | 42 | self.atom_mapping = atom_mapping 43 | self.n_element = len(list(atom_mapping.keys())) 44 | 45 | def __len__(self): 46 | return len(self.data["size_0"]) 47 | 48 | def __getitem__(self, idx): 49 | return {key: val[idx] for key, val in self.data.items()} 50 | 51 | @staticmethod 52 | def collate_fn(batch): 53 | sizes = [] 54 | for k in batch[0].keys(): 55 | if "size" in k: 56 | sizes.append(int(k.split("_")[-1])) 57 | n_fragment = len(sizes) 58 | out = [{} for _ in range(n_fragment)] 59 | res = {} 60 | for prop in batch[0].keys(): 61 | # print(prop) 62 | if prop not in ["condition", "target", "rmsd", "ediff", "ts_guess"]: 63 | idx = int(prop.split("_")[-1]) 64 | _prop = prop.replace(f"_{idx}", "") 65 | if "size" in prop: 66 | out[idx][_prop] = torch.tensor( 67 | [x[prop] for x in batch], 68 | device=batch[0][prop].device, 69 | ) 70 | elif "mask" in prop: 71 | # make sure indices in batch start at zero (needed for 72 | # torch_scatter) 73 | out[idx][_prop] = torch.cat( 74 | [ 75 | i * torch.ones(len(x[prop]), device=x[prop].device).long() 76 | for i, x in enumerate(batch) 77 | ], 78 | dim=0, 79 | ) 80 | elif prop in ["condition", "target", "rmsd", "ediff", "ts_guess"]: 81 | res[prop] = torch.cat([x[prop] for x in batch], dim=0) 82 | else: 83 | out[idx][_prop] = torch.cat([x[prop] for x in batch], dim=0) 84 | 85 | if len(list(res.keys())) == 1: 86 | return out, res["condition"] 87 | return out, res 88 | 89 | def patch_dummy_molecules(self, idx): 90 | self.data[f"size_{idx}"] = torch.ones_like( 91 | self.data[f"size_0"], device=self.device, 92 | ) 93 | self.data[f"pos_{idx}"] = [ 94 | torch.tensor([[0, 0, 0]], device=self.device,) 95 | for _ in range(self.n_samples) 96 | ] 97 | 98 | self.data[f"one_hot_{idx}"] = [ 99 | torch.tensor([0], device=self.device,) 100 | for _ in range(self.n_samples) 101 | ] 102 | self.data[f"one_hot_{idx}"] = [ 103 | F.one_hot(_z, num_classes=self.n_element) for _z in self.data[f"one_hot_{idx}"] 104 | ] 105 | 106 | if self.zero_charge: 107 | self.data[f"charge_{idx}"] = [ 108 | torch.zeros(size=(1, 1), dtype=torch.int64, device=self.device,) 109 | for _ in range(self.n_samples) 110 | ] 111 | else: 112 | self.data[f"charge_{idx}"] = [ 113 | torch.ones(size=(1, 1), dtype=torch.int64, device=self.device,) 114 | for _ in range(self.n_samples) 115 | ] 116 | 117 | self.data[f"mask_{idx}"] = [ 118 | torch.zeros(size=(1,), dtype=torch.int64, device=self.device,) 119 | for _ in range(self.n_samples) 120 | ] 121 | 122 | def process_molecules(self, dataset_name, n_samples, idx, append_charge=None, 123 | position_key="positions"): 124 | data = getattr(self, dataset_name) 125 | self.data[f"size_{idx}"] = torch.tensor(data["num_atoms"], device=self.device) 126 | self.data[f"pos_{idx}"] = [ 127 | torch.tensor( 128 | data[position_key][ii][: data["num_atoms"][ii]], 129 | device=self.device, 130 | dtype=torch.float32, 131 | ) 132 | for ii in range(n_samples) 133 | ] 134 | 135 | self.data[f"one_hot_{idx}"] = [ 136 | torch.tensor( 137 | [ 138 | self.atom_mapping[_at] 139 | for _at in data["charges"][ii][: data["num_atoms"][ii]] 140 | ], 141 | device=self.device, 142 | ) 143 | for ii in range(n_samples) 144 | ] 145 | self.data[f"one_hot_{idx}"] = [ 146 | F.one_hot(_z, num_classes=self.n_element) 147 | for _z in self.data[f"one_hot_{idx}"] 148 | ] 149 | 150 | if self.zero_charge: 151 | self.data[f"charge_{idx}"] = [ 152 | torch.zeros(size=(_size, 1), dtype=torch.int64, device=self.device,) 153 | for _size in data["num_atoms"] 154 | ] 155 | else: 156 | if append_charge is None: 157 | self.data[f"charge_{idx}"] = [ 158 | torch.tensor( 159 | data["charges"][ii][: data["num_atoms"][ii]], 160 | device=self.device, 161 | ).view(-1, 1) 162 | for ii in range(n_samples) 163 | ] 164 | else: 165 | self.data[f"charge_{idx}"] = [ 166 | torch.cat( 167 | [ 168 | torch.tensor( 169 | data["charges"][ii][: data["num_atoms"][ii]], 170 | device=self.device, 171 | ).view(-1, 1), 172 | torch.tensor( 173 | [append_charge for _ in range(data["num_atoms"][ii])], 174 | device=self.device, 175 | ).view(-1, 1), 176 | ], 177 | dim=1, 178 | ) 179 | for ii in range(n_samples) 180 | ] 181 | 182 | self.data[f"mask_{idx}"] = [ 183 | torch.zeros(size=(_size,), dtype=torch.int64, device=self.device,) 184 | for _size in data["num_atoms"] 185 | ] 186 | 187 | if self.center: 188 | self.data[f"pos_{idx}"] = [ 189 | pos - torch.mean(pos, dim=0) for pos in self.data[f"pos_{idx}"] 190 | ] 191 | -------------------------------------------------------------------------------- /reactot/dataset/ff_lmdb.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import bisect 9 | import pickle 10 | from pathlib import Path 11 | 12 | import lmdb 13 | import numpy as np 14 | from torch.utils.data import Dataset 15 | from torch_geometric.data import Batch 16 | 17 | 18 | class LmdbDataset(Dataset): 19 | r"""Dataset class to load from LMDB files containing relaxation 20 | trajectories or single point computations. 21 | 22 | Useful for Structure to Energy & Force (S2EF), Initial State to 23 | Relaxed State (IS2RS), and Initial State to Relaxed Energy (IS2RE) tasks. 24 | 25 | Args: 26 | config (dict): Dataset configuration 27 | transform (callable, optional): Data transform function. 28 | (default: :obj:`None`) 29 | """ 30 | 31 | def __init__(self, src, transform=None, **kwargs): 32 | super(LmdbDataset, self).__init__() 33 | 34 | self.path = Path(src) 35 | if not self.path.is_file(): 36 | db_paths = sorted(self.path.glob("*.lmdb")) 37 | assert len(db_paths) > 0, f"No LMDBs found in '{self.path}'" 38 | 39 | self.metadata_path = self.path / "metadata.npz" 40 | 41 | self._keys, self.envs = [], [] 42 | for db_path in db_paths: 43 | self.envs.append(self.connect_db(db_path)) 44 | length = pickle.loads( 45 | self.envs[-1].begin().get("length".encode("ascii")) 46 | ) 47 | self._keys.append(list(range(length))) 48 | 49 | keylens = [len(k) for k in self._keys] 50 | self._keylen_cumulative = np.cumsum(keylens).tolist() 51 | self.num_samples = sum(keylens) 52 | else: 53 | self.metadata_path = self.path.parent / "metadata.npz" 54 | self.env = self.connect_db(self.path) 55 | self._keys = [ 56 | f"{j}".encode("ascii") 57 | for j in range(self.env.stat()["entries"]) 58 | ] 59 | self.num_samples = len(self._keys) 60 | 61 | self.transform = transform 62 | 63 | def __len__(self): 64 | return self.num_samples 65 | 66 | def __getitem__(self, idx): 67 | if not self.path.is_file(): 68 | # Figure out which db this should be indexed from. 69 | db_idx = bisect.bisect(self._keylen_cumulative, idx) 70 | # Extract index of element within that db. 71 | el_idx = idx 72 | if db_idx != 0: 73 | el_idx = idx - self._keylen_cumulative[db_idx - 1] 74 | assert el_idx >= 0 75 | 76 | # Return features. 77 | datapoint_pickled = ( 78 | self.envs[db_idx] 79 | .begin() 80 | .get(f"{self._keys[db_idx][el_idx]}".encode("ascii")) 81 | ) 82 | data_object = pickle.loads(datapoint_pickled) 83 | data_object.id = f"{db_idx}_{el_idx}" 84 | else: 85 | datapoint_pickled = self.env.begin().get(self._keys[idx]) 86 | data_object = pickle.loads(datapoint_pickled) 87 | 88 | if self.transform is not None: 89 | data_object = self.transform(data_object) 90 | 91 | return data_object 92 | 93 | def connect_db(self, lmdb_path=None): 94 | env = lmdb.open( 95 | str(lmdb_path), 96 | subdir=False, 97 | readonly=True, 98 | lock=False, 99 | readahead=False, 100 | meminit=False, 101 | max_readers=1, 102 | map_size=1099511627776 * 2, 103 | ) 104 | return env 105 | 106 | def close_db(self): 107 | if not self.path.is_file(): 108 | for env in self.envs: 109 | env.close() 110 | else: 111 | self.env.close() 112 | -------------------------------------------------------------------------------- /reactot/dataset/qm9.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from reactot.dataset.base_dataset import BaseDataset, ATOM_MAPPING 5 | 6 | n_element = len(list(ATOM_MAPPING.keys())) 7 | 8 | 9 | class BaseQM9(BaseDataset): 10 | def __init__( 11 | self, 12 | npz_path, 13 | center=True, 14 | zero_charge=False, 15 | device="cpu", 16 | remove_h=False, 17 | ) -> None: 18 | super().__init__( 19 | npz_path=npz_path, 20 | center=center, 21 | zero_charge=zero_charge, 22 | device=device, 23 | remove_h=remove_h, 24 | ) 25 | if self.remove_h: 26 | pos = self.raw_dataset['positions'] 27 | charges = self.raw_dataset['charges'] 28 | num_atoms = self.raw_dataset['num_atoms'] 29 | 30 | mask = self.raw_dataset['charges'] > 1 31 | new_positions = np.zeros_like(pos) 32 | new_charges = np.zeros_like(charges) 33 | for i in range(new_positions.shape[0]): 34 | m = mask[i] 35 | p = pos[i][m] # positions to keep 36 | c = charges[i][m] # Charges to keep 37 | n = np.sum(m) 38 | new_positions[i, :n, :] = p 39 | new_charges[i, :n] = c 40 | 41 | self.raw_dataset['positions'] = new_positions 42 | self.raw_dataset['charges'] = new_charges 43 | self.raw_dataset['num_atoms'] = np.sum( 44 | self.raw_dataset['charges'] > 0, axis=1) 45 | 46 | self.n_samples = len(self.raw_dataset["charges"]) 47 | self.data = {} 48 | 49 | def get_subsets(self): 50 | hasN, hasO, hasF = [], [], [] 51 | for ii in range(self.n_samples): 52 | charges = self.raw_dataset["charges"][ii] 53 | unique_charges = np.unique(charges) 54 | if set(unique_charges) <= set([0, 1, 6, 8]) and 8 in set(unique_charges): 55 | hasO.append(ii) 56 | if set(unique_charges) <= set([0, 1, 6, 7]) and 7 in set(unique_charges): 57 | hasN.append(ii) 58 | if set(unique_charges) <= set([0, 1, 6, 9]) and 9 in set(unique_charges): 59 | hasF.append(ii) 60 | self.hasO_set = {key: val[hasO] for key, val in self.raw_dataset.items()} 61 | self.hasN_set = {key: val[hasN] for key, val in self.raw_dataset.items()} 62 | self.hasF_set = {key: val[hasF] for key, val in self.raw_dataset.items()} 63 | 64 | 65 | class ProcessedQM9(BaseQM9): 66 | def __init__( 67 | self, 68 | npz_path, 69 | center=True, 70 | pad_fragments=2, 71 | device="cpu", 72 | zero_charge=False, 73 | remove_h=False, 74 | **kwargs, 75 | ): 76 | super().__init__( 77 | npz_path=npz_path, 78 | center=center, 79 | device=device, 80 | zero_charge=zero_charge, 81 | remove_h=remove_h, 82 | ) 83 | 84 | self.n_fragments = pad_fragments + 1 85 | self.device = torch.device(device) 86 | 87 | n_samples = len(self.raw_dataset["charges"]) 88 | self.n_samples = n_samples 89 | 90 | self.data = {} 91 | self.process_molecules("raw_dataset", n_samples, idx=0) 92 | 93 | for idx in range(pad_fragments): 94 | self.patch_dummy_molecules(idx + 1) 95 | 96 | self.data["condition"] = [ 97 | torch.zeros(size=(1, 1), dtype=torch.int64, device=self.device,) 98 | for _ in range(self.n_samples) 99 | ] 100 | 101 | 102 | class ProcessedDoubleQM9(BaseQM9): 103 | def __init__( 104 | self, 105 | npz_path, 106 | center=True, 107 | pad_fragments=1, 108 | device="cpu", 109 | zero_charge=False, 110 | remove_h=False, 111 | **kwargs, 112 | ): 113 | super().__init__( 114 | npz_path=npz_path, 115 | center=center, 116 | device=device, 117 | zero_charge=zero_charge, 118 | remove_h=remove_h, 119 | ) 120 | 121 | self.n_fragments = pad_fragments + 2 122 | self.device = torch.device(device) 123 | n_samples = len(self.raw_dataset["charges"]) 124 | self.n_samples = len(self.raw_dataset["charges"]) 125 | 126 | self.get_subsets() 127 | self.get_pairs() 128 | 129 | self.data = {} 130 | self.process_molecules("frag1_data", n_samples, idx=0) 131 | self.process_molecules("frag2_data", n_samples, idx=1) 132 | 133 | for idx in range(pad_fragments): 134 | self.patch_dummy_molecules(idx + 2) 135 | 136 | self.data["condition"] = [ 137 | torch.zeros(size=(1, 1), dtype=torch.int64, device=self.device,) 138 | for _ in range(self.n_samples) 139 | ] 140 | 141 | def get_pairs(self): 142 | self.frag1_data, self.frag2_data = {}, {} 143 | frag1_O_idx_1sthalf = np.random.choice( 144 | len(self.hasO_set["charges"]), 145 | int(self.n_samples / 2), 146 | replace=True, 147 | ) 148 | frag2_N_idx_1sthalf = np.random.choice( 149 | len(self.hasN_set["charges"]), 150 | int(self.n_samples / 2), 151 | replace=True, 152 | ) 153 | frag1_N_idx_2ndhalf = np.random.choice( 154 | len(self.hasN_set["charges"]), 155 | int(self.n_samples / 2), 156 | replace=True, 157 | ) 158 | frag2_O_idx_2ndhalf = np.random.choice( 159 | len(self.hasO_set["charges"]), 160 | int(self.n_samples / 2), 161 | replace=True, 162 | ) 163 | self.frag1_data = { 164 | key: np.concatenate( 165 | [ 166 | self.hasO_set[key][frag1_O_idx_1sthalf], 167 | self.hasN_set[key][frag1_N_idx_2ndhalf], 168 | ], 169 | axis=0 170 | ) for key in self.raw_dataset 171 | } 172 | self.frag2_data = { 173 | key: np.concatenate( 174 | [ 175 | self.hasN_set[key][frag2_N_idx_1sthalf], 176 | self.hasO_set[key][frag2_O_idx_2ndhalf], 177 | ], 178 | axis=0 179 | ) for key in self.raw_dataset 180 | } 181 | 182 | 183 | class ProcessedTripleQM9(BaseQM9): 184 | def __init__( 185 | self, 186 | npz_path, 187 | center=True, 188 | pad_fragments=0, 189 | device="cpu", 190 | zero_charge=False, 191 | remove_h=False, 192 | **kwargs, 193 | ): 194 | super().__init__( 195 | npz_path=npz_path, 196 | center=center, 197 | device=device, 198 | zero_charge=zero_charge, 199 | remove_h=remove_h, 200 | ) 201 | 202 | self.n_fragments = pad_fragments + 3 203 | self.device = torch.device(device) 204 | n_samples = len(self.raw_dataset["charges"]) 205 | self.n_samples = len(self.raw_dataset["charges"]) 206 | 207 | self.get_subsets() 208 | self.get_pairs() 209 | 210 | self.data = {} 211 | self.process_molecules("frag1_data", n_samples, idx=0) 212 | self.process_molecules("frag2_data", n_samples, idx=1) 213 | self.process_molecules("frag3_data", n_samples, idx=2) 214 | 215 | for idx in range(pad_fragments): 216 | self.patch_dummy_molecules(idx + 3) 217 | 218 | self.data["condition"] = [ 219 | torch.zeros(size=(1, 1), dtype=torch.int64, device=self.device,) 220 | for _ in range(self.n_samples) 221 | ] 222 | 223 | def get_pairs(self): 224 | n1 = int(self.n_samples / 3) 225 | n2 = int(self.n_samples / 3) 226 | n3 = self.n_samples - n1 - n2 227 | self.frag1_data, self.frag2_data = {}, {} 228 | frag1_O_idx_1_3 = np.random.choice( 229 | len(self.hasO_set["charges"]), 230 | n1, 231 | replace=True, 232 | ) 233 | frag2_N_idx_1_3 = np.random.choice( 234 | len(self.hasN_set["charges"]), 235 | n1, 236 | replace=True, 237 | ) 238 | frag3_F_idx_1_3 = np.random.choice( 239 | len(self.hasF_set["charges"]), 240 | n1, 241 | replace=True, 242 | ) 243 | frag1_F_idx_2_3 = np.random.choice( 244 | len(self.hasF_set["charges"]), 245 | n2, 246 | replace=True, 247 | ) 248 | frag2_O_idx_2_3 = np.random.choice( 249 | len(self.hasO_set["charges"]), 250 | n2, 251 | replace=True, 252 | ) 253 | frag3_N_idx_2_3 = np.random.choice( 254 | len(self.hasN_set["charges"]), 255 | n2, 256 | replace=True, 257 | ) 258 | frag1_N_idx_3_3 = np.random.choice( 259 | len(self.hasN_set["charges"]), 260 | n3, 261 | replace=True, 262 | ) 263 | frag2_F_idx_3_3 = np.random.choice( 264 | len(self.hasF_set["charges"]), 265 | n3, 266 | replace=True, 267 | ) 268 | frag3_O_idx_3_3 = np.random.choice( 269 | len(self.hasO_set["charges"]), 270 | n3, 271 | replace=True, 272 | ) 273 | self.frag1_data = { 274 | key: np.concatenate( 275 | [ 276 | self.hasO_set[key][frag1_O_idx_1_3], 277 | self.hasF_set[key][frag1_F_idx_2_3], 278 | self.hasN_set[key][frag1_N_idx_3_3], 279 | ], 280 | axis=0 281 | ) for key in self.raw_dataset 282 | } 283 | self.frag2_data = { 284 | key: np.concatenate( 285 | [ 286 | self.hasN_set[key][frag2_N_idx_1_3], 287 | self.hasO_set[key][frag2_O_idx_2_3], 288 | self.hasF_set[key][frag2_F_idx_3_3], 289 | ], 290 | axis=0 291 | ) for key in self.raw_dataset 292 | } 293 | self.frag3_data = { 294 | key: np.concatenate( 295 | [ 296 | self.hasF_set[key][frag3_F_idx_1_3], 297 | self.hasN_set[key][frag3_N_idx_2_3], 298 | self.hasO_set[key][frag3_O_idx_3_3], 299 | ], 300 | axis=0 301 | ) for key in self.raw_dataset 302 | } -------------------------------------------------------------------------------- /reactot/dataset/sampler.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Iterator, List, Optional 3 | 4 | import torch 5 | 6 | from torch.utils.data import Dataset, RandomSampler 7 | from torch.utils.data.sampler import Sampler, BatchSampler 8 | from torch.utils.data.distributed import DistributedSampler 9 | 10 | 11 | class DynamicBatchSampler(Sampler): 12 | r"""Dynamically adds samples to a mini-batch up to a maximum size (either 13 | based on number of nodes or number of edges). When data samples have a 14 | wide range in sizes, specifying a mini-batch size in terms of number of 15 | samples is not ideal and can cause CUDA OOM errors. 16 | 17 | Within the :class:`DynamicBatchSampler`, the number of steps per epoch is 18 | ambiguous, depending on the order of the samples. By default the 19 | :meth:`__len__` will be undefined. This is fine for most cases but 20 | progress bars will be infinite. Alternatively, :obj:`num_steps` can be 21 | supplied to cap the number of mini-batches produced by the sampler. 22 | 23 | .. code-block:: python 24 | 25 | from torch_geometric.loader import DataLoader, DynamicBatchSampler 26 | 27 | sampler = DynamicBatchSampler(dataset, max_num=10000, mode="node") 28 | loader = DataLoader(dataset, batch_sampler=sampler, ...) 29 | 30 | Args: 31 | dataset (Dataset): Dataset to sample from. 32 | max_num (int): Size of mini-batch to aim for in number of nodes or 33 | edges. 34 | mode (str, optional): :obj:`"node"` or :obj:`"edge"` to measure 35 | batch size. (default: :obj:`"node"`) 36 | shuffle (bool, optional): If set to :obj:`True`, will have the data 37 | reshuffled at every epoch. (default: :obj:`False`) 38 | skip_too_big (bool, optional): If set to :obj:`True`, skip samples 39 | which cannot fit in a batch by itself. (default: :obj:`False`) 40 | num_steps (int, optional): The number of mini-batches to draw for a 41 | single epoch. If set to :obj:`None`, will iterate through all the 42 | underlying examples, but :meth:`__len__` will be :obj:`None` since 43 | it is be ambiguous. (default: :obj:`None`) 44 | """ 45 | def __init__( 46 | self, 47 | dataset: Dataset, 48 | max_num: int, 49 | mode: str = 'node', 50 | shuffle: bool = False, 51 | skip_too_big: bool = False, 52 | num_steps: Optional[int] = None, 53 | drop_last: bool = True, 54 | max_batch: Optional[int] = None, 55 | ddp: bool = False, 56 | **kwargs 57 | ): 58 | if not isinstance(max_num, int) or max_num <= 0: 59 | raise ValueError(f"`max_num` should be a positive integer value " 60 | "(got {max_num}).") 61 | self.mode_avail = ['node', 'node^2'] 62 | self.mode_calc_map = { 63 | "node": self.node_calc, 64 | "node^2": self.node_square_calc, 65 | } 66 | 67 | if not mode in self.mode_avail: 68 | raise ValueError(f"mode {self.mode} is not available.") 69 | self.mode_calc = self.mode_calc_map[mode] 70 | 71 | if num_steps is None: 72 | num_steps = len(dataset) 73 | 74 | self.dataset = dataset 75 | self.num_samples = len(dataset) 76 | self.max_num = max_num 77 | self.mode = mode 78 | self.shuffle = shuffle 79 | self.skip_too_big = skip_too_big 80 | self.num_steps = num_steps 81 | self.drop_last = drop_last 82 | self.max_batch = max_batch 83 | if not ddp: 84 | self.sampler = RandomSampler( 85 | dataset, 86 | generator=torch.Generator().manual_seed(42), 87 | ) 88 | else: 89 | self.sampler = DistributedSampler( 90 | dataset, shuffle=shuffle, seed=42, 91 | ) 92 | 93 | self.batch_size = self.max_num // 400 94 | if self.max_batch is None: 95 | self.max_batch = len(dataset) // self.batch_size 96 | 97 | @staticmethod 98 | def node_calc(x): 99 | return x 100 | 101 | @staticmethod 102 | def node_square_calc(x): 103 | return x ** 2 104 | 105 | def __iter__(self) -> Iterator[List[int]]: 106 | batch: List[int] = [] 107 | batch_n = 0 108 | num_batch = 0 109 | for idx in self.sampler: 110 | data = self.dataset[idx] 111 | n = self.mode_calc(data["size_0"].item()) 112 | 113 | if len(batch) and batch_n + n > self.max_num: 114 | # Mini-batch filled 115 | # print("batch: ", batch) 116 | yield batch 117 | batch = [] 118 | batch_n = 0 119 | num_batch += 1 120 | 121 | if (self.max_batch is not None) \ 122 | and (num_batch > self.max_batch): 123 | break 124 | 125 | if n > self.max_num: 126 | if self.skip_too_big: 127 | continue 128 | else: 129 | warnings.warn( 130 | f"Size of data sample at index {idx} is larger than" 131 | f"{self.max_num} at {self.mode}s (got {n})." 132 | "This warning suugests that some systems you have does" 133 | "not fit in one GPU." 134 | ) 135 | batch.append(idx) 136 | batch_n += n 137 | 138 | if not self.drop_last and len(batch): 139 | yield batch 140 | 141 | def __len__(self) -> int: 142 | return self.max_batch -------------------------------------------------------------------------------- /reactot/dataset/transition1x.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | 5 | from reactot.dataset.base_dataset import BaseDataset, ATOM_MAPPING, SAM_CHARGED_ATOM_MAPPING 6 | 7 | 8 | FRAG_MAPPING = { 9 | "reactant": "product", 10 | "transition_state": "transition_state", 11 | "product": "reactant", 12 | } 13 | 14 | 15 | def reflect_z(x): 16 | x = np.array(x) 17 | x[:, -1] = - x[:, -1] 18 | return x 19 | 20 | 21 | class ProcessedTS1x(BaseDataset): 22 | def __init__( 23 | self, 24 | npz_path, 25 | center=True, 26 | pad_fragments=0, 27 | device="cuda", 28 | zero_charge=False, 29 | remove_h=False, 30 | single_frag_only=True, 31 | multi_frag_only=False, 32 | swapping_react_prod=False, 33 | append_frag=False, 34 | reflection=False, 35 | use_by_ind=False, 36 | only_ts=False, 37 | only_rp=False, 38 | confidence_model=False, 39 | position_key="positions", 40 | ediff=None, 41 | ts_guess=False, 42 | react_type=None, 43 | atom_mapping=ATOM_MAPPING, 44 | **kwargs, 45 | ): 46 | super().__init__( 47 | npz_path=npz_path, 48 | center=center, 49 | device=device, 50 | zero_charge=zero_charge, 51 | remove_h=remove_h, 52 | atom_mapping=atom_mapping, 53 | ) 54 | if confidence_model: 55 | use_by_ind = False 56 | if remove_h: 57 | print("remove_h is ignored because it is not reasonble for TS.") 58 | if single_frag_only: 59 | print("Filtering: Maintain only uni-molecular reactions") 60 | single_frag_inds = np.where( 61 | np.array(self.raw_dataset["single_fragment"]) == 1)[0] 62 | elif multi_frag_only: 63 | print("Filtering: Maintain only multi-molecular reactions") 64 | single_frag_inds = np.where( 65 | np.array(self.raw_dataset["single_fragment"]) == 0)[0] 66 | else: 67 | single_frag_inds = np.array(range(len(self.raw_dataset["single_fragment"]))) 68 | if use_by_ind: 69 | print("Filtering: Maintain based on data partitioning") 70 | use_inds = self.raw_dataset["use_ind"] 71 | else: 72 | use_inds = range(len(self.raw_dataset["single_fragment"])) 73 | if react_type is not None: 74 | print(f"Filtering: Maintain reactions only with type {react_type}") 75 | intended_inds = np.where( 76 | np.array(self.raw_dataset["type"]) == react_type)[0] 77 | else: 78 | intended_inds = range(len(self.raw_dataset["single_fragment"])) 79 | single_frag_inds = list( 80 | set(single_frag_inds).intersection( 81 | set(use_inds)).intersection( 82 | set(intended_inds)) 83 | ) 84 | print(f"position key: {position_key}, # of data: {len(single_frag_inds)}") 85 | 86 | data_duplicated = copy.deepcopy(self.raw_dataset) 87 | for k, mapped_k in FRAG_MAPPING.items(): 88 | for v, val in data_duplicated[k].items(): 89 | self.raw_dataset[k][v] = [val[ii] for ii in single_frag_inds] 90 | if swapping_react_prod: 91 | mapped_val = data_duplicated[mapped_k][v] 92 | self.raw_dataset[k][v] += [mapped_val[ii] for ii in single_frag_inds] 93 | if reflection: 94 | for k, mapped_k in FRAG_MAPPING.items(): 95 | for v, val in self.raw_dataset[k].items(): 96 | if v in ["wB97x_6-31G(d).forces", position_key]: 97 | self.raw_dataset[k][v] += [ 98 | reflect_z(_val) for _val in val] 99 | else: 100 | self.raw_dataset[k][v] += val 101 | 102 | self.reactant = self.raw_dataset["reactant"] 103 | self.transition_state = self.raw_dataset["transition_state"] 104 | self.product = self.raw_dataset["product"] 105 | 106 | self.n_fragments = pad_fragments + 3 107 | self.device = torch.device(device) 108 | n_samples = len(self.reactant["charges"]) 109 | self.n_samples = len(self.reactant["charges"]) 110 | 111 | self.data = {} 112 | repeat = 2 if swapping_react_prod else 1 113 | if confidence_model: 114 | self.data["target"] = torch.tensor(self.raw_dataset["target"] * repeat).unsqueeze(1) 115 | self.data["rmsd"] = torch.tensor(self.raw_dataset["rmsd"] * repeat).unsqueeze(1) 116 | if ediff is not None: 117 | self.data["ediff"] = torch.tensor(self.raw_dataset[ediff]["ediff"] * repeat).unsqueeze(1) 118 | if ts_guess: 119 | self.data["ts_guess"] = [torch.tensor(self.raw_dataset[ts_guess][ii]) for ii in single_frag_inds] * repeat 120 | if not only_ts: 121 | if not only_rp: 122 | if not append_frag: 123 | self.process_molecules("reactant", n_samples, idx=0, position_key=position_key) 124 | self.process_molecules("transition_state", n_samples, idx=1) 125 | self.process_molecules("product", n_samples, idx=2, position_key=position_key) 126 | else: 127 | self.process_molecules("reactant", n_samples, idx=0, append_charge=0, position_key=position_key) 128 | self.process_molecules("transition_state", n_samples, idx=1, append_charge=1) 129 | self.process_molecules("product", n_samples, idx=2, append_charge=0, position_key=position_key) 130 | 131 | for idx in range(pad_fragments): 132 | self.patch_dummy_molecules(idx + 3) 133 | else: 134 | self.process_molecules("reactant", n_samples, idx=0, position_key=position_key) 135 | self.process_molecules("product", n_samples, idx=1, position_key=position_key) 136 | else: 137 | if not append_frag: 138 | self.process_molecules("transition_state", n_samples, idx=0) 139 | else: 140 | self.process_molecules("transition_state", n_samples, idx=0, append_charge=1) 141 | # for idx in range(2): 142 | # self.patch_dummy_molecules(idx + 1) 143 | 144 | # if "charge" in self.raw_dataset: 145 | if False: 146 | charge_duplicated = copy.deepcopy(self.raw_dataset["charge"]) 147 | self.data["condition"] = [ 148 | torch.tensor([charge_duplicated[ii]], dtype=torch.int64, device=self.device,).reshape(1, 1) 149 | for ii in single_frag_inds 150 | ] 151 | if swapping_react_prod: 152 | self.data["condition"] += self.data["condition"] 153 | assert len(self.data["condition"]) == self.n_samples 154 | else: 155 | self.data["condition"] = [ 156 | torch.zeros(size=(1, 1), dtype=torch.int64, device=self.device,) 157 | for _ in range(self.n_samples) 158 | ] 159 | -------------------------------------------------------------------------------- /reactot/dataset/zeolite.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | 5 | from reactot.dataset.base_dataset import BaseDataset 6 | from reactot.dataset.datasets_config import ZEOLITE_ATOM_MAPPING 7 | 8 | 9 | FRAG_MAPPING = { 10 | "reactant": "reactant", 11 | } 12 | 13 | 14 | 15 | def reflect_z(x): 16 | x = np.array(x) 17 | x[:, -1] = -x[:, -1] 18 | return x 19 | 20 | 21 | class ProcessedZeolite(BaseDataset): 22 | def __init__( 23 | self, 24 | npz_path, 25 | center=True, 26 | pad_fragments=0, 27 | device="cuda", 28 | zero_charge=False, 29 | remove_h=False, 30 | swapping_react_prod=False, 31 | append_frag=False, 32 | reflection=False, 33 | use_by_ind=True, 34 | only_one=True, 35 | for_scalar_target=False, 36 | position_key="positions", 37 | atom_mapping=ZEOLITE_ATOM_MAPPING, 38 | **kwargs, 39 | ): 40 | super().__init__( 41 | npz_path=npz_path, 42 | center=center, 43 | device=device, 44 | zero_charge=zero_charge, 45 | remove_h=remove_h, 46 | atom_mapping=atom_mapping, 47 | ) 48 | if remove_h: 49 | print("remove_h is ignored because it is not reasonble for TS.") 50 | 51 | single_frag_inds = np.array(range(len(self.raw_dataset["reactant"]["idx"]))) 52 | if use_by_ind: 53 | use_inds = self.raw_dataset["use_ind"] 54 | else: 55 | try: 56 | use_inds = range(len(self.raw_dataset["single_fragment"])) 57 | except: 58 | use_inds = np.array(range(len(self.raw_dataset["reactant"]["idx"]))) 59 | single_frag_inds = list(set(single_frag_inds).intersection(set(use_inds))) 60 | 61 | data_duplicated = copy.deepcopy(self.raw_dataset) 62 | for k, mapped_k in FRAG_MAPPING.items(): 63 | for v, val in data_duplicated[k].items(): 64 | self.raw_dataset[k][v] = [val[ii] for ii in single_frag_inds] 65 | if swapping_react_prod: 66 | mapped_val = data_duplicated[mapped_k][v] 67 | self.raw_dataset[k][v] += [ 68 | mapped_val[ii] for ii in single_frag_inds 69 | ] 70 | if reflection: 71 | for k, mapped_k in FRAG_MAPPING.items(): 72 | for v, val in self.raw_dataset[k].items(): 73 | if v in ["wB97x_6-31G(d).forces", position_key]: 74 | self.raw_dataset[k][v] += [reflect_z(_val) for _val in val] 75 | else: 76 | self.raw_dataset[k][v] += val 77 | 78 | self.reactant = self.raw_dataset["reactant"] 79 | 80 | self.n_fragments = pad_fragments + 3 81 | self.device = torch.device(device) 82 | n_samples = len(self.reactant["charges"]) 83 | self.n_samples = len(self.reactant["charges"]) 84 | 85 | self.data = {} 86 | repeat = 2 if swapping_react_prod else 1 87 | if for_scalar_target: 88 | self.data["target"] = torch.tensor( 89 | [ 90 | self.raw_dataset["target"][ii] for ii in single_frag_inds 91 | ] * repeat 92 | ).unsqueeze(1).float() 93 | 94 | if not only_one: 95 | if not append_frag: 96 | self.process_molecules( 97 | "reactant", n_samples, idx=0, position_key=position_key 98 | ) 99 | self.process_molecules("transition_state", n_samples, idx=1) 100 | self.process_molecules( 101 | "product", n_samples, idx=2, position_key=position_key 102 | ) 103 | else: 104 | self.process_molecules( 105 | "reactant", 106 | n_samples, 107 | idx=0, 108 | append_charge=0, 109 | position_key=position_key, 110 | ) 111 | self.process_molecules( 112 | "transition_state", n_samples, idx=1, append_charge=1 113 | ) 114 | self.process_molecules( 115 | "product", 116 | n_samples, 117 | idx=2, 118 | append_charge=0, 119 | position_key=position_key, 120 | ) 121 | 122 | for idx in range(pad_fragments): 123 | self.patch_dummy_molecules(idx + 3) 124 | else: 125 | if not append_frag: 126 | self.process_molecules("reactant", n_samples, idx=0) 127 | else: 128 | self.process_molecules( 129 | "reactant", n_samples, idx=0, append_charge=1 130 | ) 131 | 132 | self.data["condition"] = [ 133 | torch.zeros(size=(1, 1), dtype=torch.int64, device=self.device,) 134 | for _ in range(self.n_samples) 135 | ] -------------------------------------------------------------------------------- /reactot/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from . import _utils as utils 2 | from ._schedule import DiffSchedule, PredefinedNoiseSchedule 3 | from ._normalizer import Normalizer 4 | from .en_diffusion import EnVariationalDiffusion 5 | from .en_sb import EnSB 6 | -------------------------------------------------------------------------------- /reactot/diffusion/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/diffusion/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/diffusion/__pycache__/_normalizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/diffusion/__pycache__/_normalizer.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/diffusion/__pycache__/_schedule.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/diffusion/__pycache__/_schedule.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/diffusion/__pycache__/_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/diffusion/__pycache__/_utils.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/diffusion/__pycache__/en_diffusion.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/diffusion/__pycache__/en_diffusion.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/diffusion/__pycache__/en_sb.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/diffusion/__pycache__/en_sb.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/diffusion/_node_dist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions.categorical import Categorical 3 | 4 | import numpy as np 5 | 6 | 7 | # TODO: This code is just copied over diffSBDD and has not been modified at all... 8 | class DoubleDistributionNodes: 9 | def __init__(self, histogram): 10 | 11 | histogram = torch.tensor(histogram).float() 12 | histogram = histogram + 1e-3 # for numerical stability 13 | 14 | prob = histogram / histogram.sum() 15 | 16 | self.idx_to_n_nodes = torch.tensor( 17 | [[(i, j) for j in range(prob.shape[1])] for i in range(prob.shape[0])] 18 | ).view(-1, 2) 19 | 20 | self.n_nodes_to_idx = { 21 | tuple(x.tolist()): i for i, x in enumerate(self.idx_to_n_nodes) 22 | } 23 | 24 | self.prob = prob 25 | self.m = torch.distributions.Categorical(self.prob.view(-1), validate_args=True) 26 | 27 | self.n1_given_n2 = [ 28 | torch.distributions.Categorical(prob[:, j], validate_args=True) 29 | for j in range(prob.shape[1]) 30 | ] 31 | self.n2_given_n1 = [ 32 | torch.distributions.Categorical(prob[i, :], validate_args=True) 33 | for i in range(prob.shape[0]) 34 | ] 35 | 36 | # entropy = -torch.sum(self.prob.view(-1) * torch.log(self.prob.view(-1) + 1e-30)) 37 | entropy = self.m.entropy() 38 | print("Entropy of n_nodes: H[N]", entropy.item()) 39 | 40 | def sample(self, n_samples=1): 41 | idx = self.m.sample((n_samples,)) 42 | num_nodes_lig, num_nodes_pocket = self.idx_to_n_nodes[idx].T 43 | return num_nodes_lig, num_nodes_pocket 44 | 45 | def sample_conditional(self, n1=None, n2=None): 46 | assert (n1 is None) ^ (n2 is None), "Exactly one input argument must be None" 47 | 48 | m = self.n1_given_n2 if n2 is not None else self.n2_given_n1 49 | c = n2 if n2 is not None else n1 50 | 51 | return torch.tensor([m[i].sample() for i in c], device=c.device) 52 | 53 | def log_prob(self, batch_n_nodes_1, batch_n_nodes_2): 54 | assert len(batch_n_nodes_1.size()) == 1 55 | assert len(batch_n_nodes_2.size()) == 1 56 | 57 | idx = torch.tensor( 58 | [ 59 | self.n_nodes_to_idx[(n1, n2)] 60 | for n1, n2 in zip(batch_n_nodes_1.tolist(), batch_n_nodes_2.tolist()) 61 | ] 62 | ) 63 | 64 | # log_probs = torch.log(self.prob.view(-1)[idx] + 1e-30) 65 | log_probs = self.m.log_prob(idx) 66 | 67 | return log_probs.to(batch_n_nodes_1.device) 68 | 69 | def log_prob_n1_given_n2(self, n1, n2): 70 | assert len(n1.size()) == 1 71 | assert len(n2.size()) == 1 72 | log_probs = torch.stack( 73 | [self.n1_given_n2[c].log_prob(i.cpu()) for i, c in zip(n1, n2)] 74 | ) 75 | return log_probs.to(n1.device) 76 | 77 | def log_prob_n2_given_n1(self, n2, n1): 78 | assert len(n2.size()) == 1 79 | assert len(n1.size()) == 1 80 | log_probs = torch.stack( 81 | [self.n2_given_n1[c].log_prob(i.cpu()) for i, c in zip(n2, n1)] 82 | ) 83 | return log_probs.to(n2.device) 84 | 85 | 86 | class SingleDistributionNodes: 87 | def __init__(self, histogram): 88 | 89 | self.n_nodes = [] 90 | prob = [] 91 | self.keys = {} 92 | for i, nodes in enumerate(histogram): 93 | self.n_nodes.append(nodes) 94 | self.keys[nodes] = i 95 | prob.append(histogram[nodes]) 96 | self.n_nodes = torch.tensor(self.n_nodes) 97 | prob = np.array(prob) 98 | prob = prob/np.sum(prob) 99 | 100 | self.prob = torch.from_numpy(prob).float() 101 | 102 | entropy = torch.sum(self.prob * torch.log(self.prob + 1e-30)) 103 | print("Entropy of n_nodes: H[N]", entropy.item()) 104 | 105 | self.m = Categorical(torch.tensor(prob)) 106 | 107 | def sample(self, n_samples=1): 108 | idx = self.m.sample((n_samples,)) 109 | return self.n_nodes[idx] 110 | 111 | def log_prob(self, batch_n_nodes): 112 | assert len(batch_n_nodes.size()) == 1 113 | 114 | idcs = [self.keys[i.item()] for i in batch_n_nodes] 115 | idcs = torch.tensor(idcs).to(batch_n_nodes.device) 116 | 117 | log_p = torch.log(self.prob + 1e-30) 118 | 119 | log_p = log_p.to(batch_n_nodes.device) 120 | 121 | log_probs = log_p[idcs] 122 | 123 | return log_probs 124 | -------------------------------------------------------------------------------- /reactot/diffusion/_normalizer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List, Dict 2 | 3 | import torch 4 | from torch import nn, Tensor 5 | 6 | FEATURE_MAPPING = ["pos", "one_hot", "charge"] 7 | 8 | 9 | class Normalizer(nn.Module): 10 | def __init__( 11 | self, 12 | norm_values: Tuple = (1.0, 1.0, 1.0), 13 | norm_biases: Tuple = (0.0, 0.0, 0.0), 14 | pos_dim: int = 3, 15 | ) -> None: 16 | super().__init__() 17 | self.norm_values = norm_values 18 | self.norm_biases = norm_biases 19 | self.pos_dim = pos_dim 20 | 21 | def normalize(self, representations: List[Dict]) -> List[Dict]: 22 | for ii in range(len(representations)): 23 | for jj, feature_type in enumerate(FEATURE_MAPPING): 24 | representations[ii][feature_type] = ( 25 | representations[ii][feature_type] - self.norm_biases[jj] 26 | ) / self.norm_values[jj] 27 | return representations 28 | 29 | def unnormalize(self, x: Tensor, ind: int) -> Tensor: 30 | return x * self.norm_values[ind] + self.norm_biases[ind] 31 | 32 | def unnormalize_z(self, z_combined: List[Tensor]) -> List[Tensor]: 33 | for ii in range(len(z_combined)): 34 | z_combined[ii][:, : self.pos_dim] = self.unnormalize( 35 | z_combined[ii][:, : self.pos_dim], 0 36 | ) 37 | z_combined[ii][:, self.pos_dim : -1] = self.unnormalize( 38 | z_combined[ii][:, self.pos_dim : -1], 1 39 | ) 40 | z_combined[ii][:, -1:] = self.unnormalize(z_combined[ii][:, -1:], 2) 41 | return z_combined 42 | -------------------------------------------------------------------------------- /reactot/diffusion/_schedule.py: -------------------------------------------------------------------------------- 1 | """t schedule used in diffusion process.""" 2 | from typing import Tuple 3 | from functools import partial 4 | import numpy as np 5 | import torch 6 | from torch import Tensor, nn 7 | import torch.nn.functional as F 8 | 9 | from ._utils import unsqueeze_xdim 10 | 11 | 12 | def cosine_beta_schedule(timesteps, s=0.008, raise_to_power: float = 1): 13 | r""" 14 | cosine schedule 15 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 16 | """ 17 | steps = timesteps + 2 18 | x = np.linspace(0, steps, steps) 19 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 20 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 21 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 22 | betas = np.clip(betas, a_min=0, a_max=0.999) 23 | alphas = 1.0 - betas 24 | alphas_cumprod = np.cumprod(alphas, axis=0) 25 | 26 | if raise_to_power != 1: 27 | alphas_cumprod = np.power(alphas_cumprod, raise_to_power) 28 | 29 | return alphas_cumprod 30 | 31 | 32 | def ccosine_schedule(timesteps, start=0, end=1, tau=1, clip_min=1e-9): 33 | t = np.linspace(0, 1, timesteps+1) 34 | v_start = np.cos(start * np.pi / 2) ** (2 * tau) 35 | v_end = np.cos(end * np.pi / 2) ** (2 * tau) 36 | output = np.cos((t * (end - start) + start) * np.pi /2) ** (2 * tau) 37 | output = (v_end - output) / (v_end - v_start) 38 | return np.clip(output, clip_min, 1 - clip_min) 39 | 40 | 41 | def linear_schedule(timesteps, clip_min=1e-9): 42 | t = np.linspace(0, 1, timesteps+1) 43 | output = 1 - t 44 | return np.clip(output, clip_min, 1 - clip_min) 45 | 46 | 47 | def clip_noise_schedule(alphas2, clip_value=0.001): 48 | r""" 49 | For a noise schedule given by alpha^2, this clips alpha_t / alpha_t-1. 50 | This may help improve stability during 51 | sampling. 52 | """ 53 | alphas2 = np.concatenate([np.ones(1), alphas2], axis=0) 54 | 55 | alphas_step = alphas2[1:] / alphas2[:-1] 56 | 57 | alphas_step = np.clip(alphas_step, a_min=clip_value, a_max=1.0) 58 | alphas2 = np.cumprod(alphas_step, axis=0) 59 | 60 | return alphas2 61 | 62 | 63 | def polynomial_schedule(timesteps: int, s=1e-4, power=3.0): 64 | r""" 65 | A noise schedule based on a simple polynomial equation: 1 - x^power. 66 | """ 67 | steps = timesteps + 1 68 | x = np.linspace(0, steps, steps) 69 | alphas2 = (1 - np.power(x / steps, power)) ** 2 70 | 71 | alphas2 = clip_noise_schedule(alphas2, clip_value=0.001) 72 | 73 | precision = 1 - 2 * s 74 | 75 | alphas2 = precision * alphas2 + s 76 | 77 | return alphas2 78 | 79 | 80 | class PredefinedNoiseSchedule(nn.Module): 81 | r""" 82 | Predefined noise schedule. Essentially creates a lookup array for predefined 83 | (non-learned) noise schedules. 84 | """ 85 | 86 | def __init__( 87 | self, 88 | noise_schedule: str, 89 | timesteps: int, 90 | precision: float, 91 | ): 92 | super().__init__() 93 | self.timesteps = timesteps 94 | 95 | if "cosine" in noise_schedule: 96 | splits = noise_schedule.split("_") 97 | assert len(splits) <= 2 98 | power = 1 if len(splits) == 1 else float(splits[1]) 99 | alphas2 = cosine_beta_schedule(timesteps, raise_to_power=power) 100 | elif "polynomial" in noise_schedule: 101 | splits = noise_schedule.split("_") 102 | assert len(splits) == 2 103 | power = float(splits[1]) 104 | alphas2 = polynomial_schedule(timesteps, s=precision, power=power) 105 | elif "csin" in noise_schedule: 106 | splits = noise_schedule.split("_") 107 | assert len(splits) == 4 108 | start, end, tau = float(splits[1]), float(splits[2]), float(splits[3]) 109 | alphas2 = ccosine_schedule(timesteps, start=start, end=end, tau=tau) 110 | elif "linear" in noise_schedule: 111 | alphas2 = linear_schedule(timesteps) 112 | else: 113 | raise ValueError(noise_schedule) 114 | 115 | # print("alphas2", alphas2) 116 | 117 | sigmas2 = 1 - alphas2 118 | 119 | log_alphas2 = np.log(alphas2) 120 | log_sigmas2 = np.log(sigmas2) 121 | 122 | log_alphas2_to_sigmas2 = log_alphas2 - log_sigmas2 123 | 124 | # print("gamma", -log_alphas2_to_sigmas2) 125 | 126 | self.gamma = torch.nn.Parameter( 127 | torch.from_numpy(-log_alphas2_to_sigmas2).float(), requires_grad=False 128 | ) 129 | 130 | def forward(self, t): 131 | t_int = torch.round(t * self.timesteps).long() 132 | return self.gamma[t_int] 133 | 134 | 135 | class DiffSchedule(nn.Module): 136 | def __init__(self, gamma_module: nn.Module, norm_values: Tuple[float]) -> None: 137 | super().__init__() 138 | self.gamma_module = gamma_module 139 | self.norm_values = norm_values 140 | self.check_issues_norm_values() 141 | 142 | @staticmethod 143 | def inflate_batch_array(array, target): 144 | r""" 145 | Inflates the batch array (array) with only a single axis 146 | (i.e. shape = (batch_size,), or possibly more empty axes 147 | (i.e. shape (batch_size, 1, ..., 1)) to match the target shape. 148 | """ 149 | target_shape = (array.size(0),) + (1,) * (len(target.size()) - 1) 150 | return array.view(target_shape) 151 | 152 | def sigma(self, gamma, target_tensor): 153 | r"""Computes sigma given gamma.""" 154 | return self.inflate_batch_array(torch.sqrt(torch.sigmoid(gamma)), target_tensor) 155 | 156 | def alpha(self, gamma, target_tensor): 157 | r"""Computes alpha given gamma.""" 158 | return self.inflate_batch_array( 159 | torch.sqrt(torch.sigmoid(-gamma)), target_tensor 160 | ) 161 | 162 | @staticmethod 163 | def SNR(gamma): 164 | r"""Computes signal to noise ratio (alpha^2/sigma^2) given gamma.""" 165 | return torch.exp(-gamma) 166 | 167 | def sigma_and_alpha_t_given_s( 168 | self, gamma_t: Tensor, gamma_s: Tensor, target_tensor: Tensor 169 | ) -> tuple[Tensor, Tensor, Tensor]: 170 | r""" 171 | Computes sigma t given s, using gamma_t and gamma_s. Used during sampling. 172 | These are defined as: 173 | alpha t given s = alpha t / alpha s, 174 | sigma t given s = sqrt(1 - (alpha t given s) ^2 ). 175 | """ 176 | sigma2_t_given_s = self.inflate_batch_array( 177 | -torch.expm1(F.softplus(gamma_s) - F.softplus(gamma_t)), target_tensor 178 | ) 179 | 180 | # alpha_t_given_s = alpha_t / alpha_s 181 | log_alpha2_t = F.logsigmoid(-gamma_t) 182 | log_alpha2_s = F.logsigmoid(-gamma_s) 183 | log_alpha2_t_given_s = log_alpha2_t - log_alpha2_s 184 | 185 | alpha_t_given_s = torch.exp(0.5 * log_alpha2_t_given_s) 186 | alpha_t_given_s = self.inflate_batch_array(alpha_t_given_s, target_tensor) 187 | 188 | sigma_t_given_s = torch.sqrt(sigma2_t_given_s) 189 | 190 | return sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s 191 | 192 | def check_issues_norm_values(self, num_stdevs=8): 193 | zeros = torch.zeros((1, 1)) 194 | gamma_0 = self.gamma_module(zeros) 195 | sigma_0 = self.sigma(gamma_0, target_tensor=zeros).item() 196 | 197 | # Checked if 1 / norm_value is still larger than 10 * standard 198 | # deviation. 199 | norm_value = self.norm_values[1] 200 | 201 | if sigma_0 * num_stdevs > 1.0 / norm_value: 202 | raise ValueError( 203 | f"Value for normalization value {norm_value} probably too " 204 | f"large with sigma_0 {sigma_0:.5f} and " 205 | f"1 / norm_value = {1. / norm_value}" 206 | ) 207 | 208 | 209 | def get_repaint_schedule(resamplings, jump_length, timesteps): 210 | """ 211 | Each integer in the schedule list describes how many denoising steps 212 | need to be applied before jumping back. 213 | 214 | sum(out) - (len(out) -1) * jump_length = timesteps 215 | 216 | """ 217 | repaint_schedule = [] 218 | curr_t = 0 219 | while curr_t < timesteps: 220 | if curr_t + jump_length < timesteps: 221 | if len(repaint_schedule) > 0: 222 | repaint_schedule[-1] += jump_length 223 | repaint_schedule.extend([jump_length] * (resamplings - 1)) 224 | else: 225 | repaint_schedule.extend([jump_length] * resamplings) 226 | curr_t += jump_length 227 | else: 228 | residual = (timesteps - curr_t) 229 | if len(repaint_schedule) > 0: 230 | repaint_schedule[-1] += residual 231 | else: 232 | repaint_schedule.append(residual) 233 | curr_t += residual 234 | 235 | return list(reversed(repaint_schedule)) 236 | 237 | 238 | def make_beta_schedule(n_timestep=1000, linear_start=1e-4, linear_end=2e-2, 239 | power: float = 1., inv_power: float = 1.): 240 | """ 241 | betas for schrodinger bridge 242 | """ 243 | betas = torch.linspace( 244 | linear_start ** inv_power, 245 | linear_end ** inv_power, 246 | n_timestep, 247 | dtype=torch.float64 248 | ) ** power 249 | return betas.numpy() 250 | 251 | 252 | def compute_gaussian_product_coef(sigma1, sigma2): 253 | """ 254 | Given p1 = N(x_t|x_0, sigma_1**2) and p2 = N(x_t|x_1, sigma_2**2) 255 | return p1 * p2 = N(x_t| coef1 * x0 + coef2 * x1, var) 256 | """ 257 | 258 | # denom = np.sqrt(sigma1**2 + sigma2**2) 259 | # coef1 = sigma2 / denom 260 | # coef2 = sigma1 / denom 261 | 262 | denom = sigma1**2 + sigma2**2 263 | coef1 = sigma2**2 / denom 264 | coef2 = sigma1**2 / denom 265 | var = (sigma1**2 * sigma2**2) / denom 266 | return coef1, coef2, var 267 | 268 | 269 | class SBSchedule(): 270 | def __init__( 271 | self, 272 | timesteps: int = 1000, 273 | beta_max: float = 0.3, 274 | power: float = 1., 275 | inv_power: float = 1. 276 | ): 277 | betas = make_beta_schedule( 278 | n_timestep=timesteps, 279 | linear_end=beta_max/timesteps, 280 | power=power, 281 | inv_power=inv_power, 282 | ) 283 | betas = np.concatenate([betas[:timesteps//2], np.flip(betas[:timesteps//2])]) 284 | betas = (beta_max / timesteps) / np.max(betas) * betas * 0.5 285 | 286 | self.timesteps = betas.shape[0] 287 | 288 | # compute analytic std: eq 11 289 | std_fwd = np.sqrt(np.cumsum(betas)) 290 | std_bwd = np.sqrt(np.flip(np.cumsum(np.flip(betas)))) 291 | 292 | mu_x0, mu_x1, var = compute_gaussian_product_coef(std_fwd, std_bwd) 293 | 294 | # alphas2 = polynomial_schedule(timesteps=3000 - 1, s=1e-5, power=1) 295 | # alphas = np.sqrt(alphas2) 296 | 297 | # mu_x0, mu_x1, var = compute_gaussian_product_coef(alphas) 298 | 299 | std_sb = np.sqrt(var) 300 | 301 | # tensorize everything 302 | to_torch = partial(torch.tensor, dtype=torch.float32) 303 | self.betas = to_torch(betas) 304 | self.std_fwd = to_torch(std_fwd) 305 | self.std_bwd = to_torch(std_bwd) 306 | self.std_sb = to_torch(std_sb) 307 | self.mu_x0 = to_torch(mu_x0) 308 | self.mu_x1 = to_torch(mu_x1) 309 | # self.alphas = to_torch(alphas) 310 | 311 | @staticmethod 312 | def inflate_batch_array(array, target): 313 | r""" 314 | Inflates the batch array (array) with only a single axis 315 | (i.e. shape = (batch_size,), or possibly more empty axes 316 | (i.e. shape (batch_size, 1, ..., 1)) to match the target shape. 317 | """ 318 | target_shape = (array.size(0),) + (1,) * (len(target.size()) - 1) 319 | return array.view(target_shape) 320 | 321 | def get_std_fwd(self, step, xdim=None): 322 | device = self.mu_x0.device 323 | step=step.to(device) 324 | std_fwd = self.std_fwd[step] 325 | return std_fwd if xdim is None else unsqueeze_xdim(std_fwd, xdim) 326 | -------------------------------------------------------------------------------- /reactot/diffusion/_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import math 4 | import torch 5 | from torch import Tensor 6 | from torch_scatter import scatter_add, scatter_mean 7 | 8 | import ase 9 | from ase.calculators.emt import EMT 10 | from ase.neb import NEB 11 | from ase import Atoms 12 | 13 | 14 | def remove_mean_batch(x, indices): 15 | mean = scatter_mean(x, indices, dim=0) 16 | x = x - mean[indices] 17 | return x 18 | 19 | 20 | def assert_mean_zero_with_mask(x, node_mask, eps=1e-10): 21 | largest_value = x.abs().max().item() 22 | error = scatter_add(x, node_mask, dim=0).abs().max().item() 23 | rel_error = error / (largest_value + eps) 24 | assert rel_error < 1e-2, f"Mean is not zero, relative_error {rel_error}" 25 | 26 | 27 | def sample_center_gravity_zero_gaussian_batch( 28 | size: List[int], indices: List[Tensor] 29 | ) -> Tensor: 30 | assert len(size) == 2 31 | x = torch.randn(size, device=indices[0].device) 32 | 33 | # This projection only works because Gaussian is rotation invariant 34 | # around zero and samples are independent! 35 | x_projected = remove_mean_batch(x, torch.cat(indices)) 36 | return x_projected 37 | 38 | 39 | def sum_except_batch(x, indices, dim_size): 40 | return scatter_add(x.sum(-1), indices, dim=0, dim_size=dim_size) 41 | 42 | 43 | def cdf_standard_gaussian(x): 44 | return 0.5 * (1.0 + torch.erf(x / math.sqrt(2))) 45 | 46 | 47 | def sample_gaussian(size, device): 48 | x = torch.randn(size, device=device) 49 | return x 50 | 51 | 52 | def num_nodes_to_batch_mask(n_samples, num_nodes, device): 53 | assert isinstance(num_nodes, int) or len(num_nodes) == n_samples 54 | 55 | if isinstance(num_nodes, torch.Tensor): 56 | num_nodes = num_nodes.to(device) 57 | 58 | sample_inds = torch.arange(n_samples, device=device) 59 | 60 | return torch.repeat_interleave(sample_inds, num_nodes) 61 | 62 | 63 | def unsqueeze_xdim(z, xdim): 64 | bc_dim = (...,) + (None,) * len(xdim) 65 | return z[bc_dim] 66 | 67 | 68 | def space_indices(num_steps, count): 69 | assert count <= num_steps 70 | 71 | if count <= 1: 72 | frac_stride = 1 73 | else: 74 | frac_stride = (num_steps - 1) / (count - 1) 75 | 76 | cur_idx = 0.0 77 | taken_steps = [] 78 | for _ in range(count): 79 | taken_steps.append(round(cur_idx)) 80 | cur_idx += frac_stride 81 | 82 | return taken_steps 83 | 84 | 85 | def idpp_guess(r_pos, p_pos, x0_size, x0_other, n_images=3, interpolate="idpp"): 86 | _r_pos = torch.tensor_split( 87 | r_pos, 88 | torch.cumsum(x0_size, dim=0).to("cpu")[:-1] 89 | ) 90 | _p_pos = torch.tensor_split( 91 | p_pos, 92 | torch.cumsum(x0_size, dim=0).to("cpu")[:-1] 93 | ) 94 | z = torch.tensor_split( 95 | x0_other[:, -1], 96 | torch.cumsum(x0_size, dim=0).to("cpu")[:-1] 97 | ) 98 | z = [_z.long().cpu().numpy() for _z in z] 99 | 100 | ts_pos = [] 101 | for x_r, x_p, atom_number in zip(_r_pos, _p_pos, z): 102 | mol_r = Atoms( 103 | numbers=atom_number, 104 | positions=x_r.cpu().numpy(), 105 | ) 106 | mol_p = Atoms( 107 | numbers=atom_number, 108 | positions=x_p.cpu().numpy(), 109 | ) 110 | 111 | images = [mol_r.copy()] 112 | for _ in range(n_images - 2): 113 | images.append(mol_r.copy()) 114 | images.append(mol_p.copy()) 115 | 116 | for image in images: 117 | image.calc = EMT() 118 | 119 | neb = NEB(images) 120 | if interpolate == "idpp": 121 | neb.idpp_interpolate( 122 | traj=None, log=None, fmax=1000, optimizer=ase.optimize.MDMin, mic=False, steps=0) 123 | elif interpolate == "linear": 124 | neb.interpolate('linear') 125 | else: 126 | raise ValueError("interpolate can only be idpp or linear") 127 | x_ts = torch.tensor( 128 | neb.images[n_images // 2].arrays["positions"], 129 | dtype=torch.float32, 130 | ) 131 | ts_pos.append(x_ts) 132 | 133 | ts_pos = torch.concat(ts_pos).to(x0_size.device) 134 | return ts_pos -------------------------------------------------------------------------------- /reactot/dynamics/__init__.py: -------------------------------------------------------------------------------- 1 | from .egnn_dynamics import EGNNDynamics 2 | from .confidence import Confidence 3 | from .potential import Potential 4 | -------------------------------------------------------------------------------- /reactot/dynamics/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/dynamics/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/dynamics/__pycache__/_base.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/dynamics/__pycache__/_base.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/dynamics/__pycache__/confidence.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/dynamics/__pycache__/confidence.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/dynamics/__pycache__/egnn_dynamics.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/dynamics/__pycache__/egnn_dynamics.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/dynamics/__pycache__/potential.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/dynamics/__pycache__/potential.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/dynamics/_base.py: -------------------------------------------------------------------------------- 1 | """Base class for assembling fragments and performing model updates.""" 2 | from typing import Dict, List, Optional 3 | import torch 4 | from torch import nn 5 | 6 | from reactot.model import MLP, EGNN 7 | 8 | 9 | class BaseDynamics(nn.Module): 10 | def __init__( 11 | self, 12 | model_config: Dict, 13 | fragment_names: List[str], 14 | node_nfs: List[int], 15 | edge_nf: int, 16 | condition_nf: int = 0, 17 | pos_dim: int = 3, 18 | update_pocket_coords: bool = True, 19 | condition_time: bool = True, 20 | edge_cutoff: Optional[float] = None, 21 | model: nn.Module = EGNN, 22 | device: torch.device = torch.device("cuda"), 23 | enforce_same_encoding: Optional[List] = None, 24 | source: Optional[Dict] = None, 25 | ) -> None: 26 | r"""Base dynamics class set up for denoising process. 27 | 28 | Args: 29 | model_config (Dict): config for the equivariant model. 30 | fragment_names (List[str]): list of names for fragments 31 | node_nfs (List[int]): list of number of input node attributues. 32 | edge_nf (int): number of input edge attributes. 33 | condition_nf (int): number of attributes for conditional generation. 34 | Defaults to 0. 35 | pos_dim (int): dimension for position vector. Defaults to 3. 36 | update_pocket_coords (bool): whether to update positions of everything. 37 | Defaults to True. 38 | condition_time (bool): whether to condition on time. Defaults to True. 39 | edge_cutoff (Optional[float]): cutoff for building intra-fragment edges. 40 | Defaults to None. 41 | model (Optional[nn.Module]): Module for equivariant model. Defaults to None. 42 | """ 43 | super().__init__() 44 | assert len(node_nfs) == len(fragment_names) 45 | for nf in node_nfs: 46 | assert nf > pos_dim 47 | if "act_fn" not in model_config: 48 | model_config["act_fn"] = "swish" 49 | if "in_node_nf" not in model_config: 50 | model_config["in_node_nf"] = model_config["in_hidden_channels"] 51 | self.model_config = model_config 52 | self.node_nfs = node_nfs 53 | self.edge_nf = edge_nf 54 | self.condition_nf = condition_nf 55 | self.fragment_names = fragment_names 56 | self.pos_dim = pos_dim 57 | self.update_pocket_coords = update_pocket_coords 58 | self.condition_time = condition_time 59 | self.edge_cutoff = edge_cutoff 60 | self.device = device 61 | 62 | if model is None: 63 | model = EGNN 64 | self.model = model(**model_config) 65 | if source is not None and "model" in source: 66 | self.model.load_state_dict(source["model"]) 67 | self.dist_dim = self.model.dist_dim if hasattr(self.model, "dist_dim") else 0 68 | 69 | self.embed_dim = model_config["in_node_nf"] 70 | self.edge_embed_dim = model_config["in_edge_nf"] if "in_edge_nf" in model_config else 0 71 | if condition_time: 72 | self.embed_dim -= 1 73 | if condition_nf > 0: 74 | self.embed_dim -= condition_nf 75 | assert self.embed_dim > 0 76 | 77 | self.build_encoders_decoders(enforce_same_encoding, source) 78 | del source 79 | 80 | def build_encoders_decoders( 81 | self, 82 | enfoce_name_encoding: Optional[List] = None, 83 | source: Optional[Dict] = None, 84 | ): 85 | r"""Build encoders and decoders for nodes and edges.""" 86 | self.encoders = nn.ModuleList() 87 | self.decoders = nn.ModuleList() 88 | for ii, name in enumerate(self.fragment_names): 89 | self.encoders.append( 90 | MLP( 91 | in_dim=self.node_nfs[ii] - self.pos_dim, 92 | out_dims=[2 * (self.node_nfs[ii] - self.pos_dim), self.embed_dim], 93 | activation=self.model_config["act_fn"], 94 | last_layer_no_activation=True, 95 | ) 96 | ) 97 | self.decoders.append( 98 | MLP( 99 | in_dim=self.embed_dim, 100 | out_dims=[ 101 | 2 * (self.node_nfs[ii] - self.pos_dim), 102 | self.node_nfs[ii] - self.pos_dim, 103 | ], 104 | activation=self.model_config["act_fn"], 105 | last_layer_no_activation=True, 106 | ) 107 | ) 108 | if enfoce_name_encoding is not None: 109 | for ii in enfoce_name_encoding: 110 | self.encoders[ii] = self.encoders[0] 111 | self.decoders[ii] = self.decoders[0] 112 | if source is not None and "encoders" in source: 113 | self.encoders.load_state_dict(source["encoders"]) 114 | self.decoders.load_state_dict(source["decoders"]) 115 | 116 | if self.edge_embed_dim > 0: 117 | self.edge_encoder = MLP( 118 | in_dim=self.edge_nf, 119 | out_dims=[2 * self.edge_nf, self.edge_embed_dim], 120 | activation=self.model_config["act_fn"], 121 | last_layer_no_activation=True, 122 | ) 123 | self.edge_decoder = MLP( 124 | in_dim=self.edge_embed_dim + self.dist_dim, 125 | out_dims=[2 * self.edge_nf, self.edge_nf], 126 | activation=self.model_config["act_fn"], 127 | last_layer_no_activation=True, 128 | ) 129 | else: 130 | self.edge_encoder, self.edge_decoder = None, None 131 | 132 | def forward(self): 133 | raise NotImplementedError 134 | -------------------------------------------------------------------------------- /reactot/dynamics/confidence.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple 2 | 3 | import torch 4 | from torch import nn, Tensor 5 | from torch_scatter import scatter_mean 6 | 7 | from reactot.model import EGNN 8 | from reactot.model.core import GatedMLP 9 | from reactot.utils import ( 10 | get_subgraph_mask, 11 | get_n_frag_switch, 12 | get_mask_for_frag, 13 | get_edges_index, 14 | ) 15 | from ._base import BaseDynamics 16 | 17 | 18 | FEATURE_MAPPING = ["pos", "one_hot", "charge"] 19 | 20 | 21 | class Confidence(BaseDynamics): 22 | def __init__( 23 | self, 24 | model_config: Dict, 25 | fragment_names: List[str], 26 | node_nfs: List[int], 27 | edge_nf: int, 28 | condition_nf: int = 0, 29 | pos_dim: int = 3, 30 | edge_cutoff: Optional[float] = None, 31 | model: nn.Module = EGNN, 32 | device: torch.device = torch.device("cuda"), 33 | enforce_same_encoding: Optional[List] = None, 34 | source: Optional[Dict] = None, 35 | **kwargs, 36 | ) -> None: 37 | r"""Confindence score for generated samples. 38 | 39 | Args: 40 | model_config (Dict): config for the equivariant model. 41 | fragment_names (List[str]): list of names for fragments 42 | node_nfs (List[int]): list of number of input node attributues. 43 | edge_nf (int): number of input edge attributes. 44 | condition_nf (int): number of attributes for conditional generation. 45 | Defaults to 0. 46 | pos_dim (int): dimension for position vector. Defaults to 3. 47 | update_pocket_coords (bool): whether to update positions of everything. 48 | Defaults to True. 49 | condition_time (bool): whether to condition on time. Defaults to True. 50 | edge_cutoff (Optional[float]): cutoff for building intra-fragment edges. 51 | Defaults to None. 52 | model (Optional[nn.Module]): Module for equivariant model. Defaults to None. 53 | """ 54 | model_config.update({"for_conf": True}) 55 | update_pocket_coords = True 56 | condition_time = True, 57 | super().__init__( 58 | model_config, 59 | fragment_names, 60 | node_nfs, 61 | edge_nf, 62 | condition_nf, 63 | pos_dim, 64 | update_pocket_coords, 65 | condition_time, 66 | edge_cutoff, 67 | model, 68 | device, 69 | enforce_same_encoding, 70 | source=source, 71 | ) 72 | 73 | hidden_channels = model_config["hidden_channels"] 74 | self.readout = GatedMLP( 75 | in_dim=hidden_channels, 76 | out_dims=[hidden_channels, hidden_channels, 1], 77 | activation="swish", 78 | bias=True, 79 | last_layer_no_activation=True, 80 | ) 81 | 82 | def _forward( 83 | self, 84 | xh: List[Tensor], 85 | edge_index: Tensor, 86 | t: Tensor, 87 | conditions: Tensor, 88 | n_frag_switch: Tensor, 89 | combined_mask: Tensor, 90 | edge_attr: Optional[Tensor] = None, 91 | ) -> Tensor: 92 | r"""predict confidence. 93 | 94 | Args: 95 | xh (List[Tensor]): list of concatenated tensors for pos and h 96 | edge_index (Tensor): [n_edge, 2] 97 | t (Tensor): time tensor. If dim is 1, same for all samples; 98 | otherwise different t for different samples 99 | conditions (Tensor): condition tensors 100 | n_frag_switch (Tensor): [n_nodes], fragment index for each nodes 101 | combined_mask (Tensor): [n_nodes], sample index for each node 102 | edge_attr (Optional[Tensor]): [n_edge, dim_edge_attribute]. Defaults to None. 103 | 104 | Raises: 105 | NotImplementedError: The fragement-position-fixed mode is not implement. 106 | 107 | Returns: 108 | Tensor: binary probability of confidence fo each graph. 109 | """ 110 | pos = torch.concat( 111 | [_xh[:, : self.pos_dim].clone() for _xh in xh], 112 | dim=0, 113 | ) 114 | h = torch.concat( 115 | [ 116 | self.encoders[ii](xh[ii][:, self.pos_dim :].clone()) 117 | for ii, name in enumerate(self.fragment_names) 118 | ], 119 | dim=0, 120 | ) 121 | if self.edge_encoder is not None: 122 | edge_attr = self.edge_encoder(edge_attr) 123 | 124 | condition_dim = 0 125 | if self.condition_time: 126 | if len(t.size()) == 1: 127 | # t is the same for all elements in batch. 128 | h_time = torch.empty_like(h[:, 0:1]).fill_(t.item()) 129 | else: 130 | # t is different over the batch dimension. 131 | h_time = t[combined_mask] 132 | h = torch.cat([h, h_time], dim=1) 133 | condition_dim += 1 134 | 135 | if self.condition_nf > 0: 136 | h_condition = conditions[combined_mask] 137 | h = torch.cat([h, h_condition], dim=1) 138 | condition_dim += self.condition_nf 139 | 140 | subgraph_mask = get_subgraph_mask(edge_index, n_frag_switch) 141 | if self.update_pocket_coords: 142 | update_coords_mask = None 143 | else: 144 | raise NotImplementedError # no need to mask pos for inpainting mode. 145 | 146 | node_features = self.model( 147 | h, 148 | pos, 149 | edge_index, 150 | edge_attr, 151 | node_mask=None, 152 | edge_mask=None, 153 | update_coords_mask=update_coords_mask, 154 | subgraph_mask=subgraph_mask[:, None], 155 | ) # (n_node, n_hidden) 156 | 157 | graph_features = scatter_mean( 158 | node_features, 159 | index=combined_mask, 160 | dim=0, 161 | ) # (n_system, n_hidden) 162 | conf = self.readout(graph_features) 163 | return conf.squeeze() 164 | 165 | def forward( 166 | self, 167 | representations: List[Dict], 168 | conditions: Tensor, 169 | ): 170 | masks = [repre["mask"] for repre in representations] 171 | combined_mask = torch.cat(masks) 172 | edge_index = get_edges_index(combined_mask, remove_self_edge=True) 173 | fragments_nodes = [repr["size"] for repr in representations] 174 | n_frag_switch = get_n_frag_switch(fragments_nodes) 175 | 176 | xh = [ 177 | torch.cat( 178 | [repre[feature_type] for feature_type in FEATURE_MAPPING], 179 | dim=1, 180 | ) 181 | for repre in representations 182 | ] 183 | 184 | pred = self._forward( 185 | xh=xh, 186 | edge_index=edge_index, 187 | t=torch.tensor([0]), 188 | conditions=conditions, 189 | n_frag_switch=n_frag_switch, 190 | combined_mask=combined_mask, 191 | edge_attr=None, 192 | ) 193 | return pred 194 | -------------------------------------------------------------------------------- /reactot/dynamics/potential.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Dict, List, Optional, Tuple 3 | 4 | import torch 5 | from torch import nn, Tensor 6 | from torch.autograd import grad 7 | from torch_scatter import scatter_sum, scatter_mean 8 | from torch_geometric.data import Data 9 | 10 | from reactot.model import EGNN 11 | from reactot.model.core import GatedMLP 12 | from reactot.utils import ( 13 | get_subgraph_mask, 14 | get_n_frag_switch, 15 | get_mask_for_frag, 16 | get_edges_index, 17 | ) 18 | from ._base import BaseDynamics 19 | 20 | 21 | FEATURE_MAPPING = ["pos", "one_hot", "charges"] 22 | 23 | 24 | def remove_mean_batch(x, indices): 25 | mean = scatter_mean(x, indices, dim=0) 26 | x = x - mean[indices] 27 | return x 28 | 29 | 30 | class Potential(BaseDynamics): 31 | def __init__( 32 | self, 33 | model_config: Dict, 34 | fragment_names: List[str], 35 | node_nfs: List[int], 36 | edge_nf: int, 37 | condition_nf: int = 0, 38 | pos_dim: int = 3, 39 | edge_cutoff: Optional[float] = None, 40 | model: nn.Module = EGNN, 41 | device: torch.device = torch.device("cuda"), 42 | enforce_same_encoding: Optional[List] = None, 43 | source: Optional[Dict] = None, 44 | timesteps: int = 5000, 45 | condition_time: bool = True, 46 | **kwargs, 47 | ) -> None: 48 | r"""Confindence score for generated samples. 49 | 50 | Args: 51 | model_config (Dict): config for the equivariant model. 52 | fragment_names (List[str]): list of names for fragments 53 | node_nfs (List[int]): list of number of input node attributues. 54 | edge_nf (int): number of input edge attributes. 55 | condition_nf (int): number of attributes for conditional generation. 56 | Defaults to 0. 57 | pos_dim (int): dimension for position vector. Defaults to 3. 58 | update_pocket_coords (bool): whether to update positions of everything. 59 | Defaults to True. 60 | condition_time (bool): whether to condition on time. Defaults to True. 61 | edge_cutoff (Optional[float]): cutoff for building intra-fragment edges. 62 | Defaults to None. 63 | model (Optional[nn.Module]): Module for equivariant model. Defaults to None. 64 | """ 65 | model_config.update({"for_conf": False, "ff": True}) 66 | update_pocket_coords = True 67 | super().__init__( 68 | model_config, 69 | fragment_names, 70 | node_nfs, 71 | edge_nf, 72 | condition_nf, 73 | pos_dim, 74 | update_pocket_coords, 75 | condition_time, 76 | edge_cutoff, 77 | model, 78 | device, 79 | enforce_same_encoding, 80 | source=source, 81 | ) 82 | 83 | hidden_channels = model_config["hidden_channels"] 84 | self.readout = GatedMLP( 85 | in_dim=hidden_channels, 86 | out_dims=[hidden_channels, hidden_channels, 1], 87 | activation="swish", 88 | bias=True, 89 | last_layer_no_activation=True, 90 | ) 91 | self.timesteps = timesteps 92 | 93 | def _forward( 94 | self, 95 | xh: List[Tensor], 96 | edge_index: Tensor, 97 | t: Tensor, 98 | conditions: Tensor, 99 | n_frag_switch: Tensor, 100 | combined_mask: Tensor, 101 | edge_attr: Optional[Tensor] = None, 102 | ) -> Tensor: 103 | r"""predict confidence. 104 | 105 | Args: 106 | xh (List[Tensor]): list of concatenated tensors for pos and h 107 | edge_index (Tensor): [n_edge, 2] 108 | t (Tensor): time tensor. If dim is 1, same for all samples; 109 | otherwise different t for different samples 110 | conditions (Tensor): condition tensors 111 | n_frag_switch (Tensor): [n_nodes], fragment index for each nodes 112 | combined_mask (Tensor): [n_nodes], sample index for each node 113 | edge_attr (Optional[Tensor]): [n_edge, dim_edge_attribute]. Defaults to None. 114 | 115 | Raises: 116 | NotImplementedError: The fragement-position-fixed mode is not implement. 117 | 118 | Returns: 119 | Tensor: binary probability of confidence fo each graph. 120 | """ 121 | pos = torch.concat( 122 | [_xh[:, : self.pos_dim] for _xh in xh], 123 | dim=0, 124 | ) 125 | h = torch.concat( 126 | [ 127 | self.encoders[ii](xh[ii][:, self.pos_dim :]) 128 | for ii, name in enumerate(self.fragment_names) 129 | ], 130 | dim=0, 131 | ) 132 | if self.edge_encoder is not None: 133 | edge_attr = self.edge_encoder(edge_attr) 134 | 135 | condition_dim = 0 136 | if self.condition_time: 137 | if len(t.size()) == 1: 138 | # t is the same for all elements in batch. 139 | h_time = torch.empty_like(h[:, 0:1]).fill_(t.item()) 140 | else: 141 | # t is different over the batch dimension. 142 | h_time = t[combined_mask] 143 | h = torch.cat([h, h_time], dim=1) 144 | condition_dim += 1 145 | 146 | if self.condition_nf > 0: 147 | h_condition = conditions[combined_mask] 148 | h = torch.cat([h, h_condition], dim=1) 149 | condition_dim += self.condition_nf 150 | 151 | subgraph_mask = get_subgraph_mask(edge_index, n_frag_switch) 152 | if self.update_pocket_coords: 153 | update_coords_mask = None 154 | else: 155 | raise NotImplementedError # no need to mask pos for inpainting mode. 156 | 157 | node_features, forces = self.model( 158 | h, 159 | pos, 160 | edge_index, 161 | edge_attr, 162 | node_mask=None, 163 | edge_mask=None, 164 | update_coords_mask=update_coords_mask, 165 | subgraph_mask=subgraph_mask[:, None], 166 | ) # (n_node, n_hidden) 167 | 168 | node_features = self.readout(node_features) 169 | ae = scatter_sum( 170 | node_features, 171 | index=combined_mask, 172 | dim=0, 173 | ) # (n_system, n_hidden) 174 | return ae.squeeze(), forces 175 | 176 | def forward( 177 | self, 178 | pyg_batch: Data, 179 | conditions: Optional[Tensor] = None, 180 | ): 181 | masks = [pyg_batch.batch] 182 | combined_mask = torch.cat(masks) 183 | edge_index = get_edges_index(combined_mask, remove_self_edge=True) 184 | fragments_nodes = [pyg_batch.natoms] 185 | n_frag_switch = get_n_frag_switch(fragments_nodes) 186 | conditions = conditions or torch.zeros(pyg_batch.ae.size(0), 1, dtype=torch.long) 187 | conditions = conditions.to(pyg_batch.batch.device) 188 | 189 | pyg_batch.pos = remove_mean_batch(pyg_batch.pos, pyg_batch.batch) 190 | 191 | xh = [ 192 | torch.cat( 193 | [pyg_batch.pos, pyg_batch.one_hot, pyg_batch.charges.view(-1, 1)], 194 | dim=1, 195 | ) 196 | ] 197 | 198 | t = torch.randint(0, self.timesteps, size=(1,)) / self.timesteps 199 | 200 | ae, forces = self._forward( 201 | xh=xh, 202 | edge_index=edge_index, 203 | t=torch.tensor([0.]), 204 | conditions=conditions, 205 | n_frag_switch=n_frag_switch, 206 | combined_mask=combined_mask, 207 | edge_attr=None, 208 | ) 209 | return ae, forces 210 | 211 | def _forward_autograd( 212 | self, 213 | h: List[Tensor], 214 | pos: Tensor, 215 | edge_index: Tensor, 216 | t: Tensor, 217 | conditions: Tensor, 218 | n_frag_switch: Tensor, 219 | combined_mask: Tensor, 220 | edge_attr: Optional[Tensor] = None, 221 | ) -> Tensor: 222 | r"""predict confidence. 223 | 224 | Args: 225 | xh (List[Tensor]): list of concatenated tensors for pos and h 226 | edge_index (Tensor): [n_edge, 2] 227 | t (Tensor): time tensor. If dim is 1, same for all samples; 228 | otherwise different t for different samples 229 | conditions (Tensor): condition tensors 230 | n_frag_switch (Tensor): [n_nodes], fragment index for each nodes 231 | combined_mask (Tensor): [n_nodes], sample index for each node 232 | edge_attr (Optional[Tensor]): [n_edge, dim_edge_attribute]. Defaults to None. 233 | 234 | Raises: 235 | NotImplementedError: The fragement-position-fixed mode is not implement. 236 | 237 | Returns: 238 | Tensor: binary probability of confidence fo each graph. 239 | """ 240 | h = torch.concat( 241 | [ 242 | self.encoders[ii](h[ii]) 243 | for ii, name in enumerate(self.fragment_names) 244 | ], 245 | dim=0, 246 | ) 247 | if self.edge_encoder is not None: 248 | edge_attr = self.edge_encoder(edge_attr) 249 | 250 | condition_dim = 0 251 | if self.condition_time: 252 | if len(t.size()) == 1: 253 | # t is the same for all elements in batch. 254 | h_time = torch.empty_like(h[:, 0:1]).fill_(t.item()) 255 | else: 256 | # t is different over the batch dimension. 257 | h_time = t[combined_mask] 258 | h = torch.cat([h, h_time], dim=1) 259 | condition_dim += 1 260 | 261 | if self.condition_nf > 0: 262 | h_condition = conditions[combined_mask] 263 | h = torch.cat([h, h_condition], dim=1) 264 | condition_dim += self.condition_nf 265 | 266 | subgraph_mask = get_subgraph_mask(edge_index, n_frag_switch) 267 | if self.update_pocket_coords: 268 | update_coords_mask = None 269 | else: 270 | raise NotImplementedError # no need to mask pos for inpainting mode. 271 | 272 | node_features, forces = self.model( 273 | h, 274 | pos, 275 | edge_index, 276 | edge_attr, 277 | node_mask=None, 278 | edge_mask=None, 279 | update_coords_mask=update_coords_mask, 280 | subgraph_mask=subgraph_mask[:, None], 281 | ) # (n_node, n_hidden) 282 | 283 | node_features = self.readout(node_features) 284 | ae = scatter_sum( 285 | node_features, 286 | index=combined_mask, 287 | dim=0, 288 | ) # (n_system, n_hidden) 289 | return ae.squeeze(), forces 290 | 291 | @torch.enable_grad() 292 | def forward_autograd( 293 | self, 294 | pyg_batch: Data, 295 | conditions: Optional[Tensor] = None, 296 | ): 297 | masks = [pyg_batch.batch] 298 | combined_mask = torch.cat(masks) 299 | edge_index = get_edges_index(combined_mask, remove_self_edge=True) 300 | fragments_nodes = [pyg_batch.natoms] 301 | n_frag_switch = get_n_frag_switch(fragments_nodes) 302 | conditions = conditions or torch.zeros(pyg_batch.ae.size(0), 1, dtype=torch.long) 303 | conditions = conditions.to(pyg_batch.batch.device) 304 | 305 | pyg_batch.pos = remove_mean_batch(pyg_batch.pos, pyg_batch.batch) 306 | pyg_batch.pos.requires_grad_(True) 307 | 308 | h = [ 309 | torch.cat( 310 | [pyg_batch.one_hot, pyg_batch.charges.view(-1, 1)], 311 | dim=1, 312 | ).float() 313 | ] 314 | 315 | t = torch.randint(0, self.timesteps, size=(1,)) / self.timesteps 316 | 317 | ae, forces = self._forward_autograd( 318 | h=h, 319 | pos=pyg_batch.pos, 320 | edge_index=edge_index, 321 | t=torch.tensor([0.]), 322 | conditions=conditions, 323 | n_frag_switch=n_frag_switch, 324 | combined_mask=combined_mask, 325 | edge_attr=None, 326 | ) 327 | forces = -grad( 328 | torch.sum(ae), 329 | pyg_batch.pos, 330 | create_graph=self.training, 331 | )[0] 332 | return ae, forces 333 | -------------------------------------------------------------------------------- /reactot/evaluate/evaluate_rmsd_vs_ediff.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import time 3 | import os 4 | import numpy as np 5 | import time 6 | import os 7 | import numpy as np 8 | import torch 9 | import pickle 10 | import argparse 11 | from uuid import uuid4 12 | 13 | from torch.utils.data import DataLoader 14 | 15 | from reactot.trainer.pl_trainer import DDPMModule 16 | from reactot.dataset.transition1x import ProcessedTS1x 17 | from reactot.analyze.rmsd import batch_rmsd 18 | from reactot.analyze.geomopt import calc_deltaE, compute_efh 19 | from reactot.evaluate.utils import ( 20 | set_new_schedule, 21 | inplaint_batch, 22 | batch_ts_deltaE, 23 | ) 24 | from reactot.utils.sampling_tools import write_tmp_xyz 25 | 26 | EV2KCALMOL = 23.06 27 | AU2KCALMOL = 627.5 28 | 29 | 30 | def save_pickle(filename, rmsds, deltaEs): 31 | with open(f"results/{filename}", "wb") as fo: 32 | pickle.dump( 33 | { 34 | "rmsd": rmsds, 35 | "ts_deltaE": deltaEs, 36 | }, 37 | fo 38 | ) 39 | 40 | 41 | config = dict( 42 | timesteps=150, 43 | bz=16, 44 | resamplings=5, 45 | jump_length=5, 46 | repeats=1, 47 | max_batch=-1, 48 | shuffle=False, 49 | single_frag_only=False, 50 | ) 51 | 52 | 53 | filename = "" 54 | for k, v in config.items(): 55 | filename += f"rmsdvsEdiff_{k}-{v}_" 56 | filename += ".pkl" 57 | print(filename) 58 | 59 | print("loading ddpm trainer...") 60 | device = torch.device("cuda") 61 | tspath = "/anfhome/crduan/diff/TSDiffusion/reactot/trainer/checkpoint/TSDiffusion-TS1x" 62 | checkpoints = { 63 | "repro": f"{tspath}/Adam-AF-False-LRSNone-EENone-PT-TS1x-Edge-SinEmTrue-L9-Unit128-NF1.0-AGGmean-SubL1-rmHFalse-BZ64-Norm1.0_1.0_1.0-emaFalse-POTrue-Decay0-SFTrue/ddpm-epoch=4379-val-totloss=602.40.ckpt", 64 | "chiral": f"{tspath}/5edcbc9baced/ddpm-epoch=4159-val-totloss=585.78.ckpt", 65 | "leftnet_legacy": f"{tspath}/leftnet-78c7590798bc/ddpm-epoch=1059-val-totloss=648.90.ckpt", 66 | "leftnet_2074": f"{tspath}-All/leftnet-8-70b75beeaac1/ddpm-epoch=2074-val-totloss=531.18.ckpt", 67 | } 68 | ddpm_trainer = DDPMModule.load_from_checkpoint( 69 | checkpoint_path=checkpoints["leftnet_2074"], 70 | map_location=device, 71 | ) 72 | ddpm_trainer = set_new_schedule( 73 | ddpm_trainer, 74 | timesteps=config["timesteps"], 75 | noise_schedule="polynomial_2", 76 | ) 77 | 78 | print("loading dataset...") 79 | dataset = ProcessedTS1x( 80 | npz_path="../data/transition1x/valid.pkl", 81 | center=True, 82 | pad_fragments=0, 83 | device="cuda", 84 | zero_charge=False, 85 | remove_h=False, 86 | single_frag_only=config["single_frag_only"], 87 | swapping_react_prod=False, 88 | use_by_ind=True, 89 | ) 90 | loader = DataLoader( 91 | dataset, 92 | batch_size=config["bz"], 93 | shuffle=config["shuffle"], 94 | num_workers=0, 95 | collate_fn=dataset.collate_fn 96 | ) 97 | _id = uuid4() 98 | localpath = "tmp/" + str(_id) 99 | os.makedirs(localpath) 100 | 101 | print("evaluating...") 102 | rmsds, deltaEs = [], [] 103 | TSEs = [] 104 | for num_repeat in range(config["repeats"]): 105 | print("num_repeat: ", num_repeat) 106 | _rmsds, _deltaEs = [], [] 107 | for ii, batch in enumerate(loader): 108 | print("batch_idx: ", ii) 109 | time_start = time.time() 110 | if ii == config["max_batch"]: 111 | break 112 | out_samples, xh_fixed, fragments_nodes = inplaint_batch( 113 | batch, 114 | ddpm_trainer, 115 | resamplings=config["resamplings"], 116 | jump_length=config["jump_length"], 117 | ) 118 | write_tmp_xyz(fragments_nodes, out_samples, idx=[0, 1, 2], localpath=localpath) 119 | write_tmp_xyz(fragments_nodes, xh_fixed, idx=[1], prefix="sample", localpath=localpath) 120 | _rmsds += batch_rmsd( 121 | fragments_nodes, 122 | out_samples, 123 | xh_fixed, 124 | idx=1, 125 | threshold=0.5, 126 | ) 127 | print("time cost: ", time.time() - time_start) 128 | print("rmsds: ", [round(_x, 2) for _x in _rmsds], np.mean(_rmsds), np.std(_rmsds)) 129 | 130 | _deltaEs += batch_ts_deltaE(config["bz"], xc="wb97x", localpath=localpath) 131 | print("deltaEs: ", [round(_x, 2) for _x in _deltaEs], np.mean(np.abs(_deltaEs))) 132 | save_pickle(filename, _rmsds, _deltaEs) 133 | rmsds.append(_rmsds) 134 | deltaEs.append(_deltaEs) 135 | save_pickle(filename, rmsds, deltaEs) 136 | -------------------------------------------------------------------------------- /reactot/evaluate/evaluate_ts_w_rp.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import time 3 | import os 4 | import numpy as np 5 | import torch 6 | import pickle 7 | import argparse 8 | 9 | from torch.utils.data import DataLoader 10 | 11 | from reactot.trainer.pl_trainer import DDPMModule 12 | from reactot.dataset.transition1x import ProcessedTS1x 13 | from reactot.analyze.rmsd import batch_rmsd 14 | from reactot.evaluate.utils import ( 15 | set_new_schedule, 16 | inplaint_batch, 17 | ) 18 | 19 | EV2KCALMOL = 23.06 20 | AU2KCALMOL = 627.5 21 | 22 | 23 | parser = argparse.ArgumentParser(description="get training params") 24 | parser.add_argument( 25 | "--bz", dest="bz", default=64, type=int, help="batch size" 26 | ) 27 | parser.add_argument( 28 | "--timesteps", dest="timesteps", default=250, type=int, help="timesteps" 29 | ) 30 | parser.add_argument( 31 | "--resamplings", dest="resamplings", default=5, type=int, help="resamplings" 32 | ) 33 | parser.add_argument( 34 | "--jump_length", dest="jump_length", default=5, type=int, help="jump_length" 35 | ) 36 | parser.add_argument( 37 | "--repeats", dest="repeats", default=5, type=int, help="repeats" 38 | ) 39 | parser.add_argument( 40 | "--partition", dest="partition", default="valid", type=str, help="partition" 41 | ) 42 | parser.add_argument( 43 | "--single_frag_only", dest="single_frag_only", default=1, type=int, help="single_frag_only" 44 | ) 45 | parser.add_argument( 46 | "--model", dest="model", default="leftnet", type=str, help="model" 47 | ) 48 | parser.add_argument( 49 | "--power", dest="power", default="2", type=str, help="power" 50 | ) 51 | 52 | args = parser.parse_args() 53 | print("args: ", args) 54 | 55 | config = dict( 56 | model=args.model, 57 | partition=args.partition, 58 | timesteps=args.timesteps, 59 | bz=args.bz, 60 | resamplings=args.resamplings, 61 | jump_length=args.jump_length, 62 | repeats=args.repeats, 63 | max_batch=-1, 64 | shuffle=False, 65 | single_frag_only=args.single_frag_only, 66 | noise_schedule="polynomial_" + args.power, 67 | ) 68 | 69 | filename = "" 70 | for k, v in config.items(): 71 | filename += f"{k}-{v}_" 72 | filename += ".pkl" 73 | print(filename) 74 | 75 | print("loading ddpm trainer...") 76 | device = torch.device("cuda") 77 | tspath = "/anfhome/crduan/diff/TSDiffusion/reactot/trainer/checkpoint/TSDiffusion-TS1x" 78 | checkpoints = { 79 | "chiral": f"{tspath}/5edcbc9baced/ddpm-epoch=4159-val-totloss=585.78.ckpt", 80 | "leftnet_legacy": f"{tspath}/leftnet-78c7590798bc/ddpm-epoch=1059-val-totloss=648.90.ckpt", 81 | "leftnet4": f"{tspath}/leftnet-4-48f308df7ec4/ddpm-epoch=809-val-totloss=587.27.ckpt", 82 | "leftnet_all": f"{tspath}-All/leftnet-4-77ae3fd23222/ddpm-epoch=1619-val-totloss=605.85.ckpt", 83 | # "leftnet_final": f"{tspath}-All/leftnet-8-17cf1d7b9324/ddpm-epoch=1289-val-totloss=536.86.ckpt", 84 | "leftnet_final": f"{tspath}-All/leftnet-8-70b75beeaac1/ddpm-epoch=1274-val-totloss=519.81.ckpt", 85 | "leftnet_1654": f"{tspath}-All/leftnet-8-70b75beeaac1/ddpm-epoch=1654-val-totloss=540.84.ckpt", 86 | "leftnet_1884": f"{tspath}-All/leftnet-8-70b75beeaac1/ddpm-epoch=1884-val-totloss=549.61.ckpt", 87 | "leftnet_2074": f"{tspath}-All/leftnet-8-70b75beeaac1/ddpm-epoch=2074-val-totloss=531.18.ckpt", 88 | "leftnet_2304": f"{tspath}-All/leftnet-8-70b75beeaac1/ddpm-epoch=2304-val-totloss=524.65.ckpt", 89 | } 90 | ddpm_trainer = DDPMModule.load_from_checkpoint( 91 | checkpoint_path=checkpoints[config["model"]], 92 | map_location=device, 93 | ) 94 | ddpm_trainer = set_new_schedule( 95 | ddpm_trainer, 96 | timesteps=config["timesteps"], 97 | noise_schedule=config["noise_schedule"] 98 | ) 99 | 100 | print("loading dataset...") 101 | dataset = ProcessedTS1x( 102 | npz_path=f"../data/transition1x/{args.partition}.pkl", 103 | center=True, 104 | pad_fragments=0, 105 | device="cuda", 106 | zero_charge=False, 107 | remove_h=False, 108 | single_frag_only=config["single_frag_only"], 109 | swapping_react_prod=False, 110 | use_by_ind=True, 111 | ) 112 | print("# of points:", len(dataset) ) 113 | loader = DataLoader( 114 | dataset, 115 | batch_size=config["bz"], 116 | shuffle=config["shuffle"], 117 | num_workers=0, 118 | collate_fn=dataset.collate_fn 119 | ) 120 | 121 | print("evaluating...") 122 | if os.path.isfile(f"results/{filename}"): 123 | d = pickle.load(open(f"results/{filename}", "rb")) 124 | rmsds = d["rmsd"] 125 | else: 126 | rmsds = [] 127 | TSEs = [] 128 | for num_repeat in range(config["repeats"]): 129 | print("num_repeat: ", num_repeat) 130 | _rmsds, _genEs = [], [] 131 | for ii, batch in enumerate(loader): 132 | print("batch_idx: ", ii) 133 | time_start = time.time() 134 | if ii == config["max_batch"]: 135 | break 136 | out_samples, xh_fixed, fragments_nodes = inplaint_batch( 137 | batch, 138 | ddpm_trainer, 139 | resamplings=config["resamplings"], 140 | jump_length=config["jump_length"], 141 | frag_fixed=[0, 2] 142 | ) 143 | # write_tmp_xyz(fragments_nodes, out_samples, idx=[0, 1, 2]) 144 | # write_tmp_xyz(fragments_nodes, xh_fixed, idx=[1], prefix="sample") 145 | _rmsds += batch_rmsd( 146 | fragments_nodes, 147 | out_samples, 148 | xh_fixed, 149 | idx=1, 150 | threshold=0.5, 151 | ) 152 | print("time cost: ", time.time() - time_start) 153 | print("rmsds: ", [round(_x, 2) for _x in _rmsds], np.mean(_rmsds), np.std(_rmsds)) 154 | rmsds.append(_rmsds) 155 | 156 | with open(f"results/{filename}", "wb") as fo: 157 | pickle.dump( 158 | { 159 | "rmsd": rmsds, 160 | }, 161 | fo 162 | ) 163 | -------------------------------------------------------------------------------- /reactot/evaluate/generate_confidence_sample.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import time 3 | import os 4 | import numpy as np 5 | import torch 6 | import pickle 7 | import argparse 8 | from uuid import uuid4 9 | 10 | from torch.utils.data import DataLoader 11 | 12 | from reactot.trainer.pl_trainer import DDPMModule 13 | from reactot.dataset.transition1x import ProcessedTS1x 14 | from reactot.analyze.rmsd import batch_rmsd 15 | from reactot.evaluate.utils import ( 16 | set_new_schedule, 17 | inplaint_batch, 18 | samples_to_pos_charge, 19 | ) 20 | 21 | EV2KCALMOL = 23.06 22 | AU2KCALMOL = 627.5 23 | 24 | 25 | def assemble_filename(config): 26 | _id = str(uuid4()).split("-")[0] 27 | filename = f"conf-uuid-{_id}-" 28 | for k, v in config.items(): 29 | filename += f"{k}-{v}_" 30 | filename += ".pkl" 31 | print(filename) 32 | return filename 33 | 34 | 35 | parser = argparse.ArgumentParser(description="get training params") 36 | parser.add_argument( 37 | "--bz", dest="bz", default=32, type=int, help="batch size" 38 | ) 39 | parser.add_argument( 40 | "--timesteps", dest="timesteps", default=150, type=int, help="timesteps" 41 | ) 42 | parser.add_argument( 43 | "--resamplings", dest="resamplings", default=2, type=int, help="resamplings" 44 | ) 45 | parser.add_argument( 46 | "--jump_length", dest="jump_length", default=2, type=int, help="jump_length" 47 | ) 48 | parser.add_argument( 49 | "--repeats", dest="repeats", default=1, type=int, help="repeats" 50 | ) 51 | parser.add_argument( 52 | "--partition", dest="partition", default="valid", type=str, help="partition" 53 | ) 54 | parser.add_argument( 55 | "--dataset", dest="dataset", default="transition1x", type=str, help="dataset" 56 | ) 57 | parser.add_argument( 58 | "--single_frag_only", dest="single_frag_only", default=0, type=int, help="single_frag_only" 59 | ) 60 | parser.add_argument( 61 | "--model", dest="model", default="leftnet_2074", type=str, help="model" 62 | ) 63 | parser.add_argument( 64 | "--power", dest="power", default="2", type=str, help="power" 65 | ) 66 | parser.add_argument( 67 | "--position_key", dest="position_key", default="positions", type=str, help="position_key" 68 | ) 69 | 70 | args = parser.parse_args() 71 | print("args: ", args) 72 | 73 | config = dict( 74 | model=args.model, 75 | dataset=args.dataset, 76 | partition=args.partition, 77 | timesteps=args.timesteps, 78 | bz=args.bz, 79 | resamplings=args.resamplings, 80 | jump_length=args.jump_length, 81 | repeats=args.repeats, 82 | max_batch=-1, 83 | shuffle=True, 84 | single_frag_only=args.single_frag_only, 85 | noise_schedule="polynomial_" + args.power, 86 | position_key=args.position_key 87 | ) 88 | 89 | print("loading ddpm trainer...") 90 | device = torch.device("cuda") 91 | tspath = "/home/ubuntu/efs/TSDiffusion/reactot/trainer/ckpt/TSDiffusion-TS1x-All" 92 | checkpoints = { 93 | "leftnet_2074": f"{tspath}/leftnet-8-70b75beeaac1/ddpm-epoch=2074-val-totloss=531.18.ckpt", 94 | "egnn": f"{tspath}/egnn-1-7d0e388fa0fd/ddpm-epoch=759-val-totloss=616.42.ckpt", 95 | "leftnet_wo_oa": f"{tspath}/leftnet-10-da396de30744_wo_oa/ddpm-epoch=149-val-totloss=600.87.ckpt", 96 | "leftnet_wo_oa_aligned": f"{tspath}/leftnet-10-d13a2c2bace6_wo_oa_align/ddpm-epoch=779-val-totloss=747.10.ckpt", 97 | "leftnet_wo_oa_aligned_early": f"{tspath}/leftnet-10-d13a2c2bace6_wo_oa_align/ddpm-epoch=719-val-totloss=680.64.ckpt" 98 | } 99 | ddpm_trainer = DDPMModule.load_from_checkpoint( 100 | checkpoint_path=checkpoints[config["model"]], 101 | map_location=device, 102 | ) 103 | ddpm_trainer = set_new_schedule( 104 | ddpm_trainer, 105 | timesteps=config["timesteps"], 106 | noise_schedule=config["noise_schedule"] 107 | ) 108 | 109 | print("loading dataset...") 110 | dataset = ProcessedTS1x( 111 | npz_path=f"../data/{args.dataset}/{args.partition}.pkl", 112 | center=True, 113 | pad_fragments=0, 114 | device="cuda", 115 | zero_charge=False, 116 | remove_h=False, 117 | single_frag_only=config["single_frag_only"], 118 | swapping_react_prod=False, 119 | use_by_ind=True, 120 | position_key=config["position_key"], 121 | ) 122 | print("# of points:", len(dataset)) 123 | loader = DataLoader( 124 | dataset, 125 | batch_size=config["bz"], 126 | shuffle=config["shuffle"], 127 | num_workers=0, 128 | collate_fn=dataset.collate_fn 129 | ) 130 | 131 | print("evaluating...") 132 | speices = ["reactant", "transition_state", "product"] 133 | keys = ['num_atoms', 'charges', 'position'] 134 | 135 | for num_repeat in range(config["repeats"]): 136 | print("num_repeat: ", num_repeat) 137 | _rmsds, _genEs = [], [] 138 | filename = assemble_filename(config) 139 | 140 | data = {} 141 | for s in speices: 142 | data[s] = {} 143 | for k in keys: 144 | data[s][k] = [] 145 | for s in ["target", "rmsd"]: 146 | data[s] = [] 147 | 148 | for ii, batch in enumerate(loader): 149 | print("batch_idx: ", ii) 150 | time_start = time.time() 151 | if ii == config["max_batch"]: 152 | break 153 | 154 | # TS gen 155 | out_samples, xh_fixed, fragments_nodes = inplaint_batch( 156 | batch, 157 | ddpm_trainer, 158 | resamplings=config["resamplings"], 159 | jump_length=config["jump_length"], 160 | frag_fixed=[0, 2] 161 | ) 162 | pos, z, natoms = samples_to_pos_charge(out_samples, fragments_nodes) 163 | _rmsds = batch_rmsd( 164 | fragments_nodes, 165 | out_samples, 166 | xh_fixed, 167 | idx=1, 168 | threshold=0.5, 169 | ) 170 | for s in speices: 171 | data[s]["position"] += pos[s] 172 | data[s]["charges"] += z 173 | data[s]["num_atoms"] += natoms 174 | data["rmsd"] += _rmsds 175 | data["target"] += [1 if _r < 0.2 else 0 for _r in _rmsds] 176 | print("time cost: ", time.time() - time_start) 177 | print("rmsds: ", [round(_x, 2) for _x in data["rmsd"]], np.mean(data["rmsd"]), np.median(data["rmsd"])) 178 | 179 | # # R gen 180 | # out_samples, xh_fixed, fragments_nodes = inplaint_batch( 181 | # batch, 182 | # ddpm_trainer, 183 | # resamplings=config["resamplings"] * 2, 184 | # jump_length=config["jump_length"] * 2, 185 | # frag_fixed=[1, 2] 186 | # ) 187 | # pos, z, natoms = samples_to_pos_charge(out_samples, fragments_nodes) 188 | # _rmsds = batch_rmsd( 189 | # fragments_nodes, 190 | # out_samples, 191 | # xh_fixed, 192 | # idx=0, 193 | # threshold=0.5, 194 | # ) 195 | # for s in speices: 196 | # data[s]["position"] += pos[s] 197 | # data[s]["charges"] += z 198 | # data[s]["num_atoms"] += natoms 199 | # data["rmsd"] += _rmsds 200 | # data["target"] += [1 if _r < 0.2 else 0 for _r in _rmsds] 201 | # print("time cost: ", time.time() - time_start) 202 | # print("rmsds: ", [round(_x, 2) for _x in _rmsds], np.mean(_rmsds), np.std(_rmsds)) 203 | 204 | with open(f"samples/{filename}", "wb") as fo: 205 | pickle.dump(data, fo) 206 | -------------------------------------------------------------------------------- /reactot/evaluate/generate_on_example_path.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import time 3 | import os 4 | import numpy as np 5 | import torch 6 | import pickle 7 | import argparse 8 | from uuid import uuid4 9 | 10 | from torch.utils.data import DataLoader 11 | 12 | from reactot.trainer.pl_trainer import DDPMModule 13 | from reactot.dataset.transition1x import ProcessedTS1x 14 | from reactot.analyze.rmsd import batch_rmsd 15 | from reactot.evaluate.utils import ( 16 | set_new_schedule, 17 | inplaint_batch, 18 | samples_to_pos_charge, 19 | ) 20 | from reactot.utils.sampling_tools import ( 21 | assemble_sample_inputs, 22 | write_single_xyz, 23 | write_tmp_xyz, 24 | ) 25 | from reactot.diffusion._normalizer import FEATURE_MAPPING 26 | 27 | EV2KCALMOL = 23.06 28 | AU2KCALMOL = 627.5 29 | 30 | 31 | def assemble_filename(config): 32 | _id = str(uuid4()).split("-")[0] 33 | filename = f"conf-uuid-{_id}-" 34 | for k, v in config.items(): 35 | filename += f"{k}-{v}_" 36 | filename += ".pkl" 37 | print(filename) 38 | return filename 39 | 40 | 41 | parser = argparse.ArgumentParser(description="get training params") 42 | parser.add_argument( 43 | "--bz", dest="bz", default=64, type=int, help="batch size" 44 | ) 45 | parser.add_argument( 46 | "--timesteps", dest="timesteps", default=150, type=int, help="timesteps" 47 | ) 48 | parser.add_argument( 49 | "--resamplings", dest="resamplings", default=15, type=int, help="resamplings" 50 | ) 51 | parser.add_argument( 52 | "--jump_length", dest="jump_length", default=15, type=int, help="jump_length" 53 | ) 54 | parser.add_argument( 55 | "--repeats", dest="repeats", default=2, type=int, help="repeats" 56 | ) 57 | parser.add_argument( 58 | "--partition", dest="partition", default="valid", type=str, help="partition" 59 | ) 60 | parser.add_argument( 61 | "--single_frag_only", dest="single_frag_only", default=0, type=int, help="single_frag_only" 62 | ) 63 | parser.add_argument( 64 | "--model", dest="model", default="leftnet_2074", type=str, help="model" 65 | ) 66 | parser.add_argument( 67 | "--power", dest="power", default="2", type=str, help="power" 68 | ) 69 | 70 | args = parser.parse_args() 71 | print("args: ", args) 72 | 73 | name = "sample_all_12" 74 | chosen_idx = int(name.split("_")[-1]) 75 | 76 | if not os.path.isdir(name): 77 | os.makedirs(name) 78 | 79 | config = dict( 80 | model=args.model, 81 | partition=args.partition, 82 | timesteps=args.timesteps, 83 | bz=args.bz, 84 | resamplings=args.resamplings, 85 | jump_length=args.jump_length, 86 | repeats=args.repeats, 87 | max_batch=-1, 88 | shuffle=False, 89 | single_frag_only=args.single_frag_only, 90 | noise_schedule="polynomial_" + args.power, 91 | ) 92 | 93 | print("loading ddpm trainer...") 94 | device = torch.device("cuda") 95 | tspath = "/anfhome/crduan/diff/TSDiffusion/reactot/trainer/checkpoint/TSDiffusion-TS1x" 96 | checkpoints = { 97 | "leftnet_2074": f"{tspath}-All/leftnet-8-70b75beeaac1/ddpm-epoch=2074-val-totloss=531.18.ckpt", 98 | } 99 | ddpm_trainer = DDPMModule.load_from_checkpoint( 100 | checkpoint_path=checkpoints[config["model"]], 101 | map_location=device, 102 | ) 103 | ddpm_trainer = set_new_schedule( 104 | ddpm_trainer, 105 | timesteps=config["timesteps"], 106 | noise_schedule=config["noise_schedule"] 107 | ) 108 | 109 | dataset = ProcessedTS1x( 110 | npz_path="../data/transition1x/valid.pkl", 111 | center=True, 112 | pad_fragments=0, 113 | device="cuda", 114 | zero_charge=False, 115 | remove_h=False, 116 | single_frag_only=False, 117 | swapping_react_prod=False, 118 | use_by_ind=True, 119 | # confidence_model=True, 120 | ) 121 | loader = DataLoader( 122 | dataset, 123 | batch_size=1, 124 | shuffle=False, 125 | num_workers=0, 126 | collate_fn=dataset.collate_fn 127 | ) 128 | 129 | speices = ["reactant", "transition_state", "product"] 130 | keys = ['num_atoms', 'charges', 'positions'] 131 | if not os.path.isfile(f"../data/transition1x/examples/{name}.pkl"): 132 | data = {} 133 | for s in speices: 134 | data[s] = {} 135 | for k in keys: 136 | data[s][k] = [] 137 | for s in ["target", "rmsd", "single_fragment"]: 138 | data[s] = [] 139 | else: 140 | data = pickle.load(open(f"../data/transition1x/examples/{name}.pkl", "rb")) 141 | 142 | print("sampling...") 143 | n_samples = 128 144 | ex_ind = 0 145 | 146 | for batch_idx, batch in enumerate(loader): 147 | if batch_idx != chosen_idx: 148 | continue 149 | 150 | representations, res = batch 151 | duplicates = n_samples 152 | representations_duplicated = [] 153 | for ii, repr in enumerate(representations): 154 | tmp = {} 155 | for k, v in repr.items(): 156 | # print(ii, v) 157 | if not k == "mask": 158 | tmp[k] = torch.cat([v] * duplicates) 159 | else: 160 | tmp[k] = torch.arange(duplicates).repeat_interleave(repr["size"].item()) 161 | representations_duplicated.append(tmp) 162 | 163 | xh_fixed = [ 164 | torch.cat( 165 | [repre[feature_type] for feature_type in FEATURE_MAPPING], 166 | dim=1, 167 | ) 168 | for repre in representations_duplicated 169 | ] 170 | n_samples = representations_duplicated[0]["size"].size(0) 171 | fragments_nodes = [ 172 | repre["size"] for repre in representations_duplicated 173 | ] 174 | conditions = torch.tensor([[0] for _ in range(duplicates)], device=device) 175 | write_tmp_xyz(fragments_nodes, xh_fixed, idx=[0, 1, 2], prefix="sample", localpath=name, ex_ind=ex_ind) 176 | 177 | out_samples, out_masks = ddpm_trainer.ddpm.inpaint( 178 | n_samples=n_samples, 179 | fragments_nodes=fragments_nodes, 180 | conditions=conditions, 181 | return_frames=1, 182 | resamplings=config["resamplings"], 183 | jump_length=config["jump_length"], 184 | timesteps=None, 185 | xh_fixed=xh_fixed, 186 | frag_fixed=[0, 2], 187 | ) 188 | 189 | out_samples = out_samples[0] 190 | write_tmp_xyz(fragments_nodes, out_samples, idx=[0, 1, 2], localpath=name, ex_ind=ex_ind) 191 | pos, z, natoms = samples_to_pos_charge(out_samples, fragments_nodes) 192 | for s in speices: 193 | data[s]["positions"] += pos[s] 194 | data[s]["charges"] += z 195 | data[s]["num_atoms"] += natoms 196 | data["rmsd"] += [-1] * n_samples 197 | data["target"] += [-1] * n_samples 198 | data["single_fragment"] += [0] * n_samples 199 | 200 | with open(f"../data/transition1x/examples/{name}.pkl", "wb") as fo: 201 | pickle.dump(data, fo) 202 | -------------------------------------------------------------------------------- /reactot/evaluate/run_confidence_sample.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=$1 2 | export timesteps=150 3 | export resamplings=15 4 | export jump_length=15 5 | export dataset="transition1x" 6 | export partition="train_addprop" 7 | export single_frag_only=0 8 | export model="leftnet_2074" 9 | export power="2" 10 | export position_key="positions" 11 | 12 | save_path=nohupout/conf-$model-$dataset-$partition-timesteps-$timesteps-resamplings-$resamplings-single_frag_only-$single_frag_only-power-$power.out 13 | echo "save path: " $save_path 14 | 15 | nohup python -u generate_confidence_sample.py \ 16 | --timesteps $timesteps \ 17 | --resamplings $resamplings \ 18 | --jump_length $jump_length \ 19 | --partition $partition \ 20 | --dataset $dataset \ 21 | --single_frag_only $single_frag_only \ 22 | --model $model \ 23 | --power $power \ 24 | --position_key $position_key \ 25 | > $save_path 2>&1 & 26 | -------------------------------------------------------------------------------- /reactot/evaluate/run_eva_ts_e_rp.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | export timesteps=150 3 | export resamplings=10 4 | export jump_length=10 5 | export partition="valid" 6 | export single_frag_only=0 7 | export model="leftnet_2304" 8 | export power="2.5" 9 | 10 | save_path=nohupout/$model-$partition-timesteps-$timesteps-resamplings-$resamplings-single_frag_only-$single_frag_only-power-$power.out 11 | 12 | nohup python -u evaluate_ts_w_rp.py \ 13 | --timesteps $timesteps \ 14 | --resamplings $resamplings \ 15 | --jump_length $jump_length \ 16 | --partition $partition \ 17 | --single_frag_only $single_frag_only \ 18 | --model $model \ 19 | --power $power \ 20 | > $save_path 2>&1 & -------------------------------------------------------------------------------- /reactot/evaluate/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch 3 | 4 | from reactot.trainer.pl_trainer import DDPMModule 5 | from reactot.diffusion._schedule import DiffSchedule, PredefinedNoiseSchedule 6 | from reactot.diffusion._normalizer import FEATURE_MAPPING 7 | from reactot.analyze.geomopt import calc_deltaE, compute_efh 8 | 9 | EV2KCALMOL = 23.06 10 | AU2KCALMOL = 627.5 11 | device = torch.device("cuda") 12 | 13 | 14 | def set_new_schedule( 15 | ddpm_trainer: DDPMModule, 16 | timesteps: int = 250, 17 | device: torch.device = torch.device("cuda"), 18 | noise_schedule: str = "polynomial_2" 19 | ) -> DDPMModule: 20 | precision: float = 1e-5 21 | 22 | gamma_module = PredefinedNoiseSchedule( 23 | noise_schedule=noise_schedule, 24 | timesteps=timesteps, 25 | precision=precision, 26 | ) 27 | schedule = DiffSchedule( 28 | gamma_module=gamma_module, 29 | norm_values=ddpm_trainer.ddpm.norm_values 30 | ) 31 | ddpm_trainer.ddpm.schedule = schedule 32 | ddpm_trainer.ddpm.T = timesteps 33 | return ddpm_trainer.to(device) 34 | 35 | 36 | def inplaint_batch( 37 | batch: List, 38 | ddpm_trainer: DDPMModule, 39 | resamplings: int = 1, 40 | jump_length: int = 1, 41 | frag_fixed: List = [0, 2], 42 | ): 43 | representations, conditions = batch 44 | xh_fixed = [ 45 | torch.cat( 46 | [repre[feature_type] for feature_type in FEATURE_MAPPING], 47 | dim=1, 48 | ) 49 | for repre in representations 50 | ] 51 | n_samples = representations[0]["size"].size(0) 52 | fragments_nodes = [ 53 | repre["size"] for repre in representations 54 | ] 55 | out_samples, _ = ddpm_trainer.ddpm.inpaint( 56 | n_samples=n_samples, 57 | fragments_nodes=fragments_nodes, 58 | conditions=conditions, 59 | return_frames=1, 60 | resamplings=resamplings, 61 | jump_length=jump_length, 62 | timesteps=None, 63 | xh_fixed=xh_fixed, 64 | frag_fixed=frag_fixed, 65 | ) 66 | return out_samples[0], xh_fixed, fragments_nodes 67 | 68 | 69 | def batch_ts_deltaE(bz, xc="wb97x", localpath="tmp"): 70 | deltaEs = [] 71 | for ii in range(bz): 72 | deltaEs.append( 73 | calc_deltaE( 74 | f"{localpath}/sample_{ii}_ts.xyz", 75 | f"{localpath}/gen_{ii}_ts.xyz", 76 | xc=xc, 77 | ) * EV2KCALMOL 78 | ) 79 | print("----") 80 | return deltaEs 81 | 82 | 83 | def batch_E(bz, prefix="gen"): 84 | Es = [] 85 | for ii in range(bz): 86 | mf, _, _ = compute_efh( 87 | f"tmp/{prefix}_{ii}_ts.xyz", 88 | f=False, hess=False, return_metrics=False) 89 | Es.append(mf.e_tot * AU2KCALMOL) 90 | return Es 91 | 92 | 93 | def samples_to_pos_charge(out_samples, fragments_nodes): 94 | x_r = torch.tensor_split( 95 | out_samples[0], 96 | torch.cumsum(fragments_nodes[0], dim=0).to("cpu")[:-1] 97 | ) 98 | x_ts = torch.tensor_split( 99 | out_samples[1], 100 | torch.cumsum(fragments_nodes[0], dim=0).to("cpu")[:-1] 101 | ) 102 | x_p = torch.tensor_split( 103 | out_samples[2], 104 | torch.cumsum(fragments_nodes[0], dim=0).to("cpu")[:-1] 105 | ) 106 | pos = { 107 | "reactant": [_x[:, :3].cpu().numpy() for _x in x_r], 108 | "transition_state": [_x[:, :3].cpu().numpy() for _x in x_ts], 109 | "product": [_x[:, :3].cpu().numpy() for _x in x_p], 110 | } 111 | z = [_x[:, -1].long().cpu().numpy() for _x in x_r] 112 | natoms = [f.cpu().item() for f in fragments_nodes[0]] 113 | return pos, z, natoms 114 | -------------------------------------------------------------------------------- /reactot/model/__init__.py: -------------------------------------------------------------------------------- 1 | """EGNN is mostly adpated from https://github.com/ehoogeboom/e3_diffusion_for_molecules.""" 2 | from .egnn import EGNN 3 | from .core import MLP 4 | from .util_funcs import coord2diff, move_by_com 5 | from .leftnet import LEFTNet 6 | -------------------------------------------------------------------------------- /reactot/model/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/model/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/model/__pycache__/block.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/model/__pycache__/block.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/model/__pycache__/core.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/model/__pycache__/core.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/model/__pycache__/egnn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/model/__pycache__/egnn.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/model/__pycache__/leftnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/model/__pycache__/leftnet.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/model/__pycache__/util_funcs.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/model/__pycache__/util_funcs.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/model/core.py: -------------------------------------------------------------------------------- 1 | """ 2 | Core layers provide basic operations, e.g., MLP 3 | """ 4 | from typing import List, Union 5 | 6 | import torch 7 | from torch import nn, Tensor, tensor 8 | 9 | 10 | ACTIVATION_MAPPING = { 11 | "swish": nn.SiLU(), 12 | "silu": nn.SiLU(), 13 | "relu": nn.ReLU(), 14 | "sigmoid": nn.Sigmoid(), 15 | } 16 | 17 | 18 | class ZeroLayer(nn.Module): 19 | r"""A skeleton layer that returns zeros.""" 20 | 21 | def forward(self, inputs: List[Tensor], **kwargs) -> Tensor: 22 | return 0 23 | 24 | 25 | class ConcatLayer(nn.Module): 26 | r"""Concatnate layer.""" 27 | 28 | def __init__(self, dim: int = -1) -> None: 29 | super().__init__() 30 | self.register_buffer("dim", tensor(dim)) 31 | 32 | def forward(self, inputs: List[Tensor], **kwargs) -> Tensor: 33 | return torch.concat(inputs, dim=self.dim) 34 | 35 | 36 | class OneLayerActivation(nn.Module): 37 | r"""One layer NN with activation.""" 38 | 39 | def __init__( 40 | self, in_dim: int, out_dim: int, bias: int = True, activation=Union[str, None] 41 | ) -> None: 42 | super().__init__() 43 | self.linear = nn.Linear(in_dim, out_dim, bias=bias) 44 | self.activation = ( 45 | ACTIVATION_MAPPING[activation] if activation is not None else nn.Identity() 46 | ) 47 | 48 | def forward(self, input: Tensor) -> Tensor: 49 | return self.activation(self.linear(input)) 50 | 51 | 52 | class MLP(nn.Module): 53 | r"""Multi-layer perceptron.""" 54 | 55 | def __init__( 56 | self, 57 | in_dim: int, 58 | out_dims: list, 59 | bias: bool = True, 60 | activation: Union[list[Union[str, None]], str, None] = "swish", 61 | last_layer_no_activation: bool = False, 62 | ): 63 | super().__init__() 64 | input_dim = in_dim 65 | if isinstance(activation, str) or activation is None: 66 | activation = [activation] * len(out_dims) 67 | else: 68 | assert len(activation) == len( 69 | out_dims 70 | ), "activation and out_dims must have the same length" 71 | if last_layer_no_activation: 72 | activation[-1] = None 73 | for _activation in activation: 74 | assert (_activation is None) or ( 75 | _activation in ACTIVATION_MAPPING 76 | ), f"activation {activation} not avail." 77 | 78 | module_list = [] 79 | for ii in range(len(out_dims)): 80 | module_list.append( 81 | OneLayerActivation( 82 | in_dim=input_dim, 83 | out_dim=out_dims[ii], 84 | bias=bias, 85 | activation=activation[ii], 86 | ) 87 | ) 88 | input_dim = out_dims[ii] 89 | self.mlp = nn.Sequential(*module_list) 90 | 91 | def forward(self, input: Tensor) -> Tensor: 92 | return self.mlp(input) 93 | 94 | 95 | class GatedMLP(nn.Module): 96 | r""" 97 | Gated MLP implementation. It implements the following 98 | `out = MLP(x) * MLP_\\sigmoid(x)` 99 | 100 | The current implementation is slightly different from the tf version, 101 | where the last activation from an MLP is forced to be sigmoid. 102 | """ 103 | 104 | def __init__( 105 | self, 106 | in_dim: int, 107 | out_dims: list, 108 | bias: bool = True, 109 | activation: Union[list[Union[str, None]], str, None] = "swish", 110 | gate_activation: str = "sigmoid", 111 | last_layer_no_activation: bool = False, 112 | ): 113 | super().__init__() 114 | self.mlp = MLP( 115 | in_dim, 116 | out_dims, 117 | bias, 118 | activation, 119 | last_layer_no_activation=last_layer_no_activation, 120 | ) 121 | self.gmlp = MLP( 122 | in_dim, 123 | out_dims, 124 | bias, 125 | activation, 126 | last_layer_no_activation=last_layer_no_activation, 127 | ) 128 | self.gate_activation = ACTIVATION_MAPPING[gate_activation] 129 | 130 | def forward(self, input: Tensor) -> Tensor: 131 | return self.mlp(input) * self.gate_activation(self.gmlp(input)) 132 | -------------------------------------------------------------------------------- /reactot/model/egnn.py: -------------------------------------------------------------------------------- 1 | """EGNN model""" 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | from torch import nn, Tensor 6 | 7 | from .block import EquivariantBlock, SinusoidsEmbeddingNew 8 | from .util_funcs import coord2diff, symmetrize_edge, get_ji_bond_index 9 | 10 | 11 | class EGNN(nn.Module): 12 | def __init__( 13 | self, 14 | in_node_nf: int = 8, 15 | in_edge_nf: int = 2, 16 | hidden_nf: int = 256, 17 | edge_hidden_nf: int = 32, 18 | act_fn: str = "swish", 19 | n_layers: int = 3, 20 | attention: int = False, 21 | out_node_nf: Optional[int] = None, 22 | tanh: bool = False, 23 | coords_range: float = 15.0, 24 | norm_constant: float = 1.0, 25 | inv_sublayers: int = 2, 26 | sin_embedding: bool = False, 27 | normalization_factor: float = 100.0, 28 | aggregation_method: str = "sum", 29 | reflect_equiv: bool = True, 30 | ): 31 | r"""_summary_ 32 | 33 | Args: 34 | in_node_nf (int): number of input node feature. Defaults to 8. 35 | in_edge_nf (int): number of input edge feature. Defaults to 2. 36 | hidden_nf (int): number of hidden units. Defaults to 256. 37 | act_fn (str): activation function. Defaults to "swish". 38 | n_layers (int): number of equivariant update block. Defaults to 3. 39 | attention (int): whether to use self attention. Defaults to False. 40 | out_node_nf (Optional[int]): number of output node features. 41 | Defaults to None to set the same as in_node_nf 42 | coords_range (float): range factor, only used in tanh = True. 43 | Defaults to 15.0. 44 | norm_constant (float): distance normalizating factor. Defaults to 1.0. 45 | inv_sublayers (int): number of GCL in an equivariant update block. 46 | Defaults to 2. 47 | sin_embedding (Optional[nn.Module]): whether to use edge distance embedding. 48 | Defaults to None. 49 | normalization_factor (float): distance normalization used in coord2diff. 50 | Defaults to 1.0. 51 | aggregation_method (str): aggregation options in scattering. 52 | Defaults to "sum". 53 | reflect_equiv (bool): whether to ignore reflection. 54 | Defaults to True. 55 | """ 56 | super().__init__() 57 | if out_node_nf is None: 58 | out_node_nf = in_node_nf 59 | self.hidden_nf = hidden_nf 60 | self.edge_hidden_nf = edge_hidden_nf 61 | self.n_layers = n_layers 62 | self.coords_range_layer = float(coords_range / n_layers) 63 | self.normalization_factor = normalization_factor 64 | self.aggregation_method = aggregation_method 65 | self.reflect_equiv = reflect_equiv 66 | 67 | edge_feat_nf = in_edge_nf 68 | if sin_embedding: 69 | self.sin_embedding = SinusoidsEmbeddingNew() 70 | self.dist_dim = self.sin_embedding.dim 71 | else: 72 | self.sin_embedding = None 73 | self.dist_dim = 1 74 | 75 | self.edge_feat_nf = edge_feat_nf + self.dist_dim 76 | 77 | self.embedding = nn.Linear(in_node_nf, self.hidden_nf) 78 | self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf) 79 | 80 | self.edge_embedding = nn.Linear( 81 | self.edge_feat_nf, self.hidden_nf - self.dist_dim 82 | ) 83 | self.edge_embedding_out = nn.Linear( 84 | self.hidden_nf - self.dist_dim, self.edge_feat_nf 85 | ) 86 | for i in range(0, n_layers): 87 | self.add_module( 88 | "e_block_%d" % i, 89 | EquivariantBlock( 90 | hidden_nf, 91 | edge_feat_nf=edge_feat_nf, 92 | act_fn=act_fn, 93 | n_layers=inv_sublayers, 94 | attention=attention, 95 | tanh=tanh, 96 | coords_range=coords_range, 97 | norm_constant=norm_constant, 98 | sin_embedding=self.sin_embedding, 99 | normalization_factor=normalization_factor, 100 | aggregation_method=aggregation_method, 101 | reflect_equiv=reflect_equiv, 102 | ), 103 | ) 104 | 105 | def forward( 106 | self, 107 | h: Tensor, 108 | pos: Tensor, 109 | edge_index: Tensor, 110 | edge_attr: Optional[Tensor] = None, 111 | node_mask: Optional[Tensor] = None, 112 | edge_mask: Optional[Tensor] = None, 113 | update_coords_mask: Optional[Tensor] = None, 114 | subgraph_mask: Optional[Tensor] = None, 115 | ) -> Tuple[Tensor, Tensor, Tensor]: 116 | r""" 117 | 118 | Args: 119 | h (Tensor): [n_nodes, n_hidden], node features. 120 | pos (Tensor): [n_nodes, n_dim (3 in 3D space)], position tensor. 121 | edge_index (Tensor): [2, n_edge], edge index {ij} 122 | edge_attr (Optional[Tensor]): [n_edge, edge_feature_dim]. edge attributes. 123 | Defaults to None. 124 | node_mask (Optional[Tensor]): [n_node, 1], mask for node updates. 125 | Defaults to None. 126 | edge_mask (Optional[Tensor]): [n_edge, 1], mask for edge updates. 127 | Defaults to None. 128 | update_coords_mask (Optional[Tensor]): [n_node, 1], mask for position updates. 129 | Defaults to None. 130 | subgraph_mask (Optional[Tensor]): n_edge, 1], mask for positions aggregations. 131 | The idea is keep subgraph (i.e., fragment) level equivariance. 132 | Defaults to None. 133 | 134 | Returns: 135 | Tuple[Tensor, Tensor, Tensor]: updated h, pos, edge_attr 136 | """ 137 | # Edit Emiel: Remove velocity as input 138 | distances, _ = coord2diff(pos, edge_index) 139 | if subgraph_mask is not None: 140 | distances = distances * subgraph_mask 141 | 142 | if self.sin_embedding is not None: 143 | distances = self.sin_embedding(distances) 144 | if edge_attr is None or edge_attr.size(-1) == 0: 145 | edge_attr = distances 146 | else: 147 | edge_attr = torch.concat([distances, edge_attr], dim=-1) 148 | edge_attr = self.edge_embedding(edge_attr) 149 | h = self.embedding(h) 150 | # edge_index_ji = get_ji_bond_index(edge_index) 151 | # edge_attr = symmetrize_edge(edge_attr, edge_index_ji) 152 | 153 | for i in range(0, self.n_layers): 154 | h, pos, edge_attr = self._modules["e_block_%d" % i]( 155 | h, 156 | pos, 157 | edge_index, 158 | edge_attr=edge_attr, 159 | node_mask=node_mask, 160 | edge_mask=edge_mask, 161 | update_coords_mask=update_coords_mask, 162 | subgraph_mask=subgraph_mask, 163 | ) 164 | 165 | # edge_attr = symmetrize_edge(edge_attr, edge_index_ji) 166 | 167 | # Important, the bias of the last linear might be non-zero 168 | h = self.embedding_out(h) 169 | edge_attr = self.edge_embedding_out(edge_attr) 170 | 171 | if node_mask is not None: 172 | h = h * node_mask 173 | if edge_mask is not None: 174 | edge_attr = edge_attr * edge_mask 175 | return h, pos, edge_attr 176 | -------------------------------------------------------------------------------- /reactot/pre_process.py: -------------------------------------------------------------------------------- 1 | from pymatgen.io.xyz import XYZ 2 | import numpy as np 3 | from pymatgen.analysis.molecule_matcher import KabschMatcher 4 | from ase.io import read 5 | import os 6 | 7 | 8 | def xyz2pmg(xyzfile): 9 | # Converts an XYZ file to a pymatgen Molecule object 10 | xyz_converter = XYZ(mol=None) 11 | mol = xyz_converter.from_file(xyzfile).molecule 12 | return mol 13 | 14 | def translate_molecule(mol): 15 | # Translates the molecule so that its center of mass is at the origin 16 | coordinates = np.array([[site.x, site.y, site.z] for site in mol.sites]) 17 | avg_coordinates = np.mean(coordinates, axis=0) 18 | translated_coordinates = coordinates - avg_coordinates 19 | for i, site in enumerate(mol.sites): 20 | site.x, site.y, site.z = translated_coordinates[i] 21 | return mol 22 | 23 | def write_xyz(mol, filename): 24 | # Writes a pymatgen Molecule object to an XYZ file 25 | num_atoms = len(mol.sites) 26 | comment = "have a nice day" 27 | with open(filename, 'w',encoding='utf-8') as f: 28 | f.write(f"{num_atoms}\n") 29 | f.write(f"{comment}\n") 30 | for site in mol.sites: 31 | f.write(f"{site.specie} {site.x:.6f} {site.y:.6f} {site.z:.6f}\n") 32 | return filename 33 | 34 | def pre_treatment(rxyz,pxyz): 35 | """ 36 | Pre-treatment function to optimize the reactant and product molecules 37 | """ 38 | mol1 = xyz2pmg(rxyz) 39 | mol2 = xyz2pmg(pxyz) 40 | mol1_opt=translate_molecule(mol1) 41 | bfm = KabschMatcher(mol1_opt) 42 | mol2_opt, rmsd = bfm.fit(mol2) 43 | mol1_opt_path = rxyz.replace('.xyz', '-opt.xyz') 44 | mol2_opt_path = pxyz.replace('.xyz', '-opt.xyz') 45 | 46 | write_xyz(mol1_opt, mol1_opt_path) 47 | write_xyz(mol2_opt, mol2_opt_path) 48 | 49 | mol1_data=read(mol1_opt_path) 50 | mol2_data=read(mol2_opt_path) 51 | os.remove(mol1_opt_path) 52 | os.remove(mol2_opt_path) 53 | return mol1_data, mol2_data, rmsd 54 | -------------------------------------------------------------------------------- /reactot/sampling/sample_datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytorch_lightning import Trainer 4 | from reactot.diffusion._node_dist import SingleDistributionNodes 5 | from reactot.utils import bond_analyze 6 | 7 | 8 | @torch.no_grad() 9 | def sample_qm9( 10 | ddpm_trainer: Trainer, 11 | nodes_dist: SingleDistributionNodes, 12 | bz: int, 13 | n_samples: int, 14 | n_real: int = 1, 15 | n_fake: int = 2, 16 | device: torch.device = torch.device("cuda"), 17 | ): 18 | n_batch = int(n_samples / bz) 19 | mols = [] 20 | pos_dim = ddpm_trainer.ddpm.pos_dim 21 | for _ in range(n_batch): 22 | fragments_nodes = [ 23 | nodes_dist.sample(bz).to(device) for _ in range(n_real) 24 | ] 25 | fragments_nodes += [ 26 | torch.ones(bz, device=device).long() for _ in range(n_fake) 27 | ] 28 | conditions = torch.zeros((bz, 1), device=device) 29 | 30 | out_samples, out_masks = ddpm_trainer.ddpm.sample( 31 | n_samples=bz, 32 | fragments_nodes=fragments_nodes, 33 | conditions=conditions, 34 | return_frames=1, 35 | timesteps=None, 36 | ) 37 | sample_idxs = torch.cat( 38 | [ 39 | torch.tensor([0], device=device), 40 | torch.cumsum(fragments_nodes[0], dim=0) 41 | ] 42 | ) 43 | for ii in range(bz): 44 | _start, _end = sample_idxs[ii], sample_idxs[ii+1] 45 | mols.append({ 46 | "pos": out_samples[0][0][_start: _end, :pos_dim].detach().cpu(), 47 | "atom":torch.argmax( 48 | out_samples[0][0][_start: _end, pos_dim: -1].detach().cpu(), 49 | dim=1, 50 | ) 51 | }) 52 | return mols 53 | -------------------------------------------------------------------------------- /reactot/trainer/__pycache__/_metrics.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/trainer/__pycache__/_metrics.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/trainer/__pycache__/pl_trainer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/trainer/__pycache__/pl_trainer.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/trainer/_metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def average_over_batch_metrics(batch_metrics: List[Dict], allowed: List=[]): 7 | epoch_metrics = {} 8 | effective_batch = {} 9 | for ii, out in enumerate(batch_metrics): 10 | for k, v in out.items(): 11 | if not (k in allowed or len(allowed) == 0): 12 | continue 13 | if ii == 0: 14 | epoch_metrics[k] = v 15 | effective_batch[k] = 1 16 | else: 17 | if not np.isnan(v): 18 | epoch_metrics[k] += v 19 | effective_batch[k] += 1 20 | for k in epoch_metrics: 21 | epoch_metrics[k] /= effective_batch[k] 22 | return epoch_metrics 23 | 24 | 25 | def pretty_print(epoch, metric_dict, prefix="Train"): 26 | out = f"{prefix} epoch {epoch} " 27 | for k, v in metric_dict.items(): 28 | out += f"{k} {v:.2f} " 29 | print(out) 30 | -------------------------------------------------------------------------------- /reactot/trainer/ema.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/07b_collections.callbacks.ema.ipynb (unless otherwise specified). 2 | 3 | __all__ = ['EMACallback'] 4 | 5 | # Cell 6 | import logging 7 | from typing import Any, Dict 8 | from copy import deepcopy 9 | 10 | import pytorch_lightning as pl 11 | import torch 12 | import torch.nn as nn 13 | from pytorch_lightning.callbacks import Callback 14 | from pytorch_lightning.utilities import rank_zero_only 15 | from timm.utils.model import get_state_dict, unwrap_model 16 | # from timm.utils.model_ema import ModelEmaV2 17 | 18 | 19 | _logger = logging.getLogger(__name__) 20 | 21 | 22 | class ModelEmaV2(nn.Module): 23 | """ Model Exponential Moving Average V2 24 | 25 | Keep a moving average of everything in the model state_dict (parameters and buffers). 26 | V2 of this module is simpler, it does not match params/buffers based on name but simply 27 | iterates in order. It works with torchscript (JIT of full model). 28 | 29 | This is intended to allow functionality like 30 | https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage 31 | 32 | A smoothed version of the weights is necessary for some training schemes to perform well. 33 | E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use 34 | RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA 35 | smoothing of weights to match results. Pay attention to the decay constant you are using 36 | relative to your update count per epoch. 37 | 38 | To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but 39 | disable validation of the EMA weights. Validation will have to be done manually in a separate 40 | process, or after the training stops converging. 41 | 42 | This class is sensitive where it is initialized in the sequence of model init, 43 | GPU assignment and distributed training wrappers. 44 | """ 45 | def __init__(self, model, decay=0.9999, device=None): 46 | super(ModelEmaV2, self).__init__() 47 | # make a copy of the model for accumulating moving average of weights 48 | self.module = deepcopy(model) 49 | self.module.eval() 50 | self.decay = decay 51 | self.device = device # perform ema on different device from model if set 52 | if self.device is not None: 53 | self.module.to(device=device) 54 | 55 | def _update(self, model, update_fn): 56 | with torch.no_grad(): 57 | for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): 58 | # if self.device is not None: 59 | # model_v = model_v.to(device=self.device) 60 | # ghliu: fix ... 61 | model_v = model_v.to(ema_v) 62 | ema_v.copy_(update_fn(ema_v, model_v)) 63 | 64 | def update(self, model): 65 | self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) 66 | 67 | def set(self, model): 68 | self._update(model, update_fn=lambda e, m: m) 69 | 70 | # Cell 71 | class EMACallback(Callback): 72 | """ 73 | Model Exponential Moving Average. Empirically it has been found that using the moving average 74 | of the trained parameters of a deep network is better than using its trained parameters directly. 75 | If `use_ema_weights`, then the ema parameters of the network is set after training end. 76 | """ 77 | 78 | def __init__(self, pl_module, decay=1e-4, use_ema_weights: bool = True): 79 | self.decay = 1. - decay # due to the different defition of `decay` 80 | self.ema = None 81 | self.use_ema_weights = use_ema_weights 82 | self.ema = ModelEmaV2(pl_module, decay=self.decay, device=None) 83 | 84 | # def on_fit_start(self, trainer, pl_module): 85 | # "Initialize `ModelEmaV2` from timm to keep a copy of the moving average of the weights" 86 | # self.ema = ModelEmaV2(pl_module, decay=self.decay, device=None) 87 | 88 | def on_train_batch_end( 89 | self, trainer, pl_module, outputs, batch, batch_idx, 90 | ): 91 | "Update the stored parameters using a moving average" 92 | # Update currently maintained parameters. 93 | self.ema.update(pl_module) 94 | 95 | def on_validation_epoch_start(self, trainer, pl_module): 96 | "do validation using the stored parameters" 97 | # save original parameters before replacing with EMA version 98 | self.store(pl_module.parameters()) 99 | 100 | # update the LightningModule with the EMA weights 101 | # ~ Copy EMA parameters to LightningModule 102 | self.copy_to(self.ema.module.parameters(), pl_module.parameters()) 103 | 104 | def on_validation_end(self, trainer, pl_module): 105 | "Restore original parameters to resume training later" 106 | self.restore(pl_module.parameters()) 107 | 108 | def on_train_end(self, trainer, pl_module): 109 | # update the LightningModule with the EMA weights 110 | if self.use_ema_weights: 111 | self.copy_to(self.ema.module.parameters(), pl_module.parameters()) 112 | msg = "Model weights replaced with the EMA version." 113 | # log_main_process(_logger, logging.INFO, msg) 114 | 115 | # def on_save_checkpoint(self, trainer, pl_module, checkpoint): 116 | # if self.ema is not None: 117 | # return {"state_dict_ema": get_state_dict(self.ema, unwrap_model)} 118 | 119 | def state_dict(self): 120 | if self.ema is not None: 121 | return {"state_dict_ema": get_state_dict(self.ema, unwrap_model)} 122 | 123 | # def on_load_checkpoint(self, checkpoint): 124 | # if self.ema is not None: 125 | # self.ema.module.load_state_dict(checkpoint["state_dict_ema"]) 126 | 127 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 128 | if self.ema is not None: 129 | self.ema.module.load_state_dict(state_dict["state_dict_ema"]) 130 | 131 | def store(self, parameters): 132 | "Save the current parameters for restoring later." 133 | self.collected_params = [param.clone() for param in parameters] 134 | 135 | def restore(self, parameters): 136 | """ 137 | Restore the parameters stored with the `store` method. 138 | Useful to validate the model with EMA parameters without affecting the 139 | original optimization process. 140 | """ 141 | for c_param, param in zip(self.collected_params, parameters): 142 | param.data.copy_(c_param.data) 143 | 144 | def copy_to(self, shadow_parameters, parameters): 145 | "Copy current parameters into given collection of parameters." 146 | for s_param, param in zip(shadow_parameters, parameters): 147 | if param.requires_grad: 148 | param.data.copy_(s_param.data) 149 | -------------------------------------------------------------------------------- /reactot/trainer/potential_module.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple 2 | 3 | from pathlib import Path 4 | import torch 5 | from torch import nn 6 | 7 | from torch_geometric.loader import DataLoader 8 | from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, StepLR 9 | from pytorch_lightning import LightningModule 10 | from torchmetrics import MeanAbsoluteError, MeanAbsolutePercentageError, CosineSimilarity 11 | from sklearn.metrics.pairwise import cosine_similarity 12 | 13 | from reactot.dataset.ff_lmdb import LmdbDataset 14 | from reactot.dynamics import Potential 15 | from reactot.trainer._metrics import average_over_batch_metrics, pretty_print 16 | import reactot.utils.training_tools as utils 17 | 18 | LR_SCHEDULER = { 19 | "cos": CosineAnnealingWarmRestarts, 20 | "step": StepLR, 21 | } 22 | 23 | 24 | class PotentialModule(LightningModule): 25 | def __init__( 26 | self, 27 | model_config: Dict, 28 | optimizer_config: Dict, 29 | training_config: Dict, 30 | node_nfs: List[int] = [9] * 1, 31 | edge_nf: int = 4, 32 | condition_nf: int = 1, 33 | fragment_names: List[str] = ["struct"], 34 | pos_dim: int = 3, 35 | edge_cutoff: Optional[float] = None, 36 | model: nn.Module = None, 37 | enforce_same_encoding: Optional[List] = None, 38 | source: Optional[Dict] = None, 39 | use_autograd: bool = False, 40 | timesteps: int = 5000, 41 | condition_time: bool = True, 42 | ) -> None: 43 | super().__init__() 44 | self.potential = Potential( 45 | model_config=model_config, 46 | node_nfs=node_nfs, 47 | edge_nf=edge_nf, 48 | condition_nf=condition_nf, 49 | fragment_names=fragment_names, 50 | pos_dim=pos_dim, 51 | edge_cutoff=edge_cutoff, 52 | model=model, 53 | enforce_same_encoding=enforce_same_encoding, 54 | source=source, 55 | timesteps=timesteps, 56 | condition_time=condition_time, 57 | ) 58 | 59 | self.model_config = model_config 60 | self.optimizer_config = optimizer_config 61 | self.training_config = training_config 62 | self.n_fragments = len(fragment_names) 63 | self.use_autograd = use_autograd 64 | 65 | self.clip_grad = training_config["clip_grad"] 66 | if self.clip_grad: 67 | self.gradnorm_queue = utils.Queue() 68 | self.gradnorm_queue.add(3000) 69 | self.save_hyperparameters() 70 | 71 | self.loss_fn = nn.MSELoss() 72 | self.MAEEval = MeanAbsoluteError() 73 | self.MAPEEval = MeanAbsolutePercentageError() 74 | self.cosineEval = CosineSimilarity(reduction="mean") 75 | 76 | def configure_optimizers(self): 77 | optimizer = torch.optim.AdamW( 78 | self.potential.parameters(), 79 | **self.optimizer_config 80 | ) 81 | if not self.training_config["lr_schedule_type"] is None: 82 | scheduler_func = LR_SCHEDULER[self.training_config["lr_schedule_type"]] 83 | scheduler = scheduler_func( 84 | optimizer=optimizer, 85 | **self.training_config["lr_schedule_config"] 86 | ) 87 | return [optimizer], [scheduler] 88 | return optimizer 89 | 90 | def setup(self, stage: Optional[str] = None): 91 | if stage == "fit": 92 | self.train_dataset = LmdbDataset( 93 | Path(self.training_config["datadir"], f"ff_valid.lmdb"), 94 | **self.training_config, 95 | ) 96 | self.val_dataset = LmdbDataset( 97 | Path(self.training_config["datadir"], f"ff_valid.lmdb"), 98 | **self.training_config, 99 | ) 100 | print("# of training data: ", len(self.train_dataset)) 101 | print("# of validation data: ", len(self.val_dataset)) 102 | elif stage == "test": 103 | self.test_dataset = LmdbDataset( 104 | Path(self.training_config["datadir"], f"ff_test.lmdb"), 105 | **self.training_config, 106 | ) 107 | else: 108 | raise NotImplementedError 109 | 110 | def train_dataloader(self) -> DataLoader: 111 | return DataLoader( 112 | self.train_dataset, 113 | batch_size=self.training_config["bz"], 114 | shuffle=True, 115 | num_workers=self.training_config["num_workers"], 116 | ) 117 | 118 | def val_dataloader(self) -> DataLoader: 119 | return DataLoader( 120 | self.val_dataset, 121 | batch_size=self.training_config["bz"] * 3, 122 | shuffle=False, 123 | num_workers=self.training_config["num_workers"], 124 | ) 125 | 126 | def test_dataloader(self) -> DataLoader: 127 | return DataLoader( 128 | self.test_dataset, 129 | batch_size=self.training_config["bz"], 130 | shuffle=False, 131 | num_workers=self.training_config["num_workers"], 132 | ) 133 | 134 | @torch.enable_grad() 135 | def compute_loss(self, batch): 136 | if not self.use_autograd: 137 | hat_ae, hat_forces = self.potential.forward( 138 | batch.to(self.device), 139 | ) 140 | else: 141 | hat_ae, hat_forces = self.potential.forward_autograd( 142 | batch.to(self.device), 143 | ) 144 | hat_ae = hat_ae.to(self.device) 145 | hat_forces = hat_forces.view(-1, ).to(self.device) 146 | ae = batch.ae.to(self.device) 147 | forces = batch.forces.view(-1, ).to(self.device) 148 | 149 | eloss = self.loss_fn(ae, hat_ae) 150 | floss = self.loss_fn(forces, hat_forces) 151 | info = { 152 | "MAE_E": self.MAEEval(hat_ae, ae).item(), 153 | "MAE_F": self.MAEEval(hat_forces, forces).item(), 154 | "MAPE_E": self.MAPEEval(hat_ae, ae).item(), 155 | "MAPE_F": self.MAPEEval(hat_forces, forces).item(), 156 | "MAE_Fcos": 1 - self.cosineEval(hat_forces.detach().cpu(), forces.detach().cpu()), 157 | "Loss_E": eloss.item(), 158 | "Loss_F": floss.item(), 159 | } 160 | 161 | # loss = floss * 100 + eloss 162 | loss = floss * 100 163 | return loss, info 164 | 165 | def training_step(self, batch, batch_idx): 166 | loss, info = self.compute_loss(batch) 167 | self.log("train-totloss", loss, rank_zero_only=True) 168 | 169 | for k, v in info.items(): 170 | self.log(f"train-{k}", v, rank_zero_only=True) 171 | return loss 172 | 173 | def _shared_eval(self, batch, batch_idx, prefix, *args): 174 | loss, info = self.compute_loss(batch) 175 | info["totloss"] = loss.item() 176 | 177 | info_prefix = {} 178 | for k, v in info.items(): 179 | info_prefix[f"{prefix}-{k}"] = v 180 | return info_prefix 181 | 182 | def validation_step(self, batch, batch_idx, *args): 183 | return self._shared_eval(batch, batch_idx, "val", *args) 184 | 185 | def test_step(self, batch, batch_idx, *args): 186 | return self._shared_eval(batch, batch_idx, "test", *args) 187 | 188 | def validation_epoch_end(self, val_step_outputs): 189 | val_epoch_metrics = average_over_batch_metrics(val_step_outputs) 190 | if self.trainer.is_global_zero: 191 | pretty_print(self.current_epoch, val_epoch_metrics, prefix="val") 192 | val_epoch_metrics.update({"epoch": self.current_epoch}) 193 | for k, v in val_epoch_metrics.items(): 194 | self.log(k, v, sync_dist=True) 195 | 196 | def configure_gradient_clipping( 197 | self, 198 | optimizer, 199 | optimizer_idx, 200 | gradient_clip_val, 201 | gradient_clip_algorithm 202 | ): 203 | 204 | if not self.clip_grad: 205 | return 206 | 207 | # Allow gradient norm to be 150% + 1.5 * stdev of the recent history. 208 | max_grad_norm = 1.5 * self.gradnorm_queue.mean() + \ 209 | 3 * self.gradnorm_queue.std() 210 | 211 | # Get current grad_norm 212 | params = [p for g in optimizer.param_groups for p in g['params']] 213 | grad_norm = utils.get_grad_norm(params) 214 | 215 | # Lightning will handle the gradient clipping 216 | self.clip_gradients(optimizer, gradient_clip_val=max_grad_norm, 217 | gradient_clip_algorithm='norm') 218 | 219 | if float(grad_norm) > max_grad_norm: 220 | self.gradnorm_queue.add(float(max_grad_norm)) 221 | else: 222 | self.gradnorm_queue.add(float(grad_norm)) 223 | 224 | if float(grad_norm) > max_grad_norm: 225 | print(f'Clipped gradient with value {grad_norm:.1f} ' 226 | f'while allowed {max_grad_norm:.1f}') -------------------------------------------------------------------------------- /reactot/trainer/train_rpsb_ts1x.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | from uuid import uuid4 3 | import os 4 | import shutil 5 | import torch 6 | 7 | from reactot.trainer.pl_trainer import SBModule, DDPMModule 8 | from pytorch_lightning import Trainer, seed_everything 9 | from pytorch_lightning.callbacks.progress import TQDMProgressBar 10 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor 11 | from pytorch_lightning.loggers import WandbLogger 12 | from pytorch_lightning.strategies.ddp import DDPStrategy 13 | 14 | from reactot.trainer.ema import EMACallback 15 | from reactot.model import LEFTNet 16 | 17 | from ipdb import set_trace as debug 18 | import colored_traceback.always 19 | 20 | 21 | class OPT: 22 | def __init__( 23 | self, 24 | solver, 25 | method, 26 | ): 27 | self.solver = solver 28 | self.method = method 29 | self.atol = 1e-2 30 | self.rtol = 1e-2 31 | 32 | opt = OPT(solver="ddpm", method="midpoint") 33 | 34 | model_type = "leftnet" 35 | version = "ts_guess_NEBCI-xtb-ema" 36 | project = "RPSB-FT-Schedule" 37 | # ---EGNNDynamics--- 38 | leftnet_config = dict( 39 | pos_require_grad=False, 40 | cutoff=10.0, 41 | num_layers=6, 42 | hidden_channels=196, 43 | num_radial=96, 44 | in_hidden_channels=8, 45 | reflect_equiv=True, 46 | legacy=True, 47 | update=True, 48 | pos_grad=False, 49 | single_layer_output=True, 50 | object_aware=True, 51 | ) 52 | 53 | if model_type == "leftnet": 54 | model_config = leftnet_config 55 | model = LEFTNet 56 | else: 57 | raise KeyError("model type not implemented.") 58 | 59 | optimizer_config = dict( 60 | lr=1e-4, 61 | betas=[0.9, 0.999], 62 | weight_decay=0, 63 | amsgrad=True, 64 | ) 65 | 66 | T_0 = 10 67 | T_mult = 1 68 | training_config = dict( 69 | datadir="reactot/data/transition1x/", 70 | remove_h=False, 71 | bz=14, 72 | num_workers=0, 73 | clip_grad=True, 74 | gradient_clip_val=None, 75 | ema=True, 76 | ema_decay=0.999, 77 | swapping_react_prod=False, 78 | append_frag=False, 79 | use_by_ind=True, 80 | reflection=False, 81 | single_frag_only=False, 82 | # react_type="xTB-IRC", 83 | # position_key="xtb_positions", 84 | only_ts=False, 85 | lr_schedule_type=None, 86 | lr_schedule_config=dict( 87 | gamma=0.8, 88 | step_size=10, 89 | ), # step 90 | # lr_schedule_config=dict( 91 | # T_0=T_0, 92 | # T_mult=T_mult, 93 | # eta_min=0, 94 | # ), #cos 95 | use_sampler=True, 96 | sampler_config=dict( 97 | max_num=2800, # This is for 16GB GPU; Scale linearly with memory 98 | mode="node^2", 99 | shuffle=True, 100 | ddp=False, 101 | ) 102 | ) 103 | training_data_frac = 1.0 if not training_config["reflection"] else 0.5 104 | 105 | 106 | node_nfs: List[int] = [9] * 3 # 3 (pos) + 5 (cat) + 1 (charge) 107 | edge_nf: int = 0 # edge type 108 | condition_nf: int = 1 109 | fragment_names: List[str] = ["R", "TS", "P"] 110 | pos_dim: int = 3 111 | update_pocket_coords: bool = True 112 | condition_time: bool = True 113 | edge_cutoff: Optional[float] = None 114 | loss_type = "l2" 115 | pos_only = True 116 | process_type = "TS1x" 117 | enforce_same_encoding = None 118 | scales = [1., 2., 1.] 119 | fixed_idx = [0, 2] 120 | eval_epochs = 1 121 | save_epochs = 1 122 | 123 | # ----Normalizer--- 124 | norm_values: Tuple = (1., 1., 1.) 125 | norm_biases: Tuple = (0., 0., 0.) 126 | 127 | # ---Schedule--- 128 | timesteps: int = 3000 129 | beta_max: float = 0.3 130 | power: float = 0.5 131 | inv_power: float = 1 132 | precision: float = 1e-5 # not used 133 | noise_schedule: str = "cosine" # not used 134 | 135 | # ---SB--- 136 | mapping: str = "R+P->TS" 137 | mapping_initial: str = "RP" # RP for (r+p)/2, GUESS for guessing 138 | nfe: int = 25 139 | ot_ode: bool = True 140 | sigma: float = 0. 141 | ts_guess: bool = None #"ts_guess_NEBCI-xtb" # "ts_guess_sbv1" # "ts_guess_linear" 142 | 143 | norms = "_".join([str(x) for x in norm_values]) 144 | run_name = f"{model_type}-{version}-" + str(uuid4()).split("-")[-1] 145 | 146 | ## === Fine tuning from a FF --- 147 | # tspath = "/home/ubuntu/efs/reactot/reactot/trainer/ckpt" 148 | # checkpoint_path=f"{tspath}/TSDiffusion-TS1x-All/leftnet-8-70b75beeaac1/ddpm-epoch=2074-val-totloss=531.18.ckpt" # All diffuse/denoise 149 | # checkpoint_path=f"{tspath}/leftnet-10-d13a2c2bace6_wo_oa_align/ddpm-epoch=719-val-totloss=680.64.ckpt" 150 | # checkpoint_path=f"{tspath}/TSDiffusion-TS1x-All/RGD1xtb-pretrained-leftnet-0-7962cf1208dc/ddpm-epoch=1279-val-error_t_1=0.237.ckpt" 151 | # checkpoint_path = "/home/ubuntu/efs/reactot/reactot/trainer/checkpoint/TSDiff/leftnet-xtb-from-dftckpt-cd01d85c5152/ddpm-epoch=189-val-totloss=619.05.ckpt" 152 | # checkpoint_path = "/home/ubuntu/efs/reactot/reactot/trainer/checkpoint/RPSB-FT-Schedule/leftnet-xtb-c79fcfe0518d/sb-epoch=349-val_ep_scaled_err=0.0483.ckpt" 153 | checkpoint_path = None 154 | use_pretrain: bool = False 155 | 156 | source = None 157 | if use_pretrain: 158 | ddpm_trainer = DDPMModule.load_from_checkpoint( 159 | checkpoint_path=checkpoint_path, 160 | map_location="cpu", 161 | ) 162 | source = { 163 | "model": ddpm_trainer.ddpm.dynamics.model.state_dict(), 164 | "encoders": ddpm_trainer.ddpm.dynamics.encoders.state_dict(), 165 | "decoders": ddpm_trainer.ddpm.dynamics.decoders.state_dict(), 166 | } 167 | training_config.update( 168 | { 169 | "checkpoint_path": checkpoint_path, 170 | "use_pretrain": use_pretrain, 171 | } 172 | ) 173 | 174 | seed_everything(42, workers=True) 175 | ddpm = SBModule( 176 | model_config, 177 | optimizer_config, 178 | training_config, 179 | node_nfs, 180 | edge_nf, 181 | condition_nf, 182 | fragment_names, 183 | pos_dim, 184 | update_pocket_coords, 185 | condition_time, 186 | edge_cutoff, 187 | norm_values, 188 | norm_biases, 189 | noise_schedule, 190 | timesteps, 191 | precision, 192 | loss_type, 193 | pos_only, 194 | process_type, 195 | model, 196 | enforce_same_encoding, 197 | scales, 198 | source=source, 199 | fixed_idx=fixed_idx, 200 | eval_epochs=eval_epochs, 201 | mapping=mapping, 202 | mapping_initial=mapping_initial, 203 | nfe=nfe, 204 | beta_max=beta_max, 205 | ot_ode=ot_ode, 206 | power=power, 207 | inv_power=inv_power, 208 | sigma=sigma, 209 | ts_guess=ts_guess, 210 | ) 211 | ddpm.ddpm.opt = opt # heck for the new optimizer 212 | 213 | config = model_config.copy() 214 | config.update(optimizer_config) 215 | config.update(training_config) 216 | trainer = None 217 | if trainer is None or (isinstance(trainer, Trainer) and trainer.is_global_zero): 218 | wandb_logger = WandbLogger( 219 | project=project, 220 | log_model=False, 221 | name=run_name, 222 | ) 223 | try: # Avoid errors for creating wandb instances multiple times 224 | wandb_logger.experiment.config.update(config) 225 | wandb_logger.watch( 226 | ddpm.ddpm.dynamics, log="all", log_freq=100, log_graph=False 227 | ) 228 | except: 229 | pass 230 | 231 | ckpt_path = f"checkpoint/{project}/{wandb_logger.experiment.name}" 232 | earlystopping = EarlyStopping( 233 | monitor="val_ep_scaled_err", 234 | patience=2000, 235 | verbose=True, 236 | log_rank_zero_only=True, 237 | ) 238 | checkpoint_callback = ModelCheckpoint( 239 | monitor="val_ep_scaled_err", 240 | dirpath=ckpt_path, 241 | filename="sb-{epoch:03d}-{val_ep_scaled_err:.4f}", 242 | every_n_epochs=save_epochs, 243 | save_top_k=-1, 244 | ) 245 | lr_monitor = LearningRateMonitor(logging_interval='step') 246 | callbacks = [earlystopping, checkpoint_callback, TQDMProgressBar(), lr_monitor] 247 | 248 | strategy = None 249 | devices = [0] 250 | strategy = DDPStrategy(find_unused_parameters=True) 251 | if strategy is not None: 252 | devices = list(range(torch.cuda.device_count())) 253 | if len(devices) == 1: 254 | strategy = None 255 | 256 | if training_config["ema"]: 257 | callbacks.append( 258 | EMACallback( 259 | pl_module=ddpm, 260 | decay=training_config["ema_decay"]) 261 | 262 | ) 263 | 264 | print("config: ", config) 265 | trainer = Trainer( 266 | max_epochs=3000, 267 | accelerator="gpu", 268 | deterministic=False, 269 | devices=devices, 270 | strategy=strategy, 271 | log_every_n_steps=20, 272 | callbacks=callbacks, 273 | profiler=None, 274 | logger=wandb_logger, 275 | accumulate_grad_batches=1, 276 | gradient_clip_val=training_config["gradient_clip_val"], 277 | limit_train_batches=200, 278 | limit_val_batches=20, 279 | replace_sampler_ddp=False, 280 | # resume_from_checkpoint=checkpoint_path, 281 | # max_time="00:10:00:00", 282 | ) 283 | 284 | trainer.fit(ddpm) 285 | -------------------------------------------------------------------------------- /reactot/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from ._graph_tools import ( 2 | get_edges_index, 3 | get_subgraph_mask, 4 | get_n_frag_switch, 5 | get_mask_for_frag, 6 | ) 7 | -------------------------------------------------------------------------------- /reactot/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/utils/__pycache__/_graph_tools.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/utils/__pycache__/_graph_tools.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/utils/__pycache__/bond_analyze.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/utils/__pycache__/bond_analyze.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/utils/__pycache__/sampling_tools.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/utils/__pycache__/sampling_tools.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/utils/__pycache__/training_tools.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepprinciple/react-ot/8f03066d84f81fb4a94062e3f6390912aa5027da/reactot/utils/__pycache__/training_tools.cpython-310.pyc -------------------------------------------------------------------------------- /reactot/utils/_graph_tools.py: -------------------------------------------------------------------------------- 1 | """Utility functions for graphs.""" 2 | from typing import List, Optional 3 | 4 | import numpy as np 5 | import torch 6 | from torch import Tensor 7 | 8 | 9 | def get_edges_index( 10 | combined_mask: Tensor, 11 | pos: Optional[Tensor] = None, 12 | edge_cutoff: Optional[float] = None, 13 | remove_self_edge: bool = False, 14 | ) -> Tensor: 15 | r""" 16 | 17 | Args: 18 | combined_mask (Tensor): Combined mask for all fragments. 19 | Edges are built for nodes with the same indexes in the mask. 20 | pos (Optional[Tensor]): 3D coordinations of nodes. Defaults to None. 21 | edge_cutoff (Optional[float]): cutoff for building edges within a fragment. 22 | Defaults to None. 23 | remove_self_edge (bool): whether to remove self-connecting edge (i.e., ii). 24 | Defaults to False. 25 | 26 | Returns: 27 | Tensor: [2, n_edges], i for node index. 28 | """ 29 | # TODO: cache batches for each example in self._edges_dict[n_nodes] 30 | adj = combined_mask[:, None] == combined_mask[None, :] 31 | if edge_cutoff is not None: 32 | adj = adj & (torch.cdist(pos, pos) <= edge_cutoff) 33 | if remove_self_edge: 34 | adj = adj.fill_diagonal_(False) 35 | edges = torch.stack(torch.where(adj), dim=0) 36 | return edges 37 | 38 | 39 | def get_subgraph_mask(edge_index: Tensor, n_frag_switch: Tensor) -> Tensor: 40 | r"""Filter out edges that have inter-fragment connections. 41 | Example: 42 | edge_index: [ 43 | [0, 0, 1, 1, 2, 2], 44 | [1, 2, 0, 2, 0, 1], 45 | ] 46 | n_frag_switch: [0, 0, 1] 47 | -> [1, 0, 1, 0, 0, 0] 48 | 49 | Args: 50 | edge_index (Tensor): e_ij 51 | n_frag_switch (Tensor): fragment that a node belongs to 52 | 53 | Returns: 54 | Tensor: [n_edge], 1 for inner- and 0 for inter-fragment edge 55 | """ 56 | subgraph_mask = torch.zeros(edge_index.size(1)).long() 57 | in_same_frag = n_frag_switch[edge_index[0]] == n_frag_switch[edge_index[1]] 58 | subgraph_mask[torch.where(in_same_frag)] = 1 59 | return subgraph_mask.to(edge_index.device) 60 | 61 | 62 | def get_n_frag_switch(natm_list: List[Tensor]) -> Tensor: 63 | r"""Get the type of fragments to which each node belongs 64 | Example: [Tensor(1, 1), Tensor(2, 1)] -> [0, 0, 1, 1 ,1] 65 | 66 | Args: 67 | natm_list (List[Tensor]): [Tensor([number of atoms per small fragment])] 68 | 69 | Returns: 70 | Tensor: [n_nodes], type of fragment each node belongs to 71 | """ 72 | shapes = [natm.shape[0] for natm in natm_list] 73 | assert np.std(shapes) == 0, "Tensor must be the same length for " 74 | n_frag_switch = torch.repeat_interleave( 75 | torch.arange(len(natm_list), device=natm_list[0].device), 76 | torch.tensor( 77 | [torch.sum(natm).item() for natm in natm_list], 78 | device=natm_list[0].device, 79 | ), 80 | ) 81 | return n_frag_switch.to(natm_list[0].device) 82 | 83 | 84 | def get_mask_for_frag(natm: Tensor) -> Tensor: 85 | r"""Get fragment index for each node 86 | Example: Tensor([2, 0, 3]) -> [0, 0, 2, 2, 2] 87 | 88 | Args: 89 | natm (Tensor): number of nodes per small fragment 90 | 91 | Returns: 92 | Tensor: [n_node], the natural index of fragment a node belongs to 93 | """ 94 | return torch.repeat_interleave( 95 | torch.arange(natm.size(0), device=natm.device), natm 96 | ).to(natm.device) 97 | 98 | 99 | def get_inner_edge_index(subgraph_mask: Tensor): 100 | return torch.stack(torch.where(subgraph_mask), dim=0) 101 | -------------------------------------------------------------------------------- /reactot/utils/bond_analyze.py: -------------------------------------------------------------------------------- 1 | # Bond lengths from: 2 | # http://www.wiredchemist.com/chemistry/data/bond_energies_lengths.html 3 | # And: 4 | # http://chemistry-reference.com/tables/Bond%20Lengths%20and%20Enthalpies.pdf 5 | bonds1 = {'H': {'H': 74, 'C': 109, 'N': 101, 'O': 96, 'F': 92, 6 | 'B': 119, 'Si': 148, 'P': 144, 'As': 152, 'S': 134, 7 | 'Cl': 127, 'Br': 141, 'I': 161}, 8 | 'C': {'H': 109, 'C': 154, 'N': 147, 'O': 143, 'F': 135, 9 | 'Si': 185, 'P': 184, 'S': 182, 'Cl': 177, 'Br': 194, 10 | 'I': 214}, 11 | 'N': {'H': 101, 'C': 147, 'N': 145, 'O': 140, 'F': 136, 12 | 'Cl': 175, 'Br': 214, 'S': 168, 'I': 222, 'P': 177}, 13 | 'O': {'H': 96, 'C': 143, 'N': 140, 'O': 148, 'F': 142, 14 | 'Br': 172, 'S': 151, 'P': 163, 'Si': 163, 'Cl': 164, 15 | 'I': 194}, 16 | 'F': {'H': 92, 'C': 135, 'N': 136, 'O': 142, 'F': 142, 17 | 'S': 158, 'Si': 160, 'Cl': 166, 'Br': 178, 'P': 156, 18 | 'I': 187}, 19 | 'B': {'H': 119, 'Cl': 175}, 20 | 'Si': {'Si': 233, 'H': 148, 'C': 185, 'O': 163, 'S': 200, 21 | 'F': 160, 'Cl': 202, 'Br': 215, 'I': 243 }, 22 | 'Cl': {'Cl': 199, 'H': 127, 'C': 177, 'N': 175, 'O': 164, 23 | 'P': 203, 'S': 207, 'B': 175, 'Si': 202, 'F': 166, 24 | 'Br': 214}, 25 | 'S': {'H': 134, 'C': 182, 'N': 168, 'O': 151, 'S': 204, 26 | 'F': 158, 'Cl': 207, 'Br': 225, 'Si': 200, 'P': 210, 27 | 'I': 234}, 28 | 'Br': {'Br': 228, 'H': 141, 'C': 194, 'O': 172, 'N': 214, 29 | 'Si': 215, 'S': 225, 'F': 178, 'Cl': 214, 'P': 222}, 30 | 'P': {'P': 221, 'H': 144, 'C': 184, 'O': 163, 'Cl': 203, 31 | 'S': 210, 'F': 156, 'N': 177, 'Br': 222}, 32 | 'I': {'H': 161, 'C': 214, 'Si': 243, 'N': 222, 'O': 194, 33 | 'S': 234, 'F': 187, 'I': 266}, 34 | 'As': {'H': 152} 35 | } 36 | 37 | bonds2 = {'C': {'C': 134, 'N': 129, 'O': 120, 'S': 160}, 38 | 'N': {'C': 129, 'N': 125, 'O': 121}, 39 | 'O': {'C': 120, 'N': 121, 'O': 121, 'P': 150}, 40 | 'P': {'O': 150, 'S': 186}, 41 | 'S': {'P': 186}} 42 | 43 | 44 | bonds3 = {'C': {'C': 120, 'N': 116, 'O': 113}, 45 | 'N': {'C': 116, 'N': 110}, 46 | 'O': {'C': 113}} 47 | 48 | 49 | def print_table(bonds_dict): 50 | letters = ['H', 'C', 'O', 'N', 'P', 'S', 'F', 'Si', 'Cl', 'Br', 'I'] 51 | 52 | new_letters = [] 53 | for key in (letters + list(bonds_dict.keys())): 54 | if key in bonds_dict.keys(): 55 | if key not in new_letters: 56 | new_letters.append(key) 57 | 58 | letters = new_letters 59 | 60 | for j, y in enumerate(letters): 61 | if j == 0: 62 | for x in letters: 63 | print(f'{x} & ', end='') 64 | print() 65 | for i, x in enumerate(letters): 66 | if i == 0: 67 | print(f'{y} & ', end='') 68 | if x in bonds_dict[y]: 69 | print(f'{bonds_dict[y][x]} & ', end='') 70 | else: 71 | print('- & ', end='') 72 | print() 73 | 74 | 75 | # print_table(bonds3) 76 | 77 | 78 | def check_consistency_bond_dictionaries(): 79 | for bonds_dict in [bonds1, bonds2, bonds3]: 80 | for atom1 in bonds1: 81 | for atom2 in bonds_dict[atom1]: 82 | bond = bonds_dict[atom1][atom2] 83 | try: 84 | bond_check = bonds_dict[atom2][atom1] 85 | except KeyError: 86 | raise ValueError('Not in dict ' + str((atom1, atom2))) 87 | 88 | assert bond == bond_check, ( 89 | f'{bond} != {bond_check} for {atom1}, {atom2}') 90 | 91 | 92 | stdv = {'H': 5, 'C': 1, 'N': 1, 'O': 2, 'F': 3} 93 | margin1, margin2, margin3 = 10, 5, 3 94 | 95 | allowed_bonds = {'H': 1, 'C': 4, 'N': 3, 'O': 2, 'F': 1, 'B': 3, 'Al': 3, 96 | 'Si': 4, 'P': [3, 5], 97 | 'S': 4, 'Cl': 1, 'As': 3, 'Br': 1, 'I': 1, 'Hg': [1, 2], 98 | 'Bi': [3, 5]} 99 | 100 | 101 | def get_bond_order(atom1, atom2, distance, check_exists=False): 102 | distance = 100 * distance # We change the metric 103 | 104 | # Check exists for large molecules where some atom pairs do not have a 105 | # typical bond length. 106 | if check_exists: 107 | if atom1 not in bonds1: 108 | return 0 109 | if atom2 not in bonds1[atom1]: 110 | return 0 111 | 112 | # margin1, margin2 and margin3 have been tuned to maximize the stability of 113 | # the QM9 true samples. 114 | if distance < bonds1[atom1][atom2] + margin1: 115 | 116 | # Check if atoms in bonds2 dictionary. 117 | if atom1 in bonds2 and atom2 in bonds2[atom1]: 118 | thr_bond2 = bonds2[atom1][atom2] + margin2 119 | if distance < thr_bond2: 120 | if atom1 in bonds3 and atom2 in bonds3[atom1]: 121 | thr_bond3 = bonds3[atom1][atom2] + margin3 122 | if distance < thr_bond3: 123 | return 3 # Triple 124 | return 2 # Double 125 | return 1 # Single 126 | return 0 # No bond 127 | 128 | 129 | def single_bond_only(threshold, length, margin1=5): 130 | if length < threshold + margin1: 131 | return 1 132 | return 0 133 | -------------------------------------------------------------------------------- /reactot/utils/examples/H2O_dissociated.xyz: -------------------------------------------------------------------------------- 1 | 3 2 | 3 | O 0 0 0 4 | H 0.8 0 0 5 | H 10 0 0 -------------------------------------------------------------------------------- /reactot/utils/examples/acetate.xyz: -------------------------------------------------------------------------------- 1 | 7 2 | charge=-1= 3 | C -4.71686 0.89919 0.05714 4 | C -3.24898 0.98400 -0.22830 5 | H -5.04167 1.74384 0.67862 6 | H -5.01710 -0.02205 0.56344 7 | H -5.21076 0.96874 -0.91208 8 | O -2.65909 2.05702 -0.34025 9 | O -2.63413 -0.18702 -0.48679 10 | -------------------------------------------------------------------------------- /reactot/utils/examples/chiral_stereo_test.xyz: -------------------------------------------------------------------------------- 1 | 15 2 | Energy: 10.5637353 3 | C -5.48821 0.02982 -0.00852 4 | C -4.15445 -0.12323 -0.04208 5 | C -3.48273 -1.46491 0.04697 6 | F -3.88123 -2.11120 1.17935 7 | C -1.96681 -1.36452 0.07853 8 | H -3.78257 -2.08264 -0.80658 9 | C -6.18988 1.34568 -0.08727 10 | H -6.12260 -0.84989 0.08936 11 | H -3.51606 0.75189 -0.13305 12 | H -5.49066 2.18549 -0.14705 13 | H -6.81679 1.48581 0.79859 14 | H -6.83374 1.37210 -0.97169 15 | H -1.62796 -0.78043 0.94086 16 | H -1.57677 -0.90140 -0.83351 17 | H -1.52787 -2.36296 0.17627 18 | -------------------------------------------------------------------------------- /reactot/utils/examples/ethane.xyz: -------------------------------------------------------------------------------- 1 | 8 2 | charge=0= 3 | C -4.58735 0.92696 0.00000 4 | C -3.11050 0.92696 0.00000 5 | H -4.93786 1.78883 0.58064 6 | H -4.93786 -0.00682 0.45608 7 | H -4.93786 0.99888 -1.03672 8 | H -2.75999 0.85505 1.03672 9 | H -2.75998 1.86075 -0.45608 10 | H -2.75998 0.06509 -0.58064 11 | -------------------------------------------------------------------------------- /reactot/utils/examples/ethane_radical.xyz: -------------------------------------------------------------------------------- 1 | 6 2 | charge=0= 3 | C -4.58735 0.92696 0.00000 4 | C -3.11050 0.92696 0.00000 5 | H -4.93786 1.78883 0.58064 6 | H -4.93786 -0.00682 0.45608 7 | H -4.93786 0.99888 -1.03672 8 | H -2.75999 0.85505 1.03672 9 | -------------------------------------------------------------------------------- /reactot/utils/examples/propylbenzene.xyz: -------------------------------------------------------------------------------- 1 | 20 2 | 3 | C -2.08081073 1.27759366 0.52999704 4 | C -1.36085808 0.01534835 0.13171776 5 | C 0.12921265 -0.00145767 -0.01251015 6 | C 0.89390756 1.16259960 0.22072207 7 | C 2.28529729 1.14285208 0.08499036 8 | C 2.93783862 -0.03314066 -0.28435514 9 | C 2.20046595 -1.19345389 -0.51916959 10 | C 0.80878206 -1.18180595 -0.38553184 11 | C -2.17184071 -1.22963114 -0.11838690 12 | H -1.72431086 1.61348849 1.52614588 13 | H -3.17848660 1.12721396 0.59360457 14 | H -1.88832766 2.07143028 -0.22166901 15 | H 0.42742526 2.09446201 0.50865072 16 | H 2.85855884 2.04284076 0.26700529 17 | H 4.01510494 -0.04529350 -0.38861905 18 | H 2.70792713 -2.10563507 -0.80577565 19 | H 0.27503723 -2.10238605 -0.57663639 20 | H -1.85660650 -2.02918702 0.58415512 21 | H -2.02061491 -1.57122523 -1.16369726 22 | H -3.25770147 -1.05461302 0.02936218 23 | -------------------------------------------------------------------------------- /reactot/utils/sampling_tools.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch 3 | import numpy as np 4 | from reactot.utils import bond_analyze 5 | 6 | 7 | def write_xyz(mol, dataset_info, xyzfile="tmp.xyz"): 8 | atom_decoder = dataset_info['atom_decoder'] 9 | n_atom = mol["atom"].size(0) 10 | with open(xyzfile, "w") as fo: 11 | fo.write(str(n_atom) + "\n\n") 12 | for ii in range(n_atom): 13 | pos = mol["pos"][ii].cpu().numpy() 14 | ele = atom_decoder[mol["atom"][ii]] 15 | _x = " ".join([str(__x) for __x in pos]) 16 | fo.write(f"{ele} {_x}\n") 17 | 18 | 19 | def check_stability(mol, dataset_info, debug=False): 20 | positions, atom_type = mol["pos"], mol["atom"] 21 | assert len(positions.shape) == 2 22 | assert positions.shape[1] == 3 23 | atom_decoder = dataset_info['atom_decoder'] 24 | x = positions[:, 0] 25 | y = positions[:, 1] 26 | z = positions[:, 2] 27 | 28 | nr_bonds = np.zeros(len(x), dtype='int') 29 | 30 | for i in range(len(x)): 31 | for j in range(i + 1, len(x)): 32 | p1 = np.array([x[i], y[i], z[i]]) 33 | p2 = np.array([x[j], y[j], z[j]]) 34 | dist = np.sqrt(np.sum((p1 - p2) ** 2)) 35 | atom1, atom2 = atom_decoder[atom_type[i]], atom_decoder[atom_type[j]] 36 | pair = sorted([atom_type[i], atom_type[j]]) 37 | if dataset_info['name'] == 'qm9': 38 | order = bond_analyze.get_bond_order(atom1, atom2, dist) 39 | else: 40 | raise KeyError("only qm9 is allowed!") 41 | # if i == 3 or j == 3: 42 | # print(i, j, dist, order) 43 | nr_bonds[i] += order 44 | nr_bonds[j] += order 45 | nr_stable_bonds = 0 46 | for atom_type_i, nr_bonds_i in zip(atom_type, nr_bonds): 47 | possible_bonds = bond_analyze.allowed_bonds[atom_decoder[atom_type_i]] 48 | # print(atom_decoder[atom_type_i], nr_bonds_i) 49 | if type(possible_bonds) == int: 50 | is_stable = possible_bonds >= nr_bonds_i 51 | else: 52 | is_stable = nr_bonds_i in possible_bonds 53 | if not is_stable and debug: 54 | print("Invalid bonds for molecule %s with %d bonds" % (atom_decoder[atom_type_i], nr_bonds_i)) 55 | nr_stable_bonds += int(is_stable) 56 | 57 | molecule_stable = nr_stable_bonds == len(x) 58 | return int(molecule_stable), nr_stable_bonds, len(x) 59 | 60 | 61 | def assemble_sample_inputs( 62 | atoms: List, 63 | device: torch.device = torch.device("cuda"), 64 | n_samples: int = 1, 65 | frag_type: bool = False, 66 | ): 67 | empty_site = torch.tensor([[1, 0, 0, 0, 0, 1]], device=device) 68 | if not frag_type: 69 | decoders = [ 70 | { 71 | "H": [1, 0, 0, 0, 0, 1], 72 | "C": [0, 1, 0, 0, 0, 6], 73 | "N": [0, 0, 1, 0, 0, 7], 74 | "O": [0, 0, 0, 1, 0, 8], 75 | "F": [0, 0, 0, 0, 1, 9] 76 | } 77 | ] * 2 78 | else: 79 | decoders = [ 80 | { 81 | "H": [1, 0, 0, 0, 0, 1, 0], 82 | "C": [0, 1, 0, 0, 0, 6, 0], 83 | "N": [0, 0, 1, 0, 0, 7, 0], 84 | "O": [0, 0, 0, 1, 0, 8, 0], 85 | "F": [0, 0, 0, 0, 1, 9, 0] 86 | }, 87 | { 88 | "H": [1, 0, 0, 0, 0, 1, 1], 89 | "C": [0, 1, 0, 0, 0, 6, 1], 90 | "N": [0, 0, 1, 0, 0, 7, 1], 91 | "O": [0, 0, 0, 1, 0, 8, 1], 92 | "F": [0, 0, 0, 0, 1, 9, 1] 93 | } 94 | ] 95 | 96 | h0 = [ 97 | torch.cat( 98 | [ 99 | torch.tensor( 100 | [decoders[ii % 2][atom] for atom in atoms], 101 | device=device 102 | ) 103 | for _ in range(n_samples) 104 | ] 105 | ) for ii in range(3) 106 | ] 107 | return h0 108 | 109 | 110 | def write_single_xyz(xyzfile, natoms, out): 111 | C2A = { 112 | 1: "H", 113 | 2: "He", 114 | 3: "Li", 115 | 4: "Be", 116 | 6: "C", 117 | 7: "N", 118 | 8: "O", 119 | 9: "F", 120 | 3: "Li", 121 | 12: "Mg", 122 | 14: "Si", 123 | 15: "P", 124 | 16: "S", 125 | 17: "Cl", 126 | 25: "Mn", 127 | 26: "Fe", 128 | 27: "Co", 129 | 28: "Ni", 130 | 29: "Cu", 131 | 30: "Zn", 132 | 47: "Ag", 133 | 48: "Cd", 134 | 40: "Zr", 135 | 72: "Hf", 136 | } 137 | with open(xyzfile, "w") as fo: 138 | fo.write(str(natoms) + "\n\n") 139 | for ele in out: 140 | x = ele[:3].cpu().numpy() 141 | _a = C2A[ele[-1].long().item()] 142 | _x = " ".join([str(__x) for __x in x]) 143 | fo.write(f"{_a} {_x}\n") 144 | 145 | 146 | def write_tmp_xyz(fragments_nodes, out_samples, idx=[0], prefix="gen", localpath="tmp", ex_ind=0): 147 | TYPEMAP = { 148 | 0: "react", 149 | 1: "ts", 150 | 2: "prod", 151 | 152 | } 153 | for ii in idx: 154 | st = TYPEMAP[ii] 155 | start_ind, end_ind = 0, 0 156 | for jj, natoms in enumerate(fragments_nodes[0]): 157 | _jj = jj + ex_ind 158 | 159 | # xyzfile = f"{localpath}/{_jj}.xyz" 160 | xyzfile = f"{localpath}/{prefix}_{_jj}_{st}.xyz" 161 | 162 | end_ind += natoms.item() 163 | write_single_xyz( 164 | xyzfile, 165 | natoms.item(), 166 | out=out_samples[ii][start_ind: end_ind], 167 | ) 168 | start_ind = end_ind 169 | -------------------------------------------------------------------------------- /reactot/utils/training_tools.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Iterable 2 | import numpy as np 3 | import torch 4 | 5 | 6 | class Queue(): 7 | def __init__(self, max_len=50): 8 | self.items = [] 9 | self.max_len = max_len 10 | 11 | def __len__(self): 12 | return len(self.items) 13 | 14 | def add(self, item): 15 | self.items.insert(0, item) 16 | if len(self) > self.max_len: 17 | self.items.pop() 18 | 19 | def mean(self): 20 | return np.mean(self.items) 21 | 22 | def std(self): 23 | return np.std(self.items) 24 | 25 | 26 | ##### 27 | 28 | 29 | def get_grad_norm( 30 | parameters: Union[torch.Tensor, Iterable[torch.Tensor]], 31 | norm_type: float = 2.0 32 | ) -> torch.Tensor: 33 | """ 34 | Adapted from: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html#clip_grad_norm_ 35 | """ 36 | 37 | if isinstance(parameters, torch.Tensor): 38 | parameters = [parameters] 39 | parameters = [p for p in parameters if p.grad is not None] 40 | 41 | norm_type = float(norm_type) 42 | 43 | if len(parameters) == 0: 44 | return torch.tensor(0.) 45 | 46 | device = parameters[0].grad.device 47 | 48 | total_norm = torch.norm(torch.stack( 49 | [torch.norm(p.grad.detach(), norm_type).to(device) for p in 50 | parameters]), norm_type) 51 | 52 | return total_norm 53 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tos==2.8.1 2 | ase==3.23.0 3 | pandas==2.2.3 4 | numpy==1.26.4 5 | ipdb==0.13.13 6 | pymatgen==2024.11.13 7 | requests 8 | kafka-python==2.0.2 -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # Helper file to handle all configs 2 | 3 | [coverage:run] 4 | # .coveragerc to control coverage.py and pytest-cov 5 | omit = 6 | # Omit the tests 7 | */tests/* 8 | # Omit generated versioneer 9 | mof_diffusion/_version.py 10 | 11 | [yapf] 12 | # YAPF, in .style.yapf files this shows up as "[style]" header 13 | COLUMN_LIMIT = 119 14 | INDENT_WIDTH = 4 15 | USE_TABS = False 16 | 17 | [flake8] 18 | # Flake8, PyFlakes, etc 19 | max-line-length = 119 20 | 21 | [aliases] 22 | test = pytest 23 | --------------------------------------------------------------------------------