├── .github └── workflows │ └── codeql.yml ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── SECURITY.md ├── docs ├── ai2bmd.md └── fragmentation.md ├── examples ├── abd.pdb ├── chig.pdb ├── chig_preprocessed │ ├── chig-preeq-nowat.pdb │ └── chig-preeq.pdb ├── trpcage.pdb └── ww.pdb ├── scripts └── ai2bmd └── src ├── AIMD ├── __init__.py ├── arguments.py ├── fragment.py ├── preprocess.py ├── protein.py └── simulator.py ├── Calculators ├── __init__.py ├── async_utils.py ├── bonded.py ├── calculator.py ├── combiner.py ├── device_strategy.py ├── fragment.py ├── nonbonded.py ├── pme.py ├── qmmm.py ├── tinker_async.py └── visnet_calculator.py ├── Fragmentation ├── __init__.py ├── basefrag.py ├── bondlen │ ├── AA.npz │ ├── AN.npz │ ├── ANAN.npz │ ├── CC.npz │ ├── DD.npz │ ├── EE.npz │ ├── FF.npz │ ├── GG.npz │ ├── HH.npz │ ├── HID.npz │ ├── II.npz │ ├── KK.npz │ ├── LL.npz │ ├── MM.npz │ ├── NN.npz │ ├── PP.npz │ ├── QQ.npz │ ├── RR.npz │ ├── SS.npz │ ├── TT.npz │ ├── VV.npz │ ├── WW.npz │ └── YY.npz ├── distancefrag.py ├── hydrogen │ ├── __init__.py │ ├── ctable.py │ ├── energies.py │ └── topology.py ├── mm.in └── prmtop │ ├── AA.prmtop │ ├── AN.prmtop │ ├── ANAN.prmtop │ ├── CC.prmtop │ ├── CYX.prmtop │ ├── DD.prmtop │ ├── EE.prmtop │ ├── FF.prmtop │ ├── GG.prmtop │ ├── HH.prmtop │ ├── HID.prmtop │ ├── II.prmtop │ ├── KK.prmtop │ ├── LL.prmtop │ ├── MM.prmtop │ ├── NN.prmtop │ ├── PP.prmtop │ ├── QQ.prmtop │ ├── RR.prmtop │ ├── SS.prmtop │ ├── TT.prmtop │ ├── VV.prmtop │ ├── WW.prmtop │ └── YY.prmtop ├── ViSNet ├── __init__.py ├── checkpoints │ ├── visnet-uni-2ef43f29ec78fa5fef0b3de832bfada9.ckpt │ └── visnet-uni-de11d1421ccda37ffab07d7403c8f5bb.ckpt └── model │ ├── __init__.py │ ├── output_modules.py │ ├── priors.py │ ├── utils.py │ ├── visnet.py │ └── visnet_block.py ├── main.py └── utils ├── amoebabio18.prm ├── pdb.py ├── reference.py ├── seq_dict.pkl ├── signals.py ├── system.py ├── traj2dcd.py └── utils.py /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "main" ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ "main" ] 20 | schedule: 21 | - cron: '24 14 * * 4' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: ['python'] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v3 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v2 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | 52 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 53 | # queries: security-extended,security-and-quality 54 | 55 | 56 | # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java). 57 | # If this step fails, then you should remove it and run the build manually (see below) 58 | - name: Autobuild 59 | uses: github/codeql-action/autobuild@v2 60 | 61 | # ℹ️ Command-line programs to run using the OS shell. 62 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 63 | 64 | # If the Autobuild fails above, remove it and uncomment the following three lines. 65 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 66 | 67 | # - run: | 68 | # echo "Run, Build Application using script" 69 | # ./location_of_script_within_repo/buildscript.sh 70 | 71 | - name: Perform CodeQL Analysis 72 | uses: github/codeql-action/analyze@v2 73 | with: 74 | category: "/language:${{matrix.language}}" 75 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.rsuser 8 | *.suo 9 | *.user 10 | *.userosscache 11 | *.sln.docstates 12 | 13 | # User-specific files (MonoDevelop/Xamarin Studio) 14 | *.userprefs 15 | 16 | # Mono auto generated files 17 | mono_crash.* 18 | 19 | # Build results 20 | [Dd]ebug/ 21 | [Dd]ebugPublic/ 22 | [Rr]elease/ 23 | [Rr]eleases/ 24 | x64/ 25 | x86/ 26 | [Aa][Rr][Mm]/ 27 | [Aa][Rr][Mm]64/ 28 | bld/ 29 | [Bb]in/ 30 | [Oo]bj/ 31 | [Ll]og/ 32 | [Ll]ogs/ 33 | 34 | # Visual Studio 2015/2017 cache/options directory 35 | .vs/ 36 | # Uncomment if you have tasks that create the project's static files in wwwroot 37 | #wwwroot/ 38 | 39 | # Visual Studio 2017 auto generated files 40 | Generated\ Files/ 41 | 42 | # MSTest test Results 43 | [Tt]est[Rr]esult*/ 44 | [Bb]uild[Ll]og.* 45 | 46 | # NUnit 47 | *.VisualState.xml 48 | TestResult.xml 49 | nunit-*.xml 50 | 51 | # Build Results of an ATL Project 52 | [Dd]ebugPS/ 53 | [Rr]eleasePS/ 54 | dlldata.c 55 | 56 | # Benchmark Results 57 | BenchmarkDotNet.Artifacts/ 58 | 59 | # .NET Core 60 | project.lock.json 61 | project.fragment.lock.json 62 | artifacts/ 63 | 64 | # StyleCop 65 | StyleCopReport.xml 66 | 67 | # Files built by Visual Studio 68 | *_i.c 69 | *_p.c 70 | *_h.h 71 | *.ilk 72 | *.meta 73 | *.obj 74 | *.iobj 75 | *.pch 76 | *.pdb 77 | *.ipdb 78 | *.pgc 79 | *.pgd 80 | *.rsp 81 | *.sbr 82 | *.tlb 83 | *.tli 84 | *.tlh 85 | *.tmp 86 | *.tmp_proj 87 | *_wpftmp.csproj 88 | *.log 89 | *.vspscc 90 | *.vssscc 91 | .builds 92 | *.pidb 93 | *.svclog 94 | *.scc 95 | 96 | # Chutzpah Test files 97 | _Chutzpah* 98 | 99 | # Visual C++ cache files 100 | ipch/ 101 | *.aps 102 | *.ncb 103 | *.opendb 104 | *.opensdf 105 | *.sdf 106 | *.cachefile 107 | *.VC.db 108 | *.VC.VC.opendb 109 | 110 | # Visual Studio profiler 111 | *.psess 112 | *.vsp 113 | *.vspx 114 | *.sap 115 | 116 | # Visual Studio Trace Files 117 | *.e2e 118 | 119 | # TFS 2012 Local Workspace 120 | $tf/ 121 | 122 | # Guidance Automation Toolkit 123 | *.gpState 124 | 125 | # ReSharper is a .NET coding add-in 126 | _ReSharper*/ 127 | *.[Rr]e[Ss]harper 128 | *.DotSettings.user 129 | 130 | # TeamCity is a build add-in 131 | _TeamCity* 132 | 133 | # DotCover is a Code Coverage Tool 134 | *.dotCover 135 | 136 | # AxoCover is a Code Coverage Tool 137 | .axoCover/* 138 | !.axoCover/settings.json 139 | 140 | # Visual Studio code coverage results 141 | *.coverage 142 | *.coveragexml 143 | 144 | # NCrunch 145 | _NCrunch_* 146 | .*crunch*.local.xml 147 | nCrunchTemp_* 148 | 149 | # MightyMoose 150 | *.mm.* 151 | AutoTest.Net/ 152 | 153 | # Web workbench (sass) 154 | .sass-cache/ 155 | 156 | # Installshield output folder 157 | [Ee]xpress/ 158 | 159 | # DocProject is a documentation generator add-in 160 | DocProject/buildhelp/ 161 | DocProject/Help/*.HxT 162 | DocProject/Help/*.HxC 163 | DocProject/Help/*.hhc 164 | DocProject/Help/*.hhk 165 | DocProject/Help/*.hhp 166 | DocProject/Help/Html2 167 | DocProject/Help/html 168 | 169 | # Click-Once directory 170 | publish/ 171 | 172 | # Publish Web Output 173 | *.[Pp]ublish.xml 174 | *.azurePubxml 175 | # Note: Comment the next line if you want to checkin your web deploy settings, 176 | # but database connection strings (with potential passwords) will be unencrypted 177 | *.pubxml 178 | *.publishproj 179 | 180 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 181 | # checkin your Azure Web App publish settings, but sensitive information contained 182 | # in these scripts will be unencrypted 183 | PublishScripts/ 184 | 185 | # NuGet Packages 186 | *.nupkg 187 | # NuGet Symbol Packages 188 | *.snupkg 189 | # The packages folder can be ignored because of Package Restore 190 | **/[Pp]ackages/* 191 | # except build/, which is used as an MSBuild target. 192 | !**/[Pp]ackages/build/ 193 | # Uncomment if necessary however generally it will be regenerated when needed 194 | #!**/[Pp]ackages/repositories.config 195 | # NuGet v3's project.json files produces more ignorable files 196 | *.nuget.props 197 | *.nuget.targets 198 | 199 | # Microsoft Azure Build Output 200 | csx/ 201 | *.build.csdef 202 | 203 | # Microsoft Azure Emulator 204 | ecf/ 205 | rcf/ 206 | 207 | # Windows Store app package directories and files 208 | AppPackages/ 209 | BundleArtifacts/ 210 | Package.StoreAssociation.xml 211 | _pkginfo.txt 212 | *.appx 213 | *.appxbundle 214 | *.appxupload 215 | 216 | # Visual Studio cache files 217 | # files ending in .cache can be ignored 218 | *.[Cc]ache 219 | # but keep track of directories ending in .cache 220 | !?*.[Cc]ache/ 221 | 222 | # Others 223 | ClientBin/ 224 | ~$* 225 | *~ 226 | *.dbmdl 227 | *.dbproj.schemaview 228 | *.jfm 229 | *.pfx 230 | *.publishsettings 231 | orleans.codegen.cs 232 | 233 | # Including strong name files can present a security risk 234 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 235 | #*.snk 236 | 237 | # Since there are multiple workflows, uncomment next line to ignore bower_components 238 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 239 | #bower_components/ 240 | 241 | # RIA/Silverlight projects 242 | Generated_Code/ 243 | 244 | # Backup & report files from converting an old project file 245 | # to a newer Visual Studio version. Backup files are not needed, 246 | # because we have git ;-) 247 | _UpgradeReport_Files/ 248 | Backup*/ 249 | UpgradeLog*.XML 250 | UpgradeLog*.htm 251 | ServiceFabricBackup/ 252 | *.rptproj.bak 253 | 254 | # SQL Server files 255 | *.mdf 256 | *.ldf 257 | *.ndf 258 | 259 | # Business Intelligence projects 260 | *.rdl.data 261 | *.bim.layout 262 | *.bim_*.settings 263 | *.rptproj.rsuser 264 | *- [Bb]ackup.rdl 265 | *- [Bb]ackup ([0-9]).rdl 266 | *- [Bb]ackup ([0-9][0-9]).rdl 267 | 268 | # Microsoft Fakes 269 | FakesAssemblies/ 270 | 271 | # GhostDoc plugin setting file 272 | *.GhostDoc.xml 273 | 274 | # Node.js Tools for Visual Studio 275 | .ntvs_analysis.dat 276 | node_modules/ 277 | 278 | # Visual Studio 6 build log 279 | *.plg 280 | 281 | # Visual Studio 6 workspace options file 282 | *.opt 283 | 284 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 285 | *.vbw 286 | 287 | # Visual Studio LightSwitch build output 288 | **/*.HTMLClient/GeneratedArtifacts 289 | **/*.DesktopClient/GeneratedArtifacts 290 | **/*.DesktopClient/ModelManifest.xml 291 | **/*.Server/GeneratedArtifacts 292 | **/*.Server/ModelManifest.xml 293 | _Pvt_Extensions 294 | 295 | # Paket dependency manager 296 | .paket/paket.exe 297 | paket-files/ 298 | 299 | # FAKE - F# Make 300 | .fake/ 301 | 302 | # CodeRush personal settings 303 | .cr/personal 304 | 305 | # Python Tools for Visual Studio (PTVS) 306 | __pycache__/ 307 | *.pyc 308 | 309 | # Cake - Uncomment if you are using it 310 | # tools/** 311 | # !tools/packages.config 312 | 313 | # Tabs Studio 314 | *.tss 315 | 316 | # Telerik's JustMock configuration file 317 | *.jmconfig 318 | 319 | # BizTalk build output 320 | *.btp.cs 321 | *.btm.cs 322 | *.odx.cs 323 | *.xsd.cs 324 | 325 | # OpenCover UI analysis results 326 | OpenCover/ 327 | 328 | # Azure Stream Analytics local run output 329 | ASALocalRun/ 330 | 331 | # MSBuild Binary and Structured Log 332 | *.binlog 333 | 334 | # NVidia Nsight GPU debugger configuration file 335 | *.nvuser 336 | 337 | # MFractors (Xamarin productivity tool) working folder 338 | .mfractor/ 339 | 340 | # Local History for Visual Studio 341 | .localhistory/ 342 | 343 | # BeatPulse healthcheck temp database 344 | healthchecksdb 345 | 346 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 347 | MigrationBackup/ 348 | 349 | # Ionide (cross platform F# VS Code tools) working folder 350 | .ionide/ 351 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | ## [1.1.0] - 2025-02-18 6 | ### Added 7 | - new checkpoint with support for cystine residues 8 | 9 | ### Fixed 10 | - fix UB bug that arises with consecutive cystine residues 11 | 12 | ## [1.0.0] - 2025-02-07 13 | ### Added 14 | - option to exclude solvent in output (`--[no-]write-solvent`) 15 | - check number of residues in input file; fail gracefully with helpful message 16 | 17 | ### Changed 18 | - avoid usage of deprecated function (move socket file for async servers) 19 | - sanitise logging; `--verbose` controls all output 20 | 21 | ## [0.1.0] - 2024-11-07 22 | ### Added 23 | - initial public release 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /docs/ai2bmd.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/docs/ai2bmd.md -------------------------------------------------------------------------------- /docs/fragmentation.md: -------------------------------------------------------------------------------- 1 | # Fragmentation 2 | 3 | In AI2BMD, the input protein is fragmented into a number of dipeptides and 4 | ACE-NMEs, which are fed to the ViSNet model. In general, each dipeptide 5 | fragment consists of the backbone of the previous amino acid residue, all atoms 6 | of the current amino acid residue, and the backbone of the next amino acid 7 | residue, while each ACE-NME fragment consists of the acetyl group of an amino 8 | acid and the N-methyl amide group of the previous amino acid. 9 | 10 | ## Input 11 | 12 | The fragmentation process takes a protein as input, in the form of a PDB file. 13 | This file contains The columns are, from left to right, record type, atom 14 | index, atom name, residue type, residue index, cartesian coordinates, 15 | occupancy, temperature factor, and atom type. 16 | 17 | ``` 18 | ATOM 1 H1 ACE 1 10.845 8.614 5.964 1.00 0.00 H 19 | ATOM 2 CH3 ACE 1 10.143 9.373 5.620 1.00 0.00 C 20 | ATOM 3 H2 ACE 1 9.425 9.446 6.437 1.00 0.00 H 21 | ATOM 4 H3 ACE 1 9.643 9.085 4.695 1.00 0.00 H 22 | ATOM 5 C ACE 1 10.805 10.740 5.408 1.00 0.00 C 23 | ATOM 6 O ACE 1 10.682 11.417 4.442 1.00 0.00 O 24 | ATOM 7 N TYR 2 11.363 11.214 6.507 1.00 0.00 N 25 | ATOM 8 H TYR 2 11.345 10.675 7.361 1.00 0.00 H 26 | ATOM 9 CA TYR 2 11.963 12.513 6.704 1.00 0.00 C 27 | ATOM 10 HA TYR 2 12.444 12.615 5.731 1.00 0.00 H 28 | ATOM 11 CB TYR 2 10.909 13.631 6.882 1.00 0.00 C 29 | ATOM 12 HB2 TYR 2 10.230 13.593 6.030 1.00 0.00 H 30 | ATOM 13 HB3 TYR 2 10.264 13.332 7.708 1.00 0.00 H 31 | ATOM 14 CG TYR 2 11.383 15.023 7.089 1.00 0.00 C 32 | ATOM 15 CD1 TYR 2 11.729 15.440 8.404 1.00 0.00 C 33 | ATOM 16 HD1 TYR 2 11.527 14.718 9.182 1.00 0.00 H 34 | ATOM 17 CE1 TYR 2 12.033 16.805 8.578 1.00 0.00 C 35 | ATOM 18 HE1 TYR 2 12.380 17.146 9.543 1.00 0.00 H 36 | ATOM 19 CZ TYR 2 12.136 17.704 7.498 1.00 0.00 C 37 | ATOM 20 OH TYR 2 12.556 19.002 7.619 1.00 0.00 O 38 | ATOM 21 HH TYR 2 12.437 19.327 8.514 1.00 0.00 H 39 | ATOM 22 CE2 TYR 2 11.872 17.214 6.251 1.00 0.00 C 40 | ATOM 23 HE2 TYR 2 12.034 17.842 5.387 1.00 0.00 H 41 | ATOM 24 CD2 TYR 2 11.440 15.923 6.033 1.00 0.00 C 42 | ATOM 25 HD2 TYR 2 11.349 15.559 5.021 1.00 0.00 H 43 | ATOM 26 C TYR 2 13.070 12.496 7.702 1.00 0.00 C 44 | ATOM 27 O TYR 2 12.959 11.951 8.788 1.00 0.00 O 45 | ... 46 | ``` 47 | 48 | ## Extracting indices for each fragment 49 | 50 | The first step is to extract the indices of the atoms that belong to each 51 | dipeptide and ACE-NME fragment (implementation in 52 | `src/Fragmentation/basefrag.py:get_fragments_index`). We loop through the list 53 | of atoms in order, and insert the atom into the list of dipeptides and ACE-NMEs 54 | as appropriate. For example, the `CA` and `HA` atoms belong in the dipeptide 55 | and both ACE/NME fragments, while the `C` and `O` atoms belong only to the 56 | dipeptide and ACE (acetyl group) part of the fragment. The side-chain of each 57 | amino acid is present in only one of the dipeptides. Care is taken to calculate 58 | the correct dipeptide/ACE/NME index that each atom belongs to, and that edge 59 | cases (first/last dipeptides/ACE-NMEs) are handled properly. 60 | 61 | The following diagram illustrates the division of atoms. 62 | 63 | ``` 64 | H O H R O H R O H R O H R O 65 | | " | | " | | " | | " | | " 66 | H - C - C - N - C - C - N - C - C - N - C - C - N - C - C - ... 67 | | | | | | 68 | H H H H H 69 | |_________| |_________| |_________| |_________| |_________| 70 | ACE TYR TYR ASP PRO 71 | 72 | |------- 1st dipeptide -------| 73 | |----- 2nd dipeptide -----| 74 | |----- 3rd dipeptide -----| 75 | |----- 4th dipeptide -----| 76 | 77 | |-------------| 78 | 1st |-------------| 79 | ACE-NME 2nd |-------------| 80 | ACE-NME 3rd 81 | ACE-NME 82 | ``` 83 | 84 | ## Capping dipeptides with hydrogens 85 | 86 | The dipeptide and ACE-NME fragments are extracted directly from the full amino 87 | acid chain, and thus have dangling bonds at the edges, which is inconsistent 88 | with their original conditions. Hence, additional hydrogen atoms are added 89 | where necessary to avoid artifical interactions resulting from the incomplete 90 | covalent bonds. 91 | 92 | As a first approximation, the hydrogens are added along the same direction as 93 | the original bond, with the bond length modified to be the sum of the average 94 | radii of the two atoms. The coordinates of the added hydrogen atoms are later 95 | optimised to relax the system. Since the ACE-NME fragments overlap entirely 96 | with the dipeptide fragments, the optimisation is performed only for the 97 | dipeptide fragments. 98 | 99 | ``` 100 | R O H R O H R O H R O H R O 101 | | " | | " | | " | | " | | " 102 | ... - C - C - N - C - C - N - C - C - N - C - C - N - C - C - ... 103 | | | | | | 104 | H H H H H 105 | |------- dipeptide -------| 106 | 107 | *H O H R O H H* 108 | | " | | " | | 109 | *H - C - C - N - C - C - N - C - H* 110 | | | | 111 | H H H 112 | ``` 113 | 114 | In the example above, the dipeptide fragment has four broken bonds: 2 each on 115 | the alpha-carbons on the first and last residues that result from the removal 116 | of the side-chains and the neighbouring aminde and carboxyl groups. 117 | 118 | The number of added hydrogen atoms depends on the specific amino acid residues 119 | that make up the fragments. For example, glycine (GLY) already has a single 120 | hydrogen atom as its side-chain, and thus does not require a hydrogen atom to 121 | be added. 122 | 123 | ``` 124 | CA C O HA N CA C O H HA CB CG OD1 OD2 HB2 HB3 N CA HA | [H*] [H*] [H*] [H*] [H*] 125 | 126 | CA C [H*] [H*] O HA N CA C O H HA CB CG OD1 OD2 HB2 HB3 N [H*] CA HA [H*] [H*] 127 | ``` 128 | 129 | ## Reordering atoms 130 | 131 | The atoms in the dipeptide fragments have to be reordered to match the format 132 | that AMBER expects. This reordering of indices is implemented using a 133 | precomputed lookup table (`src/utils/seq_dict.pkl`) based on the three residue 134 | types that make up the dipeptide fragment. 135 | 136 | ``` 137 | CA C [H*] [H*] O HA N CA C O H HA CB CG OD1 OD2 HB2 HB3 N [H*] CA HA [H*] [H*] 138 | 139 | HA CA [H*] [H*] C O N H CA HA CB HB2 HB3 CG OD1 OD2 C O N [H*] CA HA [H*] [H*] 140 | ``` 141 | 142 | In our implementation, the insertion of hydrogen atoms and reordering of the 143 | dipeptide is performed in a single step. 144 | 145 | ``` 146 | CA C O HA N CA C O H HA CB CG OD1 OD2 HB2 HB3 N CA HA | [H*] [H*] [H*] [H*] [H*] 147 | 148 | HA CA [H*] [H*] C O N H CA HA CB HB2 HB3 CG OD1 OD2 C O N [H*] CA HA [H*] [H*] 149 | ``` 150 | 151 | The ACE-NME fragments can then be obtained by extracting the first and last six 152 | atoms of the appropriate dipeptides, e.g. the 1st ACE-NME fragment consists of 153 | the first six atoms of the 2nd dipeptide and the last six atoms of the 1st 154 | dipeptide. 155 | -------------------------------------------------------------------------------- /src/AIMD/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/AIMD/__init__.py -------------------------------------------------------------------------------- /src/AIMD/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | 6 | from Calculators.device_strategy import DeviceStrategy 7 | from utils.utils import src_dir 8 | 9 | 10 | _args = None 11 | 12 | 13 | def get(): 14 | if not _args: 15 | raise Exception("Arguments are not initialized. Call initialize() first.") 16 | return _args 17 | 18 | 19 | def init(argv=None): 20 | """Initializes the argument registry. If no argv is supplied (default), 21 | parses arguments from process command line. 22 | The initialization result will be kept in the module-level member `_args`, 23 | so that the settings can be retrieved from other modules with get(). 24 | """ 25 | global _args 26 | 27 | _src_dir = src_dir() 28 | parser = argparse.ArgumentParser(description="DL Molecular Simulation.") 29 | parser.add_argument( 30 | "--base-dir", 31 | type=str, 32 | default=os.getcwd(), 33 | help="A directory for running simulation", 34 | ) 35 | parser.add_argument( 36 | "--log-dir", 37 | type=str, 38 | default=None, 39 | help="A directory for saving results", 40 | ) 41 | parser.add_argument( 42 | "--ckpt-path", 43 | type=str, 44 | default=os.path.join(_src_dir, "ViSNet/checkpoints"), 45 | help="A directory including well-trained pytorch models", 46 | ) 47 | parser.add_argument( 48 | "--ckpt-type", 49 | type=str, 50 | default="2ef43f29ec78fa5fef0b3de832bfada9", 51 | choices=[ 52 | "2ef43f29ec78fa5fef0b3de832bfada9", 53 | "de11d1421ccda37ffab07d7403c8f5bb", 54 | ], 55 | help="Checkpoint type, which is the md5sum of the model checkpoint file", 56 | ) 57 | parser.add_argument( 58 | "--prot-file", 59 | type=str, 60 | required=True, 61 | help="Protein file for simulation", 62 | ) 63 | parser.add_argument( 64 | "--temp-k", 65 | type=int, 66 | default=300, 67 | help="Simulation temperature in Kelvin", 68 | ) 69 | parser.add_argument( 70 | "--timestep", 71 | type=float, 72 | default=1, 73 | help="TimeStep (fs) for simulation", 74 | ) 75 | parser.add_argument( 76 | "--sim-steps", 77 | type=int, 78 | default=1000, 79 | help="Simulation steps for simulation", 80 | ) 81 | parser.add_argument( 82 | "--preeq-steps", 83 | type=int, 84 | default=2000, 85 | help="Pre-equilibration simulation steps for each constraint", 86 | ) 87 | parser.add_argument( 88 | "--max-cyc", 89 | type=int, 90 | default=100, 91 | help="Maximum energy minimization cycles in preprocessing", 92 | ) 93 | parser.add_argument( 94 | "--constraints", 95 | action=argparse.BooleanOptionalAction, 96 | default=False, 97 | help="Constrain hydrogen bonds", 98 | ) 99 | parser.add_argument( 100 | "--solvent", 101 | action=argparse.BooleanOptionalAction, 102 | default=True, 103 | help="Use solvent or not", 104 | ) 105 | parser.add_argument( 106 | "--write-solvent", 107 | action=argparse.BooleanOptionalAction, 108 | default=True, 109 | help="Write coordinates of solvent atoms in output", 110 | ) 111 | parser.add_argument( 112 | "--preprocess-method", 113 | type=str, 114 | default="FF19SB", 115 | choices=["FF19SB", "AMOEBA"], 116 | help="Method to use for preprocessing the protein", 117 | ) 118 | parser.add_argument( 119 | "--mm-method", 120 | type=str, 121 | default="tinker-GPU", 122 | choices=["tinker", "tinker-GPU"], 123 | help="MM calculator for the nonbonded energy", 124 | ) 125 | parser.add_argument( 126 | "--mode", 127 | type=str, 128 | default="fragment", 129 | choices=["fragment", "visnet"], 130 | help="""Mode for performing calculations. 131 | fragment=Perform fragmentation (>1 amino acids in chain). 132 | visnet=Feed the input directly to ViSNet. 133 | """, 134 | ) 135 | parser.add_argument( 136 | "--fragment-longrange-calc", 137 | type=str, 138 | default="mm", 139 | choices=["mm", "pme"], 140 | help="Long-range interactions calculator for fragments; required for 'fragment' mode.", 141 | ) 142 | parser.add_argument( 143 | "--seed", 144 | type=int, 145 | default=0, 146 | help="Random seed for simulation", 147 | ) 148 | parser.add_argument( 149 | "--restart", 150 | action=argparse.BooleanOptionalAction, 151 | default=False, 152 | help="Restart the simulation", 153 | ) 154 | parser.add_argument( 155 | "--build-frames", 156 | action=argparse.BooleanOptionalAction, 157 | default=False, 158 | help="Build xyz frames from the trajectory after simulation", 159 | ) 160 | parser.add_argument( 161 | "--record-per-steps", 162 | type=int, 163 | default=100, 164 | help="Interval for writing out frame data", 165 | ) 166 | parser.add_argument( 167 | "--device-strategy", 168 | type=str, 169 | default="small-molecule", 170 | choices=["excess-compute", "small-molecule", "large-molecule"], 171 | help="""The compute device allocation strategy. 172 | excess-compute=Assume compute resources are more than sufficient for 173 | ViSNet inference. Reserves last GPU for solvent/non-bonded 174 | computation. 175 | small-molecule=Maximise resources for ViSNet. 176 | large-molecule=Maximise resources for ViSNet, while also maximising 177 | concurrency and usage of GPUs for computation. 178 | """, 179 | ) 180 | parser.add_argument( 181 | "--work-strategy", 182 | type=str, 183 | default="combined", 184 | choices=["combined"], 185 | help="""The work allocation strategy. 186 | combined=Distribute work evenly amongst all fragments. 187 | """, 188 | ) 189 | parser.add_argument( 190 | "--chunk-size", 191 | type=int, 192 | default=9999, 193 | help="""Define the maximum chunk size (in units of atoms) for 194 | ACE-NME/dipeptide fragments. The data will be split and processed 195 | according to these sizes. 196 | """, 197 | ) 198 | parser.add_argument( 199 | "-v", 200 | "--verbose", 201 | action='count', 202 | default=0, 203 | help="""Verbosity level""" 204 | ) 205 | 206 | _args = parser.parse_args(argv) 207 | _args.prot_name = os.path.basename(_args.prot_file)[:-4] 208 | if _args.log_dir is None: 209 | _args.log_dir = os.path.join(_args.base_dir, f"Logs-{_args.prot_name}") 210 | os.makedirs(_args.log_dir, exist_ok=True) 211 | _args.base_dir = os.path.abspath(_args.base_dir) 212 | _args.log_dir = os.path.abspath(_args.log_dir) 213 | _args.ckpt_path = os.path.abspath(_args.ckpt_path) 214 | _args.prot_file = os.path.abspath(_args.prot_file) 215 | _args.utils_dir = os.path.join(_src_dir, "utils") 216 | 217 | strategy_feedback = DeviceStrategy.initialize( 218 | _args.device_strategy, 219 | _args.work_strategy, 220 | _args.mm_method, 221 | torch.cuda.device_count(), 222 | _args.chunk_size, 223 | ) 224 | _args.mm_method = strategy_feedback['mm-method'] 225 | 226 | return _args 227 | 228 | -------------------------------------------------------------------------------- /src/AIMD/fragment.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | import numpy as np 4 | from ase import Atoms 5 | 6 | 7 | class FragmentData: 8 | def __init__(self, z: np.ndarray, pos: np.ndarray, start: np.ndarray, end: np.ndarray, batch: np.ndarray): 9 | self.z = z 10 | self.pos = pos 11 | self.start = start 12 | self.end = end 13 | self.batch = batch 14 | 15 | def __getitem__(self, f_idx): 16 | # f_idx: index of fragment [0, __len__) 17 | if isinstance(f_idx, int): 18 | f_idx = slice(f_idx, f_idx + 1) 19 | 20 | # a_idx: index of atoms [0, end[-1]) 21 | a_idx = slice(self.start[f_idx][0], self.end[f_idx][-1]) 22 | 23 | return FragmentData( 24 | self.z[a_idx], 25 | self.pos[a_idx], 26 | self.start[f_idx] - self.start[f_idx.start], 27 | self.end[f_idx] - self.start[f_idx.start], 28 | self.batch[a_idx] - self.batch[a_idx.start], 29 | ) 30 | 31 | @lru_cache 32 | def scalar_split(self): 33 | valid = np.flatnonzero(self.end - self.start) 34 | split = np.zeros(len(self), dtype=int) 35 | split[0::2] = 1 36 | split = split[valid] 37 | 38 | return split == 1, split == 0 39 | 40 | @lru_cache 41 | def vector_split(self): 42 | split = np.zeros(self.end[-1], dtype=int) 43 | np.add.at(split, self.start[0::2], 1) 44 | np.add.at(split, self.start[1::2], -1) 45 | split = np.cumsum(split, axis=0) 46 | 47 | return split == 1, split == 0 48 | 49 | def __len__(self): 50 | return len(self.start) 51 | 52 | def get_atoms(self, idx: int) -> Atoms: 53 | start, end = self.start[idx], self.end[idx] 54 | 55 | return Atoms(numbers=self.z[start:end], positions=self.pos[start:end]) 56 | 57 | 58 | class FragmentInfo: 59 | @classmethod 60 | def split(cls, total): 61 | return (total + 1) // 2, total // 2 62 | -------------------------------------------------------------------------------- /src/AIMD/protein.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numbers 3 | import os 4 | from itertools import product 5 | from typing import Optional 6 | 7 | import numpy as np 8 | import torch 9 | from ase import Atoms 10 | from ase.atom import Atom 11 | from openmm import NonbondedForce 12 | from openmm.app import ForceField, NoCutoff, PDBFile 13 | 14 | 15 | class Protein(Atoms): 16 | def __init__( 17 | self, 18 | atoms: Atoms, 19 | pdb4params: Optional[str] = None, 20 | charges: Optional[np.ndarray] = None, 21 | sigmas: Optional[np.ndarray] = None, 22 | epsilons: Optional[np.ndarray] = None, 23 | ): 24 | # * Initialize the Atoms properties 25 | self.__dict__.update(atoms.__dict__) 26 | 27 | assert pdb4params is not None, "pdb4params is not specified." 28 | 29 | self.nowater_PDB_path = pdb4params 30 | self.charges = charges 31 | self.sigmas = sigmas 32 | self.epsilons = epsilons 33 | # skip_check_state: if True, skips the atom comparison and calculation in Calculator.get_property(...) 34 | # can be overridden with utils.SkipCheckState context manager. 35 | self.skip_check_state = False 36 | 37 | if self.charges is None or self.sigmas is None or self.epsilons is None: 38 | self.generate_nonbonded_params() 39 | 40 | def __getitem__(self, i): 41 | r"""Return a subset of the atoms. 42 | 43 | i -- scalar integer, list of integers, or slice object 44 | describing which atoms to return. 45 | 46 | If i is a scalar, return an Atom object. If i is a list or a 47 | slice, return an Atoms object with the same cell, pbc, and 48 | other associated info as the original Atoms object. The 49 | indices of the constraints will be shuffled so that they match 50 | the indexing in the subset returned. 51 | 52 | """ 53 | 54 | if isinstance(i, numbers.Integral): 55 | natoms = len(self) 56 | if i < -natoms or i >= natoms: 57 | raise IndexError("Index out of range.") 58 | 59 | return Atom(atoms=self, index=i) 60 | elif not isinstance(i, slice): 61 | i = np.array(i) 62 | # if i is a mask 63 | if i.dtype == bool: 64 | if len(i) != len(self): 65 | raise IndexError( 66 | "Length of mask {} must equal " 67 | "number of atoms {}".format(len(i), len(self)) 68 | ) 69 | i = np.arange(len(self))[i] 70 | 71 | import copy 72 | 73 | conadd = [] 74 | # Constraints need to be deepcopied, but only the relevant ones. 75 | for con in copy.deepcopy(self.constraints): 76 | try: 77 | con.index_shuffle(self, i) 78 | except (IndexError, NotImplementedError): 79 | pass 80 | else: 81 | conadd.append(con) 82 | 83 | atoms = Atoms( 84 | cell=self.cell, 85 | pbc=self.pbc, 86 | info=self.info, 87 | celldisp=self._celldisp.copy(), 88 | ) 89 | # TODO: Do we need to shuffle indices in adsorbate_info too? 90 | 91 | atoms.arrays = {} 92 | for name, a in self.arrays.items(): 93 | atoms.arrays[name] = a[i].copy() 94 | 95 | atoms.constraints = conadd 96 | copy_object = self.__class__( 97 | atoms, 98 | pdb4params=self.nowater_PDB_path, 99 | charges=self.charges, 100 | sigmas=self.sigmas, 101 | epsilons=self.epsilons, 102 | ) 103 | return copy_object 104 | 105 | def copy(self): 106 | r""" 107 | Return a copy. 108 | """ 109 | atoms = Atoms( 110 | cell=self.cell, 111 | pbc=self.pbc, 112 | info=self.info, 113 | celldisp=self._celldisp.copy(), 114 | ) 115 | 116 | atoms.arrays = {} 117 | for name, a in self.arrays.items(): 118 | atoms.arrays[name] = a.copy() 119 | atoms.constraints = copy.deepcopy(self.constraints) 120 | copy_object = self.__class__( 121 | atoms, 122 | pdb4params=self.nowater_PDB_path, 123 | charges=self.charges, 124 | sigmas=self.sigmas, 125 | epsilons=self.epsilons, 126 | ) 127 | return copy_object 128 | 129 | @property 130 | def num_atoms(self): 131 | return len(self) 132 | 133 | def initial_mm_adjmatrix(self): 134 | r""" 135 | Mask the previous atoms or the atoms in the same dipeptides, 136 | and return the adjacency matrix. 137 | """ 138 | 139 | # Get the total number of nodes (atoms) 140 | num_nodes = len(self.positions) 141 | 142 | # Generate all possible pairs of nodes 143 | pairs = [ 144 | (i, j) for i, j in product(range(num_nodes), repeat=2) if i != j 145 | ] 146 | # Remove the pairs that are in the same dipeptide 147 | edge_index = torch.tensor( 148 | [p for p in pairs if p not in self.exclude_pair], 149 | dtype=torch.long, 150 | ).t() 151 | return edge_index 152 | 153 | def generate_nonbonded_params(self): 154 | r""" 155 | Use OpenMM for generating non-bonded parameters. 156 | """ 157 | pdb = PDBFile(self.nowater_PDB_path) 158 | forcefield = ForceField("amber14-all.xml") 159 | system = forcefield.createSystem( 160 | pdb.topology, nonbondedMethod=NoCutoff 161 | ) 162 | nonbonded = [ 163 | f for f in system.getForces() if isinstance(f, NonbondedForce) 164 | ][0] 165 | charge_ls = [] 166 | sigma_ls = [] 167 | epsilon_ls = [] 168 | for i in range(system.getNumParticles()): 169 | charge, sigma, epsilon = nonbonded.getParticleParameters(i) 170 | charge_ls.append(charge._value) 171 | sigma_ls.append(sigma._value) 172 | epsilon_ls.append(epsilon._value) 173 | self.charges = np.array(charge_ls, dtype=np.float32) 174 | self.sigmas = np.array(sigma_ls, dtype=np.float32) 175 | self.epsilons = np.array(epsilon_ls, dtype=np.float32) 176 | -------------------------------------------------------------------------------- /src/AIMD/simulator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from abc import ABC, abstractmethod 4 | 5 | import numpy as np 6 | from ase import units 7 | from ase.calculators.calculator import Calculator 8 | from ase.constraints import Hookean 9 | from ase.io import write 10 | from ase.io.trajectory import Trajectory 11 | from ase.md.langevin import Langevin 12 | from ase.md.md import MolecularDynamics 13 | from ase.md.nvtberendsen import NVTBerendsen 14 | from ase.md.velocitydistribution import MaxwellBoltzmannDistribution 15 | 16 | from AIMD import arguments 17 | from AIMD.protein import Protein 18 | from Calculators.device_strategy import DeviceStrategy 19 | from Calculators.fragment import FragmentCalculator 20 | from Calculators.qmmm import AsyncQMMM 21 | from Calculators.tinker_async import TinkerAsyncCalculator, TinkerRuntimeError 22 | from Calculators.visnet_calculator import ViSNetCalculator 23 | from utils.pdb import read_protein 24 | from utils.system import get_physical_core_count 25 | from utils.utils import ( 26 | MDObserver, 27 | PDBAnalyzer, 28 | RNGPool, 29 | TemperatureRunawayError, 30 | WorkQueue, 31 | ) 32 | 33 | 34 | class BaseSimulator(ABC): 35 | def __init__( 36 | self, prot: Protein, log_path: str, preeq_steps: int = 200, temp_k: int = 300 37 | ) -> None: 38 | self.prot = prot 39 | self.log_path = log_path 40 | self.simulation_save_path = os.path.join(log_path, "SimulationResults") 41 | os.makedirs(self.simulation_save_path, exist_ok=True) 42 | self.nowat_pdb = self.prot.nowater_PDB_path 43 | self.preeq_steps = preeq_steps 44 | self.prot.set_pbc(True) 45 | self.temp_k = temp_k 46 | 47 | def get_qm_idx(self): 48 | return list(range(len(read_protein(self.nowat_pdb)))) 49 | 50 | def need_fragmentation(self): 51 | return isinstance(self.qmcalc, FragmentCalculator) 52 | 53 | def initialize_fragcalc(self): 54 | if self.need_fragmentation(): 55 | self.qmcalc.bonded_calculator.fragment_method.fragment(self.qmatoms) 56 | self.qmcalc.nonbonded_calculator.set_parameters(self.qmatoms) 57 | 58 | start, end = self.qmatoms.fragments_start, self.qmatoms.fragments_end 59 | else: 60 | start, end = [0], [len(self.qmatoms)] 61 | 62 | # set work partitions based on dipeptides/ACE-NMEs 63 | DeviceStrategy.set_work_partitions(start, end) 64 | 65 | def set_calculator(self, **kwargs) -> None: 66 | os.chdir(self.simulation_save_path) 67 | self.make_calculator(**kwargs) 68 | self.initialize_fragcalc() 69 | 70 | @abstractmethod 71 | def make_calculator(self, **kwargs) -> Calculator: 72 | pass 73 | 74 | def make_fragment_calculator(self, is_root_calc: bool, **kwargs) -> FragmentCalculator: 75 | mode = arguments.get().mode 76 | if mode == "fragment": 77 | return FragmentCalculator(is_root_calc=is_root_calc, **kwargs) 78 | if mode == "visnet": 79 | return ViSNetCalculator(is_root_calc=is_root_calc, **kwargs) 80 | 81 | def simulate( 82 | self, prot_name: str, simulation_steps: int, time_step: float, 83 | record_per_steps: int, hydrogen_constraints: bool, 84 | seed: int, restart: bool, build_frames: bool 85 | ): 86 | restart_traj_path = os.path.join(self.log_path, f"{prot_name}-traj.traj") 87 | 88 | if restart: 89 | with Trajectory(restart_traj_path) as ori_traj: 90 | restart_frame_count = len(ori_traj) 91 | restart_last_frame = ori_traj[-1] 92 | self.prot.set_positions(restart_last_frame.get_positions()) 93 | self.prot.set_velocities(restart_last_frame.get_velocities()) 94 | else: 95 | restart_frame_count = 0 96 | MaxwellBoltzmannDistribution(self.prot, temperature_K=self.temp_k, rng=np.random.RandomState(seed)) 97 | 98 | ''' 99 | MolDyn = NVTBerendsen( 100 | self.prot, 101 | timestep=time_step * units.fs, 102 | temperature=self.temp_k, 103 | taut=0.01 * 1000 * units.fs, 104 | ) 105 | ''' 106 | 107 | # initialize rng pool 108 | rng_pool = RNGPool(seed=seed, shape=(len(self.prot), 3), count=2) 109 | 110 | MolDyn = Langevin( 111 | self.prot, 112 | timestep=time_step * units.fs, 113 | temperature_K=self.temp_k, 114 | friction=0.001 / units.fs, 115 | rng=rng_pool, 116 | ) 117 | 118 | if restart: 119 | moldyn_traj_filename = os.path.join(self.log_path, f"{prot_name}-traj-restart.traj") 120 | else: 121 | moldyn_traj_filename = os.path.join(self.log_path, f"{prot_name}-traj.traj") 122 | 123 | moldyn_traj = Trajectory(moldyn_traj_filename, "w", self.prot) 124 | 125 | observer = MDObserver( 126 | a=self.prot, 127 | q=self.qmatoms, 128 | md=MolDyn, 129 | traj=moldyn_traj, 130 | rng=rng_pool, 131 | step_offset=restart_frame_count, 132 | temp_k=self.temp_k, 133 | ) 134 | MolDyn.attach(observer.save_traj_copy, interval=record_per_steps) 135 | MolDyn.attach(observer.write_traj, interval=record_per_steps) 136 | MolDyn.attach(observer.printenergy, interval=record_per_steps) 137 | MolDyn.attach(observer.fill_rng_pool, interval=1) 138 | 139 | if (not restart) and (self.preeq_steps != 0): 140 | init_constraint = self.prot.constraints.copy() 141 | indices_to_constrain = self.get_qm_idx() 142 | restraints = [10, 5, 1, 0.5, 0.1] 143 | print("Start pre-equilibration") 144 | for restraint in restraints: 145 | print( 146 | f"Pre-equilibration with {restraint} eV/A² for {self.preeq_steps} steps" 147 | ) 148 | constraints = [] 149 | ref_positions = self.prot.positions 150 | for idx in indices_to_constrain: 151 | pos = ref_positions[idx] 152 | kcalmol2ev = (units.kcal / units.mol) / units.eV 153 | constraint = Hookean(a1=idx, a2=pos, k=restraint * kcalmol2ev, rt=0) 154 | constraints.append(constraint) 155 | self.prot.constraints.extend(constraints) 156 | 157 | try: 158 | MolDyn.run(self.preeq_steps) 159 | except TemperatureRunawayError: 160 | print("Thermostat detects a temperature runaway condition, cannot proceed.") 161 | exit(-1) 162 | except TinkerRuntimeError: 163 | print("Solvent dynamic component Tinker terminated abnormally, cannot proceed.") 164 | exit(-1) 165 | self.prot.constraints = init_constraint.copy() 166 | print("Pre-equilibration finished!") 167 | 168 | if hydrogen_constraints is True: 169 | pdb_analyzer = PDBAnalyzer(self.nowat_pdb) 170 | hydrogen_bonds = pdb_analyzer.find_bonded_atoms("H") 171 | hydrogen_constraints = [] 172 | 173 | for pair in hydrogen_bonds: 174 | # * Hookean constraints 175 | hydrogen_constraint = Hookean( 176 | a1=pair[0], a2=pair[1], k=pair[3], rt=pair[2] 177 | ) 178 | hydrogen_constraints.append(hydrogen_constraint) 179 | 180 | self.prot.constraints.extend(hydrogen_constraints) 181 | 182 | if restart: 183 | print(f"Re-start simulation for {simulation_steps} steps") 184 | else: 185 | print(f"Start simulation for {simulation_steps} steps") 186 | 187 | try: 188 | MolDyn.run(simulation_steps) 189 | except TinkerRuntimeError: 190 | print("Solvent dynamic component Tinker terminated abnormally, cannot proceed.") 191 | exit(-1) 192 | except TemperatureRunawayError: 193 | print("Thermostat detects a temperature runaway condition, cannot proceed.") 194 | exit(-1) 195 | 196 | print("Simulation finished!") 197 | WorkQueue.finalise() 198 | moldyn_traj.close() 199 | 200 | if build_frames and not restart: 201 | self.build_frames_from_traj(prot_name, record_per_steps, MolDyn.nsteps) 202 | 203 | shutil.rmtree(os.path.join(self.log_path, "SimulationResults")) 204 | 205 | def build_frames_from_traj(self, prot_name, record_per_steps, nsteps): 206 | print("Building frames from trajectory...") 207 | simutraj = Trajectory(os.path.join(self.log_path, f"{prot_name}-traj.traj")) 208 | opt_traj_filename = os.path.join(self.log_path, f"{prot_name}-traj.xyz") 209 | os.makedirs(os.path.join(self.log_path, "frames"), exist_ok=True) 210 | # break the trajectory into frames (xyz), then append them all to opt_traj_filename 211 | for i in range(0, nsteps, record_per_steps): 212 | atoms = simutraj[i] 213 | frame_filename = os.path.join(self.log_path, "frames", f"structure{i:0>5}.xyz") 214 | write(frame_filename, atoms) 215 | with open(frame_filename) as finframe: 216 | inframe = finframe.read() 217 | with open(opt_traj_filename, "a") as fopt_traj: 218 | fopt_traj.write(inframe) 219 | 220 | os.chdir(self.log_path) 221 | os.makedirs("results", exist_ok=True) 222 | os.system(f"cp {prot_name}-traj.xyz results") 223 | print("Done building frames from trajectory.") 224 | 225 | 226 | class SolventSimulator(BaseSimulator): 227 | def __init__( 228 | self, 229 | prot: Protein, 230 | log_path: str, 231 | preeq_steps: int, 232 | temp_k: int, 233 | utils_dir: str, 234 | pdb_file: str, 235 | nowat_pdb_file: str, 236 | mmcalc_type: str, 237 | preprocess_method: str, 238 | dev_strategy: str, 239 | ) -> None: 240 | super().__init__(prot, log_path, preeq_steps, temp_k) 241 | 242 | self.utils_dir = utils_dir 243 | self.pdb_file = pdb_file 244 | self.nowat_pdb_file = nowat_pdb_file 245 | self.mmcalc_type = mmcalc_type 246 | self.preprocess_method = preprocess_method 247 | self.dev_strategy = dev_strategy 248 | 249 | def make_mm_calculator(self): 250 | devices = DeviceStrategy.get_solvent_devices() 251 | 252 | if self.mmcalc_type in ['tinker', 'tinker-GPU']: 253 | mm_calc = TinkerAsyncCalculator( 254 | pdb_file=self.pdb_file, 255 | utils_dir=self.utils_dir, 256 | devices=devices, 257 | ) 258 | else: 259 | raise ValueError(f"Unknown mm calculator: {self.mmcalc_type}") 260 | return mm_calc 261 | 262 | def make_mm_qmregion_calculator(self): 263 | devices = DeviceStrategy.get_solvent_devices() 264 | if self.mmcalc_type in ['tinker', 'tinker-GPU']: 265 | mm_qmregion_calc = TinkerAsyncCalculator( 266 | pdb_file=self.nowat_pdb_file, 267 | utils_dir=self.utils_dir, 268 | devices=devices, 269 | ) 270 | else: 271 | raise ValueError(f"Unknown mm calculator: {self.mmcalc_type}") 272 | return mm_qmregion_calc 273 | 274 | def make_calculator(self, **kwargs): 275 | self.prot.calc = AsyncQMMM( 276 | selection=self.get_qm_idx(), 277 | qmcalc=self.make_fragment_calculator(is_root_calc=False, **kwargs), 278 | mmcalc1=self.make_mm_qmregion_calculator(), 279 | mmcalc2=self.make_mm_calculator(), 280 | ) 281 | 282 | self.prot.calc.initialize_qm(self.prot) 283 | 284 | self.qmcalc = self.prot.calc.qmcalc 285 | self.qmatoms = self.prot.calc.qmatoms 286 | 287 | if isinstance(self.prot.calc.mmcalc1, TinkerAsyncCalculator): 288 | self.prot.calc.mmcalc1.atoms = self.qmatoms 289 | self.prot.calc.mmcalc1._start_tinker() 290 | if isinstance(self.prot.calc.mmcalc2, TinkerAsyncCalculator): 291 | self.prot.calc.mmcalc2.atoms = self.prot 292 | self.prot.calc.mmcalc2._start_tinker() 293 | 294 | 295 | class NoSolventSimulator(BaseSimulator): 296 | def __init__( 297 | self, 298 | prot: Protein, 299 | log_path: str, 300 | preeq_steps: int, 301 | temp_k: int, 302 | **kwargs 303 | ) -> None: 304 | super().__init__(prot, log_path, preeq_steps, temp_k) 305 | 306 | self.prot = self.prot[self.get_qm_idx()] 307 | 308 | 309 | def make_calculator(self, **kwargs): 310 | self.prot.calc = self.make_fragment_calculator(is_root_calc=True, **kwargs) 311 | 312 | self.qmcalc = self.prot.calc 313 | self.qmatoms = self.prot 314 | -------------------------------------------------------------------------------- /src/Calculators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Calculators/__init__.py -------------------------------------------------------------------------------- /src/Calculators/async_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import socket 4 | import tempfile 5 | from typing import Union 6 | from select import select 7 | 8 | import numpy as np 9 | 10 | 11 | class AsyncUtilError(ConnectionError): 12 | def __init__(self, *args: object): 13 | super().__init__(args) 14 | 15 | 16 | class SocketOps: 17 | def __init__(self): 18 | self.connection: socket.socket = None 19 | self.header = self.makebuf([1], 'int32') 20 | self.recvbuf = bytearray() 21 | 22 | def recv(self, buf, sz=None): 23 | # python memoryview still remembers original shape. flatten it. 24 | buf = memoryview(buf).cast('b') 25 | if sz is None: 26 | sz = len(buf) 27 | while True: 28 | actual = self.connection.recv_into(buf, sz) 29 | if actual == sz: 30 | break 31 | if actual == 0: 32 | raise AsyncUtilError("nothing received") 33 | sz -= actual 34 | buf = buf[actual:] 35 | buf.release() 36 | 37 | def send(self, buf): 38 | buf = memoryview(buf).cast('b') 39 | sz = len(buf) 40 | if sz <= 0: 41 | return 42 | try: 43 | while True: 44 | actual = self.connection.send(buf) 45 | if actual == sz: 46 | break 47 | if actual <= 0: 48 | raise AsyncUtilError("send failure") 49 | sz -= actual 50 | buf = buf[actual:] 51 | finally: 52 | buf.release() 53 | 54 | def send_object(self, obj): 55 | buf = pickle.dumps(obj) 56 | self.header[0] = len(buf) 57 | self.send(self.header) 58 | self.send(buf) 59 | 60 | def recv_object(self): 61 | self.recv(self.header) 62 | n = self.header[0] 63 | if len(self.recvbuf) < n: 64 | self.recvbuf = self.recvbuf.zfill(n) 65 | self.recv(self.recvbuf, n) 66 | return pickle.loads(self.recvbuf[:n]) 67 | 68 | def makebuf(self, shape, type: Union[str, int]): 69 | if type == 4: 70 | return np.empty(shape=shape, dtype='float32') 71 | elif type == 8: 72 | return np.empty(shape=shape, dtype='float64') 73 | elif isinstance(type, str): 74 | return np.empty(shape=shape, dtype=type) 75 | else: 76 | raise Exception(f"unrecognized type {type}") 77 | 78 | 79 | class AsyncServer(SocketOps): 80 | def __init__(self, type: str): 81 | super().__init__() 82 | self.type = type 83 | self.socket_dir = tempfile.mkdtemp(prefix=f"ai2bmd-{type}-") 84 | self.socket_path = os.path.join(self.socket_dir, "socket") 85 | self.server_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 86 | self.server_socket.bind(self.socket_path) 87 | self.server_socket.listen() 88 | self.connection = None 89 | self.socket_client_address = None 90 | 91 | def accept(self): 92 | self.connection, self.socket_client_address = self.server_socket.accept() 93 | 94 | def close(self): 95 | self.connection.close() 96 | self.server_socket.close() 97 | 98 | def wait_for_data(self, timeout): 99 | rl, _, _ = select([self.connection], [], [], timeout) 100 | return len(rl) == 1 101 | 102 | 103 | 104 | class AsyncClient(SocketOps): 105 | def __init__(self, socket_path): 106 | super().__init__() 107 | self.socket_path = socket_path 108 | self.connection = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 109 | self.connection.connect(socket_path) 110 | -------------------------------------------------------------------------------- /src/Calculators/bonded.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from collections import deque 3 | from concurrent.futures import Future, ThreadPoolExecutor 4 | from os import path as osp 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from AIMD import arguments 10 | from AIMD.fragment import FragmentData 11 | from AIMD.protein import Protein 12 | from Calculators.combiner import DipeptideBondedCombiner 13 | from Calculators.device_strategy import DeviceStrategy 14 | from Calculators.visnet_calculator import ViSNetModelLike, get_visnet_model 15 | from Fragmentation import DistanceFragment 16 | from utils.utils import numpy_to_torch 17 | 18 | 19 | class DLBondedCalculator: 20 | r""" 21 | DLBondedCalculator is a dipeptide bonded calculator based on 22 | DL calculations supported by ViSNet. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | ckpt_path: str, 28 | ckpt_type: str, 29 | **kwargs, 30 | ) -> None: 31 | self.models: list[ViSNetModelLike] = [] 32 | self.ckpt_path = ckpt_path 33 | self.ckpt_type = ckpt_type 34 | 35 | # * set fragment method and combiner 36 | self.fragment_method = DistanceFragment() 37 | self.combiner = DipeptideBondedCombiner() 38 | 39 | print("Loading models...") 40 | model_path = osp.join(self.ckpt_path, f"visnet-uni-{self.ckpt_type}.ckpt") 41 | self.models = [ 42 | get_visnet_model(model_path, device) 43 | for device in DeviceStrategy.get_bonded_devices() 44 | ] 45 | 46 | def _inference_impl( 47 | self, data: list[FragmentData], model: ViSNetModelLike 48 | ) -> tuple[list[np.ndarray], list[np.ndarray]]: 49 | return zip(*[model.dl_potential_loader(unit) for unit in data]) 50 | 51 | def calculate( 52 | self, fragments: FragmentData 53 | ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 54 | r""" 55 | Calculate the energy and forces of the dipeptide. 56 | The basic function of DLBondedCalculator. 57 | 58 | Parameters: 59 | ----------- 60 | fragments: FragmentData 61 | combined dipeptide and ACE-NME fragments. 62 | """ 63 | 64 | # retrieve devices and work assignment 65 | devices = DeviceStrategy.get_bonded_devices() 66 | work = DeviceStrategy.get_work_partitions() 67 | 68 | # work execution 69 | n_devices = len(devices) 70 | partitions = [[] for _ in range(n_devices)] 71 | for idx, start, end in work: 72 | partitions[idx].append(fragments[start:end]) 73 | 74 | futures: list[Future] = [] 75 | with ThreadPoolExecutor(n_devices) as executor: 76 | for data, model in zip(partitions, self.models): 77 | futures.append(executor.submit(self._inference_impl, data, model)) 78 | 79 | # collect results 80 | energy, forces = [ 81 | np.concatenate(list(itertools.chain(*item))) 82 | for item in zip(*[f.result() for f in futures]) 83 | ] 84 | 85 | # convert numpy arrays to torch tensors 86 | device = DeviceStrategy.get_default_device() 87 | 88 | energy = numpy_to_torch(energy, device=device) 89 | forces = numpy_to_torch(forces, device=device) 90 | 91 | # split results between dipeptides/ACE-NMEs 92 | dipeptides_energy, ACE_NMEs_energy = (energy[s] for s in fragments.scalar_split()) 93 | dipeptides_forces, ACE_NMEs_forces = (forces[s] for s in fragments.vector_split()) 94 | 95 | return ( 96 | dipeptides_energy, 97 | dipeptides_forces, 98 | ACE_NMEs_energy, 99 | ACE_NMEs_forces, 100 | ) 101 | 102 | def __call__(self, prot: Protein) -> tuple[np.ndarray, np.ndarray]: 103 | fragments = self.fragment_method.get_fragments(prot) 104 | ( 105 | dipeptides_energies, 106 | dipeptides_forces, 107 | ACE_NMEs_energies, 108 | ACE_NMEs_forces, 109 | ) = self.calculate(fragments) 110 | 111 | energy = self.combiner.energy_combine( 112 | dipeptides_energies, 113 | ACE_NMEs_energies, 114 | ) 115 | forces = self.combiner.forces_combine( 116 | len(prot), 117 | dipeptides_forces, 118 | ACE_NMEs_forces, 119 | prot.select_index, 120 | prot.origin_index, 121 | ) 122 | 123 | return energy, forces 124 | -------------------------------------------------------------------------------- /src/Calculators/calculator.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from ase import Atoms 4 | from ase.calculators.calculator import Calculator, compare_atoms 5 | 6 | from AIMD.protein import Protein 7 | 8 | 9 | def check_state(self: Calculator, atoms: Union[Atoms, Protein], tol=1e-15): 10 | r""" 11 | Check for any system changes since last calculation. 12 | - Skips the check and assume no changes (and therefore no calculations), if 13 | 'atoms.check_state' is false. 14 | """ 15 | 16 | if getattr(atoms, "skip_check_state", False): 17 | return [] 18 | else: 19 | return compare_atoms(self.atoms, atoms, tol=tol, excluded_properties=set(self.ignored_changes)) 20 | 21 | 22 | def patch_check_state(): 23 | Calculator.check_state = check_state 24 | -------------------------------------------------------------------------------- /src/Calculators/combiner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch_scatter 4 | 5 | 6 | class DipeptideBondedCombiner: 7 | r""" 8 | Combine the energies and forces of the dipeptides and ACE-NMEs. 9 | """ 10 | 11 | @staticmethod 12 | def energy_combine( 13 | dipeptides_energies: torch.Tensor, 14 | acenmes_energies: torch.Tensor, 15 | ) -> np.ndarray: 16 | r""" 17 | Combine the energies of the dipeptides and ACE-NMEs. 18 | """ 19 | energy = torch.sum(dipeptides_energies) - torch.sum(acenmes_energies) 20 | 21 | return energy.detach().cpu().numpy() 22 | 23 | @staticmethod 24 | def forces_combine( 25 | num_atoms: int, 26 | dipeptides_forces: torch.Tensor, 27 | acenmes_forces: torch.Tensor, 28 | select_index: torch.Tensor, 29 | origin_index: torch.Tensor, 30 | ) -> np.ndarray: 31 | r""" 32 | Combine the forces of the dipeptides and ACE-NMEs, removing the extra 33 | forces that correspond to the added hydrogens, and assigning the 34 | selected forces to their original indices in the protein. 35 | """ 36 | 37 | # negate ACE-NME forces 38 | forces = torch.cat([dipeptides_forces, -acenmes_forces])[select_index] 39 | forces = torch_scatter.scatter(forces, origin_index, dim=0, dim_size=num_atoms, reduce="sum") 40 | 41 | return forces.detach().cpu().numpy() 42 | 43 | 44 | class DipeptideCombiner: 45 | @staticmethod 46 | def energy_combine( 47 | bonded_energy: np.float32, nonbonded_energy: np.float32 48 | ) -> np.float32: 49 | return bonded_energy + nonbonded_energy 50 | 51 | @staticmethod 52 | def forces_combine( 53 | bonded_forces: np.ndarray, nonbonded_forces: np.ndarray 54 | ) -> np.ndarray: 55 | return bonded_forces + nonbonded_forces 56 | -------------------------------------------------------------------------------- /src/Calculators/device_strategy.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import logging 3 | import math 4 | import subprocess 5 | 6 | import torch 7 | 8 | from AIMD.fragment import FragmentInfo 9 | from utils.system import get_physical_core_count 10 | 11 | 12 | class DeviceStrategy: 13 | """Computation resource is required by: 14 | - Preprocess 15 | - Simulator (Solvent/Non-solvent) 16 | - FragmentCalculator 17 | - Bonded calculator 18 | - Non-bonded calculator 19 | - Solvent-related calculators 20 | - MM calculator 21 | - MM/QM region calculator 22 | """ 23 | 24 | @classmethod 25 | def _check_device(cls, device: str): 26 | if device == 'cpu': 27 | return 28 | elif device.startswith('cuda'): 29 | tup = device.split(':') 30 | assert len(tup) == 2, "invalid device syntax" 31 | n = int(tup[1]) 32 | assert n >= 0 and n < cls._gpu_count, "invalid device index" 33 | return 34 | else: 35 | raise Exception("Unrecognized device") 36 | 37 | 38 | @classmethod 39 | def get_preprocess_device(cls): 40 | return cls._preprocess_device 41 | 42 | 43 | @classmethod 44 | def get_bonded_devices(cls): 45 | if len(cls._bonded_devices) < 1: 46 | raise Exception("No compute resources for bonded calculation") 47 | for dev in cls._bonded_devices: 48 | cls._check_device(dev) 49 | return cls._bonded_devices 50 | 51 | 52 | @classmethod 53 | def get_non_bonded_device(cls): 54 | cls._check_device(cls._non_bonded_device) 55 | return cls._non_bonded_device 56 | 57 | 58 | @classmethod 59 | def get_solvent_devices(cls): 60 | if len(cls._solvent_devices) < 1: 61 | raise Exception("No compute resources for solvent calculation") 62 | for dev in cls._solvent_devices: 63 | cls._check_device(dev) 64 | return cls._solvent_devices 65 | 66 | 67 | @classmethod 68 | def get_optimiser_device(cls): 69 | return cls._optimiser_device 70 | 71 | 72 | @classmethod 73 | def get_default_device(cls): 74 | cls._check_device(cls._default_device) 75 | return cls._default_device 76 | 77 | 78 | @classmethod 79 | def fragment_strategy(cls): 80 | return cls._fragment_strategy 81 | 82 | 83 | @classmethod 84 | def _set_combined_work_partitions(cls, devices: list[int], start: list[int], end: list[int]): 85 | """ 86 | Work partition strategy, where the combined work for 87 | ACE-NMEs/dipeptides is split evenly. 88 | """ 89 | partitions = [] 90 | 91 | n_blocks = len(devices) 92 | a_end = len(start) 93 | 94 | chunk = cls._chunk_size 95 | 96 | # divide work into blocks, in units of atoms 97 | b_prev = 0 98 | for i in range(n_blocks): 99 | block = (end[-1] - start[b_prev]) // (n_blocks - i) 100 | b_end = bisect.bisect(start, block + start[b_prev]) 101 | b_idx = b_end - 1 102 | 103 | block_end = block + start[b_prev] 104 | if (block_end - start[b_idx]) < (end[b_idx] - block_end): 105 | b_end = b_end - 1 106 | 107 | b_end = min(b_end, a_end) 108 | 109 | # divide work into chunks 110 | c_prev = b_prev 111 | while c_prev != b_end: 112 | c_end = bisect.bisect(start, chunk + start[c_prev]) 113 | c_idx = c_end - 1 114 | 115 | chunk_end = chunk + start[c_prev] 116 | if (chunk_end - start[c_idx]) < (end[c_idx] - chunk_end): 117 | c_end = c_end - 1 118 | 119 | c_end = min(c_end, b_end) 120 | 121 | partitions.append((i, c_prev, c_end)) 122 | 123 | c_prev = c_end 124 | 125 | b_prev = b_end 126 | 127 | cls._work_partitions = partitions 128 | 129 | 130 | @classmethod 131 | def set_work_partitions(cls, start: list[int], end: list[int]): 132 | bonded = cls._bonded_devices 133 | 134 | cls._set_combined_work_partitions(bonded, start, end) 135 | 136 | 137 | @classmethod 138 | def get_work_partitions(cls): 139 | return cls._work_partitions 140 | 141 | 142 | @classmethod 143 | def initialize(cls, dev_strategy: str, work_strategy: str, mm_method: str, gpu_count: int, chunk_size: int): 144 | cls._gpu_count = gpu_count 145 | cls._dev_strategy = dev_strategy 146 | cls._work_strategy = work_strategy 147 | cls._chunk_size = chunk_size 148 | 149 | cls._fragment_strategy = False 150 | 151 | except_last = range(gpu_count - 1) 152 | all_gpus = range(gpu_count) 153 | last_gpu = gpu_count - 1 154 | 155 | # device slots 156 | preprocess = "cpu" 157 | bonded = [] 158 | non_bonded = "cpu" if gpu_count == 0 else f"cuda:{last_gpu}" 159 | solvent = [] 160 | default = "cpu" if gpu_count == 0 else "cuda:0" 161 | optimiser = "cpu" 162 | 163 | if mm_method == "tinker-GPU" and gpu_count > 0: 164 | preprocess = f"cuda:{last_gpu}" 165 | else: 166 | preprocess = "cpu" 167 | 168 | print(f"DeviceStrategy: setting strategy to [{dev_strategy} / {work_strategy}]") 169 | 170 | if work_strategy == "combined" and chunk_size == 0: 171 | raise ValueError(f"chunk-size: {chunk_size} must be non-zero for 'combined' work strategy") 172 | 173 | # see: AIMD/arguments.py: --device-strategy for docs 174 | if dev_strategy == 'excess-compute': 175 | if gpu_count == 0: 176 | bonded = ["cpu", "cpu"] 177 | elif gpu_count == 1: 178 | bonded = ["cuda:0"] 179 | else: 180 | bonded = [f"cuda:{i}" for i in except_last] 181 | 182 | if gpu_count > 0: 183 | solvent = [f"cuda:{last_gpu}"] 184 | 185 | elif dev_strategy == 'small-molecule': 186 | if gpu_count == 0: 187 | bonded = ["cpu", "cpu"] 188 | else: 189 | bonded = [f"cuda:{i}" for i in all_gpus] 190 | 191 | if gpu_count > 1: 192 | solvent = ["cuda:1", "cuda:0"] 193 | elif gpu_count > 0: 194 | solvent = ["cuda:0"] 195 | 196 | elif dev_strategy == 'large-molecule': 197 | if gpu_count == 0: 198 | bonded = ["cpu", "cpu"] 199 | else: 200 | bonded = [f"cuda:{i}" for i in all_gpus] 201 | 202 | if gpu_count > 3: 203 | solvent = ["cuda:2", "cuda:1"] 204 | elif gpu_count > 2: 205 | solvent = ["cuda:1"] 206 | elif gpu_count > 0: 207 | solvent = ["cuda:0"] 208 | 209 | if gpu_count > 2: 210 | optimiser = "cuda:0" 211 | 212 | else: 213 | raise Exception("Unknown compute strategy") 214 | 215 | if dev_strategy == 'large-molecule': 216 | # run bonded/non-bonded calculations concurrently 217 | cls._fragment_strategy = True 218 | 219 | if mm_method == "tinker-GPU": 220 | if len(solvent) == 0: 221 | logging.error("tinker-GPU is specified, but there's no GPU. Reverting back to CPU.") 222 | solvent = ["cpu"] 223 | mm_method = "tinker" 224 | else: 225 | solvent = ["cpu"] 226 | 227 | cls._bonded_devices = bonded 228 | cls._non_bonded_device = non_bonded 229 | cls._solvent_devices = solvent 230 | cls._optimiser_device = optimiser 231 | cls._default_device = default 232 | cls._preprocess_device = preprocess 233 | 234 | cls._work_partitions = [] 235 | 236 | # On some machines, libtorch.so->libgomp.so excessively consume CPU resource, saturating all cores 237 | # and bring down performance, due to heavy synchronization. 238 | # We need to manually tell torch to start less CPU threads in this case. 239 | libgomp_bug_cpu_blacklist = [ 240 | "Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz" 241 | ] 242 | lscpu = subprocess.Popen("lscpu", stdout=subprocess.PIPE) 243 | if lscpu.wait(): 244 | raise RuntimeError("lscpu") 245 | output = lscpu.stdout.read().decode().splitlines() 246 | cpu_model = next(x for x in output if "Model name:" in x).split(':')[1].strip() 247 | # blacklist is weak against unknown CPU models, especially in Azure VMs (they upgrade all the time). 248 | # so currently we don't rely on this check. 249 | bad_cpu = cpu_model in libgomp_bug_cpu_blacklist 250 | # TODO check kernel and libgomp versions 251 | 252 | # setting num threads to 1 will only hurt performance only if the model is running on CPU 253 | model_on_cpu = "cpu" in bonded 254 | if not model_on_cpu: 255 | torch.set_num_threads(1) 256 | else: 257 | # split the CPU resources to models 258 | total_threads = get_physical_core_count() 259 | total_models = sum([1 for x in bonded if x == 'cpu']) 260 | # this shouldn't happen but going defensive anyways 261 | total_models = max(1, total_models) 262 | torch_threads = max(1, total_threads // total_models) 263 | torch.set_num_threads(torch_threads) 264 | 265 | return { 'mm-method': mm_method } 266 | -------------------------------------------------------------------------------- /src/Calculators/fragment.py: -------------------------------------------------------------------------------- 1 | from ase.calculators.calculator import Calculator 2 | 3 | from Calculators.device_strategy import DeviceStrategy 4 | from Calculators.bonded import DLBondedCalculator 5 | from Calculators.combiner import DipeptideCombiner 6 | from Calculators.nonbonded import MMNonBondedCalculator 7 | from Calculators.pme import PMENonBondedCalculator 8 | from utils.utils import WorkQueue, execution_wrapper 9 | 10 | 11 | nonbonded_calcs = { 12 | 'mm': MMNonBondedCalculator, 13 | 'pme': PMENonBondedCalculator, 14 | } 15 | 16 | class FragmentCalculator(Calculator): 17 | r""" 18 | FragmentCalculator is a universal calculator for dipeptide fragments, 19 | including Bonded Calculator, QM/DL Calculator, and MM (Non-Bonded) Calculator. 20 | 21 | Parameters: 22 | ----------- 23 | properties: list[str] 24 | Targets of the calculation. 25 | """ 26 | 27 | implemented_properties = ["energy", "forces"] 28 | 29 | def __init__(self, is_root_calc, nbcalc_type, **kwargs): 30 | super().__init__(**kwargs) 31 | 32 | self.is_root_calc = is_root_calc 33 | 34 | if self.is_root_calc is True: 35 | self.work_queue = WorkQueue() 36 | 37 | # * set bonded calculator 38 | self.bonded_calculator = DLBondedCalculator(**kwargs) 39 | 40 | # * set non-bonded calculator 41 | device = DeviceStrategy.get_non_bonded_device() 42 | self.nonbonded_calculator = nonbonded_calcs[nbcalc_type](device=device) 43 | 44 | # * set fragment strategy 45 | self.concurrent = DeviceStrategy.fragment_strategy() 46 | 47 | # * set combiner 48 | self.combiner = DipeptideCombiner() 49 | 50 | def calculate(self, atoms, properties, system_changes): 51 | if self.is_root_calc: 52 | Calculator.calculate(self, atoms, properties, system_changes) 53 | self.work_queue.drain() 54 | 55 | f_args = [ 56 | (self.bonded_calculator, atoms), 57 | (self.nonbonded_calculator, atoms), 58 | ] 59 | 60 | energy, forces = zip(*execution_wrapper(f_args, self.concurrent)) 61 | 62 | energy = self.combiner.energy_combine(*energy) 63 | forces = self.combiner.forces_combine(*forces) 64 | 65 | self.results = { 66 | "energy": energy, 67 | "forces": forces, 68 | } 69 | -------------------------------------------------------------------------------- /src/Calculators/nonbonded.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from ase.units import C, _eps0, kJ, mol, nm, pi 4 | from torch_scatter import scatter_add 5 | 6 | from AIMD.protein import Protein 7 | 8 | 9 | class MMNonBondedCalculator: 10 | r""" 11 | MMNonBondedCalculator is a non-bonded calculator based on MM calculations. 12 | """ 13 | 14 | def __init__(self, device="cpu") -> None: 15 | super().__init__() 16 | self.device = device 17 | self.k = 1 / (4 * pi * _eps0) * 10e6 * mol * C ** (-2) 18 | self.sigmas = None 19 | self.epsilons = None 20 | self.charges = None 21 | self.src = None 22 | self.dst = None 23 | 24 | def set_parameters(self, prot: Protein) -> None: 25 | self.sigmas = torch.tensor(prot.sigmas, dtype=torch.float, device=self.device) 26 | self.epsilons = torch.tensor(prot.epsilons, dtype=torch.float, device=self.device) 27 | self.charges = torch.tensor(prot.charges, dtype=torch.float, device=self.device) 28 | 29 | src, dst = prot.initial_mm_adjmatrix() 30 | self.src = src.to(self.device) 31 | self.dst = dst.to(self.device) 32 | 33 | def __call__(self, prot: Protein) -> tuple[np.float32, np.ndarray]: 34 | r""" 35 | Using non-bonded atom pairs to calculate the non-bonded energy and 36 | force. Non-bonded forces are calculated by calculating the gradient of 37 | the non-bonded energy with respect to the atom positions. 38 | """ 39 | pos = torch.tensor(prot.get_positions(), dtype=torch.float, device=self.device) 40 | 41 | vec = pos[self.dst] - pos[self.src] 42 | d2 = (vec**2).sum(-1) 43 | d = torch.sqrt(d2) 44 | 45 | # LJ 46 | sigmaij = 0.5 * (self.sigmas[self.src] + self.sigmas[self.dst]) * nm 47 | epsij = torch.sqrt(self.epsilons[self.src] * self.epsilons[self.dst]) 48 | c6 = (sigmaij**2 / d2) ** 3 49 | c12 = c6**2 50 | energy_lj = 4 * epsij * (c12 - c6) 51 | force_lj = (24 * epsij * (2 * c12 - c6) / d2).unsqueeze(-1) * vec 52 | 53 | # Coulomb 54 | energy_coulomb = self.k * self.charges[self.src] * self.charges[self.dst] / d 55 | force_coulomb = (energy_coulomb / d2).unsqueeze(-1) * vec 56 | 57 | # Combine 58 | energy = energy_lj.sum() + energy_coulomb.sum() 59 | force = force_lj + force_coulomb 60 | force = scatter_add(force, self.dst, dim=0, dim_size=len(prot)) 61 | energy = energy.cpu().item() * (kJ / mol) / 2 62 | force = force.cpu().numpy() * (kJ / mol) 63 | return energy, force 64 | -------------------------------------------------------------------------------- /src/Calculators/pme.py: -------------------------------------------------------------------------------- 1 | import helpmelib 2 | import numpy as np 3 | import torch 4 | from ase.units import C, _eps0, kJ, mol, nm, pi 5 | from torch.special import erf, erfc 6 | from torch_geometric.nn import radius_graph 7 | from torch_scatter import scatter_add 8 | 9 | from AIMD.protein import Protein 10 | 11 | 12 | def setup( 13 | pme, rPower, kappa, splineOrder, dimA, dimB, dimC, scaleFactor, numThreads 14 | ): 15 | """ 16 | setup initializes this object for a PME calculation using only threading. 17 | This may be called repeatedly without compromising performance. 18 | :param pme: the PME instance. 19 | :param rPower: the exponent of the (inverse) distance kernel (e.g. 1 for Coulomb, 6 for attractive 20 | dispersion). 21 | :param kappa: the attenuation parameter in units inverse of those used to specify coordinates. 22 | :param splineOrder: the order of B-spline; must be at least (2 + max. multipole order + deriv. level needed). 23 | :param dimA: the dimension of the FFT grid along the A axis. 24 | :param dimB: the dimension of the FFT grid along the B axis. 25 | :param dimC: the dimension of the FFT grid along the C axis. 26 | :param scaleFactor: a scale factor to be applied to all computed energies and derivatives thereof (e.g. the 27 | 1 / [4 pi epslion0] for Coulomb calculations). 28 | :param numThreads: the maximum number of threads to use for each MPI instance; if set to 0 all available threads 29 | are used. 30 | :return: None 31 | """ 32 | pme.setup( 33 | rPower, kappa, splineOrder, dimA, dimB, dimC, scaleFactor, numThreads 34 | ) 35 | 36 | 37 | def set_lattice_vectors( 38 | pme, a, b, c, alpha, beta, gamma, latticeType=helpmelib.PMEInstanceF.LatticeType.XAligned 39 | ): 40 | """ 41 | Sets the unit cell lattice vectors, with units consistent with those used to specify coordinates. 42 | :param pme: the PME instance. 43 | :param a: the A lattice parameter in units consistent with the coordinates. 44 | :param b: the B lattice parameter in units consistent with the coordinates. 45 | :param c: the C lattice parameter in units consistent with the coordinates. 46 | :param alpha: the alpha lattice parameter in degrees. 47 | :param beta: the beta lattice parameter in degrees. 48 | :param gamma: the gamma lattice parameter in degrees. 49 | :param latticeType: how to arrange the lattice vectors. Options are 50 | ShapeMatrix: enforce a symmetric representation of the lattice vectors [c.f. S. Nosé and M. L. Klein, 51 | Mol. Phys. 50 1055 (1983)] particularly appendix C. 52 | XAligned: make the A vector coincide with the X axis, the B vector fall in the XY plane, and the C vector 53 | take the appropriate alignment to completely define the system. 54 | :return: None 55 | """ 56 | pme.set_lattice_vectors(a, b, c, alpha, beta, gamma, latticeType) 57 | 58 | 59 | def compute_E_rec(pme, mat, parameterAngMom, parameters, coordinates): 60 | """ 61 | Runs a PME reciprocal space calculation, computing energies and forces. 62 | :param pme: the PME instance. 63 | :param mat: the matrix type (either MatrixD or MatrixF). 64 | :param parameterAngMom: the angular momentum of the parameters (0 for charges, C6 coefficients, 2 for quadrupoles, etc.). 65 | :param parameters: parameters the list of parameters associated with each atom (charges, C6 66 | coefficients, multipoles, etc...). For a parameter with angular momentum L, a matrix of dimension nAtoms x nL 67 | is expected, where nL = (L+1)*(L+2)*(L+3)/6 and the fast running index nL has the ordering 68 | 0 X Y Z XX XY YY XZ YZ ZZ XXX XXY XYY YYY XXZ XYZ YYZ XZZ YZZ ZZZ ... 69 | :param coordinates: the cartesian coordinates, ordered in memory as {x1,y1,z1,x2,y2,z2,....xN,yN,zN}. 70 | :return: the reciprocal space energy. 71 | """ 72 | return pme.compute_E_rec( 73 | parameterAngMom, mat(parameters), mat(coordinates) 74 | ) 75 | 76 | 77 | def electro_dir(coord, charge, edge_index, beta=1.0): 78 | if edge_index[0].size(0) == 0: 79 | return 0 80 | src, dst = edge_index 81 | dist = torch.norm(coord[dst] - coord[src], dim=1) 82 | discount = erfc(beta * dist) 83 | return 0.5 * (charge[src] * charge[dst] * discount / dist).sum() 84 | 85 | 86 | def electro_slf(charge, beta=1.0): 87 | return -beta / np.sqrt(np.pi) * (charge**2).sum() 88 | 89 | 90 | def electro_adj(coord, charge, exclude_index, beta=1.0): 91 | if exclude_index[0].size(0) == 0: 92 | return 0 93 | src, dst = exclude_index 94 | dist = torch.norm(coord[dst] - coord[src], dim=1) 95 | discount = erf(beta * dist) 96 | return -0.5 * (charge[src] * charge[dst] * discount / dist).sum() 97 | 98 | 99 | def electro_neutral(pme, charge, beta): 100 | # return -np.pi / 2 / beta ** 2 / volume * (charge.sum() ** 2) 101 | q_tot = charge.sum() 102 | rec_sp = compute_E_rec( 103 | pme, 104 | helpmelib.MatrixF, 105 | 0, 106 | q_tot.reshape(1, 1), 107 | np.zeros((1, 3), dtype=np.float32), 108 | ) 109 | self_sp = -beta * q_tot**2 / np.sqrt(np.pi) 110 | return -rec_sp - self_sp 111 | 112 | 113 | class PMENonBondedCalculator: 114 | def __init__(self, beta=0.3, cutoff=9.0, grid_spacing=1.0, device="cpu") -> None: 115 | super().__init__() 116 | self.beta = beta 117 | self.cutoff = cutoff 118 | self.grid_spacing = grid_spacing 119 | self.device = device 120 | self.k = 1 / (4 * pi * _eps0) * 10e6 * mol * C ** (-2) 121 | self.pme = helpmelib.PMEInstanceF() 122 | 123 | self.sigmas = None 124 | self.epsilons = None 125 | self.charges = None 126 | self.charges_cpu = None 127 | self.src_ex = None 128 | self.dst_ex = None 129 | self.coulomb_slf = 0.0 130 | self.coulomb_neutral = 0.0 131 | 132 | def set_parameters(self, prot: Protein) -> None: 133 | # Initialize parameters 134 | sigmas = torch.tensor(prot.sigmas, dtype=torch.float) 135 | epsilons = torch.tensor(prot.epsilons, dtype=torch.float) 136 | charges = torch.tensor(prot.charges, dtype=torch.float) 137 | self.sigmas = sigmas.to(self.device) 138 | self.epsilons = epsilons.to(self.device) 139 | self.charges = charges.to(self.device) 140 | self.charges_cpu = prot.charges[:, None] 141 | src_ex, dst_ex = prot.exclude_index 142 | self.src_ex = src_ex.to(self.device) 143 | self.dst_ex = dst_ex.to(self.device) 144 | 145 | # Initialize PME 146 | cx, cy, cz = prot.cell.diagonal() 147 | dimA = int(cx / self.grid_spacing) 148 | dimB = int(cy / self.grid_spacing) 149 | dimC = int(cz / self.grid_spacing) 150 | setup(self.pme, 1, self.beta, 4, dimA, dimB, dimC, 1, 0) 151 | set_lattice_vectors(self.pme, cx, cy, cz, 90, 90, 90) 152 | self.coulomb_slf = electro_slf(prot.charges, self.beta) 153 | self.coulomb_neutral = electro_neutral(self.pme, prot.charges, self.beta) 154 | 155 | def __call__(self, prot: Protein) -> tuple[np.float32, np.ndarray]: 156 | r""" 157 | Using non-bonded atom pairs to calculate 158 | the non-bonded energy and force. 159 | Non-bonded forces are calculated by grading the non-bonded energy 160 | with respect of the atom positions. 161 | """ 162 | pos_cpu = torch.tensor(prot.get_positions(), dtype=torch.float) 163 | pos = pos_cpu.to(self.device) 164 | src, dst = radius_graph( 165 | pos, self.cutoff, max_num_neighbors=len(pos), loop=False 166 | ) 167 | vec = pos[dst] - pos[src] 168 | d2 = (vec**2).sum(-1) 169 | d = torch.sqrt(d2) 170 | vec_ex = pos[self.dst_ex] - pos[self.src_ex] 171 | d2_ex = (vec_ex**2).sum(-1) 172 | d_ex = torch.sqrt(d2_ex) 173 | 174 | # LJ 175 | sigmaij = 0.5 * (self.sigmas[src] + self.sigmas[dst]) * nm 176 | epsij = torch.sqrt(self.epsilons[src] * self.epsilons[dst]) 177 | c6 = (sigmaij**2 / d2) ** 3 178 | c12 = c6**2 179 | energy_lj = 4 * epsij * (c12 - c6) / 2 180 | force_lj = (24 * epsij * (2 * c12 - c6) / d2).unsqueeze(-1) * vec 181 | 182 | # LJ exclude 183 | sigmaij_ex = 0.5 * (self.sigmas[self.src_ex] + self.sigmas[self.dst_ex]) * nm 184 | epsij_ex = torch.sqrt(self.epsilons[self.src_ex] * self.epsilons[self.dst_ex]) 185 | c6_ex = (sigmaij_ex**2 / d2_ex) ** 3 186 | c12_ex = c6_ex**2 187 | energy_lj_ex = 4 * epsij_ex * (c12_ex - c6_ex) / 2 188 | force_lj_ex = (24 * epsij_ex * (2 * c12_ex - c6_ex) / d2_ex).unsqueeze(-1) * vec_ex 189 | 190 | # Coulomb 191 | energy_coulomb_rec = compute_E_rec(self.pme, helpmelib.MatrixF, 0, self.charges_cpu, pos_cpu) 192 | energy_coulomb_dir = electro_dir(pos, self.charges, (src, dst), self.beta) 193 | energy_coulomb_adj = electro_adj(pos, self.charges, prot.exclude_index, self.beta) 194 | energy_coulomb_ex = electro_dir(pos, self.charges, (self.src_ex, self.dst_ex), self.beta) 195 | energy_coulomb = ( 196 | energy_coulomb_rec 197 | + energy_coulomb_dir 198 | + energy_coulomb_adj 199 | + self.coulomb_slf 200 | + self.coulomb_neutral 201 | - energy_coulomb_ex 202 | ) * self.k 203 | force_coulomb = (self.k * self.charges[src] * self.charges[dst] / d / d2).unsqueeze(-1) * vec 204 | force_coulomb_ex = ( 205 | self.k * self.charges[self.src_ex] * self.charges[self.dst_ex] / d_ex / d2_ex 206 | ).unsqueeze(-1) * vec_ex 207 | 208 | # Combine 209 | energy = energy_lj.sum() - energy_lj_ex.sum() + energy_coulomb 210 | force = scatter_add(force_lj + force_coulomb, dst, dim=0, dim_size=len(prot)) 211 | force_ex = scatter_add(force_lj_ex + force_coulomb_ex, self.dst_ex, dim=0, dim_size=len(prot)) 212 | energy = energy.item() * (kJ / mol) 213 | force = (force - force_ex).cpu().numpy() * (kJ / mol) 214 | return energy, force 215 | -------------------------------------------------------------------------------- /src/Calculators/qmmm.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ThreadPoolExecutor 2 | 3 | from ase.calculators.calculator import Calculator 4 | from ase.calculators.qmmm import SimpleQMMM 5 | 6 | from AIMD.protein import Protein 7 | from Calculators.device_strategy import DeviceStrategy 8 | from utils.utils import WorkQueue, execution_wrapper 9 | 10 | 11 | class AsyncQMMM(SimpleQMMM): 12 | 13 | qmcalc: Calculator 14 | mmcalc1: Calculator 15 | mmcalc2: Calculator 16 | qmatoms: Protein 17 | 18 | def __init__(self, *args, **kwargs): 19 | super().__init__(*args, **kwargs) 20 | 21 | self.work_queue = WorkQueue() 22 | self.concurrent = len(DeviceStrategy.get_solvent_devices()) > 1 23 | 24 | def get_qmcalc_results(self, properties, system_changes): 25 | # force the evaluation of results 26 | self.qmcalc.calculate(self.qmatoms, properties, system_changes) 27 | return ( 28 | self.qmcalc.results["energy"], 29 | self.qmcalc.results["forces"], 30 | ) 31 | 32 | def get_mmcalc1_results(self, properties, system_changes): 33 | # force the evaluation of results 34 | self.mmcalc1.calculate(self.qmatoms, properties, system_changes) 35 | return ( 36 | self.mmcalc1.results["energy"], 37 | self.mmcalc1.results["forces"], 38 | ) 39 | 40 | def get_mmcalc2_results(self, properties, system_changes): 41 | # force the evaluation of results 42 | self.mmcalc2.calculate(self.atoms, properties, system_changes) 43 | return ( 44 | self.mmcalc2.results["energy"], 45 | self.mmcalc2.results["forces"], 46 | ) 47 | 48 | def calculate(self, atoms, properties, system_changes): 49 | Calculator.calculate(self, atoms, properties, system_changes) 50 | 51 | self.qmatoms.positions = atoms.positions[self.selection] 52 | if self.vacuum: 53 | self.qmatoms.positions += self.center - self.qmatoms.positions.mean(axis=0) 54 | 55 | f_args = [ 56 | (self.get_mmcalc1_results, properties, system_changes), 57 | (self.get_mmcalc2_results, properties, system_changes), 58 | ] 59 | 60 | # Start some new threads to calculate the energies and forces 61 | with ThreadPoolExecutor(3) as executor: 62 | qm = executor.submit(self.get_qmcalc_results, properties, system_changes) 63 | mm = executor.submit(execution_wrapper, f_args, self.concurrent) 64 | 65 | self.work_queue.drain() 66 | 67 | qm_e, qm_f = qm.result() 68 | mm_1, mm_2 = mm.result() 69 | 70 | mm1_e, mm1_f = mm_1 71 | mm2_e, mm2_f = mm_2 72 | 73 | energy = mm2_e + qm_e - mm1_e 74 | forces = mm2_f 75 | 76 | if self.vacuum: 77 | qm_f -= qm_f.mean(axis=0) 78 | 79 | forces[self.selection] += qm_f - mm1_f 80 | 81 | self.results["energy"] = energy 82 | self.results["forces"] = forces 83 | -------------------------------------------------------------------------------- /src/Calculators/tinker_async.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import os 3 | import subprocess 4 | from logging import getLogger 5 | import time 6 | 7 | import numpy as np 8 | from ase import Atoms 9 | from ase.calculators.calculator import Calculator 10 | from ase.units import kcal, mol 11 | 12 | from AIMD import arguments 13 | from AIMD.preprocess import run_command 14 | from Calculators.async_utils import AsyncServer 15 | 16 | 17 | _tinker_instance_id = 0 18 | 19 | 20 | class TinkerRuntimeError(RuntimeError): 21 | def __init__(self, *args: object): 22 | super().__init__(args) 23 | 24 | 25 | class TinkerAsyncCalculator(Calculator): 26 | implemented_properties = ["energy", "forces"] 27 | 28 | def __init__(self, pdb_file, utils_dir, devices: list[str], **kwargs): 29 | super().__init__(**kwargs) 30 | self.pdb_file = os.path.abspath(pdb_file) 31 | self.utils_dir = utils_dir 32 | self.prot_name = os.path.basename(pdb_file)[:-4] 33 | self.devices = devices 34 | self._tinker_proc = None 35 | self.atoms: Atoms 36 | 37 | global _tinker_instance_id 38 | self.instance_id = _tinker_instance_id 39 | _tinker_instance_id += 1 40 | 41 | self.server = AsyncServer("tinker") 42 | self.logger = getLogger(f"Tinker-Proxy-{self.instance_id}") 43 | 44 | if any(map(lambda x: x.startswith('cuda'), devices)): 45 | self.command_dir = '/usr/local/gpu-m' 46 | else: 47 | self.command_dir = '/usr/local/cpu-m' 48 | atexit.register(self._shutdown) 49 | 50 | def _start_tinker(self): 51 | self.logger.debug("Initializing tinker...") 52 | self._write_key(self.atoms) 53 | self.logger.debug('Key file written!') 54 | if not os.path.exists(f'{self.prot_name}.xyz'): 55 | self._generate_xyz_template() 56 | self.logger.debug('XYZ template generated!') 57 | 58 | # bind devices 59 | envs = os.environ.copy() 60 | gpus = [] 61 | for device in self.devices: 62 | if device.startswith("cuda"): 63 | _, nr = device.split(':') 64 | gpus.append(nr) 65 | if len(gpus) > 1: 66 | assert self.instance_id < len(gpus) 67 | envs["CUDA_VISIBLE_DEVICES"] = gpus[self.instance_id] 68 | elif len(gpus) == 1: 69 | envs["CUDA_VISIBLE_DEVICES"] = gpus[0] 70 | 71 | outfd = None if arguments.get().verbose >= 3 else subprocess.DEVNULL 72 | self._tinker_proc = subprocess.Popen( 73 | f"{self.command_dir}/tinker9 ai2bmd {self.prot_name} -k {self.prot_name} << _EOF\n" 74 | f"{self.server.socket_path}\n" # unix socket for IPC 75 | f"_EOF", 76 | shell=True, 77 | env=envs, 78 | stdout=outfd, 79 | stderr=outfd, 80 | ) 81 | self.logger.debug('Waiting for Tinker to start...') 82 | self.server.accept() 83 | sz = np.empty(shape=[2], dtype='int32') 84 | self.server.recv(sz) 85 | 86 | self.sz_real = sz[0] 87 | self.sz_energy = sz[1] 88 | self.sz_double = 8 # hard-coded 89 | 90 | self.n_atoms = len(self.atoms) 91 | self._tinker_istep = 1 92 | self._tinker_sbuf = self.server.makebuf([1], 'int32') 93 | self._tinker_ebuf = self.server.makebuf([1], self.sz_energy) 94 | self._tinker_xbuf = self.server.makebuf([self.n_atoms], self.sz_real) 95 | self._tinker_ybuf = self.server.makebuf([self.n_atoms], self.sz_real) 96 | self._tinker_zbuf = self.server.makebuf([self.n_atoms], self.sz_real) 97 | self._tinker_gxbuf = self.server.makebuf([self.n_atoms], self.sz_double) 98 | self._tinker_gybuf = self.server.makebuf([self.n_atoms], self.sz_double) 99 | self._tinker_gzbuf = self.server.makebuf([self.n_atoms], self.sz_double) 100 | 101 | self.logger.debug(f'Tinker connected! tinker::real size={self.sz_real} tinker::energy_prec size={self.sz_energy}') 102 | 103 | def _shutdown(self): 104 | if self.server is None or self._tinker_proc is None: 105 | return 106 | self.logger.debug("Shutting down tinker...") 107 | # send the shutdown command, wait for 500ms, and if the program still does not exit, kill it 108 | try: 109 | self._tinker_sbuf[0] = -1 110 | self.server.send(self._tinker_sbuf) 111 | self._tinker_proc.wait(timeout=0.5) 112 | except BrokenPipeError: 113 | self._tinker_proc.kill() 114 | except subprocess.TimeoutExpired: 115 | self._tinker_proc.kill() 116 | self.server.close() 117 | self.logger.debug("Tinker shutdown complete.") 118 | 119 | def _restart_tinker(self): 120 | self._tinker_proc.kill() 121 | self._tinker_proc.wait() 122 | self.server.connection.close() 123 | time.sleep(2.0) 124 | self._start_tinker() 125 | time.sleep(2.0) 126 | 127 | def _write_key(self, atoms): 128 | with open(f'{self.prot_name}.key', 'w') as f: 129 | f.write(f""" 130 | parameters {self.utils_dir}/amoebabio18.prm 131 | neighbor-list 132 | a-axis {atoms.cell[0, 0]} 133 | b-axis {atoms.cell[1, 1]} 134 | c-axis {atoms.cell[2, 2]} 135 | cutoff 12 136 | vdw-cutoff 12 137 | integrator stochastic 138 | friction 1.0 139 | save-force 140 | ewald 141 | ewald-cutoff 7.0 142 | fft-package FFTW 143 | polarization mutual 144 | polar-eps 0.01 145 | """) 146 | 147 | def _generate_xyz_template(self): 148 | run_command( 149 | f"cp {self.pdb_file} . && " 150 | f"sed -i '/TER/d' {self.prot_name}.pdb && " 151 | f"{self.command_dir}/pdbxyz8 {self.prot_name}", 152 | self.directory 153 | ) 154 | 155 | def _sync_tinker(self): 156 | self._tinker_sbuf[0] = self._tinker_istep 157 | self.server.send(self._tinker_sbuf) 158 | self.server.recv(self._tinker_sbuf) 159 | if self._tinker_sbuf[0] != self._tinker_istep: 160 | raise Exception("tinker_async: status decynchronized") 161 | self._tinker_istep += 1 162 | 163 | def _write_xyz(self): 164 | pos = self.atoms.get_positions() 165 | 166 | # convert to tinker::real 167 | self._tinker_xbuf[:] = pos[:,0] 168 | self._tinker_ybuf[:] = pos[:,1] 169 | self._tinker_zbuf[:] = pos[:,2] 170 | self.server.send(self._tinker_xbuf) 171 | self.server.send(self._tinker_ybuf) 172 | self.server.send(self._tinker_zbuf) 173 | 174 | def _read_result(self): 175 | self.server.recv(self._tinker_ebuf) 176 | self.server.recv(self._tinker_gxbuf) 177 | self.server.recv(self._tinker_gybuf) 178 | self.server.recv(self._tinker_gzbuf) 179 | energy = self._tinker_ebuf[0] * (kcal / mol) 180 | grad = np.stack([self._tinker_gxbuf, self._tinker_gybuf, self._tinker_gzbuf], axis=-1) 181 | return energy, -grad * (kcal / mol) 182 | 183 | def calculate(self, atoms, properties, system_changes): 184 | # Calculator.calculate(self, atoms, properties, system_changes) 185 | for retry in range(100): 186 | try: 187 | self._sync_tinker() 188 | self._write_xyz() 189 | if not self.server.wait_for_data(3.0): 190 | self.logger.debug(f"Tinker instance {self.instance_id} took too long to respond. Restarting...") 191 | self._restart_tinker() 192 | continue 193 | energy, forces = self._read_result() 194 | self.results["energy"] = energy 195 | self.results["forces"] = forces 196 | return 197 | except Exception as err: 198 | print(err) 199 | continue 200 | raise TinkerRuntimeError("tinker crashed") 201 | -------------------------------------------------------------------------------- /src/Calculators/visnet_calculator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import atexit 4 | import os 5 | import subprocess 6 | from logging import getLogger 7 | from os import path as osp 8 | from typing import Union 9 | 10 | import numpy as np 11 | import torch 12 | from ase.calculators.calculator import Calculator 13 | 14 | from AIMD import arguments 15 | from AIMD.fragment import FragmentData 16 | from Calculators.device_strategy import DeviceStrategy 17 | from Calculators.async_utils import AsyncServer, AsyncClient 18 | from ViSNet.model.visnet import load_model 19 | from utils.utils import numpy_to_torch 20 | 21 | 22 | class ViSNetModel: 23 | r""" 24 | Calculate the energy and forces of the system 25 | using deep learning model, i.e., ViSNet. 26 | 27 | Parameters: 28 | ----------- 29 | model: 30 | Deep learning model. 31 | device: cpu | cuda 32 | Device to use for calculation. 33 | """ 34 | implemented_properties = ["energy", "forces"] 35 | 36 | def __init__(self, model, device="cpu"): 37 | self.model = model 38 | self.model.eval() 39 | self.device = device 40 | self.stream = ( 41 | torch.cuda.Stream(device=device) 42 | if device.startswith('cuda') 43 | else None 44 | ) 45 | self.model.to(self.device) 46 | 47 | def collate(self, frag: FragmentData): 48 | z = numpy_to_torch(frag.z, self.device) 49 | pos = numpy_to_torch(frag.pos, self.device) 50 | batch = numpy_to_torch(frag.batch, self.device) 51 | 52 | return dict(z=z, pos=pos, batch=batch) 53 | 54 | def dl_potential_loader(self, frag_data: FragmentData): 55 | # note: this context is a no-op if self.stream is None 56 | with torch.cuda.stream(self.stream): 57 | with torch.set_grad_enabled(True): 58 | e, f = self.model(self.collate(frag_data)) 59 | 60 | e = e.detach().cpu().reshape(-1, 1).numpy() 61 | f = f.detach().cpu().reshape(-1, 3).numpy() 62 | 63 | return e, f 64 | 65 | @classmethod 66 | def from_file(cls, **kwargs): 67 | if "model_path" not in kwargs: 68 | raise ValueError("model_path must be provided") 69 | 70 | model_path = kwargs["model_path"] 71 | device = kwargs.get("device", "cpu") 72 | 73 | model = load_model(model_path) 74 | out = cls(model, device=device) 75 | return out 76 | 77 | 78 | class ViSNetAsyncModel: 79 | """A proxy object that spawns a subprocess, loads a model, and serve inference requests.""" 80 | 81 | def __init__(self, model_path: str, device: str): 82 | self.model_path = model_path 83 | self.device = device 84 | self.server = AsyncServer("ViSNet") 85 | self.logger = getLogger("ViSNet-Proxy") 86 | envs = os.environ.copy() 87 | envs["PYTHONPATH"] = f"{osp.abspath(osp.join(osp.dirname(__file__), '..'))}:{envs['PYTHONPATH']}" 88 | outfd = None if arguments.get().verbose >= 3 else subprocess.DEVNULL 89 | # use __file__ as process so that viztracer-patched subprocess doesn't track us 90 | # this file should have chmod +x 91 | self.proc = subprocess.Popen( 92 | [ 93 | __file__, 94 | "--model-path", model_path, 95 | "--device", device, 96 | "--socket-path", self.server.socket_path, 97 | ], 98 | shell=False, 99 | env=envs, 100 | stdout=outfd, 101 | stderr=outfd, 102 | ) 103 | self.logger.debug(f'Waiting for worker ({device}) to start...') 104 | self.server.accept() 105 | self.logger.debug(f'Worker ({device}) started.') 106 | atexit.register(self._shutdown) 107 | 108 | def dl_potential_loader(self, data: FragmentData): 109 | self.server.send_object(data) 110 | return self.server.recv_object() 111 | 112 | def _shutdown(self): 113 | self.logger.debug(f"Shutting down worker ({self.device})...") 114 | if self.proc and self.proc.poll() is None: 115 | self.proc.kill() 116 | if self.server: 117 | self.server.close() 118 | self.logger.debug(f"Worker ({self.device}) shutdown complete.") 119 | 120 | 121 | class ViSNetCalculator(Calculator): 122 | r""" 123 | Feed the input through a ViSNet model, without fragmentation 124 | """ 125 | 126 | implemented_properties = ["energy", "forces"] 127 | 128 | def __init__(self, ckpt_path: str, ckpt_type: str, 129 | is_root_calc=True, **kwargs): 130 | super().__init__(**kwargs) 131 | self.ckpt_path = ckpt_path 132 | self.ckpt_type = ckpt_type 133 | self.is_root_calc = is_root_calc 134 | model_path = osp.join(self.ckpt_path, f"visnet-uni-{self.ckpt_type}.ckpt") 135 | self.device = DeviceStrategy.get_bonded_devices()[0] 136 | self.model = get_visnet_model(model_path, self.device) 137 | 138 | def calculate(self, atoms, properties, system_changes): 139 | if self.is_root_calc: 140 | Calculator.calculate(self, atoms, properties, system_changes) 141 | 142 | data = FragmentData( 143 | atoms.numbers, 144 | atoms.positions.astype(np.float32), 145 | np.array([0], dtype=int), 146 | np.array([len(atoms)], dtype=int), 147 | np.zeros((len(atoms),), dtype=int), 148 | ) 149 | 150 | e, f = self.model.dl_potential_loader(data) 151 | 152 | self.results = { 153 | "energy": e, 154 | "forces": f, 155 | } 156 | 157 | 158 | if __name__ == "__main__": 159 | parser = argparse.ArgumentParser("ViSNet proxy") 160 | parser.add_argument("--model-path", type=str, required=True) 161 | parser.add_argument("--device", type=str, required=True) 162 | parser.add_argument("--socket-path", type=str, required=True) 163 | args = parser.parse_args() 164 | 165 | kwargs = { 166 | 'model_path': args.model_path, 167 | 'device': args.device, 168 | } 169 | calculator = ViSNetModel.from_file(**kwargs) 170 | client = AsyncClient(args.socket_path) 171 | # start serving 172 | try: 173 | while True: 174 | data: FragmentData = client.recv_object() 175 | output = calculator.dl_potential_loader(data) 176 | client.send_object(output) 177 | except Exception: 178 | exit(0) 179 | 180 | ViSNetModelLike = Union[ViSNetModel, ViSNetAsyncModel] 181 | _local_calc: dict[str, ViSNetModel] = {} 182 | 183 | 184 | def get_visnet_model(model_path: str, device: str): 185 | # allow up to 1 copy of GPU model to run in the master process 186 | device_sig = device 187 | if device_sig.startswith('cuda'): 188 | device_sig = 'cuda' 189 | signature = f"{device_sig}-{model_path}" 190 | if signature in _local_calc: # exists in master 191 | if device == 'cpu': 192 | # work around CPU model on worker proxy problem: always reuse local 193 | return _local_calc[signature] 194 | else: 195 | # do not reuse local, but create a proxy 196 | return ViSNetAsyncModel(model_path, device) 197 | else: # doesn't exist in master, create one 198 | kwargs = { 199 | 'model_path': model_path, 200 | 'device': device, 201 | } 202 | calc = ViSNetModel.from_file(**kwargs) 203 | _local_calc[signature] = calc 204 | return calc 205 | -------------------------------------------------------------------------------- /src/Fragmentation/__init__.py: -------------------------------------------------------------------------------- 1 | from .distancefrag import DistanceFragment 2 | 3 | __all__ = ["DistanceFragment"] 4 | -------------------------------------------------------------------------------- /src/Fragmentation/basefrag.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from ase import Atoms 4 | 5 | from AIMD import arguments 6 | from AIMD.fragment import FragmentData 7 | from AIMD.protein import Protein 8 | 9 | 10 | class BaseFragment(ABC): 11 | r""" 12 | Basic fragment Object. 13 | """ 14 | 15 | def __init__(self) -> None: 16 | pass 17 | 18 | @abstractmethod 19 | def fragment(self, *args, **kwargs): 20 | r""" 21 | The basic fragment method. Fragment the protein into specific 22 | sub-structures (e.g. to generate dipeptides and ACE-NMEs). 23 | Please rewrite this function in subclass according to your specific 24 | fragment method. 25 | """ 26 | pass 27 | 28 | 29 | class DipeptideFragment(BaseFragment): 30 | r""" 31 | A subclass of BaseFragment to generate dipeptides and ACE-NMEs. All the 32 | dipeptides and ACE-NMEs are generated by the same method, i.e., @method 33 | get_fragment_index, but with different postprocessing in @method fragment. 34 | The @method fragment must invoke @method get_fragment_index. 35 | """ 36 | 37 | def __init__(self) -> None: 38 | super().__init__() 39 | 40 | @abstractmethod 41 | def get_fragments(self, prot: Protein) -> FragmentData: 42 | pass 43 | 44 | @staticmethod 45 | def get_fragments_index(prot: Atoms) -> tuple[list[int], list[int]]: 46 | r""" 47 | Fragment the protein into dipeptides/ACE-NMEs and return their indices 48 | in the original protein. 49 | 50 | Parameters: 51 | ----------- 52 | prot: Atoms 53 | The protein to be fragmented. 54 | 55 | Returns: 56 | -------- 57 | dipeptides: list[list[int]] 58 | The index of dipeptides in the protein. 59 | acenmes: list[list[int]] 60 | The index of ACE-NMEs in the protein. 61 | 62 | """ 63 | 64 | # prot.arrays['residuenumbers'] starts from 1 65 | num_residue = max(prot.arrays["residuenumbers"]) 66 | 67 | assert ( 68 | len(set(prot.arrays["residuenumbers"])) == num_residue 69 | ), "residue numbers are not continuous" 70 | 71 | # * One protein with $N$ residuals can be fragmented into $N-2$ 72 | # dipeptides and $N-3$ ACE-NMEs 73 | num_dipeptides = num_residue - 2 74 | num_acenmes = num_residue - 3 75 | 76 | if num_dipeptides < 2: 77 | raise NotImplementedError( 78 | "Attempting to fragment protein with 3 or fewer residues " 79 | "(including ACE/NME caps). Please check that the input file " 80 | "is prepared according to the instructions, or pass " 81 | "'--mode visnet' to run simulations on the input as an entire " 82 | "unit, i.e., without fragmentation." 83 | ) 84 | 85 | dipeptides: list[list[int]] = [[] for _ in range(num_dipeptides)] 86 | sidechains: list[list[int]] = [[] for _ in range(num_dipeptides)] 87 | acenmes: list[list[int]] = [[] for _ in range(num_acenmes)] 88 | 89 | residuenames = prot.arrays["residuenames"] 90 | residuenumbers = prot.arrays["residuenumbers"] 91 | atomtypes = prot.arrays["atomtypes"] 92 | 93 | for index in range(len(prot)): 94 | dipeptide_index = residuenumbers[index] - 2 95 | ace_index = residuenumbers[index] - 2 96 | nme_index = residuenumbers[index] - 3 97 | 98 | # 位于头部的ACE原子全部拷贝到第一个dipeptide。 99 | if str(residuenames[index]).strip() == "ACE": 100 | dipeptides[0].append(index) 101 | 102 | # 位于尾部的NME原子全部拷贝到最后一个dipeptide。 103 | elif str(residuenames[index]).strip() == "NME": 104 | dipeptides[-1].append(index) 105 | 106 | elif atomtypes[index] == "CA" or atomtypes[index][:2] == "HA": 107 | # CA, HA belong in the dipeptide and both ACE/NME 108 | 109 | # 氨基酸的C alpha原子拷贝到上一个dipeptide。 110 | if dipeptide_index > 0: 111 | dipeptides[dipeptide_index - 1].append(index) 112 | # 氨基酸的C alpha原子拷贝到自身的dipeptide。 113 | dipeptides[dipeptide_index].append(index) 114 | # 氨基酸的C alpha原子拷贝到下一个dipeptide。 115 | if dipeptide_index < num_dipeptides - 1: 116 | dipeptides[dipeptide_index + 1].append(index) 117 | 118 | if ace_index >= 0 and ace_index <= num_acenmes - 1: 119 | acenmes[ace_index].append(index) 120 | if nme_index >= 0 and nme_index <= num_acenmes - 1: 121 | acenmes[nme_index].append(index) 122 | 123 | elif atomtypes[index] == "C" or atomtypes[index] == "O": 124 | # C, O belong in the dipeptide and the ACE 125 | 126 | # 氨基酸的C、O原子拷贝到自身的dipeptide。 127 | dipeptides[dipeptide_index].append(index) 128 | # 氨基酸的C、O原子拷贝到下一个dipeptide。 129 | if dipeptide_index < num_dipeptides - 1: 130 | dipeptides[dipeptide_index + 1].append(index) 131 | 132 | if ace_index >= 0 and ace_index <= num_acenmes - 1: 133 | acenmes[ace_index].append(index) 134 | 135 | elif atomtypes[index] == "N" or atomtypes[index] == "H": 136 | # N, H belong in the dipeptide and the NME 137 | 138 | # 氨基酸的N、H原子拷贝到上一个dipeptide。 139 | if dipeptide_index > 0: 140 | dipeptides[dipeptide_index - 1].append(index) 141 | # 氨基酸的N、H原子拷贝到自身的dipeptide。 142 | dipeptides[dipeptide_index].append(index) 143 | 144 | if nme_index >= 0 and nme_index <= num_acenmes - 1: 145 | acenmes[nme_index].append(index) 146 | 147 | else: 148 | # 残基原子拷贝到自身的dipeptide。 149 | sidechains[dipeptide_index].append(index) 150 | 151 | # tinker: insert sidechain into backbone, just before the second 'N' 152 | for idx, unit in enumerate(dipeptides): 153 | nitrogens = [i for i, index in enumerate(unit) if atomtypes[index] == 'N'] 154 | assert len(nitrogens) == 2, "number of nitrogen atoms in dipeptide != 2" 155 | 156 | unit[nitrogens[1]:nitrogens[1]] = sidechains[idx] 157 | 158 | # print atom types of fragments 159 | if arguments.get().verbose >= 1: 160 | print(" [i] dipeptide fragments:") 161 | for idx, unit in enumerate(dipeptides): 162 | print(f"{idx:>8} | {' '.join([atomtypes[i] for i in unit])}") 163 | print(" [i] ACE-NME fragments:") 164 | for idx, unit in enumerate(acenmes): 165 | print(f"{idx:>8} | {' '.join([atomtypes[i] for i in unit])}") 166 | 167 | return dipeptides, acenmes 168 | -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/AA.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/AA.npz -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/AN.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/AN.npz -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/ANAN.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/ANAN.npz -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/CC.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/CC.npz -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/DD.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/DD.npz -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/EE.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/EE.npz -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/FF.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/FF.npz -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/GG.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/GG.npz -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/HH.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/HH.npz -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/HID.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/HID.npz -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/II.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/II.npz -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/KK.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/KK.npz -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/LL.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/LL.npz -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/MM.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/MM.npz -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/NN.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/NN.npz -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/PP.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/PP.npz -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/QQ.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/QQ.npz -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/RR.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/RR.npz -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/SS.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/SS.npz -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/TT.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/TT.npz -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/VV.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/VV.npz -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/WW.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/WW.npz -------------------------------------------------------------------------------- /src/Fragmentation/bondlen/YY.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/Fragmentation/bondlen/YY.npz -------------------------------------------------------------------------------- /src/Fragmentation/hydrogen/__init__.py: -------------------------------------------------------------------------------- 1 | from .ctable import CTable 2 | from .energies import HydrogenOptimizer 3 | from .topology import ProteinData, ProteinDataBatch 4 | 5 | __all__ = ["CTable", "HydrogenOptimizer", "ProteinData", "ProteinDataBatch"] 6 | -------------------------------------------------------------------------------- /src/Fragmentation/hydrogen/ctable.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class CTable: 6 | """ 7 | The coefficient table as defined in the AMBER topology file `.prmtop`. 8 | The attributes share the same name as defined in the AMBER manual. 9 | See https://ambermd.org/FileFormats.php for more details. 10 | """ 11 | 12 | def __init__(self): 13 | # Scalar values defined in `%FLAG POINTERS` 14 | self.natom = None 15 | self.ntypes = None 16 | self.numbnd = None 17 | self.numang = None 18 | self.nptra = None 19 | 20 | # 1D Arrays 21 | self.charge = None 22 | self.atomic_number = None 23 | self.atom_type_idx = None 24 | self.number_excluded_atoms = None 25 | self.nonbonded_parm_index = None 26 | self.bond_force_constant = None 27 | self.bond_equil_value = None 28 | self.angle_force_constant = None 29 | self.angle_equil_value = None 30 | self.dihedral_force_constant = None 31 | self.dihedral_periodicity = None 32 | self.dihedral_phase = None 33 | self.scee_scale_factor = None 34 | self.scnb_scale_factor = None 35 | self.lennard_jones_acoef = None 36 | self.lennard_jones_bcoef = None 37 | 38 | # 2D Arrays 39 | self.bonds_inc_hydrogen = None 40 | self.angles_inc_hydrogen = None 41 | self.dihedrals_inc_hydrogen = None 42 | self.excluded_atoms_list = None 43 | 44 | def print_stats(self): 45 | for name, value in self.__dict__.items(): 46 | if isinstance(value, torch.Tensor): 47 | print(f"{name}: {value.size()}") 48 | else: 49 | print(f"{name}: {value}") 50 | 51 | def to(self, device): 52 | for name, value in self.__dict__.items(): 53 | if isinstance(value, torch.Tensor): 54 | setattr(self, name, value.to(device)) 55 | return self 56 | 57 | @classmethod 58 | def from_prmtop(cls, filename): 59 | """ 60 | Read the coefficient table from the AMBER topology file `.prmtop`. 61 | :param filename: file name 62 | :return: a `CTable` instance 63 | """ 64 | torch_dtype = {float: torch.float, int: torch.long} 65 | 66 | def _read_flag(flag, dtype=float, shape=None, as_tensor=True): 67 | """ 68 | Read the values of a flag. A Flag is defined in a 69 | single line starting with `%FLAG` 70 | :param flag: flag string 71 | :param dtype: value type 72 | :param shape: value shape. If None is provided, 73 | the values are read as an 1D array. 74 | :param as_tensor: whether to return a tensor 75 | :return: a list or a tensor 76 | """ 77 | nonlocal idx 78 | try: 79 | while not lines[idx].startswith(flag): 80 | idx += 1 81 | except IndexError: 82 | raise ValueError(f"Cannot find flag {flag!r}") 83 | 84 | idx += 2 85 | values = [] 86 | while idx < len(lines) and not lines[idx].startswith("%"): 87 | values.extend(map(dtype, lines[idx].split())) 88 | idx += 1 89 | if not as_tensor: 90 | return values 91 | values = torch.tensor(values, dtype=torch_dtype[dtype]) 92 | if shape is not None: 93 | values = values.view(shape) 94 | return values 95 | 96 | ctable = cls() 97 | with open(filename, "r") as f: 98 | lines = [line.strip() for line in f.readlines()] 99 | idx = 0 100 | 101 | scalars = _read_flag("%FLAG POINTERS", int, as_tensor=False) 102 | ctable.natom = scalars[0] 103 | ctable.ntypes = scalars[1] 104 | ctable.numbnd = scalars[15] 105 | ctable.numang = scalars[16] 106 | ctable.nptra = scalars[17] 107 | 108 | # All indices are stored as 1-based in the prmtop file 109 | ctable.charge = _read_flag("%FLAG CHARGE", float) 110 | ctable.atomic_number = _read_flag("%FLAG ATOMIC_NUMBER", int) 111 | ctable.atom_type_idx = _read_flag("%FLAG ATOM_TYPE_INDEX", int) - 1 112 | ctable.number_excluded_atoms = _read_flag( 113 | "%FLAG NUMBER_EXCLUDED_ATOMS", int 114 | ) 115 | ctable.nonbonded_parm_index = ( 116 | _read_flag("%FLAG NONBONDED_PARM_INDEX", int) - 1 117 | ) 118 | ctable.bond_force_constant = _read_flag( 119 | "%FLAG BOND_FORCE_CONSTANT", float 120 | ) 121 | ctable.bond_equil_value = _read_flag("%FLAG BOND_EQUIL_VALUE", float) 122 | ctable.angle_force_constant = _read_flag( 123 | "%FLAG ANGLE_FORCE_CONSTANT", float 124 | ) 125 | ctable.angle_equil_value = _read_flag("%FLAG ANGLE_EQUIL_VALUE", float) 126 | ctable.dihedral_force_constant = _read_flag( 127 | "%FLAG DIHEDRAL_FORCE_CONSTANT", float 128 | ) 129 | ctable.dihedral_periodicity = _read_flag( 130 | "%FLAG DIHEDRAL_PERIODICITY", float 131 | ) 132 | ctable.dihedral_phase = _read_flag("%FLAG DIHEDRAL_PHASE", float) 133 | ctable.scee_scale_factor = _read_flag("%FLAG SCEE_SCALE_FACTOR", float) 134 | ctable.scnb_scale_factor = _read_flag("%FLAG SCNB_SCALE_FACTOR", float) 135 | ctable.lennard_jones_acoef = _read_flag( 136 | "%FLAG LENNARD_JONES_ACOEF", float 137 | ) 138 | ctable.lennard_jones_bcoef = _read_flag( 139 | "%FLAG LENNARD_JONES_BCOEF", float 140 | ) 141 | 142 | # The bond, angle, and dihedral arrays store the atom indices as 3N 143 | ctable.bonds_inc_hydrogen = torch.div( 144 | _read_flag("%FLAG BONDS_INC_HYDROGEN", int, (-1, 3)), 145 | torch.LongTensor([[3, 3, 1]]), 146 | rounding_mode="floor", 147 | ) 148 | ctable.angles_inc_hydrogen = torch.div( 149 | _read_flag("%FLAG ANGLES_INC_HYDROGEN", int, (-1, 4)), 150 | torch.LongTensor([[3, 3, 3, 1]]), 151 | rounding_mode="floor", 152 | ) 153 | ctable.dihedrals_inc_hydrogen = torch.div( 154 | _read_flag("%FLAG DIHEDRALS_INC_HYDROGEN", int, (-1, 5)), 155 | torch.LongTensor([[3, 3, 3, 3, 1]]), 156 | rounding_mode="floor", 157 | ) 158 | # Convert the atom indices to 0-based 159 | ctable.bonds_inc_hydrogen[:, -1] -= 1 160 | ctable.angles_inc_hydrogen[:, -1] -= 1 161 | ctable.dihedrals_inc_hydrogen[:, -1] -= 1 162 | # For unknown reason, there may be zeros in `excluded_atoms_list` 163 | ctable.excluded_atoms_list = ( 164 | _read_flag("%FLAG EXCLUDED_ATOMS_LIST", int) - 1 165 | ) 166 | return ctable 167 | 168 | def filter_bonds(self, atom_idx): 169 | """ 170 | Filter the bonds that include the given hydrogen atoms. 171 | :param atom_idx: hydrogen atom indices to be considered 172 | :return: `bonds_inc_hydrogen` of shape (3, n_bonds) 173 | """ 174 | mask = torch.isin(self.bonds_inc_hydrogen[:, :2], atom_idx).any(dim=-1) 175 | return self.bonds_inc_hydrogen[mask].t() 176 | 177 | def filter_angles(self, atom_idx): 178 | """ 179 | Filter the angles that include the given hydrogen atoms. 180 | :param atom_idx: hydrogen atom indices to be considered 181 | :return: `angles_inc_hydrogen` of shape (4, n_angles) 182 | """ 183 | mask = torch.isin(self.angles_inc_hydrogen[:, :3], atom_idx).any( 184 | dim=-1 185 | ) 186 | return self.angles_inc_hydrogen[mask].t() 187 | 188 | def filter_dihedrals(self, atom_idx): 189 | """ 190 | Filter the dihedrals that include the given hydrogen atoms. 191 | :param atom_idx: hydrogen atom indices to be considered 192 | :return: `dihedrals_inc_hydrogen` of shape (5, n_dihedrals) 193 | """ 194 | mask = torch.isin(self.dihedrals_inc_hydrogen[:, :4], atom_idx).any( 195 | dim=-1 196 | ) 197 | mask &= (self.dihedrals_inc_hydrogen[:, 2:4] >= 0).all(dim=-1) 198 | return self.dihedrals_inc_hydrogen[mask].t() 199 | 200 | def gen_nonbonded_pair(self, atom_idx): 201 | """ 202 | Generate the non-bonded pair of the given hydrogen atoms. 203 | :param atom_idx: hydrogen atom indices to be considered 204 | :return: edge_index of shape (2, n_edges) 205 | """ 206 | excluded_atoms_ptr = F.pad( 207 | torch.cumsum(self.number_excluded_atoms, dim=0), 208 | (1, 0), 209 | "constant", 210 | value=0, 211 | ) 212 | 213 | all_edge_index = set( 214 | [ 215 | (i, j) 216 | for i in range(self.natom) 217 | for j in range(i + 1, self.natom) 218 | if i in atom_idx or j in atom_idx 219 | ] 220 | ) 221 | for atom_idx_src in range(self.natom): 222 | start, end = ( 223 | excluded_atoms_ptr[atom_idx_src], 224 | excluded_atoms_ptr[atom_idx_src + 1], 225 | ) 226 | exc_idx = self.excluded_atoms_list[start:end].tolist() 227 | for atom_idx_dst in exc_idx: 228 | # The excluded atom idx only contains atoms with larger indices 229 | all_edge_index.discard((atom_idx_src, atom_idx_dst)) 230 | edge_index = torch.LongTensor(list(all_edge_index)).t() 231 | return edge_index 232 | 233 | def generate_lj_idx(self, atom_idx_src, atom_idx_dst): 234 | """ 235 | Generate the Lennard-Jones index of the given hydrogen atoms. 236 | :param atom_idx_src: the source atom indices 237 | :param atom_idx_dst: the destination atom indices 238 | :return: indices of (n_edges,) 239 | """ 240 | parm_idx = ( 241 | self.ntypes * self.atom_type_idx[atom_idx_src] 242 | + self.atom_type_idx[atom_idx_dst] 243 | ) 244 | return self.nonbonded_parm_index[parm_idx] 245 | 246 | def __repr__(self): 247 | return ( 248 | f"{self.__class__.__name__}({self.natom} atoms," 249 | f" {self.ntypes} types, " 250 | f"{self.numbnd} bonds, {self.numang} angles," 251 | f" {self.nptra} dihedrals)" 252 | ) 253 | -------------------------------------------------------------------------------- /src/Fragmentation/hydrogen/energies.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch_scatter import scatter_add 4 | 5 | from Fragmentation.hydrogen.topology import ProteinData 6 | 7 | 8 | @torch.jit.script 9 | def _get_bond_energy_jit(coord, bond_idx_src, bond_idx_dst, bond_force_constant, bond_idx, bond_equil_value, bond_batch): 10 | dist = torch.norm(coord[bond_idx_src] - coord[bond_idx_dst], dim=-1) 11 | energy = bond_force_constant[bond_idx] * (dist - bond_equil_value[bond_idx]).square() 12 | return scatter_add(energy, bond_batch, dim=0) * 0.5 13 | 14 | 15 | @torch.jit.script 16 | def _get_angle_energy_jit(coord, angles_i, angles_j, angles_k, angle_force_constant, angle_idx, angle_equil_value, angle_batch): 17 | v0 = coord[angles_i] - coord[angles_j] 18 | v1 = coord[angles_k] - coord[angles_j] 19 | y = torch.norm(torch.cross(v0, v1, dim=-1), dim=-1) 20 | x = torch.einsum('ij,ij->i', v0, v1) 21 | angle = torch.atan2(y, x) 22 | 23 | energy = angle_force_constant[angle_idx] * (angle - angle_equil_value[angle_idx]).square() 24 | return scatter_add(energy, angle_batch, dim=0) * 0.5 25 | 26 | 27 | @torch.jit.script 28 | def _get_dihedral_energy_jit(coord, dih_i, dih_j, dih_k, dih_l, dih_force_constant, dih_idx, dih_periodicity, dih_phase, dih_batch): 29 | p0, p1 = coord[dih_i], coord[dih_j] 30 | p2, p3 = coord[dih_k], coord[dih_l] 31 | v0, v1, v2 = p1 - p2, p1 - p0, p3 - p2 32 | n1 = F.normalize(torch.cross(v1, v0, dim=-1), dim=-1) 33 | n2 = F.normalize(torch.cross(v0, v2, dim=-1), dim=-1) 34 | m1 = torch.cross(n1, F.normalize(v0, dim=-1), dim=-1) 35 | x = torch.sum(n1 * n2, dim=-1) 36 | y = torch.sum(m1 * n2, dim=-1) 37 | dihedral = torch.atan2(y, x) 38 | # TODO: There is discrepancy between the Amber manual and the PDF. 39 | # Which one is correct? 40 | energy = dih_force_constant[dih_idx] * ( 41 | 1 + torch.cos(dih_periodicity[dih_idx] * dihedral - dih_phase[dih_idx]) 42 | ) 43 | return scatter_add(energy, dih_batch, dim=0) * 0.5 44 | 45 | 46 | @torch.jit.script 47 | def _get_vdw_energy_jit(lj_acoef, lj_bcoef, lj_idx, nonbonded_batch, scnb_scale_factor, dist): 48 | # TODO: Amber discount the 1-4 VDW interactions by `scnb_scale_factor`. 49 | # Implement the discounting. 50 | r6 = torch.pow(dist, 6) 51 | r12 = torch.square(r6) 52 | energy = lj_acoef[lj_idx] / r12 - lj_bcoef[lj_idx] / r6 53 | return scatter_add(energy, nonbonded_batch, dim=0) / scnb_scale_factor 54 | 55 | 56 | @torch.jit.script 57 | def _get_elec_energy_jit(charge, nonbonded_idx_src, nonbonded_idx_dst, nonbonded_batch, scee_scale_factor, dist): 58 | # TODO: Amber discount the 1-4 electrostatic 59 | # interactions by `scee_scale_factor`. Implement the discounting. 60 | energy = charge[nonbonded_idx_src] * charge[nonbonded_idx_dst] / dist 61 | return scatter_add(energy, nonbonded_batch, dim=0) / scee_scale_factor 62 | 63 | 64 | class HydrogenOptimizer: 65 | """ 66 | Hydrogen coordinate optimizer. 67 | It contains an energy calculator and an optimizer. 68 | Amber potential energy is calculated including the bond, angle, 69 | and dihedral energies. 70 | Non-bonded energies include the van der Waals energy 71 | and the electrostatic energy. 72 | Torch's `LBFGS` implementation is used as the optimizer. 73 | """ 74 | 75 | def __init__( 76 | self, max_iter=5, scnb_scale_factor=1.2, scee_scale_factor=2.0 77 | ): 78 | self.max_iter = max_iter 79 | self.scnb_scale_factor = scnb_scale_factor 80 | self.scee_scale_factor = scee_scale_factor 81 | 82 | def get_bond_energy(self, batch: ProteinData): 83 | r""" 84 | Bond energy are calculated according to the equation 85 | 86 | .. math:: 87 | E_{bond} = \frac{1}{2} k (r - r_{eq})^2 88 | 89 | :param batch: batch of `ProteinData` 90 | :return: graph-level bond energy 91 | """ 92 | return _get_bond_energy_jit( 93 | batch.pos, 94 | batch.bonds_atom_idx_src, 95 | batch.bonds_atom_idx_dst, 96 | batch.bond_force_constant, 97 | batch.bond_idx, 98 | batch.bond_equil_value, 99 | batch.bond_batch, 100 | ) 101 | 102 | def get_angle_energy(self, batch: ProteinData): 103 | r""" 104 | Angle energy are calculated according to the equation 105 | 106 | .. math:: 107 | E_{angle} = \frac{1}{2} k_\theta (\theta - \theta_{eq})^2 108 | 109 | :param batch: batch of `ProteinData` 110 | :return: graph-level angle energy 111 | """ 112 | return _get_angle_energy_jit( 113 | batch.pos, 114 | batch.angles_atom_idx_i, 115 | batch.angles_atom_idx_j, 116 | batch.angles_atom_idx_k, 117 | batch.angle_force_constant, 118 | batch.angle_idx, 119 | batch.angle_equil_value, 120 | batch.angle_batch, 121 | ) 122 | 123 | def get_dihedral_energy(self, batch: ProteinData): 124 | r""" 125 | Dihedral (torsion) energy are calculated according to the equation 126 | 127 | .. math:: 128 | E_{dihedral} = \frac{1}{2} k_{tor} (1 + \cos(n \phi - \psi)) 129 | 130 | :param batch: batch of `ProteinData` 131 | :return: graph-level angle energy 132 | """ 133 | 134 | return _get_dihedral_energy_jit( 135 | batch.pos, 136 | batch.dihedrals_atom_idx_i, 137 | batch.dihedrals_atom_idx_j, 138 | batch.dihedrals_atom_idx_k, 139 | batch.dihedrals_atom_idx_l, 140 | batch.dihedral_force_constant, 141 | batch.dihedral_idx, 142 | batch.dihedral_periodicity, 143 | batch.dihedral_phase, 144 | batch.dihedral_batch, 145 | ) 146 | 147 | def get_vdw_energy(self, batch: ProteinData, dist): 148 | r""" 149 | Lennard-Jones potential for van der Waals interactions. 150 | 151 | .. math:: 152 | E_{vdw} = \frac{a_{ij}}{r_{ij}^{12}} - \frac{b_{ij}}{r_{ij}^{6}} 153 | 154 | :param batch: batch of `ProteinData` 155 | :param dist: distance between atoms 156 | :return: graph-level electrostatic energy 157 | """ 158 | 159 | return _get_vdw_energy_jit( 160 | batch.lennard_jones_acoef, 161 | batch.lennard_jones_bcoef, 162 | batch.lj_idx, 163 | batch.nonbonded_batch, 164 | self.scnb_scale_factor, 165 | dist, 166 | ) 167 | 168 | def get_elec_energy(self, batch: ProteinData, dist): 169 | r""" 170 | Coulomb potential for electrostatic interactions. 171 | 172 | .. math:: 173 | E_{elec} = \frac{q_i q_j}{4 \pi \epsilon_0 r_{ij}} 174 | 175 | :param batch: batch of `ProteinData` 176 | :param dist: distance between atoms 177 | :return: graph-level electrostatic energy 178 | """ 179 | return _get_elec_energy_jit( 180 | batch.charge, 181 | batch.nonbonded_atom_idx_src, 182 | batch.nonbonded_atom_idx_dst, 183 | batch.nonbonded_batch, 184 | self.scee_scale_factor, 185 | dist, 186 | ) 187 | 188 | def cal_potential_energy(self, batch: ProteinData): 189 | """ 190 | Calculate the potential energy of the protein batch. 191 | :param batch: batch of `ProteinData` 192 | :return: a tensor of graph-level energy terms 193 | """ 194 | dist = torch.norm( 195 | batch.pos[batch.nonbonded_atom_idx_src] 196 | - batch.pos[batch.nonbonded_atom_idx_dst], 197 | dim=-1, 198 | ) 199 | energies = torch.stack( 200 | [ 201 | self.get_bond_energy(batch), 202 | self.get_angle_energy(batch), 203 | self.get_dihedral_energy(batch), 204 | self.get_vdw_energy(batch, dist), 205 | self.get_elec_energy(batch, dist), 206 | ], 207 | dim=-1, 208 | ) 209 | return energies 210 | 211 | def optimize_hydrogen(self, batch: ProteinData): 212 | """ 213 | Optimize the given hydrogen atoms. Note that this function modifies 214 | the coordinates in-place. 215 | :param batch: batch of `ProteinData` 216 | :return: optimized batch 217 | """ 218 | def closure(): 219 | optimizer.zero_grad() 220 | positions = torch.cat([atom_pos, other_pos]) 221 | batch.pos = positions[sort_idx] 222 | energy = self.cal_potential_energy(batch).sum() 223 | energy.backward() 224 | return energy 225 | 226 | device = batch.pos.device 227 | 228 | atom_pos = torch.nn.Parameter(batch.pos[batch.atom_idx]) 229 | other_pos = batch.pos[batch.other_idx] 230 | all_idx = torch.cat([batch.atom_idx, batch.other_idx]) 231 | sort_idx = torch.zeros_like(all_idx) 232 | sort_idx[all_idx] = torch.arange(len(all_idx), device=device) 233 | optimizer = torch.optim.LBFGS( 234 | [atom_pos], 235 | lr=0.1, 236 | max_iter=self.max_iter, 237 | tolerance_grad=0.1, 238 | tolerance_change=0.01, 239 | ) 240 | optimizer.step(closure) 241 | batch.pos.detach_() 242 | return batch 243 | -------------------------------------------------------------------------------- /src/Fragmentation/hydrogen/topology.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch_geometric.data import Batch, Data, separate 5 | 6 | from .ctable import CTable 7 | 8 | 9 | class ProteinData(Data): 10 | """ 11 | Class representing a protein structure as a graph. 12 | This class modifies the `__inc__` method so that 13 | the graph attributes can be correctly batched. 14 | Note that `pyg` requires a data class to have a 15 | default constructor to be batchable. 16 | Therefore, all parameters passed to the constructor must be optional. 17 | """ 18 | 19 | def __init__( 20 | self, atom_idx=None, other_idx=None, pos=None, ctable: Optional[CTable] = None 21 | ): 22 | """ 23 | Initialize a `ProteinData` object representing a protein graph. 24 | :param atom_idx: the hydrogen atom indices to be considered 25 | :param pos: the coordinates of the whole protein 26 | :param ctable: a `CTable` object containing the protein information 27 | """ 28 | 29 | def attr(name): 30 | """Get an attribute from CTable or return None if ctable is None.""" 31 | return None if ctable is None else getattr(ctable, name) 32 | 33 | def method(name, return_cnt): 34 | """Get a method call result from CTable or return a 35 | list of None's if ctable is None.""" 36 | return ( 37 | [None] * return_cnt 38 | if ctable is None 39 | else getattr(ctable, name)(atom_idx) 40 | ) 41 | 42 | def zeros_like(name): 43 | """Get a tensor of zeros or return None if ctable is None.""" 44 | return ( 45 | None 46 | if ctable is None 47 | else torch.zeros_like(getattr(self, name)) 48 | ) 49 | 50 | super().__init__(pos=pos, z=attr("atomic_number")) 51 | self.atom_idx = atom_idx 52 | self.other_idx = other_idx 53 | 54 | # The following scalar attributes are only used 55 | # for batching data. Make them private. 56 | self._natom = attr("natom") 57 | self._ntypes = attr("ntypes") 58 | self._numbnd = attr("numbnd") 59 | self._numang = attr("numang") 60 | self._nptra = attr("nptra") 61 | 62 | # The following attributes can be just concatenated. 63 | self.charge = attr("charge") 64 | self.bond_force_constant = attr("bond_force_constant") 65 | self.bond_equil_value = attr("bond_equil_value") 66 | self.angle_force_constant = attr("angle_force_constant") 67 | self.angle_equil_value = attr("angle_equil_value") 68 | self.dihedral_force_constant = attr("dihedral_force_constant") 69 | self.dihedral_periodicity = attr("dihedral_periodicity") 70 | self.dihedral_phase = attr("dihedral_phase") 71 | self.lennard_jones_acoef = attr("lennard_jones_acoef") 72 | self.lennard_jones_bcoef = attr("lennard_jones_bcoef") 73 | 74 | # The following indices need to be filtered by atom_idx. 75 | # They need to be increased during batching. 76 | ( 77 | self.bonds_atom_idx_src, 78 | self.bonds_atom_idx_dst, 79 | self.bond_idx, 80 | ) = method("filter_bonds", 3) 81 | ( 82 | self.angles_atom_idx_i, 83 | self.angles_atom_idx_j, 84 | self.angles_atom_idx_k, 85 | self.angle_idx, 86 | ) = method("filter_angles", 4) 87 | ( 88 | self.dihedrals_atom_idx_i, 89 | self.dihedrals_atom_idx_j, 90 | self.dihedrals_atom_idx_k, 91 | self.dihedrals_atom_idx_l, 92 | self.dihedral_idx, 93 | ) = method("filter_dihedrals", 5) 94 | self.nonbonded_atom_idx_src, self.nonbonded_atom_idx_dst = method( 95 | "gen_nonbonded_pair", 2 96 | ) 97 | self.lj_idx = ( 98 | None 99 | if ctable is None 100 | else ctable.generate_lj_idx( 101 | self.nonbonded_atom_idx_src, self.nonbonded_atom_idx_dst 102 | ) 103 | ) 104 | 105 | # The following attributes contains the batch indices. 106 | # They need to be increased by 1 during batching. 107 | self.bond_batch = zeros_like("bond_idx") 108 | self.angle_batch = zeros_like("angle_idx") 109 | self.dihedral_batch = zeros_like("dihedral_idx") 110 | self.nonbonded_batch = zeros_like("lj_idx") 111 | 112 | def __inc__(self, key, value, *args, **kwargs): 113 | if "atom_idx" in key: 114 | # All atom indices need to be increased by the number 115 | # of nodes during batching. 116 | return self._natom 117 | elif "other_idx" in key: 118 | return self._natom 119 | elif "_batch" in key: 120 | # All batch indices need to be increased by 1 during batching. 121 | return 1 122 | elif key == "bond_idx": 123 | return self._numbnd 124 | elif key == "angle_idx": 125 | return self._numang 126 | elif key == "dihedral_idx": 127 | return self._nptra 128 | elif key == "lj_idx": 129 | return self._ntypes * (self._ntypes + 1) // 2 130 | return super().__inc__(key, value, *args, **kwargs) 131 | 132 | 133 | class ProteinDataBatch(Batch, ProteinData): 134 | """ 135 | Class representing a batch of protein structures as a graph. 136 | This class modifies the `get_example` method so that 137 | only the specified attributes are unbatched to save memory. 138 | """ 139 | 140 | def get_example(self, idx, keys=("pos", "z")): 141 | """ 142 | Get the `idx`-th example from the batch. Only the specified 143 | attributes are unbatched. 144 | :param idx: data index 145 | :param keys: the attributes to be unbatched. If None, all 146 | attributes are unbatched. 147 | :return: a `ProteinData` object 148 | """ 149 | if not hasattr(self, "_slice_dict"): 150 | raise RuntimeError( 151 | ( 152 | "Cannot reconstruct 'Data' object from 'Batch' because " 153 | "'Batch' was not created via 'Batch.from_data_list()'" 154 | ) 155 | ) 156 | 157 | slice_dict = ( 158 | self._slice_dict 159 | if keys is None 160 | else {key: self._slice_dict[key] for key in keys} 161 | ) 162 | data = separate.separate( 163 | cls=self.__class__.__bases__[-1], 164 | batch=self, 165 | idx=idx, 166 | slice_dict=slice_dict, 167 | inc_dict=self._inc_dict, 168 | decrement=True, 169 | ) 170 | 171 | return data 172 | 173 | def to_data_list(self, keys=("pos", "z")): 174 | """ 175 | Convert a batch to a list of `ProteinData` objects. 176 | :param keys: the attributes to be unbatched. If None, 177 | all attributes are unbatched. 178 | :return: a list of `ProteinData` objects 179 | """ 180 | return [self.get_example(i, keys) for i in range(self.num_graphs)] 181 | -------------------------------------------------------------------------------- /src/Fragmentation/mm.in: -------------------------------------------------------------------------------- 1 | zero step md to get energy and force 2 | &cntrl 3 | imin=0, nstlim=0, ntx=1 !0 step md 4 | cut=10, ntb=0, !non-periodic 5 | ntpr=1,ntwf=1,ntwe=1,ntwx=1 ! (output frequencies) 6 | &end 7 | END 8 | -------------------------------------------------------------------------------- /src/Fragmentation/prmtop/ANAN.prmtop: -------------------------------------------------------------------------------- 1 | %VERSION VERSION_STAMP = V0001.000 DATE = 06/21/22 18:55:05 2 | %FLAG TITLE 3 | %FORMAT(20a4) 4 | ACE 5 | %FLAG POINTERS 6 | %FORMAT(10I8) 7 | 12 7 7 4 14 4 22 3 0 0 8 | 46 2 4 4 3 7 10 7 7 0 9 | 0 0 0 0 0 0 0 0 6 0 10 | 0 11 | %FLAG ATOM_NAME 12 | %FORMAT(20a4) 13 | H1 CH3 H2 H3 C O N H CH3 HH31HH32HH33 14 | %FLAG CHARGE 15 | %FORMAT(5E16.8) 16 | 2.04636429E+00 -6.67300626E+00 2.04636429E+00 2.04636429E+00 1.08823576E+01 17 | -1.03484442E+01 -7.57501011E+00 4.95464337E+00 -2.71512270E+00 1.77849648E+00 18 | 1.77849648E+00 1.77849648E+00 19 | %FLAG ATOMIC_NUMBER 20 | %FORMAT(10I8) 21 | 1 6 1 1 6 8 7 1 6 1 22 | 1 1 23 | %FLAG MASS 24 | %FORMAT(5E16.8) 25 | 1.00800000E+00 1.20100000E+01 1.00800000E+00 1.00800000E+00 1.20100000E+01 26 | 1.60000000E+01 1.40100000E+01 1.00800000E+00 1.20100000E+01 1.00800000E+00 27 | 1.00800000E+00 1.00800000E+00 28 | %FLAG ATOM_TYPE_INDEX 29 | %FORMAT(10I8) 30 | 1 2 1 1 3 4 5 6 2 7 31 | 7 7 32 | %FLAG NUMBER_EXCLUDED_ATOMS 33 | %FORMAT(10I8) 34 | 6 7 4 3 7 3 5 4 3 2 35 | 1 1 36 | %FLAG NONBONDED_PARM_INDEX 37 | %FORMAT(10I8) 38 | 1 2 4 7 11 16 22 2 3 5 39 | 8 12 17 23 4 5 6 9 13 18 40 | 24 7 8 9 10 14 19 25 11 12 41 | 13 14 15 20 26 16 17 18 19 20 42 | 21 27 22 23 24 25 26 27 28 43 | %FLAG RESIDUE_LABEL 44 | %FORMAT(20a4) 45 | ACE NME 46 | %FLAG RESIDUE_POINTER 47 | %FORMAT(10I8) 48 | 1 7 49 | %FLAG BOND_FORCE_CONSTANT 50 | %FORMAT(5E16.8) 51 | 5.70000000E+02 4.90000000E+02 3.40000000E+02 3.17000000E+02 3.40000000E+02 52 | 4.34000000E+02 3.37000000E+02 53 | %FLAG BOND_EQUIL_VALUE 54 | %FORMAT(5E16.8) 55 | 1.22900000E+00 1.33500000E+00 1.09000000E+00 1.52200000E+00 1.09000000E+00 56 | 1.01000000E+00 1.44900000E+00 57 | %FLAG ANGLE_FORCE_CONSTANT 58 | %FORMAT(5E16.8) 59 | 8.00000000E+01 5.00000000E+01 5.00000000E+01 5.00000000E+01 3.50000000E+01 60 | 8.00000000E+01 7.00000000E+01 3.50000000E+01 5.00000000E+01 5.00000000E+01 61 | %FLAG ANGLE_EQUIL_VALUE 62 | %FORMAT(5E16.8) 63 | 2.14501057E+00 2.09439600E+00 2.12755727E+00 1.91113635E+00 1.91113635E+00 64 | 2.10137732E+00 2.03505478E+00 1.91113635E+00 2.06018753E+00 1.91113635E+00 65 | %FLAG DIHEDRAL_FORCE_CONSTANT 66 | %FORMAT(5E16.8) 67 | 2.00000000E+00 2.50000000E+00 0.00000000E+00 8.00000000E-01 8.00000000E-02 68 | 1.05000000E+01 1.10000000E+00 69 | %FLAG DIHEDRAL_PERIODICITY 70 | %FORMAT(5E16.8) 71 | 1.00000000E+00 2.00000000E+00 2.00000000E+00 1.00000000E+00 3.00000000E+00 72 | 2.00000000E+00 2.00000000E+00 73 | %FLAG DIHEDRAL_PHASE 74 | %FORMAT(5E16.8) 75 | 0.00000000E+00 3.14159400E+00 0.00000000E+00 0.00000000E+00 3.14159400E+00 76 | 3.14159400E+00 3.14159400E+00 77 | %FLAG SCEE_SCALE_FACTOR 78 | %FORMAT(5E16.8) 79 | 1.20000000E+00 1.20000000E+00 1.20000000E+00 1.20000000E+00 1.20000000E+00 80 | 0.00000000E+00 0.00000000E+00 81 | %FLAG SCNB_SCALE_FACTOR 82 | %FORMAT(5E16.8) 83 | 2.00000000E+00 2.00000000E+00 2.00000000E+00 2.00000000E+00 2.00000000E+00 84 | 0.00000000E+00 0.00000000E+00 85 | %FLAG SOLTY 86 | %FORMAT(5E16.8) 87 | 0.00000000E+00 0.00000000E+00 0.00000000E+00 0.00000000E+00 0.00000000E+00 88 | 0.00000000E+00 0.00000000E+00 89 | %FLAG LENNARD_JONES_ACOEF 90 | %FORMAT(5E16.8) 91 | 7.51607703E+03 9.71708117E+04 1.04308023E+06 8.61541883E+04 9.24822270E+05 92 | 8.19971662E+05 5.44261042E+04 6.47841731E+05 5.74393458E+05 3.79876399E+05 93 | 8.96776989E+04 9.95480466E+05 8.82619071E+05 6.06829342E+05 9.44293233E+05 94 | 1.07193646E+02 2.56678134E+03 2.27577561E+03 1.02595236E+03 2.12601181E+03 95 | 1.39982777E-01 4.98586848E+03 6.78771368E+04 6.01816484E+04 3.69471530E+04 96 | 6.20665997E+04 5.94667300E+01 3.25969625E+03 97 | %FLAG LENNARD_JONES_BCOEF 98 | %FORMAT(5E16.8) 99 | 2.17257828E+01 1.26919150E+02 6.75612247E+02 1.12529845E+02 5.99015525E+02 100 | 5.31102864E+02 1.11805549E+02 6.26720080E+02 5.55666448E+02 5.64885984E+02 101 | 1.36131731E+02 7.36907417E+02 6.53361429E+02 6.77220874E+02 8.01323529E+02 102 | 2.59456373E+00 2.06278363E+01 1.82891803E+01 1.53505284E+01 2.09604198E+01 103 | 9.37598976E-02 1.76949863E+01 1.06076943E+02 9.40505980E+01 9.21192136E+01 104 | 1.13252061E+02 1.93248820E+00 1.43076527E+01 105 | %FLAG BONDS_INC_HYDROGEN 106 | %FORMAT(10I8) 107 | 3 6 3 3 9 3 0 3 3 24 108 | 27 5 24 30 5 24 33 5 18 21 109 | 6 110 | %FLAG BONDS_WITHOUT_HYDROGEN 111 | %FORMAT(10I8) 112 | 12 15 1 12 18 2 3 12 4 18 113 | 24 7 114 | %FLAG ANGLES_INC_HYDROGEN 115 | %FORMAT(10I8) 116 | 12 18 21 2 9 3 12 4 6 3 117 | 9 5 6 3 12 4 0 3 6 5 118 | 0 3 9 5 0 3 12 4 30 24 119 | 33 8 27 24 30 8 27 24 33 8 120 | 21 18 24 9 18 24 27 10 18 24 121 | 30 10 18 24 33 10 122 | %FLAG ANGLES_WITHOUT_HYDROGEN 123 | %FORMAT(10I8) 124 | 15 12 18 1 12 18 24 3 3 12 125 | 15 6 3 12 18 7 126 | %FLAG DIHEDRALS_INC_HYDROGEN 127 | %FORMAT(10I8) 128 | 15 12 18 21 1 15 12 -18 21 2 129 | 12 18 24 27 3 12 18 24 30 3 130 | 12 18 24 33 3 9 3 12 15 4 131 | 9 3 -12 15 3 9 3 -12 15 5 132 | 9 3 12 18 3 6 3 12 15 4 133 | 6 3 -12 15 3 6 3 -12 15 5 134 | 6 3 12 18 3 3 12 18 21 2 135 | 0 3 12 15 4 0 3 -12 15 3 136 | 0 3 -12 15 5 0 3 12 18 3 137 | 21 18 24 27 3 21 18 24 30 3 138 | 21 18 24 33 3 12 24 -18 -21 7 139 | %FLAG DIHEDRALS_WITHOUT_HYDROGEN 140 | %FORMAT(10I8) 141 | 15 12 18 24 2 3 12 18 24 2 142 | 3 18 -12 -15 6 143 | %FLAG EXCLUDED_ATOMS_LIST 144 | %FORMAT(10I8) 145 | 2 3 4 5 6 7 3 4 5 6 146 | 7 8 9 4 5 6 7 5 6 7 147 | 6 7 8 9 10 11 12 7 8 9 148 | 8 9 10 11 12 9 10 11 12 10 149 | 11 12 11 12 12 0 150 | %FLAG HBOND_ACOEF 151 | %FORMAT(5E16.8) 152 | 153 | %FLAG HBOND_BCOEF 154 | %FORMAT(5E16.8) 155 | 156 | %FLAG HBCUT 157 | %FORMAT(5E16.8) 158 | 159 | %FLAG AMBER_ATOM_TYPE 160 | %FORMAT(20a4) 161 | HC CT HC HC C O N H CT H1 H1 H1 162 | %FLAG TREE_CHAIN_CLASSIFICATION 163 | %FORMAT(20a4) 164 | M M E E M E M E M E E E 165 | %FLAG JOIN_ARRAY 166 | %FORMAT(10I8) 167 | 0 0 0 0 0 0 0 0 0 0 168 | 0 0 169 | %FLAG IROTAT 170 | %FORMAT(10I8) 171 | 0 0 0 0 0 0 0 0 0 0 172 | 0 0 173 | %FLAG RADIUS_SET 174 | %FORMAT(1a80) 175 | H(N)-modified Bondi radii (mbondi2) 176 | %FLAG RADII 177 | %FORMAT(5E16.8) 178 | 1.20000000E+00 1.70000000E+00 1.20000000E+00 1.20000000E+00 1.70000000E+00 179 | 1.50000000E+00 1.55000000E+00 1.30000000E+00 1.70000000E+00 1.20000000E+00 180 | 1.20000000E+00 1.20000000E+00 181 | %FLAG SCREEN 182 | %FORMAT(5E16.8) 183 | 8.50000000E-01 7.20000000E-01 8.50000000E-01 8.50000000E-01 7.20000000E-01 184 | 8.50000000E-01 7.90000000E-01 8.50000000E-01 7.20000000E-01 8.50000000E-01 185 | 8.50000000E-01 8.50000000E-01 186 | %FLAG IPOL 187 | %FORMAT(1I8) 188 | 0 189 | %FLAG CMAP_COUNT 190 | %FORMAT(2I8) 191 | 0 0 192 | %FLAG CMAP_RESOLUTION 193 | %FORMAT(20I4) 194 | 195 | %FLAG CMAP_INDEX 196 | %FORMAT(6I8) 197 | -------------------------------------------------------------------------------- /src/ViSNet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/ViSNet/__init__.py -------------------------------------------------------------------------------- /src/ViSNet/checkpoints/visnet-uni-2ef43f29ec78fa5fef0b3de832bfada9.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/ViSNet/checkpoints/visnet-uni-2ef43f29ec78fa5fef0b3de832bfada9.ckpt -------------------------------------------------------------------------------- /src/ViSNet/checkpoints/visnet-uni-de11d1421ccda37ffab07d7403c8f5bb.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/ViSNet/checkpoints/visnet-uni-de11d1421ccda37ffab07d7403c8f5bb.ckpt -------------------------------------------------------------------------------- /src/ViSNet/model/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["ViSNetBlock", "EquivariantScalar"] 2 | -------------------------------------------------------------------------------- /src/ViSNet/model/output_modules.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .utils import act_class_mapping 7 | 8 | 9 | class GatedEquivariantBlock(nn.Module): 10 | """ 11 | Gated Equivariant Block as defined in Schütt et al. (2021): 12 | Equivariant message passing for the prediction of 13 | tensorial properties and molecular spectra 14 | """ 15 | 16 | def __init__( 17 | self, 18 | hidden_channels, 19 | out_channels, 20 | intermediate_channels=None, 21 | activation="silu", 22 | scalar_activation=False, 23 | ): 24 | super(GatedEquivariantBlock, self).__init__() 25 | self.out_channels = out_channels 26 | 27 | if intermediate_channels is None: 28 | intermediate_channels = hidden_channels 29 | 30 | self.vec1_proj = nn.Linear( 31 | hidden_channels, hidden_channels, bias=False 32 | ) 33 | self.vec2_proj = nn.Linear(hidden_channels, out_channels, bias=False) 34 | 35 | act_class = act_class_mapping[activation] 36 | self.update_net = nn.Sequential( 37 | nn.Linear(hidden_channels * 2, intermediate_channels), 38 | act_class(), 39 | nn.Linear(intermediate_channels, out_channels * 2), 40 | ) 41 | 42 | self.act = act_class() if scalar_activation else None 43 | 44 | def reset_parameters(self): 45 | nn.init.xavier_uniform_(self.vec1_proj.weight) 46 | nn.init.xavier_uniform_(self.vec2_proj.weight) 47 | nn.init.xavier_uniform_(self.update_net[0].weight) 48 | self.update_net[0].bias.data.fill_(0) 49 | nn.init.xavier_uniform_(self.update_net[2].weight) 50 | self.update_net[2].bias.data.fill_(0) 51 | 52 | def forward(self, x, v): 53 | vec1 = torch.norm(self.vec1_proj(v), dim=-2) 54 | vec2 = self.vec2_proj(v) 55 | 56 | x = torch.cat([x, vec1], dim=-1) 57 | x, v = torch.split(self.update_net(x), self.out_channels, dim=-1) 58 | v = v.unsqueeze(1) * vec2 59 | 60 | if self.act is not None: 61 | x = self.act(x) 62 | return x, v 63 | 64 | 65 | class OutputModel(nn.Module, metaclass=ABCMeta): 66 | def __init__(self, allow_prior_model): 67 | super(OutputModel, self).__init__() 68 | self.allow_prior_model = allow_prior_model 69 | 70 | def reset_parameters(self): 71 | pass 72 | 73 | @abstractmethod 74 | def pre_reduce(self, x, v, z, pos, batch): 75 | return 76 | 77 | def post_reduce(self, x): 78 | return x 79 | 80 | 81 | class Scalar(OutputModel): 82 | def __init__( 83 | self, hidden_channels, activation="silu", allow_prior_model=True 84 | ): 85 | super(Scalar, self).__init__(allow_prior_model=allow_prior_model) 86 | act_class = act_class_mapping[activation] 87 | self.output_network = nn.Sequential( 88 | nn.Linear(hidden_channels, hidden_channels // 2), 89 | act_class(), 90 | nn.Linear(hidden_channels // 2, 1), 91 | ) 92 | 93 | self.reset_parameters() 94 | 95 | def reset_parameters(self): 96 | nn.init.xavier_uniform_(self.output_network[0].weight) 97 | self.output_network[0].bias.data.fill_(0) 98 | nn.init.xavier_uniform_(self.output_network[2].weight) 99 | self.output_network[2].bias.data.fill_(0) 100 | 101 | def pre_reduce(self, x, v, z, pos, batch): 102 | # include v in output to make sure all parameters have a gradient 103 | return self.output_network(x) 104 | 105 | 106 | class EquivariantScalar(OutputModel): 107 | def __init__( 108 | self, hidden_channels, activation="silu", allow_prior_model=True 109 | ): 110 | super(EquivariantScalar, self).__init__( 111 | allow_prior_model=allow_prior_model 112 | ) 113 | self.output_network = nn.ModuleList( 114 | [ 115 | GatedEquivariantBlock( 116 | hidden_channels, 117 | hidden_channels // 2, 118 | activation=activation, 119 | scalar_activation=True, 120 | ), 121 | GatedEquivariantBlock( 122 | hidden_channels // 2, 123 | 1, 124 | activation=activation, 125 | scalar_activation=False, 126 | ), 127 | ] 128 | ) 129 | 130 | self.reset_parameters() 131 | 132 | def reset_parameters(self): 133 | for layer in self.output_network: 134 | layer.reset_parameters() 135 | 136 | def pre_reduce(self, x, v, z, pos, batch): 137 | for layer in self.output_network: 138 | x, v = layer(x, v) 139 | # include v in output to make sure all parameters have a gradient 140 | return x + v.sum() * 0 141 | -------------------------------------------------------------------------------- /src/ViSNet/model/priors.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import torch 4 | import torch.nn as nn 5 | from pytorch_lightning.utilities import rank_zero_warn 6 | 7 | __all__ = ["Atomref"] 8 | 9 | 10 | class BasePrior(nn.Module, metaclass=ABCMeta): 11 | """ 12 | Base class for prior models. 13 | Derive this class to make custom prior models, 14 | which take some arguments and a dataset as input. 15 | As an example, have a look at the `torchmdnet.priors.Atomref` prior. 16 | """ 17 | 18 | def __init__(self): 19 | super(BasePrior, self).__init__() 20 | 21 | @abstractmethod 22 | def get_init_args(self): 23 | """ 24 | A function that returns all required arguments to construct a prior object. 25 | The values should be returned inside a dict 26 | with the keys being the arguments' names. 27 | All values should also be saveable in a .yaml file 28 | as this is used to reconstruct the 29 | prior model from a checkpoint file. 30 | """ 31 | return 32 | 33 | @abstractmethod 34 | def forward(self, x, z): 35 | """ 36 | Forward method of the prior model. 37 | 38 | Args: 39 | x (torch.Tensor): scalar atomwise predictions from the model. 40 | z (torch.Tensor): atom types of all atoms. 41 | 42 | Returns: 43 | torch.Tensor: updated scalar atomwise predictions 44 | """ 45 | return 46 | 47 | 48 | class Atomref(BasePrior): 49 | """ 50 | Atomref prior model. 51 | When using this in combination with some dataset, the dataset class must implement 52 | the function `get_atomref`, which returns the atomic reference values as a tensor. 53 | """ 54 | 55 | def __init__(self, max_z=None, dataset=None): 56 | super(Atomref, self).__init__() 57 | if max_z is None and dataset is None: 58 | raise ValueError( 59 | "Can't instantiate Atomref prior, all arguments are None." 60 | ) 61 | if dataset is None: 62 | atomref = torch.zeros(max_z, 1) 63 | else: 64 | atomref = dataset.get_atomref() 65 | if atomref is None: 66 | rank_zero_warn( 67 | "The atomref returned by the dataset is None," 68 | " defaulting to zeros with max. " 69 | "atomic number 99. Maybe atomref is not defined" 70 | " for the current target." 71 | ) 72 | atomref = torch.zeros(100, 1) 73 | 74 | if atomref.ndim == 1: 75 | atomref = atomref.view(-1, 1) 76 | self.register_buffer("initial_atomref", atomref) 77 | self.atomref = nn.Embedding(len(atomref), 1) 78 | self.atomref.weight.data.copy_(atomref) 79 | 80 | def reset_parameters(self): 81 | self.atomref.weight.data.copy_(self.initial_atomref) 82 | 83 | def get_init_args(self): 84 | return dict(max_z=self.initial_atomref.size(0)) 85 | 86 | def forward(self, x, z): 87 | return x + self.atomref(z) 88 | -------------------------------------------------------------------------------- /src/ViSNet/model/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch_cluster import radius_graph 7 | from torch_geometric.nn import MessagePassing 8 | from torch_sparse import SparseTensor 9 | 10 | class CosineCutoff(nn.Module): 11 | def __init__(self, cutoff): 12 | super(CosineCutoff, self).__init__() 13 | 14 | self.cutoff = cutoff 15 | 16 | def forward(self, distances): 17 | cutoffs = 0.5 * (torch.cos(distances * math.pi / self.cutoff) + 1.0) 18 | cutoffs = cutoffs * (distances < self.cutoff).float() 19 | return cutoffs 20 | 21 | 22 | class ExpNormalSmearing(nn.Module): 23 | def __init__(self, cutoff=5.0, num_rbf=50, trainable=True): 24 | super(ExpNormalSmearing, self).__init__() 25 | self.cutoff = cutoff 26 | self.num_rbf = num_rbf 27 | self.trainable = trainable 28 | 29 | self.cutoff_fn = CosineCutoff(cutoff) 30 | self.alpha = 5.0 / cutoff 31 | 32 | means, betas = self._initial_params() 33 | if trainable: 34 | self.register_parameter("means", nn.Parameter(means)) 35 | self.register_parameter("betas", nn.Parameter(betas)) 36 | else: 37 | self.register_buffer("means", means) 38 | self.register_buffer("betas", betas) 39 | 40 | def _initial_params(self): 41 | start_value = torch.exp(torch.scalar_tensor(-self.cutoff)) 42 | means = torch.linspace(start_value, 1, self.num_rbf) 43 | betas = torch.tensor( 44 | [(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf 45 | ) 46 | return means, betas 47 | 48 | def reset_parameters(self): 49 | means, betas = self._initial_params() 50 | self.means.data.copy_(means) 51 | self.betas.data.copy_(betas) 52 | 53 | def forward(self, dist): 54 | dist = dist.unsqueeze(-1) 55 | return self.cutoff_fn(dist) * torch.exp( 56 | -self.betas * (torch.exp(self.alpha * (-dist)) - self.means) ** 2 57 | ) 58 | 59 | 60 | class GaussianSmearing(nn.Module): 61 | def __init__(self, cutoff=5.0, num_rbf=50, trainable=True): 62 | super(GaussianSmearing, self).__init__() 63 | self.cutoff = cutoff 64 | self.num_rbf = num_rbf 65 | self.trainable = trainable 66 | 67 | offset, coeff = self._initial_params() 68 | if trainable: 69 | self.register_parameter("coeff", nn.Parameter(coeff)) 70 | self.register_parameter("offset", nn.Parameter(offset)) 71 | else: 72 | self.register_buffer("coeff", coeff) 73 | self.register_buffer("offset", offset) 74 | 75 | def _initial_params(self): 76 | offset = torch.linspace(0, self.cutoff, self.num_rbf) 77 | coeff = -0.5 / (offset[1] - offset[0]) ** 2 78 | return offset, coeff 79 | 80 | def reset_parameters(self): 81 | offset, coeff = self._initial_params() 82 | self.offset.data.copy_(offset) 83 | self.coeff.data.copy_(coeff) 84 | 85 | def forward(self, dist): 86 | dist = dist.unsqueeze(-1) - self.offset 87 | return torch.exp(self.coeff * torch.pow(dist, 2)) 88 | 89 | 90 | rbf_class_mapping = {"gauss": GaussianSmearing, "expnorm": ExpNormalSmearing} 91 | 92 | 93 | class ShiftedSoftplus(nn.Module): 94 | def __init__(self): 95 | super(ShiftedSoftplus, self).__init__() 96 | self.shift = torch.log(torch.tensor(2.0)).item() 97 | 98 | def forward(self, x): 99 | return F.softplus(x) - self.shift 100 | 101 | 102 | class Swish(nn.Module): 103 | def __init__(self): 104 | super(Swish, self).__init__() 105 | 106 | def forward(self, x): 107 | return x * torch.sigmoid(x) 108 | 109 | 110 | act_class_mapping = { 111 | "ssp": ShiftedSoftplus, 112 | "silu": nn.SiLU, 113 | "tanh": nn.Tanh, 114 | "sigmoid": nn.Sigmoid, 115 | "swish": Swish, 116 | } 117 | 118 | 119 | class Sphere(nn.Module): 120 | def __init__(self, lmax=2): 121 | super(Sphere, self).__init__() 122 | self.l = lmax 123 | 124 | def forward(self, edge_vec): 125 | edge_sh = self._spherical_harmonics( 126 | self.l, edge_vec[..., 0], edge_vec[..., 1], edge_vec[..., 2] 127 | ) 128 | return edge_sh 129 | 130 | @staticmethod 131 | def _spherical_harmonics( 132 | lmax: int, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor 133 | ) -> torch.Tensor: 134 | sh_1_0, sh_1_1, sh_1_2 = x, y, z 135 | 136 | if lmax == 1: 137 | return torch.stack([sh_1_0, sh_1_1, sh_1_2], dim=-1) 138 | 139 | sh_2_0 = math.sqrt(3.0) * x * z 140 | sh_2_1 = math.sqrt(3.0) * x * y 141 | y2 = y.pow(2) 142 | x2z2 = x.pow(2) + z.pow(2) 143 | sh_2_2 = y2 - 0.5 * x2z2 144 | sh_2_3 = math.sqrt(3.0) * y * z 145 | sh_2_4 = math.sqrt(3.0) / 2.0 * (z.pow(2) - x.pow(2)) 146 | 147 | if lmax == 2: 148 | return torch.stack( 149 | [ 150 | sh_1_0, 151 | sh_1_1, 152 | sh_1_2, 153 | sh_2_0, 154 | sh_2_1, 155 | sh_2_2, 156 | sh_2_3, 157 | sh_2_4, 158 | ], 159 | dim=-1, 160 | ) 161 | else: 162 | raise ValueError(f"lmax={lmax}") 163 | 164 | 165 | class VecLayerNorm(nn.Module): 166 | def __init__(self, hidden_channels, trainable, norm_type="max_min"): 167 | super(VecLayerNorm, self).__init__() 168 | 169 | self.hidden_channels = hidden_channels 170 | self.eps = 1e-12 171 | 172 | weight = torch.ones(self.hidden_channels) 173 | if trainable: 174 | self.register_parameter("weight", nn.Parameter(weight)) 175 | else: 176 | self.register_buffer("weight", weight) 177 | 178 | self.norm_type = norm_type 179 | 180 | self.reset_parameters() 181 | 182 | def reset_parameters(self): 183 | weight = torch.ones(self.hidden_channels) 184 | self.weight.data.copy_(weight) 185 | 186 | def none_norm(self, vec): 187 | return vec 188 | 189 | def rms_norm(self, vec): 190 | # vec: (num_atoms, 3 or 5, hidden_channels) 191 | dist = torch.norm(vec, dim=1) 192 | 193 | if (dist == 0).all(): 194 | return torch.zeros_like(vec) 195 | 196 | dist = dist.clamp(min=self.eps) 197 | dist = torch.sqrt(torch.mean(dist**2, dim=-1)) 198 | return vec / F.relu(dist).unsqueeze(-1).unsqueeze(-1) 199 | 200 | def max_min_norm(self, vec): 201 | # vec: (num_atoms, 3 or 5, hidden_channels) 202 | dist = torch.norm(vec, dim=1, keepdim=True) 203 | 204 | if (dist == 0).all(): 205 | return torch.zeros_like(vec) 206 | 207 | dist = dist.clamp(min=self.eps) 208 | direct = vec / dist 209 | 210 | max_val, _ = torch.max(dist, dim=-1) 211 | min_val, _ = torch.min(dist, dim=-1) 212 | delta = (max_val - min_val).view(-1) 213 | delta = torch.where(delta == 0, torch.ones_like(delta), delta) 214 | dist = (dist - min_val.view(-1, 1, 1)) / delta.view(-1, 1, 1) 215 | 216 | return F.relu(dist) * direct 217 | 218 | def forward(self, vec): 219 | # vec: (num_atoms, 3 or 8, hidden_channels) 220 | if vec.shape[1] == 3: 221 | if self.norm_type == "rms": 222 | vec = self.rms_norm(vec) 223 | elif self.norm_type == "max_min": 224 | vec = self.max_min_norm(vec) 225 | else: 226 | vec = self.none_norm(vec) 227 | 228 | return vec * self.weight.unsqueeze(0).unsqueeze(0) 229 | elif vec.shape[1] == 8: 230 | vec1, vec2 = torch.split(vec, [3, 5], dim=1) 231 | 232 | if self.norm_type == "rms": 233 | vec1 = self.rms_norm(vec1) 234 | elif self.norm_type == "max_min": 235 | vec1 = self.max_min_norm(vec1) 236 | else: 237 | vec1 = self.none_norm(vec1) 238 | 239 | if self.norm_type == "rms": 240 | vec2 = self.rms_norm(vec2) 241 | elif self.norm_type == "max_min": 242 | vec2 = self.max_min_norm(vec2) 243 | else: 244 | vec2 = self.none_norm(vec2) 245 | 246 | vec = torch.cat([vec1, vec2], dim=1) 247 | return vec * self.weight.unsqueeze(0).unsqueeze(0) 248 | else: 249 | raise ValueError("VecLayerNorm only support 3 or 8 channels") 250 | 251 | 252 | class Distance(nn.Module): 253 | def __init__(self, cutoff, max_num_neighbors=32, loop=True): 254 | super(Distance, self).__init__() 255 | self.cutoff = cutoff 256 | self.max_num_neighbors = max_num_neighbors 257 | self.loop = loop 258 | 259 | def forward(self, pos, batch): 260 | edge_index = radius_graph( 261 | pos, 262 | r=self.cutoff, 263 | batch=batch, 264 | loop=self.loop, 265 | max_num_neighbors=self.max_num_neighbors, 266 | ) 267 | edge_vec = pos[edge_index[0]] - pos[edge_index[1]] 268 | 269 | if self.loop: 270 | mask = edge_index[0] != edge_index[1] 271 | edge_weight = torch.zeros(edge_vec.size(0), device=edge_vec.device) 272 | edge_weight[mask] = torch.norm(edge_vec[mask], dim=-1) 273 | else: 274 | edge_weight = torch.norm(edge_vec, dim=-1) 275 | 276 | return edge_index, edge_weight, edge_vec 277 | 278 | 279 | class NeighborEmbedding(MessagePassing): 280 | def __init__(self, hidden_channels, num_rbf, cutoff, max_z): 281 | super(NeighborEmbedding, self).__init__(aggr="add") 282 | self.embedding = nn.Embedding(max_z, hidden_channels) 283 | self.distance_proj = nn.Linear(num_rbf, hidden_channels) 284 | self.combine = nn.Linear(hidden_channels * 2, hidden_channels) 285 | self.cutoff = CosineCutoff(cutoff) 286 | 287 | self.reset_parameters() 288 | 289 | def reset_parameters(self): 290 | self.embedding.reset_parameters() 291 | nn.init.xavier_uniform_(self.distance_proj.weight) 292 | nn.init.xavier_uniform_(self.combine.weight) 293 | self.distance_proj.bias.data.fill_(0) 294 | self.combine.bias.data.fill_(0) 295 | 296 | def forward(self, z, x, edge_index, edge_weight, edge_attr): 297 | # remove self loops 298 | mask = edge_index[0] != edge_index[1] 299 | if not mask.all(): 300 | edge_index = edge_index[:, mask] 301 | edge_weight = edge_weight[mask] 302 | edge_attr = edge_attr[mask] 303 | 304 | C = self.cutoff(edge_weight) 305 | W = self.distance_proj(edge_attr) * C.view(-1, 1) 306 | 307 | x_neighbors = self.embedding(z) 308 | 309 | # propagate_type: (x: Tensor, W: Tensor) 310 | x_neighbors = self.propagate( 311 | edge_index, x=x_neighbors, W=W, size=None 312 | ) 313 | x_neighbors = self.combine(torch.cat([x, x_neighbors], dim=1)) 314 | return x_neighbors 315 | 316 | def message(self, x_j, W): 317 | return x_j * W 318 | 319 | 320 | class EdgeEmbedding(MessagePassing): 321 | def __init__(self, num_rbf, hidden_channels): 322 | super(EdgeEmbedding, self).__init__(aggr=None) 323 | self.edge_proj = nn.Linear(num_rbf, hidden_channels) 324 | 325 | self.reset_parameters() 326 | 327 | def reset_parameters(self): 328 | nn.init.xavier_uniform_(self.edge_proj.weight) 329 | self.edge_proj.bias.data.fill_(0) 330 | 331 | def forward(self, edge_index: torch.Tensor, edge_attr, x): 332 | # propagate_type: (x: Tensor, edge_attr: Tensor) 333 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None) 334 | return out 335 | 336 | def message(self, x_i, x_j, edge_attr): 337 | return (x_i + x_j) * self.edge_proj(edge_attr) 338 | 339 | def aggregate(self, features, index): 340 | # no aggregate 341 | return features 342 | -------------------------------------------------------------------------------- /src/ViSNet/model/visnet.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List, Optional, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | from pytorch_lightning.utilities import rank_zero_warn 7 | from torch import Tensor 8 | from torch.autograd import grad 9 | from torch_scatter import scatter 10 | 11 | from . import output_modules, priors 12 | 13 | 14 | def create_model(args, prior_model=None, mean=None, std=None): 15 | visnet_args = dict( 16 | lmax=args["lmax"], 17 | vecnorm_type=args["vecnorm_type"], 18 | trainable_vecnorm=args["trainable_vecnorm"], 19 | num_heads=args["num_heads"], 20 | num_layers=args["num_layers"], 21 | hidden_channels=args["embedding_dimension"], 22 | num_rbf=args["num_rbf"], 23 | rbf_type=args["rbf_type"], 24 | trainable_rbf=args["trainable_rbf"], 25 | activation=args["activation"], 26 | attn_activation=args["attn_activation"], 27 | max_z=args["max_z"], 28 | cutoff=args["cutoff"], 29 | max_num_neighbors=args["max_num_neighbors"], 30 | ) 31 | 32 | # representation network 33 | if args["model"] == "ViSNetBlock": 34 | from .visnet_block import ViSNetBlock 35 | 36 | representation_model = ViSNetBlock(**visnet_args) 37 | else: 38 | raise ValueError(f"Unknown model {args['model']}.") 39 | 40 | # prior model 41 | if args["prior_model"] and prior_model is None: 42 | assert "prior_args" in args, ( 43 | f"Requested prior model {args['prior_model']} but the " 44 | f'arguments are lacking the key "prior_args".' 45 | ) 46 | assert hasattr(priors, args["prior_model"]), ( 47 | f'Unknown prior model {args["prior_model"]}. ' 48 | f'Available models are {", ".join(priors.__all__)}' 49 | ) 50 | # instantiate prior model if it was not passed to create_model 51 | # (i.e. when loading a model) 52 | prior_model = getattr(priors, args["prior_model"])( 53 | **args["prior_args"] 54 | ) 55 | 56 | # create output network 57 | output_model = getattr( 58 | output_modules, "Equivariant" + args["output_model"] 59 | )(args["embedding_dimension"], args["activation"]) 60 | 61 | model = ViSNet( 62 | representation_model, 63 | output_model, 64 | prior_model=prior_model, 65 | reduce_op=args["reduce_op"], 66 | mean=mean, 67 | std=std, 68 | derivative=args["derivative"], 69 | ) 70 | return model 71 | 72 | 73 | def load_model(filepath, args=None, device="cpu", **kwargs): 74 | ckpt = torch.load(filepath, map_location="cpu") 75 | if args is None: 76 | args = ckpt["hyper_parameters"] 77 | 78 | for key, value in kwargs.items(): 79 | if key not in args: 80 | rank_zero_warn(f"Unknown hyperparameter: {key}={value}") 81 | args[key] = value 82 | 83 | model = create_model(args) 84 | state_dict = { 85 | re.sub(r"^model\.", "", k): v for k, v in ckpt["state_dict"].items() 86 | } 87 | model.load_state_dict(state_dict) 88 | 89 | for p in model.parameters(): 90 | p.requires_grad=False 91 | 92 | model = torch.jit.script(model) 93 | return model.to(device) 94 | 95 | 96 | class ViSNet(nn.Module): 97 | def __init__( 98 | self, 99 | representation_model, 100 | output_model, 101 | prior_model=None, 102 | reduce_op="add", 103 | mean=None, 104 | std=None, 105 | derivative=False, 106 | ): 107 | super(ViSNet, self).__init__() 108 | self.representation_model = representation_model 109 | self.output_model = output_model 110 | 111 | self.prior_model = prior_model 112 | if not output_model.allow_prior_model and prior_model is not None: 113 | self.prior_model = None 114 | rank_zero_warn( 115 | "Prior model was given but the output model does " 116 | "not allow prior models. Dropping the prior model." 117 | ) 118 | 119 | self.reduce_op = reduce_op 120 | self.derivative = derivative 121 | 122 | mean = torch.scalar_tensor(0) if mean is None else mean 123 | self.register_buffer("mean", mean) 124 | std = torch.scalar_tensor(1) if std is None else std 125 | self.register_buffer("std", std) 126 | 127 | self.reset_parameters() 128 | 129 | def reset_parameters(self): 130 | self.representation_model.reset_parameters() 131 | self.output_model.reset_parameters() 132 | if self.prior_model is not None: 133 | self.prior_model.reset_parameters() 134 | 135 | def forward(self, data: dict[str, Tensor]) -> Tuple[Tensor, Optional[Tensor]]: 136 | if self.derivative: 137 | data['pos'].requires_grad_(True) 138 | 139 | x, v = self.representation_model(data) 140 | x = self.output_model.pre_reduce(x, v, data['z'], data['pos'], data['batch']) 141 | x = x * self.std 142 | 143 | if self.prior_model is not None: 144 | x = self.prior_model(x, data['z']) 145 | 146 | out = scatter(x, data['batch'], dim=0, reduce=self.reduce_op) 147 | out = self.output_model.post_reduce(out) 148 | 149 | out = out + self.mean 150 | 151 | # compute gradients with respect to coordinates 152 | if self.derivative: 153 | grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(out)] 154 | dy = grad( 155 | [out], 156 | [data['pos']], 157 | grad_outputs=grad_outputs, 158 | create_graph=False, 159 | retain_graph=False, 160 | )[0] 161 | if dy is None: 162 | raise RuntimeError( 163 | "Autograd returned None for the force prediction." 164 | ) 165 | return out, -dy 166 | return out, None 167 | -------------------------------------------------------------------------------- /src/ViSNet/model/visnet_block.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | from torch_geometric.nn import MessagePassing 7 | from torch_scatter import scatter 8 | 9 | from .utils import ( 10 | CosineCutoff, 11 | Distance, 12 | EdgeEmbedding, 13 | NeighborEmbedding, 14 | Sphere, 15 | VecLayerNorm, 16 | act_class_mapping, 17 | rbf_class_mapping, 18 | ) 19 | 20 | 21 | class ViSNetBlock(nn.Module): 22 | def __init__( 23 | self, 24 | lmax: int=2, 25 | vecnorm_type="none", 26 | trainable_vecnorm=False, 27 | num_heads=8, 28 | num_layers=9, 29 | hidden_channels=256, 30 | num_rbf=32, 31 | rbf_type="expnorm", 32 | trainable_rbf=False, 33 | activation="silu", 34 | attn_activation="silu", 35 | max_z=100, 36 | cutoff=5.0, 37 | max_num_neighbors=32, 38 | ): 39 | super(ViSNetBlock, self).__init__() 40 | self.lmax = lmax 41 | self.vecnorm_type = vecnorm_type 42 | self.trainable_vecnorm = trainable_vecnorm 43 | self.num_heads = num_heads 44 | self.num_layers = num_layers 45 | self.hidden_channels = hidden_channels 46 | self.num_rbf = num_rbf 47 | self.rbf_type = rbf_type 48 | self.trainable_rbf = trainable_rbf 49 | self.activation = activation 50 | self.attn_activation = attn_activation 51 | self.max_z = max_z 52 | self.cutoff = cutoff 53 | self.max_num_neighbors = max_num_neighbors 54 | 55 | self.embedding = nn.Embedding(max_z, hidden_channels) 56 | self.distance = Distance( 57 | cutoff, max_num_neighbors=max_num_neighbors, loop=True 58 | ) 59 | self.sphere = Sphere(lmax=lmax) 60 | self.distance_expansion = rbf_class_mapping[rbf_type]( 61 | cutoff, num_rbf, trainable_rbf 62 | ) 63 | self.neighbor_embedding = NeighborEmbedding(hidden_channels, num_rbf, cutoff, max_z).jittable() 64 | self.edge_embedding = EdgeEmbedding( 65 | num_rbf, hidden_channels 66 | ).jittable() 67 | 68 | self.vis_mp_layers = nn.ModuleList() 69 | vis_mp_kwargs = dict( 70 | num_heads=num_heads, 71 | hidden_channels=hidden_channels, 72 | activation=activation, 73 | attn_activation=attn_activation, 74 | cutoff=cutoff, 75 | vecnorm_type=vecnorm_type, 76 | trainable_vecnorm=trainable_vecnorm, 77 | ) 78 | for _ in range(num_layers - 1): 79 | layer = ViS_MP(last_layer=False, **vis_mp_kwargs).jittable() 80 | self.vis_mp_layers.append(layer) 81 | self.vis_mp_layers.append( 82 | ViS_MP(last_layer=True, **vis_mp_kwargs).jittable() 83 | ) 84 | 85 | self.out_norm = nn.LayerNorm(hidden_channels) 86 | self.vec_out_norm = VecLayerNorm( 87 | hidden_channels, 88 | trainable=trainable_vecnorm, 89 | norm_type=vecnorm_type, 90 | ) 91 | self.reset_parameters() 92 | 93 | def reset_parameters(self): 94 | self.embedding.reset_parameters() 95 | self.distance_expansion.reset_parameters() 96 | self.neighbor_embedding.reset_parameters() 97 | self.edge_embedding.reset_parameters() 98 | for layer in self.vis_mp_layers: 99 | layer.reset_parameters() 100 | self.out_norm.reset_parameters() 101 | self.vec_out_norm.reset_parameters() 102 | 103 | def forward(self, data: dict[str, Tensor]) -> Tuple[Tensor, Tensor]: 104 | # z, pos, batch = data['z'], data['pos'], data['batch'] 105 | z: Tensor = data['z'] 106 | pos: Tensor = data['pos'] 107 | batch: Tensor = data['batch'] 108 | 109 | # Embedding Layers 110 | x = self.embedding(z) 111 | edge_index, edge_weight, edge_vec = self.distance(pos, batch) 112 | edge_attr = self.distance_expansion(edge_weight) 113 | mask = edge_index[0] != edge_index[1] 114 | edge_vec[mask] = edge_vec[mask] / torch.norm( 115 | edge_vec[mask], dim=1 116 | ).unsqueeze(1) 117 | edge_vec = self.sphere(edge_vec) 118 | x = self.neighbor_embedding(z, x, edge_index, edge_weight, edge_attr) 119 | vec = torch.zeros( 120 | x.size(0), int((self.lmax + 1) ** 2) - 1, x.size(1), device=x.device 121 | ) 122 | edge_attr: Tensor = self.edge_embedding(edge_index, edge_attr, x) 123 | 124 | # ViS-MP Layers 125 | for attn in self.vis_mp_layers[:-1]: 126 | dx, dvec, dedge_attr = attn.forward( 127 | x, vec, edge_index, edge_weight, edge_attr, edge_vec 128 | ) 129 | x = x + dx 130 | vec = vec + dvec 131 | edge_attr = edge_attr + dedge_attr 132 | 133 | dx, dvec, _ = self.vis_mp_layers[-1]( 134 | x, vec, edge_index, edge_weight, edge_attr, edge_vec 135 | ) 136 | x = x + dx 137 | vec = vec + dvec 138 | 139 | x = self.out_norm(x) 140 | vec = self.vec_out_norm(vec) 141 | 142 | return x, vec 143 | 144 | 145 | class ViS_MP(MessagePassing): 146 | def __init__( 147 | self, 148 | num_heads, 149 | hidden_channels, 150 | activation, 151 | attn_activation, 152 | cutoff, 153 | vecnorm_type, 154 | trainable_vecnorm, 155 | last_layer=False, 156 | ): 157 | super(ViS_MP, self).__init__(aggr="add", node_dim=0) 158 | assert hidden_channels % num_heads == 0, ( 159 | f"The number of hidden channels ({hidden_channels}) " 160 | f"must be evenly divisible by the number of " 161 | f"attention heads ({num_heads})" 162 | ) 163 | 164 | self.num_heads = num_heads 165 | self.hidden_channels = hidden_channels 166 | self.head_dim = hidden_channels // num_heads 167 | self.last_layer = last_layer 168 | 169 | self.layernorm = nn.LayerNorm(hidden_channels) 170 | self.vec_layernorm = VecLayerNorm( 171 | hidden_channels, 172 | trainable=trainable_vecnorm, 173 | norm_type=vecnorm_type, 174 | ) 175 | 176 | self.act = act_class_mapping[activation]() 177 | self.attn_activation = act_class_mapping[attn_activation]() 178 | 179 | self.cutoff = CosineCutoff(cutoff) 180 | 181 | self.vec_proj = nn.Linear( 182 | hidden_channels, hidden_channels * 3, bias=False 183 | ) 184 | 185 | self.q_proj = nn.Linear(hidden_channels, hidden_channels) 186 | self.k_proj = nn.Linear(hidden_channels, hidden_channels) 187 | self.v_proj = nn.Linear(hidden_channels, hidden_channels) 188 | self.dk_proj = nn.Linear(hidden_channels, hidden_channels) 189 | self.dv_proj = nn.Linear(hidden_channels, hidden_channels) 190 | 191 | self.s_proj = nn.Linear(hidden_channels, hidden_channels * 2) 192 | if not self.last_layer: 193 | self.f_proj = nn.Linear(hidden_channels, hidden_channels) 194 | self.w_src_proj = nn.Linear(hidden_channels, hidden_channels, bias=False) 195 | self.w_trg_proj = nn.Linear(hidden_channels, hidden_channels, bias=False) 196 | else: 197 | # XXX just to make tracer happy 198 | self.f_proj = nn.Identity() 199 | self.w_src_proj = nn.Identity() 200 | self.w_trg_proj = nn.Identity() 201 | 202 | self.o_proj = nn.Linear(hidden_channels, hidden_channels * 3) 203 | 204 | self.reset_parameters() 205 | 206 | @staticmethod 207 | def vector_rejection(vec, d_ij): 208 | vec_proj = (vec * d_ij.unsqueeze(2)).sum(dim=1, keepdim=True) 209 | return vec - vec_proj * d_ij.unsqueeze(2) 210 | 211 | def reset_parameters(self): 212 | self.layernorm.reset_parameters() 213 | self.vec_layernorm.reset_parameters() 214 | nn.init.xavier_uniform_(self.q_proj.weight) 215 | self.q_proj.bias.data.fill_(0) 216 | nn.init.xavier_uniform_(self.k_proj.weight) 217 | self.k_proj.bias.data.fill_(0) 218 | nn.init.xavier_uniform_(self.v_proj.weight) 219 | self.v_proj.bias.data.fill_(0) 220 | nn.init.xavier_uniform_(self.o_proj.weight) 221 | self.o_proj.bias.data.fill_(0) 222 | nn.init.xavier_uniform_(self.s_proj.weight) 223 | self.s_proj.bias.data.fill_(0) 224 | 225 | if not self.last_layer: 226 | nn.init.xavier_uniform_(self.f_proj.weight) 227 | self.f_proj.bias.data.fill_(0) 228 | nn.init.xavier_uniform_(self.w_src_proj.weight) 229 | nn.init.xavier_uniform_(self.w_trg_proj.weight) 230 | 231 | nn.init.xavier_uniform_(self.vec_proj.weight) 232 | nn.init.xavier_uniform_(self.dk_proj.weight) 233 | self.dk_proj.bias.data.fill_(0) 234 | nn.init.xavier_uniform_(self.dv_proj.weight) 235 | self.dv_proj.bias.data.fill_(0) 236 | 237 | def forward(self, x, vec, edge_index, r_ij, f_ij, d_ij): 238 | x = self.layernorm(x) 239 | vec = self.vec_layernorm(vec) 240 | 241 | q = self.q_proj(x).reshape(-1, self.num_heads, self.head_dim) 242 | k = self.k_proj(x).reshape(-1, self.num_heads, self.head_dim) 243 | v = self.v_proj(x).reshape(-1, self.num_heads, self.head_dim) 244 | dk = self.act(self.dk_proj(f_ij)).reshape(-1, self.num_heads, self.head_dim) 245 | dv = self.act(self.dv_proj(f_ij)).reshape(-1, self.num_heads, self.head_dim) 246 | 247 | vec1, vec2, vec3 = torch.split( 248 | self.vec_proj(vec), self.hidden_channels, dim=-1 249 | ) 250 | vec_dot = (vec1 * vec2).sum(dim=1) 251 | 252 | # propagate_type: (q: Tensor, k: Tensor, v: Tensor, dk: Tensor, dv: Tensor, vec: Tensor, r_ij: Tensor, d_ij: Tensor) 253 | x, vec_out = self.propagate( 254 | edge_index, 255 | q=q, 256 | k=k, 257 | v=v, 258 | dk=dk, 259 | dv=dv, 260 | vec=vec, 261 | r_ij=r_ij, 262 | d_ij=d_ij, 263 | size=None, 264 | ) 265 | 266 | # edge_updater_type: (vec: Tensor, d_ij: Tensor, f_ij: Tensor) 267 | df_ij = self.edge_updater( 268 | edge_index, vec=vec, d_ij=d_ij, f_ij=f_ij 269 | ) 270 | 271 | o1, o2, o3 = torch.split(self.o_proj(x), self.hidden_channels, dim=1) 272 | dx = vec_dot * o2 + o3 273 | dvec = vec3 * o1.unsqueeze(1) + vec_out 274 | return dx, dvec, df_ij 275 | 276 | def message(self, q_i, k_j, v_j, vec_j, dk, dv, r_ij, d_ij): 277 | attn = (q_i * k_j * dk).sum(dim=-1) 278 | attn = self.attn_activation(attn) * self.cutoff(r_ij).unsqueeze(1) 279 | 280 | v_j = v_j * dv 281 | v_j = (v_j * attn.unsqueeze(2)).view(-1, self.hidden_channels) 282 | 283 | s1, s2 = torch.split( 284 | self.act(self.s_proj(v_j)), self.hidden_channels, dim=1 285 | ) 286 | vec_j = vec_j * s1.unsqueeze(1) + s2.unsqueeze(1) * d_ij.unsqueeze(2) 287 | 288 | return v_j, vec_j 289 | 290 | def edge_update(self, vec_i, vec_j, d_ij, f_ij): 291 | w1 = self.vector_rejection(self.w_trg_proj(vec_i), d_ij) 292 | w2 = self.vector_rejection(self.w_src_proj(vec_j), -d_ij) 293 | w_dot = (w1 * w2).sum(dim=1) 294 | df_ij = self.act(self.f_proj(f_ij)) * w_dot 295 | return df_ij 296 | 297 | def aggregate( 298 | self, 299 | features: Tuple[torch.Tensor, torch.Tensor], 300 | index: torch.Tensor, 301 | ptr: Optional[torch.Tensor], 302 | dim_size: Optional[int], 303 | ) -> Tuple[torch.Tensor, torch.Tensor]: 304 | x, vec = features 305 | x = scatter(x, index, dim=self.node_dim, dim_size=dim_size) 306 | vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size) 307 | return x, vec 308 | 309 | def update( 310 | self, inputs: Tuple[torch.Tensor, torch.Tensor] 311 | ) -> Tuple[torch.Tensor, torch.Tensor]: 312 | return inputs 313 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | 5 | import torch.multiprocessing as mp 6 | 7 | from AIMD import arguments 8 | from AIMD.preprocess import Preprocess 9 | from AIMD.protein import Protein 10 | from AIMD.simulator import SolventSimulator, NoSolventSimulator 11 | from Calculators.calculator import patch_check_state 12 | from utils.pdb import fix_atomic_numbers, read_protein 13 | from utils.system import redir_output 14 | 15 | if __name__ == "__main__": 16 | mp.set_sharing_strategy("file_system") 17 | mp.set_start_method("spawn") 18 | 19 | args = arguments.init() 20 | if args.verbose >= 2: 21 | logging.basicConfig(level=logging.DEBUG) 22 | elif args.verbose >= 1: 23 | logging.basicConfig(level=logging.INFO) 24 | else: 25 | logging.basicConfig(level=logging.ERROR) 26 | 27 | logfile = os.path.join(args.log_dir, f"main-{time.strftime('%Y%m%d-%H%M%S')}.log") 28 | redir_output(logfile) 29 | 30 | preeq = Preprocess( 31 | prot_path=args.prot_file, 32 | utils_dir=args.utils_dir, 33 | command_save_path=os.path.join(args.log_dir, "PreprocessBackup"), 34 | preprocess_method=args.preprocess_method, 35 | log_dir=args.log_dir, 36 | temp_k=args.temp_k, 37 | ) 38 | preeq_pdb, preeq_nowat_pdb = preeq.run_preprocess() 39 | 40 | os.chdir(args.base_dir) 41 | 42 | patch_check_state() 43 | 44 | prot = Protein(read_protein(preeq_pdb), pdb4params=preeq_nowat_pdb) 45 | fix_atomic_numbers(preeq_pdb, prot) 46 | 47 | simulator_class = SolventSimulator if args.solvent else NoSolventSimulator 48 | simulator = simulator_class( 49 | prot=prot, 50 | log_path=args.log_dir, 51 | preeq_steps=args.preeq_steps, 52 | temp_k=args.temp_k, 53 | utils_dir=args.utils_dir, 54 | pdb_file=preeq_pdb, 55 | nowat_pdb_file=preeq_nowat_pdb, 56 | mmcalc_type=args.mm_method, 57 | preprocess_method=args.preprocess_method, 58 | dev_strategy=args.device_strategy, 59 | ) 60 | 61 | simulator.set_calculator( 62 | ckpt_path=args.ckpt_path, 63 | ckpt_type=args.ckpt_type, 64 | nbcalc_type=args.fragment_longrange_calc, 65 | ) 66 | 67 | simulator.simulate( 68 | prot_name=args.prot_name, 69 | simulation_steps=args.sim_steps, 70 | time_step=args.timestep, 71 | record_per_steps=args.record_per_steps, 72 | hydrogen_constraints=args.constraints, 73 | seed=args.seed, 74 | restart=args.restart, 75 | build_frames=args.build_frames, 76 | ) 77 | -------------------------------------------------------------------------------- /src/utils/pdb.py: -------------------------------------------------------------------------------- 1 | import math 2 | import re 3 | 4 | import numpy as np 5 | from ase import Atoms 6 | from ase.data import chemical_symbols 7 | from ase.io import read 8 | 9 | 10 | def read_protein(fpath: str) -> Atoms: 11 | r""" 12 | Convert .pdb file to ase Atoms object. 13 | """ 14 | assert fpath.endswith(".pdb"), "Error: The file format is not PDB!" 15 | atoms = read(fpath) 16 | if len(atoms) == 0: 17 | raise ValueError("Error: The PDB file is empty!") 18 | 19 | return atoms 20 | 21 | 22 | def fix_atomic_numbers(fpath: str, atoms: Atoms): 23 | r""" 24 | Fix atomic numbers. ase.io.read interprets atom names incorrectly when the 25 | atoms are denoted as part of a protein residue. 26 | """ 27 | symbols = [] 28 | with open(fpath, 'r') as f: 29 | for l in f.readlines(): 30 | if l.startswith('ATOM') or l.startswith('HETATOM'): 31 | symbol = l[12:14].strip() 32 | if symbol.startswith('H'): 33 | symbol = 'H' 34 | symbols.append(symbol) 35 | 36 | _numbers = atoms.get_atomic_numbers() 37 | _numbers[:len(symbols)] = [chemical_symbols.index(sym) for sym in symbols] 38 | 39 | atoms.set_atomic_numbers(_numbers) 40 | 41 | 42 | def reorder_atoms(fpath: str): 43 | r""" 44 | Reorder atoms in .pdb output from tinker. 45 | """ 46 | assert fpath.endswith(".pdb"), "Error: The file format is not PDB!" 47 | with open(fpath, 'r') as f: 48 | lines = f.readlines() 49 | 50 | output = [] 51 | sidechain = [] 52 | 53 | res_id = None 54 | res_count = 0 55 | h_found = False 56 | 57 | for l in lines: 58 | cols = l.split() 59 | 60 | if len(cols) < 8 or cols[0] != 'ATOM': 61 | output.append(l) 62 | continue 63 | 64 | if cols[4] != res_id: 65 | res_count = 0 66 | 67 | # check if sidechain atoms should be written 68 | if cols[2] == 'H' or cols[2] == 'HA': 69 | res_count = 0 70 | h_found = True 71 | elif h_found is True: 72 | # write sidechain atoms right after H/HA 73 | output.extend(sidechain) 74 | sidechain = [] 75 | h_found = False 76 | 77 | # enumerate N/CA/C/O atoms 78 | if res_count == 0 and cols[2] == 'N': 79 | res_count += 1 80 | elif res_count == 1 and cols[2] == 'CA': 81 | res_count += 1 82 | elif res_count == 2 and cols[2] == 'C': 83 | res_count += 1 84 | elif res_count == 3 and cols[2] == 'O': 85 | res_count += 1 86 | # save rows between N/CA/C/O and H/HA 87 | elif res_count == 4: 88 | sidechain.append(l) 89 | res_id = cols[4] 90 | continue 91 | 92 | # save line to output 93 | output.append(l) 94 | 95 | # update current residue id 96 | res_id = cols[4] 97 | 98 | with open(fpath, 'w') as f: 99 | for l in output: 100 | f.write(l) 101 | 102 | 103 | def standardise_pdb(fpath: str): 104 | r""" 105 | Check and rewrite residue numbers in .pdb output from tinker. Wraps residue 106 | numbers > 9999 to 0 so that ase can process the .pdb file correctly. 107 | """ 108 | assert fpath.endswith(".pdb"), "Error: The file format is not PDB!" 109 | with open(fpath, 'r') as f: 110 | for line in f: 111 | if not line.startswith('ATOM'): 112 | continue 113 | 114 | try: 115 | # ase.io.proteindatabank: extract the residue number 116 | res_idx = int(line[22:26].split()[0]) 117 | except IndexError: 118 | break 119 | else: 120 | return 121 | 122 | output = [] 123 | with open(fpath, 'r') as f: 124 | for line in f: 125 | if line.startswith('ATOM') or line.startswith('HETATM'): 126 | # wrap the residue number on 10000 127 | res_idx = int(line[6:].split()[3]) 128 | end = re.search(r'\b({})\b'.format(res_idx), line[22:]).end() + 22 129 | output.append(line[:22] + f"{res_idx % 10000: >4}" + line[end:]) 130 | else: 131 | output.append(line) 132 | 133 | with open(fpath, 'w') as f: 134 | for l in output: 135 | f.write(l) 136 | 137 | 138 | # Move pdb coordinates according to the center of mass 139 | def translate_coord_pdb(inpfile: str, outfile: str): 140 | atomic_masses = { 141 | 'H': 1.008, 142 | 'C': 12.011, 143 | 'N': 14.007, 144 | 'O': 15.999, 145 | 'F': 18.998, 146 | 'P': 30.974, 147 | 'S': 32.06, 148 | 'NA': 22.990, 149 | 'CL': 35.453, 150 | } 151 | atoms = [] 152 | masses = [] 153 | with open(inpfile, 'r') as file: 154 | for line in file: 155 | if line.startswith("ATOM") or line.startswith("HETATM"): 156 | x = float(line[30:38].strip()) 157 | y = float(line[38:46].strip()) 158 | z = float(line[46:54].strip()) 159 | atom_type = line[76:78].strip() 160 | mass = atomic_masses.get(atom_type, 0) 161 | atoms.append([x, y, z]) 162 | masses.append(mass) 163 | atoms = np.array(atoms) 164 | masses = np.array(masses) 165 | 166 | # get mass center 167 | total_mass = np.sum(masses) 168 | mass_center = np.sum(atoms * masses[:, np.newaxis], axis=0) / total_mass 169 | 170 | # translate atoms 171 | atoms -= mass_center 172 | 173 | # write pdb 174 | with open(inpfile, 'r') as file: 175 | original_lines = file.readlines() 176 | 177 | # get pdc 178 | pbc_x = math.ceil(np.max(np.abs(atoms[:, 0])) * 2 + 20) 179 | pbc_y = math.ceil(np.max(np.abs(atoms[:, 1])) * 2 + 20) 180 | pbc_z = math.ceil(np.max(np.abs(atoms[:, 2])) * 2 + 20) 181 | 182 | with open(outfile, 'w') as file: 183 | atom_index = 0 184 | # file.write(f"HEADER {pbc_x:.3f} {pbc_y:.3f} {pbc_z:.3f}\n") 185 | for line in original_lines[1:]: 186 | if line.startswith("ATOM") or line.startswith("HETATM"): 187 | x, y, z = atoms[atom_index] 188 | file.write(f"{line[:30]}{x:8.3f}{y:8.3f}{z:8.3f}{line[54:]}") 189 | atom_index += 1 190 | else: 191 | file.write(line) 192 | 193 | return pbc_x, pbc_y, pbc_z 194 | 195 | 196 | def reorder_coord_amber2tinker(fpath: str): 197 | r""" 198 | Reorder atoms in .pdb output from tinker. 199 | """ 200 | assert fpath.endswith(".pdb"), "Error: The file format is not PDB!" 201 | 202 | reorder_dict = { 203 | 'ACE': [1, 4, 5, 0, 2, 3], 204 | 'ALA': [0, 2, 8, 9, 1, 3, 4, 5, 6, 7], 205 | 'ARG': [0, 2, 22, 23, 1, 3, 4, 7, 10, 13, 15, 16, 19, 5, 6, 8, 9, 11, 12, 14, 17, 18, 20, 21], 206 | 'ASN': [0, 2, 12, 13, 1, 3, 4, 7, 8, 9, 5, 6, 10, 11], 207 | 'ASP': [0, 2, 10, 11, 1, 3, 4, 7, 8, 9, 5, 6], 208 | 'CYS': [0, 2, 9, 10, 1, 3, 4, 7, 5, 6, 8], 209 | 'CYX': [0, 2, 8, 9, 1, 3, 4, 7, 5, 6], 210 | 'GLN': [0, 2, 15, 16, 1, 3, 4, 7, 10, 11, 12, 5, 6, 8, 9, 13, 14], 211 | 'GLU': [0, 2, 13, 14, 1, 3, 4, 7, 10, 11, 12, 5, 6, 8, 9], 212 | 'GLY': [0, 2, 5, 6, 1, 3, 4], 213 | 'HIE': [0, 2, 15, 16, 1, 3, 4, 7, 8, 13, 9, 11, 5, 6, 14, 10, 12], 214 | 'ILE': [0, 2, 17, 18, 1, 3, 4, 10, 6, 13, 5, 11, 12, 7, 8, 9, 14, 15, 16], 215 | 'LEU': [0, 2, 17, 18, 1, 3, 4, 7, 9, 13, 5, 6, 8, 10, 11, 12, 14, 15, 16], 216 | 'LYS': [0, 2, 20, 21, 1, 3, 4, 7, 10, 13, 16, 5, 6, 8, 9, 11, 12, 14, 15, 17, 18, 19], 217 | 'MET': [0, 2, 15, 16, 1, 3, 4, 7, 10, 11, 5, 6, 8, 9, 12, 13, 14], 218 | 'PHE': [0, 2, 18, 19, 1, 3, 4, 7, 8, 16, 10, 14, 12, 5, 6, 9, 17, 11, 15, 13], 219 | 'PRO': [0, 10, 12, 13, 11, 7, 4, 1, 8, 9, 5, 6, 2, 3], 220 | 'SER': [0, 2, 9, 10, 1, 3, 4, 7, 5, 6, 8], 221 | 'THR': [0, 2, 12, 13, 1, 3, 4, 10, 6, 5, 11, 7, 8, 9], 222 | 'TRP': [0, 2, 22, 23, 1, 3, 4, 7, 8, 21, 10, 12, 19, 13, 17, 15, 5, 6, 9, 11, 20, 14, 18, 16], 223 | 'TYR': [0, 2, 19, 20, 1, 3, 4, 7, 8, 17, 10, 15, 12, 13, 5, 6, 9, 18, 11, 16, 14], 224 | 'VAL': [0, 2, 14, 15, 1, 3, 4, 6, 10, 5, 7, 8, 9, 11, 12, 13], 225 | 'NME': [0, 2, 1, 3, 4, 5], 226 | } 227 | 228 | output = [] 229 | 230 | amino_acids = [] 231 | residue_names = [''] 232 | atom_start = False 233 | 234 | with open(fpath, 'r') as f: 235 | lines = f.readlines() 236 | 237 | res_idx = None 238 | atoms = [] 239 | 240 | for l in lines: 241 | cols = l.split() 242 | 243 | if not atom_start and not l.startswith('ATOM') and not l.startswith('HETATM'): 244 | output.append(l) 245 | continue 246 | 247 | atom_start = True 248 | 249 | if atom_start and not l.startswith('ATOM') and not l.startswith('HETATM'): 250 | continue 251 | 252 | if cols[4] != res_idx: 253 | res_idx = cols[4] 254 | 255 | residue_names.append(cols[3]) 256 | amino_acids.append(atoms) 257 | atoms = [] 258 | 259 | atoms.append(l) 260 | 261 | amino_acids.append(atoms) 262 | 263 | for residue, atoms in zip(residue_names[1:], amino_acids[1:]): 264 | if residue not in reorder_dict: 265 | output.extend(atoms) 266 | else: 267 | reordered_atoms = [atoms[i] for i in reorder_dict[residue]] 268 | output.extend(reordered_atoms) 269 | 270 | with open(fpath, 'w') as f: 271 | for l in output: 272 | f.write(l) 273 | -------------------------------------------------------------------------------- /src/utils/seq_dict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/AI2BMD/734b4de28340e4f5c0cfaa715fd92e66fa50e593/src/utils/seq_dict.pkl -------------------------------------------------------------------------------- /src/utils/signals.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import re 4 | import signal 5 | import sys 6 | import threading 7 | import traceback 8 | import types 9 | # from concurrent.futures import ThreadPoolExecutor 10 | from io import StringIO 11 | # from multiprocessing import Lock, Process 12 | from os import getpid 13 | # from time import sleep 14 | from typing import Callable, Optional 15 | 16 | import psutil 17 | 18 | # LOCK = Lock() 19 | 20 | 21 | def _signal_tree(pid: int, sig: signal.Signals) -> None: 22 | parent = psutil.Process(pid) 23 | children = parent.children(recursive=True) 24 | print("==== Found {} child processes for PID {}: {} ====".format(len(children), pid, children)) 25 | for p in children: 26 | try: 27 | p.send_signal(sig) 28 | except psutil.NoSuchProcess: 29 | pass 30 | 31 | 32 | class TeeWriter: 33 | def __init__(self, filename: str) -> None: 34 | self._filename = filename 35 | self._stream = sys.stdout 36 | 37 | def __enter__(self) -> "TeeWriter": 38 | self._file = open(self._filename, "a") 39 | return self 40 | 41 | def write(self, data: str) -> None: 42 | self._file.write(data) 43 | self._stream.write(data) 44 | 45 | def flush(self) -> None: 46 | self._file.flush() 47 | self._stream.flush() 48 | 49 | def __exit__(self, *args, **kwargs): # type: ignore 50 | self._file.close() 51 | 52 | 53 | def _create_handler( 54 | pass_through: bool, match: Optional[str] = None 55 | ) -> Callable[[int, Optional[types.FrameType]], None]: 56 | """ 57 | if pass_through, pass the signal to all child processes as well 58 | """ 59 | 60 | filename = f"stacktraces-{getpid()}.log" 61 | 62 | if "AMLT_OUTPUT_DIR" in os.environ: 63 | filename = os.path.join(os.environ["AMLT_OUTPUT_DIR"], filename) 64 | 65 | def handler(signal_number: int, frame: Optional[types.FrameType]) -> None: 66 | with TeeWriter(filename) as file: 67 | print( 68 | f"---- Process {getpid()} received signal {signal_number}: {datetime.datetime.now().isoformat()} ----", 69 | file=file, 70 | ) 71 | threads = list(threading.enumerate()) 72 | tracebacks = {} 73 | for th in threads: 74 | buf = StringIO() 75 | assert th.ident 76 | traceback.print_stack(sys._current_frames()[th.ident], file=buf) 77 | tracebacks[str(th)] = buf.getvalue() 78 | print(f"------ Found {len(tracebacks)} threads in process {getpid()}.", file=file) 79 | for info, tb in tracebacks.items(): 80 | if match and not re.search(match, tb): 81 | print(f"Skipping {info} because it does not match {match}", file=file) 82 | continue 83 | print("---------", info, file=file) 84 | print(tb, file=file) 85 | 86 | if pass_through: 87 | print("Signaling children of process", getpid(), file=file) 88 | _signal_tree(getpid(), signal.SIGUSR2) 89 | 90 | return handler 91 | 92 | 93 | def register_print_stack_on_sigusr2(pass_through: bool = True, match: Optional[str] = None) -> None: 94 | """ 95 | To investigate hanging processes, call this function once in your main process 96 | with pass_through=True, and in every subprocess with pass_through=False. 97 | 98 | If `match` is specified, only print stack traces matching this pattern. 99 | """ 100 | for sig in [signal.SIGUSR2]: 101 | signal.signal(sig, _create_handler(pass_through)) 102 | 103 | -------------------------------------------------------------------------------- /src/utils/system.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | 5 | 6 | _tee = None 7 | 8 | def redir_output(logfile: str): 9 | global _tee 10 | if _tee: 11 | raise Exception("Standard output/error already redirected") 12 | _tee = subprocess.Popen(["tee", logfile], stdin=subprocess.PIPE) 13 | # Cause tee's stdin to get a copy of our stdin/stdout (as well as that 14 | # of any child processes we spawn) 15 | os.dup2(_tee.stdin.fileno(), sys.stdout.fileno()) 16 | os.dup2(_tee.stdin.fileno(), sys.stderr.fileno()) 17 | 18 | 19 | def which_python(): 20 | """returns the full path to the python interpreter, 21 | e.g. /opt/conda/bin/python""" 22 | which = subprocess.Popen(["which", "python"], stdout=subprocess.PIPE) 23 | if which.wait(): 24 | raise RuntimeError("which") 25 | return which.stdout.read().decode() 26 | 27 | 28 | def get_physical_core_count() -> int: 29 | """Obtain the number of physical cores on the local machine. 30 | Logical processors (i.e. hyper-threading cores) will only be counted as ONE processor. 31 | For example, for a machine with 2 hyperthreads per CPU, and `nproc` reports 32, 32 | this routine will return 16. 33 | 34 | The return value is suitable for starting numerical-intensive MPI tasks, to avoid 35 | over-subscribing the float processing units (FPUs), which are usually shared among 36 | the logical cores. 37 | """ 38 | 39 | lscpu = subprocess.Popen("lscpu", stdout=subprocess.PIPE) 40 | if lscpu.wait(): 41 | raise RuntimeError("lscpu") 42 | output = lscpu.stdout.read().decode().splitlines() 43 | n_cores_per_socket = next(x for x in output if "Core(s) per socket:" in x).split(':')[1].strip() 44 | n_sockets = next(x for x in output if "Socket(s):" in x).split(':')[1].strip() 45 | return int(n_cores_per_socket) * int(n_sockets) 46 | -------------------------------------------------------------------------------- /src/utils/traj2dcd.py: -------------------------------------------------------------------------------- 1 | import MDAnalysis as mda 2 | from ase.io import read 3 | from MDAnalysis.coordinates.DCD import DCDWriter 4 | import os 5 | import argparse 6 | 7 | def traj2dcd(trajpath,output_name): 8 | if not os.path.exists(trajpath): 9 | return 10 | atoms = read(trajpath,index='::1') 11 | u = mda.Universe(pdbtop, atoms[0][:prot_atom_num].get_positions()) 12 | with DCDWriter(output_name, n_atoms=u.atoms.n_atoms) as W: 13 | for idx,frame in enumerate(atoms): 14 | if idx % stride == 0: 15 | u.atoms.positions = frame[:prot_atom_num].get_positions() 16 | W.write(u) 17 | 18 | parser = argparse.ArgumentParser(description="Traj to DCD") 19 | parser.add_argument('--input', type=str) 20 | parser.add_argument('--output', type=str) 21 | parser.add_argument('--pdb', type=str) 22 | parser.add_argument('--num-atoms',type=int) 23 | parser.add_argument('--stride', type=int) 24 | 25 | args = parser.parse_args() 26 | trajpath = args.input 27 | output_name = args.output 28 | pdbtop = args.pdb 29 | prot_atom_num = args.num_atoms 30 | stride = args.stride 31 | 32 | traj2dcd(trajpath,output_name) 33 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from collections import deque 4 | from concurrent.futures import ThreadPoolExecutor 5 | from functools import partial, wraps 6 | from typing import Any, Callable, List 7 | 8 | import numpy as np 9 | import torch 10 | from ase import Atoms 11 | from ase.io.trajectory import TrajectoryWriter 12 | from ase.md.md import MolecularDynamics 13 | 14 | from AIMD import arguments 15 | 16 | 17 | def record_time(func): 18 | def wrapper(*args, **kwargs): 19 | start = time.time() 20 | result = func(*args, **kwargs) 21 | end = time.time() 22 | print(f"{func.__name__} takes {end - start} seconds") 23 | return result 24 | 25 | return wrapper 26 | 27 | 28 | class RNGPool: 29 | def __init__(self, seed, shape, count): 30 | self.rng = np.random.default_rng(seed) 31 | 32 | self.pool = deque() 33 | self.shape = shape 34 | self.count = count 35 | 36 | self.fill() 37 | 38 | def fill(self): 39 | while len(self.pool) < self.count: 40 | self.pool.append(self.rng.standard_normal(self.shape)) 41 | 42 | def drain(self): 43 | return self.pool.popleft() 44 | 45 | def standard_normal(self, size): 46 | if size == self.shape and len(self.pool): 47 | return self.drain() 48 | else: 49 | return self.rng.standard_normal(size) 50 | 51 | 52 | class SkipCheckState: 53 | """Temporarily disables atoms.check_state so that it does not compare""" 54 | def __init__(self, atoms): 55 | self.skip_check_state = getattr(atoms, 'skip_check_state', False) 56 | self.atoms = atoms 57 | atoms.skip_check_state = True 58 | 59 | def __enter__(self, *_): 60 | pass 61 | 62 | def __exit__(self, *_): 63 | self.atoms.skip_check_state = self.skip_check_state 64 | 65 | 66 | _workqueue_instance = None 67 | 68 | class WorkQueue(): 69 | def __init__(self): 70 | global _workqueue_instance 71 | self.work: deque[Callable] = deque() 72 | if _workqueue_instance is not None: 73 | raise RuntimeError("There should be only one WorkQueue instance.") 74 | _workqueue_instance = self 75 | 76 | def __bool__(self): 77 | return True 78 | 79 | def __len__(self): 80 | return len(self.work) 81 | 82 | def submit(self, action): 83 | return self.work.append(action) 84 | 85 | def drain(self): 86 | while len(self.work): 87 | self.work.popleft()() 88 | 89 | @classmethod 90 | def finalise(cls): 91 | if _workqueue_instance: 92 | _workqueue_instance.drain() 93 | 94 | 95 | def delay_work(f): 96 | @wraps(f) 97 | def wrapper(*args, **kwargs): 98 | if _workqueue_instance: 99 | _workqueue_instance.submit(partial(f, *args, **kwargs)) 100 | else: 101 | # root calc isn't workqueue-aware 102 | # do not delay the work because nobody will drain the works 103 | f(*args, **kwargs) 104 | 105 | return wrapper 106 | 107 | 108 | class TemperatureRunawayError(RuntimeError): 109 | def __init__(self, temp_k: int, *args: object): 110 | self.temp = temp_k 111 | super().__init__(args) 112 | 113 | 114 | class MDObserver: 115 | """ 116 | An observer class, offering functions that can be attached to ASE 117 | MolecularDynamics object, notified at every step. 118 | """ 119 | 120 | def __init__(self, a: Atoms, q: Atoms, md: MolecularDynamics, traj: TrajectoryWriter, rng: RNGPool, step_offset: int, temp_k: int): 121 | self.a = a 122 | self.q = q 123 | self.md = md 124 | self.traj = traj 125 | self.rng = rng 126 | self.step_offset = step_offset 127 | self.copy = None 128 | self.temp_k = temp_k 129 | 130 | self.atoms = a if arguments.get().write_solvent else q 131 | 132 | def get_md_step(self): 133 | return self.step_offset + self.md.nsteps 134 | 135 | def save_traj_copy(self): 136 | self.copy = self.atoms.copy() 137 | 138 | @delay_work 139 | def write_traj(self): 140 | with SkipCheckState(self.copy): 141 | self.traj.write(self.copy) 142 | 143 | def printenergy(self): 144 | """ 145 | Function to print the potential, kinetic and total energy 146 | """ 147 | # per atom need / len(a) 148 | 149 | with SkipCheckState(self.a): 150 | cur_time = time.perf_counter() 151 | epot = self.a.get_potential_energy().item() 152 | ekin = self.a.get_kinetic_energy().item() 153 | temperature = self.a.get_temperature() 154 | if temperature > 1.5 * self.temp_k: 155 | raise TemperatureRunawayError(temperature, "temperature runaway") 156 | print(f"Step {self.get_md_step():d}: " 157 | f"Epot = {epot:.3f}eV " 158 | f"Ekin = {ekin:.3f}eV " 159 | f"Etot = {epot+ekin:.3f}eV") 160 | 161 | @delay_work 162 | def fill_rng_pool(self): 163 | """ 164 | Function to fill the RNG pool 165 | """ 166 | self.rng.fill() 167 | 168 | 169 | class PDBAnalyzer: 170 | def __init__(self, filename): 171 | self.filename = filename 172 | self.covalent_radii = { 173 | "H": 0.31, 174 | "C": 0.76, 175 | "N": 0.71, 176 | "O": 0.66, 177 | "P": 1.07, 178 | "S": 1.05, 179 | } 180 | self.atoms = self.parse_pdb() 181 | 182 | def parse_pdb(self): 183 | """Parse a PDB file and return a list of atoms and their coordinates.""" 184 | atoms = [] 185 | with open(self.filename, "r") as f: 186 | for line in f: 187 | if line.startswith("ATOM"): 188 | atom_name = line[12:16].strip() 189 | x = float(line[30:38]) 190 | y = float(line[38:46]) 191 | z = float(line[46:54]) 192 | atoms.append((atom_name, np.array([x, y, z]))) 193 | return atoms 194 | 195 | def compute_distance(self, atom1, atom2): 196 | """Compute the Euclidean distance between two atoms.""" 197 | _, (_, pos1) = atom1 198 | _, (_, pos2) = atom2 199 | return np.linalg.norm(pos1 - pos2) 200 | 201 | def find_bonded_atoms(self, target_atom_name): 202 | """Find atoms that are bonded to a target atom type based on distance.""" 203 | covalent_radius = self.covalent_radii.get(target_atom_name, 0) 204 | indexed_atoms = [ 205 | (i, atom) 206 | for i, atom in enumerate(self.atoms) 207 | if atom[0].startswith(target_atom_name) 208 | ] 209 | bonded_atoms = [] 210 | for i1, atom1 in indexed_atoms: 211 | for i2, atom2 in enumerate(self.atoms): 212 | if i1 != i2: 213 | distance = self.compute_distance((i1, atom1), (i2, atom2)) 214 | atom2_radius = self.covalent_radii.get(atom2[0][0], 0) 215 | idea_length = covalent_radius + atom2_radius 216 | if distance <= idea_length + 0.2: 217 | bonded_atoms.append((i1, i2, idea_length + 0.2, 15)) 218 | assert len(indexed_atoms) == len( 219 | bonded_atoms 220 | ), "Hydrogen constraint: hydrogen covalent bonds != hydrogen num" 221 | return bonded_atoms 222 | 223 | 224 | def src_dir(): 225 | return os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 226 | 227 | 228 | # helpers for common numpy -> torch operations 229 | def numpy_to_torch(x: np.array, device: str): 230 | return torch.from_numpy(x).to(device) 231 | 232 | def numpy_list_to_torch(x: list[np.array], device: str): 233 | return torch.from_numpy(np.concatenate(x)).to(device) 234 | 235 | 236 | # wrapper for serial/parallel execution 237 | def execution_wrapper(f_args: list[Any], concurrent: bool): 238 | if concurrent is True: 239 | with ThreadPoolExecutor(len(f_args)) as executor: 240 | futures = [executor.submit(f, *args) for f, *args in f_args] 241 | 242 | return [f.result() for f in futures] 243 | else: 244 | return [f(*args) for f, *args in f_args] 245 | --------------------------------------------------------------------------------