├── .bash_profile ├── .circleci └── config.yml ├── .pylintrc ├── LICENSE ├── Pipfile ├── README.md ├── data ├── batch_context.pickle ├── en.wiki.bpe.vs25000.model └── sample-dataset.json ├── lint.sh ├── requirements-dev.txt ├── requirements.txt ├── setup.cfg ├── src ├── __init__.py ├── config.py ├── criterion.py ├── dataset.py ├── lr_decay.py ├── model.py └── model_components.py └── tests ├── __init__.py ├── test_dataset.py ├── test_model.py └── test_model_components.py /.bash_profile: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=. -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | jobs: 3 | build: 4 | docker: 5 | - image: circleci/python:3.7 6 | steps: 7 | - checkout 8 | 9 | - restore_cache: 10 | # Read about caching dependencies: https://circleci.com/docs/2.0/caching/ 11 | key: deps9-{{ .Branch }}-{{ checksum "Pipfile.lock" }} 12 | - run: 13 | command: | 14 | sudo pip install pipenv 15 | pipenv install --skip-lock --dev 16 | source .bash_profile 17 | - save_cache: # cache Python dependencies using checksum of Pipfile as the cache-key 18 | key: deps9-{{ .Branch }}-{{ checksum "Pipfile.lock" }} 19 | paths: 20 | - "venv" 21 | 22 | - run: 23 | name: run evaluation tests 24 | command: | 25 | pipenv run pytest -s --ignore=tests/test_model.py tests/ 26 | - run: 27 | name: run linting 28 | when: always 29 | command: | 30 | pipenv run bash lint.sh 31 | 32 | - store_artifacts: 33 | path: test-reports 34 | destination: test-reports 35 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MESSAGES CONTROL] 2 | 3 | # Disable the message(s) with the given id(s). 4 | # E1130 - invalid-unary-operand-type false positive https://github.com/PyCQA/pylint/issues/1472 5 | # E1136 - unsubscriptable (unsubscriptable-object) - Pylint is failing to infer correct type from astroid https://github.com/PyCQA/pylint/issues/2849 6 | # R0801 - similar lines across files 7 | # W0511 - TODO comments 8 | # W1202 - logging-format-interpolation - Behavior barring fstrings in logging https://github.com/PyCQA/pylint/issues/2395 9 | # missing-function-dosctring: docstyle handles 10 | # bad-continuation: disagrees with black formatter 11 | disable=E1130,E1136,R0801,W0511,W1202,missing-function-docstring,bad-continuation 12 | # LAST AUDITED: 2019-01-09 13 | 14 | [MASTER] 15 | 16 | # A comma-separated list of package or module names from where C extensions may 17 | # be loaded. Extensions are loading into the active Python interpreter and may 18 | # run arbitrary code 19 | extension-pkg-whitelist=numpy 20 | 21 | [TYPECHECK] 22 | 23 | # List of module names for which member attributes should not be checked 24 | # (useful for modules/projects where namespaces are manipulated during runtime 25 | # and thus existing member attributes cannot be deduced by static analysis. It 26 | # supports qualified module names, as well as Unix pattern matching. 27 | ignored-modules=cv2,numpy,tensorflow,torch 28 | 29 | # List of classes names for which member attributes should not be checked 30 | # (useful for classes with attributes dynamically set). This supports can work 31 | # with qualified names. 32 | ignored-classes=cv2,numpy,tensorflow,torch 33 | 34 | [BASIC] 35 | 36 | # Good variable names which should always be accepted, separated by a comma 37 | good-names = _, e, f, fn, i, j, k, n, N, m, M, D, p, t, v, x, X, y, Y, w, h, W, H, x1, x2, y1, y2, ax, df 38 | 39 | # Regular expression which should only match correct function names 40 | function-rgx=[a-z_][a-z0-9_]{2,70}$ 41 | 42 | # Regular expression which should only match correct method names 43 | method-rgx=[a-z_][a-z0-9_]{2,70}$ 44 | 45 | [FORMAT] 46 | 47 | # Maximum number of characters on a single line. 48 | max-line-length = 120 49 | 50 | [DESIGN] 51 | # Minimum number of public methods for a class (see R0903). 52 | min-public-methods = 0 53 | 54 | # Maximum number of attributes for a class (see R0902). 55 | max-attributes = 15 56 | 57 | max-locals = 18 58 | 59 | max-args = 8 60 | 61 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | name = "pypi" 3 | url = "https://pypi.org/simple" 4 | verify_ssl = true 5 | 6 | [dev-packages] 7 | pandas = "*" 8 | black = "*" 9 | flake8 = "*" 10 | isort = "*" 11 | pytest = "*" 12 | pytest-cov = "*" 13 | codecov = "*" 14 | coverage = "*" 15 | pandas-sphinx-theme = {git = "https://github.com/pandas-dev/pandas-sphinx-theme.git", ref = "master"} 16 | Sphinx = "*" 17 | pydocstyle='*' 18 | autoflake='*' 19 | 20 | [packages] 21 | tqdm = "==4.48.2" 22 | sentencepiece = "==0.1.92" 23 | torch = "==1.3.1" 24 | sklearn = "==0.0" 25 | jupyter = "*" 26 | pandas = "==1.1.0" 27 | pytorch-lightning = "==0.9.1rc1" 28 | pytest = "*" 29 | pylint = "*" 30 | safety = "==1.9.0" 31 | spellcheck = "==1.0.2" 32 | bandit = "==1.6.2" 33 | 34 | [requires] 35 | python_version = "3.7" 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![LICENSE](https://img.shields.io/github/license/jordiclive/Convert-PolyAI-Torch.svg)](https://github.com/jordiclive/Convert-PolyAI-Torch/blob/master/LICENSE) 2 | [![CircleCI](https://circleci.com/gh/jordiclive/Convert-PolyAI-Torch.svg?style=shield)](https://circleci.com/gh/jordiclive/Convert-PolyAI-Torch) 3 | ![GitHub issues](https://img.shields.io/github/issues/jordiclive/Convert-PolyAI-Torch.svg) 4 | 5 | 6 | # First complete Pytorch implmentation of PolyAI's [ConveRT](https://paperswithcode.com/paper/convert-efficient-and-accurate-conversational) 7 | ConveRT: Efficient and Accurate Conversational Representations from Transformers for Pytorch 8 | ## Developed By 9 | 10 | Jordan Clive(jordan.clive19@imperial.ac.uk). If you have any questions or ideas/improvements please contact me. 11 | 12 | ## Background 13 | 14 | PolyAI built the model in TensorFlow 1—they did not release the code—although, they did release the model object on TensorFlow Hub, so it can be used, fine tuned and the graph/model weights inspected. 15 | 16 | This is a Pytorch implementation built from scratch, with inspiration from [codertimo](https://github.com/codertimo/ConveRT-pytorch) who began a similar project but did not get round to implementing the whole model. 17 | 18 | 19 | ## Implementation details 20 | 21 | Note: this is only for the single context model for the moment. 22 | ... 23 | 24 | 25 | ## Discrepancies (+ possible discrepancies) with original implementation 26 | ... 27 | 28 | 29 | ## TODO 30 | 31 | - [ ] Finish optimizing on a few batches, efficiency checks (apex fused optimizer etc.) 32 | - [ ] write further training evaluation tests, Continuous Integration tests, artifacts. 33 | - [ ] Write new apache beam Dataflow script, find cheapest way to store on GCP bucket 34 | - [ ] work out tmp/ file transfer bash scripts during training for logs and checkpoints . GCSFuse 35 | - [ ] more advanced quantization akin to original paper 36 | - [ ] Pretrain on 12 GPU nodes with one Tesla K80 each for 18 hours 37 | - [ ] Do fine tuning downstream benchmarks and compare results 38 | 39 | ## Training & Logging & Checkpointing 40 | 41 | The trainer is in model.py, pass in Pytorch Lightning trainer args if familiar with those, as well as [ConveRTTrainConfig](https://github.com/jordiclive/Convert-PolyAI-Torch/blob/c4ddec5a2ef9c4077d02aeb139029f520d642b9f/src/config.py#L21) arguments. Although a lot of the Lightning had to be overriden, Lightning hooks make this rather simple, so it is well worth putting it in the Lightning framework—so it iseasier to scale up the model, and carry out distributed training and FP16 training. Although the original paper is heavily optimized for floating point 'quantization aware' optimization eg. 8 bit per embedding parameters with dynamic quantization ranges during training, which I need to look into. (One of the main points of ConveRT is it's quantization). Currently viewing logs in default /lightning_logs with Tensorboard. 42 | 43 | 44 | ``` 45 | python model.py \ 46 | --gpus 8 \ 47 | --precision 16 \ 48 | --batch_size 512 \ 49 | -- distributed_backend 'ddp' 50 | ``` 51 | 52 | 53 | 54 | ## Dataset 55 | PolyAI Reddit data corpus details on how to run on dataflow 56 | 57 | 58 | 59 | ## Repository structure 60 | 61 | ``` 62 | ├── LICENSE 63 | ├── Pipfile 64 | ├── README.md 65 | ├── data 66 | │   ├── batch_context.pickle # example model input object for testing 67 | │   ├── en.wiki.bpe.vs25000.model # tokenizer model 68 | │   └── sample-dataset.json # mini dataset for running overfit batch tests etc. 69 | ├── lint.sh 70 | ├── requirements-dev.txt 71 | ├── requirements.txt 72 | ├── setup.cfg 73 | ├── src 74 | │   ├── __init__.py 75 | │   ├── config.py #Modelconfig and training config 76 | │   ├── criterion.py 77 | │   ├── dataset.py # prepare dataloaders, with pytorch lightning DataModule 78 | │   ├── lr_decay.py # Lightning callback fn to implement linear warm up of learning rate, followed by cosine annealing 79 | │   ├── model.py # trainer in here, uses Pytorch Lightning for scale 80 | │   └── model_components.py # All model consituent components, context and reply share Transformer blocks before model forks into distinct projection mlps 81 | └── tests 82 | ├── __init__.py 83 | ├── test_dataset.py 84 | ├── test_model.py # run overfitting on small batch tests etc. check actually trains. 85 | └── test_model_components.py # check shapes etc. 86 | ``` 87 | 88 | ## License 89 | 90 | Apache License 91 | 92 | ## Citations 93 | 94 | - [ConveRT: Efficient and Accurate Conversational Representations from Transformers](https://arxiv.org/abs/1911.03688) 95 | 96 | ```bibtext 97 | @misc{1911.03688, 98 | Author = {Matthew Henderson and Iñigo Casanueva and Nikola Mrkšić and Pei-Hao Su and Tsung-Hsien Wen and Ivan Vulić}, 99 | Title = {ConveRT: Efficient and Accurate Conversational Representations from Transformers}, 100 | Year = {2019}, 101 | Eprint = {arXiv:1911.03688}, 102 | } 103 | ``` 104 | 105 | ## References 106 | 107 | The [dataset](https://github.com/jordiclive/Convert-PolyAI-Torch/blob/master/src/dataset.py) preparation code borrows heavily from [codertimo](https://github.com/codertimo). As well as seed code and inspiriation for some of the model components 108 | - [Codertimo's in progress Pytorch conveRT implementation](https://github.com/codertimo/ConveRT-pytorch) 109 | -------------------------------------------------------------------------------- /data/batch_context.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jordiclive/Convert-PolyAI-Torch/a33b44afa255ddf12c02d5037c373d2123d30b1e/data/batch_context.pickle -------------------------------------------------------------------------------- /data/en.wiki.bpe.vs25000.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jordiclive/Convert-PolyAI-Torch/a33b44afa255ddf12c02d5037c373d2123d30b1e/data/en.wiki.bpe.vs25000.model -------------------------------------------------------------------------------- /lint.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -uo pipefail 4 | set +e 5 | 6 | FAILURE=false 7 | 8 | echo "safety" 9 | pipenv run safety check -r requirements.txt -r requirements-dev.txt || FAILURE=true 10 | 11 | echo 'black' 12 | pipenv run black src --line-length 120 || FAILURE=true 13 | 14 | # echo "pydocstyle" 15 | # pipenv run pydocstyle src || FAILURE=true 16 | 17 | #echo "mypy" 18 | #mypy src || FAILURE=true 19 | 20 | echo "autoflake" 21 | pipenv run autoflake --remove-all-unused-imports -i -r src || FAILURE=true 22 | 23 | #echo "pylint" 24 | #pylint src || FAILURE=true 25 | 26 | #echo "bandit" 27 | #bandit -r src || FAILURE=true 28 | 29 | echo "isort" 30 | pipenv run isort -rc src || FAILURE=true 31 | 32 | ##echo "pycodestyle" 33 | ##pycodestyle src || FAILURE=true 34 | # 35 | #echo "flake8" 36 | ##flake8 src || FAILURE=true 37 | 38 | echo "pytest" 39 | pipenv run pytest -s --ignore=tests/test_model.py tests/ || FAILURE=true 40 | 41 | echo "training evaluation" 42 | export PYTHONPATH=. 43 | pipenv run python -m unittest 44 | 45 | if [ "$FAILURE" = true ]; then 46 | echo "Linting failed" 47 | exit 1 48 | fi 49 | echo "Linting passed" 50 | exit 0 51 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | #dependency check 2 | safety 3 | bandit 4 | 5 | # for automatic code style 6 | black 7 | isort 8 | autoflake 9 | 10 | # linting 11 | flake8 12 | 13 | 14 | # for testing 15 | pytest 16 | pytest-cov 17 | codecov 18 | pcoverage 19 | 20 | # for documentation 21 | sphinx 22 | git+https://github.com/pandas-dev/pandas-sphinx-theme.git@master#egg=pandas-sphinx-theme 23 | 24 | 25 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.5.1+cu101 -f https://download.pytorch.org/whl/torch_stable.htmltqdm 2 | torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.htmltqdm 3 | sklearn 4 | tensorboard 5 | sentencepiece 6 | pytorch-lightning==0.9.0rc2 7 | 8 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # fixers 2 | [black] 3 | line-length = 120 4 | 5 | [tool:isort] 6 | line_length = 120 7 | multi_line_output = 3 8 | include_trailing_comma = True 9 | skip=docs/conf.py 10 | 11 | # linters 12 | [flake8] 13 | max-line-length = 120 14 | ignore = E501, E203 15 | 16 | [pycodestyle] 17 | max-line-length = 120 18 | ignore = E203,W503 19 | 20 | [pydocstyle] 21 | convention = numpy 22 | add-ignore = D102,D103,D104,D105,D200,D205,D400 23 | 24 | [mypy] 25 | ignore_missing_imports = True 26 | 27 | [mypy-scrap_Tests.pyi] 28 | ignore_errors = True 29 | 30 | #testing 31 | [tool:pytest] 32 | addopts = -ra -v -l 33 | # filter warnings from tensorflow 34 | # imp module -> tensorflow_core 35 | # abc module -> protobuf, botocore 36 | filterwarnings = 37 | ignore:the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses 38 | ignore:Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working 39 | ignore:Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working 40 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jordiclive/Convert-PolyAI-Torch/a33b44afa255ddf12c02d5037c373d2123d30b1e/src/__init__.py -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import NamedTuple 3 | 4 | dirname, _ = os.path.split(os.path.dirname(__file__)) 5 | 6 | class ConveRTModelConfig(NamedTuple): 7 | 8 | num_embed_hidden: int = 512 9 | feed_forward1_hidden: int = 2048 10 | feed_forward2_hidden: int = 1024 11 | num_attention_project: int = 64 12 | vocab_size: int = 25000 13 | num_encoder_layers: int = 6 14 | dropout_rate: float = 0.0 15 | n: int = 121 16 | relative_attns: list = [3, 5, 48, 48, 48, 48] 17 | num_attention_heads: int = 2 18 | token_sequence_truncation: int = 60 19 | 20 | 21 | class ConveRTTrainConfig(NamedTuple): 22 | 23 | sp_model_path: str = os.path.join(dirname, "data/en.wiki.bpe.vs25000.model") 24 | dataset_path: str = os.path.join(dirname, "data/sample-dataset.json") 25 | test_dataset_path: str = "data/sample-dataset.json" 26 | 27 | model_save_dir: str = "lightning_logs/checkpoints/" 28 | log_dir: str = "lightning_logs" 29 | device: str = "cpu" 30 | use_data_paraller: bool = True 31 | 32 | is_reddit: bool = True 33 | 34 | train_batch_size: int = 64 35 | test_batch_size: int = 256 36 | 37 | split_size: int = 8 38 | learning_rate: float = 1e-3 # final learning rate ie 'lr annealed to' 39 | lr_warmup_start: float = 0.1 # start of lr before initial linear warmup section 40 | lr_warmup_end: float = 1.0 # end of linear warmup section , annealing begin 41 | warmup_batch: float = 10000 # how many batches linear warm up for 42 | final_batch: float = 1e8 # final batch of training when want learning rate 43 | learning_rate_end: float = 0.0001 44 | epochs: int = 10 45 | grad_norm_clip: float = 1.0 46 | smoothing: float = 0.2 47 | l2_weight_decay: float = 1e-5 # note: different from L2 reg, as working with Adam. L2 regularization 48 | # (or any lagrange m on loss) not wise 49 | 50 | 51 | -------------------------------------------------------------------------------- /src/criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LossFunction(nn.Module): 6 | @staticmethod 7 | def cosine_similarity_matrix( 8 | context_embed: torch.Tensor, reply_embed: torch.Tensor 9 | ) -> torch.Tensor: 10 | assert context_embed.size(0) == reply_embed.size(0) 11 | cosine_similarity = torch.matmul( 12 | context_embed, reply_embed.T 13 | ) # both normalized already from last layer. So cosine similarity for batch can be 14 | # efficiently calculated as simply similarity matrix 15 | return cosine_similarity 16 | 17 | def forward( 18 | self, context_embed: torch.Tensor, reply_embed: torch.Tensor 19 | ) -> torch.Tensor: 20 | cosine_similarity = self.cosine_similarity_matrix(context_embed, reply_embed) 21 | j = -torch.sum(torch.diagonal(cosine_similarity)) 22 | 23 | cosine_similarity.diagonal().copy_(torch.zeros(cosine_similarity.size(0))) 24 | # The abel smoothing implemented is not clear from the paper. As not CE loss, have negative sampling. 25 | # I assumed a lessening of how penalized the model is when assigns non zero probs 26 | # to wrong class by increasing the negative component of loss fn by label smoothing mass indicated in paper 27 | 28 | j = 0.8 * j + ( 29 | 0.2 / (cosine_similarity.size(0) * (cosine_similarity.size(0) - 1)) 30 | ) * torch.sum(cosine_similarity) 31 | 32 | 33 | j += torch.sum(torch.logsumexp(cosine_similarity, dim=0)) 34 | 35 | # torch.logsumexp(input, dim, keepdim=False, out=None) 36 | 37 | # Returns the log of summed exponentials of each row of the input tensor in the 38 | # given dimension dim. Very important The computation is numerically stable with logs/exp in loss. 39 | # This torch implementation is done in C. 40 | return j # negative of objective fn in paper as want loss fn 41 | 42 | 43 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from typing import List, NamedTuple 4 | 5 | import pytorch_lightning as pl 6 | import torch 7 | from sentencepiece import SentencePieceProcessor 8 | from torch.nn.functional import pad 9 | from torch.utils.data import DataLoader 10 | from src.config import ConveRTTrainConfig 11 | 12 | 13 | config = ConveRTTrainConfig() 14 | 15 | 16 | # Todo implement BPE from scratch with unk tokens hashed (although may achieve worse results on downstream tasks) as 17 | # perhaps not as general as bpemb's 25000.model 18 | 19 | 20 | @dataclass 21 | class EncoderInputFeature: 22 | input_ids: torch.Tensor 23 | attention_mask: torch.Tensor 24 | position_ids: torch.Tensor 25 | input_lengths: torch.Tensor 26 | 27 | def pad_sequence(self, seq_len: int): 28 | self.input_ids = pad( 29 | self.input_ids, [0, seq_len - self.input_ids.size(0)], "constant", 0 30 | ) 31 | self.attention_mask = pad( 32 | self.attention_mask, 33 | [0, seq_len - self.attention_mask.size(0)], 34 | "constant", 35 | 0, 36 | ) 37 | self.position_ids = pad( 38 | self.position_ids, [0, seq_len - self.position_ids.size(0)], "constant", 0 39 | ) 40 | 41 | 42 | @dataclass 43 | class EmbeddingPair: 44 | context: EncoderInputFeature 45 | reply: EncoderInputFeature 46 | 47 | 48 | class DataModule(pl.LightningDataModule): 49 | # using pytorch lightning, as will save a lot of time downstream, when using multi-gpu, distributed method 50 | # and for managing 16 precision. 51 | def __init__(self): 52 | super().__init__() 53 | self.input_attributes = [ 54 | "input_ids", 55 | "attention_mask", 56 | "position_ids", 57 | "input_lengths", 58 | ] 59 | 60 | def batching_input_features( 61 | self, encoder_inputs: List[EncoderInputFeature] 62 | ) -> EncoderInputFeature: 63 | max_seq_len = max( 64 | [ 65 | int(encoder_input.input_lengths.item()) 66 | for encoder_input in encoder_inputs 67 | ] 68 | ) 69 | for encoder_input in encoder_inputs: 70 | encoder_input.pad_sequence(max_seq_len) 71 | 72 | batch_features = { 73 | feature_name: torch.stack( 74 | [ 75 | getattr(encoder_input, feature_name) 76 | for encoder_input in encoder_inputs 77 | ], 78 | dim = 0, 79 | ) 80 | for feature_name in self.input_attributes 81 | } 82 | return EncoderInputFeature(**batch_features) 83 | 84 | def convert_collate_fn(self, features: List[EmbeddingPair]) -> EmbeddingPair: 85 | return EmbeddingPair( 86 | context = self.batching_input_features( 87 | [feature.context for feature in features] 88 | ), 89 | reply = self.batching_input_features([feature.reply for feature in features]), 90 | ) 91 | 92 | def train_dataloader(self, train_dataset): 93 | return DataLoader( 94 | train_dataset, 95 | config.train_batch_size, 96 | collate_fn = self.convert_collate_fn, 97 | drop_last = True, # drop last incomplete batch 98 | #num_workers = 8 99 | ) 100 | 101 | 102 | def val_dataloader(self): 103 | # Todo 104 | pass 105 | # return DataLoader() 106 | 107 | def test_dataloader(self): 108 | # Todo 109 | pass 110 | # return DataLoader() 111 | 112 | 113 | class DatasetInstance(NamedTuple): 114 | context: List[str] 115 | response: str 116 | 117 | 118 | def load_instances_from_reddit_json(dataset_path: str) -> List[DatasetInstance]: 119 | instances: List[DatasetInstance] = [] 120 | with open(dataset_path) as f: 121 | for line in f: 122 | x = json.loads(line) 123 | context_keys = sorted([key for key in x.keys() if "context" in key]) 124 | instance = DatasetInstance( 125 | context = [x[key] for key in context_keys], response = x["response"], 126 | ) 127 | instances.append(instance) 128 | return instances 129 | 130 | 131 | class RedditData(torch.utils.data.Dataset): 132 | def __init__( 133 | self, 134 | instances: List[DatasetInstance], 135 | sp_processor: SentencePieceProcessor, 136 | truncation_length: int, 137 | ): 138 | self.sp_processor = sp_processor 139 | self.instances = instances 140 | self.truncation_length = truncation_length 141 | 142 | def __len__(self): 143 | return len(self.instances) 144 | 145 | def __getitem__(self, item): 146 | context_str = self.instances[item].context[0] 147 | context_embedding = self._convert_instance_to_embedding(context_str) 148 | reply_embedding = self._convert_instance_to_embedding( 149 | self.instances[item].response 150 | ) 151 | return EmbeddingPair(context = context_embedding, reply = reply_embedding) 152 | 153 | def _convert_instance_to_embedding(self, input_str: str) -> EncoderInputFeature: 154 | input_ids = self.sp_processor.EncodeAsIds(input_str) 155 | if self.truncation_length: 156 | input_ids = input_ids[: self.truncation_length] 157 | attention_mask = [1 for _ in range(len(input_ids))] # [1]*len(input_ids) 158 | position_ids = [i for i in range(len(input_ids))] # list(range(len(input_ids)) 159 | 160 | return EncoderInputFeature( 161 | input_ids = torch.tensor(input_ids).to(config.device), 162 | attention_mask = torch.tensor(attention_mask).to(config.device), 163 | position_ids = torch.tensor(position_ids).to(config.device), 164 | input_lengths = torch.tensor(len(input_ids)).to(config.device), 165 | ) 166 | -------------------------------------------------------------------------------- /src/lr_decay.py: -------------------------------------------------------------------------------- 1 | """ 2 | Callback function to be passed to the lightning trainer 3 | –Implements the linear warm up schedule followed by cosine annealing, demarcated by current batch idx 4 | """ 5 | 6 | import math 7 | import pytorch_lightning as pl 8 | 9 | # Really not clear from paper, paper starts talking about cosine annealing when discussing 10 | # the cosine similarity measure. Needs clarification 11 | # I assume 0.1 to 1 linear warm up over first 10000 batches then annealed to 0.001 12 | 13 | 14 | class LearningRateDecayCallback(pl.Callback): 15 | def __init__( 16 | self, 17 | config, 18 | lr_decay=True, 19 | ): 20 | super().__init__() 21 | self.lr_warmup_end = config.lr_warmup_end 22 | self.lr_warmup_start = config.lr_warmup_start 23 | self.learning_rate = config.learning_rate 24 | self.warmup_batch = config.warmup_batch 25 | self.final_batch = config.final_batch 26 | 27 | self.lr_decay = lr_decay 28 | 29 | 30 | def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): 31 | """ 32 | 33 | :param trainer: 34 | :type trainer: 35 | :param pl_module: 36 | :type pl_module: 37 | :param batch: 38 | :type batch: 39 | :param batch_idx: 40 | :type batch_idx: 41 | :param dataloader_idx: 42 | :type dataloader_idx: 43 | """ 44 | optimizer = trainer.optimizers[0] 45 | 46 | if self.lr_decay: 47 | if batch_idx < self.warmup_batch: 48 | # linear warmup, in paper: start from 0.1 to 1 over 10000 batches 49 | lr_mult = float(batch_idx) / float(max(1, self.warmup_batch)) 50 | lr = self.lr_warmup_start + lr_mult * ( 51 | self.lr_warmup_end - self.lr_warmup_start 52 | ) 53 | 54 | else: 55 | # Cosine learning rate decay 56 | progress = float(batch_idx - self.warmup_batch) / float( 57 | max(1, self.final_batch - self.warmup_batch) 58 | ) 59 | 60 | lr = max( 61 | self.learning_rate 62 | + 0.5 63 | * (1.0 + math.cos(math.pi * progress)) 64 | * (self.lr_warmup_end - self.learning_rate), 65 | self.learning_rate, 66 | ) 67 | for param_group in optimizer.param_groups: 68 | param_group["lr"] = lr 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from collections import OrderedDict 4 | 5 | import numpy as np 6 | import pytorch_lightning as pl 7 | import torch 8 | 9 | from src.config import ConveRTModelConfig, ConveRTTrainConfig 10 | from src.criterion import LossFunction 11 | 12 | from src.model_components import FeedForward2, TransformerLayers 13 | 14 | import argparse 15 | from sentencepiece import SentencePieceProcessor 16 | from src.dataset import DataModule, RedditData, load_instances_from_reddit_json 17 | 18 | from src.lr_decay import LearningRateDecayCallback 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | def set_seed(seed): 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed_all(seed) 28 | 29 | 30 | def find_subword_params(model): 31 | """Long winded helper fn to return Subword Embedding Params for clipping, as they are the only parameters that 32 | are gradient clipped in the paper, only calculated once after model instantiation, but before training""" 33 | embeds = set() 34 | for mn, m in model.named_modules(): 35 | for pn, p in m.named_parameters(): 36 | if mn.startswith("transformer_layers.subword_embedding"): 37 | fpn = "%s.%s" % (mn, pn) if mn else pn # full param name 38 | 39 | embeds.add(fpn) 40 | param_dict = {pn: p for pn, p in model.named_parameters()} 41 | 42 | return [param_dict[pn] for pn in sorted(list(embeds))], embeds 43 | 44 | 45 | # todo need to write own 46 | # lightning optimizer step to include torch.nn.utils.clip_grad_norm_(find_subword_params(model), config.grad_norm_clip), 47 | 48 | 49 | class SingleContextConvert(pl.LightningModule): 50 | def __init__( 51 | self, model_config: ConveRTModelConfig, train_config: ConveRTTrainConfig 52 | ): 53 | super().__init__() 54 | 55 | self.model_config = model_config 56 | self.train_config = train_config 57 | self.transformer_layers = TransformerLayers(model_config) 58 | self.ff2_context = FeedForward2(model_config) 59 | self.ff2_reply = FeedForward2(model_config) 60 | self.loss_function = LossFunction() 61 | 62 | self.weight_decay = train_config.l2_weight_decay 63 | 64 | self.hparams = self.train_config._field_defaults 65 | self.hparams.update(self.model_config._field_defaults) 66 | self.subword_params = None 67 | 68 | logger.info( 69 | "number of parameters: %e", sum(p.numel() for p in self.parameters()) 70 | ) 71 | def register_subword_params(self): 72 | self.subword_params = find_subword_params(self)[0] 73 | 74 | def forward(self, x): 75 | return self.transformer_layers(x) 76 | 77 | def backward(self, trainer, loss, optimizer, optimizer_idx): 78 | """override hook of lightning as want specific grad norm clip of only subword embedding parameters, after loss.backward() 79 | but before optimizer step""" 80 | loss.backward() 81 | torch.nn.utils.clip_grad_norm_(self.subword_params, self.train_config.grad_norm_clip) 82 | 83 | 84 | def configure_optimizers(self): 85 | """ 86 | here I did not implement weight decay on bias and Layernorm layers as is typical in modern NLP papers. 87 | I do not think the paper specified params to avoid weight decay on 88 | :return: 89 | :rtype: 90 | """ 91 | # create the optimizer, here I did not implement weight decay on bias and weight as is customary in modern 92 | # NLP papers. 93 | no_decay = ["bias", "LayerNorm.weight"] 94 | params_decay = [ 95 | p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay) 96 | ] 97 | params_nodecay = [ 98 | p for n, p in self.named_parameters() if any(nd in n for nd in no_decay) 99 | ] 100 | optim_groups = [ 101 | {"params": params_decay, "weight_decay": self.hparams.l2_weight_decay}, 102 | {"params": params_nodecay, "weight_decay": 0.0}, 103 | ] 104 | optimizer = torch.optim.AdamW( 105 | optim_groups, lr = self.hparams.learning_rate 106 | ) 107 | return optimizer 108 | 109 | def training_step(self, batch, batch_idx): 110 | batch_context = batch.context 111 | batch_reply = batch.reply 112 | rx = self(batch_context) 113 | ry = self(batch_reply) 114 | hx = self.ff2_context(rx, batch_context.attention_mask) 115 | hy = self.ff2_reply(ry, batch_reply.attention_mask) 116 | 117 | loss = self.loss_function(hx, hy) 118 | 119 | tqdm_dict = {"train_loss": loss} 120 | output = OrderedDict( 121 | {"loss": loss, "progress_bar": tqdm_dict, "log": tqdm_dict} 122 | ) 123 | # result = pl.TrainResult(minimize=loss, checkpoint_on=loss) 124 | # result.log("train_loss", loss) 125 | return output 126 | 127 | def validation_step(self, batch, batch_idx): 128 | output = self.training_step(batch, batch_idx) 129 | val_output = {"val_loss": output["loss"]} 130 | return val_output 131 | 132 | 133 | def _parse_args(): 134 | """Parse command-line arguments.""" 135 | 136 | parser = argparse.ArgumentParser() 137 | #parser.add_argument("--gpus", type = int, default = 1) 138 | #parser.add_argument("--precision", type = int, default = 16) 139 | parser.add_argument("--progress_bar_refresh_rate", type = int, default = 1) 140 | parser.add_argument("--row_log_interval", type = int, default = 1) 141 | 142 | args = parser.parse_args() 143 | 144 | return args 145 | 146 | 147 | def main(**kwargs): 148 | set_seed(1) 149 | train_config = ConveRTTrainConfig() 150 | model_config = ConveRTModelConfig() 151 | tokenizer = SentencePieceProcessor() 152 | args = _parse_args() 153 | tokenizer.Load(train_config.sp_model_path) 154 | train_instances = load_instances_from_reddit_json(train_config.dataset_path) 155 | RD = RedditData(train_instances, tokenizer, 60) 156 | dm = DataModule() 157 | train_loader = dm.train_dataloader(RD) 158 | model = SingleContextConvert(model_config, train_config) 159 | lr_decay = LearningRateDecayCallback(train_config) 160 | model.register_subword_params() 161 | 162 | trainer = ( 163 | pl.Trainer.from_argparse_args(args, callbacks = [lr_decay],**kwargs) 164 | ) # ,checkpoint_callback = checkpoint_callback) # ,resume_from_checkpoint=) 165 | trainer.fit(model, train_dataloader = train_loader, val_dataloaders = train_loader) 166 | 167 | 168 | if __name__ == "__main__": 169 | main(fast_dev_run=True) 170 | -------------------------------------------------------------------------------- /src/model_components.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.nn.modules.normalization import LayerNorm 8 | 9 | from src.config import ConveRTModelConfig 10 | from src.dataset import EncoderInputFeature 11 | 12 | 13 | 14 | 15 | def circulant_mask(n: int, window: int) -> torch.Tensor: 16 | """ Calculate the relative attention mask, calculated once when model instatiated, as a subset of this matrix 17 | will be used for a input length less than max. 18 | i,j represent relative token positions in this matrix and in the attention scores matrix, 19 | this mask enables attention scores to be set to 0 if further than the specified window length 20 | 21 | :param n: a fixed parameter set to be larger than largest max sequence length across batches 22 | :param window: [window length], 23 | :return relative attention mask 24 | """ 25 | circulant_t = torch.zeros(n, n) 26 | # [0, 1, 2, ..., window, -1, -2, ..., window] 27 | offsets = [0] + [i for i in range(window + 1)] + [-i for i in range(window + 1)] 28 | if window >= n: 29 | return torch.ones(n, n) 30 | for offset in offsets: 31 | # size of the 1-tensor depends on the length of the diagonal 32 | circulant_t.diagonal(offset=offset).copy_(torch.ones(n - abs(offset))) 33 | return circulant_t 34 | 35 | 36 | class SubwordEmbedding(nn.Module): 37 | def __init__(self, config: ConveRTModelConfig): 38 | """ init embedding model 39 | 40 | :param config: model.config 41 | :type config: ConveRTModelConfig 42 | """ 43 | super().__init__() 44 | self.subword_embed = nn.Embedding( 45 | config.vocab_size, config.num_embed_hidden 46 | ) #eg. 25000 x 512 47 | self.m1_positional_embed = nn.Embedding(47, config.num_embed_hidden) 48 | self.m2_positional_embed = nn.Embedding(11, config.num_embed_hidden) 49 | 50 | def forward( 51 | self, input_ids: torch.Tensor, position_ids: torch.Tensor 52 | ) -> torch.Tensor: 53 | """Subword Embedding and Positional encoding, takes in sequence of sub words, calculates 54 | subword embeddings and adds positional encodings 55 | 56 | m1_positional_embed is calculated with m1_embed_weight(mod(position_ids, 47)) 57 | m2_positional_embed is calculated with m1_embed_weight(mod(position_ids, 11)) 58 | 59 | :param input_ids: raw token ids 60 | :type input_ids: torch.LongTensor 61 | :param position_ids: [description], defaults to None 62 | :type position_ids: torch.LongTensor, optional 63 | :return: return embedding sum (position{m1, m2} + sub-word) 64 | :rtype: torch.Tensor 65 | 66 | """ 67 | subword_embed = self.subword_embed.forward( 68 | input_ids 69 | ) # B x T x d_emb eg. 64 x 47 x 512 70 | m1_positional_embed = self.m1_positional_embed.forward( 71 | torch.fmod(position_ids, 47) 72 | ) 73 | m2_positional_embed = self.m2_positional_embed.forward( 74 | torch.fmod(position_ids, 11) 75 | ) # B x T x d_emb 76 | embedding = subword_embed + m1_positional_embed + m2_positional_embed 77 | return embedding 78 | 79 | 80 | class SelfAttention( 81 | nn.Module 82 | ): 83 | """normal query, key, value based self attention but with relative attention functionality 84 | and a learnable bias encoding relative token position which is added to the attention scores before the softmax""" 85 | 86 | def __init__(self, config: ConveRTModelConfig, relative_attention: int): 87 | """init self attention weight of each key, query, value and output projection layer. 88 | 89 | :param config: model config 90 | :type config: ConveRTModelConfig 91 | """ 92 | super().__init__() 93 | 94 | self.config = config 95 | self.query = nn.Linear(config.num_embed_hidden, config.num_attention_project) 96 | self.key = nn.Linear(config.num_embed_hidden, config.num_attention_project) 97 | self.value = nn.Linear(config.num_embed_hidden, config.num_attention_project) 98 | 99 | self.softmax = nn.Softmax(dim=-1) 100 | self.output_projection = nn.Linear( 101 | config.num_attention_project, config.num_embed_hidden 102 | ) 103 | self.bias = torch.nn.Parameter(torch.randn(config.n), requires_grad=True) 104 | stdv = 1.0 / math.sqrt(self.bias.data.size(0)) 105 | self.bias.data.uniform_(-stdv, stdv) 106 | self.relative_attention = relative_attention 107 | self.n = self.config.n 108 | self.half_n = self.n // 2 109 | self.register_buffer( 110 | "relative_mask", 111 | circulant_mask(config.token_sequence_truncation, self.relative_attention), 112 | ) 113 | 114 | def forward( 115 | self, attn_input: torch.Tensor, attention_mask: torch.Tensor 116 | ) -> torch.Tensor: 117 | """ calculate self-attention of query, key and weighted to value at the end. 118 | self-attention input is projected by linear layer at the first time. 119 | applying attention mask for ignore pad index attention weight. Relative attention mask 120 | applied and a learnable bias added to the attention scores. 121 | return value after apply output projection layer to value * attention 122 | 123 | :param attn_input: [description] 124 | :type attn_input: [type] 125 | :param attention_mask: [description], defaults to None 126 | :type attention_mask: [type], optional 127 | :return: [description] 128 | :rtype: [type] 129 | """ 130 | self.T = attn_input.size()[1] 131 | # input is B x max seq len x n_emb 132 | _query = self.query.forward(attn_input) 133 | _key = self.key.forward(attn_input) 134 | _value = self.value.forward(attn_input) 135 | 136 | # scaled dot product 137 | attention_scores = torch.matmul(_query, _key.transpose(1, 2)) 138 | attention_scores = attention_scores / math.sqrt( 139 | self.config.num_attention_project 140 | ) 141 | 142 | # Relative attention 143 | 144 | # extended_attention_mask = attention_mask.to(attention_scores.device) # fp16 compatibility 145 | extended_attention_mask = (1.0 - attention_mask.unsqueeze(-1)) * -10000.0 146 | attention_scores = attention_scores + extended_attention_mask 147 | 148 | 149 | # fix circulant_matrix to matrix of size 60 x60 (max token truncation_length, 150 | # register as buffer, so not keep creating masks of different sizes. 151 | 152 | attention_scores = attention_scores.masked_fill( 153 | self.relative_mask.unsqueeze(0)[:, : self.T, : self.T] == 0, float("-inf") 154 | ) 155 | 156 | # Learnable bias vector is used of max size,for each i, different subsets of it are added to the scores, where the permutations 157 | # depend on the relative position (i-j). this way cleverly allows no loops. bias vector is 2*max truncation length+1 158 | # so has a learnable parameter for each eg. (i-j) /in {-60,...60} . 159 | 160 | ii, jj = torch.meshgrid(torch.arange(self.T), torch.arange(self.T)) 161 | B_matrix = self.bias[self.n // 2 - ii + jj] 162 | 163 | attention_scores = attention_scores + B_matrix.unsqueeze(0) 164 | 165 | attention_scores = self.softmax(attention_scores) 166 | output = torch.matmul(attention_scores, _value) 167 | 168 | output = self.output_projection(output) 169 | 170 | return output # B x T x num embed hidden 64 x eg. 47 x 512 171 | 172 | 173 | class FeedForward1(nn.Module): 174 | """ feed-forward 1 is the 175 | standard FFN layer also used by Vaswani et al. (2017),""" 176 | 177 | def __init__( 178 | self, input_hidden: int, intermediate_hidden: int, dropout_rate: float = 0.0 179 | ): 180 | # 512 2048 181 | """ 182 | :param input_hidden: first-hidden layer input embed-dim 183 | :type input_hidden: int 184 | :param intermediate_hidden: layer-(hidden)-layer middle point weight 185 | :type intermediate_hidden: int 186 | :param dropout_rate: dropout rate, defaults to None 187 | :type dropout_rate: float, optional 188 | """ 189 | super().__init__() 190 | 191 | self.linear_1 = nn.Linear(input_hidden, intermediate_hidden) 192 | self.dropout = nn.Dropout(dropout_rate) 193 | self.linear_2 = nn.Linear(intermediate_hidden, input_hidden) 194 | 195 | def forward(self, x: torch.Tensor) -> torch.Tensor: 196 | """forward through fully-connected 2-layer 197 | 198 | :param x: F input 199 | :type x: torch.Tensor 200 | :return: return F output 201 | :rtype: torch.Tensor 202 | """ 203 | x = F.gelu(self.linear_1(x)) 204 | return self.linear_2(self.dropout(x)) 205 | 206 | 207 | class SharedInnerBlock(nn.Module): 208 | """ Inner 'Transformer' block, this block is repeated six times in the original paper with respective relative attentions 209 | [3, 5, 48, 48, 48, 48] 210 | 211 | """ 212 | 213 | def __init__(self, config: ConveRTModelConfig, relative_attn: int): 214 | super().__init__() 215 | """ 216 | :param config: model config 217 | :type config: ConveRTModelConfig 218 | :param config: relative attention 219 | :type config: int 220 | 221 | """ 222 | self.config = config 223 | self.self_attention = SelfAttention(config, relative_attn) 224 | self.norm1 = LayerNorm(config.num_embed_hidden) # 512 225 | self.dropout = nn.Dropout(config.dropout_rate) 226 | self.ff1 = FeedForward1( 227 | config.num_embed_hidden, config.feed_forward1_hidden, config.dropout_rate 228 | ) 229 | self.norm2 = LayerNorm(config.num_embed_hidden) 230 | 231 | def forward(self, x: torch.Tensor, attention_mask: int) -> torch.Tensor: 232 | """calculating single Transformer block 233 | 234 | 1. single-self attention (EMBED_DIM -> ATTEN_PROJ -> EMBED_DIM) 235 | 2. first residual connection -> layer norm 236 | 3. feed-forward-1 layer (EMBED_DIM -> FFD-1-DIM -> EMBED_DIM) 237 | 4. second residual connection -> layer norm 238 | 239 | :param x: embed_output: sub-word embedding + positional encoding 240 | :type x: embed_output: torch.Tensor 241 | :param attention_mask: 1.0 for token position, 0.0 for padding position, defaults to None 242 | :type attention_mask: Optional[torch.Tensor], optional 243 | :return: Transformer block forward output 244 | :rtype: torch.Tensor 245 | 246 | """ 247 | # think better practice to relabel same var , although is more confusing to read. 248 | x = x + self.self_attention(x, attention_mask=attention_mask) 249 | x = self.norm2(x) 250 | x = x + self.ff1(x) 251 | return self.norm2(x) 252 | 253 | 254 | 255 | 256 | # pretty basic, just single head. but done many times, stack to have another dimension (4 with batches).# so get stacks of B x H of attention scores T x T.. 257 | # then matrix multiply these extra stacks with the v 258 | # (B xnh)x T xT . (Bx nh xTx hs) gives (B Nh) T x hs stacks. now hs is set to be final dimension/ number of heads, so reorder the stacks (concatenating them) 259 | # can have optional extra projection layer, but doing that later 260 | 261 | 262 | class MultiheadAttention(nn.Module): 263 | """Standard non causal MHA, Half Hugging Face/Half Andrej Karpathy implementation, 264 | no need to mask as after previous layers""" 265 | 266 | def __init__(self, config: ConveRTModelConfig): 267 | super().__init__() 268 | self.num_attention_heads = config.num_attention_heads 269 | self.num_attn_proj = config.num_embed_hidden * config.num_attention_heads 270 | self.attention_head_size = int(self.num_attn_proj / self.num_attention_heads) 271 | self.all_head_size = self.num_attention_heads * self.attention_head_size 272 | 273 | self.query = nn.Linear(config.num_embed_hidden, self.num_attn_proj) 274 | self.key = nn.Linear(config.num_embed_hidden, self.num_attn_proj) 275 | self.value = nn.Linear(config.num_embed_hidden, self.num_attn_proj) 276 | 277 | self.dropout = nn.Dropout(config.dropout_rate) 278 | 279 | def forward( 280 | self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None 281 | ) -> torch.Tensor: 282 | B, T, _ = hidden_states.size() 283 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 284 | k = ( 285 | self.key(hidden_states) 286 | .view(B, T, self.num_attention_heads, self.attention_head_size) 287 | .transpose(1, 2) 288 | ) # (B, nh, T, hs) 289 | q = ( 290 | self.query(hidden_states) 291 | .view(B, T, self.num_attention_heads, self.attention_head_size) 292 | .transpose(1, 2) 293 | ) # (B, nh, T, hs) 294 | v = ( 295 | self.value(hidden_states) 296 | .view(B, T, self.num_attention_heads, self.attention_head_size) 297 | .transpose(1, 2) 298 | ) # (B, nh, T, hs) 299 | 300 | attention_scores = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 301 | 302 | if attention_mask is not None: 303 | attention_mask = attention_mask[:, None, None, :] 304 | attention_mask = (1.0 - attention_mask) * -10000.0 305 | 306 | attention_scores = attention_scores + attention_mask 307 | 308 | attention_scores = F.softmax(attention_scores, dim=-1) 309 | 310 | attention_scores = self.dropout(attention_scores) 311 | 312 | y = attention_scores @ v 313 | 314 | y = y.transpose(1, 2).contiguous().view(B, T, self.num_attn_proj) 315 | 316 | return y 317 | 318 | 319 | class TransformerLayers(nn.Module): 320 | def __init__(self, config: ConveRTModelConfig): 321 | super().__init__() 322 | self.config = config 323 | 324 | self.subword_embedding = SubwordEmbedding(config) 325 | self.transformer_layers = nn.ModuleList( 326 | [SharedInnerBlock(config, window) for window in config.relative_attns] 327 | ) 328 | self.MHA = MultiheadAttention(config) 329 | 330 | def forward(self, encoder_input: EncoderInputFeature) -> torch.Tensor: 331 | input_ids = encoder_input.input_ids 332 | position_ids = encoder_input.position_ids 333 | attention_mask = encoder_input.attention_mask 334 | output = self.subword_embedding(input_ids, position_ids) 335 | for l in self.transformer_layers: 336 | output = l(output, attention_mask) 337 | output = self.MHA(output) 338 | return output 339 | 340 | 341 | class FeedForward2( 342 | nn.Module 343 | ): # params are not shared for context and reply. so need two sets of weights 344 | """Fully-Connected 3-layer Linear Model""" 345 | 346 | def __init__(self, config): 347 | """ 348 | :param input_hidden: first-hidden layer input embed-dim 349 | :type input_hidden: int 350 | :param intermediate_hidden: layer-(hidden)-layer middle point weight 351 | :type intermediate_hidden: int 352 | :param dropout_rate: dropout rate, defaults to None 353 | :type dropout_rate: float, optional 354 | """ 355 | # paper specifies,skip connections,layer normalization, and orthogonal initialization 356 | 357 | super().__init__() 358 | # 3,679,744 x2 params 359 | self.linear_1 = nn.Linear( 360 | config.feed_forward2_hidden, config.feed_forward2_hidden 361 | ) 362 | self.linear_2 = nn.Linear( 363 | config.feed_forward2_hidden, config.feed_forward2_hidden 364 | ) 365 | # self.linear_3 = nn.Linear( 366 | # config.feed_forward2_hidden, config.feed_forward2_hidden 367 | # ) 368 | self.norm1 = LayerNorm(config.feed_forward2_hidden) 369 | self.norm2 = LayerNorm(config.feed_forward2_hidden) 370 | # self.norm3 = LayerNorm(config.feed_forward2_hidden) 371 | self.final = nn.Linear(config.feed_forward2_hidden, config.num_embed_hidden) 372 | self.orthogonal_initialization() # torch implementation works perfectly out the box, 373 | 374 | def orthogonal_initialization(self): 375 | for l in [ 376 | self.linear_1, 377 | self.linear_2, 378 | ]: # self.linear_3]: 379 | torch.nn.init.orthogonal_(l.weight) 380 | 381 | def forward(self, x: torch.Tensor, attn_msk: torch.Tensor) -> torch.Tensor: 382 | sentence_lengths = attn_msk.sum(1) 383 | 384 | # adding square root reduction projection separately as not a shared. 385 | # part of the diagram torch.Size([64, 50, 1024]) 386 | 387 | # x has dims B x T x 2*d_emb 388 | norms = 1 / torch.sqrt(sentence_lengths.double()).float() # 64 389 | x = norms.unsqueeze(1) * torch.sum(x, dim=1) # 64 x1024 390 | 391 | x = x + F.gelu(self.linear_1(self.norm1(x))) 392 | x = x + F.gelu(self.linear_2(self.norm2(x))) 393 | # x = x + F.gelu(self.linear_3(self.norm3(x))) 394 | 395 | return F.normalize(self.final(x), dim=1, p=2) # 64 512 396 | 397 | 398 | 399 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jordiclive/Convert-PolyAI-Torch/a33b44afa255ddf12c02d5037c373d2123d30b1e/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from sentencepiece import SentencePieceProcessor 3 | 4 | 5 | from src.config import ConveRTTrainConfig 6 | from src.dataset import load_instances_from_reddit_json 7 | 8 | 9 | 10 | @pytest.fixture 11 | def config(): 12 | return ConveRTTrainConfig() 13 | 14 | 15 | @pytest.fixture 16 | def tokenizer() -> SentencePieceProcessor: 17 | tokenizer = SentencePieceProcessor() 18 | tokenizer.Load(config.sp_model_path) 19 | return tokenizer 20 | 21 | 22 | def test_load_instances_from_reddit_json(config): 23 | instances = load_instances_from_reddit_json(config.dataset_path) 24 | assert len(instances) == 1000 25 | 26 | 27 | if __name__ == "__main__": 28 | pytest.main() 29 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from src.model import main 3 | from time import time 4 | 5 | 6 | class TestModelTraining(unittest.TestCase): 7 | 'Check can overfit small batch etc. without issues' 8 | 9 | def test_fast_dev_run(self): 10 | t = time() 11 | try: 12 | main(fast_dev_run = True) 13 | except: 14 | self.fail("Obvious Training Problem!") 15 | time_taken = time() - t 16 | self.assertLess(time_taken, 10) 17 | 18 | 19 | if __name__ == '__main__': 20 | unittest.main() 21 | -------------------------------------------------------------------------------- /tests/test_model_components.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | import torch 4 | import pytest 5 | import pickle 6 | from pathlib import Path 7 | 8 | from src.config import ConveRTModelConfig, ConveRTTrainConfig 9 | from src.model_components import ( 10 | SubwordEmbedding, 11 | SelfAttention, 12 | circulant_mask, 13 | FeedForward1, 14 | SharedInnerBlock, 15 | MultiheadAttention, 16 | TransformerLayers, 17 | FeedForward2, 18 | ) 19 | 20 | SEQ_LEN = 60 21 | relative_attention = 48 22 | 23 | 24 | @pytest.fixture 25 | def model_config(): 26 | return ConveRTModelConfig() 27 | 28 | 29 | @pytest.fixture 30 | def train_config(): 31 | return ConveRTTrainConfig(train_batch_size = 64, split_size = 8, learning_rate = 2e-5) 32 | 33 | 34 | def test_circulant_t(): 35 | assert circulant_mask(50, 47).sum().item() == 2494 36 | try: 37 | circulant_mask(47, 50) 38 | circulant_mask(47, 47) 39 | circulant_mask(47, 45) 40 | except ExceptionType: 41 | self.fail("ciculant_t Failed") 42 | 43 | 44 | def test_SubwordEmbedding(train_config, model_config): 45 | embedding = SubwordEmbedding(model_config) 46 | input_token_ids = torch.randint(high = model_config.vocab_size, size = (train_config.train_batch_size, SEQ_LEN)) 47 | positional_input = torch.randint(high = model_config.vocab_size, size = (train_config.train_batch_size, SEQ_LEN)) 48 | 49 | embedding_output = embedding(input_ids = input_token_ids, position_ids = positional_input) 50 | 51 | assert embedding_output.size() == (train_config.train_batch_size, SEQ_LEN, model_config.num_embed_hidden,) 52 | 53 | 54 | def test_SelfAttention(model_config, train_config): 55 | attention = SelfAttention(model_config, relative_attention) 56 | 57 | query = torch.rand(train_config.train_batch_size, SEQ_LEN, model_config.num_embed_hidden) 58 | attn_mask = torch.ones(query.size()[:-1], dtype = torch.float) 59 | output = attention(query, attn_mask) 60 | assert output.size() == (train_config.train_batch_size, SEQ_LEN, model_config.num_embed_hidden,) 61 | 62 | 63 | def test_FeedForward1(train_config, model_config): 64 | ff1 = FeedForward1(model_config.num_embed_hidden, model_config.feed_forward1_hidden, model_config.dropout_rate) 65 | embed = torch.rand(train_config.train_batch_size, SEQ_LEN, model_config.num_embed_hidden) 66 | output = ff1(embed) 67 | assert output.size() == embed.size() 68 | 69 | 70 | def test_SharedInnerBlock(train_config, model_config): 71 | from random import randrange 72 | 73 | SIB = SharedInnerBlock(model_config, model_config.relative_attns[randrange(6)]) 74 | embed = torch.rand(train_config.train_batch_size, SEQ_LEN, model_config.num_embed_hidden) 75 | attn_mask = torch.ones(embed.size()[:-1], dtype = torch.float) 76 | out1 = SIB(embed, attn_mask) 77 | assert out1.size() == embed.size() 78 | 79 | 80 | def test_MultiheadAttention(train_config, model_config): 81 | MHA = MultiheadAttention(model_config) 82 | embed = torch.rand(train_config.train_batch_size, SEQ_LEN, model_config.num_embed_hidden) 83 | attn_mask = torch.ones(embed.size()[:-1], dtype = torch.float) 84 | 85 | assert model_config.num_embed_hidden % MHA.num_attention_heads == 0 86 | 87 | assert MHA(embed, attn_mask).size() == ( 88 | train_config.train_batch_size, 89 | SEQ_LEN, 90 | model_config.num_embed_hidden * model_config.num_attention_heads, 91 | ) 92 | 93 | 94 | def test_TransformerLayers(model_config): 95 | TL = TransformerLayers(model_config) 96 | 97 | path = str(Path(__file__).parents[1].resolve() / "data" / "batch_context.pickle") 98 | with open(path, "rb") as input_file: 99 | encoder_input = pickle.load(input_file) 100 | print(type(encoder_input)) 101 | embedding = SubwordEmbedding(model_config) 102 | emb_output = embedding(encoder_input.input_ids, encoder_input.position_ids) 103 | 104 | assert TL(encoder_input).size() == emb_output.size()[:-1] + ( 105 | model_config.num_embed_hidden * model_config.num_attention_heads, 106 | ) 107 | 108 | 109 | def test_FeedForward2(model_config, train_config): 110 | embed = torch.rand( 111 | train_config.train_batch_size, SEQ_LEN, model_config.num_embed_hidden * model_config.num_attention_heads 112 | ) 113 | attn_mask = torch.ones(embed.size()[:-1], dtype = torch.float) 114 | 115 | FF2 = FeedForward2(model_config) 116 | assert FF2(embed, attn_mask).size() == (train_config.train_batch_size, model_config.num_embed_hidden) 117 | 118 | 119 | if __name__ == "__main__": 120 | pytest.main() 121 | --------------------------------------------------------------------------------