├── .github └── workflows │ └── lint-and-test.yml ├── .gitignore ├── .gitmodules ├── .pylintrc ├── Dockerfile ├── LICENSE ├── Pipfile ├── Pipfile.lock ├── README.md ├── __main__.py ├── dashboard.py ├── data ├── mixture_model_pfam_100.txt ├── mixture_model_pfam_50.txt ├── mixture_model_pfam_500.txt ├── preprocessed │ └── .gitignore └── raw │ ├── .gitignore │ ├── protein_net_testfile.txt │ └── single_protein.txt ├── demo.mpg ├── examplemodelrun.png ├── experiments ├── __init__.py ├── example │ ├── __init__.py │ └── models.py ├── my_model │ └── __init__.py ├── rrn │ ├── __init__.py │ └── models.py └── tmhmm3 │ ├── __init__.py │ ├── tm_models.py │ └── tm_util.py ├── git-hooks └── pre-commit ├── op_cli.py ├── openprotein.py ├── output ├── .gitignore ├── models │ ├── .gitignore │ └── 2019-01-30_00_38_46-TRAIN-LR0_01-MB1.model └── predictions │ └── .gitignore ├── prediction.py ├── preprocessing.py ├── preprocessing_cli.py ├── tests ├── data │ └── raw │ │ └── testfile.txt ├── onnx_export.py ├── onnx_export_tmhmm3.py ├── output │ └── .gitignore └── test_onnx_export.py ├── training.py └── util.py /.github/workflows/lint-and-test.yml: -------------------------------------------------------------------------------- 1 | name: lint and test 2 | 3 | on: push 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 3.7 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: 3.7 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install pipenv 20 | pipenv install 21 | - name: Lint with pylint 22 | run: | 23 | find . -type f -name "*.py" | xargs pipenv run pylint --rcfile=.pylintrc 24 | - name: Test with pytest 25 | run: | 26 | export PYTHONPATH=. 27 | pipenv run python -m pytest -s 28 | 29 | - name: Build and push docker image to Docker Hub if on master branch 30 | run: >- 31 | if [ ${GITHUB_REF#refs/heads/} = 'master' ]; then 32 | docker build -t biolib/openprotein . && 33 | echo "$DOCKERHUB_PASS" | docker login -u "$DOCKERHUB_USERNAME" --password-stdin && 34 | docker push biolib/openprotein:latest ; 35 | fi 36 | env: 37 | DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} 38 | DOCKERHUB_PASS: ${{ secrets.DOCKERHUB_PASS }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .DS_Store 3 | __pycache__ 4 | .vscode/settings.json 5 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pytorchcrf"] 2 | path = pytorchcrf 3 | url = https://github.com/JeppeHallgren/pytorch-crf.git 4 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | # This configuration file provides Pylint instructions on which warning codes to ignore. 2 | # Ideally, we wouldn't ignore any warnings, and over time the codes below should be removed from here. 3 | # However, currently, Pylint has a few known bugs related to Pytorch and Numpy that causes some negligible warnings. 4 | # Besides that, some warnings are ignored at this point because fixing them would require significant code changes. 5 | 6 | 7 | [MASTER] 8 | 9 | # Missing doc string warnings are disabled so that we can provide these at a later time and parse Pylint now. 10 | # C0115: disable warnings on missing-doc-string for classes. 11 | # C0116: disable warnings on missing-doc-string for methods. 12 | 13 | # W0511: disable warnings on 'TODO' 14 | # W0105: disable warnings on 'String statement has no effect (pointless-string-statement)' 15 | # W0108: disable warnings on 'Lambda may not be necessary (unnecessary-lambda)'. 16 | # W0201: disable warnings on 'defined outside __init__ (attribute-defined-outside-init)'. 17 | # W0212: disable warnings on 'Access to a protected member x of a client class (protected-access). Allows access to CRF._compute_log_alpha(). 18 | # W0221: disable warnings on 'Parameters differ from overridden 'forward' method (arguments-differ)'. 19 | # W0231: disable warnings on '__init__ method from base class 'x' is not called (super-init-not-called)'. 20 | # W0603: disable warnings on use of keyword 'Global'. Security concerns related to this is negligible. 21 | 22 | # Some stylistic warnings below are ignored as of now since fixing these requires rewriting some of the logic in OpenProtein 23 | # R0801: disable warnings on 'Similar lines in 2 files'. 24 | # R0902: disable warnings on 'to many instance attributes'. 25 | # R0912: disable warnings on 'Too many branches (x/y) (too-many-branches)'. 26 | # R0913: disable warnings on 'Too many arguments (x/y) (too-many-arguments)'. 27 | # R0914: disable warnings on 'Too many local variables (too-many-locals)'. 28 | # R0915: disable warnings on 'Too many statements (x/y) (too-many-statements)'. 29 | # R1702: disable warnings on 'Too many nested blocks (6/5) (too-many-nested-blocks)'. 30 | # R1705: disable warnings on 'Unnecessary "else" after "return" (no-else-return)'. 31 | # R1710: disable warnings on 'Either all return statements in a function should return an expression, or none of them should. (inconsistent-return-statements)'. 32 | # R1721: disable warnings on 'Unnecessary use of a comprehension (unnecessary-comprehension)'. 33 | 34 | # Import errors like the ones below are ignored because its a bug in Pylint and not directly related to the actual code. 35 | # E0401: disable warnings on 'Unable to import X (import-error)'. . 36 | # E0611: disable warnings on 'No name 'x' in module 'y' (no-name-in-module)'. . 37 | # E1101: Class 'x' has no 'y' member (no-member). 38 | # E1102: disable warnings on function (not-callable). 39 | # E1121: disable warnings on 'too many positional arguments for function call'. 40 | 41 | disable=C0115,C0116,W0105,W0108,W0201,W0212,W0221,W0231,W0511,W0603,R0902,R0912,R0913,R0914,R0915,R1702,R1705,E0401,R0801,R1710,R1721,E1101,E1102,E0401,E0611,E1102,E1121 42 | 43 | [TYPECHECK] 44 | 45 | # List of members who is set dynamically and missed by pylint inference 46 | # system, and so shouldn't trigger E1101 when accessed. Python regular 47 | # expressions are accepted. 48 | generated-members=numpy.*,torch.* 49 | 50 | [FORMAT] 51 | 52 | # Temporary solution to suppress warnings on single case letters. Each could be changed to a longer name. 53 | # Names i,k,l are OK in some for loops. Longer names might disrupt readability. 54 | good-names=dRMSD_list,RMSD_list,ax,TraceS,RMSD,ba,X,Y,R,S,E0,bc,n,m,s,i,k,l,x,y,z,ax,ay,az,bx,by,bz 55 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM phusion/baseimage:0.11 2 | WORKDIR /openprotein 3 | 4 | # install dependencies 5 | RUN apt-get update -y && apt-get install -y --no-install-recommends \ 6 | ca-certificates \ 7 | clang \ 8 | cmake \ 9 | curl \ 10 | git \ 11 | libc6-dev \ 12 | make \ 13 | python3 \ 14 | python3-pip \ 15 | python3-setuptools \ 16 | python3-dev \ 17 | build-essential \ 18 | default-jre \ 19 | autoconf \ 20 | autogen \ 21 | libtool \ 22 | shtool \ 23 | autopoint \ 24 | software-properties-common 25 | 26 | RUN python3 -m pip install wheel 27 | RUN python3 -m pip install pipenv 28 | 29 | COPY . /openprotein 30 | 31 | ENV PIP_NO_CACHE_DIR=1 32 | RUN pipenv install 33 | 34 | RUN echo "pipenv shell" >> /root/.bashrc 35 | 36 | ENTRYPOINT ["/bin/bash"] 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019-2020 BioLib Inc 4 | Copyright (c) 2018 Jeppe Hallgren 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | name = "pypi" 3 | url = "https://pypi.org/simple" 4 | verify_ssl = true 5 | 6 | [dev-packages] 7 | 8 | [packages] 9 | torch = "==1.5.0" 10 | torchvision = "==0.6.0" 11 | flask = "==1.1.1" 12 | flask-cors = "==3.0.8" 13 | h5py = "==2.10.0" 14 | peptidebuilder = "==1.0.4" 15 | biopython = "==1.68" 16 | requests = "==2.22.0" 17 | pylint = "==2.5.2" 18 | pytest = "==5.4.1" 19 | pytest-ordering = "==0.6" 20 | 21 | [requires] 22 | python_version = "3" 23 | -------------------------------------------------------------------------------- /Pipfile.lock: -------------------------------------------------------------------------------- 1 | { 2 | "_meta": { 3 | "hash": { 4 | "sha256": "3c87df3dcb18078c268fa2c76b12bd03fc90e941b8d904a19d875f7d2a66950d" 5 | }, 6 | "pipfile-spec": 6, 7 | "requires": { 8 | "python_version": "3" 9 | }, 10 | "sources": [ 11 | { 12 | "name": "pypi", 13 | "url": "https://pypi.org/simple", 14 | "verify_ssl": true 15 | } 16 | ] 17 | }, 18 | "default": { 19 | "astroid": { 20 | "hashes": [ 21 | "sha256:2f4078c2a41bf377eea06d71c9d2ba4eb8f6b1af2135bec27bbbb7d8f12bb703", 22 | "sha256:bc58d83eb610252fd8de6363e39d4f1d0619c894b0ed24603b881c02e64c7386" 23 | ], 24 | "version": "==2.4.2" 25 | }, 26 | "attrs": { 27 | "hashes": [ 28 | "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c", 29 | "sha256:f7b7ce16570fe9965acd6d30101a28f62fb4a7f9e926b3bbc9b61f8b04247e72" 30 | ], 31 | "version": "==19.3.0" 32 | }, 33 | "biopython": { 34 | "hashes": [ 35 | "sha256:d1dc09d1ddc8e90833f507cf09f80fa9ee1537d319058d1c44fe9c09be3d0c1f" 36 | ], 37 | "index": "pypi", 38 | "version": "==1.68" 39 | }, 40 | "certifi": { 41 | "hashes": [ 42 | "sha256:5ad7e9a056d25ffa5082862e36f119f7f7cec6457fa07ee2f8c339814b80c9b1", 43 | "sha256:9cd41137dc19af6a5e03b630eefe7d1f458d964d406342dd3edf625839b944cc" 44 | ], 45 | "version": "==2020.4.5.2" 46 | }, 47 | "chardet": { 48 | "hashes": [ 49 | "sha256:84ab92ed1c4d4f16916e05906b6b75a6c0fb5db821cc65e70cbd64a3e2a5eaae", 50 | "sha256:fc323ffcaeaed0e0a02bf4d117757b98aed530d9ed4531e3e15460124c106691" 51 | ], 52 | "version": "==3.0.4" 53 | }, 54 | "click": { 55 | "hashes": [ 56 | "sha256:d2b5255c7c6349bc1bd1e59e08cd12acbbd63ce649f2588755783aa94dfb6b1a", 57 | "sha256:dacca89f4bfadd5de3d7489b7c8a566eee0d3676333fbb50030263894c38c0dc" 58 | ], 59 | "version": "==7.1.2" 60 | }, 61 | "flask": { 62 | "hashes": [ 63 | "sha256:13f9f196f330c7c2c5d7a5cf91af894110ca0215ac051b5844701f2bfd934d52", 64 | "sha256:45eb5a6fd193d6cf7e0cf5d8a5b31f83d5faae0293695626f539a823e93b13f6" 65 | ], 66 | "index": "pypi", 67 | "version": "==1.1.1" 68 | }, 69 | "flask-cors": { 70 | "hashes": [ 71 | "sha256:72170423eb4612f0847318afff8c247b38bd516b7737adfc10d1c2cdbb382d16", 72 | "sha256:f4d97201660e6bbcff2d89d082b5b6d31abee04b1b3003ee073a6fd25ad1d69a" 73 | ], 74 | "index": "pypi", 75 | "version": "==3.0.8" 76 | }, 77 | "future": { 78 | "hashes": [ 79 | "sha256:b1bead90b70cf6ec3f0710ae53a525360fa360d306a86583adc6bf83a4db537d" 80 | ], 81 | "version": "==0.18.2" 82 | }, 83 | "h5py": { 84 | "hashes": [ 85 | "sha256:063947eaed5f271679ed4ffa36bb96f57bc14f44dd4336a827d9a02702e6ce6b", 86 | "sha256:13c87efa24768a5e24e360a40e0bc4c49bcb7ce1bb13a3a7f9902cec302ccd36", 87 | "sha256:16ead3c57141101e3296ebeed79c9c143c32bdd0e82a61a2fc67e8e6d493e9d1", 88 | "sha256:3dad1730b6470fad853ef56d755d06bb916ee68a3d8272b3bab0c1ddf83bb99e", 89 | "sha256:51ae56894c6c93159086ffa2c94b5b3388c0400548ab26555c143e7cfa05b8e5", 90 | "sha256:54817b696e87eb9e403e42643305f142cd8b940fe9b3b490bbf98c3b8a894cf4", 91 | "sha256:549ad124df27c056b2e255ea1c44d30fb7a17d17676d03096ad5cd85edb32dc1", 92 | "sha256:64f74da4a1dd0d2042e7d04cf8294e04ddad686f8eba9bb79e517ae582f6668d", 93 | "sha256:6998be619c695910cb0effe5eb15d3a511d3d1a5d217d4bd0bebad1151ec2262", 94 | "sha256:6ef7ab1089e3ef53ca099038f3c0a94d03e3560e6aff0e9d6c64c55fb13fc681", 95 | "sha256:769e141512b54dee14ec76ed354fcacfc7d97fea5a7646b709f7400cf1838630", 96 | "sha256:79b23f47c6524d61f899254f5cd5e486e19868f1823298bc0c29d345c2447172", 97 | "sha256:7be5754a159236e95bd196419485343e2b5875e806fe68919e087b6351f40a70", 98 | "sha256:84412798925dc870ffd7107f045d7659e60f5d46d1c70c700375248bf6bf512d", 99 | "sha256:86868dc07b9cc8cb7627372a2e6636cdc7a53b7e2854ad020c9e9d8a4d3fd0f5", 100 | "sha256:8bb1d2de101f39743f91512a9750fb6c351c032e5cd3204b4487383e34da7f75", 101 | "sha256:a5f82cd4938ff8761d9760af3274acf55afc3c91c649c50ab18fcff5510a14a5", 102 | "sha256:aac4b57097ac29089f179bbc2a6e14102dd210618e94d77ee4831c65f82f17c0", 103 | "sha256:bffbc48331b4a801d2f4b7dac8a72609f0b10e6e516e5c480a3e3241e091c878", 104 | "sha256:c0d4b04bbf96c47b6d360cd06939e72def512b20a18a8547fa4af810258355d5", 105 | "sha256:c54a2c0dd4957776ace7f95879d81582298c5daf89e77fb8bee7378f132951de", 106 | "sha256:cbf28ae4b5af0f05aa6e7551cee304f1d317dbed1eb7ac1d827cee2f1ef97a99", 107 | "sha256:d35f7a3a6cefec82bfdad2785e78359a0e6a5fbb3f605dd5623ce88082ccd681", 108 | "sha256:d3c59549f90a891691991c17f8e58c8544060fdf3ccdea267100fa5f561ff62f", 109 | "sha256:d7ae7a0576b06cb8e8a1c265a8bc4b73d05fdee6429bffc9a26a6eb531e79d72", 110 | "sha256:ecf4d0b56ee394a0984de15bceeb97cbe1fe485f1ac205121293fc44dcf3f31f", 111 | "sha256:f0e25bb91e7a02efccb50aba6591d3fe2c725479e34769802fcdd4076abfa917", 112 | "sha256:f23951a53d18398ef1344c186fb04b26163ca6ce449ebd23404b153fd111ded9", 113 | "sha256:ff7d241f866b718e4584fa95f520cb19405220c501bd3a53ee11871ba5166ea2" 114 | ], 115 | "index": "pypi", 116 | "version": "==2.10.0" 117 | }, 118 | "idna": { 119 | "hashes": [ 120 | "sha256:c357b3f628cf53ae2c4c05627ecc484553142ca23264e593d327bcde5e9c3407", 121 | "sha256:ea8b7f6188e6fa117537c3df7da9fc686d485087abf6ac197f9c46432f7e4a3c" 122 | ], 123 | "version": "==2.8" 124 | }, 125 | "importlib-metadata": { 126 | "hashes": [ 127 | "sha256:0505dd08068cfec00f53a74a0ad927676d7757da81b7436a6eefe4c7cf75c545", 128 | "sha256:15ec6c0fd909e893e3a08b3a7c76ecb149122fb14b7efe1199ddd4c7c57ea958" 129 | ], 130 | "markers": "python_version < '3.8'", 131 | "version": "==1.6.1" 132 | }, 133 | "isort": { 134 | "hashes": [ 135 | "sha256:54da7e92468955c4fceacd0c86bd0ec997b0e1ee80d97f67c35a78b719dccab1", 136 | "sha256:6e811fcb295968434526407adb8796944f1988c5b65e8139058f2014cbe100fd" 137 | ], 138 | "version": "==4.3.21" 139 | }, 140 | "itsdangerous": { 141 | "hashes": [ 142 | "sha256:321b033d07f2a4136d3ec762eac9f16a10ccd60f53c0c91af90217ace7ba1f19", 143 | "sha256:b12271b2047cb23eeb98c8b5622e2e5c5e9abd9784a153e9d8ef9cb4dd09d749" 144 | ], 145 | "version": "==1.1.0" 146 | }, 147 | "jinja2": { 148 | "hashes": [ 149 | "sha256:89aab215427ef59c34ad58735269eb58b1a5808103067f7bb9d5836c651b3bb0", 150 | "sha256:f0a4641d3cf955324a89c04f3d94663aa4d638abe8f733ecd3582848e1c37035" 151 | ], 152 | "version": "==2.11.2" 153 | }, 154 | "lazy-object-proxy": { 155 | "hashes": [ 156 | "sha256:0c4b206227a8097f05c4dbdd323c50edf81f15db3b8dc064d08c62d37e1a504d", 157 | "sha256:194d092e6f246b906e8f70884e620e459fc54db3259e60cf69a4d66c3fda3449", 158 | "sha256:1be7e4c9f96948003609aa6c974ae59830a6baecc5376c25c92d7d697e684c08", 159 | "sha256:4677f594e474c91da97f489fea5b7daa17b5517190899cf213697e48d3902f5a", 160 | "sha256:48dab84ebd4831077b150572aec802f303117c8cc5c871e182447281ebf3ac50", 161 | "sha256:5541cada25cd173702dbd99f8e22434105456314462326f06dba3e180f203dfd", 162 | "sha256:59f79fef100b09564bc2df42ea2d8d21a64fdcda64979c0fa3db7bdaabaf6239", 163 | "sha256:8d859b89baf8ef7f8bc6b00aa20316483d67f0b1cbf422f5b4dc56701c8f2ffb", 164 | "sha256:9254f4358b9b541e3441b007a0ea0764b9d056afdeafc1a5569eee1cc6c1b9ea", 165 | "sha256:9651375199045a358eb6741df3e02a651e0330be090b3bc79f6d0de31a80ec3e", 166 | "sha256:97bb5884f6f1cdce0099f86b907aa41c970c3c672ac8b9c8352789e103cf3156", 167 | "sha256:9b15f3f4c0f35727d3a0fba4b770b3c4ebbb1fa907dbcc046a1d2799f3edd142", 168 | "sha256:a2238e9d1bb71a56cd710611a1614d1194dc10a175c1e08d75e1a7bcc250d442", 169 | "sha256:a6ae12d08c0bf9909ce12385803a543bfe99b95fe01e752536a60af2b7797c62", 170 | "sha256:ca0a928a3ddbc5725be2dd1cf895ec0a254798915fb3a36af0964a0a4149e3db", 171 | "sha256:cb2c7c57005a6804ab66f106ceb8482da55f5314b7fcb06551db1edae4ad1531", 172 | "sha256:d74bb8693bf9cf75ac3b47a54d716bbb1a92648d5f781fc799347cfc95952383", 173 | "sha256:d945239a5639b3ff35b70a88c5f2f491913eb94871780ebfabb2568bd58afc5a", 174 | "sha256:eba7011090323c1dadf18b3b689845fd96a61ba0a1dfbd7f24b921398affc357", 175 | "sha256:efa1909120ce98bbb3777e8b6f92237f5d5c8ea6758efea36a473e1d38f7d3e4", 176 | "sha256:f3900e8a5de27447acbf900b4750b0ddfd7ec1ea7fbaf11dfa911141bc522af0" 177 | ], 178 | "version": "==1.4.3" 179 | }, 180 | "markupsafe": { 181 | "hashes": [ 182 | "sha256:00bc623926325b26bb9605ae9eae8a215691f33cae5df11ca5424f06f2d1f473", 183 | "sha256:09027a7803a62ca78792ad89403b1b7a73a01c8cb65909cd876f7fcebd79b161", 184 | "sha256:09c4b7f37d6c648cb13f9230d847adf22f8171b1ccc4d5682398e77f40309235", 185 | "sha256:1027c282dad077d0bae18be6794e6b6b8c91d58ed8a8d89a89d59693b9131db5", 186 | "sha256:13d3144e1e340870b25e7b10b98d779608c02016d5184cfb9927a9f10c689f42", 187 | "sha256:24982cc2533820871eba85ba648cd53d8623687ff11cbb805be4ff7b4c971aff", 188 | "sha256:29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b", 189 | "sha256:43a55c2930bbc139570ac2452adf3d70cdbb3cfe5912c71cdce1c2c6bbd9c5d1", 190 | "sha256:46c99d2de99945ec5cb54f23c8cd5689f6d7177305ebff350a58ce5f8de1669e", 191 | "sha256:500d4957e52ddc3351cabf489e79c91c17f6e0899158447047588650b5e69183", 192 | "sha256:535f6fc4d397c1563d08b88e485c3496cf5784e927af890fb3c3aac7f933ec66", 193 | "sha256:596510de112c685489095da617b5bcbbac7dd6384aeebeda4df6025d0256a81b", 194 | "sha256:62fe6c95e3ec8a7fad637b7f3d372c15ec1caa01ab47926cfdf7a75b40e0eac1", 195 | "sha256:6788b695d50a51edb699cb55e35487e430fa21f1ed838122d722e0ff0ac5ba15", 196 | "sha256:6dd73240d2af64df90aa7c4e7481e23825ea70af4b4922f8ede5b9e35f78a3b1", 197 | "sha256:717ba8fe3ae9cc0006d7c451f0bb265ee07739daf76355d06366154ee68d221e", 198 | "sha256:79855e1c5b8da654cf486b830bd42c06e8780cea587384cf6545b7d9ac013a0b", 199 | "sha256:7c1699dfe0cf8ff607dbdcc1e9b9af1755371f92a68f706051cc8c37d447c905", 200 | "sha256:88e5fcfb52ee7b911e8bb6d6aa2fd21fbecc674eadd44118a9cc3863f938e735", 201 | "sha256:8defac2f2ccd6805ebf65f5eeb132adcf2ab57aa11fdf4c0dd5169a004710e7d", 202 | "sha256:98c7086708b163d425c67c7a91bad6e466bb99d797aa64f965e9d25c12111a5e", 203 | "sha256:9add70b36c5666a2ed02b43b335fe19002ee5235efd4b8a89bfcf9005bebac0d", 204 | "sha256:9bf40443012702a1d2070043cb6291650a0841ece432556f784f004937f0f32c", 205 | "sha256:ade5e387d2ad0d7ebf59146cc00c8044acbd863725f887353a10df825fc8ae21", 206 | "sha256:b00c1de48212e4cc9603895652c5c410df699856a2853135b3967591e4beebc2", 207 | "sha256:b1282f8c00509d99fef04d8ba936b156d419be841854fe901d8ae224c59f0be5", 208 | "sha256:b2051432115498d3562c084a49bba65d97cf251f5a331c64a12ee7e04dacc51b", 209 | "sha256:ba59edeaa2fc6114428f1637ffff42da1e311e29382d81b339c1817d37ec93c6", 210 | "sha256:c8716a48d94b06bb3b2524c2b77e055fb313aeb4ea620c8dd03a105574ba704f", 211 | "sha256:cd5df75523866410809ca100dc9681e301e3c27567cf498077e8551b6d20e42f", 212 | "sha256:cdb132fc825c38e1aeec2c8aa9338310d29d337bebbd7baa06889d09a60a1fa2", 213 | "sha256:e249096428b3ae81b08327a63a485ad0878de3fb939049038579ac0ef61e17e7", 214 | "sha256:e8313f01ba26fbbe36c7be1966a7b7424942f670f38e666995b88d012765b9be" 215 | ], 216 | "version": "==1.1.1" 217 | }, 218 | "mccabe": { 219 | "hashes": [ 220 | "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42", 221 | "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f" 222 | ], 223 | "version": "==0.6.1" 224 | }, 225 | "more-itertools": { 226 | "hashes": [ 227 | "sha256:558bb897a2232f5e4f8e2399089e35aecb746e1f9191b6584a151647e89267be", 228 | "sha256:7818f596b1e87be009031c7653d01acc46ed422e6656b394b0f765ce66ed4982" 229 | ], 230 | "version": "==8.3.0" 231 | }, 232 | "numpy": { 233 | "hashes": [ 234 | "sha256:0172304e7d8d40e9e49553901903dc5f5a49a703363ed756796f5808a06fc233", 235 | "sha256:34e96e9dae65c4839bd80012023aadd6ee2ccb73ce7fdf3074c62f301e63120b", 236 | "sha256:3676abe3d621fc467c4c1469ee11e395c82b2d6b5463a9454e37fe9da07cd0d7", 237 | "sha256:3dd6823d3e04b5f223e3e265b4a1eae15f104f4366edd409e5a5e413a98f911f", 238 | "sha256:4064f53d4cce69e9ac613256dc2162e56f20a4e2d2086b1956dd2fcf77b7fac5", 239 | "sha256:4674f7d27a6c1c52a4d1aa5f0881f1eff840d2206989bae6acb1c7668c02ebfb", 240 | "sha256:7d42ab8cedd175b5ebcb39b5208b25ba104842489ed59fbb29356f671ac93583", 241 | "sha256:965df25449305092b23d5145b9bdaeb0149b6e41a77a7d728b1644b3c99277c1", 242 | "sha256:9c9d6531bc1886454f44aa8f809268bc481295cf9740827254f53c30104f074a", 243 | "sha256:a78e438db8ec26d5d9d0e584b27ef25c7afa5a182d1bf4d05e313d2d6d515271", 244 | "sha256:a7acefddf994af1aeba05bbbafe4ba983a187079f125146dc5859e6d817df824", 245 | "sha256:a87f59508c2b7ceb8631c20630118cc546f1f815e034193dc72390db038a5cb3", 246 | "sha256:ac792b385d81151bae2a5a8adb2b88261ceb4976dbfaaad9ce3a200e036753dc", 247 | "sha256:b03b2c0badeb606d1232e5f78852c102c0a7989d3a534b3129e7856a52f3d161", 248 | "sha256:b39321f1a74d1f9183bf1638a745b4fd6fe80efbb1f6b32b932a588b4bc7695f", 249 | "sha256:cae14a01a159b1ed91a324722d746523ec757357260c6804d11d6147a9e53e3f", 250 | "sha256:cd49930af1d1e49a812d987c2620ee63965b619257bd76eaaa95870ca08837cf", 251 | "sha256:e15b382603c58f24265c9c931c9a45eebf44fe2e6b4eaedbb0d025ab3255228b", 252 | "sha256:e91d31b34fc7c2c8f756b4e902f901f856ae53a93399368d9a0dc7be17ed2ca0", 253 | "sha256:ef627986941b5edd1ed74ba89ca43196ed197f1a206a3f18cc9faf2fb84fd675", 254 | "sha256:f718a7949d1c4f622ff548c572e0c03440b49b9531ff00e4ed5738b459f011e8" 255 | ], 256 | "version": "==1.18.5" 257 | }, 258 | "packaging": { 259 | "hashes": [ 260 | "sha256:4357f74f47b9c12db93624a82154e9b120fa8293699949152b22065d556079f8", 261 | "sha256:998416ba6962ae7fbd6596850b80e17859a5753ba17c32284f67bfff33784181" 262 | ], 263 | "version": "==20.4" 264 | }, 265 | "peptidebuilder": { 266 | "hashes": [ 267 | "sha256:0b333ea83e32eaaa2ad99271dc933d681704bb7a92bf3ef73805a173f27051ad", 268 | "sha256:568620881e70b54df5abf5b557db71a63cd7aa499b2efb421b64ec9ea67a68cd" 269 | ], 270 | "index": "pypi", 271 | "version": "==1.0.4" 272 | }, 273 | "pillow": { 274 | "hashes": [ 275 | "sha256:04766c4930c174b46fd72d450674612ab44cca977ebbcc2dde722c6933290107", 276 | "sha256:0e2a3bceb0fd4e0cb17192ae506d5f082b309ffe5fc370a5667959c9b2f85fa3", 277 | "sha256:0f01e63c34f0e1e2580cc0b24e86a5ccbbfa8830909a52ee17624c4193224cd9", 278 | "sha256:12e4bad6bddd8546a2f9771485c7e3d2b546b458ae8ff79621214119ac244523", 279 | "sha256:1f694e28c169655c50bb89a3fa07f3b854d71eb47f50783621de813979ba87f3", 280 | "sha256:3d25dd8d688f7318dca6d8cd4f962a360ee40346c15893ae3b95c061cdbc4079", 281 | "sha256:4b02b9c27fad2054932e89f39703646d0c543f21d3cc5b8e05434215121c28cd", 282 | "sha256:9744350687459234867cbebfe9df8f35ef9e1538f3e729adbd8fde0761adb705", 283 | "sha256:a0b49960110bc6ff5fead46013bcb8825d101026d466f3a4de3476defe0fb0dd", 284 | "sha256:ae2b270f9a0b8822b98655cb3a59cdb1bd54a34807c6c56b76dd2e786c3b7db3", 285 | "sha256:b37bb3bd35edf53125b0ff257822afa6962649995cbdfde2791ddb62b239f891", 286 | "sha256:b532bcc2f008e96fd9241177ec580829dee817b090532f43e54074ecffdcd97f", 287 | "sha256:b67a6c47ed963c709ed24566daa3f95a18f07d3831334da570c71da53d97d088", 288 | "sha256:b943e71c2065ade6fef223358e56c167fc6ce31c50bc7a02dd5c17ee4338e8ac", 289 | "sha256:ccc9ad2460eb5bee5642eaf75a0438d7f8887d484490d5117b98edd7f33118b7", 290 | "sha256:d23e2aa9b969cf9c26edfb4b56307792b8b374202810bd949effd1c6e11ebd6d", 291 | "sha256:eaa83729eab9c60884f362ada982d3a06beaa6cc8b084cf9f76cae7739481dfa", 292 | "sha256:ee94fce8d003ac9fd206496f2707efe9eadcb278d94c271f129ab36aa7181344", 293 | "sha256:f455efb7a98557412dc6f8e463c1faf1f1911ec2432059fa3e582b6000fc90e2", 294 | "sha256:f46e0e024346e1474083c729d50de909974237c72daca05393ee32389dabe457", 295 | "sha256:f54be399340aa602066adb63a86a6a5d4f395adfdd9da2b9a0162ea808c7b276", 296 | "sha256:f784aad988f12c80aacfa5b381ec21fd3f38f851720f652b9f33facc5101cf4d" 297 | ], 298 | "version": "==7.1.2" 299 | }, 300 | "pluggy": { 301 | "hashes": [ 302 | "sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0", 303 | "sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d" 304 | ], 305 | "version": "==0.13.1" 306 | }, 307 | "py": { 308 | "hashes": [ 309 | "sha256:5e27081401262157467ad6e7f851b7aa402c5852dbcb3dae06768434de5752aa", 310 | "sha256:c20fdd83a5dbc0af9efd622bee9a5564e278f6380fffcacc43ba6f43db2813b0" 311 | ], 312 | "version": "==1.8.1" 313 | }, 314 | "pylint": { 315 | "hashes": [ 316 | "sha256:b95e31850f3af163c2283ed40432f053acbc8fc6eba6a069cb518d9dbf71848c", 317 | "sha256:dd506acce0427e9e08fb87274bcaa953d38b50a58207170dbf5b36cf3e16957b" 318 | ], 319 | "index": "pypi", 320 | "version": "==2.5.2" 321 | }, 322 | "pyparsing": { 323 | "hashes": [ 324 | "sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1", 325 | "sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b" 326 | ], 327 | "version": "==2.4.7" 328 | }, 329 | "pytest": { 330 | "hashes": [ 331 | "sha256:0e5b30f5cb04e887b91b1ee519fa3d89049595f428c1db76e73bd7f17b09b172", 332 | "sha256:84dde37075b8805f3d1f392cc47e38a0e59518fb46a431cfdaf7cf1ce805f970" 333 | ], 334 | "index": "pypi", 335 | "version": "==5.4.1" 336 | }, 337 | "pytest-ordering": { 338 | "hashes": [ 339 | "sha256:27fba3fc265f5d0f8597e7557885662c1bdc1969497cd58aff6ed21c3b617de2", 340 | "sha256:3f314a178dbeb6777509548727dc69edf22d6d9a2867bf2d310ab85c403380b6", 341 | "sha256:561ad653626bb171da78e682f6d39ac33bb13b3e272d406cd555adb6b006bda6" 342 | ], 343 | "index": "pypi", 344 | "version": "==0.6" 345 | }, 346 | "requests": { 347 | "hashes": [ 348 | "sha256:11e007a8a2aa0323f5a921e9e6a2d7e4e67d9877e85773fba9ba6419025cbeb4", 349 | "sha256:9cf5292fcd0f598c671cfc1e0d7d1a7f13bb8085e9a590f48c010551dc6c4b31" 350 | ], 351 | "index": "pypi", 352 | "version": "==2.22.0" 353 | }, 354 | "six": { 355 | "hashes": [ 356 | "sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259", 357 | "sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced" 358 | ], 359 | "version": "==1.15.0" 360 | }, 361 | "toml": { 362 | "hashes": [ 363 | "sha256:926b612be1e5ce0634a2ca03470f95169cf16f939018233a670519cb4ac58b0f", 364 | "sha256:bda89d5935c2eac546d648028b9901107a595863cb36bae0c73ac804a9b4ce88" 365 | ], 366 | "version": "==0.10.1" 367 | }, 368 | "torch": { 369 | "hashes": [ 370 | "sha256:3cc72d36eaeda96488e3a29373f739b887338952417b3e1620871063bf5d14d2", 371 | "sha256:402951484443bb49b5bc2129414ac6c644c07b8378e79922cf3645fd08cbfdc9", 372 | "sha256:6fcfe5deaf0788bbe8639869d3c752ff5fe1bdedce11c7ed2d44379b1fbe6d6c", 373 | "sha256:7f3d6af2d7e2576b9640aa684f0c18a773efffe8b37f9056272287345c1dcba5", 374 | "sha256:865d4bec21542647e0822e8b753e05d67eee874974a3937273f710edd99a7516", 375 | "sha256:931b79aed9aba50bf314214be6efaaf7972ea9539a3d63f82622bc5860a1fd81", 376 | "sha256:cb4412c6b00117ab5e014d07dac45b87f1e918e31fbb849e7e39f1f9140fff59", 377 | "sha256:dfaac4c5d27ac80705956743c34fb1ab5fb37e1646a6c8e45f05f7e739f6ea7c", 378 | "sha256:ecdc2ea4011e3ec04937b6b9e803ab671c3ac04e81b1df20354e01453e508b2f" 379 | ], 380 | "index": "pypi", 381 | "version": "==1.5.0" 382 | }, 383 | "torchvision": { 384 | "hashes": [ 385 | "sha256:0ea04a7e0f64599c158d36da01afd0cb3bc49033d2a145be4eb80c17c4c0482b", 386 | "sha256:0fa9e4a8381e5e04d0da0acd93f1429347053497ec343fe6d625b1b7fb2ce36e", 387 | "sha256:691d68f3726b7392fe37db7184aef8a6b6f7cf6ff38fae769b287b3d6e1eb69a", 388 | "sha256:6eb4e0d7dc61030447b98d412162f222a95d848b3b0e484a81282c057af6dd25", 389 | "sha256:8992f10a7860e0991766a788b546d5f11e3e7465e87a72eb9c78675dd2616400", 390 | "sha256:a9b08435fdadd89520a78f5a54d196c05878d1a15e37f760d43f72f10bae308f", 391 | "sha256:ea39bed9e9497a67c5f66e37d3d5a663a0284868ae8616de81f65c66d9ad802b", 392 | "sha256:f43dae3b348afa5778439913ba1f3f176362ffc9e684ef01dc54dae7cf1b82e4" 393 | ], 394 | "index": "pypi", 395 | "version": "==0.6.0" 396 | }, 397 | "typed-ast": { 398 | "hashes": [ 399 | "sha256:0666aa36131496aed8f7be0410ff974562ab7eeac11ef351def9ea6fa28f6355", 400 | "sha256:0c2c07682d61a629b68433afb159376e24e5b2fd4641d35424e462169c0a7919", 401 | "sha256:249862707802d40f7f29f6e1aad8d84b5aa9e44552d2cc17384b209f091276aa", 402 | "sha256:24995c843eb0ad11a4527b026b4dde3da70e1f2d8806c99b7b4a7cf491612652", 403 | "sha256:269151951236b0f9a6f04015a9004084a5ab0d5f19b57de779f908621e7d8b75", 404 | "sha256:4083861b0aa07990b619bd7ddc365eb7fa4b817e99cf5f8d9cf21a42780f6e01", 405 | "sha256:498b0f36cc7054c1fead3d7fc59d2150f4d5c6c56ba7fb150c013fbc683a8d2d", 406 | "sha256:4e3e5da80ccbebfff202a67bf900d081906c358ccc3d5e3c8aea42fdfdfd51c1", 407 | "sha256:6daac9731f172c2a22ade6ed0c00197ee7cc1221aa84cfdf9c31defeb059a907", 408 | "sha256:715ff2f2df46121071622063fc7543d9b1fd19ebfc4f5c8895af64a77a8c852c", 409 | "sha256:73d785a950fc82dd2a25897d525d003f6378d1cb23ab305578394694202a58c3", 410 | "sha256:8c8aaad94455178e3187ab22c8b01a3837f8ee50e09cf31f1ba129eb293ec30b", 411 | "sha256:8ce678dbaf790dbdb3eba24056d5364fb45944f33553dd5869b7580cdbb83614", 412 | "sha256:aaee9905aee35ba5905cfb3c62f3e83b3bec7b39413f0a7f19be4e547ea01ebb", 413 | "sha256:bcd3b13b56ea479b3650b82cabd6b5343a625b0ced5429e4ccad28a8973f301b", 414 | "sha256:c9e348e02e4d2b4a8b2eedb48210430658df6951fa484e59de33ff773fbd4b41", 415 | "sha256:d205b1b46085271b4e15f670058ce182bd1199e56b317bf2ec004b6a44f911f6", 416 | "sha256:d43943ef777f9a1c42bf4e552ba23ac77a6351de620aa9acf64ad54933ad4d34", 417 | "sha256:d5d33e9e7af3b34a40dc05f498939f0ebf187f07c385fd58d591c533ad8562fe", 418 | "sha256:fc0fea399acb12edbf8a628ba8d2312f583bdbdb3335635db062fa98cf71fca4", 419 | "sha256:fe460b922ec15dd205595c9b5b99e2f056fd98ae8f9f56b888e7a17dc2b757e7" 420 | ], 421 | "markers": "implementation_name == 'cpython' and python_version < '3.8'", 422 | "version": "==1.4.1" 423 | }, 424 | "urllib3": { 425 | "hashes": [ 426 | "sha256:3018294ebefce6572a474f0604c2021e33b3fd8006ecd11d62107a5d2a963527", 427 | "sha256:88206b0eb87e6d677d424843ac5209e3fb9d0190d0ee169599165ec25e9d9115" 428 | ], 429 | "version": "==1.25.9" 430 | }, 431 | "wcwidth": { 432 | "hashes": [ 433 | "sha256:79375666b9954d4a1a10739315816324c3e73110af9d0e102d906fdb0aec009f", 434 | "sha256:8c6b5b6ee1360b842645f336d9e5d68c55817c26d3050f46b235ef2bc650e48f" 435 | ], 436 | "version": "==0.2.4" 437 | }, 438 | "werkzeug": { 439 | "hashes": [ 440 | "sha256:2de2a5db0baeae7b2d2664949077c2ac63fbd16d98da0ff71837f7d1dea3fd43", 441 | "sha256:6c80b1e5ad3665290ea39320b91e1be1e0d5f60652b964a3070216de83d2e47c" 442 | ], 443 | "version": "==1.0.1" 444 | }, 445 | "wrapt": { 446 | "hashes": [ 447 | "sha256:b62ffa81fb85f4332a4f609cab4ac40709470da05643a082ec1eb88e6d9b97d7" 448 | ], 449 | "version": "==1.12.1" 450 | }, 451 | "zipp": { 452 | "hashes": [ 453 | "sha256:aa36550ff0c0b7ef7fa639055d797116ee891440eac1a56f378e2d3179e0320b", 454 | "sha256:c599e4d75c98f6798c509911d08a22e6c021d074469042177c8c86fb92eefd96" 455 | ], 456 | "version": "==3.1.0" 457 | } 458 | }, 459 | "develop": {} 460 | } 461 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OpenProtein 2 | 3 | A PyTorch framework for tertiary protein structure prediction. 4 | 5 | ![Alt text](examplemodelrun.png?raw=true "OpenProtein") 6 | 7 | ## Getting started 8 | To run this project, you will need `pipenv`: https://pipenv-fork.readthedocs.io/en/latest/install.html 9 | 10 | After you have installed `pipenv`, simply git clone this repository, install dependencies using `pipenv install` and then type `pipenv run python __main__.py` in the terminal to run the sample experiment: 11 | 12 | ``` 13 | $ pipenv run python __main__.py 14 | ------------------------ 15 | --- OpenProtein v0.1 --- 16 | ------------------------ 17 | Live plot deactivated, see output folder for plot. 18 | Starting pre-processing of raw data... 19 | Preprocessed file for testing.txt already exists. 20 | force_pre_processing_overwrite flag set to True, overwriting old file... 21 | Processing raw data file testing.txt 22 | Wrote output to 81 proteins to data/preprocessed/testing.txt.hdf5 23 | Completed pre-processing. 24 | 2018-09-27 19:27:34: Train loss: -781787.696391812 25 | 2018-09-27 19:27:35: Loss time: 1.8300042152404785 Grad time: 0.5147676467895508 26 | ... 27 | ``` 28 | 29 | You can view a live dashboard of the model's performance by navigating to https://biolib.com/openprotein. You can customize this dashboard by forking https://github.com/biolib/openprotein-dashboard. 30 | 31 | ## Developing a Predictive Model 32 | See `models.py` for examples of how to create your own model. 33 | 34 | To run pylint on every commit, run `git config core.hooksPath git-hooks`. 35 | 36 | ## Using a Predictive Model 37 | See `prediction.py` for examples of how to use pre-trained models. 38 | 39 | ## Memory Usage 40 | OpenProtein includes a preprocessing tool (`preprocessing.py`) which will transform the standard ProteinNet format into a hdf5 file and save it in `data/preprocessed/`. This is done in a memory-efficient way (line-by-line). 41 | 42 | The OpenProtein PyTorch data loader is memory optimized too - when reading the hdf5 file it will only load the samples needed for each minibatch into memory. 43 | 44 | ## License 45 | Please see the LICENSE file in the root directory. 46 | -------------------------------------------------------------------------------- /__main__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the OpenProtein project. 3 | 4 | For license information, please see the LICENSE file in the root directory. 5 | """ 6 | 7 | from op_cli import main 8 | 9 | print("------------------------") 10 | print("--- OpenProtein v0.1 ---") 11 | print("------------------------") 12 | 13 | main() 14 | -------------------------------------------------------------------------------- /dashboard.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the OpenProtein project. 3 | 4 | For license information, please see the LICENSE file in the root directory. 5 | """ 6 | 7 | import logging 8 | import threading 9 | from flask import Flask, request, jsonify 10 | from flask_cors import CORS, cross_origin 11 | 12 | APP = Flask(__name__) 13 | CORS = CORS(APP) 14 | DATA = None 15 | 16 | 17 | @APP.route('/graph', methods=['POST']) 18 | def update_graph(): 19 | global DATA 20 | DATA = request.json 21 | return jsonify({"result": "OK"}) 22 | 23 | 24 | @APP.route('/graph', methods=['GET']) 25 | @cross_origin() 26 | def get_graph(): 27 | return jsonify(DATA) 28 | 29 | class GraphWebServer(threading.Thread): 30 | def __init__(self): 31 | threading.Thread.__init__(self) 32 | 33 | def run(self): 34 | logging.basicConfig(filename="output/app.log", level=logging.DEBUG) 35 | APP.run(debug=False, host='0.0.0.0') 36 | 37 | def start_dashboard_server(): 38 | flask_thread = GraphWebServer() 39 | flask_thread.start() 40 | -------------------------------------------------------------------------------- /data/mixture_model_pfam_100.txt: -------------------------------------------------------------------------------- 1 | # weight mu1 mu2 k1 k2 k3 lognormc env_K env1_mu env2_mu env1_k env2_k env1_w env2_w 2 | 0.0083 -1.5084 0.0035 113.0698 136.2962 29.5309 217.1773 -0.0010 0.0000 -0.0000 96.6693 96.6693 0.5000 0.5000 3 | 0.0103 -2.1032 2.7417 63.4853 72.4552 13.0657 120.7424 -0.0021 0.0000 -0.0000 56.2145 56.2145 0.5000 0.5000 4 | 0.0325 -1.0201 -0.7034 143.3961 103.7088 10.2042 234.0299 -0.0009 0.0000 -0.0000 92.7690 92.7690 0.5000 0.5000 5 | 0.0031 1.0451 0.5031 245.8758 146.8340 71.5110 318.5327 0.0054 0.0000 -0.0000 46.8872 46.8872 0.5000 0.5000 6 | 0.0145 -1.3898 2.4876 75.2006 52.6907 12.9034 112.9555 -0.0013 0.0000 -0.0000 37.2914 37.2914 0.5000 0.5000 7 | 0.0028 1.3261 -2.7816 19.0705 8.7561 2.2965 25.0697 -0.0051 0.0000 -0.0000 6.2548 6.2548 0.5000 0.5000 8 | 0.0150 -1.8618 2.0216 47.0896 52.7199 5.6637 92.2091 -0.0029 0.0000 -0.0000 46.3698 46.3698 0.5000 0.5000 9 | 0.0029 -3.1025 2.9592 44.8640 22.3103 11.5941 54.6787 0.0219 0.0000 -0.0000 7.3833 7.3833 0.5000 0.5000 10 | 0.0092 -1.4527 -0.7120 44.5370 54.8900 6.4190 91.1011 -0.0032 0.0000 -0.0000 47.5035 47.5035 0.5000 0.5000 11 | 0.0040 1.6503 -0.1326 15.2696 15.8404 -2.1562 32.2534 -0.0073 0.0000 -0.0000 17.6811 17.6811 0.5000 0.5000 12 | 0.0055 -1.4771 1.2693 98.3477 22.3851 5.8549 113.0687 -0.0010 0.0000 -0.0000 16.2061 16.2061 0.5000 0.5000 13 | 0.0067 -2.1198 -0.0866 59.5754 35.6706 15.0278 78.7766 0.0049 0.0000 -0.0000 16.0794 16.0794 0.5000 0.5000 14 | 0.0127 -2.5019 2.6424 57.1511 56.5996 8.4147 103.3124 -0.0024 0.0000 -0.0000 46.8538 46.8538 0.5000 0.5000 15 | 0.0297 -1.2004 -0.5517 118.3727 77.5187 3.2753 189.9308 -0.0011 0.0000 -0.0000 74.1656 74.1656 0.5000 0.5000 16 | 0.0169 -1.1309 2.2775 46.9944 90.2105 8.7825 126.2553 -0.0032 0.0000 -0.0000 79.5704 79.5704 0.5000 0.5000 17 | 0.0009 -3.0105 2.2217 2.9907 6.8700 0.4790 9.9166 0.0004 0.0000 -0.0000 6.4381 6.4381 0.5000 0.5000 18 | 0.0048 -1.4850 3.1239 38.4659 31.1834 -3.4273 71.2862 -0.0030 0.0000 -0.0000 34.2963 34.2963 0.5000 0.5000 19 | 0.0036 -2.2236 1.6250 43.3245 39.5820 12.8662 68.6363 0.0010 0.0000 -0.0000 21.8214 21.8214 0.5000 0.5000 20 | 0.0236 -1.0021 -0.8349 105.2489 116.7845 12.2951 206.9956 -0.0013 0.0000 -0.0000 102.9483 102.9483 0.5000 0.5000 21 | 0.0550 -1.1251 -0.7031 290.2342 250.7588 77.0140 460.6433 0.0006 0.0000 -0.0000 146.3586 146.3586 0.5000 0.5000 22 | 0.0058 -2.4115 2.1777 56.4320 56.5672 5.5372 105.3789 -0.0024 0.0000 -0.0000 50.4949 50.4949 0.5000 0.5000 23 | 0.0130 -1.1287 2.6035 83.8111 86.6203 7.5655 160.3593 -0.0016 0.0000 -0.0000 78.3646 78.3646 0.5000 0.5000 24 | 0.0114 -2.3615 2.3817 104.7394 98.4171 25.3429 175.3778 -0.0008 0.0000 -0.0000 65.3033 65.3033 0.5000 0.5000 25 | 0.0070 -2.0380 0.3959 44.4472 38.6648 12.9876 68.7268 0.0014 0.0000 -0.0000 20.8518 20.8518 0.5000 0.5000 26 | 0.0084 -2.7088 2.9612 36.4545 46.4075 5.7506 75.4034 -0.0039 0.0000 -0.0000 39.7097 39.7097 0.5000 0.5000 27 | 0.0140 -1.9888 2.3072 49.8199 78.3050 -1.8603 127.6630 -0.0024 0.0000 -0.0000 80.0814 80.0814 0.5000 0.5000 28 | 0.0025 -2.0519 -0.9696 11.6243 18.3270 -1.0758 30.1316 -0.0102 0.0000 -0.0000 19.2747 19.2747 0.5000 0.5000 29 | 0.0019 2.7413 -3.0481 6.7088 7.6149 -1.7641 15.7923 0.0007 0.0000 -0.0000 8.9382 8.9382 0.5000 0.5000 30 | 0.0251 -1.1380 -0.3421 263.9096 199.5376 84.6411 375.8852 0.0036 0.0000 -0.0000 75.8466 75.8466 0.5000 0.5000 31 | 0.0069 1.4452 0.1218 43.3758 52.9241 8.6310 85.8664 -0.0032 0.0000 -0.0000 42.3429 42.3429 0.5000 0.5000 32 | 0.0035 1.7447 2.9094 8.2297 3.1953 -0.5270 12.1118 0.0006 0.0000 -0.0000 3.6652 3.6652 0.5000 0.5000 33 | 0.0149 -1.5230 2.1985 40.6114 37.8142 -11.5962 87.9646 -0.0022 0.0000 -0.0000 46.7644 46.7644 0.5000 0.5000 34 | 0.0131 -1.5547 1.9275 45.0501 36.1598 7.2396 72.3391 -0.0027 0.0000 -0.0000 27.6824 27.6824 0.5000 0.5000 35 | 0.0039 -1.9500 -0.4775 73.2013 61.0354 26.5357 106.1039 0.0146 0.0000 -0.0000 20.7086 20.7086 0.5000 0.5000 36 | 0.0017 1.0155 1.0994 27.8096 2.7893 -2.5048 32.4556 -0.0022 0.0000 -0.0000 5.0674 5.0674 0.5000 0.5000 37 | 0.0011 -2.3617 -2.7851 11.8248 13.4945 1.7691 23.0353 -0.0123 0.0000 -0.0000 11.5346 11.5346 0.5000 0.5000 38 | 0.0077 -1.6551 -0.4082 70.1827 91.1573 22.5148 136.7000 -0.0011 0.0000 -0.0000 58.5447 58.5447 0.5000 0.5000 39 | 0.0546 -1.1072 -0.7491 280.1487 207.3076 46.6014 437.4562 0.0001 0.0000 -0.0000 151.5696 151.5696 0.5000 0.5000 40 | 0.0071 -1.3718 2.9291 57.2998 52.6904 -1.8349 109.6288 -0.0021 0.0000 -0.0000 54.4538 54.4538 0.5000 0.5000 41 | 0.0061 -1.9557 0.1804 68.5315 44.6488 9.3437 101.8774 -0.0017 0.0000 -0.0000 33.9511 33.9511 0.5000 0.5000 42 | 0.0008 -2.6361 0.9743 5.6999 4.6168 -2.2369 12.4861 0.0023 0.0000 -0.0000 6.1428 6.1428 0.5000 0.5000 43 | 0.0099 -2.6704 2.4740 66.9637 54.7911 14.5996 105.2221 -0.0012 0.0000 -0.0000 36.3874 36.3874 0.5000 0.5000 44 | 0.0635 -1.0629 -0.7708 459.9234 460.9074 149.9729 767.0893 0.0007 0.0000 -0.0000 239.0639 239.0639 0.5000 0.5000 45 | 0.0091 -1.5990 2.7036 37.5996 41.0640 -1.4739 78.2746 -0.0032 0.0000 -0.0000 42.4648 42.4648 0.5000 0.5000 46 | 0.0007 1.2913 -0.9779 59.8127 23.4495 1.9236 79.6218 -0.0021 0.0000 -0.0000 21.4807 21.4807 0.5000 0.5000 47 | 0.0115 -2.3501 2.7865 67.0045 80.5047 9.7918 135.4177 -0.0021 0.0000 -0.0000 69.1535 69.1535 0.5000 0.5000 48 | 0.0029 -1.6731 -2.6263 13.1022 8.0936 4.4739 17.0897 0.0936 0.0000 -0.0000 2.5660 2.5660 0.5000 0.5000 49 | 0.0019 -2.4677 -0.5719 1.3807 5.5571 -1.1059 8.6184 0.0006 0.0000 -0.0000 6.0325 6.0325 0.5000 0.5000 50 | 0.0179 -2.1597 2.1876 57.1548 90.0871 6.5944 138.3201 -0.0025 0.0000 -0.0000 82.7129 82.7129 0.5000 0.5000 51 | 0.0111 -1.4245 -0.2467 62.3329 88.2098 11.1098 137.1490 -0.0023 0.0000 -0.0000 74.8454 74.8454 0.5000 0.5000 52 | 0.0052 1.2112 0.4549 98.5076 106.1632 38.0549 164.4974 0.0048 0.0000 -0.0000 45.3563 45.3563 0.5000 0.5000 53 | 0.0374 -1.0843 -0.5893 155.6683 85.0449 8.0760 229.8113 -0.0008 0.0000 -0.0000 76.5591 76.5591 0.5000 0.5000 54 | 0.0034 -2.2707 1.2621 78.2582 24.5112 -1.3030 102.1037 -0.0016 0.0000 -0.0000 25.7852 25.7852 0.5000 0.5000 55 | 0.0078 -1.0567 2.3380 18.5956 12.3176 2.0129 28.1959 -0.0070 0.0000 -0.0000 10.1429 10.1429 0.5000 0.5000 56 | 0.0022 0.9917 -2.2380 70.4637 45.4666 1.3888 112.3737 -0.0018 0.0000 -0.0000 44.0605 44.0605 0.5000 0.5000 57 | 0.0026 1.9539 -2.2436 5.5663 7.2303 3.1451 10.6249 0.0931 0.0000 -0.0000 3.0407 3.0407 0.5000 0.5000 58 | 0.0090 -1.3140 2.1086 53.2903 58.2448 -0.3940 109.7444 -0.0024 0.0000 -0.0000 58.6324 58.6324 0.5000 0.5000 59 | 0.0129 -1.1136 -0.9065 32.6148 42.3651 7.1637 66.2897 -0.0043 0.0000 -0.0000 33.4136 33.4136 0.5000 0.5000 60 | 0.0146 -2.1883 2.3981 79.2874 86.2513 10.4271 152.6826 -0.0017 0.0000 -0.0000 74.3462 74.3462 0.5000 0.5000 61 | 0.0015 1.8984 -0.2899 8.3426 9.4556 1.4133 16.2571 0.0014 0.0000 -0.0000 7.9049 7.9049 0.5000 0.5000 62 | 0.0078 -1.7666 -0.1282 35.8765 58.7680 0.1929 92.4727 -0.0036 0.0000 -0.0000 58.5768 58.5768 0.5000 0.5000 63 | 0.0143 -1.3454 -0.1713 75.7398 84.2130 13.0944 144.5176 -0.0018 0.0000 -0.0000 68.5362 68.5362 0.5000 0.5000 64 | 0.0095 -1.6620 2.4029 81.1256 66.7943 8.3303 137.2608 -0.0016 0.0000 -0.0000 57.5851 57.5851 0.5000 0.5000 65 | 0.0057 -1.9324 1.7809 31.7435 37.8451 -2.4658 70.2873 -0.0037 0.0000 -0.0000 40.1017 40.1017 0.5000 0.5000 66 | 0.0050 -0.9139 -0.8446 26.0279 31.5322 6.1360 50.1938 -0.0051 0.0000 -0.0000 23.7690 23.7690 0.5000 0.5000 67 | 0.0064 -1.4522 0.2424 35.1147 26.0843 14.2485 46.4945 0.1123 0.0000 -0.0000 5.0481 5.0481 0.5000 0.5000 68 | 0.0094 -2.7351 2.8046 74.7883 86.3670 14.0125 144.8116 -0.0018 0.0000 -0.0000 69.3001 69.3001 0.5000 0.5000 69 | 0.0048 -1.6172 0.1324 71.6738 98.9238 11.2979 156.8657 -0.0020 0.0000 -0.0000 85.6389 85.6389 0.5000 0.5000 70 | 0.0019 -2.4293 0.4471 22.9164 6.5262 -3.0611 31.6228 -0.0041 0.0000 -0.0000 9.1897 9.1897 0.5000 0.5000 71 | 0.0152 -1.2346 2.7599 87.3437 70.8930 4.6627 151.1130 -0.0015 0.0000 -0.0000 65.9994 65.9994 0.5000 0.5000 72 | 0.0015 2.6107 -2.8692 4.9263 20.1099 1.8991 23.0453 -0.0496 0.0000 -0.0000 17.5882 17.5882 0.5000 0.5000 73 | 0.0051 -2.2712 3.0191 29.1842 46.7763 -4.0752 78.1688 -0.0038 0.0000 -0.0000 50.3021 50.3021 0.5000 0.5000 74 | 0.0022 1.3049 -2.8962 97.3863 56.8729 35.9048 117.3728 0.1996 0.0000 -0.0000 5.1504 5.1504 0.5000 0.5000 75 | 0.0044 -1.5503 1.4001 20.4511 11.4079 3.1866 28.0826 -0.0043 0.0000 -0.0000 7.7927 7.7927 0.5000 0.5000 76 | 0.0045 1.0180 0.7965 45.5699 44.2131 17.1832 71.3297 0.0119 0.0000 -0.0000 17.8263 17.8263 0.5000 0.5000 77 | 0.0028 1.2279 0.2163 173.2279 194.8744 74.7785 290.7754 0.0091 0.0000 -0.0000 65.3036 65.3036 0.5000 0.5000 78 | 0.0141 -1.1516 2.4553 92.6129 79.3735 10.0279 159.4815 -0.0014 0.0000 -0.0000 68.2069 68.2069 0.5000 0.5000 79 | 0.0098 -1.0342 -0.5157 54.1554 57.5211 12.8086 96.9933 -0.0021 0.0000 -0.0000 41.0266 41.0266 0.5000 0.5000 80 | 0.0074 -1.8457 2.8349 43.2388 23.6742 -1.6559 66.8980 -0.0028 0.0000 -0.0000 25.2524 25.2524 0.5000 0.5000 81 | 0.0010 0.9741 1.2603 3.7274 7.2870 2.4472 9.7426 0.0431 0.0000 -0.0000 4.1150 4.1150 0.5000 0.5000 82 | 0.0040 -2.6215 2.0300 43.2344 28.8242 9.1341 61.5757 -0.0007 0.0000 -0.0000 17.5212 17.5212 0.5000 0.5000 83 | 0.0115 -1.9230 2.5082 61.2638 59.9530 2.4821 116.5151 -0.0021 0.0000 -0.0000 57.3891 57.3891 0.5000 0.5000 84 | 0.0142 -1.3396 -0.4535 65.4361 50.7956 -5.4455 119.3778 -0.0017 0.0000 -0.0000 55.7905 55.7905 0.5000 0.5000 85 | 0.0027 -1.4766 -1.3553 1.3263 8.0342 -1.0635 10.8448 0.0003 0.0000 -0.0000 8.4825 8.4825 0.5000 0.5000 86 | 0.0045 -1.6882 -0.7859 46.0438 35.6209 13.1062 67.2160 0.0031 0.0000 -0.0000 17.8545 17.8545 0.5000 0.5000 87 | 0.0025 -1.8498 -2.9531 13.3274 14.4568 -2.6149 29.4641 -0.0079 0.0000 -0.0000 16.5827 16.5827 0.5000 0.5000 88 | 0.0125 -0.9510 -0.7423 105.0493 93.7378 28.3983 168.0543 -0.0001 0.0000 -0.0000 55.2485 55.2485 0.5000 0.5000 89 | 0.0034 1.5460 -0.1349 186.5616 215.4222 87.9684 311.5724 0.0205 0.0000 -0.0000 52.4659 52.4659 0.5000 0.5000 90 | 0.0125 -0.9906 2.4144 183.8523 119.7054 39.5024 261.2873 0.0000 0.0000 -0.0000 69.6920 69.6920 0.5000 0.5000 91 | 0.0082 -1.7608 0.2056 77.2921 48.0191 8.5475 114.6662 -0.0016 0.0000 -0.0000 38.4967 38.4967 0.5000 0.5000 92 | 0.0017 -2.1282 -0.7268 70.4133 55.3215 26.8481 97.5221 0.0355 0.0000 -0.0000 13.9059 13.9059 0.5000 0.5000 93 | 0.0142 -2.2477 2.6106 74.3894 96.6865 20.1329 148.6676 -0.0016 0.0000 -0.0000 69.4359 69.4359 0.5000 0.5000 94 | 0.0085 -1.6379 -0.0378 112.4489 99.2228 33.2829 176.0633 0.0008 0.0000 -0.0000 52.5162 52.5162 0.5000 0.5000 95 | 0.0030 -2.6111 -3.0312 16.0573 38.6669 0.0357 53.3237 -0.0081 0.0000 -0.0000 38.6322 38.6322 0.5000 0.5000 96 | 0.0021 1.4776 0.4043 13.3189 12.9586 4.4526 21.6296 0.0029 0.0000 -0.0000 6.9581 6.9581 0.5000 0.5000 97 | 0.0029 -2.5638 1.3147 38.0852 25.9641 10.2454 52.7277 0.0050 0.0000 -0.0000 12.4663 12.4663 0.5000 0.5000 98 | 0.0162 -1.8352 2.2163 50.8087 76.7135 0.5871 124.6527 -0.0025 0.0000 -0.0000 76.1255 76.1255 0.5000 0.5000 99 | 0.0279 -1.2006 -0.7635 120.9288 129.9431 30.3205 217.8897 -0.0008 0.0000 -0.0000 89.7978 89.7978 0.5000 0.5000 100 | 0.0086 -2.1408 1.9781 27.5005 48.7649 -9.0599 83.3612 -0.0034 0.0000 -0.0000 55.4977 55.4977 0.5000 0.5000 101 | 0.0085 0.9538 0.7604 148.9607 102.4387 37.8890 211.0194 0.0009 0.0000 -0.0000 52.0776 52.0776 0.5000 0.5000 102 | -------------------------------------------------------------------------------- /data/mixture_model_pfam_50.txt: -------------------------------------------------------------------------------- 1 | # weight mu1 mu2 k1 k2 k3 lognormc env_K env1_mu env2_mu env1_k env2_k env1_w env2_w 2 | 0.0089 -1.4793 1.2925 59.1112 12.5706 1.4416 68.8608 -0.0021 0.0000 -0.0000 11.1075 11.1075 0.5000 0.5000 3 | 0.1306 -1.0933 -0.7343 392.0321 310.4931 122.3227 576.7950 0.0016 0.0000 -0.0000 133.4432 133.4432 0.5000 0.5000 4 | 0.0210 -1.8969 2.5849 40.3007 29.7811 3.6540 64.8477 -0.0033 0.0000 -0.0000 25.8257 25.8257 0.5000 0.5000 5 | 0.0135 1.2898 0.3302 70.2749 49.9771 26.4274 92.6223 0.0589 0.0000 -0.0000 10.0252 10.0252 0.5000 0.5000 6 | 0.0012 1.3114 1.9595 0.9351 4.5001 -0.7742 7.1334 0.0004 0.0000 -0.0000 4.7761 4.7761 0.5000 0.5000 7 | 0.0152 -1.1632 2.3411 24.5792 33.6777 5.3108 51.6656 -0.0059 0.0000 -0.0000 27.1214 27.1214 0.5000 0.5000 8 | 0.0145 -1.7442 0.1278 61.5967 60.0216 17.6393 102.1419 -0.0003 0.0000 -0.0000 35.7677 35.7677 0.5000 0.5000 9 | 0.0140 -2.7318 2.8397 52.8449 54.6613 7.4815 98.0456 -0.0026 0.0000 -0.0000 46.0591 46.0591 0.5000 0.5000 10 | 0.0391 -2.0189 2.2420 37.5288 49.0794 1.4824 83.2465 -0.0035 0.0000 -0.0000 47.5581 47.5581 0.5000 0.5000 11 | 0.0018 0.9891 -2.2559 118.9105 61.4887 10.2889 167.6489 -0.0010 0.0000 -0.0000 50.2883 50.2883 0.5000 0.5000 12 | 0.0244 -1.2942 -0.2693 66.3745 81.3625 10.8622 134.5948 -0.0021 0.0000 -0.0000 68.5132 68.5132 0.5000 0.5000 13 | 0.0202 -1.3481 2.6536 25.4948 44.6410 2.5788 65.9706 -0.0055 0.0000 -0.0000 41.8387 41.8387 0.5000 0.5000 14 | 0.0118 -2.4003 2.1115 13.5648 7.3561 -4.6899 24.8264 0.0015 0.0000 -0.0000 10.7747 10.7747 0.5000 0.5000 15 | 0.0245 -2.3506 2.7424 41.3987 48.5439 0.2158 87.7724 -0.0031 0.0000 -0.0000 48.3296 48.3296 0.5000 0.5000 16 | 0.0071 -1.7724 -0.9567 2.8823 5.5963 0.8887 8.4242 0.0031 0.0000 -0.0000 4.7337 4.7337 0.5000 0.5000 17 | 0.0059 1.4065 0.0198 178.3234 170.4160 85.5593 261.3892 0.1393 0.0000 -0.0000 15.4286 15.4286 0.5000 0.5000 18 | 0.0167 -1.4914 2.0128 28.6372 20.9077 -6.5769 54.5544 -0.0033 0.0000 -0.0000 26.1953 26.1953 0.5000 0.5000 19 | 0.0154 -1.5691 -0.3904 66.3775 49.7785 19.9408 94.5942 0.0052 0.0000 -0.0000 21.9678 21.9678 0.5000 0.5000 20 | 0.0117 0.9721 0.7102 163.7319 93.4395 45.5604 209.3455 0.0068 0.0000 -0.0000 31.1140 31.1140 0.5000 0.5000 21 | 0.0069 2.5866 -2.9507 4.4926 6.2577 1.8936 9.6242 0.0218 0.0000 -0.0000 3.9690 3.9690 0.5000 0.5000 22 | 0.0079 -1.7367 1.9374 7.1566 14.9221 2.2834 19.6374 0.0027 0.0000 -0.0000 12.0103 12.0103 0.5000 0.5000 23 | 0.0119 -2.0156 0.3451 17.0390 19.3480 2.2900 33.1948 -0.0083 0.0000 -0.0000 16.8053 16.8053 0.5000 0.5000 24 | 0.0189 -1.4266 2.8822 28.7219 18.7743 5.5721 40.9595 -0.0023 0.0000 -0.0000 12.0941 12.0941 0.5000 0.5000 25 | 0.0475 -1.0352 -0.7812 48.6386 62.8086 8.5283 100.9349 -0.0029 0.0000 -0.0000 52.6196 52.6196 0.5000 0.5000 26 | 0.0221 -1.1577 2.6839 143.5208 97.3161 48.8277 189.9642 0.0179 0.0000 -0.0000 24.8715 24.8715 0.5000 0.5000 27 | 0.0236 -2.3100 2.3404 71.4305 51.7989 14.9551 106.3502 -0.0008 0.0000 -0.0000 33.1454 33.1454 0.5000 0.5000 28 | 0.0571 -1.0809 -0.4977 246.4654 126.5239 68.3392 302.1486 0.0096 0.0000 -0.0000 32.9848 32.9848 0.5000 0.5000 29 | 0.0057 -1.8359 -0.8206 10.9159 13.2109 -1.3238 24.7239 0.0001 0.0000 -0.0000 14.3458 14.3458 0.5000 0.5000 30 | 0.0084 -2.2173 3.0276 15.5945 31.2403 5.9900 40.0187 -0.0096 0.0000 -0.0000 22.2175 22.2175 0.5000 0.5000 31 | 0.0082 -2.2795 1.3435 19.5213 9.1535 -4.6873 32.3439 -0.0044 0.0000 -0.0000 12.8783 12.8783 0.5000 0.5000 32 | 0.0255 -1.2998 2.2864 33.1515 38.8386 -1.1666 71.3897 -0.0037 0.0000 -0.0000 39.9495 39.9495 0.5000 0.5000 33 | 0.0440 -1.2283 -0.5872 87.6355 70.5764 8.1014 147.7026 -0.0015 0.0000 -0.0000 61.7136 61.7136 0.5000 0.5000 34 | 0.0074 1.3848 -2.9308 23.9698 13.1639 7.6156 29.3977 0.0810 0.0000 -0.0000 3.3653 3.3653 0.5000 0.5000 35 | 0.0771 -1.0899 -0.7676 155.6101 212.2671 29.9761 334.7397 0.0001 0.0000 -0.0000 175.3172 175.3172 0.5000 0.5000 36 | 0.0227 -1.5013 -0.0572 80.3587 84.5734 25.1282 137.6982 0.0001 0.0000 -0.0000 48.5906 48.5906 0.5000 0.5000 37 | 0.0147 -2.0146 1.9301 27.4589 34.4796 4.6040 55.9324 -0.0052 0.0000 -0.0000 29.0909 29.0909 0.5000 0.5000 38 | 0.0122 -1.9725 -0.2267 32.6342 28.7058 5.6556 54.3387 -0.0039 0.0000 -0.0000 22.0296 22.0296 0.5000 0.5000 39 | 0.0188 -2.6318 2.5314 39.4815 26.2423 4.3893 59.8712 -0.0032 0.0000 -0.0000 21.3898 21.3898 0.5000 0.5000 40 | 0.0307 -1.7156 2.2033 36.1142 18.9796 -6.8534 60.3096 -0.0026 0.0000 -0.0000 24.6887 24.6887 0.5000 0.5000 41 | 0.0039 1.5348 -0.4550 7.5924 0.9976 -0.8818 10.0730 0.0076 0.0000 -0.0000 1.7627 1.7627 0.5000 0.5000 42 | 0.0062 -1.8619 -2.9406 7.6081 4.1640 1.2194 11.0085 0.0084 0.0000 -0.0000 2.8827 2.8827 0.5000 0.5000 43 | 0.0199 -2.1655 2.5265 55.8630 35.8449 5.2238 84.6632 -0.0023 0.0000 -0.0000 30.1501 30.1501 0.5000 0.5000 44 | 0.0203 -1.3181 -0.8256 27.1376 47.9499 13.5030 60.5553 0.0078 0.0000 -0.0000 23.0746 23.0746 0.5000 0.5000 45 | 0.0078 -2.7342 3.0986 10.1677 22.5140 4.4347 27.8666 -0.0160 0.0000 -0.0000 15.6063 15.6063 0.5000 0.5000 46 | 0.0065 1.6196 0.0563 19.5279 23.0190 6.6961 35.1276 -0.0014 0.0000 -0.0000 13.4946 13.4946 0.5000 0.5000 47 | 0.0035 0.9739 0.9503 19.4267 14.3572 2.1307 30.8394 -0.0069 0.0000 -0.0000 12.0465 12.0465 0.5000 0.5000 48 | 0.0110 -1.6810 0.1043 15.6285 23.5217 4.4044 33.9543 -0.0093 0.0000 -0.0000 17.7478 17.7478 0.5000 0.5000 49 | 0.0225 -1.0607 2.4078 107.0936 88.2002 27.5468 165.4291 -0.0000 0.0000 -0.0000 51.5131 51.5131 0.5000 0.5000 50 | 0.0052 -2.1287 0.7050 4.8079 2.9125 0.7643 7.8248 0.0078 0.0000 -0.0000 2.1639 2.1639 0.5000 0.5000 51 | 0.0242 -1.0554 -0.6145 57.8401 40.0773 -7.2532 103.0057 -0.0019 0.0000 -0.0000 46.4794 46.4794 0.5000 0.5000 52 | -------------------------------------------------------------------------------- /data/preprocessed/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /data/raw/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !sample.txt 4 | -------------------------------------------------------------------------------- /demo.mpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biolib/openprotein/3f474d3b1c00af0f06d88bf1ad78f2c34763341d/demo.mpg -------------------------------------------------------------------------------- /examplemodelrun.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biolib/openprotein/3f474d3b1c00af0f06d88bf1ad78f2c34763341d/examplemodelrun.png -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biolib/openprotein/3f474d3b1c00af0f06d88bf1ad78f2c34763341d/experiments/__init__.py -------------------------------------------------------------------------------- /experiments/example/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the OpenProtein project. 3 | 4 | For license information, please see the LICENSE file in the root directory. 5 | """ 6 | 7 | from preprocessing import process_raw_data 8 | 9 | from experiments.example.models import * 10 | from training import train_model 11 | from util import contruct_dataloader_from_disk 12 | 13 | 14 | def run_experiment(parser, use_gpu): 15 | # parse experiment specific command line arguments 16 | parser.add_argument('--learning-rate', dest='learning_rate', type=float, 17 | default=0.01, help='Learning rate to use during training.') 18 | parser.add_argument('--min-updates', dest='minimum_updates', type=int, 19 | default=1000, help='Minimum number of minibatch iterations.') 20 | parser.add_argument('--minibatch-size', dest='minibatch_size', type=int, 21 | default=1, help='Size of each minibatch.') 22 | args, _unknown = parser.parse_known_args() 23 | 24 | # pre-process data 25 | process_raw_data(use_gpu, force_pre_processing_overwrite=False) 26 | 27 | # run experiment 28 | training_file = "data/preprocessed/single_protein.txt.hdf5" 29 | validation_file = "data/preprocessed/single_protein.txt.hdf5" 30 | 31 | model = ExampleModel(21, args.minibatch_size, use_gpu=use_gpu) # embed size = 21 32 | 33 | train_loader = contruct_dataloader_from_disk(training_file, args.minibatch_size) 34 | validation_loader = contruct_dataloader_from_disk(validation_file, args.minibatch_size) 35 | 36 | train_model_path = train_model(data_set_identifier="TRAIN", 37 | model=model, 38 | train_loader=train_loader, 39 | validation_loader=validation_loader, 40 | learning_rate=args.learning_rate, 41 | minibatch_size=args.minibatch_size, 42 | eval_interval=args.eval_interval, 43 | hide_ui=args.hide_ui, 44 | use_gpu=use_gpu, 45 | minimum_updates=args.minimum_updates) 46 | 47 | print("Completed training, trained model stored at:") 48 | print(train_model_path) 49 | -------------------------------------------------------------------------------- /experiments/example/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the OpenProtein project. 3 | 4 | For license information, please see the LICENSE file in the root directory. 5 | """ 6 | 7 | import math 8 | import torch 9 | import torch.autograd as autograd 10 | import torch.nn as nn 11 | import numpy as np 12 | from torch.nn.utils.rnn import pack_padded_sequence 13 | 14 | import openprotein 15 | from util import get_backbone_positions_from_angles, compute_atan2 16 | 17 | # seed random generator for reproducibility 18 | torch.manual_seed(1) 19 | 20 | # sample model borrowed from 21 | # https://github.com/lblaabjerg/Master/blob/master/Models%20and%20processed%20data/ProteinNet_LSTM_500.py 22 | class ExampleModel(openprotein.BaseModel): 23 | def __init__(self, embedding_size, minibatch_size, use_gpu): 24 | super(ExampleModel, self).__init__(use_gpu, embedding_size) 25 | 26 | self.hidden_size = 25 27 | self.num_lstm_layers = 2 28 | self.mixture_size = 500 29 | self.bi_lstm = nn.LSTM(self.get_embedding_size(), self.hidden_size, 30 | num_layers=self.num_lstm_layers, 31 | bidirectional=True, bias=True) 32 | self.hidden_to_labels = nn.Linear(self.hidden_size * 2, 33 | self.mixture_size, bias=True) # * 2 for bidirectional 34 | self.init_hidden(minibatch_size) 35 | self.softmax_to_angle = SoftToAngle(self.mixture_size) 36 | self.batch_norm = nn.BatchNorm1d(self.mixture_size) 37 | 38 | def init_hidden(self, minibatch_size): 39 | # number of layers (* 2 since bidirectional), minibatch_size, hidden size 40 | initial_hidden_state = torch.zeros(self.num_lstm_layers * 2, 41 | minibatch_size, self.hidden_size) 42 | initial_cell_state = torch.zeros(self.num_lstm_layers * 2, 43 | minibatch_size, self.hidden_size) 44 | if self.use_gpu: 45 | initial_hidden_state = initial_hidden_state.cuda() 46 | initial_cell_state = initial_cell_state.cuda() 47 | self.hidden_layer = (autograd.Variable(initial_hidden_state), 48 | autograd.Variable(initial_cell_state)) 49 | 50 | def _get_network_emissions(self, original_aa_string): 51 | padded_input_sequences = self.embed(original_aa_string) 52 | minibatch_size = len(original_aa_string) 53 | batch_sizes = list([v.size(0) for v in original_aa_string]) 54 | packed_sequences = pack_padded_sequence(padded_input_sequences, batch_sizes) 55 | 56 | self.init_hidden(minibatch_size) 57 | (data, bi_lstm_batches, _, _), self.hidden_layer = self.bi_lstm( 58 | packed_sequences, self.hidden_layer) 59 | emissions_padded, batch_sizes = torch.nn.utils.rnn.pad_packed_sequence( 60 | torch.nn.utils.rnn.PackedSequence(self.hidden_to_labels(data), bi_lstm_batches)) 61 | emissions = emissions_padded.transpose(0, 1)\ 62 | .transpose(1, 2) # minibatch_size, self.mixture_size, -1 63 | emissions = self.batch_norm(emissions) 64 | emissions = emissions.transpose(1, 2) # (minibatch_size, -1, self.mixture_size) 65 | probabilities = torch.softmax(emissions, 2) 66 | output_angles = self.softmax_to_angle(probabilities)\ 67 | .transpose(0, 1) # max size, minibatch size, 3 (angles) 68 | backbone_atoms_padded, _ = \ 69 | get_backbone_positions_from_angles(output_angles, 70 | batch_sizes, 71 | self.use_gpu) 72 | return output_angles, backbone_atoms_padded, batch_sizes 73 | 74 | 75 | class SoftToAngle(nn.Module): 76 | def __init__(self, mixture_size): 77 | super(SoftToAngle, self).__init__() 78 | # Omega Initializer 79 | omega_components1 = np.random.uniform(0, 1, int(mixture_size*0.1)) # set omega 90/10 pos/neg 80 | omega_components2 = np.random.uniform(2, math.pi, int(mixture_size*0.9)) 81 | omega_components = np.concatenate((omega_components1, omega_components2)) 82 | np.random.shuffle(omega_components) 83 | 84 | phi_components = np.genfromtxt("data/mixture_model_pfam_" 85 | + str(mixture_size) + ".txt")[:, 1] 86 | psi_components = np.genfromtxt("data/mixture_model_pfam_" 87 | + str(mixture_size) + ".txt")[:, 2] 88 | 89 | self.phi_components = nn.Parameter(torch.from_numpy(phi_components) 90 | .contiguous().view(-1, 1).float()) 91 | self.psi_components = nn.Parameter(torch.from_numpy(psi_components) 92 | .contiguous().view(-1, 1).float()) 93 | self.omega_components = nn.Parameter(torch.from_numpy(omega_components) 94 | .view(-1, 1).float()) 95 | 96 | def forward(self, x): 97 | phi_input_sin = torch.matmul(x, torch.sin(self.phi_components)) 98 | phi_input_cos = torch.matmul(x, torch.cos(self.phi_components)) 99 | psi_input_sin = torch.matmul(x, torch.sin(self.psi_components)) 100 | psi_input_cos = torch.matmul(x, torch.cos(self.psi_components)) 101 | omega_input_sin = torch.matmul(x, torch.sin(self.omega_components)) 102 | omega_input_cos = torch.matmul(x, torch.cos(self.omega_components)) 103 | 104 | phi = compute_atan2(phi_input_sin, phi_input_cos) 105 | psi = compute_atan2(psi_input_sin, psi_input_cos) 106 | omega = compute_atan2(omega_input_sin, omega_input_cos) 107 | 108 | return torch.cat((phi, psi, omega), 2) 109 | -------------------------------------------------------------------------------- /experiments/my_model/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the OpenProtein project. 3 | 4 | For license information, please see the LICENSE file in the root directory. 5 | """ 6 | import math 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import openprotein 12 | from preprocessing import process_raw_data 13 | from training import train_model 14 | 15 | from util import get_backbone_positions_from_angles, contruct_dataloader_from_disk 16 | 17 | ANGLE_ARR = torch.tensor([[-120, 140, -370], [0, 120, -150], [25, -120, 150]]).float() 18 | 19 | def run_experiment(parser, use_gpu): 20 | # parse experiment specific command line arguments 21 | parser.add_argument('--learning-rate', dest='learning_rate', type=float, 22 | default=0.01, help='Learning rate to use during training.') 23 | 24 | parser.add_argument('--input-file', dest='input_file', type=str, 25 | default='data/preprocessed/protein_net_testfile.txt.hdf5') 26 | 27 | args, _unknown = parser.parse_known_args() 28 | 29 | # pre-process data 30 | process_raw_data(use_gpu, force_pre_processing_overwrite=False) 31 | 32 | # run experiment 33 | training_file = args.input_file 34 | validation_file = args.input_file 35 | 36 | model = MyModel(21, use_gpu=use_gpu) # embed size = 21 37 | 38 | train_loader = contruct_dataloader_from_disk(training_file, args.minibatch_size) 39 | validation_loader = contruct_dataloader_from_disk(validation_file, args.minibatch_size) 40 | 41 | train_model_path = train_model(data_set_identifier="TRAIN", 42 | model=model, 43 | train_loader=train_loader, 44 | validation_loader=validation_loader, 45 | learning_rate=args.learning_rate, 46 | minibatch_size=args.minibatch_size, 47 | eval_interval=args.eval_interval, 48 | hide_ui=args.hide_ui, 49 | use_gpu=use_gpu, 50 | minimum_updates=args.minimum_updates) 51 | 52 | print("Completed training, trained model stored at:") 53 | print(train_model_path) 54 | 55 | class MyModel(openprotein.BaseModel): 56 | def __init__(self, embedding_size, use_gpu): 57 | super(MyModel, self).__init__(use_gpu, embedding_size) 58 | self.use_gpu = use_gpu 59 | self.number_angles = 3 60 | self.input_to_angles = nn.Linear(embedding_size, self.number_angles) 61 | 62 | def _get_network_emissions(self, original_aa_string): 63 | batch_sizes = list([a.size() for a in original_aa_string]) 64 | 65 | embedded_input = self.embed(original_aa_string) 66 | emissions_padded = self.input_to_angles(embedded_input) 67 | 68 | probabilities = torch.softmax(emissions_padded.transpose(0, 1), 2) 69 | 70 | output_angles = torch.matmul(probabilities, ANGLE_ARR).transpose(0, 1) 71 | 72 | return output_angles, [], batch_sizes 73 | -------------------------------------------------------------------------------- /experiments/rrn/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the OpenProtein project. 3 | 4 | For license information, please see the LICENSE file in the root directory. 5 | """ 6 | 7 | from experiments.rrn.models import RrnModel 8 | from experiments.example.models import * 9 | 10 | from preprocessing import process_raw_data 11 | 12 | from training import train_model 13 | from util import contruct_dataloader_from_disk 14 | 15 | def run_experiment(parser, use_gpu): 16 | # parse experiment specific command line arguments 17 | parser.add_argument('--learning-rate', dest='learning_rate', type=float, 18 | default=0.01, help='Learning rate to use during training.') 19 | 20 | parser.add_argument('--input-file', dest='input_file', type=str, 21 | default='data/preprocessed/protein_net_testfile.txt.hdf5') 22 | 23 | args, _unknown = parser.parse_known_args() 24 | 25 | # pre-process data 26 | process_raw_data(use_gpu, force_pre_processing_overwrite=False) 27 | 28 | # run experiment 29 | training_file = args.input_file 30 | validation_file = args.input_file 31 | 32 | model = RrnModel(21, use_gpu=use_gpu) # embed size = 21 33 | 34 | train_loader = contruct_dataloader_from_disk(training_file, args.minibatch_size) 35 | validation_loader = contruct_dataloader_from_disk(validation_file, args.minibatch_size) 36 | 37 | train_model_path = train_model(data_set_identifier="TRAIN", 38 | model=model, 39 | train_loader=train_loader, 40 | validation_loader=validation_loader, 41 | learning_rate=args.learning_rate, 42 | minibatch_size=args.minibatch_size, 43 | eval_interval=args.eval_interval, 44 | hide_ui=args.hide_ui, 45 | use_gpu=use_gpu, 46 | minimum_updates=args.minimum_updates) 47 | 48 | print("Completed training, trained model stored at:") 49 | print(train_model_path) 50 | -------------------------------------------------------------------------------- /experiments/rrn/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the OpenProtein project. 3 | 4 | For license information, please see the LICENSE file in the root directory. 5 | """ 6 | import time 7 | 8 | import torch 9 | import torch.nn as nn 10 | import openprotein 11 | 12 | from util import initial_pos_from_aa_string, \ 13 | pass_messages, write_out, calc_avg_drmsd_over_minibatch 14 | 15 | class RrnModel(openprotein.BaseModel): 16 | def __init__(self, embedding_size, use_gpu): 17 | super(RrnModel, self).__init__(use_gpu, embedding_size) 18 | self.recurrent_steps = 2 19 | self.hidden_size = 50 20 | self.msg_output_size = 50 21 | self.output_size = 9 # 3 dimensions * 3 coordinates for each aa 22 | self.f_to_hid = nn.Linear((embedding_size * 2 + 9), self.hidden_size, bias=True) 23 | self.hid_to_pos = nn.Linear(self.hidden_size, self.msg_output_size, bias=True) 24 | # (last state + orginal state) 25 | self.linear_transform = nn.Linear(embedding_size + 9 + self.msg_output_size, 9, bias=True) 26 | self.use_gpu = use_gpu 27 | 28 | def apply_message_function(self, aa_features): 29 | # aa_features: msg_count * 2 * feature_count 30 | aa_features_transformed = torch.cat( 31 | ( 32 | aa_features[:, 0, 0:21], 33 | aa_features[:, 1, 0:21], 34 | aa_features[:, 0, 21:30] - aa_features[:, 1, 21:30] 35 | ), dim=1) 36 | return self.hid_to_pos(self.f_to_hid(aa_features_transformed)) # msg_count * outputsize 37 | 38 | def _get_network_emissions(self, original_aa_string): 39 | backbone_atoms_padded, batch_sizes_backbone = \ 40 | initial_pos_from_aa_string(original_aa_string, self.use_gpu) 41 | embedding_padded = self.embed(original_aa_string) 42 | 43 | if self.use_gpu: 44 | backbone_atoms_padded = backbone_atoms_padded.cuda() 45 | 46 | for _ in range(self.recurrent_steps): 47 | combined_features = torch.cat( 48 | (embedding_padded, backbone_atoms_padded), 49 | dim=2 50 | ).transpose(0, 1) 51 | 52 | features_transformed = [] 53 | 54 | for aa_features in combined_features.split(1, dim=0): 55 | msg = pass_messages(aa_features.squeeze(0), 56 | self.apply_message_function, 57 | self.use_gpu) # aa_count * output size 58 | features_transformed.append(self.linear_transform( 59 | torch.cat((aa_features.squeeze(0), msg), dim=1))) 60 | 61 | backbone_atoms_padded_clone = torch.stack(features_transformed).transpose(0, 1) 62 | 63 | backbone_atoms_padded = backbone_atoms_padded_clone 64 | 65 | return [], backbone_atoms_padded, batch_sizes_backbone 66 | 67 | def compute_loss(self, minibatch): 68 | (original_aa_string, actual_coords_list, _) = minibatch 69 | 70 | _, backbone_atoms_padded, batch_sizes = \ 71 | self._get_network_emissions(original_aa_string) 72 | actual_coords_list_padded = torch.nn.utils.rnn.pad_sequence(actual_coords_list) 73 | if self.use_gpu: 74 | actual_coords_list_padded = actual_coords_list_padded.cuda() 75 | start = time.time() 76 | if isinstance(batch_sizes[0], int): 77 | batch_sizes = torch.tensor(batch_sizes) 78 | 79 | drmsd_avg = calc_avg_drmsd_over_minibatch(backbone_atoms_padded, 80 | actual_coords_list_padded, 81 | batch_sizes) 82 | write_out("drmsd calculation time:", time.time() - start) 83 | if self.use_gpu: 84 | drmsd_avg = drmsd_avg.cuda() 85 | 86 | return drmsd_avg 87 | -------------------------------------------------------------------------------- /experiments/tmhmm3/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the OpenProtein project. 3 | 4 | For license information, please see the LICENSE file in the root directory. 5 | """ 6 | import os 7 | import pickle 8 | import hashlib 9 | from training import train_model 10 | from util import load_model_from_disk, set_experiment_id, write_prediction_data_to_disk 11 | from .tm_models import * 12 | from .tm_util import * 13 | 14 | 15 | def run_experiment(parser, use_gpu): 16 | parser.add_argument('--minibatch-size-validation', 17 | dest='minibatch_size_validation', 18 | type=int, 19 | default=8, 20 | help='Size of each minibatch during evaluation.') 21 | parser.add_argument('--hidden-size', 22 | dest='hidden_size', 23 | type=int, 24 | default=64, 25 | help='Hidden size.') 26 | parser.add_argument('--learning-rate', 27 | dest='learning_rate', 28 | type=float, 29 | default=0.0002, 30 | help='Learning rate to use during training.') 31 | parser.add_argument('--cv-partition', 32 | dest='cv_partition', 33 | type=int, 34 | default=0, 35 | help='Run a particular cross validation rotation.') 36 | parser.add_argument('--input-data', 37 | dest='input_data', 38 | type=str, 39 | default='data/raw/TMHMM3.train.3line.latest', 40 | help='Path of input data file.') 41 | parser.add_argument('--pre-trained-model-paths', 42 | dest='pre_trained_model_paths', 43 | type=str, 44 | default=None, 45 | help='Paths of pre-trained models.') 46 | parser.add_argument('--profile-path', dest='profile_path', 47 | type=str, default="", 48 | help='Profiles to use for embedding.') 49 | args, _unknown = parser.parse_known_args() 50 | 51 | result_matrices = np.zeros((5, 5), dtype=np.int64) 52 | 53 | if args.profile_path != "": 54 | embedding = "PROFILE" 55 | else: 56 | embedding = "BLOSUM62" 57 | use_marg_prob = False 58 | all_prediction_data = [] 59 | 60 | for cv_partition in [0, 1, 2, 3, 4]: 61 | # prepare data sets 62 | train_set, val_set, test_set = load_data_from_disk(filename=args.input_data, 63 | partition_rotation=cv_partition) 64 | 65 | # topology data set 66 | train_set_topology = list(filter(lambda x: x[3] == 0 or x[3] == 1, train_set)) 67 | val_set_topology = list(filter(lambda x: x[3] == 0 or x[3] == 1, val_set)) 68 | test_set_topology = list(filter(lambda x: x[3] == 0 or x[3] == 1, test_set)) 69 | 70 | if not args.silent: 71 | print("Loaded ", 72 | len(train_set), "training,", 73 | len(val_set), "validation and", 74 | len(test_set), "test samples") 75 | 76 | print("Processing data...") 77 | pre_processed_path = "data/preprocessed/preprocessed_data_" + str( 78 | hashlib.sha256(args.input_data.encode()).hexdigest())[:8] + "_cv" \ 79 | + str(cv_partition) + ".pickle" 80 | if not os.path.isfile(pre_processed_path): 81 | input_data_processed = list([TMDataset.from_disk(set, use_gpu) for set in 82 | [train_set, val_set, test_set, 83 | train_set_topology, val_set_topology, 84 | test_set_topology]]) 85 | pickle.dump(input_data_processed, open(pre_processed_path, "wb")) 86 | input_data_processed = pickle.load(open(pre_processed_path, "rb")) 87 | train_preprocessed_set = input_data_processed[0] 88 | validation_preprocessed_set = input_data_processed[1] 89 | test_preprocessed_set = input_data_processed[2] 90 | train_preprocessed_set_topology = input_data_processed[3] 91 | validation_preprocessed_set_topology = input_data_processed[4] 92 | _test_preprocessed_set_topology = input_data_processed[5] 93 | 94 | print("Completed preprocessing of data...") 95 | 96 | train_loader = tm_contruct_dataloader_from_disk(train_preprocessed_set, 97 | args.minibatch_size, 98 | balance_classes=True) 99 | validation_loader = tm_contruct_dataloader_from_disk(validation_preprocessed_set, 100 | args.minibatch_size_validation, 101 | balance_classes=True) 102 | test_loader = tm_contruct_dataloader_from_disk( 103 | test_preprocessed_set if args.evaluate_on_test else validation_preprocessed_set, 104 | args.minibatch_size_validation) 105 | 106 | train_loader_topology = \ 107 | tm_contruct_dataloader_from_disk(train_preprocessed_set_topology, 108 | args.minibatch_size) 109 | validation_loader_topology = \ 110 | tm_contruct_dataloader_from_disk(validation_preprocessed_set_topology, 111 | args 112 | .minibatch_size_validation) 113 | 114 | type_predictor_model_path = None 115 | 116 | if args.pre_trained_model_paths is None: 117 | for (experiment_id, train_data, validation_data) in [ 118 | ("TRAIN_TYPE_CV" + str(cv_partition) 119 | + "-HS" + str(args.hidden_size) + "-F" + str(args.input_data.split(".")[-2]) 120 | + "-P" + str(args.profile_path.split("_")[-1]), train_loader, 121 | validation_loader), 122 | ("TRAIN_TOPOLOGY_CV" + str(cv_partition) 123 | + "-HS" + str(args.hidden_size) + "-F" + str(args.input_data.split(".")[-2]) 124 | + "-P" + str(args.profile_path.split("_")[-1]), 125 | train_loader_topology, validation_loader_topology)]: 126 | 127 | type_predictor = None 128 | if type_predictor_model_path is not None: 129 | type_predictor = load_model_from_disk(type_predictor_model_path, 130 | force_cpu=False) 131 | model = load_model_from_disk(type_predictor_model_path, 132 | force_cpu=False) 133 | model.type_classifier = type_predictor 134 | model.type_01loss_values = [] 135 | model.topology_01loss_values = [] 136 | else: 137 | model = TMHMM3( 138 | embedding, 139 | args.hidden_size, 140 | use_gpu, 141 | use_marg_prob, 142 | type_predictor, 143 | args.profile_path) 144 | 145 | model_path = train_model(data_set_identifier=experiment_id, 146 | model=model, 147 | train_loader=train_data, 148 | validation_loader=validation_data, 149 | learning_rate=args.learning_rate, 150 | minibatch_size=args.minibatch_size, 151 | eval_interval=args.eval_interval, 152 | hide_ui=args.hide_ui, 153 | use_gpu=use_gpu, 154 | minimum_updates=args.minimum_updates) 155 | 156 | # let the GC collect the model 157 | del model 158 | 159 | write_out(model_path) 160 | 161 | # if we just trained a type predictor, save it for later 162 | if "TRAIN_TYPE" in experiment_id: 163 | type_predictor_model_path = model_path 164 | else: 165 | # use the pre-trained model 166 | model_path = args.pre_trained_model_paths.split(",")[cv_partition] 167 | 168 | # test model 169 | write_out("Testing model...") 170 | model = load_model_from_disk(model_path, force_cpu=False) 171 | _loss, json_data, prediction_data = model.evaluate_model(test_loader) 172 | 173 | all_prediction_data.append(post_process_prediction_data(prediction_data)) 174 | result_matrix = np.array(json_data['confusion_matrix']) 175 | result_matrices += result_matrix 176 | write_out(result_matrix) 177 | 178 | set_experiment_id( 179 | "TEST-" + "-HS" + str(args.hidden_size) + "-F" 180 | + str(args.input_data.split(".")[-2]), 181 | args.learning_rate, 182 | args.minibatch_size) 183 | write_out(result_matrices) 184 | write_prediction_data_to_disk("\n".join(all_prediction_data)) 185 | -------------------------------------------------------------------------------- /experiments/tmhmm3/tm_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the OpenProtein project. 3 | 4 | For license information, please see the LICENSE file in the root directory. 5 | """ 6 | 7 | import sys 8 | import glob 9 | import pickle 10 | from typing import Tuple 11 | 12 | import numpy as np 13 | import torch 14 | import torch.autograd as autograd 15 | import torch.nn as nn 16 | import openprotein 17 | from experiments.tmhmm3.tm_util import batch_sizes_to_mask, \ 18 | initialize_crf_parameters, decode_numpy 19 | from experiments.tmhmm3.tm_util import is_topologies_equal 20 | from experiments.tmhmm3.tm_util import original_labels_to_fasta 21 | from pytorchcrf.torchcrf import CRF 22 | from util import write_out, get_experiment_id 23 | 24 | # seed random generator for reproducibility 25 | torch.manual_seed(1) 26 | 27 | 28 | class TMHMM3(openprotein.BaseModel): 29 | def __init__(self, 30 | embedding, 31 | hidden_size, 32 | use_gpu, 33 | use_marg_prob, 34 | type_predictor_model, 35 | profile_path): 36 | super(TMHMM3, self).__init__(embedding, use_gpu) 37 | 38 | # initialize model variables 39 | num_labels = 5 40 | self.max_signal_length = 67 41 | num_tags = 5 + 2 * 40 + self.max_signal_length 42 | self.hidden_size = hidden_size 43 | self.use_gpu = use_gpu 44 | self.use_marg_prob = use_marg_prob 45 | self.embedding = embedding 46 | self.profile_path = profile_path 47 | self.bi_lstm = nn.LSTM(self.get_embedding_size(), 48 | self.hidden_size, 49 | num_layers=1, 50 | bidirectional=True) 51 | self.hidden_to_labels = nn.Linear(self.hidden_size * 2, num_labels) # * 2 for bidirectional 52 | self.hidden_layer = None 53 | crf_start_mask = torch.ones(num_tags, dtype=torch.uint8) == 1 54 | crf_end_mask = torch.ones(num_tags, dtype=torch.uint8) == 1 55 | 56 | allowed_transitions = [ 57 | (3, 3), (4, 4), 58 | (3, 5), (4, 45)] 59 | for i in range(5, 45 - 1): 60 | allowed_transitions.append((i, i + 1)) 61 | if 8 < i < 43: 62 | allowed_transitions.append((8, i)) 63 | allowed_transitions.append((44, 4)) 64 | for i in range(45, 85 - 1): 65 | allowed_transitions.append((i, i + 1)) 66 | if 48 < i < 83: 67 | allowed_transitions.append((48, i)) 68 | allowed_transitions.append((84, 3)) 69 | for i in range(85, 151): 70 | allowed_transitions.append((i, i + 1)) 71 | allowed_transitions.append((2, i)) 72 | allowed_transitions.append((2, 151)) 73 | allowed_transitions.append((2, 4)) 74 | allowed_transitions.append((151, 4)) 75 | 76 | crf_start_mask[2] = 0 77 | crf_start_mask[3] = 0 78 | crf_start_mask[4] = 0 79 | crf_end_mask[3] = 0 80 | crf_end_mask[4] = 0 81 | 82 | self.allowed_transitions = allowed_transitions 83 | self.crf_model = CRF(num_tags) 84 | self.type_classifier = type_predictor_model 85 | self.type_tm_classier = None 86 | self.type_sp_classier = None 87 | crf_transitions_mask = torch.ones((num_tags, num_tags), dtype=torch.uint8) == 1 88 | 89 | self.type_01loss_values = [] 90 | self.topology_01loss_values = [] 91 | 92 | # if on GPU, move state to GPU memory 93 | if self.use_gpu: 94 | self.crf_model = self.crf_model.cuda() 95 | self.bi_lstm = self.bi_lstm.cuda() 96 | self.hidden_to_labels = self.hidden_to_labels.cuda() 97 | crf_transitions_mask = crf_transitions_mask.cuda() 98 | crf_start_mask = crf_start_mask.cuda() 99 | crf_end_mask = crf_end_mask.cuda() 100 | 101 | # compute mask matrix from allow transitions list 102 | for i in range(num_tags): 103 | for k in range(num_tags): 104 | if (i, k) in self.allowed_transitions: 105 | crf_transitions_mask[i][k] = 0 106 | 107 | # generate masked transition parameters 108 | crf_start_transitions, crf_end_transitions, crf_transitions = \ 109 | generate_masked_crf_transitions( 110 | self.crf_model, (crf_start_mask, crf_transitions_mask, crf_end_mask) 111 | ) 112 | 113 | # initialize CRF 114 | initialize_crf_parameters(self.crf_model, 115 | start_transitions=crf_start_transitions, 116 | end_transitions=crf_end_transitions, 117 | transitions=crf_transitions) 118 | 119 | def get_embedding_size(self): 120 | if self.embedding == "BLOSUM62": 121 | return 24 # bloom matrix has size 24 122 | elif self.embedding == "PROFILE": 123 | return 51 # protein profiles have size 51 124 | 125 | def flatten_parameters(self): 126 | self.bi_lstm.flatten_parameters() 127 | 128 | def encode_amino_acid(self, letter): 129 | if self.embedding == "BLOSUM62": 130 | # blosum encoding 131 | if not globals().get('blosum_encoder'): 132 | blosum = \ 133 | """4,-1,-2,-2,0,-1,-1,0,-2,-1,-1,-1,-1,-2,-1,1,0,-3,-2,0,-2,-1,0,-4 134 | -1,5,0,-2,-3,1,0,-2,0,-3,-2,2,-1,-3,-2,-1,-1,-3,-2,-3,-1,0,-1,-4 135 | -2,0,6,1,-3,0,0,0,1,-3,-3,0,-2,-3,-2,1,0,-4,-2,-3,3,0,-1,-4 136 | -2,-2,1,6,-3,0,2,-1,-1,-3,-4,-1,-3,-3,-1,0,-1,-4,-3,-3,4,1,-1,-4 137 | 0,-3,-3,-3,9,-3,-4,-3,-3,-1,-1,-3,-1,-2,-3,-1,-1,-2,-2,-1,-3,-3,-2,-4 138 | -1,1,0,0,-3,5,2,-2,0,-3,-2,1,0,-3,-1,0,-1,-2,-1,-2,0,3,-1,-4 139 | -1,0,0,2,-4,2,5,-2,0,-3,-3,1,-2,-3,-1,0,-1,-3,-2,-2,1,4,-1,-4 140 | 0,-2,0,-1,-3,-2,-2,6,-2,-4,-4,-2,-3,-3,-2,0,-2,-2,-3,-3,-1,-2,-1,-4 141 | -2,0,1,-1,-3,0,0,-2,8,-3,-3,-1,-2,-1,-2,-1,-2,-2,2,-3,0,0,-1,-4 142 | -1,-3,-3,-3,-1,-3,-3,-4,-3,4,2,-3,1,0,-3,-2,-1,-3,-1,3,-3,-3,-1,-4 143 | -1,-2,-3,-4,-1,-2,-3,-4,-3,2,4,-2,2,0,-3,-2,-1,-2,-1,1,-4,-3,-1,-4 144 | -1,2,0,-1,-3,1,1,-2,-1,-3,-2,5,-1,-3,-1,0,-1,-3,-2,-2,0,1,-1,-4 145 | -1,-1,-2,-3,-1,0,-2,-3,-2,1,2,-1,5,0,-2,-1,-1,-1,-1,1,-3,-1,-1,-4 146 | -2,-3,-3,-3,-2,-3,-3,-3,-1,0,0,-3,0,6,-4,-2,-2,1,3,-1,-3,-3,-1,-4 147 | -1,-2,-2,-1,-3,-1,-1,-2,-2,-3,-3,-1,-2,-4,7,-1,-1,-4,-3,-2,-2,-1,-2,-4 148 | 1,-1,1,0,-1,0,0,0,-1,-2,-2,0,-1,-2,-1,4,1,-3,-2,-2,0,0,0,-4 149 | 0,-1,0,-1,-1,-1,-1,-2,-2,-1,-1,-1,-1,-2,-1,1,5,-2,-2,0,-1,-1,0,-4 150 | -3,-3,-4,-4,-2,-2,-3,-2,-2,-3,-2,-3,-1,1,-4,-3,-2,11,2,-3,-4,-3,-2,-4 151 | -2,-2,-2,-3,-2,-1,-2,-3,2,-1,-1,-2,-1,3,-3,-2,-2,2,7,-1,-3,-2,-1,-4 152 | 0,-3,-3,-3,-1,-2,-2,-3,-3,3,1,-2,1,-1,-2,-2,0,-3,-1,4,-3,-2,-1,-4 153 | -2,-1,3,4,-3,0,1,-1,0,-3,-4,0,-3,-3,-2,0,-1,-4,-3,-3,4,1,-1,-4 154 | -1,0,0,1,-3,3,4,-2,0,-3,-3,1,-1,-3,-1,0,-1,-3,-2,-2,1,4,-1,-4 155 | 0,-1,-1,-1,-2,-1,-1,-1,-1,-1,-1,-1,-1,-1,-2,0,0,-2,-1,-1,-1,-1,-1,-4 156 | -4,-4,-4,-4,-4,-4,-4,-4,-4,-4,-4,-4,-4,-4,-4,-4,-4,-4,-4,-4,-4,-4,-4,1""" \ 157 | .replace('\n', ',') 158 | blosum_matrix = np.fromstring(blosum, sep=",").reshape(24, 24) 159 | blosum_key = "A,R,N,D,C,Q,E,G,H,I,L,K,M,F,P,S,T,W,Y,V,B,Z,X,U".split(",") 160 | key_map = {} 161 | for idx, value in enumerate(blosum_key): 162 | key_map[value] = list([int(v) for v in blosum_matrix[idx].astype('int')]) 163 | globals().__setitem__("blosum_encoder", key_map) 164 | return globals().get('blosum_encoder')[letter] 165 | elif self.embedding == "ONEHOT": 166 | # one hot encoding 167 | one_hot_key = "A,R,N,D,C,Q,E,G,H,I,L,K,M,F,P,S,T,W,Y,V,B,Z,X,U".split(",") 168 | arr = [] 169 | for idx, k in enumerate(one_hot_key): 170 | if k == letter: 171 | arr.append(1) 172 | else: 173 | arr.append(0) 174 | return arr 175 | elif self.embedding == "PYTORCH": 176 | key_id = "A,R,N,D,C,Q,E,G,H,I,L,K,M,F,P,S,T,W,Y,V,B,Z,X,U".split(",") 177 | for idx, k in enumerate(key_id): 178 | if k == letter: 179 | return idx 180 | 181 | def embed(self, prot_aa_list): 182 | embed_list = [] 183 | for aa_list in prot_aa_list: 184 | if self.embedding == "PYTORCH": 185 | tensor = torch.LongTensor(tensor) 186 | elif self.embedding == "PROFILE": 187 | if not globals().get('profile_encoder'): 188 | print("Load profiles...") 189 | files = glob.glob(self.profile_path.strip("/") + "/*") 190 | profile_dict = {} 191 | for profile_file in files: 192 | profile = pickle.load(open(profile_file, "rb")).popitem()[1] 193 | profile_dict[profile["seq"]] = torch.from_numpy(profile["profile"]).float() 194 | globals().__setitem__("profile_encoder", profile_dict) 195 | print("Loaded profiles") 196 | tensor = globals().get('profile_encoder')[aa_list] 197 | else: 198 | tensor = list([self.encode_amino_acid(aa) for aa in aa_list]) 199 | tensor = torch.FloatTensor(tensor) 200 | if self.use_gpu: 201 | tensor = tensor.cuda() 202 | embed_list.append(tensor) 203 | return embed_list 204 | 205 | def init_hidden(self, minibatch_size): 206 | # number of layers (* 2 since bidirectional), minibatch_size, hidden size 207 | initial_hidden_state = torch.zeros(1 * 2, minibatch_size, self.hidden_size) 208 | initial_cell_state = torch.zeros(1 * 2, minibatch_size, self.hidden_size) 209 | if self.use_gpu: 210 | initial_hidden_state = initial_hidden_state.cuda() 211 | initial_cell_state = initial_cell_state.cuda() 212 | self.hidden_layer = (autograd.Variable(initial_hidden_state), 213 | autograd.Variable(initial_cell_state)) 214 | 215 | def _get_network_emissions(self, pad_seq_embed: torch.Tensor) -> torch.Tensor: 216 | minibatch_size = pad_seq_embed.size(1) 217 | self.init_hidden(minibatch_size) 218 | bi_lstm_out, self.hidden_layer = self.bi_lstm(pad_seq_embed, self.hidden_layer) 219 | emissions = self.hidden_to_labels(bi_lstm_out) 220 | inout_select = torch.zeros(1, dtype=torch.long) 221 | outin_select = torch.ones(1, dtype=torch.long) 222 | signal_select = torch.ones(1, dtype=torch.long) * 2 223 | if emissions.is_cuda: 224 | inout_select = inout_select.cuda() 225 | outin_select = outin_select.cuda() 226 | signal_select = signal_select.cuda() 227 | inout = torch.index_select(emissions, 2, inout_select) 228 | outin = torch.index_select(emissions, 2, outin_select) 229 | signal = torch.index_select(emissions, 2, signal_select) 230 | emissions = torch.cat((emissions, 231 | inout.expand(-1, minibatch_size, 40), 232 | outin.expand(-1, minibatch_size, 40), 233 | signal.expand(-1, minibatch_size, self.max_signal_length)), 2) 234 | return emissions 235 | 236 | def compute_loss(self, training_minibatch): 237 | _, labels_list, remapped_labels_list_crf_hmm, _, _, _, _, original_aa_string, \ 238 | _original_label_string = training_minibatch 239 | minibatch_size = len(labels_list) 240 | labels_to_use = remapped_labels_list_crf_hmm 241 | input_sequences = [x for x in self.embed(original_aa_string)] 242 | input_sequences_padded = torch.nn.utils.rnn.pad_sequence(input_sequences) 243 | batch_sizes = torch.IntTensor(list([x.size(0) for x in input_sequences])) 244 | if input_sequences_padded.is_cuda: 245 | batch_sizes = batch_sizes.cuda() 246 | 247 | actual_labels = torch.nn.utils.rnn.pad_sequence([l for l in labels_to_use]) 248 | emissions = self._get_network_emissions(input_sequences_padded) 249 | 250 | mask = batch_sizes_to_mask(batch_sizes) 251 | loss = -1 * self.crf_model(emissions, actual_labels, mask=mask) / minibatch_size 252 | if float(loss) > 100000: # if loss is this large, an invalid tx must have been found 253 | for idx, batch_size in enumerate(batch_sizes): 254 | last_label = None 255 | for i in range(batch_size): 256 | label = int(actual_labels[i][idx]) 257 | write_out(str(label) + ",", end='') 258 | if last_label is not None and (last_label, label) \ 259 | not in self.allowed_transitions: 260 | write_out("Error: invalid transition found") 261 | write_out((last_label, label)) 262 | sys.exit(1) 263 | last_label = label 264 | write_out(" ") 265 | return loss 266 | 267 | def forward(self, input_sequences_padded) -> Tuple[torch.Tensor, torch.Tensor]: 268 | if input_sequences_padded.is_cuda or input_sequences_padded.is_cuda: 269 | input_sequences_padded = input_sequences_padded.cuda() 270 | emissions = self._get_network_emissions(input_sequences_padded) 271 | 272 | return emissions, \ 273 | self.crf_model.start_transitions, \ 274 | self.crf_model.transitions, \ 275 | self.crf_model.end_transitions 276 | 277 | def evaluate_model(self, data_loader): 278 | validation_loss_tracker = [] 279 | validation_type_loss_tracker = [] 280 | validation_topology_loss_tracker = [] 281 | confusion_matrix = np.zeros((5, 5), dtype=np.int64) 282 | protein_names = [] 283 | protein_aa_strings = [] 284 | protein_label_actual = [] 285 | protein_label_prediction = [] 286 | for _, minibatch in enumerate(data_loader, 0): 287 | validation_loss_tracker.append(self.compute_loss(minibatch).detach()) 288 | 289 | _, _, _, _, prot_type_list, prot_topology_list, \ 290 | prot_name_list, original_aa_string, original_label_string = minibatch 291 | input_sequences = [x for x in self.embed(original_aa_string)] 292 | input_sequences_padded = torch.nn.utils.rnn.pad_sequence(input_sequences) 293 | batch_sizes = torch.IntTensor(list([x.size(0) for x in input_sequences])) 294 | 295 | emmisions, start_transitions, transitions, end_transitions = \ 296 | self(input_sequences_padded) 297 | predicted_labels, predicted_types, predicted_topologies = \ 298 | decode_numpy(emmisions.detach().cpu().numpy(), 299 | batch_sizes.detach().cpu().numpy(), 300 | start_transitions.detach().cpu().numpy(), 301 | transitions.detach().cpu().numpy(), 302 | end_transitions.detach().cpu().numpy()) 303 | 304 | protein_names.extend(prot_name_list) 305 | protein_aa_strings.extend(original_aa_string) 306 | protein_label_actual.extend(original_label_string) 307 | 308 | # if we're using an external type predictor 309 | if self.type_classifier is not None: 310 | emmisions, start_transitions, transitions, end_transitions = \ 311 | self.type_classifier(input_sequences_padded) 312 | 313 | predicted_labels_type_classifer, \ 314 | predicted_types_type_classifier, predicted_topologies_type_classifier = \ 315 | decode_numpy(emmisions.detach().cpu().numpy(), 316 | batch_sizes.detach().cpu().numpy(), 317 | start_transitions.detach().cpu().numpy(), 318 | transitions.detach().cpu().numpy(), 319 | end_transitions.detach().cpu().numpy()) 320 | 321 | for idx, actual_type in enumerate(prot_type_list): 322 | 323 | predicted_type = predicted_types[idx] 324 | predicted_topology = predicted_topologies[idx] 325 | predicted_labels_for_protein = predicted_labels[idx] 326 | 327 | if self.type_classifier is not None: 328 | if predicted_type != predicted_types_type_classifier[idx]: 329 | # we must always use the type predicted by the type predictor if available 330 | predicted_type = predicted_types_type_classifier[idx] 331 | predicted_topology = predicted_topologies_type_classifier[idx] 332 | predicted_labels_for_protein = predicted_labels_type_classifer[idx] 333 | 334 | prediction_topology_match = is_topologies_equal(prot_topology_list[idx], 335 | predicted_topology, 5) 336 | 337 | if actual_type == predicted_type: 338 | validation_type_loss_tracker.append(0) 339 | # if we guessed the type right for SP+GLOB or GLOB, 340 | # count the topology as correct 341 | if actual_type == 2 or actual_type == 3 or prediction_topology_match: 342 | validation_topology_loss_tracker.append(0) 343 | confusion_matrix[actual_type][4] += 1 344 | else: 345 | validation_topology_loss_tracker.append(1) 346 | confusion_matrix[actual_type][predicted_type] += 1 347 | 348 | # if the type was correctly guessed to be 2 or 3 by the type classifier, 349 | # use its topology prediction 350 | if (actual_type in (2, 3)) and self.type_classifier is not None: 351 | protein_label_prediction.append(predicted_labels_type_classifer[idx]) 352 | else: 353 | protein_label_prediction.append(predicted_labels_for_protein) 354 | else: 355 | confusion_matrix[actual_type][int(predicted_type.item())] += 1 356 | validation_type_loss_tracker.append(1) 357 | validation_topology_loss_tracker.append(1) 358 | protein_label_prediction.append(predicted_labels_for_protein) 359 | 360 | write_out(confusion_matrix) 361 | _loss = float(torch.stack(validation_loss_tracker).mean()) 362 | 363 | type_loss = float(torch.FloatTensor(validation_type_loss_tracker).mean().detach()) 364 | topology_loss = float(torch.FloatTensor(validation_topology_loss_tracker).mean().detach()) 365 | 366 | self.type_01loss_values.append(type_loss) 367 | self.topology_01loss_values.append(topology_loss) 368 | 369 | if get_experiment_id() is not None and "TYPE" in get_experiment_id(): 370 | # optimize for type 371 | validation_loss = type_loss 372 | else: 373 | # optimize for topology 374 | validation_loss = topology_loss 375 | 376 | data = {} 377 | data['type_01loss_values'] = self.type_01loss_values 378 | data['topology_01loss_values'] = self.topology_01loss_values 379 | data['confusion_matrix'] = confusion_matrix.tolist() 380 | 381 | return validation_loss, data, ( 382 | protein_names, protein_aa_strings, protein_label_actual, protein_label_prediction) 383 | 384 | 385 | def post_process_prediction_data(prediction_data): 386 | data = [] 387 | for (name, aa_string, actual, prediction) in zip(*prediction_data): 388 | data.append("\n".join([">" + name, 389 | aa_string, 390 | actual, 391 | original_labels_to_fasta(prediction)])) 392 | return "\n".join(data) 393 | 394 | 395 | def logsumexp(data, dim): 396 | return data.max(dim)[0] + torch.log(torch.sum( 397 | torch.exp(data - data.max(dim)[0].unsqueeze(dim)), dim)) 398 | 399 | 400 | def generate_masked_crf_transitions(crf_model, transition_mask): 401 | start_transitions_mask, transitions_mask, end_transition_mask = transition_mask 402 | start_transitions = crf_model.start_transitions.data.clone() 403 | end_transitions = crf_model.end_transitions.data.clone() 404 | transitions = crf_model.transitions.data.clone() 405 | if start_transitions_mask is not None: 406 | start_transitions.masked_fill_(start_transitions_mask, -100000000) 407 | if end_transition_mask is not None: 408 | end_transitions.masked_fill_(end_transition_mask, -100000000) 409 | if transitions_mask is not None: 410 | transitions.masked_fill_(transitions_mask, -100000000) 411 | return start_transitions, end_transitions, transitions 412 | -------------------------------------------------------------------------------- /experiments/tmhmm3/tm_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the OpenProtein project. 3 | 4 | For license information, please see the LICENSE file in the root directory. 5 | """ 6 | 7 | import math 8 | import random 9 | from typing import List 10 | 11 | import torch 12 | from torch.utils.data.dataset import Dataset 13 | import numpy as np 14 | 15 | from pytorchcrf.torchcrf import CRF 16 | from util import write_out 17 | 18 | 19 | class TMDataset(Dataset): 20 | def __init__(self, 21 | aa_list, 22 | label_list, 23 | remapped_labels_list_crf_hmm, 24 | remapped_labels_list_crf_marg, 25 | type_list, 26 | topology_list, 27 | prot_name_list, 28 | original_aa_string_list, 29 | original_label_string): 30 | assert len(aa_list) == len(label_list) 31 | assert len(aa_list) == len(type_list) 32 | assert len(aa_list) == len(topology_list) 33 | self.aa_list = aa_list 34 | self.label_list = label_list 35 | self.remapped_labels_list_crf_hmm = remapped_labels_list_crf_hmm 36 | self.remapped_labels_list_crf_marg = remapped_labels_list_crf_marg 37 | self.type_list = type_list 38 | self.topology_list = topology_list 39 | self.prot_name_list = prot_name_list 40 | self.original_aa_string_list = original_aa_string_list 41 | self.original_label_string = original_label_string 42 | 43 | @staticmethod 44 | def from_disk(dataset, use_gpu): 45 | print("Constructing data set from disk...") 46 | aa_list = [] 47 | labels_list = [] 48 | remapped_labels_list_crf_hmm = [] 49 | remapped_labels_list_crf_marg = [] 50 | prot_type_list = [] 51 | prot_topology_list_all = [] 52 | prot_aa_list_all = [] 53 | prot_labels_list_all = [] 54 | prot_name_list = [] 55 | # sort according to length of aa sequence 56 | dataset.sort(key=lambda x: len(x[1]), reverse=True) 57 | for prot_name, prot_aa_list, prot_original_label_list, type_id, _cluster_id in dataset: 58 | prot_name_list.append(prot_name) 59 | prot_aa_list_all.append(prot_aa_list) 60 | prot_labels_list_all.append(prot_original_label_list) 61 | aa_tmp_list_tensor = [] 62 | labels = None 63 | remapped_labels_crf_hmm = None 64 | last_non_membrane_position = None 65 | if prot_original_label_list is not None: 66 | labels = [] 67 | for topology_label in prot_original_label_list: 68 | if topology_label == "L": 69 | topology_label = "I" 70 | if topology_label == "I": 71 | last_non_membrane_position = "I" 72 | labels.append(3) 73 | elif topology_label == "O": 74 | last_non_membrane_position = "O" 75 | labels.append(4) 76 | elif topology_label == "S": 77 | last_non_membrane_position = "S" 78 | labels.append(2) 79 | elif topology_label == "M": 80 | if last_non_membrane_position == "I": 81 | labels.append(0) 82 | elif last_non_membrane_position == "O": 83 | labels.append(1) 84 | else: 85 | print("Error: unexpected label found in last_non_membrane_position:", 86 | topology_label) 87 | else: 88 | print("Error: unexpected label found:", topology_label, "for protein", 89 | prot_name) 90 | labels = torch.LongTensor(labels) 91 | remapped_labels_crf_hmm = [] 92 | topology = label_list_to_topology(labels) 93 | # given topology, now calculate remapped labels 94 | for idx, (pos, l) in enumerate(topology): 95 | if l == 0: # I -> O 96 | membrane_length = topology[idx + 1][0] - pos 97 | mm_beginning = 4 98 | for i in range(mm_beginning): 99 | remapped_labels_crf_hmm.append(5 + i) 100 | for i in range(40 - (membrane_length - mm_beginning), 40): 101 | remapped_labels_crf_hmm.append(5 + i) 102 | elif l == 1: # O -> I 103 | membrane_length = topology[idx + 1][0] - pos 104 | mm_beginning = 4 105 | for i in range(mm_beginning): 106 | remapped_labels_crf_hmm.append(45 + i) 107 | for i in range(40 - (membrane_length - mm_beginning), 40): 108 | remapped_labels_crf_hmm.append(45 + i) 109 | elif l == 2: # S 110 | signal_length = topology[idx + 1][0] - pos 111 | remapped_labels_crf_hmm.append(2) 112 | for i in range(signal_length - 1): 113 | remapped_labels_crf_hmm.append(152 - ((signal_length - 1) - i)) 114 | if remapped_labels_crf_hmm[-1] == 85: 115 | print("Too long signal peptide region found", prot_name) 116 | else: 117 | if idx == (len(topology) - 1): 118 | for i in range(len(labels) - pos): 119 | remapped_labels_crf_hmm.append(l) 120 | else: 121 | for i in range(topology[idx + 1][0] - pos): 122 | remapped_labels_crf_hmm.append(l) 123 | remapped_labels_crf_hmm = torch.LongTensor(remapped_labels_crf_hmm) 124 | 125 | remapped_labels_crf_marg = list([l + (type_id * 5) for l in labels]) 126 | remapped_labels_crf_marg = torch.LongTensor(remapped_labels_crf_marg) 127 | 128 | # check that protein was properly parsed 129 | assert remapped_labels_crf_hmm.size() == labels.size() 130 | assert remapped_labels_crf_marg.size() == labels.size() 131 | 132 | if use_gpu: 133 | if labels is not None: 134 | labels = labels.cuda() 135 | remapped_labels_crf_hmm = remapped_labels_crf_hmm.cuda() 136 | remapped_labels_crf_marg = remapped_labels_crf_marg.cuda() 137 | aa_list.append(aa_tmp_list_tensor) 138 | labels_list.append(labels) 139 | remapped_labels_list_crf_hmm.append(remapped_labels_crf_hmm) 140 | remapped_labels_list_crf_marg.append(remapped_labels_crf_marg) 141 | prot_type_list.append(type_id) 142 | prot_topology_list_all.append(label_list_to_topology(labels)) 143 | return TMDataset(aa_list, labels_list, remapped_labels_list_crf_hmm, 144 | remapped_labels_list_crf_marg, 145 | prot_type_list, prot_topology_list_all, prot_name_list, 146 | prot_aa_list_all, prot_labels_list_all) 147 | 148 | def __getitem__(self, index): 149 | return self.aa_list[index], \ 150 | self.label_list[index], \ 151 | self.remapped_labels_list_crf_hmm[index], \ 152 | self.remapped_labels_list_crf_marg[index], \ 153 | self.type_list[index], \ 154 | self.topology_list[index], \ 155 | self.prot_name_list[index], \ 156 | self.original_aa_string_list[index], \ 157 | self.original_label_string[index] 158 | 159 | def __len__(self): 160 | return len(self.aa_list) 161 | 162 | 163 | def merge_samples_to_minibatch(samples): 164 | samples_list = [] 165 | for sample in samples: 166 | samples_list.append(sample) 167 | # sort according to length of aa sequence 168 | samples_list.sort(key=lambda x: len(x[7]), reverse=True) 169 | aa_list, labels_list, remapped_labels_list_crf_hmm, \ 170 | remapped_labels_list_crf_marg, prot_type_list, prot_topology_list, \ 171 | prot_name, original_aa_string, original_label_string = zip( 172 | *samples_list) 173 | write_out(prot_type_list) 174 | return aa_list, labels_list, remapped_labels_list_crf_hmm, remapped_labels_list_crf_marg, \ 175 | prot_type_list, prot_topology_list, prot_name, original_aa_string, original_label_string 176 | 177 | def tm_contruct_dataloader_from_disk(tm_dataset, minibatch_size, balance_classes=False): 178 | if balance_classes: 179 | batch_sampler = RandomBatchClassBalancedSequentialSampler(tm_dataset, minibatch_size) 180 | else: 181 | batch_sampler = RandomBatchSequentialSampler(tm_dataset, minibatch_size) 182 | return torch.utils.data.DataLoader(tm_dataset, 183 | batch_sampler=batch_sampler, 184 | collate_fn=merge_samples_to_minibatch) 185 | 186 | 187 | class RandomBatchClassBalancedSequentialSampler(torch.utils.data.sampler.Sampler): 188 | 189 | def __init__(self, dataset, batch_size): 190 | self.sampler = torch.utils.data.sampler.SequentialSampler(dataset) 191 | self.batch_size = batch_size 192 | self.dataset = dataset 193 | 194 | def __iter__(self): 195 | data_class_map = {} 196 | data_class_map[0] = [] 197 | data_class_map[1] = [] 198 | data_class_map[2] = [] 199 | data_class_map[3] = [] 200 | 201 | for idx in self.sampler: 202 | data_class_map[self.dataset[idx][4]].append(idx) 203 | 204 | num_each_class = int(self.batch_size / 4) 205 | 206 | max_class_size = max( 207 | [len(data_class_map[0]), len(data_class_map[1]), 208 | len(data_class_map[2]), len(data_class_map[3])]) 209 | 210 | batch_num = int(max_class_size / num_each_class) 211 | if max_class_size % num_each_class != 0: 212 | batch_num += 1 213 | 214 | batch_relative_offset = (1.0 / float(batch_num)) / 2.0 215 | batches = [] 216 | for _ in range(batch_num): 217 | batch = [] 218 | for _class_id, data_rows in data_class_map.items(): 219 | int_offset = int(batch_relative_offset * len(data_rows)) 220 | batch.extend(sample_at_index(data_rows, int_offset, num_each_class)) 221 | batch_relative_offset += 1.0 / float(batch_num) 222 | batches.append(batch) 223 | 224 | random.shuffle(batches) 225 | 226 | for batch in batches: 227 | write_out("Using minibatch from RandomBatchClassBalancedSequentialSampler") 228 | yield batch 229 | 230 | def __len__(self): 231 | length = 0 232 | for _ in self.sampler: 233 | length += 1 234 | return length 235 | 236 | 237 | class RandomBatchSequentialSampler(torch.utils.data.sampler.Sampler): 238 | 239 | def __init__(self, dataset, batch_size): 240 | self.sampler = torch.utils.data.sampler.SequentialSampler(dataset) 241 | self.batch_size = batch_size 242 | 243 | def __iter__(self): 244 | data = [] 245 | for idx in self.sampler: 246 | data.append(idx) 247 | 248 | batch_num = int(len(data) / self.batch_size) 249 | if len(data) % self.batch_size != 0: 250 | batch_num += 1 251 | 252 | batch_order = list(range(batch_num)) 253 | random.shuffle(batch_order) 254 | 255 | batch = [] 256 | for batch_id in batch_order: 257 | write_out("Accessing minibatch #" + str(batch_id)) 258 | for i in range(self.batch_size): 259 | if i + (batch_id * self.batch_size) < len(data): 260 | batch.append(data[i + (batch_id * self.batch_size)]) 261 | yield batch 262 | batch = [] 263 | 264 | def __len__(self): 265 | length = 0 266 | for _ in self.sampler: 267 | length += 1 268 | return length 269 | 270 | 271 | def sample_at_index(rows, offset, sample_num): 272 | assert sample_num < len(rows) 273 | sample_half = int(sample_num / 2) 274 | if offset - sample_half <= 0: 275 | # sample start has to be 0 276 | samples = rows[:sample_num] 277 | elif offset + sample_half + (sample_num % 2) > len(rows): 278 | # sample end has to be an end 279 | samples = rows[-(sample_num + 1):-1] 280 | else: 281 | samples = rows[offset - sample_half:offset + sample_half + (sample_num % 2)] 282 | assert len(samples) == sample_num 283 | return samples 284 | 285 | def label_list_to_topology(labels): 286 | if isinstance(labels, np.ndarray): 287 | top_list = [] 288 | last_label = None 289 | for idx, label in enumerate(labels): 290 | if last_label is None or last_label != label: 291 | top_list.append((idx, label)) 292 | last_label = label 293 | return top_list 294 | 295 | if isinstance(labels, list): 296 | labels = torch.LongTensor(labels) 297 | 298 | if isinstance(labels, torch.LongTensor): 299 | zero_tensor = torch.LongTensor([0]) 300 | if labels.is_cuda: 301 | zero_tensor = zero_tensor.cuda() 302 | 303 | unique, count = torch.unique_consecutive(labels, return_counts=True) 304 | top_list = [torch.cat((zero_tensor, labels[0]))] 305 | prev_count = 0 306 | i = 0 307 | for _ in unique.split(1): 308 | if i == 0: 309 | i += 1 310 | continue 311 | prev_count += count[i - 1] 312 | top_list.append(torch.cat((prev_count.view(1), unique[i].view(1)))) 313 | i += 1 314 | return top_list 315 | 316 | 317 | 318 | def remapped_labels_hmm_to_orginal_labels(labels): 319 | 320 | if isinstance(labels, np.ndarray): 321 | zeros = np.zeros(labels.shape, dtype=np.long) 322 | ones = np.ones(labels.shape, dtype=np.long) 323 | twos = np.ones(labels.shape, dtype=np.long) * 2 324 | 325 | 326 | labels = np.where((labels >= 5) & (labels < 45), zeros, labels) 327 | labels = np.where((labels >= 45) & (labels < 85), ones, labels) 328 | labels = np.where(labels >= 85, twos, labels) 329 | 330 | return labels 331 | 332 | if isinstance(labels, list): 333 | labels = torch.LongTensor(labels) 334 | 335 | if isinstance(labels, torch.LongTensor): 336 | 337 | torch_zeros = torch.zeros(labels.size(), dtype=torch.long) 338 | torch_ones = torch.ones(labels.size(), dtype=torch.long) 339 | torch_twos = torch.ones(labels.size(), dtype=torch.long) * 2 340 | 341 | if labels.is_cuda: 342 | labels = labels.cuda() 343 | torch_zeros = labels.cuda() 344 | torch_ones = labels.cuda() 345 | torch_twos = labels.cuda() 346 | 347 | labels = torch.where((labels >= 5) & (labels < 45), torch_zeros, labels) 348 | labels = torch.where((labels >= 45) & (labels < 85), torch_ones, labels) 349 | labels = torch.where(labels >= 85, torch_twos, labels) 350 | 351 | return labels 352 | 353 | def batch_sizes_to_mask(batch_sizes: torch.Tensor) -> torch.Tensor: 354 | arange = torch.arange(batch_sizes[0], dtype=torch.int32) 355 | if batch_sizes.is_cuda: 356 | arange = arange.cuda() 357 | res = (arange.expand(batch_sizes.size(0), batch_sizes[0]) 358 | < batch_sizes.unsqueeze(1)).transpose(0, 1) 359 | return res 360 | 361 | def original_labels_to_fasta(label_list): 362 | sequence = "" 363 | for label in label_list: 364 | if label == 0: 365 | sequence = sequence + "M" 366 | if label == 1: 367 | sequence = sequence + "M" 368 | if label == 2: 369 | sequence = sequence + "S" 370 | if label == 3: 371 | sequence = sequence + "I" 372 | if label == 4: 373 | sequence = sequence + "O" 374 | if label == 5: 375 | sequence = sequence + "-" 376 | return sequence 377 | 378 | 379 | def get_predicted_type_from_labels(labels): 380 | if isinstance(labels, np.ndarray): 381 | zero = np.zeros(1, dtype=np.long) 382 | 383 | contains_0 = (labels == 0).sum() > 0 384 | contains_1 = (labels == 1).sum() > 0 385 | contains_2 = np.where((labels == 2).sum() > 0, zero + 1, zero) 386 | 387 | is_tm = np.where(contains_0 | contains_1, zero + 1, zero) 388 | 389 | return is_tm * contains_2 \ 390 | + ((is_tm - 1) * (is_tm - 1)) * (3 - contains_2) 391 | 392 | if isinstance(labels, torch.LongTensor): 393 | torch_zero = torch.zeros(1) 394 | 395 | if labels.is_cuda: 396 | torch_zero = torch_zero.cuda() 397 | 398 | contains_0 = (labels == 0).int().sum() > 0 399 | contains_1 = (labels == 1).int().sum() > 0 400 | contains_2 = torch.where((labels == 2).int().sum() > 0, torch_zero + 1, torch_zero) 401 | 402 | is_tm = torch.where(contains_0 | contains_1, torch_zero + 1, torch_zero) 403 | 404 | return is_tm * contains_2 \ 405 | + ((is_tm - 1) * (is_tm - 1)) * (3 - contains_2) 406 | 407 | 408 | 409 | 410 | def is_topologies_equal(topology_a, topology_b, minimum_seqment_overlap=5): 411 | if len(topology_a) != len(topology_b): 412 | return False 413 | for idx, (_position_a, label_a) in enumerate(topology_a): 414 | if label_a != topology_b[idx][1]: 415 | return False 416 | if label_a in (0, 1): 417 | overlap_segment_start = max(topology_a[idx][0], topology_b[idx][0]) 418 | overlap_segment_end = min(topology_a[idx + 1][0], topology_b[idx + 1][0]) 419 | if overlap_segment_end - overlap_segment_start < minimum_seqment_overlap: 420 | return False 421 | return True 422 | 423 | 424 | def parse_3line_format(lines): 425 | i = 0 426 | prot_list = [] 427 | while i < len(lines): 428 | if lines[i].strip() == "": 429 | i += 1 430 | continue 431 | prot_name_comment = lines[i] 432 | type_string = None 433 | cluster_id = None 434 | if prot_name_comment.__contains__(">"): 435 | i += 1 436 | prot_name = prot_name_comment.split("|")[0].split(">")[1] 437 | type_string = prot_name_comment.split("|")[1] 438 | cluster_id = int(prot_name_comment.split("|")[2]) 439 | else: 440 | # assume this is data 441 | prot_name = "> Unknown Protein Name" 442 | prot_aa_list = lines[i].upper() 443 | i += 1 444 | if len(prot_aa_list) > 6000: 445 | print("Discarding protein", prot_name, "as length larger than 6000:", 446 | len(prot_aa_list)) 447 | if i < len(lines) and not lines[i].__contains__(">"): 448 | i += 1 449 | else: 450 | if i < len(lines) and not lines[i].__contains__(">"): 451 | prot_topology_list = lines[i].upper() 452 | i += 1 453 | if prot_topology_list.__contains__("S"): 454 | if prot_topology_list.__contains__("M"): 455 | type_id = 1 456 | assert type_string == "SP+TM" 457 | else: 458 | type_id = 2 459 | assert type_string == "SP" 460 | else: 461 | if prot_topology_list.__contains__("M"): 462 | type_id = 0 463 | assert type_string == "TM" 464 | else: 465 | type_id = 3 466 | assert type_string == "GLOBULAR" 467 | else: 468 | type_id = None 469 | prot_topology_list = None 470 | prot_list.append((prot_name, prot_aa_list, prot_topology_list, 471 | type_id, cluster_id)) 472 | 473 | return prot_list 474 | 475 | 476 | def parse_datafile_from_disk(file): 477 | lines = list([line.strip() for line in open(file)]) 478 | return parse_3line_format(lines) 479 | 480 | 481 | def calculate_partitions(partitions_count, cluster_partitions, types): 482 | partition_distribution = torch.ones((partitions_count, 483 | len(torch.unique(types))), 484 | dtype=torch.long) 485 | partition_assignments = torch.zeros(cluster_partitions.shape[0], 486 | dtype=torch.long) 487 | 488 | for i in torch.unique(cluster_partitions): 489 | cluster_positions = (cluster_partitions == i).nonzero() 490 | cluster_types = types[cluster_positions] 491 | unique_types_in_cluster, type_count = torch.unique(cluster_types, return_counts=True) 492 | tmp_distribution = partition_distribution.clone() 493 | tmp_distribution[:, unique_types_in_cluster] += type_count 494 | relative_distribution = partition_distribution.double() / tmp_distribution.double() 495 | min_relative_distribution_group = torch.argmin(torch.sum(relative_distribution, dim=1)) 496 | partition_distribution[min_relative_distribution_group, 497 | unique_types_in_cluster] += type_count 498 | partition_assignments[cluster_positions] = min_relative_distribution_group 499 | 500 | write_out("Loaded data into the following partitions") 501 | write_out("[[ TM SP+TM SP Glob]") 502 | write_out(partition_distribution - torch.ones(partition_distribution.shape, 503 | dtype=torch.long)) 504 | return partition_assignments 505 | 506 | 507 | def load_data_from_disk(filename, partition_rotation=0): 508 | print("Loading data from disk...") 509 | data = parse_datafile_from_disk(filename) 510 | data_unzipped = list(zip(*data)) 511 | partitions = calculate_partitions( 512 | cluster_partitions=torch.LongTensor(np.array(data_unzipped[4])), 513 | types=torch.LongTensor(np.array(data_unzipped[3])), 514 | partitions_count=5) 515 | train_set = [] 516 | val_set = [] 517 | test_set = [] 518 | for idx, sample in enumerate(data): 519 | partition = int(partitions[idx]) # in range 0-4 520 | rotated = (partition + partition_rotation) % 5 521 | if int(rotated) <= 2: 522 | train_set.append(sample) 523 | elif int(rotated) == 3: 524 | val_set.append(sample) 525 | else: 526 | test_set.append(sample) 527 | 528 | print("Data splited as:", 529 | len(train_set), "train set", 530 | len(val_set), "validation set", 531 | len(test_set), "test set") 532 | return train_set, val_set, test_set 533 | 534 | 535 | def normalize_confusion_matrix(confusion_matrix): 536 | confusion_matrix = confusion_matrix.astype(np.float64) 537 | for i in range(4): 538 | accumulator = int(confusion_matrix[i].sum()) 539 | if accumulator != 0: 540 | confusion_matrix[4][i] /= accumulator * 0.01 # 0.01 to convert to percentage 541 | for k in range(5): 542 | if accumulator != 0: 543 | confusion_matrix[i][k] /= accumulator * 0.01 # 0.01 to convert to percentage 544 | else: 545 | confusion_matrix[i][k] = math.nan 546 | return confusion_matrix.round(2) 547 | 548 | def decode(emissions, batch_sizes, start_transitions, transitions, end_transitions): 549 | mask = batch_sizes_to_mask(batch_sizes) 550 | 551 | if emissions.is_cuda: 552 | mask = mask.cuda() 553 | crf_model = CRF(int(start_transitions.size(0))) 554 | initialize_crf_parameters(crf_model, 555 | start_transitions=start_transitions, 556 | transitions=transitions, 557 | end_transitions=end_transitions) 558 | labels_predicted = [] 559 | for l in crf_model.decode(emissions, mask=mask): 560 | val = torch.tensor(l).unsqueeze(1) 561 | if emissions.is_cuda: 562 | val = val.cuda() 563 | labels_predicted.append(val) 564 | 565 | 566 | predicted_labels = [] 567 | for l in labels_predicted: 568 | predicted_labels.append(remapped_labels_hmm_to_orginal_labels(l)) 569 | 570 | predicted_types_list = [] 571 | for p_label in predicted_labels: 572 | predicted_types_list.append(get_predicted_type_from_labels(p_label)) 573 | predicted_types = torch.cat(predicted_types_list) 574 | 575 | 576 | 577 | if emissions.is_cuda: 578 | predicted_types = predicted_types.cuda() 579 | 580 | # if all O's, change to all I's (by convention) 581 | torch_zero = torch.zeros(1, dtype=torch.long) 582 | if emissions.is_cuda: 583 | torch_zero = torch_zero.cuda() 584 | for idx, labels in enumerate(predicted_labels): 585 | predicted_labels[idx] = \ 586 | labels - torch.where(torch.eq(labels, 4).min() == 1, torch_zero + 1, torch_zero) 587 | 588 | return predicted_labels, predicted_types, list(map(label_list_to_topology, predicted_labels)) 589 | 590 | 591 | def decode_numpy(emissions, batch_sizes, start_transitions, transitions, end_transitions): 592 | labels_predicted = [] 593 | for l in numpy_viterbi_decode(emissions, 594 | batch_sizes=batch_sizes, 595 | start_transitions=start_transitions, 596 | transitions=transitions, 597 | end_transitions=end_transitions): 598 | val = np.expand_dims(np.array(l), 1) 599 | labels_predicted.append(val) 600 | 601 | 602 | predicted_labels = [] 603 | for l in labels_predicted: 604 | predicted_labels.append(remapped_labels_hmm_to_orginal_labels(l)) 605 | 606 | predicted_types_list = [] 607 | for p_label in predicted_labels: 608 | predicted_types_list.append(get_predicted_type_from_labels(p_label)) 609 | predicted_types = np.array(predicted_types_list).squeeze(axis=1) 610 | 611 | # if all O's, change to all I's (by convention) 612 | zero = np.zeros(1, dtype=np.long) 613 | 614 | for idx, labels in enumerate(predicted_labels): 615 | predicted_labels[idx] = \ 616 | labels - np.where((labels == 4).min() == 1, zero + 1, zero) 617 | 618 | return predicted_labels, \ 619 | predicted_types, \ 620 | list(map(label_list_to_topology, predicted_labels)) 621 | 622 | def numpy_viterbi_decode(emissions, 623 | batch_sizes, 624 | start_transitions, 625 | transitions, 626 | end_transitions) -> List[List[int]]: 627 | # emissions: (seq_length, batch_size, num_tags) 628 | # mask: (seq_length, batch_size) 629 | assert len(emissions.shape) == 3 630 | #assert emissions.shape[:2] == mask.shape 631 | #assert emissions.size(2) == self.num_tags 632 | #assert mask[0].all() 633 | 634 | seq_length = emissions.shape[0] 635 | batch_size = emissions.shape[1] 636 | 637 | # Start transition and first emission 638 | # shape: (batch_size, num_tags) 639 | score = start_transitions + emissions[0] 640 | history = [] 641 | 642 | # score is a tensor of size (batch_size, num_tags) where for every batch, 643 | # value at column j stores the score of the best tag sequence so far that ends 644 | # with tag j 645 | # history saves where the best tags candidate transitioned from; this is used 646 | # when we trace back the best tag sequence 647 | 648 | # Viterbi algorithm recursive case: we compute the score of the best tag sequence 649 | # for every possible next tag 650 | 651 | l = [] 652 | for i in batch_sizes: 653 | l.append(np.array([1] * i + [0] * (seq_length - i))) 654 | mask = np.array(l).T 655 | 656 | 657 | for i in range(1, seq_length): 658 | 659 | # Broadcast viterbi score for every possible next tag 660 | # shape: (batch_size, num_tags, 1) 661 | broadcast_score = np.expand_dims(score, 2) 662 | 663 | # Broadcast emission score for every possible current tag 664 | # shape: (batch_size, 1, num_tags) 665 | broadcast_emission = np.expand_dims(emissions[i], 1) 666 | 667 | # Compute the score tensor of size (batch_size, num_tags, num_tags) where 668 | # for each sample, entry at row i and column j stores the score of the best 669 | # tag sequence so far that ends with transitioning from tag i to tag j and emitting 670 | # shape: (batch_size, num_tags, num_tags) 671 | next_score = broadcast_score + transitions + broadcast_emission 672 | 673 | # Find the maximum score over all possible current tag 674 | # shape: (batch_size, num_tags) 675 | indices = next_score.argmax(axis=1) 676 | next_score = next_score.max(axis=1) 677 | 678 | 679 | # Set score to the next score if this timestep is valid (mask == 1) 680 | # and save the index that produces the next score 681 | # shape: (batch_size, num_tags) 682 | score = np.where(np.expand_dims(mask[i], 1), next_score, score) # pylint: disable=E1136 683 | history.append(indices) 684 | 685 | 686 | # End transition score 687 | # shape: (batch_size, num_tags) 688 | score += end_transitions 689 | 690 | # Now, compute the best path for each sample 691 | 692 | # shape: (batch_size,) 693 | seq_ends = batch_sizes - 1 694 | best_tags_list = [] 695 | 696 | for idx in range(batch_size): 697 | # Find the tag which maximizes the score at the last timestep; this is our best tag 698 | # for the last timestep 699 | best_last_tag = score[idx].argmax(axis=0) 700 | best_tags = [best_last_tag.item()] 701 | 702 | # We trace back where the best last tag comes from, append that to our best tag 703 | # sequence, and trace it back again, and so on 704 | for hist in reversed(history[:seq_ends[idx]]): 705 | best_last_tag = hist[idx][best_tags[-1]] 706 | best_tags.append(best_last_tag.item()) 707 | 708 | # Reverse the order because we start from the last timestep 709 | best_tags.reverse() 710 | best_tags_list.append(best_tags) 711 | 712 | return best_tags_list 713 | 714 | def initialize_crf_parameters(crf_model, 715 | start_transitions=None, 716 | end_transitions=None, 717 | transitions=None) -> None: 718 | """Initialize the transition parameters. 719 | 720 | The parameters will be initialized randomly from a uniform distribution 721 | between -0.1 and 0.1, unless given explicitly as an argument. 722 | """ 723 | if start_transitions is None: 724 | torch.nn.init.uniform(crf_model.start_transitions, -0.1, 0.1) 725 | else: 726 | crf_model.start_transitions.data = start_transitions 727 | if end_transitions is None: 728 | torch.nn.init.uniform(crf_model.end_transitions, -0.1, 0.1) 729 | else: 730 | crf_model.end_transitions.data = end_transitions 731 | if transitions is None: 732 | torch.nn.init.uniform(crf_model.transitions, -0.1, 0.1) 733 | else: 734 | crf_model.transitions.data = transitions 735 | -------------------------------------------------------------------------------- /git-hooks/pre-commit: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # 3 | # An example hook script to verify what is about to be committed. 4 | # Called by "git commit" with no arguments. The hook should 5 | # exit with non-zero status after issuing an appropriate message if 6 | # it wants to stop the commit. 7 | # 8 | # To enable this hook, rename this file to "pre-commit". 9 | 10 | if git rev-parse --verify HEAD >/dev/null 2>&1 11 | then 12 | against=HEAD 13 | else 14 | # Initial commit: diff against an empty tree object 15 | against=$(git hash-object -t tree /dev/null) 16 | fi 17 | 18 | # If you want to allow non-ASCII filenames set this variable to true. 19 | allownonascii=$(git config --bool hooks.allownonascii) 20 | 21 | # Redirect output to stderr. 22 | exec 1>&2 23 | 24 | # Cross platform projects tend to avoid non-ASCII filenames; prevent 25 | # them from being added to the repository. We exploit the fact that the 26 | # printable range starts at the space character and ends with tilde. 27 | if [ "$allownonascii" != "true" ] && 28 | # Note that the use of brackets around a tr range is ok here, (it's 29 | # even required, for portability to Solaris 10's /usr/bin/tr), since 30 | # the square bracket bytes happen to fall in the designated range. 31 | test $(git diff --cached --name-only --diff-filter=A -z $against | 32 | LC_ALL=C tr -d '[ -~]\0' | wc -c) != 0 33 | then 34 | cat <<\EOF 35 | Error: Attempt to add a non-ASCII file name. 36 | 37 | This can cause problems if you want to work with people on other platforms. 38 | 39 | To be portable it is advisable to rename the file. 40 | 41 | If you know what you are doing you can disable this check using: 42 | 43 | git config hooks.allownonascii true 44 | EOF 45 | exit 1 46 | fi 47 | 48 | echo "Running pylint" 49 | 50 | find . -type f -name "*.py" | grep -v pytorchcrf | xargs pipenv run pylint --rcfile=.pylintrc 51 | 52 | if [ $? -eq 0 ] 53 | then 54 | echo "pylint succeeded" 55 | else 56 | echo "pylint failed" >&2 57 | exit 1 58 | fi 59 | 60 | # If there are whitespace errors, print the offending file names and fail. 61 | exec git diff-index --check --cached $against -- 62 | -------------------------------------------------------------------------------- /op_cli.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the OpenProtein project. 3 | 4 | For license information, please see the LICENSE file in the root directory. 5 | """ 6 | 7 | import argparse 8 | import importlib 9 | import sys 10 | import torch 11 | from dashboard import start_dashboard_server 12 | 13 | from util import write_out 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser( 17 | description="OpenProtein version 0.1", 18 | conflict_handler='resolve') 19 | parser.add_argument('--silent', dest='silent', action='store_true', 20 | help='Dont print verbose debug statements.') 21 | parser.add_argument('--hide-ui', dest='hide_ui', action='store_true', 22 | default=False, help='Hide loss graph and ' 23 | 'visualization UI while training goes on.') 24 | parser.add_argument('--evaluate-on-test', dest='evaluate_on_test', action='store_true', 25 | default=False, help='Run model of test data.') 26 | parser.add_argument('--use-gpu', dest='use_gpu', action='store_true', 27 | default=False, help='Use GPU.') 28 | parser.add_argument('--eval-interval', dest='eval_interval', type=int, 29 | default=10, help='Evaluate model on validation set every n minibatches.') 30 | parser.add_argument('--min-updates', dest='minimum_updates', type=int, 31 | default=2000, help='Minimum number of minibatch iterations.') 32 | parser.add_argument('--minibatch-size', dest='minibatch_size', type=int, 33 | default=16, help='Size of each minibatch.') 34 | parser.add_argument('--experiment-id', dest='experiment_id', type=str, 35 | default="example", help='Which experiment to run.') 36 | args, _ = parser.parse_known_args() 37 | 38 | if args.hide_ui: 39 | write_out("Live plot deactivated, see output folder for plot.") 40 | 41 | use_gpu = args.use_gpu 42 | 43 | if use_gpu and not torch.cuda.is_available(): 44 | write_out("Error: --use-gpu was set, but no GPU is available.") 45 | sys.exit(1) 46 | 47 | if not args.hide_ui: 48 | # start web server 49 | start_dashboard_server() 50 | 51 | experiment = importlib.import_module("experiments." + args.experiment_id) 52 | experiment.run_experiment(parser, use_gpu) 53 | -------------------------------------------------------------------------------- /openprotein.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the OpenProtein project. 3 | 4 | For license information, please see the LICENSE file in the root directory. 5 | """ 6 | import math 7 | import time 8 | import torch 9 | import torch.nn as nn 10 | from util import calculate_dihedral_angles_over_minibatch, calc_angular_difference, \ 11 | write_out, calculate_dihedral_angles, \ 12 | get_structure_from_angles, write_to_pdb, calc_rmsd,\ 13 | calc_drmsd, get_backbone_positions_from_angles 14 | 15 | 16 | class BaseModel(nn.Module): 17 | def __init__(self, use_gpu, embedding_size): 18 | super(BaseModel, self).__init__() 19 | 20 | # initialize model variables 21 | self.use_gpu = use_gpu 22 | self.embedding_size = embedding_size 23 | self.historical_rmsd_avg_values = list() 24 | self.historical_drmsd_avg_values = list() 25 | 26 | def get_embedding_size(self): 27 | return self.embedding_size 28 | 29 | def embed(self, original_aa_string): 30 | max_len = max([s.size(0) for s in original_aa_string]) 31 | seqs = [] 32 | for tensor in original_aa_string: 33 | padding_to_add = torch.zeros(max_len-tensor.size(0)).int() 34 | seqs.append(torch.cat((tensor, padding_to_add))) 35 | 36 | data = torch.stack(seqs).transpose(0, 1) 37 | 38 | # one-hot encoding 39 | start_compute_embed = time.time() 40 | arange_tensor = torch.arange(21).int().repeat( 41 | len(original_aa_string), 1 42 | ).unsqueeze(0).repeat(max_len, 1, 1) 43 | data_tensor = data.unsqueeze(2).repeat(1, 1, 21) 44 | embed_tensor = (arange_tensor == data_tensor).float() 45 | 46 | if self.use_gpu: 47 | embed_tensor = embed_tensor.cuda() 48 | 49 | end = time.time() 50 | write_out("Embed time:", end - start_compute_embed) 51 | 52 | return embed_tensor 53 | 54 | def compute_loss(self, minibatch): 55 | (original_aa_string, actual_coords_list, _) = minibatch 56 | 57 | emissions, _backbone_atoms_padded, _batch_sizes = \ 58 | self._get_network_emissions(original_aa_string) 59 | actual_coords_list_padded = torch.nn.utils.rnn.pad_sequence(actual_coords_list) 60 | if self.use_gpu: 61 | actual_coords_list_padded = actual_coords_list_padded.cuda() 62 | start = time.time() 63 | if isinstance(_batch_sizes[0], int): 64 | _batch_sizes = torch.tensor(_batch_sizes) 65 | emissions_actual, _ = \ 66 | calculate_dihedral_angles_over_minibatch(actual_coords_list_padded, 67 | _batch_sizes, 68 | self.use_gpu) 69 | # drmsd_avg = calc_avg_drmsd_over_minibatch(backbone_atoms_padded, 70 | # actual_coords_list_padded, 71 | # batch_sizes) 72 | write_out("Angle calculation time:", time.time() - start) 73 | if self.use_gpu: 74 | emissions_actual = emissions_actual.cuda() 75 | # drmsd_avg = drmsd_avg.cuda() 76 | angular_loss = calc_angular_difference(emissions, emissions_actual) 77 | 78 | return angular_loss # + drmsd_avg 79 | 80 | def forward(self, original_aa_string): 81 | return self._get_network_emissions(original_aa_string) 82 | 83 | def evaluate_model(self, data_loader): 84 | loss = 0 85 | data_total = [] 86 | dRMSD_list = [] 87 | RMSD_list = [] 88 | for _, data in enumerate(data_loader, 0): 89 | primary_sequence, tertiary_positions, _mask = data 90 | start = time.time() 91 | predicted_angles, backbone_atoms, batch_sizes = self(primary_sequence) 92 | write_out("Apply model to validation minibatch:", time.time() - start) 93 | 94 | if predicted_angles == []: 95 | # model didn't provide angles, so we'll compute them here 96 | output_angles, _ = calculate_dihedral_angles_over_minibatch(backbone_atoms, 97 | batch_sizes, 98 | self.use_gpu) 99 | else: 100 | output_angles = predicted_angles 101 | 102 | cpu_predicted_angles = output_angles.transpose(0, 1).cpu().detach() 103 | if backbone_atoms == []: 104 | # model didn't provide backbone atoms, we need to compute that 105 | output_positions, _ = \ 106 | get_backbone_positions_from_angles(predicted_angles, 107 | batch_sizes, 108 | self.use_gpu) 109 | else: 110 | output_positions = backbone_atoms 111 | 112 | cpu_predicted_backbone_atoms = output_positions.transpose(0, 1).cpu().detach() 113 | 114 | minibatch_data = list(zip(primary_sequence, 115 | tertiary_positions, 116 | cpu_predicted_angles, 117 | cpu_predicted_backbone_atoms)) 118 | data_total.extend(minibatch_data) 119 | start = time.time() 120 | for primary_sequence, tertiary_positions, _predicted_pos, predicted_backbone_atoms\ 121 | in minibatch_data: 122 | actual_coords = tertiary_positions.transpose(0, 1).contiguous().view(-1, 3) 123 | 124 | predicted_coords = predicted_backbone_atoms[:len(primary_sequence)]\ 125 | .transpose(0, 1).contiguous().view(-1, 3).detach() 126 | rmsd = calc_rmsd(predicted_coords, actual_coords) 127 | drmsd = calc_drmsd(predicted_coords, actual_coords) 128 | RMSD_list.append(rmsd) 129 | dRMSD_list.append(drmsd) 130 | error = rmsd 131 | loss += error 132 | 133 | end = time.time() 134 | write_out("Calculate validation loss for minibatch took:", end - start) 135 | loss /= data_loader.dataset.__len__() 136 | self.historical_rmsd_avg_values.append(float(torch.Tensor(RMSD_list).mean())) 137 | self.historical_drmsd_avg_values.append(float(torch.Tensor(dRMSD_list).mean())) 138 | 139 | prim = data_total[0][0] 140 | pos = data_total[0][1] 141 | pos_pred = data_total[0][3] 142 | if self.use_gpu: 143 | pos = pos.cuda() 144 | pos_pred = pos_pred.cuda() 145 | angles = calculate_dihedral_angles(pos, self.use_gpu) 146 | angles_pred = calculate_dihedral_angles(pos_pred, self.use_gpu) 147 | write_to_pdb(get_structure_from_angles(prim, angles), "test") 148 | write_to_pdb(get_structure_from_angles(prim, angles_pred), "test_pred") 149 | 150 | data = {} 151 | data["pdb_data_pred"] = open("output/protein_test_pred.pdb", "r").read() 152 | data["pdb_data_true"] = open("output/protein_test.pdb", "r").read() 153 | data["phi_actual"] = list([math.degrees(float(v)) for v in angles[1:, 1]]) 154 | data["psi_actual"] = list([math.degrees(float(v)) for v in angles[:-1, 2]]) 155 | data["phi_predicted"] = list([math.degrees(float(v)) for v in angles_pred[1:, 1]]) 156 | data["psi_predicted"] = list([math.degrees(float(v)) for v in angles_pred[:-1, 2]]) 157 | data["rmsd_avg"] = self.historical_rmsd_avg_values 158 | data["drmsd_avg"] = self.historical_drmsd_avg_values 159 | 160 | prediction_data = None 161 | 162 | return (loss, data, prediction_data) 163 | -------------------------------------------------------------------------------- /output/.gitignore: -------------------------------------------------------------------------------- 1 | *.* 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /output/models/.gitignore: -------------------------------------------------------------------------------- 1 | *.* 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /output/models/2019-01-30_00_38_46-TRAIN-LR0_01-MB1.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biolib/openprotein/3f474d3b1c00af0f06d88bf1ad78f2c34763341d/output/models/2019-01-30_00_38_46-TRAIN-LR0_01-MB1.model -------------------------------------------------------------------------------- /output/predictions/.gitignore: -------------------------------------------------------------------------------- 1 | *.* 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /prediction.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the OpenProtein project. 3 | 4 | For license information, please see the LICENSE file in the root directory. 5 | """ 6 | 7 | import argparse 8 | import glob 9 | import os 10 | import torch 11 | import torch.onnx 12 | 13 | 14 | from util import encode_primary_string, get_structure_from_angles, write_to_pdb, \ 15 | calculate_dihedral_angles_over_minibatch 16 | 17 | 18 | def prediction(): 19 | 20 | list_of_files = glob.glob('output/models/*') 21 | default_model_path = max(list_of_files, key=os.path.getctime) 22 | 23 | parser = argparse.ArgumentParser( 24 | description="OpenProtein - Prediction CLI" 25 | ) 26 | parser.add_argument('--input_sequence', dest='input_sequence') 27 | parser.add_argument('--model_path', dest='model_path', default=default_model_path) 28 | parser.add_argument('--use_gpu', dest='use_gpu', default=False, type=bool) 29 | 30 | args, _ = parser.parse_known_args() 31 | 32 | print("Using model:", args.model_path) 33 | 34 | model = torch.load(args.model_path) 35 | 36 | input_sequences = [args.input_sequence] 37 | 38 | input_sequences_encoded = list(torch.IntTensor(encode_primary_string(aa)) 39 | for aa in input_sequences) 40 | 41 | predicted_dihedral_angles, predicted_backbone_atoms, batch_sizes = \ 42 | model(input_sequences_encoded) 43 | 44 | if predicted_dihedral_angles == []: 45 | predicted_dihedral_angles, _ = calculate_dihedral_angles_over_minibatch( 46 | predicted_backbone_atoms, 47 | batch_sizes, 48 | args.use_gpu) 49 | write_to_pdb( 50 | get_structure_from_angles(input_sequences_encoded[0], predicted_dihedral_angles[:, 0]), 51 | "prediction" 52 | ) 53 | 54 | print("Wrote prediction to output/protein_prediction.pdb") 55 | 56 | 57 | prediction() 58 | -------------------------------------------------------------------------------- /preprocessing.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the OpenProtein project. 3 | 4 | For license information, please see the LICENSE file in the root directory. 5 | """ 6 | 7 | import glob 8 | import os.path 9 | import os 10 | import platform 11 | import re 12 | import numpy as np 13 | import h5py 14 | import torch 15 | 16 | from util import calculate_dihedral_angles_over_minibatch, \ 17 | get_backbone_positions_from_angles, encode_primary_string, write_out 18 | 19 | 20 | MAX_SEQUENCE_LENGTH = 2000 21 | 22 | def process_raw_data(use_gpu, raw_data_path="data/raw/*", force_pre_processing_overwrite=True): 23 | write_out("Starting pre-processing of raw data...") 24 | input_files = glob.glob(raw_data_path) 25 | write_out(input_files) 26 | input_files_filtered = filter_input_files(input_files) 27 | for file_path in input_files_filtered: 28 | if platform.system() == 'Windows': 29 | filename = file_path.split('\\')[-1] 30 | else: 31 | filename = file_path.split('/')[-1] 32 | preprocessed_file_name = "data/preprocessed/" + filename + ".hdf5" 33 | 34 | # check if we should remove any previously processed files 35 | if os.path.isfile(preprocessed_file_name): 36 | write_out("Preprocessed file for " + filename + " already exists.") 37 | if force_pre_processing_overwrite: 38 | write_out("force_pre_processing_overwrite flag set to True, " 39 | "overwriting old file...") 40 | os.remove(preprocessed_file_name) 41 | else: 42 | write_out("Skipping pre-processing for this file...") 43 | 44 | if not os.path.isfile(preprocessed_file_name): 45 | process_file(file_path, preprocessed_file_name, use_gpu) 46 | write_out("Completed pre-processing.") 47 | 48 | 49 | def read_protein_from_file(file_pointer): 50 | """The algorithm Defining Secondary Structure of Proteins (DSSP) uses information on e.g. the 51 | position of atoms and the hydrogen bonds of the molecule to determine the secondary structure 52 | (helices, sheets...). 53 | """ 54 | dict_ = {} 55 | _dssp_dict = {'L': 0, 'H': 1, 'B': 2, 'E': 3, 'G': 4, 'I': 5, 'T': 6, 'S': 7} 56 | _mask_dict = {'-': 0, '+': 1} 57 | 58 | while True: 59 | next_line = file_pointer.readline() 60 | if next_line == '[ID]\n': 61 | id_ = file_pointer.readline()[:-1] 62 | dict_.update({'id': id_}) 63 | elif next_line == '[PRIMARY]\n': 64 | primary = encode_primary_string(file_pointer.readline()[:-1]) 65 | dict_.update({'primary': primary}) 66 | elif next_line == '[EVOLUTIONARY]\n': 67 | evolutionary = [] 68 | for _residue in range(21): 69 | evolutionary.append(\ 70 | [float(step) for step in file_pointer.readline().split()]) 71 | dict_.update({'evolutionary': evolutionary}) 72 | elif next_line == '[SECONDARY]\n': 73 | secondary = list([_dssp_dict[dssp] for dssp in file_pointer.readline()[:-1]]) 74 | dict_.update({'secondary': secondary}) 75 | elif next_line == '[TERTIARY]\n': 76 | tertiary = [] 77 | # 3 dimension 78 | for _axis in range(3): 79 | tertiary.append(\ 80 | [float(coord) for coord in file_pointer.readline().split()]) 81 | dict_.update({'tertiary': tertiary}) 82 | elif next_line == '[MASK]\n': 83 | mask = list([_mask_dict[aa] for aa in file_pointer.readline()[:-1]]) 84 | dict_.update({'mask': mask}) 85 | mask_str = ''.join(map(str, mask)) 86 | 87 | write_out("-------------") 88 | # Check for missing AA coordinates 89 | missing_internal_aa = False 90 | sequence_end = len(mask) # for now, assume no C-terminal truncation needed 91 | write_out("Reading the protein " + id_) 92 | if re.search(r'1+0+1+', mask_str) is not None: # indicates missing coordinates 93 | missing_internal_aa = True 94 | write_out("One or more internal coordinates missing. Protein is discarded.") 95 | elif re.search(r'^0*$', mask_str) is not None: # indicates no coordinates at all 96 | missing_internal_aa = True 97 | write_out("One or more internal coordinates missing. It will be discarded.") 98 | else: 99 | if mask[0] == 0: 100 | write_out("Missing coordinates in the N-terminal end. Truncating protein.") 101 | # investigate when the sequence with coordinates start and finish 102 | sequence_start = re.search(r'1', mask_str).start() 103 | if re.search(r'10', mask_str) is not None: # missing coords in the C-term end 104 | sequence_end = re.search(r'10', mask_str).start() + 1 105 | write_out("Missing coordinates in the C-term end. Truncating protein.") 106 | write_out("Analyzing amino acids", sequence_start + 1, "-", sequence_end) 107 | 108 | # split lists in dict to have the seq with coords 109 | # separated from what should not be analysed 110 | if 'secondary' in dict_: 111 | dict_.update({'secondary': secondary[sequence_start:sequence_end]}) 112 | dict_.update({'primary': primary[sequence_start:sequence_end]}) 113 | dict_.update({'mask': mask[sequence_start:sequence_end]}) 114 | for elem in range(len(dict_['evolutionary'])): 115 | dict_['evolutionary'][elem] = \ 116 | dict_['evolutionary'][elem][sequence_start:sequence_end] 117 | for elem in range(len(dict_['tertiary'])): 118 | dict_['tertiary'][elem] = \ 119 | dict_['tertiary'][elem][sequence_start * 3:sequence_end * 3] 120 | 121 | elif next_line == '\n': 122 | return dict_, missing_internal_aa 123 | elif next_line == '': 124 | if dict_: 125 | return dict_, missing_internal_aa 126 | else: 127 | return None, False 128 | 129 | def process_file(input_file, output_file, use_gpu): 130 | write_out("Processing raw data file", input_file) 131 | # create output file 132 | file = h5py.File(output_file, 'w') 133 | current_buffer_size = 1 134 | current_buffer_allocation = 0 135 | dset1 = file.create_dataset('primary', (current_buffer_size, MAX_SEQUENCE_LENGTH), 136 | maxshape=(None, MAX_SEQUENCE_LENGTH), dtype='int32') 137 | dset2 = file.create_dataset('tertiary', (current_buffer_size, MAX_SEQUENCE_LENGTH, 9), 138 | maxshape=(None, MAX_SEQUENCE_LENGTH, 9), dtype='float') 139 | dset3 = file.create_dataset('mask', (current_buffer_size, MAX_SEQUENCE_LENGTH), 140 | maxshape=(None, MAX_SEQUENCE_LENGTH), 141 | dtype='uint8') 142 | 143 | input_file_pointer = open(input_file, "r") 144 | 145 | while True: 146 | # while there's more proteins to process 147 | next_protein, missing_aa = read_protein_from_file(input_file_pointer) 148 | if next_protein is None: # no more proteins to process 149 | break 150 | 151 | sequence_length = len(next_protein['primary']) 152 | 153 | if sequence_length > MAX_SEQUENCE_LENGTH: 154 | write_out("Dropping protein as length too long:", sequence_length) 155 | continue 156 | if missing_aa is True: 157 | continue 158 | if current_buffer_allocation >= current_buffer_size: 159 | current_buffer_size = current_buffer_size + 1 160 | dset1.resize((current_buffer_size, MAX_SEQUENCE_LENGTH)) 161 | dset2.resize((current_buffer_size, MAX_SEQUENCE_LENGTH, 9)) 162 | dset3.resize((current_buffer_size, MAX_SEQUENCE_LENGTH)) 163 | 164 | primary_padded = np.zeros(MAX_SEQUENCE_LENGTH) 165 | tertiary_padded = np.zeros((9, MAX_SEQUENCE_LENGTH)) 166 | mask_padded = np.zeros(MAX_SEQUENCE_LENGTH) 167 | 168 | # masking and padding here happens so that the stored dataset is of the same size. 169 | # when the data is loaded in this padding is removed again. 170 | primary_padded[:sequence_length] = next_protein['primary'] 171 | t_transposed = np.ravel(np.array(next_protein['tertiary']).T) 172 | t_reshaped = np.reshape(t_transposed, (sequence_length, 9)).T 173 | tertiary_padded[:, :sequence_length] = t_reshaped 174 | mask_padded[:sequence_length] = next_protein['mask'] 175 | mask = torch.Tensor(mask_padded).type(dtype=torch.bool) 176 | prim = torch.masked_select(torch.Tensor(primary_padded)\ 177 | .type(dtype=torch.long), mask) 178 | pos = torch.masked_select(torch.Tensor(tertiary_padded), mask)\ 179 | .view(9, -1).transpose(0, 1).unsqueeze(1) 180 | pos_angstrom = pos / 100 181 | 182 | if use_gpu: 183 | pos_angstrom = pos_angstrom.cuda() 184 | 185 | # map to angles and back to tertiary 186 | angles, batch_sizes = calculate_dihedral_angles_over_minibatch(pos_angstrom, 187 | torch.tensor([len(prim)]), 188 | use_gpu=use_gpu) 189 | 190 | tertiary, _ = get_backbone_positions_from_angles(angles, 191 | batch_sizes, 192 | use_gpu=use_gpu) 193 | tertiary = tertiary.squeeze(1) 194 | 195 | # create variables to store padded sequences in 196 | primary_padded = np.zeros(MAX_SEQUENCE_LENGTH) 197 | tertiary_padded = np.zeros((MAX_SEQUENCE_LENGTH, 9)) 198 | mask_padded = np.zeros(MAX_SEQUENCE_LENGTH) 199 | 200 | # store padded sequences 201 | length_after_mask_removed = len(prim) 202 | primary_padded[:length_after_mask_removed] = prim.data.cpu().numpy() 203 | tertiary_padded[:length_after_mask_removed, :] = tertiary.data.cpu().numpy() 204 | mask_padded[:length_after_mask_removed] = np.ones(length_after_mask_removed) 205 | 206 | # save padded sequences on disk 207 | dset1[current_buffer_allocation] = primary_padded 208 | dset2[current_buffer_allocation] = tertiary_padded 209 | dset3[current_buffer_allocation] = mask_padded 210 | current_buffer_allocation += 1 211 | if current_buffer_allocation == 0: 212 | write_out("Preprocessing was selected but no proteins in the input file " 213 | "were accepted. Please check your input.") 214 | os._exit(1) 215 | write_out("Wrote output to", current_buffer_allocation, "proteins to", output_file) 216 | 217 | 218 | def filter_input_files(input_files): 219 | disallowed_file_endings = (".gitignore", ".DS_Store") 220 | return list(filter(lambda x: not x.endswith(disallowed_file_endings), input_files)) 221 | -------------------------------------------------------------------------------- /preprocessing_cli.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the OpenProtein project. 3 | 4 | For license information, please see the LICENSE file in the root directory. 5 | """ 6 | import argparse 7 | import torch 8 | from preprocessing import process_raw_data 9 | from util import write_out 10 | 11 | print("------------------------") 12 | print("--- OpenProtein v0.1 ---") 13 | print("------------------------") 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser(description="OpenProtein version 0.1") 18 | parser.add_argument('--no_force_pre_processing_overwrite', 19 | dest='no_force_pre_processing_overwrite', 20 | action='store_false', 21 | help='Force overwrite existing preprocessed files', default=True) 22 | args, _unknown = parser.parse_known_args() 23 | 24 | uge_gpu = False 25 | if torch.cuda.is_available(): 26 | write_out("CUDA is available, using GPU") 27 | uge_gpu = True 28 | 29 | process_raw_data(uge_gpu, force_pre_processing_overwrite=args.force_pre_processing_overwrite) 30 | 31 | 32 | main() 33 | -------------------------------------------------------------------------------- /tests/onnx_export.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the OpenProtein project. 3 | 4 | For license information, please see the LICENSE file in the root directory. 5 | """ 6 | 7 | import glob 8 | import os 9 | import torch 10 | import torch.onnx 11 | 12 | from util import encode_primary_string 13 | 14 | def onnx_from_model(model, input_str, path): 15 | """Export to onnx""" 16 | torch.onnx.export(model, input_str, path, opset_version=10, verbose=True) 17 | 18 | def predict(): 19 | list_of_files = glob.glob('output/models/*') # * means all if need specific format then *.csv 20 | model_path = max(list_of_files, key=os.path.getctime) 21 | 22 | print("Generating ONNX from model:", model_path) 23 | model = torch.load(model_path) 24 | 25 | input_sequences = [ 26 | "SRSLVISTINQISEDSKEFYFTLDNGKTMFPSNSQAWGGEKFENGQRAFVIFNELEQPVNGYDYNIQVRDITKVLTKEIVTMDDEE" \ 27 | "NTEEKIGDDKINATYMWISKDKKYLTIEFQYYSTHSEDKKHFLNLVINNKDNTDDEYINLEFRHNSERDSPDHLGEGYVSFKLDKI" \ 28 | "EEQIEGKKGLNIRVRTLYDGIKNYKVQFP"] 29 | 30 | input_sequences_encoded = list(torch.IntTensor(encode_primary_string(aa)) 31 | for aa in input_sequences) 32 | 33 | print("Exporting to ONNX...") 34 | 35 | output_path = "./tests/output/openprotein.onnx" 36 | onnx_from_model(model, input_sequences_encoded, output_path) 37 | 38 | print("Wrote ONNX to", output_path) 39 | 40 | predict() 41 | -------------------------------------------------------------------------------- /tests/onnx_export_tmhmm3.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the OpenProtein project. 3 | 4 | For license information, please see the LICENSE file in the root directory. 5 | """ 6 | 7 | import glob 8 | import os 9 | import torch 10 | import torch.onnx 11 | import numpy as np 12 | 13 | from experiments.tmhmm3 import decode, decode_numpy 14 | from util import load_model_from_disk 15 | 16 | 17 | def onnx_from_model(model, input_str, path): 18 | """Export to onnx""" 19 | torch.onnx.export(model, input_str, path, 20 | enable_onnx_checker=True, opset_version=10, verbose=True, 21 | input_names=['embedded_sequences', 'mask'], # the model's input names 22 | output_names=['emissions', 23 | 'crf_start_transitions', 24 | 'crf_transitions', 25 | 'crf_end_transitions'], # the model's output names 26 | dynamic_axes={ 27 | 'mask': {0: 'batch_size'}, 28 | 'embedded_sequences': {0: 'max_seq_length', 1: 'batch_size'}, 29 | 'emissions': {0: 'max_seq_length', 1: 'batch_size'}, 30 | } 31 | ) 32 | 33 | def predict(): 34 | list_of_files = glob.glob('output/models/*') # * means all if need specific format then *.csv 35 | model_path = max(list_of_files, key=os.path.getctime) 36 | 37 | print("Generating ONNX from model:", model_path) 38 | model = load_model_from_disk(model_path, force_cpu=True) 39 | 40 | input_sequences = [ 41 | "AAAAAAA", "AAA"] 42 | 43 | input_sequences_embedded = [x for x in model.embed(input_sequences)] 44 | 45 | input_sequences_padded = torch.nn.utils.rnn.pad_sequence(input_sequences_embedded) 46 | 47 | batch_sizes_list = [] 48 | for x in input_sequences: 49 | batch_sizes_list.append(len(x)) 50 | 51 | batch_sizes = torch.IntTensor(batch_sizes_list) 52 | 53 | emmissions, start_transitions, transitions, end_transitions = model(input_sequences_padded) 54 | predicted_labels, predicted_types, predicted_topologies = decode(emmissions, 55 | batch_sizes, 56 | start_transitions, 57 | transitions, 58 | end_transitions) 59 | predicted_labels_2, predicted_types_2, predicted_topologies_2 = \ 60 | decode_numpy(emmissions.detach().numpy(), 61 | batch_sizes.detach().numpy(), 62 | start_transitions.detach().numpy(), 63 | transitions.detach().numpy(), 64 | end_transitions.detach().numpy()) 65 | for idx, val in enumerate(predicted_labels): 66 | assert np.array_equal(val.detach().numpy(), predicted_labels_2[idx]) 67 | assert np.array_equal(predicted_types.detach().numpy(), predicted_types_2) 68 | for idx, val in enumerate(predicted_topologies): 69 | for idx2, val2 in enumerate(val): 70 | assert np.array_equal(val2.detach().numpy(), predicted_topologies_2[idx][idx2]) 71 | 72 | print("Exporting to ONNX...") 73 | 74 | output_path = "./tests/output/tmhmm3.onnx" 75 | 76 | onnx_from_model(model, (input_sequences_padded), output_path) 77 | 78 | print("Wrote ONNX to", output_path) 79 | 80 | predict() 81 | -------------------------------------------------------------------------------- /tests/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /tests/test_onnx_export.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the OpenProtein project. 3 | 4 | For license information, please see the LICENSE file in the root directory. 5 | """ 6 | 7 | import os 8 | import subprocess 9 | import sys 10 | 11 | from op_cli import main 12 | from preprocessing import process_raw_data 13 | 14 | def test(): 15 | 16 | process_raw_data(False, raw_data_path="tests/data/raw/*", 17 | force_pre_processing_overwrite=True) 18 | 19 | # find original and transformed coordinates 20 | """origcoords = pos.numpy() 21 | origcoords = np.resize(origcoords, (len(origcoords) * 3, 3)) 22 | write_pdb("origcoords.pdb", protein_id_to_str(prim.tolist()), origcoords) 23 | transf = tertiary.numpy() 24 | transf = np.resize(transf, (len(transf) * 3, 3)) 25 | write_pdb("transf.pdb", protein_id_to_str(prim.tolist()), transf) 26 | sup = SVDSuperimposer() 27 | sup.set(transf, origcoords) 28 | sup.run() 29 | # rotation and transformation for the superimposer 30 | #rot, tran = sup.get_rotran() 31 | #print(rot, tran) 32 | rms = sup.get_rms() 33 | print("RMS", rms) 34 | # The segment below finds the structure of the orignal coordinates and the transformed 35 | encoded = prim.tolist() 36 | pos_angles = calculate_dihedral_angles(torch.squeeze(pos), use_gpu) 37 | ter_angles = calculate_dihedral_angles(tertiary, use_gpu) 38 | pos_struc = get_structure_from_angles(encoded, pos_angles) 39 | ter_struc = get_structure_from_angles(encoded, ter_angles) 40 | write_to_pdb(pos_struc, "transformed") 41 | write_to_pdb(ter_struc, "original")""" 42 | 43 | 44 | sys.argv = ["__main__.py", "--min-updates", "1", "--eval-interval", "1", 45 | "--experiment-id", "rrn", "--hide-ui", 46 | "--file", "data/preprocessed/testfile.txt.hdf5"] 47 | main() 48 | 49 | path_to_onnx_file = './tests/output/openprotein.onnx' 50 | if os.path.exists(path_to_onnx_file): 51 | os.remove(path_to_onnx_file) 52 | sub_process = subprocess.Popen(["pipenv", "run", "python", "./tests/onnx_export.py"]) 53 | stdout, stderr = sub_process.communicate() 54 | print(stdout, stderr) 55 | assert sub_process.returncode == 0 56 | assert os.path.exists(path_to_onnx_file) 57 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the OpenProtein project. 3 | 4 | For license information, please see the LICENSE file in the root directory. 5 | """ 6 | 7 | import time 8 | import numpy as np 9 | import requests 10 | import torch.optim as optim 11 | from util import set_experiment_id, write_out, write_model_to_disk, write_result_summary 12 | 13 | 14 | 15 | def train_model(data_set_identifier, model, train_loader, validation_loader, 16 | learning_rate, minibatch_size=64, eval_interval=50, hide_ui=False, 17 | use_gpu=False, minimum_updates=1000): 18 | set_experiment_id(data_set_identifier, learning_rate, minibatch_size) 19 | 20 | validation_dataset_size = validation_loader.dataset.__len__() 21 | 22 | if use_gpu: 23 | model = model.cuda() 24 | 25 | optimizer = optim.Adam(model.parameters(), lr=learning_rate) 26 | 27 | sample_num = list() 28 | train_loss_values = list() 29 | validation_loss_values = list() 30 | 31 | best_model_loss = 1e20 32 | best_model_minibatch_time = None 33 | best_model_path = None 34 | _best_json_data = None 35 | stopping_condition_met = False 36 | minibatches_proccesed = 0 37 | 38 | while not stopping_condition_met: 39 | optimizer.zero_grad() 40 | model.zero_grad() 41 | loss_tracker = np.zeros(0) 42 | for _minibatch_id, training_minibatch in enumerate(train_loader, 0): 43 | minibatches_proccesed += 1 44 | start_compute_loss = time.time() 45 | loss = model.compute_loss(training_minibatch) 46 | write_out("Train loss:", float(loss)) 47 | start_compute_grad = time.time() 48 | loss.backward() 49 | loss_tracker = np.append(loss_tracker, float(loss)) 50 | end = time.time() 51 | write_out("Loss time:", start_compute_grad - start_compute_loss, "Grad time:", 52 | end - start_compute_grad) 53 | optimizer.step() 54 | optimizer.zero_grad() 55 | model.zero_grad() 56 | 57 | # for every eval_interval samples, plot performance on the validation set 58 | if minibatches_proccesed % eval_interval == 0: 59 | 60 | write_out("Testing model on validation set...") 61 | 62 | train_loss = float(loss_tracker.mean()) 63 | loss_tracker = np.zeros(0) 64 | validation_loss, json_data, _ = model.evaluate_model(validation_loader) 65 | 66 | if validation_loss < best_model_loss: 67 | best_model_loss = validation_loss 68 | best_model_minibatch_time = minibatches_proccesed 69 | best_model_path = write_model_to_disk(model) 70 | _best_json_data = json_data 71 | 72 | write_out("Validation loss:", validation_loss, "Train loss:", train_loss) 73 | write_out("Best model so far (validation loss): ", best_model_loss, "at time", 74 | best_model_minibatch_time) 75 | write_out("Best model stored at " + best_model_path) 76 | write_out("Minibatches processed:", minibatches_proccesed) 77 | sample_num.append(minibatches_proccesed) 78 | train_loss_values.append(train_loss) 79 | validation_loss_values.append(validation_loss) 80 | 81 | json_data["validation_dataset_size"] = validation_dataset_size 82 | json_data["sample_num"] = sample_num 83 | json_data["train_loss_values"] = train_loss_values 84 | json_data["validation_loss_values"] = validation_loss_values 85 | 86 | if not hide_ui: 87 | write_out("Updating monitoring service:", str(json_data) 88 | if len(str(json_data)) < 50 else str(json_data)[:50]+"...") 89 | res = requests.post('http://localhost:5000/graph', json=json_data) 90 | if res.ok: 91 | write_out("Received response from monitoring service:", res.json()) 92 | 93 | if minibatches_proccesed > minimum_updates and minibatches_proccesed \ 94 | >= best_model_minibatch_time + minimum_updates: 95 | stopping_condition_met = True 96 | break 97 | write_result_summary(best_model_loss) 98 | # write_result_summary(json.dumps(_best_json_data)) 99 | return best_model_path 100 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the OpenProtein project. 3 | 4 | For license information, please see the LICENSE file in the root directory. 5 | """ 6 | 7 | import collections 8 | import math 9 | import os 10 | from datetime import datetime 11 | import torch 12 | import torch.utils.data 13 | import torch.nn.functional as F 14 | import h5py 15 | import PeptideBuilder 16 | import Bio.PDB 17 | from Bio.Data.IUPACData import protein_letters_1to3 18 | import numpy as np 19 | from torch.nn.utils.rnn import pad_sequence 20 | 21 | AA_ID_DICT = {'A': 1, 'C': 2, 'D': 3, 'E': 4, 'F': 5, 'G': 6, 'H': 7, 'I': 8, 'K': 9, 22 | 'L': 10, 'M': 11, 'N': 12, 'P': 13, 'Q': 14, 'R': 15, 'S': 16, 'T': 17, 23 | 'V': 18, 'W': 19, 'Y': 20} 24 | 25 | PI_TENSOR = torch.tensor([3.141592]) 26 | 27 | def contruct_dataloader_from_disk(filename, minibatch_size): 28 | return torch.utils.data.DataLoader(H5PytorchDataset(filename), 29 | batch_size=minibatch_size, 30 | shuffle=True, 31 | collate_fn=merge_samples_to_minibatch) 32 | 33 | 34 | class H5PytorchDataset(torch.utils.data.Dataset): 35 | def __init__(self, filename): 36 | super(H5PytorchDataset, self).__init__() 37 | 38 | self.h5pyfile = h5py.File(filename, 'r') 39 | self.num_proteins, self.max_sequence_len = self.h5pyfile['primary'].shape 40 | 41 | def __getitem__(self, index): 42 | mask = torch.Tensor(self.h5pyfile['mask'][index, :]).type(dtype=torch.bool) 43 | prim = torch.masked_select( 44 | torch.Tensor(self.h5pyfile['primary'][index, :]).type(dtype=torch.int), 45 | mask) 46 | tertiary = torch.Tensor(self.h5pyfile['tertiary'][index][:int(mask.sum())])# max length x 9 47 | return prim, tertiary, mask 48 | 49 | def __len__(self): 50 | return self.num_proteins 51 | 52 | 53 | def merge_samples_to_minibatch(samples): 54 | samples_list = [] 55 | for sample in samples: 56 | samples_list.append(sample) 57 | # sort according to length of aa sequence 58 | samples_list.sort(key=lambda x: len(x[0]), reverse=True) 59 | return zip(*samples_list) 60 | 61 | def set_experiment_id(data_set_identifier, learning_rate, minibatch_size): 62 | output_string = datetime.now().strftime('%Y-%m-%d_%H_%M_%S') 63 | output_string += "-" + str(os.getpid()) 64 | output_string += "-" + data_set_identifier 65 | output_string += "-LR" + str(learning_rate).replace(".", "_") 66 | output_string += "-MB" + str(minibatch_size) 67 | globals().__setitem__("experiment_id", output_string) 68 | 69 | 70 | def get_experiment_id(): 71 | return globals().get("experiment_id") 72 | 73 | 74 | def write_out(*args, end='\n'): 75 | output_string = datetime.now().strftime('%Y-%m-%d %H:%M:%S') \ 76 | + ": " + str.join(" ", [str(a) for a in args]) + end 77 | if globals().get("experiment_id") is not None: 78 | with open("output/" + globals().get("experiment_id") + ".txt", "a+") as output_file: 79 | output_file.write(output_string) 80 | output_file.flush() 81 | print(output_string, end="") 82 | 83 | 84 | def write_model_to_disk(model): 85 | path = "output/models/" + globals().get("experiment_id") + ".model" 86 | torch.save(model, path) 87 | return path 88 | 89 | 90 | def write_prediction_data_to_disk(prediction_data): 91 | filepath = "output/predictions/" + globals().get("experiment_id") + ".txt" 92 | output_file = open(filepath, 'w') 93 | output_file.write(prediction_data) 94 | output_file.close() 95 | 96 | 97 | def draw_plot(fig, plt, validation_dataset_size, sample_num, train_loss_values, 98 | validation_loss_values): 99 | def draw_with_vars(): 100 | ax = fig.gca() 101 | ax2 = ax.twinx() 102 | plt.grid(True) 103 | plt.title("Training progress (" + str(validation_dataset_size) 104 | + " samples in validation set)") 105 | train_loss_plot, = ax.plot(sample_num, train_loss_values) 106 | ax.set_ylabel('Train Negative log likelihood') 107 | ax.yaxis.labelpad = 0 108 | validation_loss_plot, = ax2.plot(sample_num, validation_loss_values, color='black') 109 | ax2.set_ylabel('Validation loss') 110 | ax2.set_ylim(bottom=0) 111 | plt.legend([train_loss_plot, validation_loss_plot], 112 | ['Train loss on last batch', 'Validation loss']) 113 | ax.set_xlabel('Minibatches processed (=network updates)', color='black') 114 | 115 | return draw_with_vars 116 | 117 | 118 | def draw_ramachandran_plot(fig, plt, phi, psi): 119 | def draw_with_vars(): 120 | ax = fig.gca() 121 | plt.grid(True) 122 | plt.title("Ramachandran plot") 123 | train_loss_plot, = ax.plot(phi, psi) 124 | ax.set_ylabel('Psi') 125 | ax.yaxis.labelpad = 0 126 | plt.legend([train_loss_plot], 127 | ['Phi psi']) 128 | ax.set_xlabel('Phi', color='black') 129 | 130 | return draw_with_vars 131 | 132 | 133 | def write_result_summary(accuracy): 134 | output_string = globals().get("experiment_id") + ": " + str(accuracy) + "\n" 135 | with open("output/result_summary.txt", "a+") as output_file: 136 | output_file.write(output_string) 137 | output_file.flush() 138 | print(output_string, end="") 139 | 140 | 141 | def calculate_dihedral_angles_over_minibatch(atomic_coords_padded, batch_sizes, use_gpu): 142 | angles = [] 143 | batch_sizes = torch.LongTensor(batch_sizes) 144 | atomic_coords = atomic_coords_padded.transpose(0, 1) 145 | 146 | for idx, coordinate in enumerate(atomic_coords.split(1, dim=0)): 147 | angles_from_coords = torch.index_select( 148 | coordinate.squeeze(0), 149 | 0, 150 | torch.arange(int(batch_sizes[idx].item())) 151 | ) 152 | angles.append(calculate_dihedral_angles(angles_from_coords, use_gpu)) 153 | 154 | return torch.nn.utils.rnn.pad_sequence(angles), batch_sizes 155 | 156 | def protein_id_to_str(protein_id_list): 157 | _aa_dict_inverse = {v: k for k, v in AA_ID_DICT.items()} 158 | aa_list = [] 159 | for protein_id in protein_id_list: 160 | aa_symbol = _aa_dict_inverse[protein_id.item()] 161 | aa_list.append(aa_symbol) 162 | return aa_list 163 | 164 | 165 | def calculate_dihedral_angles(atomic_coords, use_gpu): 166 | #assert atomic_coords.shape[1] == 9 167 | atomic_coords = atomic_coords.contiguous().view(-1, 3) 168 | 169 | zero_tensor = torch.zeros(1) 170 | if use_gpu: 171 | zero_tensor = zero_tensor.cuda() 172 | 173 | 174 | 175 | angles = torch.cat((zero_tensor, 176 | zero_tensor, 177 | compute_dihedral_list(atomic_coords), 178 | zero_tensor)).view(-1, 3) 179 | return angles 180 | 181 | def compute_cross(tensor_a, tensor_b, dim): 182 | 183 | result = [] 184 | 185 | x = torch.zeros(1).long() 186 | y = torch.ones(1).long() 187 | z = torch.ones(1).long() * 2 188 | 189 | ax = torch.index_select(tensor_a, dim, x).squeeze(dim) 190 | ay = torch.index_select(tensor_a, dim, y).squeeze(dim) 191 | az = torch.index_select(tensor_a, dim, z).squeeze(dim) 192 | 193 | bx = torch.index_select(tensor_b, dim, x).squeeze(dim) 194 | by = torch.index_select(tensor_b, dim, y).squeeze(dim) 195 | bz = torch.index_select(tensor_b, dim, z).squeeze(dim) 196 | 197 | result.append(ay * bz - az * by) 198 | result.append(az * bx - ax * bz) 199 | result.append(ax * by - ay * bx) 200 | 201 | result = torch.stack(result, dim=dim) 202 | 203 | return result 204 | 205 | 206 | def compute_atan2(y_coord, x_coord): 207 | # TODO: figure out of eps is needed here 208 | eps = 10 ** (-4) 209 | ans = torch.atan(y_coord / (x_coord + eps)) # x > 0 210 | ans = torch.where((y_coord >= 0) & (x_coord < 0), ans + PI_TENSOR, ans) 211 | ans = torch.where((y_coord < 0) & (x_coord < 0), ans - PI_TENSOR, ans) 212 | ans = torch.where((y_coord > 0) & (x_coord == 0), PI_TENSOR / 2, ans) 213 | ans = torch.where((y_coord < 0) & (x_coord == 0), -PI_TENSOR / 2, ans) 214 | return ans 215 | 216 | 217 | def compute_dihedral_list(atomic_coords): 218 | # atomic_coords is -1 x 3 219 | ba = atomic_coords[1:] - atomic_coords[:-1] 220 | ba_normalized = ba / ba.norm(dim=1).unsqueeze(1) 221 | ba_neg = -1 * ba_normalized 222 | 223 | n1_vec = compute_cross(ba_normalized[:-2], ba_neg[1:-1], dim=1) 224 | n2_vec = compute_cross(ba_neg[1:-1], ba_normalized[2:], dim=1) 225 | 226 | n1_vec_normalized = n1_vec / n1_vec.norm(dim=1).unsqueeze(1) 227 | n2_vec_normalized = n2_vec / n2_vec.norm(dim=1).unsqueeze(1) 228 | 229 | m1_vec = compute_cross(n1_vec_normalized, ba_neg[1:-1], dim=1) 230 | 231 | x_value = torch.sum(n1_vec_normalized * n2_vec_normalized, dim=1) 232 | y_value = torch.sum(m1_vec * n2_vec_normalized, dim=1) 233 | return compute_atan2(y_value, x_value) 234 | 235 | 236 | def write_pdb(file_name, aa_sequence, residue_coords): 237 | residue_names = list([protein_letters_1to3[l].upper() for l in aa_sequence]) 238 | num_atoms = len(residue_coords) 239 | backbone_names = num_atoms * ["N", "CA", "C"] 240 | 241 | assert num_atoms == len(aa_sequence) * 3 242 | file = open(file_name, 'w') 243 | 244 | for i in range(num_atoms): 245 | atom_coordinates = list([str(l) for l in np.round(residue_coords[i], 3)]) 246 | residue_position = int(i / 3) 247 | atom_id = str(i + 1) 248 | file.write(f"""\ 249 | ATOM \ 250 | {atom_id.rjust(5)} \ 251 | {backbone_names[i].rjust(4)} \ 252 | {residue_names[residue_position - 1].rjust(3)} \ 253 | A\ 254 | {str(residue_position).rjust(4)} \ 255 | {atom_coordinates[0].rjust(8)}\ 256 | {atom_coordinates[1].rjust(8)}\ 257 | {atom_coordinates[2].rjust(8)}\ 258 | \n""") 259 | file.close() 260 | 261 | def get_structure_from_angles(aa_list_encoded, angles): 262 | aa_list = protein_id_to_str(aa_list_encoded) 263 | omega_list = angles[1:, 0] 264 | phi_list = angles[1:, 1] 265 | psi_list = angles[:-1, 2] 266 | assert len(aa_list) == len(phi_list) + 1 == len(psi_list) + 1 == len(omega_list) + 1 267 | structure = PeptideBuilder.make_structure(aa_list, 268 | list(map(lambda x: math.degrees(x), phi_list)), 269 | list(map(lambda x: math.degrees(x), psi_list)), 270 | list(map(lambda x: math.degrees(x), omega_list))) 271 | return structure 272 | 273 | 274 | def write_to_pdb(structure, prot_id): 275 | out = Bio.PDB.PDBIO() 276 | out.set_structure(structure) 277 | out.save("output/protein_" + str(prot_id) + ".pdb") 278 | 279 | 280 | def calc_pairwise_distances(chain_a, chain_b, use_gpu): 281 | distance_matrix = torch.Tensor(chain_a.size()[0], chain_b.size()[0]).type(torch.float) 282 | # add small epsilon to avoid boundary issues 283 | epsilon = 10 ** (-4) * torch.ones(chain_a.size(0), chain_b.size(0)) 284 | if use_gpu: 285 | distance_matrix = distance_matrix.cuda() 286 | epsilon = epsilon.cuda() 287 | 288 | for idx, row in enumerate(chain_a.split(1)): 289 | distance_matrix[idx] = torch.sum((row.expand_as(chain_b) - chain_b) ** 2, 1).view(1, -1) 290 | 291 | return torch.sqrt(distance_matrix + epsilon) 292 | 293 | 294 | def calc_drmsd(chain_a, chain_b, use_gpu=False): 295 | assert len(chain_a) == len(chain_b) 296 | distance_matrix_a = calc_pairwise_distances(chain_a, chain_a, use_gpu) 297 | distance_matrix_b = calc_pairwise_distances(chain_b, chain_b, use_gpu) 298 | return torch.norm(distance_matrix_a - distance_matrix_b, 2) \ 299 | / math.sqrt((len(chain_a) * (len(chain_a) - 1))) 300 | 301 | 302 | # method for translating a point cloud to its center of mass 303 | def transpose_atoms_to_center_of_mass(atoms_matrix): 304 | # calculate com by summing x, y and z respectively 305 | # and dividing by the number of points 306 | center_of_mass = np.matrix([[atoms_matrix[0, :].sum() / atoms_matrix.shape[1]], 307 | [atoms_matrix[1, :].sum() / atoms_matrix.shape[1]], 308 | [atoms_matrix[2, :].sum() / atoms_matrix.shape[1]]]) 309 | # translate points to com and return 310 | return atoms_matrix - center_of_mass 311 | 312 | 313 | def calc_rmsd(chain_a, chain_b): 314 | # move to center of mass 315 | chain_a_value = chain_a.cpu().numpy().transpose() 316 | chain_b_value = chain_b.cpu().numpy().transpose() 317 | X = transpose_atoms_to_center_of_mass(chain_a_value) 318 | Y = transpose_atoms_to_center_of_mass(chain_b_value) 319 | 320 | R = Y * X.transpose() 321 | # extract the singular values 322 | _, S, _ = np.linalg.svd(R) 323 | # compute RMSD using the formular 324 | E0 = sum(list(np.linalg.norm(x) ** 2 for x in X.transpose()) 325 | + list(np.linalg.norm(x) ** 2 for x in Y.transpose())) 326 | TraceS = sum(S) 327 | RMSD = np.sqrt((1 / len(X.transpose())) * (E0 - 2 * TraceS)) 328 | return RMSD 329 | 330 | 331 | def calc_angular_difference(values_1, values_2): 332 | values_1 = values_1.transpose(0, 1).contiguous() 333 | values_2 = values_2.transpose(0, 1).contiguous() 334 | acc = 0 335 | for idx, _ in enumerate(values_1): 336 | assert values_1[idx].shape[1] == 3 337 | assert values_2[idx].shape[1] == 3 338 | a1_element = values_1[idx].view(-1, 1) 339 | a2_element = values_2[idx].view(-1, 1) 340 | acc += torch.sqrt(torch.mean( 341 | torch.min(torch.abs(a2_element - a1_element), 342 | 2 * math.pi - torch.abs(a2_element - a1_element) 343 | ) ** 2)) 344 | return acc / values_1.shape[0] 345 | 346 | 347 | def structures_to_backbone_atoms_padded(structures): 348 | backbone_atoms_list = [] 349 | for structure in structures: 350 | backbone_atoms_list.append(structure_to_backbone_atoms(structure)) 351 | backbone_atoms_padded, batch_sizes_backbone = torch.nn.utils.rnn.pad_packed_sequence( 352 | torch.nn.utils.rnn.pack_sequence(backbone_atoms_list)) 353 | return backbone_atoms_padded, batch_sizes_backbone 354 | 355 | 356 | def structure_to_backbone_atoms(structure): 357 | predicted_coords = [] 358 | for res in structure.get_residues(): 359 | predicted_coords.append(torch.Tensor(res["N"].get_coord())) 360 | predicted_coords.append(torch.Tensor(res["CA"].get_coord())) 361 | predicted_coords.append(torch.Tensor(res["C"].get_coord())) 362 | return torch.stack(predicted_coords).view(-1, 9) 363 | 364 | NUM_FRAGMENTS = torch.tensor(6) 365 | def get_backbone_positions_from_angles(angular_emissions, batch_sizes, use_gpu): 366 | # angular_emissions -1 x minibatch size x 3 (omega, phi, psi) 367 | points = dihedral_to_point(angular_emissions, use_gpu) 368 | coordinates = point_to_coordinate( 369 | points, 370 | use_gpu, 371 | num_fragments=NUM_FRAGMENTS) / 100 # divide by 100 to angstrom unit 372 | return coordinates.transpose(0, 1).contiguous()\ 373 | .view(len(batch_sizes), -1, 9).transpose(0, 1), batch_sizes 374 | 375 | 376 | def calc_avg_drmsd_over_minibatch(backbone_atoms_padded, actual_coords_padded, batch_sizes): 377 | backbone_atoms_list = list( 378 | [backbone_atoms_padded[:batch_sizes[i], i] for i in range(int(backbone_atoms_padded 379 | .size(1)))]) 380 | actual_coords_list = list( 381 | [actual_coords_padded[:batch_sizes[i], i] for i in range(int(actual_coords_padded 382 | .size(1)))]) 383 | drmsd_avg = 0 384 | for idx, backbone_atoms in enumerate(backbone_atoms_list): 385 | actual_coords = actual_coords_list[idx].transpose(0, 1).contiguous().view(-1, 3) 386 | drmsd_avg += calc_drmsd(backbone_atoms.transpose(0, 1).contiguous().view(-1, 3), 387 | actual_coords) 388 | return drmsd_avg / len(backbone_atoms_list) 389 | 390 | 391 | def encode_primary_string(primary): 392 | return list([AA_ID_DICT[aa] for aa in primary]) 393 | 394 | 395 | def initial_pos_from_aa_string(batch_aa_string, use_gpu): 396 | arr_of_angles = [] 397 | batch_sizes = [] 398 | for aa_string in batch_aa_string: 399 | length_of_protein = aa_string.size(0) 400 | angles = torch.stack([-120*torch.ones(length_of_protein), 401 | 140*torch.ones(length_of_protein), 402 | -370*torch.ones(length_of_protein)]).transpose(0, 1) 403 | arr_of_angles.append(angles) 404 | batch_sizes.append(length_of_protein) 405 | 406 | padded = pad_sequence(arr_of_angles).transpose(0, 1) 407 | return get_backbone_positions_from_angles(padded, batch_sizes, use_gpu) 408 | 409 | def pass_messages(aa_features, message_transformation, use_gpu): 410 | # aa_features (#aa, #features) - each row represents the amino acid type 411 | # (embedding) and the positions of the backbone atoms 412 | # message_transformation: (-1 * 2 * feature_size) -> (-1 * output message size) 413 | feature_size = aa_features.size(1) 414 | aa_count = aa_features.size(0) 415 | 416 | arange2d = torch.arange(aa_count).repeat(aa_count).view((aa_count, aa_count)) 417 | 418 | diagonal_matrix = (arange2d == arange2d.transpose(0, 1)).int() 419 | 420 | eye = diagonal_matrix.view(-1).expand(2, feature_size, -1)\ 421 | .transpose(1, 2).transpose(0, 1) 422 | 423 | eye_inverted = torch.ones(eye.size(), dtype=torch.uint8) - eye 424 | if use_gpu: 425 | eye_inverted = eye_inverted.cuda() 426 | features_repeated = aa_features.repeat((aa_count, 1)).view((aa_count, aa_count, feature_size)) 427 | # (aa_count^2 - aa_count) x 2 x aa_features (all pairs except for reflexive connections) 428 | aa_messages = torch.stack((features_repeated.transpose(0, 1), features_repeated))\ 429 | .transpose(0, 1).transpose(1, 2).view(-1, 2, feature_size) 430 | 431 | eye_inverted_location = eye_inverted.view(-1).nonzero().squeeze(1) 432 | 433 | aa_msg_pairs = aa_messages\ 434 | .reshape(-1).gather(0, eye_inverted_location).view(-1, 2, feature_size) 435 | 436 | transformed = message_transformation(aa_msg_pairs).view(aa_count, aa_count - 1, -1) 437 | transformed_sum = transformed.sum(dim=1) # aa_count x output message size 438 | return transformed_sum 439 | 440 | 441 | def load_model_from_disk(path, force_cpu=True): 442 | if force_cpu: 443 | # load model with map_location set to storage (main mem) 444 | model = torch.load(path, map_location=lambda storage, loc: storage) 445 | # flattern parameters in memory 446 | model.flatten_parameters() 447 | # update internal state accordingly 448 | model.use_gpu = False 449 | else: 450 | # load model using default map_location 451 | model = torch.load(path) 452 | model.flatten_parameters() 453 | return model 454 | 455 | # Constants 456 | NUM_DIMENSIONS = 3 457 | NUM_DIHEDRALS = 3 458 | BOND_LENGTHS = torch.tensor([145.801, 152.326, 132.868], dtype=torch.float32) 459 | BOND_ANGLES = torch.tensor([2.124, 1.941, 2.028], dtype=torch.float32) 460 | 461 | 462 | def dihedral_to_point(dihedral, use_gpu, bond_lengths=BOND_LENGTHS, 463 | bond_angles=BOND_ANGLES): 464 | """ 465 | Takes triplets of dihedral angles (phi, psi, omega) and returns 3D points 466 | ready for use in reconstruction of coordinates. Bond lengths and angles 467 | are based on idealized averages. 468 | :param dihedral: [NUM_STEPS, BATCH_SIZE, NUM_DIHEDRALS] 469 | :return: Tensor containing points of the protein's backbone atoms. 470 | Shape [NUM_STEPS x NUM_DIHEDRALS, BATCH_SIZE, NUM_DIMENSIONS] 471 | """ 472 | num_steps = dihedral.shape[0] 473 | batch_size = dihedral.shape[1] 474 | 475 | r_cos_theta = bond_lengths * torch.cos(PI_TENSOR - bond_angles) 476 | r_sin_theta = bond_lengths * torch.sin(PI_TENSOR - bond_angles) 477 | 478 | if use_gpu: 479 | r_cos_theta = r_cos_theta.cuda() 480 | r_sin_theta = r_sin_theta.cuda() 481 | 482 | point_x = r_cos_theta.view(1, 1, -1).repeat(num_steps, batch_size, 1) 483 | point_y = torch.cos(dihedral) * r_sin_theta 484 | point_z = torch.sin(dihedral) * r_sin_theta 485 | 486 | point = torch.stack([point_x, point_y, point_z]) 487 | point_perm = point.permute(1, 3, 2, 0) 488 | point_final = point_perm.contiguous().view(num_steps * NUM_DIHEDRALS, 489 | batch_size, 490 | NUM_DIMENSIONS) 491 | return point_final 492 | 493 | PNERF_INIT_MATRIX = [torch.tensor([-torch.sqrt(torch.tensor([1.0 / 2.0])), 494 | torch.sqrt(torch.tensor([3.0 / 2.0])), 0]), 495 | torch.tensor([-torch.sqrt(torch.tensor([2.0])), 0, 0]), 496 | torch.tensor([0, 0, 0])] 497 | 498 | def point_to_coordinate(points, use_gpu, num_fragments): 499 | """ 500 | Takes points from dihedral_to_point and sequentially converts them into 501 | coordinates of a 3D structure. 502 | 503 | Reconstruction is done in parallel by independently reconstructing 504 | num_fragments and the reconstituting the chain at the end in reverse order. 505 | The core reconstruction algorithm is NeRF, based on 506 | DOI: 10.1002/jcc.20237 by Parsons et al. 2005. 507 | The parallelized version is described in 508 | https://www.biorxiv.org/content/early/2018/08/06/385450. 509 | :param points: Tensor containing points as returned by `dihedral_to_point`. 510 | Shape [NUM_STEPS x NUM_DIHEDRALS, BATCH_SIZE, NUM_DIMENSIONS] 511 | :param num_fragments: Number of fragments in which the sequence is split 512 | to perform parallel computation. 513 | :return: Tensor containing correctly transformed atom coordinates. 514 | Shape [NUM_STEPS x NUM_DIHEDRALS, BATCH_SIZE, NUM_DIMENSIONS] 515 | """ 516 | 517 | # Compute optimal number of fragments if needed 518 | total_num_angles = points.size(0) # NUM_STEPS x NUM_DIHEDRALS 519 | if isinstance(total_num_angles, int): 520 | total_num_angles = torch.tensor(total_num_angles) 521 | 522 | # Initial three coordinates (specifically chosen to eliminate need for 523 | # extraneous matmul) 524 | Triplet = collections.namedtuple('Triplet', 'a, b, c') 525 | batch_size = points.shape[1] 526 | 527 | init_coords = [] 528 | for row in PNERF_INIT_MATRIX: 529 | row_tensor = row\ 530 | .repeat([num_fragments * batch_size, 1])\ 531 | .view(num_fragments, batch_size, NUM_DIMENSIONS) 532 | if use_gpu: 533 | row_tensor = row_tensor.cuda() 534 | init_coords.append(row_tensor) 535 | 536 | init_coords = Triplet(*init_coords) # NUM_DIHEDRALS x [NUM_FRAGS, BATCH_SIZE, NUM_DIMENSIONS] 537 | 538 | # Pad points to yield equal-sized fragments 539 | # (NUM_FRAGS x FRAG_SIZE) - (NUM_STEPS x NUM_DIHEDRALS) 540 | padding = torch.fmod(num_fragments - (total_num_angles % num_fragments), num_fragments) 541 | 542 | # [NUM_FRAGS x FRAG_SIZE, BATCH_SIZE, NUM_DIMENSIONS] 543 | padding_tensor = torch.zeros((padding, points.size(1), points.size(2))) 544 | points = torch.cat((points, padding_tensor)) 545 | 546 | points = points.view(num_fragments, -1, batch_size, 547 | NUM_DIMENSIONS) # [NUM_FRAGS, FRAG_SIZE, BATCH_SIZE, NUM_DIMENSIONS] 548 | points = points.permute(1, 0, 2, 3) # [FRAG_SIZE, NUM_FRAGS, BATCH_SIZE, NUM_DIMENSIONS] 549 | 550 | # Extension function used for single atom reconstruction and whole fragment 551 | # alignment 552 | def extend(prev_three_coords, point, multi_m): 553 | """ 554 | Aligns an atom or an entire fragment depending on value of `multi_m` 555 | with the preceding three atoms. 556 | :param prev_three_coords: Named tuple storing the last three atom 557 | coordinates ("a", "b", "c") where "c" is the current end of the 558 | structure (i.e. closest to the atom/ fragment that will be added now). 559 | Shape NUM_DIHEDRALS x [NUM_FRAGS/0, BATCH_SIZE, NUM_DIMENSIONS]. 560 | First rank depends on value of `multi_m`. 561 | :param point: Point describing the atom that is added to the structure. 562 | Shape [NUM_FRAGS/FRAG_SIZE, BATCH_SIZE, NUM_DIMENSIONS] 563 | First rank depends on value of `multi_m`. 564 | :param multi_m: If True, a single atom is added to the chain for 565 | multiple fragments in parallel. If False, an single fragment is added. 566 | Note the different parameter dimensions. 567 | :return: Coordinates of the atom/ fragment. 568 | """ 569 | bc = F.normalize(prev_three_coords.c - prev_three_coords.b, dim=-1) 570 | n = F.normalize(compute_cross(prev_three_coords.b - prev_three_coords.a, 571 | bc, dim=2 if multi_m else 1), dim=-1) 572 | if multi_m: # multiple fragments, one atom at a time 573 | m = torch.stack([bc, compute_cross(n, bc, dim=2), n]).permute(1, 2, 3, 0) 574 | else: # single fragment, reconstructed entirely at once. 575 | s = point.shape + (3,) 576 | m = torch.stack([bc, compute_cross(n, bc, dim=1), n]).permute(1, 2, 0) 577 | m = m.repeat(s[0], 1, 1).view(s) 578 | coord = torch.squeeze(torch.matmul(m, point.unsqueeze(3)), 579 | dim=3) + prev_three_coords.c 580 | return coord 581 | 582 | # Loop over FRAG_SIZE in NUM_FRAGS parallel fragments, sequentially 583 | # generating the coordinates for each fragment across all batches 584 | coords_list = [] 585 | prev_three_coords = init_coords 586 | 587 | for point in points.split(1, dim=0): # Iterate over FRAG_SIZE 588 | coord = extend(prev_three_coords, point.squeeze(0), True) 589 | coords_list.append(coord) 590 | prev_three_coords = Triplet(prev_three_coords.b, 591 | prev_three_coords.c, 592 | coord) 593 | 594 | coords_pretrans = torch.stack(coords_list).permute(1, 0, 2, 3) 595 | 596 | # Loop backwards over NUM_FRAGS to align the individual fragments. For each 597 | # next fragment, we transform the fragments we have already iterated over 598 | # (coords_trans) to be aligned with the next fragment 599 | coords_trans = coords_pretrans[-1] 600 | for idx in torch.arange(end=-1, start=coords_pretrans.shape[0] - 2, step=-1).split(1, dim=0): 601 | # Transform the fragments that we have already iterated over to be 602 | # aligned with the next fragment `coords_trans` 603 | transformed_coords = extend(Triplet(*[di.index_select(0, idx).squeeze(0) 604 | for di in prev_three_coords]), 605 | coords_trans, False) 606 | coords_trans = torch.cat( 607 | [coords_pretrans.index_select(0, idx).squeeze(0), transformed_coords], 0) 608 | 609 | coords_to_pad = torch.index_select(coords_trans, 0, torch.arange(total_num_angles - 1)) 610 | 611 | coords = F.pad(coords_to_pad, (0, 0, 0, 0, 1, 0)) 612 | 613 | return coords 614 | --------------------------------------------------------------------------------