├── .gitignore ├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING-ARCHIVED.md ├── LICENSE ├── README.md ├── SECURITY.md ├── attribution.txt ├── data ├── binding_sites │ ├── binding_site_train.lmdb │ │ ├── data.mdb │ │ └── lock.mdb │ └── binding_site_valid.lmdb │ │ ├── data.mdb │ │ └── lock.mdb └── protein_modifications │ ├── protein_modification_train.lmdb │ ├── data.mdb │ └── lock.mdb │ └── protein_modification_valid.lmdb │ ├── data.mdb │ └── lock.mdb ├── images ├── vis3d_binding_sites.png └── vis3d_contact_map.png ├── notebooks └── provis.ipynb ├── protein_attention ├── __init__.py ├── attention_analysis │ ├── __init__.py │ ├── background.py │ ├── compute_edge_features.py │ ├── features.py │ ├── report_aa_correlations.py │ ├── report_edge_features.py │ ├── report_edge_features_combined.py │ ├── report_top_heads.py │ └── scripts │ │ ├── compute_all_features_prot_albert.sh │ │ ├── compute_all_features_prot_bert.sh │ │ ├── compute_all_features_prot_bert_bfd.sh │ │ ├── compute_all_features_prot_xlnet.sh │ │ ├── compute_all_features_tape_bert.sh │ │ ├── report_all_features_prot_albert.sh │ │ ├── report_all_features_prot_bert.sh │ │ ├── report_all_features_prot_bert_bfd.sh │ │ ├── report_all_features_prot_xlnet.sh │ │ └── report_all_features_tape_bert.sh ├── datasets.py ├── probing │ ├── __init__.py │ ├── metrics.py │ ├── models.py │ ├── probe.py │ ├── report.py │ └── scripts │ │ ├── probe_contact.sh │ │ ├── probe_contact_attention.sh │ │ ├── probe_sites.sh │ │ ├── probe_ss4_0.sh │ │ ├── probe_ss4_1.sh │ │ └── probe_ss4_2.sh └── utils.py ├── reports ├── attention_analysis │ ├── blosum │ │ ├── edge_features_aa_prot_bert │ │ │ ├── aa_corr_to.pdf │ │ │ ├── args.json │ │ │ └── blosum62.pdf │ │ ├── edge_features_aa_prot_bert_bfd │ │ │ ├── aa_corr_to.pdf │ │ │ ├── args.json │ │ │ └── blosum62.pdf │ │ ├── edge_features_aa_prot_xlnet │ │ │ ├── aa_corr_to.pdf │ │ │ ├── args.json │ │ │ └── blosum62.pdf │ │ └── edge_features_aa_tape_bert │ │ │ ├── aa_corr_to.pdf │ │ │ ├── args.json │ │ │ └── blosum62.pdf │ ├── edge_features_aa_prot_albert │ │ ├── aa_to_A.pdf │ │ ├── aa_to_C.pdf │ │ ├── aa_to_D.pdf │ │ ├── aa_to_E.pdf │ │ ├── aa_to_F.pdf │ │ ├── aa_to_G.pdf │ │ ├── aa_to_H.pdf │ │ ├── aa_to_I.pdf │ │ ├── aa_to_K.pdf │ │ ├── aa_to_L.pdf │ │ ├── aa_to_M.pdf │ │ ├── aa_to_N.pdf │ │ ├── aa_to_P.pdf │ │ ├── aa_to_Q.pdf │ │ ├── aa_to_R.pdf │ │ ├── aa_to_S.pdf │ │ ├── aa_to_T.pdf │ │ ├── aa_to_V.pdf │ │ ├── aa_to_W.pdf │ │ ├── aa_to_Y.pdf │ │ └── args.json │ ├── edge_features_aa_prot_bert │ │ ├── aa_to_A.pdf │ │ ├── aa_to_C.pdf │ │ ├── aa_to_D.pdf │ │ ├── aa_to_E.pdf │ │ ├── aa_to_F.pdf │ │ ├── aa_to_G.pdf │ │ ├── aa_to_H.pdf │ │ ├── aa_to_I.pdf │ │ ├── aa_to_K.pdf │ │ ├── aa_to_L.pdf │ │ ├── aa_to_M.pdf │ │ ├── aa_to_N.pdf │ │ ├── aa_to_P.pdf │ │ ├── aa_to_Q.pdf │ │ ├── aa_to_R.pdf │ │ ├── aa_to_S.pdf │ │ ├── aa_to_T.pdf │ │ ├── aa_to_V.pdf │ │ ├── aa_to_W.pdf │ │ ├── aa_to_Y.pdf │ │ └── args.json │ ├── edge_features_aa_prot_bert_bfd │ │ ├── aa_to_A.pdf │ │ ├── aa_to_C.pdf │ │ ├── aa_to_D.pdf │ │ ├── aa_to_E.pdf │ │ ├── aa_to_F.pdf │ │ ├── aa_to_G.pdf │ │ ├── aa_to_H.pdf │ │ ├── aa_to_I.pdf │ │ ├── aa_to_K.pdf │ │ ├── aa_to_L.pdf │ │ ├── aa_to_M.pdf │ │ ├── aa_to_N.pdf │ │ ├── aa_to_P.pdf │ │ ├── aa_to_Q.pdf │ │ ├── aa_to_R.pdf │ │ ├── aa_to_S.pdf │ │ ├── aa_to_T.pdf │ │ ├── aa_to_V.pdf │ │ ├── aa_to_W.pdf │ │ ├── aa_to_Y.pdf │ │ └── args.json │ ├── edge_features_aa_prot_xlnet │ │ ├── aa_to_A.pdf │ │ ├── aa_to_C.pdf │ │ ├── aa_to_D.pdf │ │ ├── aa_to_E.pdf │ │ ├── aa_to_F.pdf │ │ ├── aa_to_G.pdf │ │ ├── aa_to_H.pdf │ │ ├── aa_to_I.pdf │ │ ├── aa_to_K.pdf │ │ ├── aa_to_L.pdf │ │ ├── aa_to_M.pdf │ │ ├── aa_to_N.pdf │ │ ├── aa_to_P.pdf │ │ ├── aa_to_Q.pdf │ │ ├── aa_to_R.pdf │ │ ├── aa_to_S.pdf │ │ ├── aa_to_T.pdf │ │ ├── aa_to_V.pdf │ │ ├── aa_to_W.pdf │ │ ├── aa_to_Y.pdf │ │ └── args.json │ ├── edge_features_aa_tape_bert │ │ ├── aa_to_A.pdf │ │ ├── aa_to_C.pdf │ │ ├── aa_to_D.pdf │ │ ├── aa_to_E.pdf │ │ ├── aa_to_F.pdf │ │ ├── aa_to_G.pdf │ │ ├── aa_to_H.pdf │ │ ├── aa_to_I.pdf │ │ ├── aa_to_K.pdf │ │ ├── aa_to_L.pdf │ │ ├── aa_to_M.pdf │ │ ├── aa_to_N.pdf │ │ ├── aa_to_P.pdf │ │ ├── aa_to_Q.pdf │ │ ├── aa_to_R.pdf │ │ ├── aa_to_S.pdf │ │ ├── aa_to_T.pdf │ │ ├── aa_to_V.pdf │ │ ├── aa_to_W.pdf │ │ ├── aa_to_Y.pdf │ │ └── args.json │ ├── edge_features_combined_tape_bert │ │ ├── args_edge_features_contact_tape_bert.json │ │ ├── args_edge_features_sec_tape_bert.json │ │ ├── args_edge_features_sites_tape_bert.json │ │ └── combined_features.pdf │ ├── edge_features_contact_prot_albert │ │ ├── args.json │ │ └── contact_map.pdf │ ├── edge_features_contact_prot_albert_topheads │ │ ├── args.json │ │ └── contact_map.pdf │ ├── edge_features_contact_prot_bert │ │ ├── args.json │ │ └── contact_map.pdf │ ├── edge_features_contact_prot_bert_bfd │ │ ├── args.json │ │ └── contact_map.pdf │ ├── edge_features_contact_prot_bert_bfd_topheads │ │ ├── args.json │ │ └── contact_map.pdf │ ├── edge_features_contact_prot_bert_topheads │ │ ├── args.json │ │ └── contact_map.pdf │ ├── edge_features_contact_prot_xlnet │ │ ├── args.json │ │ └── contact_map.pdf │ ├── edge_features_contact_prot_xlnet_topheads │ │ ├── args.json │ │ └── contact_map.pdf │ ├── edge_features_contact_tape_bert │ │ ├── args.json │ │ └── contact_map.pdf │ ├── edge_features_contact_tape_bert_topheads │ │ ├── args.json │ │ └── contact_map.pdf │ ├── edge_features_modifications_prot_albert │ │ ├── args.json │ │ └── protein_modification_to.pdf │ ├── edge_features_modifications_prot_albert_topheads │ │ ├── args.json │ │ └── protein_modification_to.pdf │ ├── edge_features_modifications_prot_bert │ │ ├── args.json │ │ └── protein_modification_to.pdf │ ├── edge_features_modifications_prot_bert_bfd │ │ ├── args.json │ │ └── protein_modification_to.pdf │ ├── edge_features_modifications_prot_bert_bfd_topheads │ │ ├── args.json │ │ └── protein_modification_to.pdf │ ├── edge_features_modifications_prot_bert_topheads │ │ ├── args.json │ │ └── protein_modification_to.pdf │ ├── edge_features_modifications_prot_xlnet │ │ ├── args.json │ │ └── protein_modification_to.pdf │ ├── edge_features_modifications_prot_xlnet_topheads │ │ ├── args.json │ │ └── protein_modification_to.pdf │ ├── edge_features_modifications_tape_bert │ │ ├── args.json │ │ └── protein_modification_to.pdf │ ├── edge_features_modifications_tape_bert_topheads │ │ ├── args.json │ │ └── protein_modification_to.pdf │ ├── edge_features_sec_prot_albert │ │ ├── args.json │ │ ├── sec_struct_to_0.pdf │ │ ├── sec_struct_to_1.pdf │ │ ├── sec_struct_to_2.pdf │ │ └── sec_struct_to_3.pdf │ ├── edge_features_sec_prot_bert │ │ ├── args.json │ │ ├── sec_struct_to_0.pdf │ │ ├── sec_struct_to_1.pdf │ │ ├── sec_struct_to_2.pdf │ │ └── sec_struct_to_3.pdf │ ├── edge_features_sec_prot_bert_bfd │ │ ├── args.json │ │ ├── sec_struct_to_0.pdf │ │ ├── sec_struct_to_1.pdf │ │ ├── sec_struct_to_2.pdf │ │ └── sec_struct_to_3.pdf │ ├── edge_features_sec_prot_xlnet │ │ ├── args.json │ │ ├── sec_struct_to_0.pdf │ │ ├── sec_struct_to_1.pdf │ │ ├── sec_struct_to_2.pdf │ │ └── sec_struct_to_3.pdf │ ├── edge_features_sec_tape_bert │ │ ├── args.json │ │ ├── sec_struct_to_0.pdf │ │ ├── sec_struct_to_1.pdf │ │ ├── sec_struct_to_2.pdf │ │ └── sec_struct_to_3.pdf │ ├── edge_features_sites_prot_albert │ │ ├── args.json │ │ └── binding_site_to.pdf │ ├── edge_features_sites_prot_albert_topheads │ │ ├── args.json │ │ └── binding_site_to.pdf │ ├── edge_features_sites_prot_bert │ │ ├── args.json │ │ └── binding_site_to.pdf │ ├── edge_features_sites_prot_bert_bfd │ │ ├── args.json │ │ └── binding_site_to.pdf │ ├── edge_features_sites_prot_bert_bfd_topheads │ │ ├── args.json │ │ └── binding_site_to.pdf │ ├── edge_features_sites_prot_bert_topheads │ │ ├── args.json │ │ └── binding_site_to.pdf │ ├── edge_features_sites_prot_xlnet │ │ ├── args.json │ │ └── binding_site_to.pdf │ ├── edge_features_sites_prot_xlnet_topheads │ │ ├── args.json │ │ └── binding_site_to.pdf │ ├── edge_features_sites_tape_bert │ │ ├── args.json │ │ └── binding_site_to.pdf │ └── edge_features_sites_tape_bert_topheads │ │ ├── args.json │ │ └── binding_site_to.pdf └── probing │ └── multichart_layer_probing.pdf ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | # pytype static type analyzer 138 | .pytype/ 139 | 140 | # Cython debug symbols 141 | cython_debug/ -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing. 2 | #ECCN:Open Source 3 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Salesforce Open Source Community Code of Conduct 2 | 3 | ## About the Code of Conduct 4 | 5 | Equality is a core value at Salesforce. We believe a diverse and inclusive 6 | community fosters innovation and creativity, and are committed to building a 7 | culture where everyone feels included. 8 | 9 | Salesforce open-source projects are committed to providing a friendly, safe, and 10 | welcoming environment for all, regardless of gender identity and expression, 11 | sexual orientation, disability, physical appearance, body size, ethnicity, nationality, 12 | race, age, religion, level of experience, education, socioeconomic status, or 13 | other similar personal characteristics. 14 | 15 | The goal of this code of conduct is to specify a baseline standard of behavior so 16 | that people with different social values and communication styles can work 17 | together effectively, productively, and respectfully in our open source community. 18 | It also establishes a mechanism for reporting issues and resolving conflicts. 19 | 20 | All questions and reports of abusive, harassing, or otherwise unacceptable behavior 21 | in a Salesforce open-source project may be reported by contacting the Salesforce 22 | Open Source Conduct Committee at ossconduct@salesforce.com. 23 | 24 | ## Our Pledge 25 | 26 | In the interest of fostering an open and welcoming environment, we as 27 | contributors and maintainers pledge to making participation in our project and 28 | our community a harassment-free experience for everyone, regardless of gender 29 | identity and expression, sexual orientation, disability, physical appearance, 30 | body size, ethnicity, nationality, race, age, religion, level of experience, education, 31 | socioeconomic status, or other similar personal characteristics. 32 | 33 | ## Our Standards 34 | 35 | Examples of behavior that contributes to creating a positive environment 36 | include: 37 | 38 | * Using welcoming and inclusive language 39 | * Being respectful of differing viewpoints and experiences 40 | * Gracefully accepting constructive criticism 41 | * Focusing on what is best for the community 42 | * Showing empathy toward other community members 43 | 44 | Examples of unacceptable behavior by participants include: 45 | 46 | * The use of sexualized language or imagery and unwelcome sexual attention or 47 | advances 48 | * Personal attacks, insulting/derogatory comments, or trolling 49 | * Public or private harassment 50 | * Publishing, or threatening to publish, others' private information—such as 51 | a physical or electronic address—without explicit permission 52 | * Other conduct which could reasonably be considered inappropriate in a 53 | professional setting 54 | * Advocating for or encouraging any of the above behaviors 55 | 56 | ## Our Responsibilities 57 | 58 | Project maintainers are responsible for clarifying the standards of acceptable 59 | behavior and are expected to take appropriate and fair corrective action in 60 | response to any instances of unacceptable behavior. 61 | 62 | Project maintainers have the right and responsibility to remove, edit, or 63 | reject comments, commits, code, wiki edits, issues, and other contributions 64 | that are not aligned with this Code of Conduct, or to ban temporarily or 65 | permanently any contributor for other behaviors that they deem inappropriate, 66 | threatening, offensive, or harmful. 67 | 68 | ## Scope 69 | 70 | This Code of Conduct applies both within project spaces and in public spaces 71 | when an individual is representing the project or its community. Examples of 72 | representing a project or community include using an official project email 73 | address, posting via an official social media account, or acting as an appointed 74 | representative at an online or offline event. Representation of a project may be 75 | further defined and clarified by project maintainers. 76 | 77 | ## Enforcement 78 | 79 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 80 | reported by contacting the Salesforce Open Source Conduct Committee 81 | at ossconduct@salesforce.com. All complaints will be reviewed and investigated 82 | and will result in a response that is deemed necessary and appropriate to the 83 | circumstances. The committee is obligated to maintain confidentiality with 84 | regard to the reporter of an incident. Further details of specific enforcement 85 | policies may be posted separately. 86 | 87 | Project maintainers who do not follow or enforce the Code of Conduct in good 88 | faith may face temporary or permanent repercussions as determined by other 89 | members of the project's leadership and the Salesforce Open Source Conduct 90 | Committee. 91 | 92 | ## Attribution 93 | 94 | This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home], 95 | version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html. 96 | It includes adaptions and additions from [Go Community Code of Conduct][golang-coc], 97 | [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc]. 98 | 99 | This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us]. 100 | 101 | [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/) 102 | [golang-coc]: https://golang.org/conduct 103 | [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md 104 | [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/ 105 | [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/ 106 | -------------------------------------------------------------------------------- /CONTRIBUTING-ARCHIVED.md: -------------------------------------------------------------------------------- 1 | # ARCHIVED 2 | 3 | This project is `Archived` and is no longer actively maintained; 4 | We are not accepting contributions or Pull Requests. 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020, Salesforce.com, Inc. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE 13 | 14 | See attribution.txt for additional licensing information. 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BERTology Meets Biology: Interpreting Attention in Protein Language Models 2 | 3 | This repository is the official implementation of [BERTology Meets Biology: Interpreting Attention in Protein Language Models](https://arxiv.org/abs/2006.15222). 4 | 5 | ## Table of Contents 6 | 7 | - [ProVis Attention Visualizer](#provis-attention-visualizer) 8 | * [Installation](#installation) 9 | * [Execution](#execution) 10 | - [Experiments](#experiments) 11 | * [Installation](#installation-2) 12 | * [Datasets](#datasets) 13 | * [Attention Analysis](#attention-analysis) 14 | + [Tape BERT Model](#tape-bert-model) 15 | + [ProtTrans Models](#prottrans-models) 16 | * [Probing Analysis](#probing-analysis) 17 | + [Training](#training) 18 | + [Reports](#reports) 19 | - [License](#license) 20 | - [Acknowledgments](#acknowledgments) 21 | - [Citation](#citation) 22 | 23 | ## ProVis Attention Visualizer 24 | 25 | This section provides instructions for generating visualizations of attention projected onto 3D protein structure. 26 | 27 | ![Image](images/vis3d_binding_sites.png?raw=true) ![Image](images/vis3d_contact_map.png?raw=true) 28 | 29 | ### Installation 30 | **General requirements**: 31 | * Python >= 3.7 32 | 33 | ``` 34 | pip install biopython==1.77 35 | pip install tape-proteins==0.5 36 | pip install jupyterlab==3.0.14 37 | pip install nglview 38 | jupyter-nbextension enable nglview --py --sys-prefix 39 | ``` 40 | 41 | If you run into problems installing nglview, please refer to their 42 | [installation instructions](https://github.com/arose/nglview#released-version) for additional installation details 43 | and options. 44 | 45 | 46 | ### Execution 47 | 48 | ``` 49 | cd /notebooks 50 | jupyter notebook provis.ipynb 51 | ``` 52 | 53 | If you get an error running the notebook, you may need to execute the notebook as follows: 54 | 55 | ``` 56 | jupyter notebook --NotebookApp.iopub_data_rate_limit=10000000 57 | ``` 58 | See nglview [installation instructions](https://github.com/arose/nglview#released-version) for more details. 59 | 60 | You may edit the notebook to choose other proteins, attention heads, etc. The visualization tool is based on the 61 | excellent [nglview](https://github.com/arose/nglview) library. 62 | 63 | --- 64 | 65 | ## Experiments 66 | 67 | This section describes how to reproduce the experiments in the paper. 68 | 69 | ### Installation 70 | 71 | ```setup 72 | cd 73 | python setup.py develop 74 | ``` 75 | 76 | To download additional required datasets from [TAPE](https://github.com/songlab-cal/tape): 77 | 78 | ```setup 79 | cd /data 80 | wget http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/secondary_structure.tar.gz 81 | tar -xvf secondary_structure.tar.gz && rm secondary_structure.tar.gz 82 | wget http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/proteinnet.tar.gz 83 | tar -xvf proteinnet.tar.gz && rm proteinnet.tar.gz 84 | ``` 85 | 86 | ### Attention Analysis 87 | 88 | The following steps will reproduce the attention analysis experiments and generate the reports currently found in 89 | /reports/attention_analysis. This includes all experiments besides the probing experiments 90 | (see [Probing Analysis](#probing-analysis)). 91 | 92 | Before performing steps, navigate to appropriate directory: 93 | ``` 94 | cd /protein_attention/attention_analysis 95 | ``` 96 | 97 | #### Tape BERT Model 98 | 99 | The following executes the attention analysis (may run for several hours): 100 | ``` 101 | sh scripts/compute_all_features_tape_bert.sh 102 | ``` 103 | The above script create a set of extract files in /data/cache corresponding to various properties 104 | being analyzed. You may edit the script files to remove properties that you are not interested in. If you wish to run the 105 | analysis without a GPU, you must specify the `--no_cuda` flag. 106 | 107 | The following generate reports based on the files created in previous step: 108 | ``` 109 | sh scripts/report_all_features_tape_bert.sh 110 | ``` 111 | If you removed steps from the analysis script above, you will need to update the reporting script accordingly. 112 | 113 | 114 | #### ProtTrans Models 115 | 116 | In order to generate reports for the ProtTrans models, follow the instructions as for the TapeBert 117 | model above, but substitute the following commands:
118 | 119 | **ProtBert:**
120 | ``` 121 | sh scripts/compute_all_features_prot_bert.sh 122 | sh scripts/report_all_features_prot_bert.sh 123 | ``` 124 | 125 | **ProtBertBFD:**
126 | ``` 127 | sh scripts/compute_all_features_prot_bert_bfd.sh 128 | sh scripts/report_all_features_prot_bert_bfd.sh 129 | ``` 130 | 131 | **ProtAlbert:**
132 | ``` 133 | sh scripts/compute_all_features_prot_albert.sh 134 | sh scripts/report_all_features_prot_albert.sh 135 | ``` 136 | 137 | **ProtXLNet:**
138 | ``` 139 | sh scripts/compute_all_features_prot_xlnet.sh 140 | sh scripts/report_all_features_prot_xlnet.sh 141 | ``` 142 | 143 | ### Probing Analysis 144 | 145 | The following steps will recreate the figures from the probing analysis, currently found in /reports/probing 146 | 147 | Navigate to directory: 148 | ``` 149 | cd /protein_attention/probing 150 | ``` 151 | 152 | #### Training 153 | Train diagnostic classifiers. Each script will write out an extract file with evaluation results. Note: each of these scripts may run for several hours. 154 | ``` 155 | sh scripts/probe_ss4_0_all 156 | sh scripts/probe_ss4_1_all 157 | sh scripts/probe_ss4_2_all 158 | sh scripts/probe_sites.sh 159 | sh scripts/probe_contacts.sh 160 | ``` 161 | #### Reports 162 | ``` 163 | python report.py 164 | ``` 165 | 166 | ## License 167 | 168 | This project is licensed under BSD3 License - see the [LICENSE](LICENSE) file for details 169 | 170 | ## Acknowledgments 171 | 172 | This project incorporates code from the following repo: 173 | * https://github.com/songlab-cal/tape 174 | 175 | ## Citation 176 | 177 | When referencing this repository, please cite [this paper](https://arxiv.org/abs/2006.15222). 178 | 179 | ``` 180 | @misc{vig2020bertology, 181 | title={BERTology Meets Biology: Interpreting Attention in Protein Language Models}, 182 | author={Jesse Vig and Ali Madani and Lav R. Varshney and Caiming Xiong and Richard Socher and Nazneen Fatema Rajani}, 183 | year={2020}, 184 | eprint={2006.15222}, 185 | archivePrefix={arXiv}, 186 | primaryClass={cs.CL}, 187 | url={https://arxiv.org/abs/2006.15222} 188 | } 189 | ``` 190 | 191 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Security 2 | 3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com) 4 | as soon as it is discovered. This library limits its runtime dependencies in 5 | order to reduce the total cost of ownership as much as can be, but all consumers 6 | should remain vigilant and have their security stakeholders review all third-party 7 | products (3PP) like this one and their dependencies. -------------------------------------------------------------------------------- /attribution.txt: -------------------------------------------------------------------------------- 1 | For portions of source code based on TAPE (protein_attention/datasets.py, protein_attention/probing/probe.py) : 2 | 3 | BSD 3-Clause License 4 | 5 | Copyright (c) 2018, Regents of the University of California 6 | All rights reserved. 7 | 8 | Redistribution and use in source and binary forms, with or without 9 | modification, are permitted provided that the following conditions are met: 10 | 11 | * Redistributions of source code must retain the above copyright notice, this 12 | list of conditions and the following disclaimer. 13 | 14 | * Redistributions in binary form must reproduce the above copyright notice, 15 | this list of conditions and the following disclaimer in the documentation 16 | and/or other materials provided with the distribution. 17 | 18 | * Neither the name of the copyright holder nor the names of its 19 | contributors may be used to endorse or promote products derived from 20 | this software without specific prior written permission. 21 | 22 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 23 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 24 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 25 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 26 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 27 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 28 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 29 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 30 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /data/binding_sites/binding_site_train.lmdb/data.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/data/binding_sites/binding_site_train.lmdb/data.mdb -------------------------------------------------------------------------------- /data/binding_sites/binding_site_train.lmdb/lock.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/data/binding_sites/binding_site_train.lmdb/lock.mdb -------------------------------------------------------------------------------- /data/binding_sites/binding_site_valid.lmdb/data.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/data/binding_sites/binding_site_valid.lmdb/data.mdb -------------------------------------------------------------------------------- /data/binding_sites/binding_site_valid.lmdb/lock.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/data/binding_sites/binding_site_valid.lmdb/lock.mdb -------------------------------------------------------------------------------- /data/protein_modifications/protein_modification_train.lmdb/data.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/data/protein_modifications/protein_modification_train.lmdb/data.mdb -------------------------------------------------------------------------------- /data/protein_modifications/protein_modification_train.lmdb/lock.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/data/protein_modifications/protein_modification_train.lmdb/lock.mdb -------------------------------------------------------------------------------- /data/protein_modifications/protein_modification_valid.lmdb/data.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/data/protein_modifications/protein_modification_valid.lmdb/data.mdb -------------------------------------------------------------------------------- /data/protein_modifications/protein_modification_valid.lmdb/lock.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/data/protein_modifications/protein_modification_valid.lmdb/lock.mdb -------------------------------------------------------------------------------- /images/vis3d_binding_sites.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/images/vis3d_binding_sites.png -------------------------------------------------------------------------------- /images/vis3d_contact_map.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/images/vis3d_contact_map.png -------------------------------------------------------------------------------- /notebooks/provis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "pycharm": { 7 | "name": "#%% md\n" 8 | } 9 | }, 10 | "source": [ 11 | "# ProVis: Attention Visualizer for Proteins" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 1, 17 | "metadata": { 18 | "pycharm": { 19 | "is_executing": false, 20 | "name": "#%%\n" 21 | } 22 | }, 23 | "outputs": [ 24 | { 25 | "data": { 26 | "application/vnd.jupyter.widget-view+json": { 27 | "model_id": "937c070d141c45b3b59c631e43621d9e", 28 | "version_major": 2, 29 | "version_minor": 0 30 | }, 31 | "text/plain": [] 32 | }, 33 | "metadata": {}, 34 | "output_type": "display_data" 35 | } 36 | ], 37 | "source": [ 38 | "import io\n", 39 | "import urllib\n", 40 | "\n", 41 | "import torch\n", 42 | "from Bio.Data import SCOPData\n", 43 | "from Bio.PDB import PDBParser, PPBuilder\n", 44 | "from tape import TAPETokenizer, ProteinBertModel\n", 45 | "import nglview\n", 46 | "\n", 47 | "attn_color = [0.937, .522, 0.212]" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 2, 53 | "metadata": { 54 | "pycharm": { 55 | "name": "#%%\n" 56 | } 57 | }, 58 | "outputs": [], 59 | "source": [ 60 | "def get_structure(pdb_id):\n", 61 | " resource = urllib.request.urlopen(f'https://files.rcsb.org/download/{pdb_id}.pdb')\n", 62 | " content = resource.read().decode('utf8')\n", 63 | " handle = io.StringIO(content)\n", 64 | " parser = PDBParser(QUIET=True)\n", 65 | " return parser.get_structure(pdb_id, handle)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 3, 71 | "metadata": { 72 | "pycharm": { 73 | "name": "#%%\n" 74 | } 75 | }, 76 | "outputs": [], 77 | "source": [ 78 | "def get_attn_data(chain, layer, head, min_attn, start_index=0, end_index=None, max_seq_len=1024):\n", 79 | "\n", 80 | " tokens = []\n", 81 | " coords = []\n", 82 | " for res in chain:\n", 83 | " t = SCOPData.protein_letters_3to1.get(res.get_resname(), \"X\")\n", 84 | " tokens += t\n", 85 | " if t == 'X':\n", 86 | " coord = None\n", 87 | " else:\n", 88 | " coord = res['CA'].coord.tolist()\n", 89 | " coords.append(coord) \n", 90 | " last_non_x = None\n", 91 | " for i in reversed(range(len(tokens))):\n", 92 | " if tokens[i] != 'X':\n", 93 | " last_non_x = i\n", 94 | " break\n", 95 | " assert last_non_x is not None\n", 96 | " tokens = tokens[:last_non_x + 1]\n", 97 | " coords = coords[:last_non_x + 1] \n", 98 | " \n", 99 | " tokenizer = TAPETokenizer()\n", 100 | " model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)\n", 101 | "\n", 102 | " if max_seq_len:\n", 103 | " tokens = tokens[:max_seq_len - 2] # Account for SEP, CLS tokens (added in next step)\n", 104 | " token_idxs = tokenizer.encode(tokens).tolist()\n", 105 | " if max_seq_len:\n", 106 | " assert len(token_idxs) == min(len(tokens) + 2, max_seq_len)\n", 107 | " else:\n", 108 | " assert len(token_idxs) == len(tokens) + 2\n", 109 | "\n", 110 | " inputs = torch.tensor(token_idxs).unsqueeze(0)\n", 111 | " with torch.no_grad():\n", 112 | " attns = model(inputs)[-1]\n", 113 | " # Remove attention from (first) and (last) token\n", 114 | " attns = [attn[:, :, 1:-1, 1:-1] for attn in attns]\n", 115 | " attns = torch.stack([attn.squeeze(0) for attn in attns])\n", 116 | " attn = attns[layer, head]\n", 117 | " if end_index is None:\n", 118 | " end_index = len(tokens)\n", 119 | " attn_data = []\n", 120 | " for i in range(start_index, end_index):\n", 121 | " for j in range(i, end_index):\n", 122 | " # Currently non-directional: shows max of two attns\n", 123 | " a = max(attn[i, j].item(), attn[j, i].item())\n", 124 | " if a is not None and a >= min_attn:\n", 125 | " attn_data.append((a, coords[i], coords[j]))\n", 126 | " return attn_data" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": {}, 132 | "source": [ 133 | "### Visualize head 7-1 (targets binding sites)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 4, 139 | "metadata": { 140 | "pycharm": { 141 | "is_executing": false, 142 | "name": "#%%\n" 143 | } 144 | }, 145 | "outputs": [ 146 | { 147 | "name": "stdout", 148 | "output_type": "stream", 149 | "text": [ 150 | "Loading chain A\n", 151 | "Loading chain B\n", 152 | "Loading chain C\n" 153 | ] 154 | }, 155 | { 156 | "data": { 157 | "application/vnd.jupyter.widget-view+json": { 158 | "model_id": "b706257cd6c44ab5a7e85da6476f6c8a", 159 | "version_major": 2, 160 | "version_minor": 0 161 | }, 162 | "text/plain": [ 163 | "NGLWidget()" 164 | ] 165 | }, 166 | "metadata": {}, 167 | "output_type": "display_data" 168 | } 169 | ], 170 | "source": [ 171 | "# Example for head 7-1 (targets binding sites)\n", 172 | "pdb_id = '7HVP'\n", 173 | "chain_ids = None # All chains\n", 174 | "layer = 7\n", 175 | "head = 1\n", 176 | "min_attn = 0.1\n", 177 | "attn_scale = .9\n", 178 | "\n", 179 | "layer_zero_indexed = layer - 1\n", 180 | "head_zero_indexed = head - 1\n", 181 | "\n", 182 | "structure = get_structure(pdb_id)\n", 183 | "view = nglview.show_biopython(structure)\n", 184 | "view.stage.set_parameters(**{\n", 185 | " \"backgroundColor\": \"black\",\n", 186 | " \"fogNear\": 50, \"fogFar\": 100,\n", 187 | "})\n", 188 | "\n", 189 | "models = list(structure.get_models())\n", 190 | "if len(models) > 1:\n", 191 | " print('Warning:', len(models), 'models. Using first one')\n", 192 | "prot_model = models[0]\n", 193 | "\n", 194 | "if chain_ids is None:\n", 195 | " chain_ids = [chain.id for chain in prot_model]\n", 196 | "for chain_id in chain_ids: \n", 197 | " print('Loading chain', chain_id)\n", 198 | " chain = prot_model[chain_id] \n", 199 | " attn_data = get_attn_data(chain, layer_zero_indexed, head_zero_indexed, min_attn)\n", 200 | " for att, coords_from, coords_to in attn_data:\n", 201 | " view.shape.add_cylinder(coords_from, coords_to, attn_color, att * attn_scale) \n", 202 | " \n", 203 | "view" 204 | ] 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "metadata": {}, 209 | "source": [ 210 | "### Visualize head 12-4 (targets contact maps)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 5, 216 | "metadata": {}, 217 | "outputs": [ 218 | { 219 | "name": "stdout", 220 | "output_type": "stream", 221 | "text": [ 222 | "Warning: 20 models. Using first one\n", 223 | "Loading chain A\n" 224 | ] 225 | }, 226 | { 227 | "data": { 228 | "application/vnd.jupyter.widget-view+json": { 229 | "model_id": "c7451b75445446f9af10e6c1ec107f9a", 230 | "version_major": 2, 231 | "version_minor": 0 232 | }, 233 | "text/plain": [ 234 | "NGLWidget()" 235 | ] 236 | }, 237 | "metadata": {}, 238 | "output_type": "display_data" 239 | } 240 | ], 241 | "source": [ 242 | "# Example for head 12-4 (targets contact maps)\n", 243 | "pdb_id = '2KC7'\n", 244 | "chain_ids = None # All chains\n", 245 | "layer = 12\n", 246 | "head = 4\n", 247 | "min_attn = 0.2\n", 248 | "attn_scale = .5\n", 249 | "\n", 250 | "layer_zero_indexed = layer - 1\n", 251 | "head_zero_indexed = head - 1\n", 252 | "\n", 253 | "structure = get_structure(pdb_id)\n", 254 | "view2 = nglview.show_biopython(structure)\n", 255 | "view2.stage.set_parameters(**{\n", 256 | " \"backgroundColor\": \"black\",\n", 257 | " \"fogNear\": 50, \"fogFar\": 100,\n", 258 | "})\n", 259 | "\n", 260 | "models = list(structure.get_models())\n", 261 | "if len(models) > 1:\n", 262 | " print('Warning:', len(models), 'models. Using first one')\n", 263 | "prot_model = models[0]\n", 264 | "\n", 265 | "if chain_ids is None:\n", 266 | " chain_ids = [chain.id for chain in prot_model]\n", 267 | "for chain_id in chain_ids: \n", 268 | " print('Loading chain', chain_id)\n", 269 | " chain = prot_model[chain_id] \n", 270 | " attn_data = get_attn_data(chain, layer_zero_indexed, head_zero_indexed, min_attn)\n", 271 | " for att, coords_from, coords_to in attn_data:\n", 272 | " view2.shape.add_cylinder(coords_from, coords_to, attn_color, att * attn_scale) \n", 273 | " \n", 274 | "view2\n", 275 | "\n", 276 | "# To save: view2.download_image(filename=\"testing.png\")" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "\n", 286 | "\n" 287 | ] 288 | } 289 | ], 290 | "metadata": { 291 | "kernelspec": { 292 | "display_name": "Python 3", 293 | "language": "python", 294 | "name": "python3" 295 | }, 296 | "language_info": { 297 | "codemirror_mode": { 298 | "name": "ipython", 299 | "version": 3 300 | }, 301 | "file_extension": ".py", 302 | "mimetype": "text/x-python", 303 | "name": "python", 304 | "nbconvert_exporter": "python", 305 | "pygments_lexer": "ipython3", 306 | "version": "3.8.8" 307 | }, 308 | "pycharm": { 309 | "stem_cell": { 310 | "cell_type": "raw", 311 | "source": [], 312 | "metadata": { 313 | "collapsed": false 314 | } 315 | } 316 | } 317 | }, 318 | "nbformat": 4, 319 | "nbformat_minor": 1 320 | } -------------------------------------------------------------------------------- /protein_attention/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/protein_attention/__init__.py -------------------------------------------------------------------------------- /protein_attention/attention_analysis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/protein_attention/attention_analysis/__init__.py -------------------------------------------------------------------------------- /protein_attention/attention_analysis/background.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | import tqdm 4 | from tape.datasets import ProteinnetDataset 5 | 6 | from protein_attention.datasets import BindingSiteDataset, ProteinModificationDataset 7 | from protein_attention.utils import get_data_path 8 | 9 | 10 | def binding_site_distribution(max_len): 11 | d = BindingSiteDataset(get_data_path(), 'train') 12 | c = Counter() 13 | for row in tqdm.tqdm(d): 14 | site_indic = row[-1] 15 | c.update(site_indic[:max_len]) 16 | 17 | return c 18 | 19 | 20 | def protein_modification_distribution(max_len): 21 | d = ProteinModificationDataset(get_data_path(), 'train') 22 | c = Counter() 23 | for row in tqdm.tqdm(d): 24 | mod_indic = row[-1] 25 | c.update(mod_indic[:max_len]) 26 | return c 27 | 28 | 29 | def contact_map_distribution(max_len): 30 | d = ProteinnetDataset(get_data_path(), 'train') 31 | c = Counter() 32 | for row in tqdm.tqdm(d): 33 | contact_map = row[-2][:max_len, :max_len].flatten() 34 | c.update(contact_map) 35 | return c 36 | -------------------------------------------------------------------------------- /protein_attention/attention_analysis/compute_edge_features.py: -------------------------------------------------------------------------------- 1 | """Compute aggregate statistics of attention edge features over a dataset 2 | 3 | Copyright (c) 2020, salesforce.com, inc. 4 | All rights reserved. 5 | SPDX-License-Identifier: BSD-3-Clause 6 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | """ 8 | import re 9 | from collections import defaultdict 10 | 11 | import numpy as np 12 | import torch 13 | from tqdm import tqdm 14 | 15 | 16 | def compute_mean_attention(model, 17 | n_layers, 18 | n_heads, 19 | items, 20 | features, 21 | tokenizer, 22 | model_name, 23 | model_version, 24 | cuda=True, 25 | max_seq_len=None, 26 | min_attn=0): 27 | model.eval() 28 | 29 | with torch.no_grad(): 30 | 31 | # Dictionary that maps feature_name to array of shape (n_layers, n_heads), containing 32 | # weighted sum of feature values for each layer/head over all examples 33 | feature_to_weighted_sum = defaultdict(lambda: torch.zeros((n_layers, n_heads), dtype=torch.double)) 34 | 35 | # Sum of attention_analysis weights in each layer/head over all examples 36 | weight_total = torch.zeros((n_layers, n_heads), dtype=torch.double) 37 | 38 | for item in tqdm(items): 39 | # Get attention weights, shape is (num_layers, num_heads, seq_len, seq_len) 40 | attns = get_attention(model, 41 | item, 42 | tokenizer, 43 | model_name, 44 | model_version, 45 | cuda, 46 | max_seq_len) 47 | if attns is None: 48 | print('Skipping due to not returning attention') 49 | continue 50 | 51 | # Update total attention_analysis weights per head. Sum over from_index (dim 2), to_index (dim 3) 52 | mask = attns >= min_attn 53 | weight_total += mask.long().sum((2, 3)) 54 | 55 | # Update weighted sum of feature values per head 56 | seq_len = attns.size(2) 57 | for to_index in range(seq_len): 58 | for from_index in range(seq_len): 59 | for feature in features: 60 | # Compute feature values 61 | feature_dict = feature.get_values(item, from_index, to_index) 62 | for feature_name, value in feature_dict.items(): 63 | # Update weighted sum of feature values across layers and heads 64 | mask = attns[:, :, from_index, to_index] >= min_attn 65 | feature_to_weighted_sum[feature_name] += mask * value 66 | 67 | return feature_to_weighted_sum, weight_total 68 | 69 | 70 | def get_attention(model, 71 | item, 72 | tokenizer, 73 | model_name, 74 | model_version, 75 | cuda, 76 | max_seq_len): 77 | tokens = item['primary'] 78 | if model_name == 'bert': 79 | if max_seq_len: 80 | tokens = tokens[:max_seq_len - 2] # Account for SEP, CLS tokens (added in next step) 81 | if model_version in ('prot_bert', 'prot_bert_bfd', 'prot_albert'): 82 | formatted_tokens = ' '.join(list(tokens)) 83 | formatted_tokens = re.sub(r"[UZOB]", "X", formatted_tokens) 84 | token_idxs = tokenizer.encode(formatted_tokens) 85 | else: 86 | token_idxs = tokenizer.encode(tokens) 87 | if isinstance(token_idxs, np.ndarray): 88 | token_idxs = token_idxs.tolist() 89 | if max_seq_len: 90 | assert len(token_idxs) == min(len(tokens) + 2, max_seq_len), (tokens, token_idxs, max_seq_len) 91 | else: 92 | assert len(token_idxs) == len(tokens) + 2 93 | elif model_name == 'xlnet': 94 | if max_seq_len: 95 | tokens = tokens[:max_seq_len - 2] # Account for SEP, CLS tokens (added in next step) 96 | formatted_tokens = ' '.join(list(tokens)) 97 | formatted_tokens = re.sub(r"[UZOB]", "X", formatted_tokens) 98 | token_idxs = tokenizer.encode(formatted_tokens) 99 | if isinstance(token_idxs, np.ndarray): 100 | token_idxs = token_idxs.tolist() 101 | if max_seq_len: 102 | # Skip rare sequence with this issue 103 | if len(token_idxs) != min(len(tokens) + 2, max_seq_len): 104 | print('Warning: the length of the sequence changed through tokenization, skipping') 105 | return None 106 | else: 107 | assert len(token_idxs) == len(tokens) + 2 108 | else: 109 | raise ValueError 110 | 111 | inputs = torch.tensor(token_idxs).unsqueeze(0) 112 | if cuda: 113 | inputs = inputs.cuda() 114 | attns = model(inputs)[-1] 115 | 116 | if model_name == 'bert': 117 | # Remove attention from (first) and (last) token 118 | attns = [attn[:, :, 1:-1, 1:-1] for attn in attns] 119 | elif model_name == 'xlnet': 120 | # Remove attention from (last) and (second to last) token 121 | attns = [attn[:, :, :-2, :-2] for attn in attns] 122 | else: 123 | raise NotImplementedError 124 | 125 | if 'contact_map' in item: 126 | assert (item['contact_map'].shape == attns[0][0, 0].shape) or (attns[0][0, 0].shape[0] == max_seq_len - 2), \ 127 | (item['id'], item['contact_map'].shape, attns[0][0, 0].shape) 128 | if 'site_indic' in item: 129 | assert (item['site_indic'].shape == attns[0][0, 0, 0].shape) or (attns[0][0, 0].shape[0] == max_seq_len - 2), \ 130 | item['id'] 131 | if 'modification_indic' in item: 132 | assert (item['modification_indic'].shape == attns[0][0, 0, 0].shape) or ( 133 | attns[0][0, 0].shape[0] == max_seq_len - 2), \ 134 | item['id'] 135 | 136 | attns = torch.stack([attn.squeeze(0) for attn in attns]) 137 | return attns.cpu() 138 | 139 | 140 | def convert_item(dataset_name, x, data, model_name, features): 141 | item = {} 142 | try: 143 | item['id'] = data['id'] 144 | except ValueError: 145 | item['id'] = data['id'].decode('utf8') 146 | 147 | item['primary'] = data['primary'] 148 | if dataset_name == 'proteinnet': 149 | if 'contact_map' in features: 150 | token_ids, input_mask, contact_map, protein_length = x 151 | item['contact_map'] = contact_map 152 | elif dataset_name == 'secondary': 153 | if 'ss4' in features: 154 | ss8_blank_index = 7 155 | ss4_blank_index = 3 156 | item['secondary'] = [ss4_blank_index if ss8 == ss8_blank_index else ss3 for ss3, ss8 in \ 157 | zip(data['ss3'], data['ss8'])] 158 | elif dataset_name == 'binding_sites': 159 | if 'binding_sites' in features: 160 | token_ids, input_mask, site_indic = x 161 | item['site_indic'] = site_indic 162 | elif dataset_name == 'protein_modifications': 163 | if 'protein_modifications' in features: 164 | token_ids, input_mask, modification_indic = x 165 | item['modification_indic'] = modification_indic 166 | else: 167 | raise ValueError 168 | 169 | if model_name == 'bert': 170 | # Remove label values from (first) and (last) token 171 | if 'site_indic' in item: 172 | item['site_indic'] = item['site_indic'][1:-1] 173 | if 'modification_indic' in item: 174 | item['modification_indic'] = item['modification_indic'][1:-1] 175 | elif model_name == 'xlnet': 176 | # Remove label values from (last) and (second to last) token 177 | if 'site_indic' in item: 178 | item['site_indic'] = item['site_indic'][:-2] 179 | if 'modification_indic' in item: 180 | item['modification_indic'] = item['modification_indic'][:-2] 181 | else: 182 | raise NotImplementedError 183 | 184 | return item 185 | 186 | 187 | if __name__ == "__main__": 188 | import pickle 189 | import pathlib 190 | 191 | from transformers import BertModel, AutoTokenizer, XLNetModel, XLNetTokenizer, AlbertModel, AlbertTokenizer 192 | from tape import TAPETokenizer, ProteinBertModel 193 | from tape.datasets import ProteinnetDataset, SecondaryStructureDataset 194 | 195 | from protein_attention.datasets import BindingSiteDataset, ProteinModificationDataset 196 | from protein_attention.utils import get_cache_path, get_data_path 197 | from protein_attention.attention_analysis.features import AminoAcidFeature, SecStructFeature, BindingSiteFeature, \ 198 | ContactMapFeature, ProteinModificationFeature 199 | 200 | import argparse 201 | 202 | parser = argparse.ArgumentParser() 203 | parser.add_argument('--exp-name', required=True, help='Name of experiment. Used to create unique filename.') 204 | parser.add_argument('--features', nargs='+', required=True, help='list of features') 205 | parser.add_argument('--dataset', required=True, help='Dataset id') 206 | parser.add_argument('--num-sequences', type=int, required=True, help='Number of sequences to analyze') 207 | parser.add_argument('--model', default='bert', help='Name of model.') 208 | parser.add_argument('--model-version', help='Name of model version.') 209 | parser.add_argument('--model_dir', help='Optional directory where pretrained model is located') 210 | parser.add_argument('--shuffle', action='store_true', help='Whether to randomly shuffle data') 211 | parser.add_argument('--max-seq-len', type=int, required=True, help='Max sequence length') 212 | parser.add_argument('--seed', type=int, default=123, help='PyTorch seed') 213 | parser.add_argument('--min-attn', type=float, help='min attention value for inclusion in analysis') 214 | parser.add_argument('--no_cuda', action='store_true', help='CPU only') 215 | args = parser.parse_args() 216 | print(args) 217 | 218 | if args.model_version and args.model_dir: 219 | raise ValueError('Cannot specify both model version and directory') 220 | 221 | if args.num_sequences is not None and not args.shuffle: 222 | print('WARNING: You are using a subset of sequences and you are not shuffling the data. This may result ' 223 | 'in a skewed sample.') 224 | cuda = not args.no_cuda 225 | 226 | torch.manual_seed(args.seed) 227 | 228 | if args.dataset == 'proteinnet': 229 | dataset = ProteinnetDataset(get_data_path(), 'train') 230 | elif args.dataset == 'secondary': 231 | dataset = SecondaryStructureDataset(get_data_path(), 'train') 232 | elif args.dataset == 'binding_sites': 233 | dataset = BindingSiteDataset(get_data_path(), 'train') 234 | elif args.dataset == 'protein_modifications': 235 | dataset = ProteinModificationDataset(get_data_path(), 'train') 236 | else: 237 | raise ValueError(f"Invalid dataset id: {args.dataset}") 238 | 239 | if not args.num_sequences: 240 | raise NotImplementedError 241 | 242 | if args.model == 'bert': 243 | if args.model_dir: 244 | model_version = args.model_dir 245 | else: 246 | model_version = args.model_version or 'bert-base' 247 | if model_version == 'prot_bert_bfd': 248 | model = BertModel.from_pretrained("Rostlab/prot_bert_bfd", output_attentions=True) 249 | tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_bert_bfd", do_lower_case=False) 250 | elif model_version == 'prot_bert': 251 | model = BertModel.from_pretrained("Rostlab/prot_bert", output_attentions=True) 252 | tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False) 253 | elif model_version == 'prot_albert': 254 | model = AlbertModel.from_pretrained("Rostlab/prot_albert", output_attentions=True) 255 | tokenizer = AlbertTokenizer.from_pretrained("Rostlab/prot_albert", do_lower_case=False) 256 | else: 257 | model = ProteinBertModel.from_pretrained(model_version, output_attentions=True) 258 | tokenizer = TAPETokenizer() 259 | num_layers = model.config.num_hidden_layers 260 | num_heads = model.config.num_attention_heads 261 | elif args.model == 'xlnet': 262 | model_version = args.model_version 263 | if model_version == 'prot_xlnet': 264 | model = XLNetModel.from_pretrained("Rostlab/prot_xlnet", output_attentions=True) 265 | tokenizer = XLNetTokenizer.from_pretrained("Rostlab/prot_xlnet", do_lower_case=False) 266 | else: 267 | raise ValueError('Invalid model version') 268 | num_layers = model.config.n_layer 269 | num_heads = model.config.n_head 270 | else: 271 | raise ValueError(f"Invalid model: {args.model}") 272 | 273 | print('Layers:', num_layers) 274 | print('Heads:', num_heads) 275 | if cuda: 276 | model.to('cuda') 277 | 278 | if args.shuffle: 279 | random_indices = torch.randperm(len(dataset))[:args.num_sequences].tolist() 280 | items = [] 281 | print('Loading dataset') 282 | for i in tqdm(random_indices): 283 | item = convert_item(args.dataset, dataset[i], dataset.data[i], args.model, args.features) 284 | items.append(item) 285 | else: 286 | raise NotImplementedError 287 | 288 | features = [] 289 | for feature_name in args.features: 290 | if feature_name == 'aa': 291 | features.append(AminoAcidFeature()) 292 | elif feature_name == 'ss4': 293 | features.append(SecStructFeature()) 294 | elif feature_name == 'binding_sites': 295 | features.append(BindingSiteFeature()) 296 | elif feature_name == 'protein_modifications': 297 | features.append(ProteinModificationFeature()) 298 | elif feature_name == 'contact_map': 299 | features.append(ContactMapFeature()) 300 | else: 301 | raise ValueError(f"Invalid feature name: {feature_name}") 302 | 303 | feature_to_weighted_sum, weight_total = compute_mean_attention( 304 | model, 305 | num_layers, 306 | num_heads, 307 | items, 308 | features, 309 | tokenizer, 310 | args.model, 311 | model_version, 312 | cuda, 313 | max_seq_len=args.max_seq_len, 314 | min_attn=args.min_attn) 315 | 316 | cache_dir = get_cache_path() 317 | pathlib.Path(cache_dir).mkdir(parents=True, exist_ok=True) 318 | path = cache_dir / f'{args.exp_name}.pickle' 319 | pickle.dump((args, dict(feature_to_weighted_sum), weight_total), open(path, 'wb')) 320 | print('Wrote to', path) 321 | -------------------------------------------------------------------------------- /protein_attention/attention_analysis/features.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from abc import ABC, abstractmethod 9 | 10 | 11 | class EdgeFeature(ABC): 12 | 13 | @abstractmethod 14 | def get_values(self, item, from_index, to_index): 15 | pass 16 | 17 | 18 | class SecStructFeature(EdgeFeature): 19 | 20 | def __init__(self, include_from=False, include_to=True): 21 | self.include_from = include_from 22 | self.include_to = include_to 23 | assert include_from or include_to 24 | 25 | def get_values(self, seq, from_index, to_index): 26 | feature_values = {} 27 | if self.include_from: 28 | from_secstruct = seq['secondary'][from_index] 29 | feature_name = f'sec_struct_from_{from_secstruct}' 30 | feature_values[feature_name] = 1 31 | if self.include_to: 32 | to_secstruct = seq['secondary'][to_index] 33 | feature_name = f'sec_struct_to_{to_secstruct}' 34 | feature_values[feature_name] = 1 35 | 36 | return feature_values 37 | 38 | 39 | class AminoAcidFeature(EdgeFeature): 40 | 41 | def __init__(self, include_from=False, include_to=True): 42 | self.include_from = include_from 43 | self.include_to = include_to 44 | assert include_from or include_to 45 | 46 | def get_values(self, item, from_index, to_index): 47 | feature_values = {} 48 | if self.include_from: 49 | feature_name = f'aa_from_{item["primary"][from_index]}' 50 | feature_values[feature_name] = 1 51 | if self.include_to: 52 | feature_name = f'aa_to_{item["primary"][to_index]}' 53 | feature_values[feature_name] = 1 54 | return feature_values 55 | 56 | 57 | class BindingSiteFeature(EdgeFeature): 58 | 59 | def get_values(self, item, from_index, to_index, dense=False): 60 | if item['site_indic'][to_index] == 1: 61 | return {'binding_site_to': 1} 62 | else: 63 | if dense: 64 | return {'binding_site_to': 0} 65 | else: 66 | return {} 67 | 68 | 69 | class ProteinModificationFeature(EdgeFeature): 70 | 71 | def get_values(self, item, from_index, to_index, dense=False): 72 | if item['modification_indic'][to_index] == 1: 73 | return {'protein_modification_to': 1} 74 | else: 75 | if dense: 76 | return {'protein_modification_to': 0} 77 | else: 78 | return {} 79 | 80 | 81 | class ContactMapFeature(EdgeFeature): 82 | def get_values(self, item, from_index, to_index, dense=False): 83 | contact_map = item['contact_map'] 84 | contact1 = contact_map[from_index, to_index] 85 | contact2 = contact_map[to_index, from_index] 86 | assert contact1 == contact2 87 | if contact1 == 1: 88 | return {'contact_map': 1} 89 | else: 90 | if dense: 91 | return {'contact_map': 0} 92 | else: 93 | return {} 94 | -------------------------------------------------------------------------------- /protein_attention/attention_analysis/report_aa_correlations.py: -------------------------------------------------------------------------------- 1 | """Report pairwise correlations of attention to specific amino acids, and compare to blosum""" 2 | 3 | import json 4 | import pickle 5 | import re 6 | 7 | import seaborn as sns 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | from scipy.stats import pearsonr 11 | from Bio.SubsMat.MatrixInfo import blosum62 12 | 13 | from protein_attention.utils import get_reports_path, get_cache_path 14 | 15 | sns.set() 16 | np.random.seed(0) 17 | 18 | 19 | def to_filename(s): 20 | return "".join(x if (x.isalnum() or x in "._-") else '_' for x in s) 21 | 22 | 23 | def create_figures(feature_to_weighted_sums, weight_totals, min_total, report_dir, filetype): 24 | aa_blosum = set() 25 | for aa1, aa2 in blosum62.keys(): 26 | aa_blosum.add(aa1) 27 | aa_blosum.add(aa2) 28 | 29 | include_mask = weight_totals >= min_total 30 | 31 | p = re.compile(r'aa_to_([A-Z])$') 32 | aa_to_features = {} 33 | for feature_name, weighted_sums in feature_to_weighted_sums.items(): 34 | m = p.match(feature_name) 35 | if m: 36 | aa = m[1] 37 | mean_by_heads = np.where(include_mask, weighted_sums / weight_totals, -1) 38 | feature_vector = mean_by_heads.flatten() 39 | feature_vector = feature_vector[feature_vector != -1] 40 | aa_to_features[aa] = feature_vector 41 | 42 | aas = sorted(aa_to_features.keys()) 43 | aas_set = set(aas) 44 | print('Excluding following AAs not in feature set', aa_blosum - aas_set) 45 | print('Excluding following AAs not in blosum62', aas_set - aa_blosum) 46 | aa_list = sorted(list(aas_set & aa_blosum)) 47 | n_aa = len(aa_list) 48 | corr = np.zeros((n_aa, n_aa)) 49 | for i, aa1 in enumerate(aa_list): 50 | vector1 = aa_to_features[aa1] 51 | for j, aa2 in enumerate(aa_list): 52 | if i == j: 53 | corr[i, j] = None 54 | else: 55 | vector2 = aa_to_features[aa2] 56 | corr[i, j], _ = pearsonr(vector1, vector2) 57 | 58 | cmap = 'Blues' 59 | ax = sns.heatmap(corr, cmap=cmap, vmin=-0.5) 60 | ax.set_xticklabels(aa_list) 61 | ax.set_yticklabels(aa_list) 62 | plt.savefig(report_dir / f'aa_corr_to.pdf', format=filetype) 63 | plt.close() 64 | 65 | blosum = np.zeros((n_aa, n_aa)) 66 | for i, aa1 in enumerate(aa_list): 67 | for j, aa2 in enumerate(aa_list): 68 | if i == j: 69 | blosum[i, j] = None 70 | else: 71 | if blosum62.get((aa1, aa2)) is not None: 72 | blosum[i, j] = blosum62.get((aa1, aa2)) 73 | else: 74 | blosum[i, j] = blosum62.get((aa2, aa1)) 75 | 76 | ax = sns.heatmap(blosum, cmap=cmap, vmin=-4, vmax=4) 77 | ax.set_xticklabels(aa_list) 78 | ax.set_yticklabels(aa_list) 79 | plt.savefig(report_dir / f'blosum62.pdf', 80 | format=filetype) 81 | plt.close() 82 | 83 | corr_scores = [] 84 | blos_scores = [] 85 | for i in range(n_aa): 86 | for j in range(i): 87 | corr_scores.append(corr[i, j]) 88 | blos_scores.append(blosum[i, j]) 89 | print('Pearson Correlation between feature corr and blosum', 90 | pearsonr(corr_scores, blos_scores)[0]) 91 | 92 | 93 | if __name__ == "__main__": 94 | 95 | import argparse 96 | 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument('exp_name', help='Name of experiment') 99 | args = parser.parse_args() 100 | 101 | min_total = 100 102 | filetype = 'pdf' 103 | 104 | cache_path = get_cache_path() / f'{args.exp_name}.pickle' 105 | args, feature_to_weighted_sums, weight_totals = pickle.load(open(cache_path, "rb")) 106 | print(args) 107 | print(weight_totals) 108 | 109 | report_dir = get_reports_path() / 'attention_analysis/blosum' / args.exp_name 110 | report_dir.mkdir(parents=True, exist_ok=True) 111 | 112 | create_figures(feature_to_weighted_sums, weight_totals, min_total, report_dir, filetype) 113 | 114 | with open(report_dir / 'args.json', 'w') as f: 115 | json.dump(vars(args), f) 116 | -------------------------------------------------------------------------------- /protein_attention/attention_analysis/report_edge_features.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | import pathlib 10 | 11 | import numpy as np 12 | 13 | np.random.seed(0) 14 | import seaborn as sns 15 | 16 | sns.set() 17 | import pickle 18 | 19 | from protein_attention.utils import get_reports_path, get_cache_path 20 | from matplotlib.ticker import FuncFormatter 21 | 22 | import numpy as np 23 | import seaborn as sns 24 | from matplotlib import pyplot as plt 25 | from matplotlib.colors import LinearSegmentedColormap 26 | from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable 27 | from mpl_toolkits.axes_grid1.colorbar import colorbar 28 | 29 | secondary_names = { 30 | '0': 'Helix', 31 | '1': 'Beta', 32 | '2': 'Turn/Bend', 33 | '3': 'Blank' 34 | } 35 | 36 | 37 | def to_filename(s, extension): 38 | return "".join(x if (x.isalnum() or x in "._-") else '_' for x in s) + "." + extension 39 | 40 | 41 | def create_figure(feature_name, weighted_sum, weight_total, report_dir, min_total, filetype): 42 | assert filetype in ('png', 'pdf') 43 | 44 | mean_by_head = weighted_sum / weight_total 45 | exclude_mask = np.array(weight_total) < min_total 46 | 47 | masked_mean_by_head = np.ma.masked_array(mean_by_head, mask=exclude_mask) 48 | layer_max = masked_mean_by_head.max(-1) 49 | 50 | n_layers, n_heads = mean_by_head.shape 51 | if n_layers == 12 and n_heads == 12: 52 | plt.figure(figsize=(3, 2.2)) 53 | ax1 = plt.subplot2grid((100, 85), (0, 0), colspan=65, rowspan=99) # Heatmap 54 | ax2 = plt.subplot2grid((100, 85), (12, 70), colspan=15, rowspan=75) # Barchart 55 | elif n_layers == 30 and n_heads == 16: 56 | plt.figure(figsize=(3, 2.2)) 57 | ax1 = plt.subplot2grid((100, 85), (0, 5), colspan=55, rowspan=96) 58 | ax2 = plt.subplot2grid((100, 85), (0, 62), colspan=17, rowspan=97) 59 | elif n_layers == 12 and n_heads == 64: 60 | plt.figure(figsize=(8.5, 2.2)) 61 | ax1 = plt.subplot2grid((100, 160), (0, 5), colspan=135, rowspan=96) 62 | ax2 = plt.subplot2grid((100, 160), (22, 144), colspan=10, rowspan=53) 63 | plt.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95, wspace=0.01, hspace=0.01) 64 | else: 65 | raise NotImplementedError 66 | 67 | xtick_labels = [str(i) if i % 2 == 0 else '' for i in range(1, n_heads + 1)] 68 | ytick_labels = [str(i) if i % 2 == 0 else '' for i in range(1, n_layers + 1)] 69 | heatmap = sns.heatmap((mean_by_head * 100).tolist(), center=0.0, ax=ax1, 70 | square=True, cbar=False, linewidth=0.1, linecolor='#D0D0D0', 71 | cmap=LinearSegmentedColormap.from_list('rg', ["#F14100", "white", "#3D4FC4"], N=256), 72 | mask=exclude_mask, 73 | xticklabels=xtick_labels, 74 | yticklabels=ytick_labels) 75 | for _, spine in heatmap.spines.items(): 76 | spine.set_visible(True) 77 | spine.set_edgecolor('#D0D0D0') 78 | spine.set_linewidth(0.1) 79 | plt.setp(heatmap.get_yticklabels(), fontsize=7) 80 | plt.setp(heatmap.get_xticklabels(), fontsize=7) 81 | heatmap.tick_params(axis='x', pad=1, length=2) 82 | heatmap.tick_params(axis='y', pad=.5, length=2) 83 | heatmap.yaxis.labelpad = 3 84 | heatmap.invert_yaxis() 85 | heatmap.set_facecolor('#E7E6E6') 86 | # split axes of heatmap to put colorbar 87 | ax_divider = make_axes_locatable(ax1) 88 | if n_layers == 12 and n_heads == 12: 89 | cax = ax_divider.append_axes('left', size='7%', pad='33%') 90 | elif n_layers == 30 and n_heads == 16: 91 | cax = ax_divider.append_axes('left', size='7%', pad='45%') 92 | elif n_layers == 12 and n_heads == 64: 93 | cax = ax_divider.append_axes('left', size='1.5%', pad='7%') 94 | else: 95 | raise NotImplementedError 96 | # # make colorbar for heatmap. 97 | # # Heatmap returns an axes obj but you need to get a mappable obj (get_children) 98 | cbar = colorbar(ax1.get_children()[0], cax=cax, orientation='vertical', format='%.0f%%') 99 | cax.yaxis.set_ticks_position('left') 100 | cbar.solids.set_edgecolor("face") 101 | cbar.ax.tick_params(labelsize=7, length=4, pad=2) 102 | ax1.set_title('% Attention', size=9) 103 | ax1.set_xlabel('Head', size=8) 104 | ax1.set_ylabel('Layer', size=8) 105 | for _, spine in ax1.spines.items(): 106 | spine.set_visible(True) 107 | ax2.set_title('Max', size=9) 108 | bp = sns.barplot(x=layer_max * 100, ax=ax2, y=list(range(layer_max.shape[0])), color="#3D4FC4", orient="h", 109 | edgecolor="none") 110 | formatter = FuncFormatter(lambda y, pos: '0' if (y == 0) else "%d%%" % (y)) 111 | ax2.xaxis.set_major_formatter(formatter) 112 | plt.setp(bp.get_xticklabels(), fontsize=7) 113 | bp.tick_params(axis='x', pad=1, length=3) 114 | ax2.invert_yaxis() 115 | ax2.set_yticklabels([]) 116 | ax2.spines['top'].set_visible(False) 117 | ax2.spines['right'].set_visible(False) 118 | ax2.spines['left'].set_visible(False) 119 | ax2.xaxis.set_ticks_position('bottom') 120 | ax2.axvline(0, linewidth=.85, color='black') 121 | fname = report_dir / to_filename(feature_name, filetype) 122 | print('Saving', fname) 123 | plt.savefig(fname, format=filetype) 124 | plt.close() 125 | 126 | 127 | if __name__ == "__main__": 128 | 129 | import argparse 130 | 131 | parser = argparse.ArgumentParser() 132 | parser.add_argument('exp_name', help='Name of experiment') 133 | args = parser.parse_args() 134 | print(args) 135 | min_total = 100 136 | filetype = 'pdf' 137 | 138 | cache_path = get_cache_path() / f'{args.exp_name}.pickle' 139 | report_dir = get_reports_path() / 'attention_analysis' / args.exp_name 140 | pathlib.Path(report_dir).mkdir(parents=True, exist_ok=True) 141 | 142 | args, feature_to_weighted_sum, weight_total = pickle.load(open(cache_path, "rb")) 143 | with open(report_dir / 'args.json', 'w') as f: 144 | json.dump(vars(args), f) 145 | print(args) 146 | for feature_name, weighted_sum in feature_to_weighted_sum.items(): 147 | create_figure(feature_name, weighted_sum, weight_total, report_dir, min_total=min_total, filetype=filetype) 148 | -------------------------------------------------------------------------------- /protein_attention/attention_analysis/report_edge_features_combined.py: -------------------------------------------------------------------------------- 1 | """Create combined plot from multiple features""" 2 | 3 | import json 4 | import pathlib 5 | import pickle 6 | import re 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import seaborn as sns 11 | from matplotlib.ticker import FuncFormatter 12 | 13 | from protein_attention.utils import get_reports_path, get_cache_path 14 | 15 | sns.set() 16 | 17 | sns.set_context("paper") 18 | 19 | ss4_names = { 20 | '0': 'Helix', 21 | '1': 'Strand', 22 | '2': 'Turn/Bend' 23 | } 24 | 25 | aa_to_pattern = re.compile(r'res_to_([A-Z])$') 26 | secondary_to_pattern = re.compile(r'sec_struct_to_([A-Z0-3\s])$') 27 | contact_map_pattern = re.compile(r'contact_map') 28 | binding_site_pattern = re.compile(r'binding_site_to') 29 | 30 | if __name__ == "__main__": 31 | 32 | import argparse 33 | 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--model', required=False, help='Names of experiments') 36 | args = parser.parse_args() 37 | 38 | exp_name_suffix = args.model or '' 39 | 40 | min_total = 100 41 | filetype = 'pdf' 42 | 43 | report_dir = get_reports_path() / ('attention_analysis/edge_features_combined' + \ 44 | (f'_{exp_name_suffix}' if exp_name_suffix else '')) 45 | pathlib.Path(report_dir).mkdir(parents=True, exist_ok=True) 46 | 47 | feature_data = [] 48 | include_features = [contact_map_pattern, secondary_to_pattern, binding_site_pattern] 49 | for exp_name_prefix in ['edge_features_sec', 'edge_features_contact', 'edge_features_sites']: 50 | exp = exp_name_prefix + (f'_{exp_name_suffix}' if exp_name_suffix else '') 51 | cache_path = get_cache_path() / f'{exp}.pickle' 52 | args, feature_to_weighted_sum, weight_total = pickle.load(open(cache_path, "rb")) 53 | with open(report_dir / f'args_{exp}.json', 'w') as f: 54 | json.dump(vars(args), f) 55 | for feature, weighted_sum in feature_to_weighted_sum.items(): 56 | for p in include_features: 57 | m = p.match(feature) 58 | desc = None 59 | if m: 60 | if p == contact_map_pattern: 61 | desc = 'Contact' 62 | elif p == binding_site_pattern: 63 | desc = 'Binding Site' 64 | elif p == secondary_to_pattern: 65 | sec = m.group(1) 66 | desc = ss4_names.get(sec) 67 | else: 68 | raise ValueError 69 | break 70 | if not desc: 71 | continue 72 | mean_by_head = weighted_sum / weight_total 73 | exclude_mask = np.array(weight_total) < min_total 74 | masked_mean_by_head = np.ma.masked_array(mean_by_head, mask=exclude_mask) 75 | layer_macro = masked_mean_by_head.mean(-1) 76 | layer_macro *= 100 # Convert to percentage 77 | n_layers = len(layer_macro) 78 | # assert n_layers == 12 79 | normalized = layer_macro / layer_macro.sum() 80 | assert np.allclose(normalized.sum(), 1) 81 | mean_center = sum(i * normalized[i] for i in range(n_layers)) 82 | feature_data.append((mean_center, feature, desc, layer_macro)) 83 | 84 | # Sort aggregated data by center of gravity 85 | feature_data.sort() 86 | 87 | # Create combined plot 88 | figsize = (3, 5) 89 | plt.figure(figsize=figsize) 90 | fig, ax = plt.subplots(len(feature_data), figsize=figsize, sharex=True, gridspec_kw={'wspace': 0, 'hspace': .17}) 91 | for i, (center, feature, desc, layer_macro) in enumerate(feature_data): 92 | ax[i].plot(list(range(n_layers)), 93 | layer_macro) 94 | ax[i].axvline(x=center, color='red', linestyle='dashed', linewidth=1) 95 | ax[i].tick_params(labelsize=6) 96 | ax[i].set_ylabel(desc, fontsize=8) 97 | ax[i].set_ylim(top=1.03 * max(layer_macro), bottom=0) 98 | ax[i].yaxis.tick_right() 99 | formatter = FuncFormatter(lambda y, pos: "%d%%" % (y)) 100 | ax[i].yaxis.set_major_formatter(formatter) 101 | ax[i].grid(True, axis='x', color='#F3F2F3', lw=1.2) 102 | ax[i].grid(True, axis='y', color='#F3F2F3', lw=1.2) 103 | 104 | plt.xticks(range(n_layers), range(1, n_layers + 1)) 105 | plt.xlabel('Layer', fontsize=8) 106 | fname = report_dir / (f'combined_features.{filetype}') 107 | print('Saving', fname) 108 | plt.savefig(fname, format=filetype, bbox_inches='tight') 109 | plt.close() 110 | -------------------------------------------------------------------------------- /protein_attention/attention_analysis/report_top_heads.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | import pathlib 10 | 11 | import numpy as np 12 | from statsmodels.stats.proportion import proportion_confint, proportions_ztest 13 | 14 | np.random.seed(0) 15 | import seaborn as sns 16 | 17 | sns.set() 18 | import pickle 19 | 20 | from protein_attention.utils import get_reports_path, get_cache_path 21 | from protein_attention.attention_analysis.background import binding_site_distribution,\ 22 | contact_map_distribution, protein_modification_distribution 23 | import numpy as np 24 | import seaborn as sns 25 | from matplotlib import pyplot as plt 26 | 27 | sns.set_context("paper") 28 | sns.set_style("white") 29 | 30 | 31 | def to_filename(s, extension): 32 | return "".join(x if (x.isalnum() or x in "._-") else '_' for x in s) + "." + extension 33 | 34 | 35 | def create_figure(feature_name, weighted_sum, weight_total, report_dir, min_total, filetype, max_seq_len, 36 | use_bonferroni=False, k=10): 37 | assert filetype in ('png', 'pdf') 38 | 39 | mean_by_head = weighted_sum / weight_total 40 | n_layers, n_heads = mean_by_head.shape 41 | 42 | scored_heads = [] 43 | for i in range(n_layers): 44 | for j in range(n_heads): 45 | if weight_total[i, j] > min_total: 46 | scored_heads.append((mean_by_head[i, j], i, j)) 47 | top_heads = sorted(scored_heads, reverse=True)[:k] 48 | 49 | 50 | if feature_name == 'binding_site_to': 51 | counts = binding_site_distribution(max_seq_len) 52 | num_pos_background = counts[1] 53 | num_neg_background = counts[0] 54 | elif feature_name == 'protein_modification_to': 55 | counts = protein_modification_distribution(max_seq_len) 56 | num_pos_background = counts[1] 57 | num_neg_background = counts[0] 58 | elif feature_name == 'contact_map': 59 | counts = contact_map_distribution(max_seq_len) 60 | num_pos_background = counts[1] 61 | num_neg_background = counts[0] 62 | else: 63 | raise NotImplementedError 64 | num_total_background = num_pos_background + num_neg_background 65 | background_pct = num_pos_background / (num_total_background) * 100 66 | 67 | scores = [] 68 | conf_ints = [] 69 | labels = [] 70 | 71 | for score, i, j in top_heads: 72 | scores.append(score.item() * 100) 73 | labels.append(f'{i + 1}-{j + 1}') 74 | num_pos = int(weighted_sum[i, j]) 75 | num_total = int(weight_total[i, j]) 76 | if use_bonferroni: 77 | print('m=', len(scored_heads)) 78 | start, end = proportion_confint(num_pos, num_total, alpha=0.05 / len(scored_heads)) 79 | else: 80 | start, end = proportion_confint(num_pos, num_total, alpha=0.05) 81 | conf_int = (end - start) / 2 * 100 82 | conf_ints.append(conf_int) 83 | print(i, j) 84 | print('background', num_pos_background, num_total_background, num_pos_background / num_total_background) 85 | print('attn', num_pos, num_total, num_pos / num_total, start, end) 86 | p_value = proportions_ztest([num_pos_background, num_pos], [num_total_background, num_total]) 87 | print('p_value', f'{p_value[-1]:.25f}') 88 | 89 | print(list(enumerate(zip(labels, [f'{s:.3f}' for s in scores])))) 90 | 91 | figsize = (3.7, 1.8) 92 | plt.figure(figsize=figsize) 93 | plt.bar(range(k), scores, yerr=conf_ints, capsize=3, edgecolor="none") 94 | x = np.arange(k) 95 | plt.xticks(x, labels, fontsize=7, rotation=45) 96 | plt.xlabel(f'Top heads', fontsize=8, labelpad=6) 97 | plt.ylabel('Attention %', fontsize=8, labelpad=6) 98 | plt.yticks(fontsize=7) 99 | ax = plt.gca() 100 | ax.tick_params(axis='both', which='major', pad=-2) 101 | ax.axhline(background_pct, linestyle='dashed', color='#FF7F00', linewidth=2, alpha=0.8) 102 | plt.tight_layout() 103 | ax.grid(True, axis='y', color='#F3F2F3', lw=1.2) 104 | fname = report_dir / to_filename(feature_name, filetype) 105 | print('Saving', fname) 106 | plt.savefig(fname, format=filetype) 107 | plt.close() 108 | 109 | 110 | if __name__ == "__main__": 111 | 112 | import argparse 113 | 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument('exp_name', help='Name of experiment') 116 | parser.add_argument('--min_total', type=int, default=100) 117 | args = parser.parse_args() 118 | 119 | use_bonferroni = True 120 | if use_bonferroni: 121 | print('Using bonferroni') 122 | min_total = args.min_total 123 | filetype = 'pdf' 124 | print(args.exp_name) 125 | cache_path = get_cache_path() / f'{args.exp_name}.pickle' 126 | report_dir = get_reports_path() / 'attention_analysis' / f'{args.exp_name}_topheads' 127 | pathlib.Path(report_dir).mkdir(parents=True, exist_ok=True) 128 | 129 | cache_args, feature_to_weighted_sum, weight_total = pickle.load(open(cache_path, "rb")) 130 | with open(report_dir / 'args.json', 'w') as f: 131 | json.dump(vars(cache_args), f) 132 | for feature_name, weighted_sum in feature_to_weighted_sum.items(): 133 | create_figure(feature_name, weighted_sum, weight_total, report_dir, min_total=min_total, filetype=filetype, 134 | use_bonferroni=use_bonferroni, max_seq_len=cache_args.max_seq_len) 135 | -------------------------------------------------------------------------------- /protein_attention/attention_analysis/scripts/compute_all_features_prot_albert.sh: -------------------------------------------------------------------------------- 1 | python compute_edge_features.py \ 2 | --exp-name edge_features_contact_prot_albert \ 3 | --model-version prot_albert \ 4 | --features contact_map \ 5 | --dataset proteinnet \ 6 | --num-sequences 5000 \ 7 | --max-seq-len 512 \ 8 | --min-attn .3 \ 9 | --shuffle && 10 | python compute_edge_features.py \ 11 | --exp-name edge_features_modifications_prot_albert \ 12 | --model-version prot_albert \ 13 | --features protein_modifications \ 14 | --dataset protein_modifications \ 15 | --num-sequences 5000 \ 16 | --max-seq-len 512 \ 17 | --min-attn .3 \ 18 | --shuffle && 19 | python compute_edge_features.py \ 20 | --exp-name edge_features_sec_prot_albert \ 21 | --model-version prot_albert \ 22 | --features ss4 \ 23 | --dataset secondary \ 24 | --num-sequences 5000 \ 25 | --max-seq-len 512 \ 26 | --min-attn .3 \ 27 | --shuffle && 28 | python compute_edge_features.py \ 29 | --exp-name edge_features_sites_prot_albert \ 30 | --model-version prot_albert \ 31 | --features binding_sites \ 32 | --dataset binding_sites \ 33 | --num-sequences 5000 \ 34 | --max-seq-len 512 \ 35 | --min-attn .3 \ 36 | --shuffle && 37 | python compute_edge_features.py \ 38 | --exp-name edge_features_aa_prot_albert \ 39 | --model-version prot_albert \ 40 | --features aa \ 41 | --dataset proteinnet \ 42 | --num-sequences 5000 \ 43 | --max-seq-len 512 \ 44 | --min-attn .3 \ 45 | --shuffle -------------------------------------------------------------------------------- /protein_attention/attention_analysis/scripts/compute_all_features_prot_bert.sh: -------------------------------------------------------------------------------- 1 | python compute_edge_features.py \ 2 | --exp-name edge_features_contact_prot_bert \ 3 | --model-version prot_bert \ 4 | --features contact_map \ 5 | --dataset proteinnet \ 6 | --num-sequences 5000 \ 7 | --max-seq-len 512 \ 8 | --min-attn .3 \ 9 | --shuffle && 10 | python compute_edge_features.py \ 11 | --exp-name edge_features_modifications_prot_bert \ 12 | --model-version prot_bert \ 13 | --features protein_modifications \ 14 | --dataset protein_modifications \ 15 | --num-sequences 5000 \ 16 | --max-seq-len 512 \ 17 | --min-attn .3 \ 18 | --shuffle && 19 | python compute_edge_features.py \ 20 | --exp-name edge_features_sec_prot_bert \ 21 | --model-version prot_bert \ 22 | --features ss4 \ 23 | --dataset secondary \ 24 | --num-sequences 5000 \ 25 | --max-seq-len 512 \ 26 | --min-attn .3 \ 27 | --shuffle && 28 | python compute_edge_features.py \ 29 | --exp-name edge_features_sites_prot_bert \ 30 | --model-version prot_bert \ 31 | --features binding_sites \ 32 | --dataset binding_sites \ 33 | --num-sequences 5000 \ 34 | --max-seq-len 512 \ 35 | --min-attn .3 \ 36 | --shuffle && 37 | python compute_edge_features.py \ 38 | --exp-name edge_features_aa_prot_bert \ 39 | --model-version prot_bert \ 40 | --features aa \ 41 | --dataset proteinnet \ 42 | --num-sequences 5000 \ 43 | --max-seq-len 512 \ 44 | --min-attn .3 \ 45 | --shuffle -------------------------------------------------------------------------------- /protein_attention/attention_analysis/scripts/compute_all_features_prot_bert_bfd.sh: -------------------------------------------------------------------------------- 1 | python compute_edge_features.py \ 2 | --exp-name edge_features_contact_prot_bert_bfd \ 3 | --model-version prot_bert_bfd \ 4 | --features contact_map \ 5 | --dataset proteinnet \ 6 | --num-sequences 5000 \ 7 | --max-seq-len 512 \ 8 | --min-attn .3 \ 9 | --shuffle && 10 | python compute_edge_features.py \ 11 | --exp-name edge_features_modifications_prot_bert_bfd \ 12 | --model-version prot_bert_bfd \ 13 | --features protein_modifications \ 14 | --dataset protein_modifications \ 15 | --num-sequences 5000 \ 16 | --max-seq-len 512 \ 17 | --min-attn .3 \ 18 | --shuffle && 19 | python compute_edge_features.py \ 20 | --exp-name edge_features_sec_prot_bert_bfd \ 21 | --model-version prot_bert_bfd \ 22 | --features ss4 \ 23 | --dataset secondary \ 24 | --num-sequences 5000 \ 25 | --max-seq-len 512 \ 26 | --min-attn .3 \ 27 | --shuffle && 28 | python compute_edge_features.py \ 29 | --exp-name edge_features_sites_prot_bert_bfd \ 30 | --model-version prot_bert_bfd \ 31 | --features binding_sites \ 32 | --dataset binding_sites \ 33 | --num-sequences 5000 \ 34 | --max-seq-len 512 \ 35 | --min-attn .3 \ 36 | --shuffle && 37 | python compute_edge_features.py \ 38 | --exp-name edge_features_aa_prot_bert_bfd \ 39 | --model-version prot_bert_bfd \ 40 | --features aa \ 41 | --dataset proteinnet \ 42 | --num-sequences 5000 \ 43 | --max-seq-len 512 \ 44 | --min-attn .3 \ 45 | --shuffle -------------------------------------------------------------------------------- /protein_attention/attention_analysis/scripts/compute_all_features_prot_xlnet.sh: -------------------------------------------------------------------------------- 1 | python compute_edge_features.py \ 2 | --exp-name edge_features_contact_prot_xlnet \ 3 | --model xlnet \ 4 | --model-version prot_xlnet \ 5 | --features contact_map \ 6 | --dataset proteinnet \ 7 | --num-sequences 5000 \ 8 | --max-seq-len 512 \ 9 | --min-attn .3 \ 10 | --shuffle && 11 | python compute_edge_features.py \ 12 | --exp-name edge_features_modifications_prot_xlnet \ 13 | --model xlnet \ 14 | --model-version prot_xlnet \ 15 | --features protein_modifications \ 16 | --dataset protein_modifications \ 17 | --num-sequences 5000 \ 18 | --max-seq-len 512 \ 19 | --min-attn .3 \ 20 | --shuffle && 21 | python compute_edge_features.py \ 22 | --exp-name edge_features_sec_prot_xlnet \ 23 | --model xlnet \ 24 | --model-version prot_xlnet \ 25 | --features ss4 \ 26 | --dataset secondary \ 27 | --num-sequences 5000 \ 28 | --max-seq-len 512 \ 29 | --min-attn .3 \ 30 | --shuffle && 31 | python compute_edge_features.py \ 32 | --exp-name edge_features_sites_prot_xlnet \ 33 | --model xlnet \ 34 | --model-version prot_xlnet \ 35 | --features binding_sites \ 36 | --dataset binding_sites \ 37 | --num-sequences 5000 \ 38 | --max-seq-len 512 \ 39 | --min-attn .3 \ 40 | --shuffle && 41 | python compute_edge_features.py \ 42 | --exp-name edge_features_aa_prot_xlnet \ 43 | --model xlnet \ 44 | --model-version prot_xlnet \ 45 | --features aa \ 46 | --dataset proteinnet \ 47 | --num-sequences 5000 \ 48 | --max-seq-len 512 \ 49 | --min-attn .3 \ 50 | --shuffle -------------------------------------------------------------------------------- /protein_attention/attention_analysis/scripts/compute_all_features_tape_bert.sh: -------------------------------------------------------------------------------- 1 | python compute_edge_features.py \ 2 | --exp-name edge_features_contact_tape_bert \ 3 | --features contact_map \ 4 | --dataset proteinnet \ 5 | --num-sequences 5000 \ 6 | --max-seq-len 512 \ 7 | --min-attn .3 \ 8 | --shuffle && 9 | python compute_edge_features.py \ 10 | --exp-name edge_features_modifications_tape_bert \ 11 | --features protein_modifications \ 12 | --dataset protein_modifications \ 13 | --num-sequences 5000 \ 14 | --max-seq-len 512 \ 15 | --min-attn .3 \ 16 | --shuffle && 17 | python compute_edge_features.py \ 18 | --exp-name edge_features_sec_tape_bert \ 19 | --features ss4 \ 20 | --dataset secondary \ 21 | --num-sequences 5000 \ 22 | --max-seq-len 512 \ 23 | --min-attn .3 \ 24 | --shuffle && 25 | python compute_edge_features.py \ 26 | --exp-name edge_features_sites_tape_bert \ 27 | --features binding_sites \ 28 | --dataset binding_sites \ 29 | --num-sequences 5000 \ 30 | --max-seq-len 512 \ 31 | --min-attn .3 \ 32 | --shuffle && 33 | python compute_edge_features.py \ 34 | --exp-name edge_features_aa_tape_bert \ 35 | --features aa \ 36 | --dataset proteinnet \ 37 | --num-sequences 5000 \ 38 | --max-seq-len 512 \ 39 | --min-attn .3 \ 40 | --shuffle -------------------------------------------------------------------------------- /protein_attention/attention_analysis/scripts/report_all_features_prot_albert.sh: -------------------------------------------------------------------------------- 1 | python report_edge_features.py edge_features_contact_prot_albert 2 | python report_edge_features.py edge_features_modifications_prot_albert 3 | python report_edge_features.py edge_features_sec_prot_albert 4 | python report_edge_features.py edge_features_sites_prot_albert 5 | python report_edge_features.py edge_features_aa_prot_albert 6 | python report_aa_correlations.py edge_features_aa_prot_albert 7 | python report_top_heads.py edge_features_contact_prot_albert 8 | python report_top_heads.py edge_features_sites_prot_albert 9 | python report_top_heads.py edge_features_modifications_prot_albert 10 | -------------------------------------------------------------------------------- /protein_attention/attention_analysis/scripts/report_all_features_prot_bert.sh: -------------------------------------------------------------------------------- 1 | python report_edge_features.py edge_features_contact_prot_bert 2 | python report_edge_features.py edge_features_modifications_prot_bert 3 | python report_edge_features.py edge_features_sec_prot_bert 4 | python report_edge_features.py edge_features_sites_prot_bert 5 | python report_edge_features.py edge_features_aa_prot_bert 6 | python report_aa_correlations.py edge_features_aa_prot_bert 7 | python report_top_heads.py edge_features_contact_prot_bert 8 | python report_top_heads.py edge_features_sites_prot_bert 9 | python report_top_heads.py edge_features_modifications_prot_bert 10 | 11 | -------------------------------------------------------------------------------- /protein_attention/attention_analysis/scripts/report_all_features_prot_bert_bfd.sh: -------------------------------------------------------------------------------- 1 | python report_edge_features.py edge_features_contact_prot_bert_bfd 2 | python report_edge_features.py edge_features_modifications_prot_bert_bfd 3 | python report_edge_features.py edge_features_sec_prot_bert_bfd 4 | python report_edge_features.py edge_features_sites_prot_bert_bfd 5 | python report_edge_features.py edge_features_aa_prot_bert_bfd 6 | python report_aa_correlations.py edge_features_aa_prot_bert_bfd 7 | python report_top_heads.py edge_features_contact_prot_bert_bfd 8 | python report_top_heads.py edge_features_sites_prot_bert_bfd 9 | python report_top_heads.py edge_features_modifications_prot_bert_bfd 10 | 11 | 12 | -------------------------------------------------------------------------------- /protein_attention/attention_analysis/scripts/report_all_features_prot_xlnet.sh: -------------------------------------------------------------------------------- 1 | python report_edge_features.py edge_features_contact_prot_xlnet 2 | python report_edge_features.py edge_features_modifications_prot_xlnet 3 | python report_edge_features.py edge_features_sec_prot_xlnet 4 | python report_edge_features.py edge_features_sites_prot_xlnet 5 | python report_edge_features.py edge_features_aa_prot_xlnet 6 | python report_aa_correlations.py edge_features_aa_prot_xlnet 7 | python report_top_heads.py edge_features_contact_prot_xlnet 8 | python report_top_heads.py edge_features_sites_prot_xlnet 9 | python report_top_heads.py edge_features_modifications_prot_xlnet 10 | -------------------------------------------------------------------------------- /protein_attention/attention_analysis/scripts/report_all_features_tape_bert.sh: -------------------------------------------------------------------------------- 1 | python report_edge_features.py edge_features_contact_tape_bert 2 | python report_edge_features.py edge_features_modifications_tape_bert 3 | python report_edge_features.py edge_features_sec_tape_bert 4 | python report_edge_features.py edge_features_sites_tape_bert 5 | python report_edge_features.py edge_features_aa_tape_bert 6 | python report_aa_correlations.py edge_features_aa_tape_bert 7 | python report_edge_features_combined.py --model tape_bert 8 | python report_top_heads.py edge_features_contact_tape_bert 9 | python report_top_heads.py edge_features_sites_tape_bert 10 | python report_top_heads.py edge_features_modifications_tape_bert 11 | -------------------------------------------------------------------------------- /protein_attention/datasets.py: -------------------------------------------------------------------------------- 1 | """Extensions to dataset classes from TAPE Repository: https://github.com/songlab-cal/tape 2 | Date Change 3 | ---------- --------------------- 4 | 05/01/2020 Added Binding Site dataset class 5 | Added one-vs-all Secondary structure dataset class 6 | 7 | """ 8 | 9 | from pathlib import Path 10 | from typing import Union, List, Tuple, Sequence, Dict, Any 11 | 12 | import numpy as np 13 | import torch 14 | from scipy.spatial.distance import pdist, squareform 15 | from tape.datasets import dataset_factory 16 | from tape.tokenizers import TAPETokenizer 17 | from torch.utils.data import Dataset 18 | from transformers import BertTokenizer, XLNetTokenizer, AlbertTokenizer 19 | 20 | ss8_cds = ['G', 'H', 'I', 'B', 'E', 'S', 'T', ' '] 21 | ss8_to_idx = {cd: i for i, cd in enumerate(ss8_cds)} 22 | 23 | ss8_blank_index = 7 24 | ss4_blank_index = 3 25 | 26 | 27 | class SecondaryStructureOneVsAllDataset(Dataset): 28 | 29 | def __init__(self, 30 | data_path: Union[str, Path], 31 | split: str, 32 | label_scheme: str, 33 | label: str, 34 | tokenizer: Union[str, TAPETokenizer, BertTokenizer, XLNetTokenizer, AlbertTokenizer] = 'iupac', 35 | in_memory: bool = False, 36 | max_seqlen: int = 512): 37 | 38 | if label_scheme != 'ss8' and label_scheme != 'ss4': 39 | raise NotImplementedError 40 | 41 | if split not in ('train', 'valid', 'casp12', 'ts115', 'cb513'): 42 | raise ValueError(f"Unrecognized split: {split}. Must be one of " 43 | f"['train', 'valid', 'casp12', " 44 | f"'ts115', 'cb513']") 45 | 46 | if isinstance(tokenizer, str): 47 | tokenizer = TAPETokenizer(vocab=tokenizer) 48 | self.tokenizer = tokenizer 49 | 50 | data_path = Path(data_path) 51 | data_file = f'secondary_structure/secondary_structure_{split}.lmdb' 52 | self.data = dataset_factory(data_path / data_file, in_memory) 53 | if label_scheme == 'ss8': 54 | self.label = ss8_to_idx[label] 55 | elif label_scheme == 'ss4': 56 | self.label = label 57 | else: 58 | raise NotImplementedError 59 | self.label_scheme = label_scheme 60 | self.max_seqlen = max_seqlen 61 | 62 | 63 | def __len__(self) -> int: 64 | return len(self.data) 65 | 66 | def __getitem__(self, index: int): 67 | item = self.data[index] 68 | sequence = item['primary'] 69 | if self.max_seqlen: 70 | sequence = sequence[:self.max_seqlen] 71 | if isinstance(self.tokenizer, BertTokenizer) or isinstance(self.tokenizer, XLNetTokenizer) or \ 72 | isinstance(self.tokenizer, AlbertTokenizer): 73 | token_ids = np.array(self.tokenizer.encode(list(sequence)), np.int64) 74 | elif isinstance(self.tokenizer, TAPETokenizer): 75 | token_ids = self.tokenizer.encode(sequence) 76 | else: 77 | raise NotImplementedError 78 | input_mask = np.ones_like(token_ids) 79 | 80 | if self.label_scheme == 'ss4': 81 | # ss8 code 7 is for blank label. 3 is used to represent blank in ss4 82 | ss_labels = [ss4_blank_index if ss8 == ss8_blank_index else ss3 for ss3, ss8 in 83 | zip(item['ss3'], item['ss8'])] 84 | else: 85 | ss_labels = item['ss8'] 86 | if self.max_seqlen: 87 | ss_labels = ss_labels[:self.max_seqlen] 88 | labels = np.asarray([label == self.label for label in ss_labels], np.int64) 89 | if isinstance(self.tokenizer, XLNetTokenizer): 90 | # pad with two -1s at end because of sep/cls tokens 91 | labels = np.pad(labels, (0, 2), 'constant', constant_values=-1) 92 | elif isinstance(self.tokenizer, BertTokenizer) or isinstance(self.tokenizer, TAPETokenizer) or \ 93 | isinstance(self.tokenizer, AlbertTokenizer): 94 | # pad with -1s at beginning and end because of cls/sep tokens 95 | labels = np.pad(labels, (1, 1), 'constant', constant_values=-1) 96 | else: 97 | raise NotImplementedError 98 | 99 | return token_ids, input_mask, labels 100 | 101 | def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]: 102 | input_ids, input_mask, ss_label = tuple(zip(*batch)) 103 | input_ids = torch.from_numpy(pad_sequences(input_ids, 0)) 104 | input_mask = torch.from_numpy(pad_sequences(input_mask, 0)) 105 | ss_label = torch.from_numpy(pad_sequences(ss_label, -1)) 106 | 107 | output = {'input_ids': input_ids, 108 | 'input_mask': input_mask, 109 | 'targets': ss_label} 110 | 111 | return output 112 | 113 | 114 | class BindingSiteDataset(Dataset): 115 | 116 | def __init__(self, 117 | data_path: Union[str, Path], 118 | split: str, 119 | tokenizer: Union[str, TAPETokenizer, BertTokenizer] = 'iupac', 120 | in_memory: bool = False, 121 | max_seqlen: int = 512): 122 | 123 | allowed_splits = ('train', 'valid') 124 | if split not in allowed_splits: 125 | raise ValueError(f"Unrecognized split: {split}. Must be one of: {', '.join(allowed_splits)}") 126 | 127 | if isinstance(tokenizer, str): 128 | tokenizer = TAPETokenizer(vocab=tokenizer) 129 | self.tokenizer = tokenizer 130 | 131 | data_path = Path(data_path) 132 | data_file = f'binding_sites/binding_site_{split}.lmdb' 133 | self.data = dataset_factory(data_path / data_file, in_memory) 134 | self.max_seqlen = max_seqlen 135 | 136 | def __len__(self) -> int: 137 | return len(self.data) 138 | 139 | def __getitem__(self, index: int): 140 | item = self.data[index] 141 | sequence = item['primary'] 142 | positions = item['positions'] 143 | if self.max_seqlen: 144 | sequence = sequence[:self.max_seqlen] 145 | positions = positions[:self.max_seqlen] 146 | 147 | if isinstance(self.tokenizer, BertTokenizer) or isinstance(self.tokenizer, XLNetTokenizer) or \ 148 | isinstance(self.tokenizer, AlbertTokenizer): 149 | token_ids = np.array(self.tokenizer.encode(list(sequence)), np.int64) 150 | elif isinstance(self.tokenizer, TAPETokenizer): 151 | token_ids = self.tokenizer.encode(sequence) 152 | else: 153 | raise NotImplementedError 154 | 155 | input_mask = np.ones_like(token_ids) 156 | 157 | labels = [1 if seq_pos in item['sites'] else 0 for seq_pos in positions] 158 | 159 | if isinstance(self.tokenizer, XLNetTokenizer): 160 | # pad with two -1s at end because of sep/cls tokens 161 | labels = np.pad(labels, (0, 2), 'constant', constant_values=-1) 162 | elif isinstance(self.tokenizer, BertTokenizer) or isinstance(self.tokenizer, TAPETokenizer) or \ 163 | isinstance(self.tokenizer, AlbertTokenizer): 164 | # pad with -1s at beginning and end because of cls/sep tokens 165 | labels = np.pad(labels, (1, 1), 'constant', constant_values=-1) 166 | else: 167 | raise NotImplementedError 168 | 169 | return token_ids, input_mask, labels 170 | 171 | def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]: 172 | input_ids, input_mask, label = tuple(zip(*batch)) 173 | input_ids = torch.from_numpy(pad_sequences(input_ids, 0)) 174 | input_mask = torch.from_numpy(pad_sequences(input_mask, 0)) 175 | label = torch.from_numpy(pad_sequences(label, -1)) 176 | 177 | output = {'input_ids': input_ids, 178 | 'input_mask': input_mask, 179 | 'targets': label} 180 | 181 | return output 182 | 183 | class ProteinModificationDataset(Dataset): 184 | 185 | def __init__(self, 186 | data_path: Union[str, Path], 187 | split: str, 188 | tokenizer: Union[str, TAPETokenizer, BertTokenizer, XLNetTokenizer] = 'iupac', 189 | in_memory: bool = False, 190 | max_seqlen: int = 512): 191 | 192 | allowed_splits = ('train', 'valid') 193 | if split not in allowed_splits: 194 | raise ValueError(f"Unrecognized split: {split}. Must be one of: {', '.join(allowed_splits)}") 195 | 196 | if isinstance(tokenizer, str): 197 | tokenizer = TAPETokenizer(vocab=tokenizer) 198 | self.tokenizer = tokenizer 199 | 200 | data_path = Path(data_path) 201 | data_file = f'protein_modifications/protein_modification_{split}.lmdb' 202 | self.data = dataset_factory(data_path / data_file, in_memory) 203 | self.max_seqlen = max_seqlen 204 | 205 | def __len__(self) -> int: 206 | return len(self.data) 207 | 208 | def __getitem__(self, index: int): 209 | item = self.data[index] 210 | sequence = item['primary'] 211 | positions = item['positions'] 212 | if self.max_seqlen: 213 | sequence = sequence[:self.max_seqlen] 214 | positions = positions[:self.max_seqlen] 215 | 216 | if isinstance(self.tokenizer, BertTokenizer) or isinstance(self.tokenizer, XLNetTokenizer) or \ 217 | isinstance(self.tokenizer, AlbertTokenizer): 218 | token_ids = np.array(self.tokenizer.encode(list(sequence)), np.int64) 219 | elif isinstance(self.tokenizer, TAPETokenizer): 220 | token_ids = self.tokenizer.encode(sequence) 221 | else: 222 | raise NotImplementedError 223 | 224 | input_mask = np.ones_like(token_ids) 225 | 226 | labels = [1 if seq_pos in item['modifications'] else 0 for seq_pos in positions] 227 | 228 | if isinstance(self.tokenizer, XLNetTokenizer): 229 | # pad with two -1s at end because of sep/cls tokens 230 | labels = np.pad(labels, (0, 2), 'constant', constant_values=-1) 231 | elif isinstance(self.tokenizer, BertTokenizer) or isinstance(self.tokenizer, TAPETokenizer) or \ 232 | isinstance(self.tokenizer, AlbertTokenizer): 233 | # pad with -1s at beginning and end because of cls/sep tokens 234 | labels = np.pad(labels, (1, 1), 'constant', constant_values=-1) 235 | else: 236 | raise NotImplementedError 237 | 238 | return token_ids, input_mask, labels 239 | 240 | def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]: 241 | input_ids, input_mask, label = tuple(zip(*batch)) 242 | input_ids = torch.from_numpy(pad_sequences(input_ids, 0)) 243 | input_mask = torch.from_numpy(pad_sequences(input_mask, 0)) 244 | label = torch.from_numpy(pad_sequences(label, -1)) 245 | 246 | output = {'input_ids': input_ids, 247 | 'input_mask': input_mask, 248 | 'targets': label} 249 | 250 | return output 251 | 252 | 253 | class ProteinnetDataset(Dataset): 254 | 255 | def __init__(self, 256 | data_path: Union[str, Path], 257 | split: str, 258 | tokenizer: Union[str, TAPETokenizer] = 'iupac', 259 | in_memory: bool = False, 260 | max_seq_len=None): 261 | 262 | if split not in ('train', 'train_unfiltered', 'valid', 'test'): 263 | raise ValueError(f"Unrecognized split: {split}. Must be one of " 264 | f"['train', 'train_unfiltered', 'valid', 'test']") 265 | 266 | if isinstance(tokenizer, str): 267 | tokenizer = TAPETokenizer(vocab=tokenizer) 268 | self.tokenizer = tokenizer 269 | 270 | data_path = Path(data_path) 271 | data_file = f'proteinnet/proteinnet_{split}.lmdb' 272 | self.data = dataset_factory(data_path / data_file, in_memory) 273 | self.max_seq_len = max_seq_len 274 | 275 | def __len__(self) -> int: 276 | return len(self.data) 277 | 278 | def __getitem__(self, index: int): 279 | item = self.data[index] 280 | primary = item['primary'] 281 | tertiary = item['tertiary'] 282 | valid_mask = item['valid_mask'] 283 | if self.max_seq_len: 284 | primary = primary[:self.max_seq_len] 285 | tertiary = tertiary[:self.max_seq_len] 286 | valid_mask = valid_mask[:self.max_seq_len] 287 | 288 | protein_length = len(primary) 289 | token_ids = self.tokenizer.encode(primary) 290 | input_mask = np.ones_like(token_ids) 291 | 292 | 293 | contact_map = np.less(squareform(pdist(tertiary)), 8.0).astype(np.int64) 294 | 295 | yind, xind = np.indices(contact_map.shape) 296 | invalid_mask = ~(valid_mask[:, None] & valid_mask[None, :]) 297 | invalid_mask |= np.abs(yind - xind) < 6 298 | contact_map[invalid_mask] = -1 299 | 300 | return token_ids, input_mask, contact_map, protein_length 301 | 302 | def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]: 303 | input_ids, input_mask, contact_labels, protein_length = tuple(zip(*batch)) 304 | input_ids = torch.from_numpy(pad_sequences(input_ids, 0)) 305 | input_mask = torch.from_numpy(pad_sequences(input_mask, 0)) 306 | contact_labels = torch.from_numpy(pad_sequences(contact_labels, -1)) 307 | protein_length = torch.LongTensor(protein_length) # type: ignore 308 | 309 | return {'input_ids': input_ids, 310 | 'input_mask': input_mask, 311 | 'targets': contact_labels, 312 | 'protein_length': protein_length} 313 | 314 | 315 | def pad_sequences(sequences: Sequence, constant_value=0, dtype=None) -> np.ndarray: 316 | batch_size = len(sequences) 317 | shape = [batch_size] + np.max([seq.shape for seq in sequences], 0).tolist() 318 | 319 | if dtype is None: 320 | dtype = sequences[0].dtype 321 | 322 | if isinstance(sequences[0], np.ndarray): 323 | array = np.full(shape, constant_value, dtype=dtype) 324 | elif isinstance(sequences[0], torch.Tensor): 325 | array = torch.full(shape, constant_value, dtype=dtype) 326 | 327 | for arr, seq in zip(array, sequences): 328 | arrslice = tuple(slice(dim) for dim in seq.shape) 329 | arr[arrslice] = seq 330 | 331 | return array 332 | -------------------------------------------------------------------------------- /protein_attention/probing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/protein_attention/probing/__init__.py -------------------------------------------------------------------------------- /protein_attention/probing/metrics.py: -------------------------------------------------------------------------------- 1 | """Metrics for evaluating probing classifiers""" 2 | 3 | from typing import Sequence, Union 4 | 5 | import numpy as np 6 | from scipy.special import softmax 7 | 8 | 9 | def precision(target: Union[Sequence[int], Sequence[Sequence[int]]], 10 | prediction: Union[Sequence[float], Sequence[Sequence[float]]]) -> float: 11 | if isinstance(target[0], int): 12 | raise NotImplementedError 13 | else: 14 | tp = 0 15 | fp = 0 16 | tn = 0 17 | fn = 0 18 | 19 | for label, score in zip(target, prediction): 20 | label_array = np.asarray(label) 21 | pred_array = np.asarray(score).argmax(-1) 22 | mask = label_array != -1 23 | is_correct = label_array[mask] == pred_array[mask] 24 | is_incorrect = ~is_correct 25 | 26 | is_predicted_true = pred_array[mask] == 1 27 | is_predicted_false = ~is_predicted_true 28 | 29 | tp += (is_predicted_true & is_correct).sum() 30 | fp += (is_predicted_true & is_incorrect).sum() 31 | tn += (is_predicted_false & is_correct).sum() 32 | fn += (is_predicted_false & is_incorrect).sum() 33 | 34 | print('tp:', tp, 'fp:', fp, 'tn:', tn, 'fn:', fn) 35 | return tp / (tp + fp) 36 | 37 | 38 | def recall(target: Union[Sequence[int], Sequence[Sequence[int]]], 39 | prediction: Union[Sequence[float], Sequence[Sequence[float]]]) -> float: 40 | if isinstance(target[0], int): 41 | raise NotImplementedError 42 | else: 43 | tp = 0 44 | fp = 0 45 | tn = 0 46 | fn = 0 47 | 48 | for label, score in zip(target, prediction): 49 | label_array = np.asarray(label) 50 | pred_array = np.asarray(score).argmax(-1) 51 | mask = label_array != -1 52 | is_correct = label_array[mask] == pred_array[mask] 53 | is_incorrect = ~is_correct 54 | 55 | is_predicted_true = pred_array[mask] == 1 56 | is_predicted_false = ~is_predicted_true 57 | 58 | tp += (is_predicted_true & is_correct).sum() 59 | fp += (is_predicted_true & is_incorrect).sum() 60 | tn += (is_predicted_false & is_correct).sum() 61 | fn += (is_predicted_false & is_incorrect).sum() 62 | 63 | print('tp:', tp, 'fp:', fp, 'tn:', tn, 'fn:', fn) 64 | return tp / (tp + fn) 65 | 66 | 67 | def f1(target: Union[Sequence[int], Sequence[Sequence[int]]], 68 | prediction: Union[Sequence[float], Sequence[Sequence[float]]]) -> float: 69 | p = precision(target, prediction) 70 | r = recall(target, prediction) 71 | return 2 * p * r / (p + r) 72 | 73 | 74 | def precision_at_ks(ks: Sequence[int], target: Union[Sequence[int], Sequence[Sequence[int]]], 75 | prediction: Union[Sequence[float], Sequence[Sequence[float]]]) -> float: 76 | if isinstance(target[0], int): 77 | raise NotImplementedError 78 | else: 79 | top_k_all = [] 80 | for k, label, score in zip(ks, target, prediction): 81 | label_array = np.asarray(label) 82 | pred_array = np.asarray(score) 83 | num_classes = pred_array.shape[-1] 84 | if num_classes != 2: 85 | raise NotImplementedError('Currently only support binary classification tasks') 86 | probs = softmax(pred_array, axis=-1) 87 | pos_probs = probs[:, 1] 88 | mask = label_array != -1 89 | score_labels = [] 90 | num_pos = 0 91 | num_total = 0 92 | 93 | for label, pos_prob, m in zip(label_array, pos_probs, mask): 94 | if m: 95 | score_labels.append((pos_prob, label)) 96 | num_total += 1 97 | if label == 1: 98 | num_pos += 1 99 | if label not in (0, 1): 100 | print(label) 101 | # print('added', (score, label)) 102 | if len(score_labels) == 0: 103 | continue 104 | top = sorted(score_labels, reverse=True)[:k] 105 | top_labels = list(zip(*top))[1] 106 | top_k_all.extend(top_labels) 107 | return sum(top_k_all) / len(top_k_all) 108 | 109 | 110 | if __name__ == '__main__': 111 | 112 | target = [ 113 | np.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0]), 114 | np.array([1, 0, 1, 0, 1, 0, 1, 1])] 115 | prediction = [ 116 | np.array([[.2, .8], [0., .2], [0., 1], [.1, .9], [.9, .1], [1., 0.], [0., 1.], [0., .3], [0., .7], [0., .2], 117 | [0., 1.], [0., 0.], [0., 0.], [0., 0.]]), 118 | np.array([[0., .9], [0., .1], [0., .1], [0., .8], [0., .9], [0., 0], [0., 0]]) 119 | ] 120 | ks = [ 121 | 6, # 5/6 are correct 122 | 3, # 2/3 are correct 123 | ] 124 | 125 | assert precision_at_ks(ks, target, prediction) == (5 + 2) / (6 + 3) 126 | -------------------------------------------------------------------------------- /protein_attention/probing/models.py: -------------------------------------------------------------------------------- 1 | """Diagnostic classifiers for probing analysis. Based on TAPE models.""" 2 | 3 | import torch 4 | from tape.models.modeling_bert import ProteinBertAbstractModel, ProteinBertModel 5 | from tape.models.modeling_utils import PairwiseContactPredictionHead 6 | from torch import nn 7 | from transformers import BertModel 8 | import torch.nn.functional as F 9 | 10 | 11 | class ProteinBertForLinearSequenceToSequenceProbingFromAttention(ProteinBertAbstractModel): 12 | """Bert head for token-level prediction tasks (secondary structure, binding sites) from attention weights""" 13 | 14 | def __init__(self, config): 15 | super().__init__(config) 16 | config.output_attentions = True 17 | self.bert = ProteinBertModel(config) 18 | self.predict = LinearSequenceToSequenceClassificationFromAttentionHead(config, 19 | ignore_index=-1) 20 | 21 | for param in self.bert.parameters(): 22 | param.requires_grad = False 23 | 24 | self.init_weights() 25 | 26 | def forward(self, input_ids, input_mask=None, targets=None): 27 | 28 | outputs = self.bert(input_ids, input_mask=input_mask) 29 | attention = outputs[-1] 30 | print('Sum of attentoins', attention[0][0, 0, 0].sum()) 31 | last_layer_attention = attention[-1] 32 | outputs = self.predict(last_layer_attention, targets) #+ outputs[2:] 33 | return outputs 34 | 35 | class LinearSequenceToSequenceClassificationFromAttentionHead(nn.Module): 36 | 37 | def __init__(self, 38 | config, 39 | ignore_index=-100, 40 | dropout=0.1, 41 | num_top_weights=10): 42 | super().__init__() 43 | if hasattr(config, 'probing_heads'): 44 | self.probing_heads = config.probing_heads 45 | else: 46 | self.probing_heads = list(range(config.num_attention_heads)) 47 | self.classify = nn.Sequential( 48 | nn.Dropout(dropout), 49 | nn.Linear(len(self.probing_heads) * num_top_weights, config.num_labels)) 50 | self.num_labels = config.num_labels 51 | self._ignore_index = ignore_index 52 | self.num_top_weights = num_top_weights 53 | 54 | def forward(self, attention, targets=None): 55 | """ Args: 56 | attention: tensor of shape 57 | ``(batch_size, num_heads, sequence_length, sequence_length)`` 58 | """ 59 | batch_size = attention.shape[0] 60 | seq_len = attention.shape[2] 61 | assert attention.shape[3] == seq_len 62 | head_attentions = [] 63 | for head in self.probing_heads: 64 | head_attention = attention[:, head].squeeze(1) 65 | assert head_attention.shape == (batch_size, seq_len, seq_len) 66 | head_attentions.append(head_attention) 67 | 68 | stacked = torch.stack(head_attentions) 69 | assert stacked.shape == (len(self.probing_heads), batch_size, seq_len, seq_len) 70 | stacked = stacked.permute(1, 0, 3, 2) 71 | assert stacked.shape == (batch_size, len(self.probing_heads), seq_len, seq_len) 72 | # Now dim 2 has the attention TO a particular position and dim 3 has the attention FROM a position 73 | 74 | features = stacked.topk(self.num_top_weights)[0] 75 | assert features.shape == (batch_size, len(self.probing_heads), seq_len, self.num_top_weights) 76 | # Last dimension is the top K attention weights to a given sequence position 77 | 78 | features = features.permute(0, 2, 1, 3) 79 | assert features.shape == (batch_size, seq_len, len(self.probing_heads), self.num_top_weights) 80 | 81 | features = features.flatten(start_dim=2) 82 | # Flatten last two dimension to create a single feature vector for each sequence position 83 | assert features.shape == (batch_size, seq_len, len(self.probing_heads) * self.num_top_weights) 84 | 85 | sequence_logits = self.classify(features) 86 | outputs = (sequence_logits,) 87 | if targets is not None: 88 | loss_fct = nn.CrossEntropyLoss(ignore_index=self._ignore_index) 89 | classification_loss = loss_fct( 90 | sequence_logits.view(-1, self.num_labels), targets.view(-1)) 91 | metrics = { 92 | 'accuracy': accuracy(sequence_logits.view(-1, self.num_labels), targets.view(-1), self._ignore_index), 93 | } 94 | loss_and_metrics = (classification_loss, metrics) 95 | outputs = (loss_and_metrics,) + outputs 96 | return outputs # (loss), sequence_logits 97 | 98 | 99 | 100 | class ProteinBertForContactPredictionFromAttention(ProteinBertAbstractModel): 101 | """Bert head for token-pair contact prediction from attention weights""" 102 | 103 | def __init__(self, config): 104 | super().__init__(config) 105 | config.output_attentions = True 106 | self.bert = ProteinBertModel(config) 107 | self.predict = PairwiseContactPredictionFromAttentionHead(config, 108 | ignore_index=-1) 109 | 110 | for param in self.bert.parameters(): 111 | param.requires_grad = False 112 | 113 | self.init_weights() 114 | 115 | def forward(self, input_ids, protein_length, input_mask=None, targets=None): 116 | 117 | outputs = self.bert(input_ids, input_mask=input_mask) 118 | attention = outputs[-1] 119 | last_layer_attention = attention[-1] 120 | outputs = self.predict(last_layer_attention, protein_length, targets) #+ outputs[2:] 121 | return outputs 122 | 123 | 124 | class PairwiseContactPredictionFromAttentionHead(nn.Module): 125 | 126 | def __init__(self, config, ignore_index=-100): 127 | super().__init__() 128 | if hasattr(config, 'probing_heads'): 129 | self.probing_heads = config.probing_heads 130 | else: 131 | self.probing_heads = list(range(config.num_attention_heads)) 132 | self.classifier = nn.Sequential( 133 | nn.Dropout(), nn.Linear(len(self.probing_heads), 2)) 134 | self._ignore_index = ignore_index 135 | 136 | def forward(self, attention, sequence_lengths, targets=None): 137 | """ Args: 138 | attention: tensor of shape 139 | ``(batch_size, num_heads, sequence_length, sequence_length)`` 140 | """ 141 | 142 | batch_size = attention.shape[0] 143 | seq_len = attention.shape[2] 144 | assert attention.shape[3] == seq_len 145 | head_attentions = [] 146 | for head in self.probing_heads: 147 | head_attention = attention[:, head].squeeze(1) 148 | assert head_attention.shape == (batch_size, seq_len, seq_len) 149 | head_attentions.append(head_attention) 150 | 151 | num_features = len(head_attentions) 152 | stacked = torch.stack(head_attentions) 153 | assert stacked.shape == (num_features, batch_size, seq_len, seq_len) 154 | pairwise_features = stacked.permute(1, 2, 3, 0) 155 | assert pairwise_features.shape == (batch_size, seq_len, seq_len, num_features) 156 | pairwise_features = (pairwise_features + pairwise_features.transpose(1,2))/2 # Mean attention from both directions 157 | prediction = self.classifier(pairwise_features) 158 | prediction = prediction[:, 1:-1, 1:-1].contiguous() # remove start/stop tokens 159 | assert prediction.shape == (batch_size, seq_len - 2, seq_len - 2, 2) 160 | outputs = (prediction,) 161 | 162 | if targets is not None: 163 | loss_fct = nn.CrossEntropyLoss(ignore_index=self._ignore_index) 164 | loss_prediction = prediction.view(-1, 2) 165 | assert loss_prediction.shape == (batch_size * (seq_len - 2)**2, 2) 166 | loss_targets = targets.view(-1) 167 | assert loss_targets.shape == (batch_size * (seq_len - 2)**2, ) 168 | contact_loss = loss_fct(loss_prediction, loss_targets) 169 | metrics = {'precision_at_l5': 170 | self.compute_precision_at_l5(sequence_lengths, prediction, targets)} 171 | loss_and_metrics = (contact_loss, metrics) 172 | outputs = (loss_and_metrics,) + outputs 173 | 174 | return outputs 175 | 176 | def compute_precision_at_l5(self, sequence_lengths, prediction, labels): 177 | with torch.no_grad(): 178 | valid_mask = labels != self._ignore_index 179 | seqpos = torch.arange(valid_mask.size(1), device=sequence_lengths.device) 180 | x_ind, y_ind = torch.meshgrid(seqpos, seqpos) 181 | valid_mask &= ((y_ind - x_ind) >= 6).unsqueeze(0) 182 | probs = F.softmax(prediction, 3)[:, :, :, 1] 183 | valid_mask = valid_mask.type_as(probs) 184 | correct = 0 185 | total = 0 186 | for length, prob, label, mask in zip(sequence_lengths, probs, labels, valid_mask): 187 | masked_prob = (prob * mask).view(-1) 188 | most_likely = masked_prob.topk(length // 5, sorted=False) 189 | selected = label.view(-1).gather(0, most_likely.indices) 190 | correct += selected.sum().float() 191 | total += selected.numel() 192 | return correct / total 193 | 194 | 195 | 196 | class ProteinBertForContactProbing(ProteinBertAbstractModel): 197 | 198 | def __init__(self, config): 199 | super().__init__(config) 200 | 201 | self.bert = ProteinBertModel(config) 202 | self.predict = PairwiseContactPredictionHead(config.hidden_size, ignore_index=-1) 203 | 204 | for param in self.bert.parameters(): 205 | param.requires_grad = False 206 | 207 | self.init_weights() 208 | 209 | def forward(self, input_ids, protein_length, input_mask=None, targets=None): 210 | outputs = self.bert(input_ids, input_mask=input_mask) 211 | sequence_output, pooled_output = outputs[:2] 212 | outputs = self.predict(sequence_output, protein_length, targets) + outputs[2:] 213 | # (loss), prediction_scores, (hidden_states), (attentions) 214 | return outputs 215 | 216 | 217 | class ProteinBertForLinearSequenceToSequenceProbing(ProteinBertAbstractModel): 218 | """ProteinBert head for token-level prediction tasks (secondary structure, binding sites)""" 219 | 220 | def __init__(self, config): 221 | super().__init__(config) 222 | 223 | self.bert = ProteinBertModel(config) 224 | 225 | self.classify = LinearSequenceToSequenceClassificationHead( 226 | config.hidden_size, 227 | config.num_labels, 228 | ignore_index=-1, 229 | dropout=0.5) 230 | 231 | for param in self.bert.parameters(): 232 | param.requires_grad = False 233 | 234 | self.init_weights() 235 | 236 | def forward(self, input_ids, input_mask=None, targets=None): 237 | outputs = self.bert(input_ids, input_mask=input_mask) 238 | 239 | sequence_output, pooled_output = outputs[:2] 240 | outputs = self.classify(sequence_output, targets) + outputs[2:] 241 | return outputs 242 | 243 | class BertForLinearSequenceToSequenceProbing(ProteinBertAbstractModel): 244 | """Bert head for token-level prediction tasks (secondary structure, binding sites)""" 245 | 246 | def __init__(self, config): 247 | super().__init__(config) 248 | 249 | self.bert = BertModel(config) 250 | 251 | self.classify = LinearSequenceToSequenceClassificationHead( 252 | config.hidden_size, 253 | config.num_labels, 254 | ignore_index=-1, 255 | dropout=0.5) 256 | 257 | for param in self.bert.parameters(): 258 | param.requires_grad = False 259 | 260 | self.init_weights() 261 | 262 | def forward(self, input_ids, input_mask=None, targets=None): 263 | outputs = self.bert(input_ids, input_mask=input_mask) 264 | 265 | sequence_output, pooled_output = outputs[:2] 266 | outputs = self.classify(sequence_output, targets) + outputs[2:] 267 | return outputs 268 | 269 | 270 | class LinearSequenceToSequenceClassificationHead(nn.Module): 271 | 272 | def __init__(self, 273 | hidden_size: int, 274 | num_labels: int, 275 | ignore_index=-100, 276 | dropout=0.1): 277 | super().__init__() 278 | self.classify = nn.Sequential( 279 | nn.Dropout(dropout), 280 | nn.Linear(hidden_size, num_labels)) 281 | self.num_labels = num_labels 282 | self._ignore_index = ignore_index 283 | 284 | def forward(self, sequence_output, targets=None): 285 | sequence_logits = self.classify(sequence_output) 286 | outputs = (sequence_logits,) 287 | if targets is not None: 288 | loss_fct = nn.CrossEntropyLoss(ignore_index=self._ignore_index) 289 | classification_loss = loss_fct( 290 | sequence_logits.view(-1, self.num_labels), targets.view(-1)) 291 | metrics = { 292 | 'accuracy': accuracy(sequence_logits.view(-1, self.num_labels), targets.view(-1), self._ignore_index), 293 | } 294 | loss_and_metrics = (classification_loss, metrics) 295 | outputs = (loss_and_metrics,) + outputs 296 | return outputs # (loss), sequence_logits 297 | 298 | 299 | def accuracy(logits, labels, ignore_index: int = -100): 300 | with torch.no_grad(): 301 | valid_mask = (labels != ignore_index) 302 | predictions = logits.float().argmax(-1) 303 | correct = (predictions == labels) * valid_mask 304 | return correct.sum().float() / valid_mask.sum().float() 305 | 306 | 307 | def f1(logits, labels, ignore_index: int = -100): 308 | with torch.no_grad(): 309 | valid_mask = (labels != ignore_index) 310 | unique_labels = set(labels * valid_mask) 311 | for label in unique_labels: 312 | if label not in (0, 1): 313 | raise NotImplementedError('Precision is only supported for binary labels') 314 | predictions = logits.float().argmax(-1) 315 | tp = (((predictions == 1) & (labels == 1)) * valid_mask).sum().float() 316 | fp = (((predictions == 1) & (labels == 0)) * valid_mask).sum().float() 317 | fn = (((predictions == 0) & (labels == 1)) * valid_mask).sum().float() 318 | precision = tp / (tp + fp) 319 | recall = tp / (tp + fn) 320 | return 2 * precision * recall / (precision + recall) 321 | -------------------------------------------------------------------------------- /protein_attention/probing/report.py: -------------------------------------------------------------------------------- 1 | """Report on diagnostic classifiers for probing analysis""" 2 | 3 | import json 4 | import pathlib 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import seaborn as sns 9 | 10 | from protein_attention.utils import get_data_path 11 | from protein_attention.utils import get_reports_path 12 | 13 | sns.set() 14 | sns.set_context("paper") 15 | 16 | ss4_names = { 17 | 0: 'Helix', 18 | 1: 'Strand', 19 | 2: 'Turn/Bend' 20 | } 21 | 22 | feature_order = [ 23 | 'Helix', 24 | 'Turn/Bend', 25 | 'Strand', 26 | 'Binding Site', 27 | 'Contact Map' 28 | ] 29 | 30 | feature_to_metric = { 31 | 'Helix': 'F1', 32 | 'Turn/Bend': 'F1', 33 | 'Strand': 'F1', 34 | 'Binding Site': 'Precision @ L/20', 35 | 'Contact Map': 'Precision @ L/5' 36 | } 37 | 38 | feature_to_title = { 39 | 'Helix': 'Secondary Structure: Helix', 40 | 'Turn/Bend': 'Secondary Structure: Turn/Bend', 41 | 'Strand': 'Secondary Structure: Strand', 42 | 'Binding Site': 'Binding Site', 43 | 'Contact Map': 'Contact Map' 44 | } 45 | 46 | 47 | def report(feature_to_scores, attn_feature_to_scores, report_dir, filetype='pdf'): 48 | 49 | # Create detail plots 50 | for i, feature in enumerate(feature_order): 51 | scores = feature_to_scores[feature] 52 | fig, ax = plt.subplots() 53 | ax.plot(list(range(12)), scores) 54 | ax.set_xlabel('Layer', labelpad=10, fontsize=13) 55 | ax.set_title(feature_to_title[feature], pad=12, fontsize=13) 56 | ax.set_ylabel(feature_to_metric.get(feature, ''), labelpad=10, fontsize=13) 57 | fname = report_dir / f'layer_probing_{feature.replace(" ", "_").replace("/", "")}.{filetype}' 58 | print('Saving', fname) 59 | plt.xticks(range(12), range(1, 13)) 60 | plt.tight_layout() 61 | plt.savefig(fname, format=filetype) # , bbox_inches='tight') 62 | plt.close() 63 | scores = np.array(scores) 64 | if scores.sum() > 0: 65 | normalized = scores / scores.sum() 66 | assert np.allclose(normalized.sum(), 1) 67 | mean_center = sum(i * normalized[i] for i in range(12)) 68 | print(feature, 'center:', mean_center) 69 | 70 | # Create combined plot of layer differences 71 | figsize = (3, 5) 72 | plt.figure(figsize=figsize) 73 | fig, ax = plt.subplots(len(feature_order), figsize=figsize, sharex=True, 74 | gridspec_kw={'wspace': 0, 'hspace': .17}) 75 | for i, feature in enumerate(feature_order): 76 | scores = feature_to_scores[feature] 77 | diffs = [scores[i] - scores[i - 1] for i in range(1, 12)] 78 | ax[i].bar(list(range(11)), diffs) 79 | ax[i].tick_params(labelsize=6) 80 | ax[i].set_ylabel(feature.replace('Contact Map', 'Contact'), fontsize=8) 81 | ax[i].yaxis.tick_right() 82 | plt.xticks(list(range(11)), list(range(2, 13))) 83 | plt.xlabel('Layer', fontsize=8) 84 | fname = report_dir / f'multichart_layer_delta_probing.{filetype}' 85 | print('Saving', fname) 86 | plt.savefig(fname, format=filetype, bbox_inches='tight') 87 | plt.close() 88 | 89 | # Create combined plot 90 | figsize = (3, 5) 91 | plt.figure(figsize=figsize) 92 | fig, ax = plt.subplots(len(feature_order), figsize=figsize, sharex=True, 93 | gridspec_kw={'wspace': 0, 'hspace': .17}) 94 | 95 | for i, feature in enumerate(feature_order): 96 | 97 | scores = feature_to_scores[feature] 98 | ax[i].plot(list(range(12)), scores, label='Embedding probe', color='#DD8353') 99 | if feature == 'Contact Map': 100 | scores = attn_feature_to_scores[feature] 101 | ax[i].plot(list(range(12)), scores, label='Attention probe', color='#4D71B0') 102 | l = ax[i].legend(fontsize=6.3, handlelength=1, handletextpad=0.4, frameon=False) 103 | for text in l.get_texts(): 104 | text.set_color('#3B3838') 105 | ax[i].tick_params(labelsize=6) 106 | ax[i].set_ylabel(feature.replace('Contact Map', 'Contact'), fontsize=8) 107 | ax[i].yaxis.tick_right() 108 | ax[i].grid(True, axis='x', color='#F3F2F3', lw=1.2) 109 | ax[i].grid(True, axis='y', color='#F3F2F3', lw=1.2) 110 | 111 | 112 | plt.xticks(range(12), range(1, 13)) 113 | plt.xlabel('Layer', fontsize=8) 114 | fname = report_dir / f'multichart_layer_probing.{filetype}' 115 | print('Saving', fname) 116 | plt.savefig(fname, format=filetype, bbox_inches='tight') 117 | plt.close() 118 | 119 | 120 | if __name__ == "__main__": 121 | 122 | data_path = get_data_path() 123 | 124 | feature_to_scores = {} 125 | attn_feature_to_scores = {} 126 | 127 | # Probing sec struct results 128 | ss_cds = [0, 1, 2] 129 | ss_names = ss4_names 130 | for ss_cd in ss_cds: 131 | feature = ss_names[ss_cd] 132 | scores = [0] * 12 133 | for num_layers in list(range(1, 13)): 134 | fname = data_path / 'probing' / f'secondary_{ss_cd}_{num_layers}/results.json' 135 | try: 136 | with open(fname) as infile: 137 | results = json.load(infile) 138 | f1 = results['f1'] 139 | print(feature, num_layers, f1) 140 | scores[num_layers - 1] = f1 141 | except FileNotFoundError: 142 | print('Skipping', fname) 143 | continue 144 | attn_feature_to_scores[feature] = scores 145 | 146 | # Probing binding site results 147 | feature = 'Binding Site' 148 | scores = [0] * 12 149 | for num_layers in list(range(1, 13)): 150 | fname = data_path / 'probing' / f'binding_sites_{num_layers}/results.json' 151 | try: 152 | with open(fname) as infile: 153 | results = json.load(infile) 154 | print('binding sites', num_layers, 'f1:', results['f1'], 'precision:', results['precision'], 'recall:', 155 | results['recall'], 'precision at k:', results['precision_at_k']) 156 | scores[num_layers - 1] = results['precision_at_k'] 157 | except FileNotFoundError: 158 | print('Skipping', fname) 159 | continue 160 | feature_to_scores[feature] = scores 161 | 162 | # Probing contact map results 163 | feature = 'Contact Map' 164 | for use_attn in False, True: 165 | scores = [0] * 12 166 | for num_layers in list(range(1, 13)): 167 | fname = data_path / 'probing' / f'contact_map{"_attn" if use_attn else ""}_{num_layers}/results.json' 168 | try: 169 | with open(fname) as infile: 170 | results = json.load(infile) 171 | print('contact maps', num_layers, 'f1:', results['f1'], 'precision:', results['precision'], 'recall:', 172 | results['recall'], 'precision at k:', results['precision_at_k']) 173 | scores[num_layers - 1] = results['precision_at_k'] 174 | except FileNotFoundError: 175 | print('Skipping', fname) 176 | continue 177 | if use_attn: 178 | attn_feature_to_scores[feature] = scores 179 | else: 180 | feature_to_scores[feature] = scores 181 | 182 | 183 | report_dir = get_reports_path() / 'probing' 184 | pathlib.Path(report_dir).mkdir(parents=True, exist_ok=True) 185 | report(feature_to_scores, attn_feature_to_scores, report_dir) 186 | -------------------------------------------------------------------------------- /protein_attention/probing/scripts/probe_contact.sh: -------------------------------------------------------------------------------- 1 | python probe.py \ 2 | contact_map \ 3 | --batch_size 2 \ 4 | --learning_rate .00005 \ 5 | --warmup_steps 2000 \ 6 | --num_train_epochs 50 \ 7 | --save_freq improvement \ 8 | --patience 3 \ 9 | --num_workers 0 \ 10 | --max_seq_len 512 -------------------------------------------------------------------------------- /protein_attention/probing/scripts/probe_contact_attention.sh: -------------------------------------------------------------------------------- 1 | python probe.py \ 2 | contact_map \ 3 | --attention_probe \ 4 | --batch_size 16 \ 5 | --learning_rate .0001 \ 6 | --warmup_steps 2000 \ 7 | --num_train_epochs 50 \ 8 | --save_freq improvement \ 9 | --patience 3 \ 10 | --num_workers 0 \ 11 | --max_seq_len 512 -------------------------------------------------------------------------------- /protein_attention/probing/scripts/probe_sites.sh: -------------------------------------------------------------------------------- 1 | python probe.py \ 2 | binding_sites \ 3 | --batch_size 8 \ 4 | --learning_rate .0001 \ 5 | --warmup_steps 500 \ 6 | --num_train_epochs 50 \ 7 | --save_freq improvement \ 8 | --patience 3 \ 9 | --num_workers 0 \ 10 | --max_seq_len 512 -------------------------------------------------------------------------------- /protein_attention/probing/scripts/probe_ss4_0.sh: -------------------------------------------------------------------------------- 1 | python probe.py \ 2 | secondary \ 3 | --label_scheme ss4 \ 4 | --one_vs_all_label 0 \ 5 | --batch_size 8 \ 6 | --learning_rate .0001 \ 7 | --warmup_steps 500 \ 8 | --num_train_epochs 50 \ 9 | --save_freq improvement \ 10 | --patience 3 \ 11 | --num_workers 0 \ 12 | --max_seq_len 512 -------------------------------------------------------------------------------- /protein_attention/probing/scripts/probe_ss4_1.sh: -------------------------------------------------------------------------------- 1 | python probe.py \ 2 | secondary \ 3 | --label_scheme ss4 \ 4 | --one_vs_all_label 1 \ 5 | --batch_size 8 \ 6 | --learning_rate .0001 \ 7 | --warmup_steps 500 \ 8 | --num_train_epochs 50 \ 9 | --save_freq improvement \ 10 | --patience 3 \ 11 | --num_workers 0 \ 12 | --max_seq_len 512 -------------------------------------------------------------------------------- /protein_attention/probing/scripts/probe_ss4_2.sh: -------------------------------------------------------------------------------- 1 | python probe.py \ 2 | secondary \ 3 | --label_scheme ss4 \ 4 | --one_vs_all_label 2 \ 5 | --batch_size 8 \ 6 | --learning_rate .0001 \ 7 | --warmup_steps 500 \ 8 | --num_train_epochs 50 \ 9 | --save_freq improvement \ 10 | --patience 3 \ 11 | --num_workers 0 \ 12 | --max_seq_len 512 -------------------------------------------------------------------------------- /protein_attention/utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | 4 | def get_project_root() -> Path: 5 | """Returns project root folder.""" 6 | return Path(__file__).parent.parent 7 | 8 | 9 | def get_data_path() -> Path: 10 | """Returns data root folder.""" 11 | return get_project_root() / Path('data') 12 | 13 | 14 | def get_models_path() -> Path: 15 | """Returns models root folder.""" 16 | return get_project_root() / Path('models') 17 | 18 | 19 | def get_reports_path() -> Path: 20 | """Returns models root folder.""" 21 | return get_project_root() / Path('reports') 22 | 23 | 24 | def get_cache_path() -> Path: 25 | """Returns data root folder.""" 26 | return get_data_path() / Path('cache') 27 | 28 | 29 | ### From https://github.com/songlab-cal/tape/blob/master/tape/utils/utils.py 30 | import typing 31 | 32 | 33 | # def write_lmdb(filename: str, iterable: typing.Iterable, map_size: int = 2 ** 20): 34 | # """Utility for writing a dataset to an LMDB file. 35 | # Args: 36 | # filename (str): Output filename to write to 37 | # iterable (Iterable): An iterable dataset to write to. Entries must be pickleable. 38 | # map_size (int, optional): Maximum allowable size of database in bytes. Required by LMDB. 39 | # You will likely have to increase this. Default: 1MB. 40 | # """ 41 | # import lmdb 42 | # import pickle as pkl 43 | # env = lmdb.open(filename, map_size=map_size) 44 | # 45 | # with env.begin(write=True) as txn: 46 | # for i, entry in enumerate(iterable): 47 | # txn.put(str(i).encode(), pkl.dumps(entry)) 48 | # txn.put(b'num_examples', pkl.dumps(i + 1)) 49 | # env.close() 50 | -------------------------------------------------------------------------------- /reports/attention_analysis/blosum/edge_features_aa_prot_bert/aa_corr_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/blosum/edge_features_aa_prot_bert/aa_corr_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/blosum/edge_features_aa_prot_bert/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_aa_prot_bert", "features": ["aa"], "dataset": "proteinnet", "num_sequences": 5000, "model": "bert", "model_version": "prot_bert", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/blosum/edge_features_aa_prot_bert/blosum62.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/blosum/edge_features_aa_prot_bert/blosum62.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/blosum/edge_features_aa_prot_bert_bfd/aa_corr_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/blosum/edge_features_aa_prot_bert_bfd/aa_corr_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/blosum/edge_features_aa_prot_bert_bfd/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_aa_prot_bert_bfd", "features": ["aa"], "dataset": "proteinnet", "num_sequences": 5000, "model": "bert", "model_version": "prot_bert_bfd", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/blosum/edge_features_aa_prot_bert_bfd/blosum62.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/blosum/edge_features_aa_prot_bert_bfd/blosum62.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/blosum/edge_features_aa_prot_xlnet/aa_corr_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/blosum/edge_features_aa_prot_xlnet/aa_corr_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/blosum/edge_features_aa_prot_xlnet/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_aa_prot_xlnet", "features": ["aa"], "dataset": "proteinnet", "num_sequences": 5000, "model": "xlnet", "model_version": "prot_xlnet", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/blosum/edge_features_aa_prot_xlnet/blosum62.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/blosum/edge_features_aa_prot_xlnet/blosum62.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/blosum/edge_features_aa_tape_bert/aa_corr_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/blosum/edge_features_aa_tape_bert/aa_corr_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/blosum/edge_features_aa_tape_bert/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_aa_tape_bert", "features": ["aa"], "dataset": "proteinnet", "num_sequences": 5000, "model": "bert", "model_version": null, "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/blosum/edge_features_aa_tape_bert/blosum62.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/blosum/edge_features_aa_tape_bert/blosum62.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_albert/aa_to_A.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_albert/aa_to_A.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_albert/aa_to_C.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_albert/aa_to_C.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_albert/aa_to_D.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_albert/aa_to_D.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_albert/aa_to_E.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_albert/aa_to_E.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_albert/aa_to_F.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_albert/aa_to_F.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_albert/aa_to_G.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_albert/aa_to_G.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_albert/aa_to_H.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_albert/aa_to_H.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_albert/aa_to_I.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_albert/aa_to_I.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_albert/aa_to_K.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_albert/aa_to_K.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_albert/aa_to_L.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_albert/aa_to_L.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_albert/aa_to_M.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_albert/aa_to_M.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_albert/aa_to_N.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_albert/aa_to_N.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_albert/aa_to_P.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_albert/aa_to_P.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_albert/aa_to_Q.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_albert/aa_to_Q.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_albert/aa_to_R.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_albert/aa_to_R.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_albert/aa_to_S.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_albert/aa_to_S.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_albert/aa_to_T.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_albert/aa_to_T.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_albert/aa_to_V.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_albert/aa_to_V.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_albert/aa_to_W.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_albert/aa_to_W.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_albert/aa_to_Y.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_albert/aa_to_Y.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_albert/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_aa_prot_albert", "features": ["aa"], "dataset": "proteinnet", "num_sequences": 5000, "model": "bert", "model_version": "prot_albert", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert/aa_to_A.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert/aa_to_A.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert/aa_to_C.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert/aa_to_C.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert/aa_to_D.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert/aa_to_D.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert/aa_to_E.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert/aa_to_E.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert/aa_to_F.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert/aa_to_F.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert/aa_to_G.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert/aa_to_G.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert/aa_to_H.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert/aa_to_H.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert/aa_to_I.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert/aa_to_I.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert/aa_to_K.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert/aa_to_K.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert/aa_to_L.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert/aa_to_L.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert/aa_to_M.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert/aa_to_M.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert/aa_to_N.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert/aa_to_N.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert/aa_to_P.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert/aa_to_P.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert/aa_to_Q.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert/aa_to_Q.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert/aa_to_R.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert/aa_to_R.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert/aa_to_S.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert/aa_to_S.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert/aa_to_T.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert/aa_to_T.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert/aa_to_V.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert/aa_to_V.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert/aa_to_W.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert/aa_to_W.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert/aa_to_Y.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert/aa_to_Y.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_aa_prot_bert", "features": ["aa"], "dataset": "proteinnet", "num_sequences": 5000, "model": "bert", "model_version": "prot_bert", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_A.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_A.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_C.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_C.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_D.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_D.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_E.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_E.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_F.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_F.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_G.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_G.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_H.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_H.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_I.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_I.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_K.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_K.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_L.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_L.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_M.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_M.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_N.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_N.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_P.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_P.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_Q.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_Q.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_R.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_R.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_S.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_S.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_T.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_T.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_V.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_V.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_W.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_W.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_Y.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_bert_bfd/aa_to_Y.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_bert_bfd/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_aa_prot_bert_bfd", "features": ["aa"], "dataset": "proteinnet", "num_sequences": 5000, "model": "bert", "model_version": "prot_bert_bfd", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_A.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_A.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_C.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_C.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_D.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_D.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_E.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_E.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_F.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_F.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_G.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_G.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_H.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_H.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_I.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_I.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_K.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_K.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_L.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_L.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_M.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_M.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_N.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_N.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_P.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_P.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_Q.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_Q.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_R.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_R.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_S.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_S.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_T.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_T.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_V.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_V.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_W.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_W.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_Y.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_prot_xlnet/aa_to_Y.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_prot_xlnet/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_aa_prot_xlnet", "features": ["aa"], "dataset": "proteinnet", "num_sequences": 5000, "model": "xlnet", "model_version": "prot_xlnet", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_tape_bert/aa_to_A.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_tape_bert/aa_to_A.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_tape_bert/aa_to_C.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_tape_bert/aa_to_C.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_tape_bert/aa_to_D.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_tape_bert/aa_to_D.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_tape_bert/aa_to_E.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_tape_bert/aa_to_E.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_tape_bert/aa_to_F.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_tape_bert/aa_to_F.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_tape_bert/aa_to_G.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_tape_bert/aa_to_G.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_tape_bert/aa_to_H.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_tape_bert/aa_to_H.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_tape_bert/aa_to_I.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_tape_bert/aa_to_I.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_tape_bert/aa_to_K.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_tape_bert/aa_to_K.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_tape_bert/aa_to_L.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_tape_bert/aa_to_L.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_tape_bert/aa_to_M.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_tape_bert/aa_to_M.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_tape_bert/aa_to_N.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_tape_bert/aa_to_N.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_tape_bert/aa_to_P.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_tape_bert/aa_to_P.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_tape_bert/aa_to_Q.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_tape_bert/aa_to_Q.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_tape_bert/aa_to_R.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_tape_bert/aa_to_R.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_tape_bert/aa_to_S.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_tape_bert/aa_to_S.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_tape_bert/aa_to_T.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_tape_bert/aa_to_T.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_tape_bert/aa_to_V.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_tape_bert/aa_to_V.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_tape_bert/aa_to_W.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_tape_bert/aa_to_W.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_tape_bert/aa_to_Y.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_aa_tape_bert/aa_to_Y.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_aa_tape_bert/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_aa_tape_bert", "features": ["aa"], "dataset": "proteinnet", "num_sequences": 5000, "model": "bert", "model_version": null, "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_combined_tape_bert/args_edge_features_contact_tape_bert.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_contact_tape_bert", "features": ["contact_map"], "dataset": "proteinnet", "num_sequences": 5000, "model": "bert", "model_version": null, "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_combined_tape_bert/args_edge_features_sec_tape_bert.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_sec_tape_bert", "features": ["ss4"], "dataset": "secondary", "num_sequences": 5000, "model": "bert", "model_version": null, "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_combined_tape_bert/args_edge_features_sites_tape_bert.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_sites_tape_bert", "features": ["binding_sites"], "dataset": "binding_sites", "num_sequences": 5000, "model": "bert", "model_version": null, "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_combined_tape_bert/combined_features.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_combined_tape_bert/combined_features.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_contact_prot_albert/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_contact_prot_albert", "features": ["contact_map"], "dataset": "proteinnet", "num_sequences": 5000, "model": "bert", "model_version": "prot_albert", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_contact_prot_albert/contact_map.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_contact_prot_albert/contact_map.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_contact_prot_albert_topheads/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_contact_prot_albert", "features": ["contact_map"], "dataset": "proteinnet", "num_sequences": 5000, "model": "bert", "model_version": "prot_albert", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_contact_prot_albert_topheads/contact_map.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_contact_prot_albert_topheads/contact_map.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_contact_prot_bert/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_contact_prot_bert", "features": ["contact_map"], "dataset": "proteinnet", "num_sequences": 5000, "model": "bert", "model_version": "prot_bert", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_contact_prot_bert/contact_map.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_contact_prot_bert/contact_map.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_contact_prot_bert_bfd/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_contact_prot_bert_bfd", "features": ["contact_map"], "dataset": "proteinnet", "num_sequences": 5000, "model": "bert", "model_version": "prot_bert_bfd", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_contact_prot_bert_bfd/contact_map.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_contact_prot_bert_bfd/contact_map.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_contact_prot_bert_bfd_topheads/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_contact_prot_bert_bfd", "features": ["contact_map"], "dataset": "proteinnet", "num_sequences": 5000, "model": "bert", "model_version": "prot_bert_bfd", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_contact_prot_bert_bfd_topheads/contact_map.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_contact_prot_bert_bfd_topheads/contact_map.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_contact_prot_bert_topheads/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_contact_prot_bert", "features": ["contact_map"], "dataset": "proteinnet", "num_sequences": 5000, "model": "bert", "model_version": "prot_bert", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_contact_prot_bert_topheads/contact_map.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_contact_prot_bert_topheads/contact_map.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_contact_prot_xlnet/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_contact_prot_xlnet", "features": ["contact_map"], "dataset": "proteinnet", "num_sequences": 5000, "model": "xlnet", "model_version": "prot_xlnet", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_contact_prot_xlnet/contact_map.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_contact_prot_xlnet/contact_map.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_contact_prot_xlnet_topheads/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_contact_prot_xlnet", "features": ["contact_map"], "dataset": "proteinnet", "num_sequences": 5000, "model": "xlnet", "model_version": "prot_xlnet", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_contact_prot_xlnet_topheads/contact_map.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_contact_prot_xlnet_topheads/contact_map.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_contact_tape_bert/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_contact_tape_bert", "features": ["contact_map"], "dataset": "proteinnet", "num_sequences": 5000, "model": "bert", "model_version": null, "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_contact_tape_bert/contact_map.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_contact_tape_bert/contact_map.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_contact_tape_bert_topheads/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_contact_tape_bert", "features": ["contact_map"], "dataset": "proteinnet", "num_sequences": 5000, "model": "bert", "model_version": null, "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_contact_tape_bert_topheads/contact_map.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_contact_tape_bert_topheads/contact_map.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_modifications_prot_albert/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_modifications_prot_albert", "features": ["protein_modifications"], "dataset": "protein_modifications", "num_sequences": 5000, "model": "bert", "model_version": "prot_albert", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_modifications_prot_albert/protein_modification_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_modifications_prot_albert/protein_modification_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_modifications_prot_albert_topheads/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_modifications_prot_albert", "features": ["protein_modifications"], "dataset": "protein_modifications", "num_sequences": 5000, "model": "bert", "model_version": "prot_albert", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_modifications_prot_albert_topheads/protein_modification_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_modifications_prot_albert_topheads/protein_modification_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_modifications_prot_bert/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_modifications_prot_bert", "features": ["protein_modifications"], "dataset": "protein_modifications", "num_sequences": 5000, "model": "bert", "model_version": "prot_bert", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_modifications_prot_bert/protein_modification_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_modifications_prot_bert/protein_modification_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_modifications_prot_bert_bfd/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_modifications_prot_bert_bfd", "features": ["protein_modifications"], "dataset": "protein_modifications", "num_sequences": 5000, "model": "bert", "model_version": "prot_bert_bfd", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_modifications_prot_bert_bfd/protein_modification_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_modifications_prot_bert_bfd/protein_modification_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_modifications_prot_bert_bfd_topheads/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_modifications_prot_bert_bfd", "features": ["protein_modifications"], "dataset": "protein_modifications", "num_sequences": 5000, "model": "bert", "model_version": "prot_bert_bfd", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_modifications_prot_bert_bfd_topheads/protein_modification_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_modifications_prot_bert_bfd_topheads/protein_modification_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_modifications_prot_bert_topheads/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_modifications_prot_bert", "features": ["protein_modifications"], "dataset": "protein_modifications", "num_sequences": 5000, "model": "bert", "model_version": "prot_bert", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_modifications_prot_bert_topheads/protein_modification_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_modifications_prot_bert_topheads/protein_modification_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_modifications_prot_xlnet/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_modifications_prot_xlnet", "features": ["protein_modifications"], "dataset": "protein_modifications", "num_sequences": 5000, "model": "xlnet", "model_version": "prot_xlnet", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_modifications_prot_xlnet/protein_modification_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_modifications_prot_xlnet/protein_modification_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_modifications_prot_xlnet_topheads/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_modifications_prot_xlnet", "features": ["protein_modifications"], "dataset": "protein_modifications", "num_sequences": 5000, "model": "xlnet", "model_version": "prot_xlnet", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_modifications_prot_xlnet_topheads/protein_modification_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_modifications_prot_xlnet_topheads/protein_modification_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_modifications_tape_bert/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_modifications_tape_bert", "features": ["protein_modifications"], "dataset": "protein_modifications", "num_sequences": 5000, "model": "bert", "model_version": null, "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_modifications_tape_bert/protein_modification_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_modifications_tape_bert/protein_modification_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_modifications_tape_bert_topheads/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_modifications_tape_bert", "features": ["protein_modifications"], "dataset": "protein_modifications", "num_sequences": 5000, "model": "bert", "model_version": null, "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_modifications_tape_bert_topheads/protein_modification_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_modifications_tape_bert_topheads/protein_modification_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_prot_albert/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_sec_prot_albert", "features": ["ss4"], "dataset": "secondary", "num_sequences": 5000, "model": "bert", "model_version": "prot_albert", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_prot_albert/sec_struct_to_0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sec_prot_albert/sec_struct_to_0.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_prot_albert/sec_struct_to_1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sec_prot_albert/sec_struct_to_1.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_prot_albert/sec_struct_to_2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sec_prot_albert/sec_struct_to_2.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_prot_albert/sec_struct_to_3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sec_prot_albert/sec_struct_to_3.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_prot_bert/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_sec_prot_bert", "features": ["ss4"], "dataset": "secondary", "num_sequences": 5000, "model": "bert", "model_version": "prot_bert", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_prot_bert/sec_struct_to_0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sec_prot_bert/sec_struct_to_0.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_prot_bert/sec_struct_to_1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sec_prot_bert/sec_struct_to_1.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_prot_bert/sec_struct_to_2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sec_prot_bert/sec_struct_to_2.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_prot_bert/sec_struct_to_3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sec_prot_bert/sec_struct_to_3.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_prot_bert_bfd/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_sec_prot_bert_bfd", "features": ["ss4"], "dataset": "secondary", "num_sequences": 5000, "model": "bert", "model_version": "prot_bert_bfd", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_prot_bert_bfd/sec_struct_to_0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sec_prot_bert_bfd/sec_struct_to_0.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_prot_bert_bfd/sec_struct_to_1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sec_prot_bert_bfd/sec_struct_to_1.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_prot_bert_bfd/sec_struct_to_2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sec_prot_bert_bfd/sec_struct_to_2.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_prot_bert_bfd/sec_struct_to_3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sec_prot_bert_bfd/sec_struct_to_3.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_prot_xlnet/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_sec_prot_xlnet", "features": ["ss4"], "dataset": "secondary", "num_sequences": 5000, "model": "xlnet", "model_version": "prot_xlnet", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_prot_xlnet/sec_struct_to_0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sec_prot_xlnet/sec_struct_to_0.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_prot_xlnet/sec_struct_to_1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sec_prot_xlnet/sec_struct_to_1.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_prot_xlnet/sec_struct_to_2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sec_prot_xlnet/sec_struct_to_2.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_prot_xlnet/sec_struct_to_3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sec_prot_xlnet/sec_struct_to_3.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_tape_bert/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_sec_tape_bert", "features": ["ss4"], "dataset": "secondary", "num_sequences": 5000, "model": "bert", "model_version": null, "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_tape_bert/sec_struct_to_0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sec_tape_bert/sec_struct_to_0.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_tape_bert/sec_struct_to_1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sec_tape_bert/sec_struct_to_1.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_tape_bert/sec_struct_to_2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sec_tape_bert/sec_struct_to_2.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sec_tape_bert/sec_struct_to_3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sec_tape_bert/sec_struct_to_3.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sites_prot_albert/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_sites_prot_albert", "features": ["binding_sites"], "dataset": "binding_sites", "num_sequences": 5000, "model": "bert", "model_version": "prot_albert", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sites_prot_albert/binding_site_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sites_prot_albert/binding_site_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sites_prot_albert_topheads/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_sites_prot_albert", "features": ["binding_sites"], "dataset": "binding_sites", "num_sequences": 5000, "model": "bert", "model_version": "prot_albert", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sites_prot_albert_topheads/binding_site_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sites_prot_albert_topheads/binding_site_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sites_prot_bert/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_sites_prot_bert", "features": ["binding_sites"], "dataset": "binding_sites", "num_sequences": 5000, "model": "bert", "model_version": "prot_bert", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sites_prot_bert/binding_site_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sites_prot_bert/binding_site_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sites_prot_bert_bfd/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_sites_prot_bert_bfd", "features": ["binding_sites"], "dataset": "binding_sites", "num_sequences": 5000, "model": "bert", "model_version": "prot_bert_bfd", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sites_prot_bert_bfd/binding_site_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sites_prot_bert_bfd/binding_site_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sites_prot_bert_bfd_topheads/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_sites_prot_bert_bfd", "features": ["binding_sites"], "dataset": "binding_sites", "num_sequences": 5000, "model": "bert", "model_version": "prot_bert_bfd", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sites_prot_bert_bfd_topheads/binding_site_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sites_prot_bert_bfd_topheads/binding_site_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sites_prot_bert_topheads/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_sites_prot_bert", "features": ["binding_sites"], "dataset": "binding_sites", "num_sequences": 5000, "model": "bert", "model_version": "prot_bert", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sites_prot_bert_topheads/binding_site_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sites_prot_bert_topheads/binding_site_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sites_prot_xlnet/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_sites_prot_xlnet", "features": ["binding_sites"], "dataset": "binding_sites", "num_sequences": 5000, "model": "xlnet", "model_version": "prot_xlnet", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sites_prot_xlnet/binding_site_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sites_prot_xlnet/binding_site_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sites_prot_xlnet_topheads/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_sites_prot_xlnet", "features": ["binding_sites"], "dataset": "binding_sites", "num_sequences": 5000, "model": "xlnet", "model_version": "prot_xlnet", "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sites_prot_xlnet_topheads/binding_site_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sites_prot_xlnet_topheads/binding_site_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sites_tape_bert/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_sites_tape_bert", "features": ["binding_sites"], "dataset": "binding_sites", "num_sequences": 5000, "model": "bert", "model_version": null, "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sites_tape_bert/binding_site_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sites_tape_bert/binding_site_to.pdf -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sites_tape_bert_topheads/args.json: -------------------------------------------------------------------------------- 1 | {"exp_name": "edge_features_sites_tape_bert", "features": ["binding_sites"], "dataset": "binding_sites", "num_sequences": 5000, "model": "bert", "model_version": null, "model_dir": null, "shuffle": true, "max_seq_len": 512, "seed": 123, "min_attn": 0.3, "no_cuda": false} -------------------------------------------------------------------------------- /reports/attention_analysis/edge_features_sites_tape_bert_topheads/binding_site_to.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/attention_analysis/edge_features_sites_tape_bert_topheads/binding_site_to.pdf -------------------------------------------------------------------------------- /reports/probing/multichart_layer_probing.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/provis/29d4c53de825476cd9e02c38ea40288208eaea61/reports/probing/multichart_layer_probing.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tape-proteins==0.5 2 | torch==1.4.0 3 | biopython==1.77 4 | scikit-learn==0.23.1 5 | seaborn==0.12.2 6 | matplotlib==3.2.1 7 | statsmodels==0.12.0 8 | transformers==2.4.1 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from setuptools import setup, find_packages 4 | 5 | with open('requirements.txt', 'r') as reqs: 6 | requirements = reqs.read().split() 7 | 8 | setup( 9 | name='provis', 10 | packages=["protein_attention"], 11 | version='0.0.1', 12 | install_requires=requirements, 13 | ) --------------------------------------------------------------------------------