├── .gitignore ├── LICENSE ├── PostPruning.ipynb ├── README.md ├── pics └── overview.png ├── scripts ├── bert_base_uncased.json ├── glue │ ├── config_prunedistiller.py │ ├── main.py │ ├── predict_function.py │ ├── utils.py │ └── utils_glue.py ├── modeling_prunebert.py ├── pruners_and_distiller │ ├── distiller.py │ ├── pruners.py │ └── utils.py ├── run_glue.sh ├── run_squad.sh └── squad │ ├── config_prunedistiller.py │ ├── evaluate_squad.py │ ├── main.py │ └── utils.py ├── teacher_models ├── config.json └── vocab.txt └── textpruner ├── __init__.py ├── commands ├── __init__.py ├── functions.py ├── textpruner_cli.py └── utils.py ├── configurations.py ├── extentions ├── configurations.py └── pruner.py ├── model_map.py ├── model_utils ├── __init__.py ├── albert.py ├── bart.py ├── bert.py ├── electra.py ├── model_structure.py ├── mt5.py ├── roberta.py ├── t5.py ├── utils.py ├── xlm.py └── xlm_roberta.py ├── pruners ├── __init__.py ├── pipeline_pruner.py ├── transformer_pruner.py ├── utils.py └── vocabulary_pruner.py ├── tokenizer_utils ├── __init__.py ├── mt5_sp_tokenizer.py ├── roberta_gpt2_tokenizer.py ├── sp_tokenizer.py ├── subword_tokenizer.py ├── t5_sp_tokenizer.py ├── utils.py ├── xlm_tokenizer.py └── xlmr_sp_tokenizer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /PostPruning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "bert_config_file ='teacher_models/config.json'\n", 20 | "# specify your pruned model\n", 21 | "ckpt_file='pruned_models/pd-sst2-05/lr3e20_s_bs32_0.4_pf1_IS0.998_Reg3e-1_E192/gs42080.pt'" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "device='cuda'\n", 31 | "import torch\n", 32 | "import os\n", 33 | "os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION']='python'\n", 34 | "\n", 35 | "from modeling_prunebert import BertModel as PrunedBertModel\n", 36 | "from modeling_prunebert import BertForSequenceClassification\n", 37 | "from modeling_prunebert import set_head_cuts\n", 38 | "from transformers import BertConfig\n", 39 | "from textpruner import summary,inference_time\n", 40 | "from textpruner import TransformerPruner\n", 41 | "from textpruner.extentions.pruner import FineGrainedPruner\n", 42 | "\n", 43 | "config = BertConfig.from_json_file(bert_config_file)\n", 44 | "config.proj_size = 192\n", 45 | "\n", 46 | "state_dict = torch.load(ckpt_file,map_location=device)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "# restore weights\n", 56 | "state_dict_items=list(state_dict.items())\n", 57 | "for k,v in state_dict_items:\n", 58 | " if k.endswith('_mask'):\n", 59 | " state_dict[k[:-5]] = state_dict[k] * state_dict[k[:-5]+'_orig']\n", 60 | "keys = [k for k in state_dict.keys() if k.endswith('_orig')]\n", 61 | "for k in keys:\n", 62 | " del state_dict[k]" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": { 69 | "scrolled": true 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "model = BertForSequenceClassification.from_pretrained(None,config=config,state_dict=state_dict)\n", 74 | "model.to(device)\n", 75 | "model.eval();" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "ffn_mask_list = [state_dict[f'bert.encoder.layer.{i}.output.dense.weight_mask'][0] for i in range(12)]\n", 85 | "ffn_mask = torch.stack(ffn_mask_list)\n", 86 | "qk_mask_list = [state_dict[f'bert.encoder.layer.{i}.attention.self.query.bias_mask'] for i in range(12)]\n", 87 | "vo_mask_list = [state_dict[f'bert.encoder.layer.{i}.attention.self.value.bias_mask'] for i in range(12)]\n", 88 | "qk_head_size_list = [t.reshape(12,64).sum(-1) for t in qk_mask_list]\n", 89 | "vo_head_size_list = [t.reshape(12,64).sum(-1) for t in vo_mask_list]\n", 90 | "\n", 91 | "# make qk_mask and vo_mask consistent\n", 92 | "def make_qk_vo_consistency(qk_mask_list,vo_mask_list):\n", 93 | " new_qk_mask_list = []\n", 94 | " new_vo_mask_list = []\n", 95 | " assert len(qk_mask_list)==len(vo_mask_list)\n", 96 | " for qk_mask, vo_mask in zip(qk_mask_list, vo_mask_list):\n", 97 | " if vo_mask.sum()==0: #important for empty MHA\n", 98 | " new_qk_mask = []\n", 99 | " new_vo_mask = []\n", 100 | " else:\n", 101 | " new_qk_mask = []\n", 102 | " new_vo_mask = []\n", 103 | " qk_head_mask = qk_mask.reshape(12,64)\n", 104 | " vo_head_mask = vo_mask.reshape(12,64)\n", 105 | " for i,(qk_head, vo_head) in enumerate(zip(qk_head_mask, vo_head_mask)):\n", 106 | " if vo_head.sum()==0 and qk_head.sum()==0 :\n", 107 | " continue\n", 108 | " else:\n", 109 | " new_qk_mask.append(qk_head.clone())\n", 110 | " new_vo_mask.append(vo_head.clone())\n", 111 | " new_qk_mask = torch.stack(new_qk_mask)\n", 112 | " new_vo_mask = torch.stack(new_vo_mask)\n", 113 | " new_qk_mask_list.append(new_qk_mask)\n", 114 | " new_vo_mask_list.append(new_vo_mask)\n", 115 | " return new_qk_mask_list,new_vo_mask_list\n", 116 | "\n", 117 | "consistent_qk_mask_list,consistent_vo_mask_list = make_qk_vo_consistency(qk_mask_list,vo_mask_list)\n", 118 | "consistent_qk_head_size_list = [t.reshape(-1,64).sum(-1).int() if isinstance(t,torch.Tensor) else t for t in consistent_qk_mask_list ]\n", 119 | "consistent_vo_head_size_list = [t.reshape(-1,64).sum(-1).int() if isinstance(t,torch.Tensor) else t for t in consistent_vo_mask_list ]\n", 120 | "\n", 121 | "qk_head_cuts_list = [torch.tensor([0]+list(t)).cumsum(-1) for t in consistent_qk_head_size_list]\n", 122 | "vo_head_cuts_list = [torch.tensor([0]+list(t)).cumsum(-1) for t in consistent_vo_head_size_list]" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "def show_masks(state_dict):\n", 132 | " ffn_mask_list = torch.stack([state_dict[f'bert.encoder.layer.{i}.output.dense.weight_mask'][0] for i in range(12)]).int()\n", 133 | " qk_mask_list = torch.stack([state_dict[f'bert.encoder.layer.{i}.attention.self.query.bias_mask'] for i in range(12)]).int()\n", 134 | " vo_mask_list = torch.stack([state_dict[f'bert.encoder.layer.{i}.attention.self.value.bias_mask'] for i in range(12)]).int()\n", 135 | " qk_head_size_list = [t.reshape(12,64).sum(-1) for t in qk_mask_list]\n", 136 | " #qk_head_size_list = [t[t>0] for t in qk_head_size_list]\n", 137 | " vo_head_size_list = vo_mask_list.reshape(12,12,64).sum(-1)\n", 138 | " #vo_head_size_list = [t[t>0] for t in vo_head_size_list]\n", 139 | " print(\"=====VO=====\")\n", 140 | " for i in range(12):\n", 141 | " print(f\"{i}: {[i for i in vo_head_size_list[i].tolist() if i >0]}, {vo_head_size_list[i].sum().item()}, {(vo_head_size_list[i]>0).sum().item()}\")\n", 142 | " print(\"Total number of heads:\",(vo_head_size_list>0).sum().item())\n", 143 | " print(\"Total number of MHA layer:\",(vo_head_size_list.sum(-1)>0).sum().item())\n", 144 | " \n", 145 | " print(\"=====FFN=====\")\n", 146 | " print(f\"FFN size/12: {ffn_mask_list.sum(-1).tolist()} {(ffn_mask_list).sum().item()/12:.1f}\")\n", 147 | " print(\"Total number of FFN layers:\",(ffn_mask_list.sum(-1)>0).sum().item())\n", 148 | "show_masks(state_dict)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "inputs = torch.randint(low=0,high=10000,size=(128,512),device=device)\n", 158 | "with torch.no_grad():\n", 159 | " mean,std = inference_time(model,[inputs])\n", 160 | " print(mean,std)\n", 161 | " print(summary(model))\n", 162 | " original_outputs = model(inputs)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "# Remove weights where mask==1\n", 172 | "pruner = TransformerPruner(model)\n", 173 | "pruner.prune(ffn_mask=ffn_mask, save_model=False)\n", 174 | "pruner =FineGrainedPruner(model)\n", 175 | "pruner.prune(QK_mask_list=qk_mask_list,VO_mask_list=vo_mask_list,save_model=False)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "# Remove empty FFN layers and empty MHA layers\n", 185 | "\n", 186 | "from torch import nn\n", 187 | "import types\n", 188 | "def feed_forward_chunk_for_empty_ffn(self, attention_output):\n", 189 | " layer_output = self.output(attention_output)\n", 190 | " return layer_output\n", 191 | "\n", 192 | "def output_forward(self, input_tensor):\n", 193 | " return self.LayerNorm(self.dense.bias + input_tensor)\n", 194 | "\n", 195 | "def attetion_forward_for_empty_attention(self,\n", 196 | " hidden_states,\n", 197 | " attention_mask=None,\n", 198 | " head_mask=None,\n", 199 | " encoder_hidden_states=None,\n", 200 | " encoder_attention_mask=None,\n", 201 | " past_key_value=None,\n", 202 | " output_attentions=False):\n", 203 | " hidden_states = self.output.LayerNorm(self.output.dense.bias + hidden_states)\n", 204 | " return (hidden_states,)\n", 205 | "\n", 206 | "def transform(model: nn.Module,always_ffn=False, always_mha=False):\n", 207 | " base_model = model.base_model\n", 208 | " bert_layers = base_model.encoder.layer\n", 209 | " for layer in bert_layers:\n", 210 | " output = layer.output\n", 211 | " if always_ffn or output.dense.weight.numel()==0: #empty ffn\n", 212 | " print(\"replace ffn\")\n", 213 | " layer.feed_forward_chunk = types.MethodType(feed_forward_chunk_for_empty_ffn,layer)\n", 214 | " layer.output.forward = types.MethodType(output_forward,layer.output)\n", 215 | " attention_output = layer.attention.output\n", 216 | " if always_mha or attention_output.dense.weight.numel()==0: #empty attention\n", 217 | " print(\"replace mha\")\n", 218 | " layer.attention.forward = types.MethodType(attetion_forward_for_empty_attention,layer.attention)\n", 219 | "\n", 220 | "transform(model)" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "set_head_cuts(model,qk_head_cuts_list,vo_head_cuts_list)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "model.eval()\n", 239 | "with torch.no_grad():\n", 240 | " pruned_outputs = model(inputs)" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "# calcuate the discrepency between unpruned and pruned models\n", 250 | "torch.max((pruned_outputs.logits-original_outputs.logits).abs())" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": null, 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "# show model size\n", 260 | "print(summary(model))" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "# inference time\n", 270 | "\n", 271 | "inputs = torch.randint(low=0,high=10000,size=(128,512),device=device)\n", 272 | "with torch.no_grad():\n", 273 | " mean,std = inference_time(model,[inputs])\n", 274 | "\n", 275 | "print(\"Mean: \", mean)\n", 276 | "print(\"Std: \", std)\n" 277 | ] 278 | } 279 | ], 280 | "metadata": { 281 | "kernelspec": { 282 | "display_name": "Python 3 (ipykernel)", 283 | "language": "python", 284 | "name": "python3" 285 | }, 286 | "language_info": { 287 | "codemirror_mode": { 288 | "name": "ipython", 289 | "version": 3 290 | }, 291 | "file_extension": ".py", 292 | "mimetype": "text/x-python", 293 | "name": "python", 294 | "nbconvert_exporter": "python", 295 | "pygments_lexer": "ipython3", 296 | "version": "3.8.16" 297 | } 298 | }, 299 | "nbformat": 4, 300 | "nbformat_minor": 2 301 | } 302 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GRAIN 2 | 3 | This is the repo of paper *Gradient-based Intra-attention Pruning on Pre-trained Language Models* accepted to ACL 2023. 4 | 5 | The repo is under construction. 6 | 7 | ![overview](pics/overview.png) 8 |

The workflow of GRAIN 9 |

10 | 11 | ## Usage 12 | 13 | ### Step 1: Preparation 14 | 15 | 1. Prepare the teacher models. We provide the teacher models of GLUE (MNLI, QNLI, QQP and SST2) and SQuAD tasks, which can be downloaded on [Google Drive](https://drive.google.com/file/d/1gLrdtwS4xfakvWo5-2_BTP3mPB55hohv/view?usp=sharing). Unzip the `teacher_models.zip`, The content of `teacher_models` should be 16 | ``` 17 | teacher_models\ 18 | mnli\ 19 | pytorch_model.bin 20 | qnli\ 21 | pytorch_model.bin 22 | qqp\ 23 | pytorch_model.bin 24 | sst2\ 25 | pytorch_model.bin 26 | squad\ 27 | pytorch_model.bin 28 | config.json 29 | vocab.txt 30 | ``` 31 | 32 | 2. Prepare the GLUE and SQuAD datasets. Put the datasets to `datasets`. 33 | 34 | 35 | ### Step 2: Training/Distillation with Pruning 36 | 37 | We offer examples of training on GLUE and SQuAD. 38 | 39 | **GLUE** 40 | 41 | ``` 42 | cd scripts 43 | bash run_glue.sh 44 | ``` 45 | Change the `TASK` to one of `sst2|mnli|qnli|qqp` to run different tasks. 46 | 47 | 48 | **SQuAD** 49 | 50 | ``` 51 | cd scripts 52 | bash run_squad.sh 53 | ``` 54 | 55 | ### Post Pruning 56 | 57 | 58 | The model obtained in the above step are store with full parameters and pruning masks. We then then perform post-pruning operation to remove the weights from the model. 59 | 60 | Run the `PostPruning.ipynb` and follow the steps there to remove the redundant weights and test the inference speed of the pruned model. -------------------------------------------------------------------------------- /pics/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/GRAIN/8d02105e583265f7385c052c7c532e55f24609d0/pics/overview.png -------------------------------------------------------------------------------- /scripts/bert_base_uncased.json: -------------------------------------------------------------------------------- 1 | [{ 2 | "model_type":"bert", 3 | "ckpt_file":"pretrained-models/bert/base_uncased", 4 | "config_file":"pretrained-models/bert/base_uncased/config.json", 5 | "vocab_file":"pretrained-models/bert/base_uncased/vocab.txt", 6 | "prefix":"bert_uncased", 7 | "tokenizer_kwargs":{"do_lower_case":true} 8 | }] 9 | -------------------------------------------------------------------------------- /scripts/glue/config_prunedistiller.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | from transformers import BertConfig, BertTokenizer, BertForSequenceClassification 5 | MODEL_CLASSES = { 6 | "bert": (BertConfig, BertTokenizer, BertForSequenceClassification), 7 | } 8 | 9 | def parse_specs(speclist): # list of specifications 10 | if isinstance(speclist,str): 11 | with open(speclist,'r') as f: 12 | speclist = json.load(f) 13 | else: 14 | assert isinstance(speclist,dict) 15 | for item in speclist: 16 | model_type = item['model_type'] 17 | config_class, tokenizer_class, model_class = MODEL_CLASSES[model_type] 18 | 19 | item['model_class'] = model_class 20 | 21 | if item['config_file'] is not None: 22 | config = config_class.from_json_file(item['config_file']) 23 | else: 24 | config = None 25 | item['config'] = config 26 | 27 | if item['vocab_file'] is not None: 28 | kwargs = item.get('tokenizer_kwargs',{}) 29 | tokenizer = tokenizer_class(vocab_file=item['vocab_file'],**kwargs) 30 | else: 31 | tokenizer= None 32 | item['tokenizer'] = tokenizer 33 | 34 | return speclist 35 | 36 | def parse_args(opt=None): 37 | parser = argparse.ArgumentParser() 38 | 39 | parser.add_argument("--output_dir", default=None, type=str, required=True, 40 | help="The output directory where the model checkpoints will be written.") 41 | 42 | ## Other parameters 43 | parser.add_argument("--data_dir", default=None, type=str) 44 | parser.add_argument("--max_seq_length", default=128, type=int) 45 | parser.add_argument("--do_train", default=False, action='store_true', help="Whether to run training.") 46 | parser.add_argument("--do_predict", default=False, action='store_true', help="Whether to run eval on the dev set.") 47 | parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") 48 | parser.add_argument("--predict_batch_size", default=8, type=int, help="Total batch size for predictions.") 49 | parser.add_argument("--learning_rate", default=3e-5, type=float, help="The initial learning rate for Adam.") 50 | parser.add_argument("--num_train_epochs", default=3.0, type=float, 51 | help="Total number of training epochs to perform.") 52 | parser.add_argument("--warmup_proportion", default=0.1, type=float, 53 | help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% " 54 | "of training.") 55 | parser.add_argument("--no_cuda", 56 | default=False, 57 | action='store_true', 58 | help="Whether not to use CUDA when available") 59 | parser.add_argument('--gradient_accumulation_steps', 60 | type=int, 61 | default=1, 62 | help="Number of updates steps to accumualte before performing a backward/update pass.") 63 | parser.add_argument("--local_rank", 64 | type=int, 65 | default=-1, 66 | help="local_rank for distributed training on gpus") 67 | parser.add_argument('--fp16', 68 | default=False, 69 | action='store_true', 70 | help="Whether to use 16-bit float precisoin instead of 32-bit") 71 | 72 | parser.add_argument('--seed',type=int,default=10236797) 73 | parser.add_argument('--weight_decay_rate',type=float,default=0.01) 74 | parser.add_argument('--do_eval',action='store_true') 75 | parser.add_argument('--do_test',action='store_true') 76 | parser.add_argument('--PRINT_EVERY',type=int,default=200) 77 | parser.add_argument('--ckpt_frequency',type=int,default=2) 78 | 79 | parser.add_argument('--model_spec_file',type=str) 80 | parser.add_argument('--teacher_model_path',type=str) 81 | parser.add_argument('--max_grad_norm',type=float,default=1.0) 82 | parser.add_argument('--use_MultilingualTSdataset',action='store_true') 83 | parser.add_argument('--taskname',type=str) 84 | parser.add_argument("--adam_epsilon",default=1e-6,type=float) 85 | parser.add_argument("--do_lower_case",action='store_true') #used in decoding? 86 | 87 | parser.add_argument("--end_pruning_at",default=0.7,type=float) 88 | parser.add_argument("--start_pruning_at",default=0.2,type=float) 89 | 90 | parser.add_argument("--end_weights_ratio",default=0.33,type=float) 91 | parser.add_argument("--pruning_frequency",default=50,type=int) 92 | parser.add_argument("--pruner_type",default="Pruner",type=str) 93 | parser.add_argument("--IS_beta",default=0.99,type=float) 94 | parser.add_argument("--is_global",action='store_true') 95 | parser.add_argument("--is_reweight",type=float,default=1) 96 | parser.add_argument("--is_two_ratios",action='store_true') 97 | parser.add_argument("--FFN_weights_ratio",default=None,type=float) 98 | parser.add_argument("--MHA_weights_ratio",default=None,type=float) 99 | parser.add_argument("--score_type",default='grad',type=str,choices=['grad','magnitude-sumabs','magnitude-Linf','magnitude-L1','random']) 100 | parser.add_argument("--output_hidden_states",action='store_true') 101 | parser.add_argument("--dynamic_head_size",action='store_true') 102 | 103 | parser.add_argument("--matching_layers_S",type=str,default=None) 104 | parser.add_argument("--matching_layers_T",type=str,default=None) 105 | 106 | parser.add_argument("--IS_gamma",default=0,type=float) 107 | parser.add_argument("--IS_alpha",default=0.0001,type=float) 108 | parser.add_argument("--IS_alpha_head",default=None,type=float) 109 | parser.add_argument("--IS_alpha_ffn",default=None,type=float) 110 | parser.add_argument("--IS_alpha_mha",default=None,type=float) 111 | parser.add_argument("--no_dbw",action='store_true') 112 | parser.add_argument("--transform_embed",default=0,type=int) 113 | 114 | parser.add_argument("--freeze_embeddings",action='store_true') 115 | 116 | global args 117 | if opt is None: 118 | args = parser.parse_args() 119 | else: 120 | args = parser.parse_args(opt) 121 | return args 122 | 123 | if __name__ == '__main__': 124 | print (args) 125 | parse_args(['--SAVE_DIR','test']) 126 | print(args) 127 | -------------------------------------------------------------------------------- /scripts/glue/main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | # Setup logging 3 | logging.basicConfig( 4 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 5 | level=logging.INFO,) 6 | logger = logging.getLogger(__name__) 7 | 8 | import os,random 9 | import numpy as np 10 | import torch 11 | from transformers import AdamW, get_linear_schedule_with_warmup 12 | from utils import divide_parameters 13 | from pruners_and_distiller.distiller import TrainingConfig, DistillationConfig, PruningConfig 14 | from pruners_and_distiller.distiller import PruneDistiller as PruneDistillerHidden 15 | from pruners_and_distiller.utils import show_masks, transform_embed 16 | from torch.utils.data import DataLoader, RandomSampler 17 | from functools import partial 18 | from predict_function import predict 19 | from utils_glue import output_modes, get_glue_dataset 20 | from config_prunedistiller import parse_specs, parse_args 21 | 22 | def set_seed(args): 23 | random.seed(args.seed) 24 | np.random.seed(args.seed) 25 | torch.manual_seed(args.seed) 26 | if args.n_gpu > 0: 27 | torch.cuda.manual_seed_all(args.seed) 28 | 29 | def args_check(args): 30 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: 31 | logger.warning(f"Output directory ({args.output_dir}) already exists and is not empty.") 32 | # Setup CUDA, GPU & distributed training 33 | if args.local_rank == -1 or args.no_cuda: 34 | if not args.no_cuda and not torch.cuda.is_available(): 35 | raise ValueError("No CUDA available!") 36 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 37 | args.n_gpu = torch.cuda.device_count() if not args.no_cuda else 0 38 | else: 39 | # Initializes the distributed backend which sychronizes nodes/GPUs 40 | #torch.cuda.set_device(args.local_rank) 41 | device = torch.device("cuda", args.local_rank) 42 | #torch.distributed.init_process_group(backend="nccl") 43 | args.n_gpu = 1 44 | args.device = device 45 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 46 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) 47 | return device, args.n_gpu 48 | 49 | def main(): 50 | args = parse_args() 51 | logger.setLevel(logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 52 | 53 | for k,v in vars(args).items(): 54 | logger.info(f"{k}:{v}") 55 | 56 | device, args.n_gpu = args_check(args) 57 | set_seed(args) 58 | os.makedirs(args.output_dir, exist_ok=True) 59 | 60 | if args.local_rank not in [-1, 0]: 61 | torch.distributed.barrier() 62 | 63 | #Build Model and load checkpoint 64 | speclist = parse_specs(args.model_spec_file) 65 | spec = speclist[0] 66 | config, tokenizer, model_class = spec['config'], spec['tokenizer'], spec['model_class'] 67 | ckpt_file = spec['ckpt_file'] 68 | prefix = spec['prefix'] 69 | model_type = spec['model_type'] 70 | return_token_type_ids = (model_type=='bert') 71 | print("debug:",return_token_type_ids) 72 | 73 | 74 | 75 | if args.local_rank == 0: 76 | torch.distributed.barrier() 77 | 78 | #read data 79 | train_dataset = None 80 | eval_datasets = None 81 | num_train_steps = None 82 | 83 | eval_langs=train_langs=['en'] 84 | train_dataset = get_glue_dataset(args.taskname, args.data_dir,'train', tokenizer,args.max_seq_length,return_token_type_ids) 85 | num_labels = train_dataset.features['label'].num_classes 86 | eval_datasets = [] 87 | split = 'test' if args.do_test else 'validation' 88 | if args.taskname=='mnli': 89 | eval_langs = ['m','mm'] 90 | eval_datasets = [get_glue_dataset(args.taskname, args.data_dir,split+'_matched', tokenizer,args.max_seq_length,return_token_type_ids), 91 | get_glue_dataset(args.taskname, args.data_dir,split+'_mismatched', tokenizer,args.max_seq_length,return_token_type_ids)] 92 | else: 93 | eval_dataset = get_glue_dataset(args.taskname, args.data_dir,split, tokenizer,args.max_seq_length,return_token_type_ids) 94 | eval_datasets = [eval_dataset] 95 | 96 | 97 | logger.info("Data loaded") 98 | 99 | config.num_labels = num_labels 100 | if args.output_hidden_states is True: 101 | config.output_hidden_states=True 102 | model = model_class.from_pretrained(ckpt_file,config=config) 103 | state_dict = torch.load(args.teacher_model_path,map_location='cpu') 104 | if args.transform_embed>0: 105 | transform_embed(model,args.transform_embed) 106 | model_T = model_class.from_pretrained(None,config=config,state_dict=state_dict) 107 | callback_func = None 108 | if args.do_predict: 109 | callback_func = partial(predict, eval_datasets=eval_datasets,eval_lang=eval_langs, args=args,taskname=args.taskname) 110 | if args.do_train: 111 | forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) 112 | args.forward_batch_size = forward_batch_size 113 | train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=args.forward_batch_size,drop_last=True) 114 | 115 | 116 | 117 | def AdaptorTrain(batch, model_outputs): 118 | return {'losses':(model_outputs[0],)} 119 | def AdaptorLogits(batch, model_outputs): 120 | return {'logits': (model_outputs.logits,)} 121 | def AdaptorLogitsHidden(batch, model_outputs): 122 | return {'logits': (model_outputs.logits,), 123 | 'hidden': (model_outputs.hidden_states), 124 | 'inputs_mask': batch['attention_mask'],} 125 | 126 | if args.output_hidden_states is True: 127 | print("use hidden") 128 | Adaptor = AdaptorLogitsHidden 129 | else: 130 | Adaptor = AdaptorLogits 131 | PruneDistiller = PruneDistillerHidden 132 | 133 | #parameters 134 | params = list(model.named_parameters()) 135 | frozen_params = [] 136 | if args.freeze_embeddings: 137 | frozen_params = ['word_embeddings'] 138 | #all_trainable_params = divide_parameters(params, lr=args.learning_rate) 139 | no_decay = ['bias','LayerNorm.weight'] 140 | large_lr = ['attention_head_scale'] 141 | all_trainable_params = [ 142 | { 143 | "params":[p for n,p in params if not any(nd in n for nd in no_decay+frozen_params)], 144 | "weight_decay": args.weight_decay_rate, 145 | }, 146 | { 147 | 'params': [p for n,p in params if any(nd in n for nd in no_decay)], 148 | 'weight_decay':0.0 149 | }, 150 | { 151 | 'params': [p for n,p in params if any(nd in n for nd in frozen_params)], 152 | 'lr':0.0 153 | } 154 | ] 155 | logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) 156 | 157 | ########## PruneDistiller ########### 158 | train_config = TrainingConfig( 159 | gradient_accumulation_steps = args.gradient_accumulation_steps, 160 | ckpt_frequency = args.ckpt_frequency, 161 | #ckpt_steps = int(num_train_steps//args.num_train_epochs//2), 162 | log_dir = args.output_dir, 163 | output_dir = args.output_dir, 164 | fp16 = args.fp16, 165 | device = args.device) 166 | if args.matching_layers_S is not None: 167 | matching_layers = list(zip(map(int,args.matching_layers_S.split(',')),map(int,args.matching_layers_T.split(',')))) 168 | else: 169 | matching_layers = None 170 | distill_config = DistillationConfig(temperature=8, 171 | matching_layers=matching_layers) 172 | prune_config = PruningConfig(end_pruning_at=args.end_pruning_at, start_pruning_at=args.start_pruning_at, 173 | end_weights_ratio=args.end_weights_ratio, 174 | pruning_frequency=args.pruning_frequency, 175 | IS_beta=args.IS_beta, is_global=args.is_global, is_reweight=args.is_reweight, 176 | is_two_ratios=args.is_two_ratios,FFN_weights_ratio=args.FFN_weights_ratio,MHA_weights_ratio=args.MHA_weights_ratio, 177 | score_type=args.score_type, 178 | pruner_type=args.pruner_type, 179 | dynamic_head_size=args.dynamic_head_size, 180 | IS_gamma=args.IS_gamma, 181 | IS_alpha=args.IS_alpha, 182 | IS_alpha_head=args.IS_alpha_head, 183 | IS_alpha_ffn=args.IS_alpha_ffn, 184 | IS_alpha_mha=args.IS_alpha_mha, 185 | dbw=(not args.no_dbw) 186 | ) 187 | distiller = PruneDistiller(train_config = train_config, distill_config=distill_config, 188 | prune_config=prune_config, model_T = model_T, model_S = model, 189 | adaptor = Adaptor) 190 | num_train_steps = int(len(train_dataloader)//args.gradient_accumulation_steps * args.num_train_epochs) 191 | optimizer = AdamW(all_trainable_params,lr=args.learning_rate,eps=args.adam_epsilon) 192 | scheduler_args = {'num_warmup_steps': int(args.warmup_proportion*num_train_steps), 193 | 'num_training_steps': num_train_steps} 194 | scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,**scheduler_args) 195 | 196 | logger.info("***** Running Prune Distiller *****") 197 | logger.info(" Num examples = %d", len(train_dataset)) 198 | logger.info(" Forward batch size = %d", forward_batch_size) 199 | logger.info(" Num backward steps = %d", num_train_steps) 200 | 201 | 202 | def batch_postprocessor(batch): 203 | if 'token_type_ids' in batch: 204 | return {'input_ids':batch['input_ids'],'attention_mask':batch['attention_mask'],'token_type_ids':batch['token_type_ids'], 205 | 'labels':batch['label']} 206 | else: 207 | return {'input_ids':batch['input_ids'],'attention_mask':batch['attention_mask'], 208 | 'labels':batch['label']} 209 | with distiller: 210 | distiller.train(train_dataloader, optimizer, scheduler, args.num_train_epochs, 211 | max_grad_norm=args.max_grad_norm, callback=callback_func, batch_postprocessor=batch_postprocessor) 212 | del optimizer 213 | logger.info("*********************Prune Distiller Finished*****************") 214 | 215 | if not args.do_train and args.do_predict: 216 | model.to(device) 217 | res = predict(model,eval_datasets,step=0,eval_lang=eval_langs, args=args,taskname=args.taskname) 218 | print (res) 219 | 220 | show_masks(model.state_dict()) 221 | 222 | if __name__ == "__main__": 223 | main() 224 | -------------------------------------------------------------------------------- /scripts/glue/predict_function.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | from torch.utils.data import SequentialSampler,DistributedSampler,DataLoader 5 | from utils_glue import compute_metrics 6 | from tqdm import tqdm 7 | import logging 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def predict(model,eval_datasets,step, eval_lang, args, taskname=None): 12 | eval_task = taskname 13 | eval_output_dir = args.output_dir 14 | lang_results = {} 15 | for lang,eval_dataset in zip(eval_lang, eval_datasets): 16 | if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: 17 | os.makedirs(eval_output_dir) 18 | logger.info("Predicting...") 19 | logger.info("***** Running predictions *****") 20 | logger.info(" task name = %s", eval_task) 21 | logger.info(" lang : %s", lang) 22 | logger.info(" Num examples = %d", len(eval_dataset)) 23 | logger.info(" Batch size = %d", args.predict_batch_size) 24 | eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) 25 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.predict_batch_size) 26 | model.eval() 27 | 28 | pred_logits = [] 29 | label_ids = [] 30 | for batch in tqdm(eval_dataloader, desc="Evaluating", disable=None): 31 | token_type_ids = batch.get('token_type_ids',None) 32 | if token_type_ids is not None: 33 | token_type_ids = token_type_ids.to(args.device) 34 | input_ids, input_mask, labels = batch['input_ids'],batch['attention_mask'],batch['label'] 35 | input_ids = input_ids.to(args.device) 36 | input_mask = input_mask.to(args.device) 37 | #segment_ids = segment_ids.to(args.device) 38 | with torch.no_grad(): 39 | outputs= model(input_ids, input_mask,token_type_ids=token_type_ids) 40 | logits = outputs[0] 41 | pred_logits.append(logits.detach().cpu()) 42 | label_ids.append(labels) 43 | pred_logits = np.array(torch.cat(pred_logits),dtype=np.float32) 44 | label_ids = np.array(torch.cat(label_ids),dtype=np.int64) 45 | 46 | preds = np.argmax(pred_logits, axis=1) 47 | results = compute_metrics(eval_task, preds, label_ids) 48 | 49 | logger.info("***** Eval results {} Lang {} *****".format(step, lang)) 50 | for key in sorted(results.keys()): 51 | logger.info(f"{lang} {key} = {results[key]:.5f}") 52 | lang_results[lang] = results 53 | 54 | output_eval_file = os.path.join(eval_output_dir, "eval_results.txt") 55 | 56 | write_results(output_eval_file,step,lang_results, eval_lang) 57 | model.train() 58 | return lang_results 59 | 60 | def write_results(output_eval_file,step,lang_results, eval_lang): 61 | with open(output_eval_file, "a") as writer: 62 | writer.write(f"step: {step:<8d} ") 63 | line = "Acc/F1:" 64 | 65 | for lang in eval_lang: 66 | acc = lang_results[lang]['acc'] 67 | if 'f1' in lang_results[lang]: 68 | f1 = lang_results[lang]['f1'] 69 | line += f"{lang}={acc:.5f}/{f1:.5f} " 70 | else: 71 | line += f"{lang}={acc:.5f} " 72 | writer.write(line+'\n') -------------------------------------------------------------------------------- /scripts/glue/utils.py: -------------------------------------------------------------------------------- 1 | def divide_parameters(named_parameters, weight_decay_rate=0.01, lr=None): 2 | no_decay = ['bias', 'LayerNorm.bias','LayerNorm.weight'] 3 | decay_parameters_names = list(zip(*[(p,n) for n,p in named_parameters if not any((di in n) for di in no_decay)])) 4 | no_decay_parameters_names = list(zip(*[(p,n) for n,p in named_parameters if any((di in n) for di in no_decay)])) 5 | param_group = [] 6 | if len(decay_parameters_names)>0: 7 | decay_parameters, decay_names = decay_parameters_names 8 | if lr is not None: 9 | decay_group = {'params':decay_parameters, 'weight_decay': weight_decay_rate, 'lr':lr} 10 | else: 11 | decay_group = {'params': decay_parameters, 'weight_decay': weight_decay_rate} 12 | param_group.append(decay_group) 13 | 14 | if len(no_decay_parameters_names)>0: 15 | no_decay_parameters, no_decay_names = no_decay_parameters_names 16 | if lr is not None: 17 | no_decay_group = {'params': no_decay_parameters, 'weight_decay': 0.0, 'lr': lr} 18 | else: 19 | no_decay_group = {'params': no_decay_parameters, 'weight_decay': 0.0} 20 | param_group.append(no_decay_group) 21 | 22 | assert len(param_group)>0 23 | return param_group -------------------------------------------------------------------------------- /scripts/glue/utils_glue.py: -------------------------------------------------------------------------------- 1 | import csv, json 2 | import logging 3 | import os 4 | import sys 5 | import random 6 | from io import open 7 | 8 | from scipy.stats import pearsonr, spearmanr 9 | from sklearn.metrics import f1_score 10 | 11 | import torch 12 | import datasets 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def get_glue_dataset(taskname,data_dir,split,tokenizer,max_length,return_token_type_ids=False): 18 | raw_all_dataset = datasets.load_from_disk(data_dir) 19 | if taskname=='mnli': 20 | encoded_dataset = raw_all_dataset[split].map( 21 | lambda examples: tokenizer(examples['premise'],examples['hypothesis'], 22 | return_token_type_ids=return_token_type_ids,padding='max_length', 23 | max_length=max_length,truncation=True),batched=True) 24 | if taskname=='mrpc': 25 | encoded_dataset = raw_all_dataset[split].map( 26 | lambda examples: tokenizer(examples['sentence1'],examples['sentence2'], 27 | return_token_type_ids=return_token_type_ids,padding='max_length', 28 | max_length=max_length,truncation=True),batched=True) 29 | elif taskname=='qqp': 30 | encoded_dataset = raw_all_dataset[split].map( 31 | lambda examples: tokenizer(examples['question1'],examples['question2'], 32 | return_token_type_ids=return_token_type_ids,padding='max_length', 33 | max_length=max_length,truncation=True),batched=True) 34 | elif taskname=='sst2': 35 | encoded_dataset = raw_all_dataset[split].map( 36 | lambda examples: tokenizer(examples['sentence'], 37 | return_token_type_ids=return_token_type_ids,padding='max_length', 38 | max_length=max_length,truncation=True),batched=True) 39 | elif taskname=='qnli': 40 | encoded_dataset = raw_all_dataset[split].map( 41 | lambda examples: tokenizer(examples['question'],examples['sentence'], 42 | return_token_type_ids=return_token_type_ids,padding='max_length', 43 | max_length=max_length,truncation=True),batched=True) 44 | 45 | if return_token_type_ids is False: 46 | encoded_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label']) 47 | else: 48 | encoded_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'token_type_ids', 'label']) 49 | 50 | return encoded_dataset 51 | 52 | def simple_accuracy(preds, labels): 53 | return (preds == labels).mean() 54 | 55 | 56 | def acc_and_f1(preds, labels): 57 | acc = simple_accuracy(preds, labels) 58 | f1 = f1_score(y_true=labels, y_pred=preds) 59 | return { 60 | "acc": acc, 61 | "f1": f1, 62 | "acc_and_f1": (acc + f1) / 2, 63 | } 64 | 65 | 66 | def pearson_and_spearman(preds, labels): 67 | pearson_corr = pearsonr(preds, labels)[0] 68 | spearman_corr = spearmanr(preds, labels)[0] 69 | return { 70 | "pearson": pearson_corr, 71 | "spearmanr": spearman_corr, 72 | "corr": (pearson_corr + spearman_corr) / 2, 73 | } 74 | 75 | 76 | def compute_metrics(task_name, preds, labels): 77 | assert len(preds) == len(labels) 78 | if task_name == "xnli" or task_name=='qnli' or task_name=='sst2' or task_name=='mnli': 79 | return {"acc": simple_accuracy(preds, labels)} 80 | if task_name == 'mrpc' or task_name == 'qqp': 81 | return {"acc": simple_accuracy(preds, labels),'f1':f1_score(y_true=labels,y_pred=preds)} 82 | elif task_name == "lcqmc": 83 | return {"acc": simple_accuracy(preds, labels)} 84 | else: 85 | raise KeyError(task_name) 86 | 87 | 88 | output_modes = { 89 | "xnli": "classification", 90 | "lcqmc":"classification", 91 | "pawsx":"classification", 92 | "amazon":"classification", 93 | } 94 | -------------------------------------------------------------------------------- /scripts/pruners_and_distiller/distiller.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import logging 4 | from tqdm import tqdm 5 | import os 6 | from accelerate import Accelerator 7 | logger = logging.getLogger(__name__) 8 | 9 | try: 10 | from tensorboardX import SummaryWriter 11 | except ImportError: 12 | from torch.utils.tensorboard import SummaryWriter 13 | 14 | from .pruners import ISPruner, FineISPruner 15 | from .utils import TrainingConfig, DistillationConfig, PruningConfig, DistillationContext, kd_ce_loss, hid_mse_loss 16 | from .utils import schedule_threshold, select_logits_with_mask 17 | 18 | def initializer_builder(std): 19 | _std = std 20 | def init_weights(module): 21 | if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): 22 | module.weight.data.normal_(mean=0.0, std=_std) 23 | if isinstance(module, torch.nn.Linear) and module.bias is not None: 24 | module.bias.data.zero_() 25 | return init_weights 26 | def linear_projection(dim_in, dim_out,init='gaussian'): 27 | model = torch.nn.Linear(in_features=dim_in, out_features=dim_out, bias=True) 28 | if init=='gaussian': 29 | initializer = initializer_builder(0.02) 30 | model.apply(initializer) 31 | elif init=='identity': 32 | torch.nn.init.zeros_(model.bias) 33 | torch.nn.init.eye_(model.weight) 34 | else: 35 | raise NotImplementedError 36 | return model 37 | 38 | 39 | class PruneDistiller(DistillationContext): 40 | def __init__(self, train_config, distill_config: DistillationConfig, prune_config : PruningConfig, model_T, model_S, adaptor): 41 | super(PruneDistiller, self).__init__() 42 | self.t_config = train_config 43 | self.d_config : DistillationConfig = distill_config 44 | self.p_config : PruningConfig = prune_config 45 | self.model_T = model_T 46 | self.model_S = model_S 47 | self.model = self.model_S 48 | self.adaptor = adaptor 49 | 50 | self.print_freq = 20 51 | self.tb_writer = None 52 | self.accelerator = None 53 | 54 | if self.p_config.pruner_type=='ISPruner': 55 | self.pruner = ISPruner(self.model) 56 | elif self.p_config.pruner_type=='FineISPruner': 57 | self.pruner = FineISPruner(self.model) 58 | else: 59 | raise ValueError 60 | 61 | self.projs = dict() 62 | if self.d_config.matching_layers is not None: 63 | for layer_s,_ in self.d_config.matching_layers: #range(0,13,2): 64 | self.projs[layer_s]=linear_projection(768,768,'identity') 65 | 66 | self.global_status = dict() 67 | self.metrics = {} 68 | 69 | def train(self, dataloader, optimizer, lr_scheduler, num_epochs, num_steps = None, max_grad_norm = None, callback=None, batch_postprocessor=None): 70 | 71 | mixed_precision = 'fp16' if self.t_config.fp16 is True else 'no' 72 | self.accelerator = Accelerator(mixed_precision=mixed_precision) 73 | 74 | if self.accelerator.is_main_process: 75 | self.tb_writer = SummaryWriter(log_dir = self.t_config.log_dir) 76 | self.device = self.accelerator.device 77 | for proj in (self.projs.values()): 78 | optimizer.add_param_group({**{'params':proj.parameters()},}) 79 | 80 | self.model, self.model_T, optimizer, dataloader, lr_scheduler, *projs = self.accelerator.prepare( 81 | self.model, self.model_T, optimizer, dataloader, lr_scheduler, *list(self.projs.values())) 82 | for idx,proj in zip(self.projs.keys(),projs): 83 | self.projs[idx] = proj 84 | self.model_S = self.model 85 | 86 | self.pruner.initialize(self.model) 87 | 88 | if num_epochs is not None: 89 | num_epochs = int(num_epochs) 90 | self.train_epochs(dataloader, optimizer, lr_scheduler, num_epochs, max_grad_norm, callback, batch_postprocessor) 91 | 92 | 93 | def train_epochs(self, dataloader, optimizer, lr_scheduler, num_epochs, max_grad_norm = None, callback=None, batch_postprocessor=None): 94 | 95 | train_steps_per_epoch = len(dataloader) // self.t_config.gradient_accumulation_steps 96 | print_every = train_steps_per_epoch // self.print_freq 97 | if print_every == 0: 98 | print_every = train_steps_per_epoch 99 | checkpoints = [int(train_steps_per_epoch*ci/self.t_config.ckpt_frequency) for ci in range(self.t_config.ckpt_frequency)] 100 | 101 | total_global_steps = train_steps_per_epoch * num_epochs 102 | logger.info(f"Training steps per epoch: {train_steps_per_epoch}") 103 | logger.info(f"Training total global steps: {total_global_steps}") 104 | logger.info(f"Checkpoints(step): {checkpoints}") 105 | 106 | global_step = 0 107 | 108 | scalar_total_loss = 0 109 | tqdm_disable = None if self.accelerator.is_main_process else True 110 | 111 | # only works with gradient_accumulation_steps==1 112 | assert self.t_config.gradient_accumulation_steps == 1 113 | 114 | for current_epoch in tqdm(range(num_epochs),disable=tqdm_disable): 115 | 116 | logger.info(f"Epoch {current_epoch+1}") 117 | logger.info(f"Length of current epoch in forward batch: {len(dataloader)}") 118 | 119 | for forward_step, batch in tqdm(enumerate(dataloader),disable=tqdm_disable): 120 | 121 | #init 122 | optimizer.zero_grad() 123 | # forward and get loss 124 | batch = batch_postprocessor(batch) if batch_postprocessor is not None else batch 125 | ce_loss, matching_loss = self.train_on_batch(batch) 126 | if matching_loss is not None: 127 | sum_loss = ce_loss+matching_loss 128 | else: 129 | sum_loss = ce_loss 130 | scalar_total_loss += sum_loss.cpu().item() 131 | 132 | global_step += 1 133 | 134 | # backward 135 | if matching_loss is not None and (self.p_config.dbw is True): 136 | self.accelerator.backward(ce_loss,retain_graph=True) 137 | _,prune_ended = self.maybe_should_prune(global_step, total_global_steps,do_prune=False, do_gather=True,do_update=False, gamma=1) 138 | self.accelerator.backward(matching_loss,retain_graph=False) 139 | _,prune_ended = self.maybe_should_prune(global_step, total_global_steps,do_prune=True, do_gather=False, do_update=True, gamma=self.p_config.IS_gamma) 140 | else: 141 | self.accelerator.backward(sum_loss) 142 | _,prune_ended = self.maybe_should_prune(global_step, total_global_steps) 143 | #Tensorboard logging 144 | if self.accelerator.is_main_process: 145 | self.tb_writer.add_scalar('scalar/total_loss', scalar_total_loss, global_step) 146 | # gradient clipping 147 | if max_grad_norm is not None: 148 | self.accelerator.clip_grad_norm_(self.model.parameters(), max_grad_norm) 149 | #optimizer step 150 | optimizer.step() 151 | lr_scheduler.step() 152 | 153 | if (global_step) % print_every == 0: 154 | logger.info(f"Global step: {global_step}, epoch forward_step:{forward_step+1}") 155 | 156 | if (global_step % train_steps_per_epoch in checkpoints) \ 157 | and ((current_epoch+1) % self.t_config.ckpt_epoch_frequency==0 or current_epoch+1==num_epochs): 158 | 159 | self.accelerator.wait_for_everyone() 160 | if self.accelerator.is_main_process and prune_ended: 161 | logger.info(f"Saving at global step {global_step}, epoch forward_step {forward_step+1} epoch {current_epoch+1}") 162 | coreModel = self.accelerator.unwrap_model(self.model) 163 | state_dict = coreModel.state_dict() 164 | self.accelerator.save(state_dict, os.path.join(self.t_config.output_dir,f"gs{global_step}.pt")) 165 | self.accelerator.wait_for_everyone() 166 | if callback is not None: 167 | logger.info("Running callback function...") 168 | res = callback(model=self.model, step=global_step) 169 | self.metrics[global_step] = res 170 | self.model.train() 171 | 172 | logger.info(f"Epoch {current_epoch+1} finished") 173 | 174 | def train_on_batch(self, batch) -> torch.Tensor: 175 | #batch = move_to_device(batch, self.t_config.device) 176 | if isinstance(batch,(list,tuple)): 177 | results_S = self.model(*batch) 178 | with torch.no_grad(): 179 | results_T = self.model_T(*batch) 180 | else: 181 | results_S = self.model(**batch) 182 | with torch.no_grad(): 183 | results_T = self.model_T(**batch) 184 | results_S = post_adaptor(self.adaptor(batch,results_S)) 185 | results_T = post_adaptor(self.adaptor(batch,results_T)) 186 | 187 | ce_loss, matching_loss = self.compute_loss(results_S,results_T) 188 | 189 | return ce_loss, matching_loss #, losses_dict 190 | 191 | 192 | def compute_loss(self, results_S, results_T): 193 | total_loss = 0 194 | matching_loss = None 195 | losses_dict = dict() 196 | logits_list_T = results_T['logits'] # list of tensor 197 | logits_list_S = results_S['logits'] # list of tensor 198 | 199 | if 'logits_mask' in results_S: 200 | masks_list_S = results_S['logits_mask'] 201 | logits_list_S = select_logits_with_mask(logits_list_S,masks_list_S) #(mask_sum, num_of_class) 202 | if 'logits_mask' in results_T: 203 | masks_list_T = results_T['logits_mask'] 204 | logits_list_T = select_logits_with_mask(logits_list_T,masks_list_T) #(mask_sum, num_of_class) 205 | 206 | total_kd_loss = 0 207 | for l_T,l_S in zip(logits_list_T,logits_list_S): 208 | temperature = self.d_config.temperature 209 | total_kd_loss += kd_ce_loss(l_S, l_T, temperature) 210 | total_loss += total_kd_loss * self.d_config.kd_loss_weight 211 | losses_dict['unweighted_kd_loss'] = total_kd_loss 212 | 213 | if 'losses' in results_S: 214 | total_hl_loss = 0 215 | total_hl_loss = sum(loss.mean() for loss in results_S['losses']) # in case of multi-GPU 216 | total_loss += total_hl_loss * self.d_config.hard_label_weight 217 | losses_dict['unweighted_hard_label_loss'] = total_hl_loss 218 | if 'hidden' in results_T and 'hidden' in results_S and (self.d_config.matching_layers is not None): 219 | matching_loss = 0 220 | 221 | loss_weight_pairs = [] 222 | for layer_s, layer_t in self.d_config.matching_layers: 223 | inter_S = self.projs[layer_s](results_S['hidden'][layer_s]) 224 | inter_T = results_T['hidden'][layer_t] 225 | inputs_mask_S = results_S.get('inputs_mask',None) 226 | loss_weight = 1 227 | 228 | match_loss = hid_mse_loss(inter_S, inter_T, mask=inputs_mask_S) 229 | loss_weight_pairs.append((match_loss,loss_weight)) 230 | 231 | weights_sum = sum(w for _,w in loss_weight_pairs[1:]) 232 | num_matchings = len(loss_weight_pairs) - 1 #excluding embeddings 233 | rescaled_weights = [w/weights_sum for _,w in loss_weight_pairs[1:]] #embddings + trm 234 | rescaled_weights_sum = sum(rescaled_weights) 235 | normalized_weights = [1] + [w/rescaled_weights_sum * num_matchings for w in rescaled_weights] 236 | 237 | self.global_status['normalized_weights'] = normalized_weights 238 | 239 | assert len(normalized_weights)==len(loss_weight_pairs) 240 | matching_loss += sum(p[0]*w for p,w in zip(loss_weight_pairs,normalized_weights)) 241 | return total_loss, matching_loss 242 | 243 | 244 | def maybe_should_prune(self,global_step, total_global_steps, do_gather=True, do_update=True, do_prune=True, gamma:float=1): 245 | pruning_frequency = self.p_config.pruning_frequency 246 | start_pruning_steps = int(self.p_config.start_pruning_at * total_global_steps) 247 | end_pruning_steps = int(self.p_config.end_pruning_at * total_global_steps) 248 | if do_gather is True: 249 | self.pruner.gather_IS(self.model, gamma,self.p_config.score_type, global_step>=start_pruning_steps) 250 | if do_update is True: 251 | self.pruner.update_IS(beta = self.p_config.IS_beta, alpha=self.p_config.IS_alpha, 252 | alpha_head=self.p_config.IS_alpha_head, 253 | alpha_ffn=self.p_config.IS_alpha_ffn, 254 | alpha_mha=self.p_config.IS_alpha_mha) 255 | if do_prune is True: 256 | if (global_step % pruning_frequency == 0) and \ 257 | global_step >= start_pruning_steps and global_step < end_pruning_steps: 258 | if self.p_config.pruner_type=='ISPruner' or self.p_config.pruner_type=='FineISPruner': 259 | self.pruner.do_prune(self.model, global_step, total_global_steps, self.p_config) 260 | else: 261 | raise NotImplementedError 262 | 263 | # logging 264 | if (global_step % 100 == 0): 265 | ffn_density = self.pruner.ffn_mask.sum()/self.pruner.ffn_mask.numel() 266 | if self.p_config.pruner_type=='ISPruner': 267 | head_density = self.pruner.head_mask.sum()/self.pruner.head_mask.numel() 268 | print(f"group density FFN/MHA {ffn_density:.4f} {head_density:.4f}") 269 | print(f"weighted density FFN/MHA {ffn_density * self.pruner.total_ffn_ratio + head_density * self.pruner.total_head_ratio:.4f}") 270 | elif self.p_config.pruner_type=='FineISPruner': 271 | qk_density = self.pruner.qk_mask.sum()/self.pruner.qk_mask.numel() 272 | vo_density = self.pruner.vo_mask.sum()/self.pruner.vo_mask.numel() 273 | print(f"Num zeros {self.pruner.pre_num_zeros / self.pruner.total_num}") 274 | print(f"group density FFN/QK/VO {ffn_density:.4f} {qk_density:.4f} {vo_density:.4f}") 275 | print(f"weighted density FFN/QK/VO {ffn_density * self.pruner.total_ffn_ratio + qk_density * self.pruner.total_qk_ratio + vo_density * self.pruner.total_vo_ratio:.4f}") 276 | 277 | return global_step >= start_pruning_steps, global_step >= end_pruning_steps 278 | 279 | def post_adaptor(dict_object): 280 | if 'logits' in dict_object: 281 | logits = dict_object['logits'] 282 | if not isinstance(logits,(list,tuple)): 283 | dict_object['logits'] = [ logits ] 284 | if 'logits_mask' in dict_object: 285 | logits_mask = dict_object['logits_mask'] 286 | if not isinstance(logits_mask,(list,tuple)): 287 | dict_object['logits_mask'] = [ logits_mask ] 288 | if 'losses' in dict_object: 289 | losses = dict_object['losses'] 290 | if not isinstance(losses,(list,tuple)): 291 | dict_object['losses'] = [ losses ] 292 | if 'labels' in dict_object: 293 | labels = dict_object['labels'] 294 | if not isinstance(labels,(list,tuple)): 295 | dict_object['labels'] = [ labels ] 296 | return dict_object 297 | 298 | 299 | def cycle(iterable): 300 | while True: 301 | for x in iterable: 302 | yield x 303 | -------------------------------------------------------------------------------- /scripts/pruners_and_distiller/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import logging 4 | import os, json 5 | logger = logging.getLogger(__name__) 6 | 7 | def select_logits_with_mask(logits_list, masks_list): 8 | output_logits = [] 9 | if len(masks_list)==len(logits_list): 10 | for logits,mask in zip(logits_list,masks_list): 11 | if len(logits.shape)==3: 12 | mask = mask.unsqueeze(-1).expand_as(logits).to(torch.bool) 13 | logits_select = torch.masked_select(logits,mask).view(-1,logits.size(-1)) 14 | else: 15 | logits_select = logits #Logits_mask has no effect on logits of shape (batch_size, logits_to_be_softmaxed) 16 | output_logits.append(logits_select) 17 | elif len(masks_list)==1: 18 | mask = masks_list[0] 19 | for logits in logits_list: 20 | if len(logits.shape)==3: 21 | mask = mask.unsqueeze(-1).expand_as(logits).to(torch.bool) 22 | logits_select = torch.masked_select(logits,mask).view(-1,logits.size(-1)) 23 | else: 24 | logits_select = logits #Logits_mask has no effect on logits of shape (batch_size, logits_to_be_softmaxed) 25 | output_logits.append(logits_select) 26 | else: 27 | raise AssertionError("lengths of logits list and masks list mismatch") 28 | return output_logits 29 | 30 | def kd_ce_loss(logits_S, logits_T, temperature=1): 31 | ''' 32 | Calculate the cross entropy between logits_S and logits_T 33 | 34 | :param logits_S: Tensor of shape (batch_size, length, num_labels) or (batch_size, num_labels) 35 | :param logits_T: Tensor of shape (batch_size, length, num_labels) or (batch_size, num_labels) 36 | :param temperature: A float or a tensor of shape (batch_size, length) or (batch_size,) 37 | ''' 38 | if isinstance(temperature, torch.Tensor) and temperature.dim() > 0: 39 | temperature = temperature.unsqueeze(-1) 40 | beta_logits_T = logits_T / temperature 41 | beta_logits_S = logits_S / temperature 42 | p_T = F.softmax(beta_logits_T, dim=-1) 43 | loss = -(p_T * F.log_softmax(beta_logits_S, dim=-1)).sum(dim=-1).mean() 44 | return loss 45 | 46 | def hid_mse_loss(state_S, state_T, mask=None): 47 | ''' 48 | * Calculates the mse loss between `state_S` and `state_T`, which are the hidden state of the models. 49 | * If the `inputs_mask` is given, masks the positions where ``input_mask==0``. 50 | * If the hidden sizes of student and teacher are different, 'proj' option is required in `inetermediate_matches` to match the dimensions. 51 | 52 | :param torch.Tensor state_S: tensor of shape (*batch_size*, *length*, *hidden_size*) 53 | :param torch.Tensor state_T: tensor of shape (*batch_size*, *length*, *hidden_size*) 54 | :param torch.Tensor mask: tensor of shape (*batch_size*, *length*) 55 | ''' 56 | if mask is None: 57 | loss = F.mse_loss(state_S, state_T) 58 | else: 59 | mask = mask.to(state_S) 60 | valid_count = mask.sum() * state_S.size(-1) 61 | loss = (F.mse_loss(state_S, state_T, reduction='none') * mask.unsqueeze(-1)).sum() / valid_count 62 | return loss 63 | 64 | class DistillationContext: 65 | def __init__(self): 66 | self.model = self.model_S = None 67 | self.model_T = None 68 | def __enter__(self): 69 | if isinstance(self.model_T,(list,tuple)): 70 | self.model_T_is_training = [model_t.training for model_t in self.model_T] 71 | for model_t in self.model_T: 72 | model_t.eval() 73 | elif isinstance(self.model_T,dict): 74 | self.model_T_is_training = {name:model.training for name,model in self.model_T.items()} 75 | for name in self.model_T: 76 | self.model_T[name].eval() 77 | else: 78 | self.model_T_is_training = self.model_T.training 79 | self.model_T.eval() 80 | 81 | if isinstance(self.model_S,(list,tuple)): 82 | self.model_S_is_training = [model_s.training for model_s in self.model_S] 83 | for model_s in self.model_S: 84 | model_s.train() 85 | elif isinstance(self.model_S,dict): 86 | self.model_S_is_training = {name:model.training for name,model in self.model_S.items()} 87 | for name in self.model_S: 88 | self.model_S[name].train() 89 | else: 90 | self.model_S_is_training = self.model_S.training 91 | self.model_S.train() 92 | 93 | def __exit__(self, exc_type, exc_val, exc_tb): 94 | #Restore model status 95 | if isinstance(self.model_T,(list,tuple)): 96 | for i in range(len(self.model_T_is_training)): 97 | self.model_T[i].train(self.model_T_is_training[i]) 98 | elif isinstance(self.model_T,dict): 99 | for name,is_training in self.model_T_is_training.items(): 100 | self.model_T[name].train(is_training) 101 | else: 102 | self.model_T.train(self.model_T_is_training) 103 | 104 | if isinstance(self.model_S,(list,tuple)): 105 | for i in range(len(self.model_S_is_training)): 106 | self.model_S[i].train(self.model_S_is_training[i]) 107 | elif isinstance(self.model_S,dict): 108 | for name,is_training in self.model_S_is_training.items(): 109 | self.model_S[name].train(is_training) 110 | else: 111 | self.model_S.train(self.model_S_is_training) 112 | 113 | 114 | class Config: 115 | """Base class for TrainingConfig and DistillationConfig.""" 116 | def __init__(self,**kwargs): 117 | pass 118 | 119 | @classmethod 120 | def from_json_file(cls, json_filename): 121 | """Construct configurations from a json file.""" 122 | with open(json_filename,'r') as f: 123 | json_data = json.load(f) 124 | return cls.from_dict(json_data) 125 | 126 | @classmethod 127 | def from_dict(cls, dict_object): 128 | """Construct configurations from a dict.""" 129 | config = cls(**dict_object) 130 | return config 131 | 132 | def __str__(self): 133 | str = "" 134 | for k,v in self.__dict__.items(): 135 | str += f"{k} : {v}\n" 136 | return str 137 | 138 | def __repr__(self): 139 | classname = self.__class__.__name__ 140 | return classname +":\n"+self.__str__() 141 | 142 | 143 | class TrainingConfig(Config): 144 | def __init__(self,gradient_accumulation_steps = 1, 145 | ckpt_frequency = 1, 146 | ckpt_epoch_frequency = 1, 147 | ckpt_steps = None, 148 | log_dir = None, 149 | output_dir = './saved_models', 150 | device = 'cuda', 151 | fp16 = False, 152 | fp16_opt_level = 'O1', 153 | data_parallel = False, 154 | local_rank = -1 155 | ): 156 | super(TrainingConfig, self).__init__() 157 | 158 | self.gradient_accumulation_steps =gradient_accumulation_steps 159 | self.ckpt_frequency = ckpt_frequency 160 | self.ckpt_epoch_frequency = ckpt_epoch_frequency 161 | self.ckpt_steps = ckpt_steps 162 | self.log_dir = log_dir 163 | self.output_dir = output_dir 164 | self.device = device 165 | self.fp16 = fp16 166 | self.fp16_opt_level = fp16_opt_level 167 | self.data_parallel = data_parallel 168 | 169 | self.local_rank = local_rank 170 | if self.local_rank == -1 or torch.distributed.get_rank() == 0: 171 | if not os.path.exists(self.output_dir): 172 | os.makedirs(self.output_dir) 173 | 174 | 175 | class DistillationConfig(Config): 176 | 177 | def __init__(self,temperature=4, 178 | hard_label_weight=0, 179 | kd_loss_weight=1, 180 | matching_layers = None): 181 | super(DistillationConfig, self).__init__() 182 | 183 | self.temperature = temperature 184 | self.hard_label_weight = hard_label_weight 185 | self.kd_loss_weight = kd_loss_weight 186 | self.matching_layers = matching_layers 187 | 188 | 189 | #------pruning related---------# 190 | 191 | class PruningConfig(Config): 192 | def __init__(self, 193 | start_pruning_at = 0.2, 194 | start_weights_ratio = 1.0, 195 | end_pruning_at = 0.7, 196 | end_weights_ratio = 0.33, 197 | pruning_frequency = 50, 198 | IS_alpha = 0.0001, # the dgree that biases towards pruning whole head 199 | IS_alpha_head = None, 200 | IS_alpha_ffn = None, 201 | IS_alpha_mha = None, 202 | IS_gamma = 1, 203 | IS_beta = None, 204 | is_global = False, 205 | is_reweight = 0, 206 | is_two_ratios = False, 207 | FFN_weights_ratio = None, 208 | MHA_weights_ratio = None, 209 | score_type = 'grad', 210 | dbw = True, 211 | pruner_type = "Pruner", 212 | dynamic_head_size = False 213 | ): 214 | super(PruningConfig, self).__init__() 215 | 216 | self.start_pruning_at =start_pruning_at 217 | self.start_weights_ratio = start_weights_ratio 218 | self.end_pruning_at = end_pruning_at 219 | self.end_weights_ratio = end_weights_ratio 220 | self.pruning_frequency = pruning_frequency 221 | self.IS_alpha = IS_alpha 222 | self.IS_alpha_head = IS_alpha_head 223 | self.IS_alpha_ffn = IS_alpha_ffn 224 | self.IS_alpha_mha = IS_alpha_mha 225 | self.IS_beta = IS_beta 226 | self.IS_gamma = IS_gamma 227 | self.is_global = is_global 228 | self.is_reweight = is_reweight 229 | self.pruner_type = pruner_type 230 | self.is_two_ratios = is_two_ratios 231 | self.FFN_weights_ratio = FFN_weights_ratio 232 | self.MHA_weights_ratio = MHA_weights_ratio 233 | self.score_type = score_type 234 | self.dbw = dbw 235 | self.dynamic_head_size = dynamic_head_size 236 | 237 | 238 | def schedule_threshold( 239 | step: int, 240 | total_step: int, 241 | p_config : PruningConfig, 242 | overwrite_end_ratio : float = None 243 | ): 244 | start_pruning_steps = int(p_config.start_pruning_at * total_step) 245 | end_pruning_steps = int(p_config.end_pruning_at * total_step) 246 | start_weights_ratio = p_config.start_weights_ratio 247 | end_weights_ratio = p_config.end_weights_ratio if overwrite_end_ratio is None else overwrite_end_ratio 248 | if step <= start_pruning_steps: 249 | weights_ratio = p_config.start_weights_ratio 250 | elif step > end_pruning_steps: 251 | weights_ratio = p_config.end_weights_ratio if overwrite_end_ratio is None else overwrite_end_ratio 252 | else: 253 | mul_coeff = 1 - (step - start_pruning_steps) / (end_pruning_steps - start_pruning_steps) 254 | weights_ratio = end_weights_ratio + (start_weights_ratio - end_weights_ratio) * (mul_coeff**3) 255 | return weights_ratio 256 | 257 | 258 | def show_masks(state_dict): 259 | 260 | if 'bert.encoder.layer.0.attention.self.query.bias_mask' in state_dict: 261 | print("=====VO======") 262 | qk_mask_list = torch.stack([state_dict[f'bert.encoder.layer.{i}.attention.self.query.bias_mask'] for i in range(12)]).int() 263 | vo_mask_list = torch.stack([state_dict[f'bert.encoder.layer.{i}.attention.self.value.bias_mask'] for i in range(12)]).int() 264 | qk_head_size_list = [t.reshape(12,64).sum(-1) for t in qk_mask_list] 265 | vo_head_size_list = vo_mask_list.reshape(12,12,64).sum(-1) 266 | for i in range(12): 267 | print(f"{i}: {[i for i in vo_head_size_list[i].tolist() if i >0]}, {vo_head_size_list[i].sum().item()}, {(vo_head_size_list[i]>0).sum().item()}") 268 | print(f"avg head size: {(vo_head_size_list).sum().item()/(vo_head_size_list>0).sum().item():.2f}") 269 | print("Total number of heads:",(vo_head_size_list>0).sum().item()) 270 | print("Total number of MHA layer:",(vo_head_size_list.sum(-1)>0).sum().item()) 271 | 272 | elif 'bert.encoder.layer.0.attention.output.dense.weight_mask' in state_dict: 273 | print("=====HEAD=====") 274 | head_mask_list = torch.stack([state_dict[f'bert.encoder.layer.{i}.attention.output.dense.weight_mask'] for i in range(12)]).int() 275 | number_heads_per_layer = (head_mask_list[:,0,:].view(12,12,64).sum(-1)==64).sum(-1) 276 | print("heads per layer:",number_heads_per_layer.tolist()) 277 | print("Total number of heads:",(number_heads_per_layer).sum().item()) 278 | print("Total number of MHA layer:",(number_heads_per_layer>0).sum().item()) 279 | 280 | if 'bert.encoder.layer.0.output.dense.weight_mask' in state_dict: 281 | print("=====FFN======") 282 | ffn_mask_list = torch.stack([state_dict[f'bert.encoder.layer.{i}.output.dense.weight_mask'][0] for i in range(12)]).int() 283 | print(f"FFN size/12: {ffn_mask_list.sum(-1).tolist()} {(ffn_mask_list).sum().item()/12:.1f}") 284 | print("Total number of FFN layers:",(ffn_mask_list.sum(-1)>0).sum().item()) 285 | 286 | from torch import nn 287 | import types 288 | 289 | def feed_forward_chunk_for_empty_ffn(self, attention_output): 290 | layer_output = self.output(attention_output) 291 | return layer_output 292 | 293 | def output_forward(self, input_tensor): 294 | #dropped_bias = self.dropout(self.dense.bias) 295 | #return self.LayerNorm(dropped_bias + input_tensor) 296 | return self.LayerNorm(self.dense.bias + input_tensor) 297 | 298 | def attetion_forward_for_empty_attention(self, 299 | hidden_states, 300 | attention_mask=None, 301 | head_mask=None, 302 | encoder_hidden_states=None, 303 | encoder_attention_mask=None, 304 | past_key_value=None, 305 | output_attentions=False): 306 | #dropped_bias = self.output.dropout(self.output.dense.bias) 307 | hidden_states = self.output.LayerNorm(self.output.dense.bias + hidden_states) 308 | return (hidden_states,) 309 | 310 | def transform(model: nn.Module, always_ffn=False, always_mha=False): 311 | base_model = model.base_model 312 | bert_layers = base_model.encoder.layer 313 | for layer in bert_layers: 314 | output = layer.output 315 | if always_ffn or output.dense.weight.numel()==0: #empty ffn 316 | print("replace ffn") 317 | layer.feed_forward_chunk = types.MethodType(feed_forward_chunk_for_empty_ffn,layer) 318 | layer.output.forward = types.MethodType(output_forward,layer.output) 319 | attention_output = layer.attention.output 320 | if always_mha or attention_output.dense.weight.numel()==0: #empty attention 321 | print("replace mha") 322 | layer.attention.forward = types.MethodType(attetion_forward_for_empty_attention,layer.attention) 323 | 324 | 325 | 326 | 327 | def fact_embedding_forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0): 328 | if input_ids is not None: 329 | input_shape = input_ids.size() 330 | else: 331 | input_shape = inputs_embeds.size()[:-1] 332 | 333 | seq_length = input_shape[1] 334 | 335 | if position_ids is None: 336 | position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] 337 | 338 | if token_type_ids is None: 339 | if hasattr(self, "token_type_ids"): 340 | buffered_token_type_ids = self.token_type_ids[:, :seq_length] 341 | buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) 342 | token_type_ids = buffered_token_type_ids_expanded 343 | else: 344 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) 345 | 346 | if inputs_embeds is None: 347 | inputs_embeds = self.word_embeddings(input_ids) 348 | inputs_embeds = self.proj(inputs_embeds) 349 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 350 | 351 | embeddings = inputs_embeds + token_type_embeddings 352 | if self.position_embedding_type == "absolute": 353 | position_embeddings = self.position_embeddings(position_ids) 354 | embeddings += position_embeddings 355 | embeddings = self.LayerNorm(embeddings) 356 | embeddings = self.dropout(embeddings) 357 | return embeddings 358 | 359 | 360 | def transform_embed(model: nn.Module,dim=128): 361 | if dim==0: 362 | return 363 | print('Word embedding reduced dim:',dim) 364 | base_model = model.base_model 365 | embedding_layer = base_model.embeddings 366 | u,s,v = torch.linalg.svd(embedding_layer.word_embeddings.weight) 367 | sm = torch.vstack([torch.diag(s),torch.zeros(u.size(0)-s.size(0),s.size(0))]) 368 | sm128=sm[:,:dim] 369 | v128 = v[:dim] 370 | reduced_embeddings =u@sm128 371 | print(reduced_embeddings.shape, v128.shape) 372 | 373 | embedding_layer.proj = torch.nn.Linear(in_features=dim,out_features=embedding_layer.word_embeddings.weight.size(1),bias=None) 374 | 375 | vocab_size = embedding_layer.word_embeddings.num_embeddings 376 | pad_token_id = embedding_layer.word_embeddings.padding_idx 377 | embedding_layer.word_embeddings = nn.Embedding(vocab_size, dim, padding_idx=pad_token_id) 378 | embedding_layer.word_embeddings.weight.data = reduced_embeddings 379 | embedding_layer.proj.weight.data = v128.t() 380 | embedding_layer.forward = types.MethodType(fact_embedding_forward,embedding_layer) 381 | -------------------------------------------------------------------------------- /scripts/run_glue.sh: -------------------------------------------------------------------------------- 1 | TASK=sst2 2 | 3 | OUTPUT_ROOT_DIR=pruned_models 4 | DATA_ROOT_DIR=/path/to/glue/datasets 5 | 6 | teacher_model_path=teacher_models/${TASK}/pytorch_model.bin 7 | 8 | IS_alpha_head=3e-1 9 | accu=1 10 | ngpu=1 11 | batch_size=32 12 | length=128 13 | ep=20 14 | lr=3 15 | seed=1337 16 | weights_ratio=05 17 | 18 | taskname=${TASK} 19 | DATA_DIR=${DATA_ROOT_DIR}/${taskname} 20 | end_at=0.4 21 | pf=1 22 | IS_beta=0.998 23 | embsize=192 24 | NAME=lr${lr}e${ep}_s${i}_bs${batch_size}_${end_at}_pf${pf}_IS${IS_beta}_Reg${IS_alpha_head}_E${embsize} 25 | OUTPUT_DIR=${OUTPUT_ROOT_DIR}/bert/${taskname}-${weights_ratio}/${NAME} 26 | 27 | mkdir -p $OUTPUT_DIR 28 | model_config_json_file=bert_base_uncased.json 29 | 30 | python -u glue/main.py \ 31 | --data_dir $DATA_DIR \ 32 | --do_train \ 33 | --do_eval \ 34 | --do_predict \ 35 | --max_seq_length ${length} \ 36 | --train_batch_size ${batch_size} \ 37 | --seed $seed \ 38 | --num_train_epochs ${ep} \ 39 | --learning_rate ${lr}e-5 \ 40 | --ckpt_frequency 2 \ 41 | --output_dir $OUTPUT_DIR \ 42 | --gradient_accumulation_steps ${accu} \ 43 | --taskname ${taskname} \ 44 | --model_spec_file ${model_config_json_file} \ 45 | --teacher_model_path $teacher_model_path \ 46 | --max_grad_norm 1 \ 47 | --fp16 \ 48 | --end_pruning_at ${end_at} \ 49 | --end_weights_ratio 0.${weights_ratio} \ 50 | --pruning_frequency ${pf} \ 51 | --IS_beta ${IS_beta} \ 52 | --is_global \ 53 | --output_hidden_states \ 54 | --matching_layers_S 0,2,4,6,8,10,12 \ 55 | --matching_layers_T 0,2,4,6,8,10,12 \ 56 | --IS_alpha_head ${IS_alpha_head} \ 57 | --pruner_type FineISPruner \ 58 | --transform_embed ${embsize} \ 59 | -------------------------------------------------------------------------------- /scripts/run_squad.sh: -------------------------------------------------------------------------------- 1 | OUTPUT_ROOT_DIR=pruned_models 2 | DATA_ROOT_DIR=datasets/squad 3 | 4 | teacher_model_path=teacher_models/squad/pytorch_model.bin 5 | 6 | IS_alpha_head=3e-1 7 | accu=1 8 | ngpu=1 9 | batch_size=32 10 | length=384 11 | ep=20 12 | lr=3 13 | seed=1337 14 | weights_ratio=05 15 | 16 | taskname=squad 17 | DATA_DIR=${DATA_ROOT_DIR}/${taskname} 18 | end_at=0.4 19 | pf=1 20 | IS_beta=0.998 21 | embsize=192 22 | NAME=lr${lr}e${ep}_s${i}_bs${batch_size}_${end_at}_pf${pf}_IS${IS_beta}_Reg${IS_alpha_head}_E${embsize} 23 | OUTPUT_DIR=${OUTPUT_ROOT_DIR}/bert/${taskname}-${weights_ratio}/${NAME} 24 | 25 | mkdir -p $OUTPUT_DIR 26 | model_config_json_file=bert_base_uncased.json 27 | 28 | python -u squad/main.py \ 29 | --train_file $DATA_ROOT_DIR/squad.translate.train.en.json \ 30 | --test_file $DATA_ROOT_DIR/squad.translate.dev.en.json \ 31 | --do_train \ 32 | --do_predict \ 33 | --max_seq_length ${length} \ 34 | --train_batch_size ${batch_size} \ 35 | --seed $seed \ 36 | --num_train_epochs ${ep} \ 37 | --learning_rate ${lr}e-5 \ 38 | --ckpt_frequency 2 \ 39 | --output_dir $OUTPUT_DIR \ 40 | --gradient_accumulation_steps ${accu} \ 41 | --model_spec_file ${model_config_json_file} \ 42 | --teacher_model_path $teacher_model_path \ 43 | --max_grad_norm 1 \ 44 | --fp16 \ 45 | --do_lower_case \ 46 | --end_pruning_at ${end_at} \ 47 | --end_weights_ratio 0.${weights_ratio} \ 48 | --pruning_frequency ${pf} \ 49 | --IS_beta ${IS_beta} \ 50 | --is_global \ 51 | --output_hidden_states \ 52 | --matching_layers_S 0,2,4,6,8,10,12 \ 53 | --matching_layers_T 0,2,4,6,8,10,12 \ 54 | --IS_alpha_head ${IS_alpha_head} \ 55 | --pruner_type FineISPruner \ 56 | --transform_embed ${embsize} \ 57 | -------------------------------------------------------------------------------- /scripts/squad/config_prunedistiller.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | from transformers import BertConfig, BertTokenizer, BertForQuestionAnswering 5 | MODEL_CLASSES = { 6 | "bert": (BertConfig, BertTokenizer, BertForQuestionAnswering ), 7 | } 8 | 9 | def parse_specs(speclist): # list of specifications 10 | if isinstance(speclist,str): 11 | with open(speclist,'r') as f: 12 | speclist = json.load(f) 13 | else: 14 | assert isinstance(speclist,dict) 15 | for item in speclist: 16 | model_type = item['model_type'] 17 | config_class, tokenizer_class, model_class = MODEL_CLASSES[model_type] 18 | 19 | item['model_class'] = model_class 20 | 21 | if item['config_file'] is not None: 22 | config = config_class.from_json_file(item['config_file']) 23 | else: 24 | config = None 25 | item['config'] = config 26 | 27 | if item['vocab_file'] is not None: 28 | kwargs = item.get('tokenizer_kwargs',{}) 29 | tokenizer = tokenizer_class(vocab_file=item['vocab_file'],**kwargs) 30 | else: 31 | tokenizer= None 32 | item['tokenizer'] = tokenizer 33 | 34 | return speclist 35 | 36 | def parse_args(opt=None): 37 | parser = argparse.ArgumentParser() 38 | 39 | parser.add_argument("--output_dir", default=None, type=str, required=True, 40 | help="The output directory where the model checkpoints will be written.") 41 | 42 | ## Other parameters 43 | parser.add_argument("--train_file", default=None, type=str) 44 | parser.add_argument("--test_file", default=None, type=str) 45 | parser.add_argument("--max_seq_length", default=384, type=int) 46 | parser.add_argument("--max_query_length",default=64, type=int) 47 | parser.add_argument("--max_answer_length",default=30,type=int) 48 | parser.add_argument("--do_train", default=False, action='store_true') 49 | parser.add_argument("--do_predict", default=False, action='store_true', help="Whether to run eval on the dev set.") 50 | parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") 51 | parser.add_argument("--predict_batch_size", default=8, type=int, help="Total batch size for predictions.") 52 | parser.add_argument("--learning_rate", default=3e-5, type=float, help="The initial learning rate for Adam.") 53 | parser.add_argument("--num_train_epochs", default=3.0, type=float, 54 | help="Total number of training epochs to perform.") 55 | parser.add_argument("--warmup_proportion", default=0.1, type=float, 56 | help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% " 57 | "of training.") 58 | parser.add_argument("--no_cuda", 59 | default=False, 60 | action='store_true', 61 | help="Whether not to use CUDA when available") 62 | parser.add_argument('--gradient_accumulation_steps', 63 | type=int, 64 | default=1, 65 | help="Number of updates steps to accumualte before performing a backward/update pass.") 66 | parser.add_argument("--local_rank", 67 | type=int, 68 | default=-1, 69 | help="local_rank for distributed training on gpus") 70 | parser.add_argument('--fp16', 71 | default=False, 72 | action='store_true', 73 | help="Whether to use 16-bit float precisoin instead of 32-bit") 74 | 75 | parser.add_argument('--seed',type=int,default=10236797) 76 | parser.add_argument('--weight_decay_rate',type=float,default=0.01) 77 | parser.add_argument('--do_eval',action='store_true') 78 | parser.add_argument('--do_test',action='store_true') 79 | parser.add_argument('--PRINT_EVERY',type=int,default=200) 80 | parser.add_argument('--ckpt_frequency',type=int,default=2) 81 | 82 | parser.add_argument('--model_spec_file',type=str) 83 | parser.add_argument('--teacher_model_path',type=str) 84 | parser.add_argument('--max_grad_norm',type=float,default=1.0) 85 | 86 | parser.add_argument('--n_best_size',default=20,type=int) 87 | parser.add_argument('--doc_stride',default=128,type=int) 88 | parser.add_argument("--null_score_diff_threshold",type=float,default=0.0) 89 | parser.add_argument("--version_2_with_negative",action='store_true') 90 | parser.add_argument("--threads",type=int,default=4) 91 | parser.add_argument("--do_lower_case",action='store_true') #used in decoding? 92 | parser.add_argument("--adam_epsilon",default=1e-6,type=float) 93 | parser.add_argument("--is_save_logits",action='store_true') 94 | 95 | parser.add_argument("--end_pruning_at",default=0.7,type=float) 96 | parser.add_argument("--start_pruning_at",default=0.2,type=float) 97 | 98 | parser.add_argument("--end_weights_ratio",default=0.33,type=float) 99 | parser.add_argument("--pruning_frequency",default=50,type=int) 100 | parser.add_argument("--pruner_type",default="Pruner",type=str) 101 | parser.add_argument("--IS_beta",default=0.99,type=float) 102 | parser.add_argument("--is_global",action='store_true') 103 | parser.add_argument("--is_reweight",type=float,default=1) 104 | parser.add_argument("--is_two_ratios",action='store_true') 105 | parser.add_argument("--FFN_weights_ratio",default=None,type=float) 106 | parser.add_argument("--MHA_weights_ratio",default=None,type=float) 107 | parser.add_argument("--score_type",default='grad',type=str,choices=['grad','magnitude-sumabs','magnitude-Linf','magnitude-L1','random']) 108 | parser.add_argument("--output_hidden_states",action='store_true') 109 | parser.add_argument("--dynamic_head_size",action='store_true') 110 | 111 | parser.add_argument("--matching_layers_S",type=str,default=None) 112 | parser.add_argument("--matching_layers_T",type=str,default=None) 113 | 114 | parser.add_argument("--IS_gamma",default=0,type=float) 115 | parser.add_argument("--IS_alpha",default=0.0001,type=float) 116 | parser.add_argument("--IS_alpha_head",default=None,type=float) 117 | parser.add_argument("--IS_alpha_ffn",default=None,type=float) 118 | parser.add_argument("--IS_alpha_mha",default=None,type=float) 119 | parser.add_argument("--no_dbw",action='store_true') 120 | parser.add_argument("--transform_embed",default=0,type=int) 121 | 122 | global args 123 | if opt is None: 124 | args = parser.parse_args() 125 | else: 126 | args = parser.parse_args(opt) 127 | return args 128 | 129 | if __name__ == '__main__': 130 | print (args) 131 | parse_args(['--SAVE_DIR','test']) 132 | print(args) 133 | -------------------------------------------------------------------------------- /scripts/squad/evaluate_squad.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Based on the SQuAD evaluation script from: 3 | # https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py 16 | """ Official evaluation script for v1.1 of the SQuAD dataset. """ 17 | from __future__ import print_function 18 | from collections import Counter 19 | import string 20 | import re 21 | import argparse 22 | import json 23 | import sys 24 | 25 | 26 | def normalize_answer(s): 27 | """Lower text and remove punctuation, articles and extra whitespace.""" 28 | def remove_articles(text): 29 | return re.sub(r'\b(a|an|the)\b', ' ', text) 30 | 31 | def white_space_fix(text): 32 | return ' '.join(text.split()) 33 | 34 | def remove_punc(text): 35 | exclude = set(string.punctuation) 36 | return ''.join(ch for ch in text if ch not in exclude) 37 | 38 | def lower(text): 39 | return text.lower() 40 | 41 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 42 | 43 | 44 | def f1_score(prediction, ground_truth): 45 | prediction_tokens = normalize_answer(prediction).split() 46 | ground_truth_tokens = normalize_answer(ground_truth).split() 47 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 48 | num_same = sum(common.values()) 49 | if num_same == 0: 50 | return 0 51 | precision = 1.0 * num_same / len(prediction_tokens) 52 | recall = 1.0 * num_same / len(ground_truth_tokens) 53 | f1 = (2 * precision * recall) / (precision + recall) 54 | return f1 55 | 56 | 57 | def exact_match_score(prediction, ground_truth): 58 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 59 | 60 | 61 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 62 | scores_for_ground_truths = [] 63 | for ground_truth in ground_truths: 64 | score = metric_fn(prediction, ground_truth) 65 | scores_for_ground_truths.append(score) 66 | return max(scores_for_ground_truths) 67 | 68 | 69 | def evaluate(dataset, predictions): 70 | f1 = exact_match = total = 0 71 | for article in dataset: 72 | for paragraph in article['paragraphs']: 73 | for qa in paragraph['qas']: 74 | total += 1 75 | if qa['id'] not in predictions: 76 | message = 'Unanswered question ' + qa['id'] + \ 77 | ' will receive score 0.' 78 | print(message, file=sys.stderr) 79 | continue 80 | ground_truths = list(map(lambda x: x['text'], qa['answers'])) 81 | prediction = predictions[qa['id']] 82 | exact_match += metric_max_over_ground_truths( 83 | exact_match_score, prediction, ground_truths) 84 | f1 += metric_max_over_ground_truths( 85 | f1_score, prediction, ground_truths) 86 | 87 | exact_match = 100.0 * exact_match / total 88 | f1 = 100.0 * f1 / total 89 | 90 | return {'exact_match': exact_match, 'f1': f1} 91 | 92 | 93 | if __name__ == '__main__': 94 | expected_version = '1.1' 95 | parser = argparse.ArgumentParser( 96 | description='Evaluation for SQuAD ' + expected_version) 97 | parser.add_argument('dataset_file', help='Dataset file') 98 | parser.add_argument('prediction_file', help='Prediction File') 99 | args = parser.parse_args() 100 | with open(args.dataset_file) as dataset_file: 101 | dataset_json = json.load(dataset_file) 102 | if (dataset_json['version'] != expected_version): 103 | print('Evaluation expects v-' + expected_version + 104 | ', but got dataset with v-' + dataset_json['version'], 105 | file=sys.stderr) 106 | dataset = dataset_json['data'] 107 | with open(args.prediction_file) as prediction_file: 108 | predictions = json.load(prediction_file) 109 | print(json.dumps(evaluate(dataset, predictions))) 110 | -------------------------------------------------------------------------------- /scripts/squad/main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | # Setup logging 3 | logging.basicConfig( 4 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 5 | level=logging.INFO,) 6 | logger = logging.getLogger(__name__) 7 | 8 | import os,random 9 | import numpy as np 10 | import torch 11 | from transformers import AdamW, get_linear_schedule_with_warmup 12 | from pruners_and_distiller.distiller import TrainingConfig, DistillationConfig, PruningConfig 13 | from pruners_and_distiller.distiller import PruneDistiller as PruneDistillerHidden 14 | from pruners_and_distiller.utils import show_masks, transform_embed 15 | from torch.utils.data import DataLoader, RandomSampler 16 | from functools import partial 17 | from utils import predict, MultilingualSQuADDataset 18 | from config_prunedistiller import parse_specs, parse_args 19 | 20 | def set_seed(args): 21 | random.seed(args.seed) 22 | np.random.seed(args.seed) 23 | torch.manual_seed(args.seed) 24 | if args.n_gpu > 0: 25 | torch.cuda.manual_seed_all(args.seed) 26 | 27 | def args_check(args): 28 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: 29 | logger.warning(f"Output directory ({args.output_dir}) already exists and is not empty.") 30 | # Setup CUDA, GPU & distributed training 31 | if args.local_rank == -1 or args.no_cuda: 32 | if not args.no_cuda and not torch.cuda.is_available(): 33 | raise ValueError("No CUDA available!") 34 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 35 | args.n_gpu = torch.cuda.device_count() if not args.no_cuda else 0 36 | else: 37 | # Initializes the distributed backend which sychronizes nodes/GPUs 38 | #torch.cuda.set_device(args.local_rank) 39 | device = torch.device("cuda", args.local_rank) 40 | #torch.distributed.init_process_group(backend="nccl") 41 | args.n_gpu = 1 42 | args.device = device 43 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 44 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) 45 | return device, args.n_gpu 46 | 47 | def main(): 48 | args = parse_args() 49 | logger.setLevel(logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 50 | 51 | for k,v in vars(args).items(): 52 | logger.info(f"{k}:{v}") 53 | 54 | device, args.n_gpu = args_check(args) 55 | set_seed(args) 56 | os.makedirs(args.output_dir, exist_ok=True) 57 | 58 | if args.local_rank not in [-1, 0]: 59 | torch.distributed.barrier() 60 | 61 | #Build Model and load checkpoint 62 | speclist = parse_specs(args.model_spec_file) 63 | spec = speclist[0] 64 | config, tokenizer, model_class = spec['config'], spec['tokenizer'], spec['model_class'] 65 | ckpt_file = spec['ckpt_file'] 66 | prefix = spec['prefix'] 67 | if args.output_hidden_states is True: 68 | config.output_hidden_states=True 69 | model = model_class.from_pretrained(ckpt_file,config=config) 70 | if args.transform_embed>0: 71 | transform_embed(model,args.transform_embed) 72 | state_dict = torch.load(args.teacher_model_path,map_location='cpu') 73 | model_T = model_class.from_pretrained(None,config=config,state_dict=state_dict) 74 | 75 | if args.local_rank == 0: 76 | torch.distributed.barrier() 77 | 78 | #read data 79 | train_dataset = None 80 | num_train_steps = None 81 | 82 | train_langs = ['en'] 83 | if args.do_train: 84 | train_dataset = MultilingualSQuADDataset(args, train_langs,'train', prefix, tokenizer) 85 | 86 | if args.do_predict: 87 | eval_langs = ['en'] 88 | 89 | split = 'test' if args.do_test else 'dev' 90 | assert split =='dev' or split=='test' 91 | eval_dataset = MultilingualSQuADDataset(args, eval_langs, split, prefix, tokenizer) 92 | 93 | logger.info("Data loaded") 94 | 95 | callback_func = None 96 | if args.do_predict: 97 | callback_func = partial(predict, eval_dataset=eval_dataset, args=args,tokenizer=tokenizer) 98 | if args.do_train: 99 | forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) 100 | args.forward_batch_size = forward_batch_size 101 | train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=args.forward_batch_size,drop_last=True) 102 | 103 | 104 | def AdaptorTrain(batch, model_outputs): 105 | return {'losses': (model_outputs[0])} 106 | def batch_postprocessor(batch): 107 | batch = { "input_ids": batch[0], 108 | "attention_mask": batch[1], 109 | "token_type_ids": batch[2], 110 | "start_positions": batch[3], 111 | "end_positions": batch[4]} 112 | return batch 113 | 114 | def AdaptorLogits(batch, model_outputs): 115 | return {'logits': (model_outputs.start_logits,model_outputs.end_logits)} 116 | def AdaptorLogitsHidden(batch, model_outputs): 117 | return {'logits': (model_outputs.start_logits,model_outputs.end_logits), 118 | 'hidden': (model_outputs.hidden_states), 119 | 'inputs_mask': batch['attention_mask']} 120 | if args.output_hidden_states is True: 121 | print("use hidden") 122 | Adaptor = AdaptorLogitsHidden 123 | else: 124 | Adaptor = AdaptorLogits 125 | PruneDistiller = PruneDistillerHidden 126 | 127 | #parameters 128 | params = list(model.named_parameters()) 129 | #all_trainable_params = divide_parameters(params, lr=args.learning_rate) 130 | no_decay = ['bias','LayerNorm.weight'] 131 | large_lr = ['attention_head_scale'] 132 | all_trainable_params = [ 133 | { 134 | "params":[p for n,p in params if not any(nd in n for nd in no_decay)], 135 | "weight_decay": args.weight_decay_rate, 136 | }, 137 | { 138 | 'params': [p for n,p in params if any(nd in n for nd in no_decay)], 139 | 'weight_decay':0.0 140 | }, 141 | ] 142 | logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) 143 | 144 | ########## PruneDistiller ########### 145 | train_config = TrainingConfig( 146 | gradient_accumulation_steps = args.gradient_accumulation_steps, 147 | ckpt_frequency = args.ckpt_frequency, 148 | #ckpt_steps = int(num_train_steps//args.num_train_epochs//2), 149 | log_dir = args.output_dir, 150 | output_dir = args.output_dir, 151 | fp16 = args.fp16, 152 | device = args.device) 153 | if args.matching_layers_S is not None: 154 | matching_layers = list(zip(map(int,args.matching_layers_S.split(',')),map(int,args.matching_layers_T.split(',')))) 155 | else: 156 | matching_layers = None 157 | distill_config = DistillationConfig(temperature=8, 158 | matching_layers=matching_layers) 159 | prune_config = PruningConfig(end_pruning_at=args.end_pruning_at, start_pruning_at=args.start_pruning_at, 160 | end_weights_ratio=args.end_weights_ratio, 161 | pruning_frequency=args.pruning_frequency, 162 | IS_beta=args.IS_beta, is_global=args.is_global, is_reweight=args.is_reweight, 163 | is_two_ratios=args.is_two_ratios,FFN_weights_ratio=args.FFN_weights_ratio,MHA_weights_ratio=args.MHA_weights_ratio, 164 | score_type=args.score_type, 165 | pruner_type=args.pruner_type, 166 | dynamic_head_size=args.dynamic_head_size, 167 | IS_gamma=args.IS_gamma, 168 | IS_alpha=args.IS_alpha, 169 | IS_alpha_head=args.IS_alpha_head, 170 | IS_alpha_ffn=args.IS_alpha_ffn, 171 | IS_alpha_mha=args.IS_alpha_mha, 172 | dbw=(not args.no_dbw) 173 | ) 174 | distiller = PruneDistiller(train_config = train_config, distill_config=distill_config, 175 | prune_config=prune_config, model_T = model_T, model_S = model, 176 | adaptor = Adaptor) 177 | num_train_steps = int(len(train_dataloader)//args.gradient_accumulation_steps * args.num_train_epochs) 178 | optimizer = AdamW(all_trainable_params,lr=args.learning_rate,eps=args.adam_epsilon) 179 | scheduler_args = {'num_warmup_steps': int(args.warmup_proportion*num_train_steps), 180 | 'num_training_steps': num_train_steps} 181 | scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,**scheduler_args) 182 | 183 | logger.info("***** Running Prune Distiller *****") 184 | logger.info(" Num examples = %d", len(train_dataset)) 185 | logger.info(" Forward batch size = %d", forward_batch_size) 186 | logger.info(" Num backward steps = %d", num_train_steps) 187 | 188 | 189 | with distiller: 190 | distiller.train(train_dataloader, optimizer, scheduler, args.num_train_epochs, 191 | max_grad_norm=args.max_grad_norm, callback=callback_func, batch_postprocessor=batch_postprocessor) 192 | del optimizer 193 | logger.info("*********************Prune Distiller Finished*****************") 194 | 195 | if not args.do_train and args.do_predict: 196 | model.to(device) 197 | res = predict(model, eval_dataset=eval_dataset, 198 | args=args, tokenizer=tokenizer,step=0) 199 | print (res) 200 | 201 | show_masks(model.state_dict()) 202 | 203 | if __name__ == "__main__": 204 | main() 205 | -------------------------------------------------------------------------------- /scripts/squad/utils.py: -------------------------------------------------------------------------------- 1 | #remove token_type_ids and p_mask 2 | import re 3 | import timeit 4 | from tqdm import tqdm 5 | import os, json 6 | import logging 7 | logger = logging.getLogger(__name__) 8 | import torch 9 | from torch.utils.data import Dataset, ConcatDataset 10 | from torch.utils.data import SequentialSampler, DataLoader 11 | from typing import List 12 | from transformers.data.metrics.squad_metrics import ( 13 | compute_predictions_log_probs, 14 | compute_predictions_logits, 15 | squad_evaluate, 16 | ) 17 | 18 | from transformers.data.processors.squad import ( 19 | SquadResult, 20 | SquadV1Processor, 21 | SquadV2Processor, 22 | squad_convert_examples_to_features 23 | ) 24 | 25 | class MultilingualSQuADDataset(Dataset): 26 | def __init__(self, args, langs: List[str], split: str, prefix: str, tokenizer=None): 27 | if args.local_rank not in [-1, 0]: 28 | torch.distributed.barrier() 29 | self.split = split 30 | 31 | max_seq_length = args.max_seq_length 32 | 33 | self.lang_datasets = {} 34 | self.lang_features = {} 35 | self.lang_examples = {} 36 | 37 | self.test_files = None 38 | if split=='train': 39 | self.data_files = {'en': args.train_file} 40 | else: 41 | self.data_files = {'en':args.test_file} 42 | self.data_dir = os.path.dirname(args.train_file) 43 | 44 | self.cached_features_files = {lang : os.path.join(self.data_dir, f'{prefix}_{split}_{max_seq_length}_{lang}') for lang in langs} 45 | 46 | for lang, cached_features_file in self.cached_features_files.items(): 47 | if os.path.exists(cached_features_file): 48 | logger.info("Loading features from cached file %s", cached_features_file) 49 | features_and_dataset = torch.load(cached_features_file) 50 | features, dataset, examples = features_and_dataset["features"], features_and_dataset["dataset"], features_and_dataset["examples"] 51 | else: 52 | logger.info("Creating features from dataset file at %s", cached_features_file) 53 | processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor() 54 | if split == 'train': 55 | print (self.data_files[lang]) 56 | examples = processor.get_train_examples(self.data_dir, filename=self.data_files[lang]) 57 | elif split == 'dev' or split=='test': 58 | print (self.data_files[lang]) 59 | examples = processor.get_dev_examples(self.data_dir, filename=self.data_files[lang]) 60 | else: 61 | raise ValueError 62 | 63 | features, dataset = squad_convert_examples_to_features( 64 | examples=examples, 65 | tokenizer=tokenizer, 66 | max_seq_length=args.max_seq_length, 67 | doc_stride=args.doc_stride, 68 | max_query_length=args.max_query_length, 69 | is_training=(split=='train'), 70 | return_dataset="pt", 71 | threads=args.threads 72 | ) 73 | 74 | if args.local_rank in [-1, 0]: 75 | logger.info("Saving features into cached file %s", cached_features_file) 76 | if split == 'train': 77 | examples = None 78 | features = None 79 | torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file) 80 | 81 | 82 | self.lang_datasets[lang] = dataset 83 | self.lang_features[lang] = features 84 | self.lang_examples[lang] = examples 85 | 86 | if args.local_rank == 0: 87 | torch.distributed.barrier() 88 | 89 | self.all_dataset = ConcatDataset(list(self.lang_datasets.values())) 90 | 91 | def __getitem__(self, index): 92 | return self.all_dataset[index] 93 | 94 | def __len__(self): 95 | return len(self.all_dataset) 96 | 97 | 98 | def to_list(tensor): 99 | return tensor.detach().cpu().tolist() 100 | 101 | def predict( model, eval_dataset, args, tokenizer, step, is_save_logits=False): 102 | lang_results = {} 103 | for lang in eval_dataset.lang_datasets: 104 | dataset = eval_dataset.lang_datasets[lang] 105 | examples = eval_dataset.lang_examples[lang] 106 | features = eval_dataset.lang_features[lang] 107 | # Note that DistributedSampler samples randomly 108 | eval_sampler = SequentialSampler(dataset) 109 | eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.predict_batch_size) 110 | 111 | # multi-gpu evaluate 112 | if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): 113 | model = torch.nn.DataParallel(model) 114 | 115 | # Eval! 116 | logger.info("***** Running evaluation {} {}*****".format(step, lang)) 117 | logger.info(" Num examples = %d", len(dataset)) 118 | logger.info(" Batch size = %d", args.predict_batch_size) 119 | 120 | all_results = [] 121 | start_time = timeit.default_timer() 122 | 123 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 124 | model.eval() 125 | batch = tuple(t.to(args.device) for t in batch) 126 | 127 | with torch.no_grad(): 128 | inputs = { 129 | "input_ids": batch[0], 130 | "attention_mask": batch[1], 131 | "token_type_ids": batch[2], #None if model_type in ["xlm", "distilbert", "xlmr",] else batch[2], 132 | } 133 | example_indices = batch[3] 134 | 135 | outputs = model(**inputs)[:2] #start logits and end logits 136 | 137 | for i, example_index in enumerate(example_indices): 138 | eval_feature = features[example_index.item()] 139 | unique_id = int(eval_feature.unique_id) 140 | 141 | output = [to_list(output[i]) for output in outputs] 142 | 143 | # Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler" 144 | # models only use two. 145 | if len(output) >= 5: 146 | start_logits = output[0] 147 | start_top_index = output[1] 148 | end_logits = output[2] 149 | end_top_index = output[3] 150 | cls_logits = output[4] 151 | 152 | result = SquadResult( 153 | unique_id, 154 | start_logits, 155 | end_logits, 156 | start_top_index=start_top_index, 157 | end_top_index=end_top_index, 158 | cls_logits=cls_logits, 159 | ) 160 | 161 | else: 162 | start_logits, end_logits = output 163 | result = SquadResult(unique_id, start_logits, end_logits) 164 | 165 | all_results.append(result) 166 | if is_save_logits: 167 | logger.info("Save logits") 168 | output_logits_file = os.path.join(args.output_dir, str(step), f"all_results-{lang}.logits") 169 | os.makedirs(os.path.dirname(output_logits_file),exist_ok=True) 170 | torch.save(all_results,output_logits_file) 171 | 172 | evalTime = timeit.default_timer() - start_time 173 | logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset)) 174 | 175 | # Compute predictions 176 | output_prediction_file = os.path.join(args.output_dir, str(step), f"test-{lang}.json") 177 | output_nbest_file = os.path.join(args.output_dir, str(step), f"nbest_predictions-{lang}.json") 178 | os.makedirs(os.path.dirname(output_prediction_file),exist_ok=True) 179 | if args.version_2_with_negative: 180 | output_null_log_odds_file = os.path.join(args.output_dir, str(step), "null_odds_{}.json".format(step)) 181 | else: 182 | output_null_log_odds_file = None 183 | 184 | predictions = compute_predictions_logits( 185 | examples, 186 | features, 187 | all_results, 188 | args.n_best_size, 189 | args.max_answer_length, 190 | args.do_lower_case, 191 | output_prediction_file, 192 | output_nbest_file, 193 | output_null_log_odds_file, 194 | False, #args.verbose_logging, 195 | args.version_2_with_negative, 196 | args.null_score_diff_threshold, 197 | tokenizer 198 | ) 199 | 200 | # Compute the F1 and exact scores. 201 | results = squad_evaluate(examples, predictions) 202 | logger.info("{} :Results: {}".format(lang, results)) 203 | lang_results[lang] = results 204 | 205 | eval_results_file = os.path.join(args.output_dir, 'eval_results.txt') 206 | with open(eval_results_file,'a') as f: 207 | line = f'Step {step} -- '+ ' '.join([f"{lang}:{results['f1']:.1f}/{results['exact']:.1f}" for lang, results in lang_results.items()]) 208 | avg_f1 = sum(results['f1'] for results in lang_results.values()) / len(lang_results) 209 | avg_em = sum(results['exact'] for results in lang_results.values()) / len(lang_results) 210 | line += f' avg:{avg_f1:.1f}/{avg_em:.1f}\n' 211 | f.write(line) 212 | 213 | return lang_results 214 | -------------------------------------------------------------------------------- /teacher_models/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 12, 11 | "type_vocab_size": 2, 12 | "vocab_size": 30522 13 | } 14 | -------------------------------------------------------------------------------- /textpruner/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.1" 2 | 3 | from .pruners import VocabularyPruner, TransformerPruner, PipelinePruner 4 | from .configurations import GeneralConfig, VocabularyPruningConfig, TransformerPruningConfig 5 | from .utils import summary, inference_time 6 | -------------------------------------------------------------------------------- /textpruner/commands/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/GRAIN/8d02105e583265f7385c052c7c532e55f24609d0/textpruner/commands/__init__.py -------------------------------------------------------------------------------- /textpruner/commands/functions.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | from ..pruners import VocabularyPruner, TransformerPruner, PipelinePruner 3 | from ..utils import summary 4 | from .utils import read_file_line_by_line 5 | import logging 6 | logger = logging.getLogger(__name__) 7 | 8 | def call_vocabulary_pruning(configurations, model, tokenizer, vocabulary_file): 9 | general_config = configurations["GeneralConfig"] 10 | vocabulary_pruning_config = configurations["VocabularyPruningConfig"] 11 | pruner = VocabularyPruner(model, tokenizer, vocabulary_pruning_config, general_config) 12 | texts,is_token_ids = read_file_line_by_line(vocabulary_file) 13 | if is_token_ids is False: 14 | output_dir = pruner.prune(dataiter=texts, save_model=True) 15 | else: 16 | output_dir = pruner.prune(additional_token_ids=texts, save_model=True) 17 | 18 | print("After pruning:") 19 | print(summary(model)) 20 | 21 | 22 | def call_transformer_pruning(configurations, model, dataloader, adaptor): 23 | general_config = configurations["GeneralConfig"] 24 | transformer_pruning_config = configurations["TransformerPruningConfig"] 25 | pruner = TransformerPruner(model, transformer_pruning_config, general_config) 26 | 27 | keep_shape = False 28 | if transformer_pruning_config.ffn_even_masking is False: 29 | logger.warning("ffn_even_masking is False. Cannot save pruned model with different ffn size. \ 30 | A full model with the relevant weights set to zero will be saved. \ 31 | You can save a pruned TorchScript model, \ 32 | use the textpruner.TransformerPruner.save_jit_model in your python script.") 33 | keep_shape = True 34 | output_dir = pruner.prune(dataloader=dataloader, adaptor=adaptor, keep_shape=keep_shape, save_model=True) 35 | print("After pruning:") 36 | print(summary(model)) 37 | 38 | 39 | def call_pipeling_pruning(configurations, model, tokenizer, vocabulary_file, dataloader, adaptor): 40 | general_config = configurations["GeneralConfig"] 41 | vocabulary_pruning_config = configurations["VocabularyPruningConfig"] 42 | transformer_pruning_config = configurations["TransformerPruningConfig"] 43 | pruner = PipelinePruner(model, tokenizer, 44 | transformer_pruning_config, 45 | vocabulary_pruning_config, 46 | general_config) 47 | texts,is_token_ids = read_file_line_by_line(vocabulary_file) 48 | keep_shape = False 49 | if transformer_pruning_config.ffn_even_masking is False: 50 | logger.warning("ffn_even_masking is False. Cannot save pruned model with different ffn size. \ 51 | A full model with the relevant weights set to zero will be saved. \ 52 | You can save a pruned TorchScript model, \ 53 | use the textpruner.TransformerPruner.save_jit_model in your python script.") 54 | keep_shape = True 55 | if is_token_ids is False: 56 | output_dir = pruner.prune(dataloader=dataloader, adaptor=adaptor, dataiter=texts, keep_shape=keep_shape, save_model=True) 57 | else: 58 | output_dir = pruner.prune(dataloader=dataloader, adaptor=adaptor, additional_token_ids=texts, keep_shape=keep_shape, save_model=True) 59 | 60 | print("After pruning:") 61 | print(summary(model)) -------------------------------------------------------------------------------- /textpruner/commands/textpruner_cli.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from ..configurations import Config 3 | from .functions import call_vocabulary_pruning, call_transformer_pruning, call_pipeling_pruning 4 | from .utils import create_configurations, create_model_and_tokenizer 5 | from .utils import create_dataloader_and_adaptor 6 | import logging 7 | logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def main(): 12 | parser = ArgumentParser("TextPruner CLI tool") 13 | 14 | parser.add_argument("--configurations", type=str, nargs='*', help="The configurations (json files) passed to the pruner. Seperate the filenames by space. \ 15 | TextPruner uses the default configurations if omitted") 16 | parser.add_argument("--pruning_mode", choices=['vocabulary','transformer','pipeline'], required=True, help="One of the three pruning modes.") 17 | parser.add_argument("--model_class", type=str,required=True, help="The class of your model. It must be accessible from the current directory.") 18 | parser.add_argument("--tokenizer_class", type=str,required=True, help="The class of your tokenizer. It must be accessible from the current directory.") 19 | parser.add_argument("--model_path", type=str, required=True, help="The directory where the weights and the configs of the pretrained model and the tokenizer locate.") 20 | parser.add_argument("--vocabulary", type=str, help="A text file that is used to count tokens for vocabulay pruning.") 21 | parser.add_argument("--dataloader_and_adaptor",type=str, help="The script that contains the dataloader and the adaptor. \ 22 | For example: foo/bar/dataloader_script.py or foo/bar/Processing.dataloader_script (in the latter case dataloader_script is in the package Processing.") 23 | args = parser.parse_args() 24 | 25 | 26 | # initialize model and tokenizer 27 | model, tokenizer = create_model_and_tokenizer(model_class_name=args.model_class, 28 | tokenizer_class_name=args.tokenizer_class, 29 | model_path = args.model_path) 30 | 31 | 32 | # initialize configurations 33 | configurations = create_configurations(args.configurations) 34 | 35 | 36 | # import functions 37 | dataloader, adaptor = create_dataloader_and_adaptor(args.dataloader_and_adaptor) 38 | 39 | 40 | 41 | if args.pruning_mode == 'vocabulary': 42 | call_vocabulary_pruning(configurations, model, tokenizer, args.vocabulary) 43 | elif args.pruning_mode == 'transformer': 44 | call_transformer_pruning(configurations, model, dataloader, adaptor) 45 | elif args.pruning_mode == 'pipeline': 46 | call_pipeling_pruning(configurations, model, tokenizer, args.vocabulary, dataloader, adaptor) 47 | 48 | 49 | if __name__ == '__main__': 50 | main() -------------------------------------------------------------------------------- /textpruner/commands/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from ..configurations import Config 3 | import importlib 4 | import importlib.machinery 5 | import sys,os 6 | sys.path.append(os.getcwd()) 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def import_factory(model_class_name: str): 11 | module_name, class_name = model_class_name.rsplit(".", 1) 12 | module = importlib.import_module(module_name,package=None) 13 | try: 14 | cls = getattr(module, class_name) 15 | except AttributeError: 16 | logger.info(f"Cannot get {class_name} for {module}, return None") 17 | cls = None 18 | return cls 19 | 20 | 21 | def get_class(class_name: str): 22 | if len(class_name.split('.'))==1: 23 | class_name = 'transformers.' + class_name 24 | return import_factory(class_name) 25 | 26 | 27 | def create_from_class(model_class_name: str, model_path: str): 28 | model_class = get_class(model_class_name) 29 | model = model_class.from_pretrained(model_path) 30 | return model 31 | 32 | 33 | def read_file_line_by_line(texts_file: str): 34 | ''' 35 | Read the file line by line. if the file contains only digits, treat the digits as the token_ids. 36 | 37 | Args: 38 | text_file : a text file that contains texts or ids. 39 | 40 | Returns: 41 | (List[str], False) if the file contains normal texts; (List[int], True) if the file contains only digits. 42 | ''' 43 | lines = [] 44 | is_token_ids = False 45 | with open(texts_file,'r') as f: 46 | for line in f: 47 | sline = line.strip() 48 | if len(sline)>0: 49 | lines.append(sline) 50 | try: 51 | token_ids = [int(token) for token in lines] 52 | except ValueError: 53 | is_token_ids = False 54 | return lines, is_token_ids 55 | logger.info("All contexts are digits. Treat them as the token ids.") 56 | is_token_ids = True 57 | return token_ids, is_token_ids 58 | 59 | 60 | def create_configurations(configurations_list): 61 | configurations_dict = {"GeneralConfig": None, "VocabularyPruningConfig": None, "TransformerPruningConfig": None} 62 | 63 | if configurations_list is not None: 64 | for configuration_file in configurations_list: 65 | configuration = Config.from_json(configuration_file) 66 | configurations_dict[configuration.config_class] = configuration 67 | return configurations_dict 68 | 69 | 70 | def create_model_and_tokenizer(model_class_name: str, tokenizer_class_name: str, model_path: str): 71 | model = create_from_class(model_class_name, model_path) 72 | tokenizer = create_from_class(tokenizer_class_name, model_path) 73 | return model, tokenizer 74 | 75 | 76 | def create_dataloader_and_adaptor(dataloader_and_adaptor_script: str): 77 | if dataloader_and_adaptor_script is None: 78 | return None, None 79 | if os.path.sep in dataloader_and_adaptor_script: 80 | dirname = os.path.dirname(dataloader_and_adaptor_script) 81 | filename = os.path.basename(dataloader_and_adaptor_script) 82 | if filename.endswith('.py'): 83 | filename = filename[:-3] 84 | sys.path.insert(0, os.path.abspath(dirname)) 85 | dataloader_name = filename + '.dataloader' 86 | adaptor_name = filename + '.adaptor' 87 | else: 88 | dataloader_name = dataloader_and_adaptor_script + '.dataloader' 89 | adaptor_name = dataloader_and_adaptor_script + '.adaptor' 90 | 91 | dataloader = import_factory(dataloader_name) 92 | adaptor = import_factory(adaptor_name) 93 | 94 | return dataloader, adaptor -------------------------------------------------------------------------------- /textpruner/configurations.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict 2 | import torch 3 | import json 4 | import logging 5 | from typing import Union, Optional 6 | from dataclasses import dataclass, asdict 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | 11 | @dataclass 12 | class Config: 13 | """Base class for :class:`~textpruner.configurations.GeneralConfig`, 14 | :class:`~textpruner.configurations.VocabularyPruningConfig` and :class:`~textpruner.configurations.TransformerPruningConfig`.""" 15 | 16 | @classmethod 17 | def from_json(cls, json_filename: str): 18 | """Construct the configuration from a json file.""" 19 | with open(json_filename,'r') as f: 20 | config_map = json.load(f) 21 | config = CONFIG_CLASS[config_map['config_class']].from_dict(config_map) 22 | return config 23 | 24 | @classmethod 25 | def from_dict(cls, config_map: dict): 26 | """Construct the configuration from a dict.""" 27 | config = CONFIG_CLASS[config_map['config_class']](**config_map) 28 | return config 29 | 30 | 31 | def save_to_json(self, json_filename: str): 32 | """Save the configuration the a json file.""" 33 | config_map = asdict(self) 34 | with open(json_filename,'w') as f: 35 | json.dump(config_map, f, indent = 2) 36 | 37 | 38 | @dataclass 39 | class GeneralConfig(Config): 40 | 41 | ''' 42 | Configurations for the device and the output directory. 43 | 44 | Args: 45 | device: ``'cpu'`` or ``'cuda'`` or ``'cuda:0'`` etc. Specify which device to use. If it is set to ``'auto'``, 46 | TextPruner will try to use the CUDA device if there is one; otherwise uses CPU. 47 | output_dir: The diretory to save the pruned models. 48 | config_class: Type of the configurations. Users should not change its value. 49 | ''' 50 | use_device: str = 'auto' 51 | output_dir: str = './pruned_models' 52 | config_class : str = "GeneralConfig" 53 | def __post_init__(self): 54 | if self.use_device == 'auto': 55 | if torch.cuda.is_available(): 56 | logger.info(f"Using current cuda device") 57 | self.device = ('cuda') 58 | else: 59 | logger.info(f"Using cpu device") 60 | self.device = ('cpu') 61 | else: 62 | self.device = self.use_device 63 | 64 | @dataclass 65 | class VocabularyPruningConfig(Config): 66 | ''' 67 | Configurations for vocabulary pruning. 68 | 69 | Args: 70 | min_count: The threshold to decide if the token should be removed. 71 | The token will be removed from the vocabulary if it appears less than ``min_count`` times in the corpus. 72 | prune_lm_head: whether pruning the lm_head if the model has one. If ``prune_lm_head==False``, TextPruner will not prune the lm_head; 73 | if ``prune_lm_head==True``, TextPruner will prune the lm_head and raise a error if the model does not have an lm_head; 74 | if ``prune_lm_head=='auto'``, TextPruner will try to prune the lm_head and will continue if the model does not have an lm_head. 75 | config_class: Type of the configurations. Users should not change its value. 76 | ''' 77 | min_count: int = 1 78 | prune_lm_head : Union[bool,str] = 'auto' 79 | config_class: str = "VocabularyPruningConfig" 80 | 81 | 82 | @dataclass 83 | class TransformerPruningConfig(Config): 84 | """ 85 | Configurations for transformer pruning. 86 | 87 | Args: 88 | target_ffn_size : the target average FFN size per layer. 89 | target_num_of_heads : the target average number of heads per layer. 90 | pruning_method : ``'masks'`` or ``'iterative'``. If set to ``'masks'``, the pruner prunes the model with the given masks (``head_mask`` and ``ffn_mask``). 91 | If set to ``'iterative'``. the pruner calculates the importance scores of the neurons based on the data provided by the ``dataloader`` and then prunes the model based on the scores. 92 | ffn_even_masking : Whether the FFN size of each layer should be the same. 93 | head_even_masking : Whether the number of attention heads of each layer should be the same. 94 | n_iters : if ``pruning_method`` is set to ``'iterative'``, ``n_iters`` is number of pruning iterations to prune the model progressively. 95 | multiple_of : if ``ffn_even_masking`` is ``False``, restrict the target FFN size of each layer to be a multiple of ``multiple_if``. 96 | pruning_order: ``None`` or ``'head-first'`` or ``'ffn-first'``. ``None``: prune the attention heads and ffn layer simultaneously; if set to ``'head-first'`` or ``'ffn-first'``, the actual number of iterations is ``2*n_iters``. 97 | use_logits : if ``True``, performs self-supervised pruning, where the logits are treated as the soft labels. 98 | config_class: Type of the configurations. Users should not change its value. 99 | 100 | Warning: 101 | if ``ffn_even_masking`` is ``False``, the pruned model can not be save normally (we cannot load the model with the transformers libarary with the saved weights). 102 | So make sure to set ``save_model=False`` when calling ``TransformerPruner.prune()`` or ``PipelinePruner.prune()``. 103 | There are two ways to avoid this: 104 | 105 | * Save the model in TorchScript format manually; 106 | * Set ``keep_shape=False`` when calling ``TransformerPruner.prune()`` or ``PipelinePruner.prune()``, so the full model can be saved. Then save the ``ffn_masks`` and ``head_masks``. When loading the model, load the full model and then prune it with the masks. 107 | """ 108 | 109 | target_ffn_size : Optional[int] = None 110 | target_num_of_heads: Optional[int] = None 111 | pruning_method : str = 'masks' 112 | ffn_even_masking : Optional[bool] = True 113 | head_even_masking : Optional[bool] = True 114 | n_iters : Optional[int] = 1 115 | multiple_of : int = 1 116 | pruning_order : Optional[str] = None 117 | use_logits : bool = False 118 | config_class: str = "TransformerPruningConfig" 119 | def __post_init__(self): 120 | assert self.pruning_method in ('masks','iterative'), "Unrecgonized pruning method" 121 | assert (self.pruning_order is None) or (self.pruning_order in ('head-first','ffn-first')), "Unrecgonized pruning order" 122 | if self.ffn_even_masking is False: 123 | logger.warning("ffn_even_masking is False. Pruned model can only be save in TorchScript format manually.") 124 | 125 | CONFIG_CLASS = { 126 | 'GeneralConfig': GeneralConfig, 127 | 'VocabularyPruningConfig': VocabularyPruningConfig, 128 | 'TransformerPruningConfig': TransformerPruningConfig 129 | } -------------------------------------------------------------------------------- /textpruner/extentions/configurations.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict 2 | import torch 3 | import json 4 | import logging 5 | from typing import Union, Optional 6 | from dataclasses import dataclass, asdict 7 | 8 | from ..configurations import Config 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | @dataclass 13 | class FineGrainedPruningConfig(Config): 14 | """ 15 | Configurations for transformer pruning. 16 | """ 17 | 18 | target_QK_head_size: Optional[int] = None 19 | target_VO_head_size: Optional[int] = None 20 | pruning_method : str = 'masks' 21 | n_iters : Optional[int] = 1 22 | multiple_of : int = 1 23 | use_logits : bool = False 24 | config_class: str = "FineGrainedPruningConfig" 25 | def __post_init__(self): 26 | assert self.pruning_method in ('masks','iterative'), "Unrecgonized pruning method" -------------------------------------------------------------------------------- /textpruner/model_map.py: -------------------------------------------------------------------------------- 1 | from . import model_utils 2 | from . import tokenizer_utils 3 | 4 | MODEL_MAP = { 5 | 'albert': 6 | {'resizer': model_utils.AlbertVocabResizer, 7 | 'tokenizer_helper': tokenizer_utils.SentencepieceTokenizer, 8 | 'structure': model_utils.AlbertStructure}, 9 | 'bert': 10 | {'resizer': model_utils.BertVocabResizer, 11 | 'tokenizer_helper': tokenizer_utils.SubwordTokenizer, 12 | 'structure': model_utils.BertStructure}, 13 | 'electra': 14 | {'resizer': model_utils.ElectraVocabResizer, 15 | 'tokenizer_helper': tokenizer_utils.SubwordTokenizer, 16 | 'structure': model_utils.ElectraStructure}, 17 | 'roberta': 18 | {'resizer': model_utils.RobertaVocabResizer, 19 | 'tokenizer_helper' : tokenizer_utils.RobertaGPT2Tokenizer, 20 | 'structure': model_utils.RobertaStructure}, 21 | 'xlm-roberta': 22 | {'resizer':model_utils.XLMRobertaVocabResizer, 23 | 'tokenizer_helper': tokenizer_utils.XLMRSentencepieceTokenizer, 24 | 'structure': model_utils.XLMRobertaStructure}, 25 | 'xlm': 26 | {'resizer':model_utils.XLMVocabResizer, 27 | 'tokenizer_helper':tokenizer_utils.XLMTokenizer, 28 | 'structure':model_utils.XLMStructure}, 29 | 'bart': 30 | {'resizer' : model_utils.BartVocabResizer, 31 | 'tokenizer_helper' : tokenizer_utils.RobertaGPT2Tokenizer, 32 | 'structure': model_utils.BartStructure}, 33 | 't5': 34 | {'resizer' : model_utils.T5VocabResizer, 35 | 'tokenizer_helper' : tokenizer_utils.T5SentencepieceTokenizer, 36 | 'structure' : model_utils.T5Structure}, 37 | 'mt5': 38 | {'resizer' : model_utils.MT5VocabResizer, 39 | 'tokenizer_helper' : tokenizer_utils.MT5SentencepieceTokenizer, 40 | 'structure' : model_utils.MT5Structure}, 41 | } -------------------------------------------------------------------------------- /textpruner/model_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .albert import AlbertVocabResizer, AlbertStructure 2 | from .bert import BertVocabResizer, BertStructure 3 | from .electra import ElectraVocabResizer, ElectraStructure 4 | from .roberta import RobertaVocabResizer, RobertaStructure 5 | from .xlm_roberta import XLMRobertaVocabResizer, XLMRobertaStructure 6 | from .xlm import XLMStructure, XLMVocabResizer 7 | from .bart import BartVocabResizer, BartStructure 8 | from .t5 import T5VocabResizer, T5Structure 9 | from .mt5 import MT5Structure, MT5VocabResizer -------------------------------------------------------------------------------- /textpruner/model_utils/albert.py: -------------------------------------------------------------------------------- 1 | from .utils import DefaultModelVocabResizer 2 | from .model_structure import ModelStructure 3 | 4 | class AlbertVocabResizer(DefaultModelVocabResizer): 5 | model_name : str = 'albert' 6 | 7 | class AlbertStructure(ModelStructure): 8 | MODEL_PREFIX: str = "albert." 9 | ENCODER_PREFIX: str = r"encoder.albert_layer_groups.0.albert_layers.0." 10 | LAYER_PATTERNS = dict( 11 | query="attention.query", 12 | key="attention.key", 13 | value="attention.value", 14 | att_dense="attention.dense", 15 | interm_dense=r"ffn$", 16 | output_dense="ffn_output", 17 | ) 18 | ATTENTION_PREFIX = ("attention",) 19 | ATTENTION_LAYERS = ("query", "key", "value") 20 | MHA_LAYERS = ATTENTION_LAYERS + ("att_dense",) 21 | NAME_CONFIG = dict( 22 | hidden_size="hidden_size", 23 | intermediate_size="intermediate_size", 24 | num_hidden_layers="num_hidden_layers", 25 | num_attention_heads="num_attention_heads", 26 | attention_head_size="attention_head_size", 27 | ) -------------------------------------------------------------------------------- /textpruner/model_utils/bart.py: -------------------------------------------------------------------------------- 1 | from .utils import DefaultModelVocabResizer 2 | from .model_structure import ModelStructure 3 | import torch 4 | from torch import nn 5 | class BartVocabResizer(DefaultModelVocabResizer): 6 | model_name : str = 'bart' 7 | 8 | @classmethod 9 | def set_embeddings(cls, model, token_ids): 10 | def _prun(old_weight, token_ids): 11 | pruned_word_embeddings_weight = torch.index_select( 12 | old_weight, 0, index=torch.LongTensor(token_ids).to(old_weight.device)) 13 | return pruned_word_embeddings_weight 14 | 15 | old_word_embeddings_shared, old_word_embeddings_encoder, old_word_embeddings_decoder = \ 16 | model.shared, model.encoder.embed_tokens, model.decoder.embed_tokens 17 | 18 | old_word_embeddings_shared_weight, old_word_embeddings_encoder_weight, old_word_embeddings_decoder_weight = \ 19 | old_word_embeddings_shared.weight, old_word_embeddings_encoder.weight, old_word_embeddings_decoder.weight 20 | 21 | pruned_word_embeddings_shared_weight, pruned_word_embeddings_encoder_weight, pruned_word_embeddings_decoder_weight = \ 22 | _prun(old_word_embeddings_shared_weight, token_ids), _prun(old_word_embeddings_encoder_weight, token_ids), _prun(old_word_embeddings_decoder_weight, token_ids) 23 | 24 | pruned_num_tokens, embedding_dim = pruned_word_embeddings_shared_weight.shape 25 | 26 | pruned_word_embeddings_shared = nn.Embedding( 27 | pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device) 28 | pruned_word_embeddings_shared.weight.data[:] = pruned_word_embeddings_shared_weight[:] 29 | 30 | pruned_word_embeddings_encoder = nn.Embedding( 31 | pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device) 32 | pruned_word_embeddings_encoder.weight.data[:] = pruned_word_embeddings_encoder_weight[:] 33 | 34 | pruned_word_embeddings_decoder = nn.Embedding( 35 | pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device) 36 | pruned_word_embeddings_decoder.weight.data[:] = pruned_word_embeddings_decoder_weight[:] 37 | 38 | model.shared = pruned_word_embeddings_shared 39 | model.encoder.embed_tokens = pruned_word_embeddings_encoder 40 | model.decoder.embed_tokens = pruned_word_embeddings_decoder 41 | 42 | class BartStructure(ModelStructure): 43 | MODEL_PREFIX: str = "model." 44 | ENCODER_PREFIX: str = r"encoder.layers.[0-9]+\." 45 | LAYER_PATTERNS = dict( 46 | query="self_attn.q_proj", 47 | key="self_attn.k_proj", 48 | value="self_attn.v_proj", 49 | att_dense="self_attn.out_proj", 50 | interm_dense="fc1", 51 | output_dense="fc2", 52 | ) 53 | ATTENTION_PREFIX = ("self_attn",) 54 | ATTENTION_LAYERS = ("q_proj", "k_proj", "v_proj") 55 | MHA_LAYERS = ATTENTION_LAYERS + ("att_dense",) 56 | NAME_CONFIG = dict( 57 | hidden_size="d_model", 58 | intermediate_size="encoder_ffn_dim", 59 | num_hidden_layers="encoder_layers", 60 | num_attention_heads="num_attention_heads", 61 | attention_head_size="", 62 | ) -------------------------------------------------------------------------------- /textpruner/model_utils/bert.py: -------------------------------------------------------------------------------- 1 | from .utils import DefaultModelVocabResizer 2 | from .model_structure import ModelStructure 3 | 4 | class BertVocabResizer(DefaultModelVocabResizer): 5 | model_name : str = 'bert' 6 | 7 | class BertStructure(ModelStructure): 8 | MODEL_PREFIX: str = "bert." 9 | ENCODER_PREFIX: str = r"encoder.layer.[0-9]+\." 10 | LAYER_PATTERNS = dict( 11 | query="attention.self.query", 12 | key="attention.self.key", 13 | value="attention.self.value", 14 | att_dense="attention.output.dense", 15 | interm_dense="intermediate.dense", 16 | output_dense="output.dense", 17 | ) 18 | ATTENTION_PREFIX = ("attention.self",) 19 | ATTENTION_LAYERS = ("query", "key", "value") 20 | MHA_LAYERS = ATTENTION_LAYERS + ("att_dense",) 21 | NAME_CONFIG = dict( 22 | hidden_size="hidden_size", 23 | intermediate_size="intermediate_size", 24 | num_hidden_layers="num_hidden_layers", 25 | num_attention_heads="num_attention_heads", 26 | attention_head_size="attention_head_size", 27 | ) -------------------------------------------------------------------------------- /textpruner/model_utils/electra.py: -------------------------------------------------------------------------------- 1 | from .utils import DefaultModelVocabResizer 2 | from .model_structure import ModelStructure 3 | 4 | class ElectraVocabResizer(DefaultModelVocabResizer): 5 | model_name : str = 'electra' 6 | 7 | class ElectraStructure(ModelStructure): 8 | MODEL_PREFIX: str = "electra." 9 | ENCODER_PREFIX: str = r"encoder.layer.[0-9]+\." 10 | LAYER_PATTERNS = dict( 11 | query="attention.self.query", 12 | key="attention.self.key", 13 | value="attention.self.value", 14 | att_dense="attention.output.dense", 15 | interm_dense="intermediate.dense", 16 | output_dense="output.dense", 17 | ) 18 | ATTENTION_PREFIX = ("attention.self",) 19 | ATTENTION_LAYERS = ("query", "key", "value") 20 | MHA_LAYERS = ATTENTION_LAYERS + ("att_dense",) 21 | NAME_CONFIG = dict( 22 | hidden_size="hidden_size", 23 | intermediate_size="intermediate_size", 24 | num_hidden_layers="num_hidden_layers", 25 | num_attention_heads="num_attention_heads", 26 | attention_head_size="attention_head_size", 27 | ) -------------------------------------------------------------------------------- /textpruner/model_utils/model_structure.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from torch import nn 4 | import logging 5 | from typing import Dict, List 6 | logger = logging.getLogger(__name__) 7 | 8 | # adapted from huggingface/nn_pruning/model_structure.py 9 | class ModelStructure: 10 | MODEL_PREFIX: str = "" 11 | ENCODER_PREFIX: str = "" 12 | ATTENTION_LAYERS = ("query", "key", "value") 13 | FFN_LAYERS = ("interm_dense", "output_dense") 14 | 15 | @classmethod 16 | def get_att_query(cls, model, ignore_model_prefix=False): 17 | pattern = cls.ENCODER_PREFIX + cls.LAYER_PATTERNS['query'] 18 | if ignore_model_prefix is False: 19 | pattern = cls.MODEL_PREFIX + pattern 20 | rs = [] 21 | for k in model.named_modules(): 22 | name = k[0] 23 | r = re.search(pattern, name) 24 | if r is not None: 25 | rs.append(get_submodule(model,r.group())) 26 | return rs 27 | 28 | 29 | @classmethod 30 | def get_att_key(cls, model, ignore_model_prefix=False): 31 | pattern = cls.ENCODER_PREFIX + cls.LAYER_PATTERNS['key'] 32 | if ignore_model_prefix is False: 33 | pattern = cls.MODEL_PREFIX + pattern 34 | rs = [] 35 | for k in model.named_modules(): 36 | name = k[0] 37 | r = re.search(pattern, name) 38 | if r is not None: 39 | rs.append(get_submodule(model,r.group())) 40 | return rs 41 | 42 | 43 | @classmethod 44 | def get_att_value(cls, model, ignore_model_prefix=False): 45 | pattern = cls.ENCODER_PREFIX + cls.LAYER_PATTERNS['value'] 46 | if ignore_model_prefix is False: 47 | pattern = cls.MODEL_PREFIX + pattern 48 | rs = [] 49 | for k in model.named_modules(): 50 | name = k[0] 51 | r = re.search(pattern, name) 52 | if r is not None: 53 | rs.append(get_submodule(model,r.group())) 54 | return rs 55 | 56 | 57 | @classmethod 58 | def get_att_output(cls, model, ignore_model_prefix=False): 59 | pattern = cls.ENCODER_PREFIX + cls.LAYER_PATTERNS['att_dense'] 60 | if ignore_model_prefix is False: 61 | pattern = cls.MODEL_PREFIX + pattern 62 | rs = [] 63 | for k in model.named_modules(): 64 | name = k[0] 65 | r = re.search(pattern, name) 66 | if r is not None: 67 | rs.append(get_submodule(model,r.group())) 68 | return rs 69 | 70 | 71 | @classmethod 72 | def get_ffn_interm(cls, model, ignore_model_prefix=False): 73 | pattern = cls.ENCODER_PREFIX + cls.LAYER_PATTERNS['interm_dense'] 74 | if ignore_model_prefix is False: 75 | pattern = cls.MODEL_PREFIX + pattern 76 | rs = [] 77 | for k in model.named_modules(): 78 | name = k[0] 79 | r = re.search(pattern, name) 80 | if r is not None: 81 | rs.append(get_submodule(model,r.group())) 82 | return rs 83 | 84 | 85 | @classmethod 86 | def get_ffn_output(cls, model, ignore_model_prefix=False): 87 | pattern = cls.ENCODER_PREFIX + cls.LAYER_PATTERNS['output_dense'] 88 | if ignore_model_prefix is False: 89 | pattern = cls.MODEL_PREFIX + pattern 90 | rs = [] 91 | for k in model.named_modules(): 92 | name = k[0] 93 | r = re.search(pattern, name) 94 | if r is not None: 95 | rs.append(get_submodule(model,r.group())) 96 | return rs 97 | 98 | @classmethod 99 | def get_num_layers(cls, model, ignore_model_prefix=False): 100 | pattern = cls.ENCODER_PREFIX 101 | if ignore_model_prefix is False: 102 | pattern = cls.MODEL_PREFIX + pattern 103 | rs = [] 104 | for k in model.named_modules(): 105 | name = k[0] 106 | r = re.search(pattern, name) 107 | if r is not None: 108 | rs.append(r.group()) 109 | return len(set(rs)) 110 | 111 | @classmethod 112 | def layer_index(cls, child_module_name): 113 | extracts = re.findall(r"[0-9]+", child_module_name) 114 | return int(extracts[0]) 115 | 116 | 117 | 118 | # from PyTorch 1.9.0 119 | def get_submodule(model: nn.Module, target: str) -> nn.Module: 120 | """ 121 | Returns the submodule given by ``target`` if it exists, 122 | otherwise throws an error. 123 | 124 | For example, let's say you have an ``nn.Module`` ``A`` that 125 | looks like this: 126 | 127 | .. code-block::text 128 | 129 | A( 130 | (net_b): Module( 131 | (net_c): Module( 132 | (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) 133 | ) 134 | (linear): Linear(in_features=100, out_features=200, bias=True) 135 | ) 136 | ) 137 | 138 | (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested 139 | submodule ``net_b``, which itself has two submodules ``net_c`` 140 | and ``linear``. ``net_c`` then has a submodule ``conv``.) 141 | 142 | To check whether or not we have the ``linear`` submodule, we 143 | would call ``get_submodule("net_b.linear")``. To check whether 144 | we have the ``conv`` submodule, we would call 145 | ``get_submodule("net_b.net_c.conv")``. 146 | 147 | The runtime of ``get_submodule`` is bounded by the degree 148 | of module nesting in ``target``. A query against 149 | ``named_modules`` achieves the same result, but it is O(N) in 150 | the number of transitive modules. So, for a simple check to see 151 | if some submodule exists, ``get_submodule`` should always be 152 | used. 153 | 154 | Args: 155 | target: The fully-qualified string name of the submodule 156 | to look for. (See above example for how to specify a 157 | fully-qualified string.) 158 | 159 | Returns: 160 | torch.nn.Module: The submodule referenced by ``target`` 161 | 162 | Raises: 163 | AttributeError: If the target string references an invalid 164 | path or resolves to something that is not an 165 | ``nn.Module`` 166 | """ 167 | if target == "": 168 | return model 169 | 170 | atoms: List[str] = target.split(".") 171 | mod: torch.nn.Module = model 172 | 173 | for item in atoms: 174 | 175 | if not hasattr(mod, item): 176 | raise AttributeError(mod._get_name() + " has no " 177 | "attribute `" + item + "`") 178 | 179 | mod = getattr(mod, item) 180 | 181 | if not isinstance(mod, torch.nn.Module): 182 | raise AttributeError("`" + item + "` is not " 183 | "an nn.Module") 184 | 185 | return mod -------------------------------------------------------------------------------- /textpruner/model_utils/mt5.py: -------------------------------------------------------------------------------- 1 | from .utils import DefaultModelVocabResizer 2 | from .model_structure import ModelStructure 3 | import torch 4 | from torch import nn 5 | class MT5VocabResizer(DefaultModelVocabResizer): 6 | model_name : str = 'mt5' 7 | 8 | @classmethod 9 | def set_embeddings(cls, model, token_ids): 10 | def _prun(old_weight, token_ids): 11 | pruned_word_embeddings_weight = torch.index_select( 12 | old_weight, 0, index=torch.LongTensor(token_ids).to(old_weight.device)) 13 | return pruned_word_embeddings_weight 14 | 15 | 16 | vocab_size = model.shared.weight.shape[0] 17 | max_token_ids = token_ids[-1] 18 | tokens_in_embed_notin_tokenizer_ids = list(range(max_token_ids+1, vocab_size)) 19 | token_ids_temp = token_ids[:] 20 | token_ids_temp.extend(tokens_in_embed_notin_tokenizer_ids) 21 | 22 | 23 | model.config.vocab_size = len(token_ids_temp) 24 | 25 | old_word_embeddings_shared, old_word_embeddings_encoder, old_word_embeddings_decoder = \ 26 | model.shared, model.encoder.embed_tokens, model.decoder.embed_tokens 27 | 28 | old_word_embeddings_shared_weight, old_word_embeddings_encoder_weight, old_word_embeddings_decoder_weight = \ 29 | old_word_embeddings_shared.weight, old_word_embeddings_encoder.weight, old_word_embeddings_decoder.weight 30 | 31 | pruned_word_embeddings_shared_weight, pruned_word_embeddings_encoder_weight, pruned_word_embeddings_decoder_weight = \ 32 | _prun(old_word_embeddings_shared_weight, token_ids_temp), _prun(old_word_embeddings_encoder_weight, token_ids_temp), _prun(old_word_embeddings_decoder_weight, token_ids_temp) 33 | 34 | pruned_num_tokens, embedding_dim = pruned_word_embeddings_shared_weight.shape 35 | 36 | pruned_word_embeddings_shared = nn.Embedding( 37 | pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device) 38 | pruned_word_embeddings_shared.weight.data[:] = pruned_word_embeddings_shared_weight[:] 39 | 40 | pruned_word_embeddings_encoder = nn.Embedding( 41 | pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device) 42 | pruned_word_embeddings_encoder.weight.data[:] = pruned_word_embeddings_encoder_weight[:] 43 | 44 | pruned_word_embeddings_decoder = nn.Embedding( 45 | pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device) 46 | pruned_word_embeddings_decoder.weight.data[:] = pruned_word_embeddings_decoder_weight[:] 47 | 48 | model.shared = pruned_word_embeddings_shared 49 | model.encoder.embed_tokens = pruned_word_embeddings_encoder 50 | model.decoder.embed_tokens = pruned_word_embeddings_decoder 51 | 52 | class MT5Structure(ModelStructure): 53 | MODEL_PREFIX: str = "transformer." 54 | ENCODER_PREFIX: str = r"encoder.block.[0-9]+\.layer." 55 | LAYER_PATTERNS = dict( 56 | query="0.SelfAttention.q", 57 | key="0.SelfAttention.k", 58 | value="0.SelfAttention.v", 59 | att_dense="0.SelfAttention.o", 60 | interm_dense="1.DenseReluDense.wi", 61 | output_dense="1.DenseReluDense.wo", 62 | ) 63 | ATTENTION_PREFIX = ("0.SelfAttention",) 64 | ATTENTION_LAYERS = ("q", "k", "v") 65 | MHA_LAYERS = ATTENTION_LAYERS + ("att_dense",) 66 | NAME_CONFIG = dict( 67 | hidden_size="d_model", 68 | intermediate_size="d_ff", 69 | num_hidden_layers="num_layers", 70 | num_attention_heads="num_heads", 71 | attention_head_size="", 72 | ) -------------------------------------------------------------------------------- /textpruner/model_utils/roberta.py: -------------------------------------------------------------------------------- 1 | from .utils import DefaultModelVocabResizer 2 | from .model_structure import ModelStructure 3 | 4 | class RobertaVocabResizer(DefaultModelVocabResizer): 5 | model_name : str = 'roberta' 6 | 7 | class RobertaStructure(ModelStructure): 8 | MODEL_PREFIX: str = "roberta." 9 | ENCODER_PREFIX: str = r"encoder.layer.[0-9]+\." 10 | LAYER_PATTERNS = dict( 11 | query="attention.self.query", 12 | key="attention.self.key", 13 | value="attention.self.value", 14 | att_dense="attention.output.dense", 15 | interm_dense="intermediate.dense", 16 | output_dense="output.dense", 17 | ) 18 | ATTENTION_PREFIX = ("attention.self",) 19 | ATTENTION_LAYERS = ("query", "key", "value") 20 | MHA_LAYERS = ATTENTION_LAYERS + ("att_dense",) 21 | NAME_CONFIG = dict( 22 | hidden_size="hidden_size", 23 | intermediate_size="intermediate_size", 24 | num_hidden_layers="num_hidden_layers", 25 | num_attention_heads="num_attention_heads", 26 | attention_head_size="attention_head_size", 27 | ) -------------------------------------------------------------------------------- /textpruner/model_utils/t5.py: -------------------------------------------------------------------------------- 1 | from .utils import DefaultModelVocabResizer 2 | from .model_structure import ModelStructure 3 | import torch 4 | from torch import nn 5 | class T5VocabResizer(DefaultModelVocabResizer): 6 | model_name : str = 't5' 7 | 8 | @classmethod 9 | def set_embeddings(cls, model, token_ids): 10 | def _prun(old_weight, token_ids): 11 | pruned_word_embeddings_weight = torch.index_select( 12 | old_weight, 0, index=torch.LongTensor(token_ids).to(old_weight.device)) 13 | return pruned_word_embeddings_weight 14 | 15 | vocab_size = model.shared.weight.shape[0] 16 | max_token_ids = token_ids[-1] 17 | tokens_in_embed_notin_tokenizer_ids = list(range(max_token_ids+1, vocab_size)) 18 | token_ids_temp = token_ids[:] 19 | token_ids_temp.extend(tokens_in_embed_notin_tokenizer_ids) 20 | 21 | model.config.vocab_size = len(token_ids_temp) 22 | 23 | old_word_embeddings_shared, old_word_embeddings_encoder, old_word_embeddings_decoder = \ 24 | model.shared, model.encoder.embed_tokens, model.decoder.embed_tokens 25 | 26 | old_word_embeddings_shared_weight, old_word_embeddings_encoder_weight, old_word_embeddings_decoder_weight = \ 27 | old_word_embeddings_shared.weight, old_word_embeddings_encoder.weight, old_word_embeddings_decoder.weight 28 | 29 | pruned_word_embeddings_shared_weight, pruned_word_embeddings_encoder_weight, pruned_word_embeddings_decoder_weight = \ 30 | _prun(old_word_embeddings_shared_weight, token_ids_temp), _prun(old_word_embeddings_encoder_weight, token_ids_temp), _prun(old_word_embeddings_decoder_weight, token_ids_temp) 31 | 32 | pruned_num_tokens, embedding_dim = pruned_word_embeddings_shared_weight.shape 33 | 34 | pruned_word_embeddings_shared = nn.Embedding( 35 | pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device) 36 | pruned_word_embeddings_shared.weight.data[:] = pruned_word_embeddings_shared_weight[:] 37 | 38 | pruned_word_embeddings_encoder = nn.Embedding( 39 | pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device) 40 | pruned_word_embeddings_encoder.weight.data[:] = pruned_word_embeddings_encoder_weight[:] 41 | 42 | pruned_word_embeddings_decoder = nn.Embedding( 43 | pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device) 44 | pruned_word_embeddings_decoder.weight.data[:] = pruned_word_embeddings_decoder_weight[:] 45 | 46 | model.shared = pruned_word_embeddings_shared 47 | model.encoder.embed_tokens = pruned_word_embeddings_encoder 48 | model.decoder.embed_tokens = pruned_word_embeddings_decoder 49 | 50 | class T5Structure(ModelStructure): 51 | MODEL_PREFIX: str = "transformer." 52 | ENCODER_PREFIX: str = r"encoder.block.[0-9]+\.layer." 53 | LAYER_PATTERNS = dict( 54 | query="0.SelfAttention.q", 55 | key="0.SelfAttention.k", 56 | value="0.SelfAttention.v", 57 | att_dense="0.SelfAttention.o", 58 | interm_dense="1.DenseReluDense.wi", 59 | output_dense="1.DenseReluDense.wo", 60 | ) 61 | ATTENTION_PREFIX = ("0.SelfAttention",) 62 | ATTENTION_LAYERS = ("q", "k", "v") 63 | MHA_LAYERS = ATTENTION_LAYERS + ("att_dense",) 64 | NAME_CONFIG = dict( 65 | hidden_size="d_model", 66 | intermediate_size="d_ff", 67 | num_hidden_layers="num_layers", 68 | num_attention_heads="num_heads", 69 | attention_head_size="", 70 | ) -------------------------------------------------------------------------------- /textpruner/model_utils/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from torch import nn 4 | import logging 5 | from typing import Dict 6 | logger = logging.getLogger(__name__) 7 | 8 | class DefaultModelVocabResizer: 9 | @classmethod 10 | def set_embeddings(cls, model, token_ids): 11 | # self.model.get_input_embeddings() 12 | old_word_embeddings = model.embeddings.word_embeddings 13 | old_word_embeddings_weight = old_word_embeddings.weight 14 | 15 | pruned_word_embeddings_weight = torch.index_select( 16 | old_word_embeddings_weight, 0, index=torch.LongTensor(token_ids).to(old_word_embeddings_weight.device)) 17 | pruned_num_tokens, embedding_dim = pruned_word_embeddings_weight.shape 18 | 19 | pruned_word_embeddings = nn.Embedding( 20 | pruned_num_tokens, embedding_dim).to(old_word_embeddings_weight.device) 21 | pruned_word_embeddings.weight.data[:] = pruned_word_embeddings_weight[:] 22 | 23 | model.embeddings.word_embeddings = pruned_word_embeddings 24 | 25 | @classmethod 26 | def set_lm_head(cls, model, token_ids) -> bool: 27 | try: 28 | output_embedding_layer = model.get_output_embeddings() 29 | except AttributeError: 30 | return False 31 | if output_embedding_layer is None: 32 | return False 33 | output_embedding_layer.weight = model.get_input_embeddings().weight 34 | output_embedding_layer.bias.data = torch.index_select( 35 | output_embedding_layer.bias.data, 0, index=torch.LongTensor(token_ids).to(output_embedding_layer.weight.device)) 36 | return True 37 | 38 | 39 | #bert, roberta, xlmr, ... 40 | def get_word_embeddings(model): 41 | state_dict = model.state_dict() 42 | layer_template = "embeddings.word_embeddings" 43 | layer_names = [] 44 | for key in state_dict: 45 | if layer_template in key: 46 | layer_names.append(key) 47 | assert len( 48 | layer_names) == 1, f"Invalid model structure with ambiguous word embeddings: {layer_names}" 49 | word_embedding_weight = state_dict[layer_names[0]] 50 | return word_embedding_weight 51 | 52 | 53 | #bert, roberta, xlmr, ... 54 | def get_num_of_trms(model): 55 | layer_template_regex = "encoder.layer\.(\d+)\." 56 | layer_template = "encoder.layer.LAYER_INDEX." 57 | layer_indices = set() 58 | layer_names = set() 59 | state_dict = model.state_dict() 60 | for key in state_dict: 61 | matched = re.findall(layer_template_regex, key) 62 | if len(matched) > 0: 63 | assert len( 64 | matched) == 1, f"Invalid model structure. Cannot parse {key}" 65 | layer_index = int(matched[0]) 66 | layer_indices.add(layer_index) 67 | 68 | layer_name = layer_template.replace("LAYER_INDEX", matched[0]) 69 | layer_name = key[:key.find(layer_name)]+layer_name 70 | layer_names.add(layer_name) 71 | 72 | print("Found transfomr layers:", layer_indices) 73 | print("Layer name prefixes:", layer_names) 74 | 75 | return len(layer_indices), layer_names 76 | -------------------------------------------------------------------------------- /textpruner/model_utils/xlm.py: -------------------------------------------------------------------------------- 1 | from .utils import DefaultModelVocabResizer 2 | from .model_structure import ModelStructure 3 | import torch 4 | from torch import nn 5 | class XLMVocabResizer(DefaultModelVocabResizer): 6 | model_name : str = 'xlm' 7 | 8 | @classmethod 9 | def set_embeddings(cls, model, token_ids): 10 | # self.model.get_input_embeddings() 11 | 12 | if hasattr(model.embeddings, 'word_embeddings'): #XLM 13 | old_word_embeddings = model.embeddings.word_embeddings 14 | else: 15 | old_word_embeddings = model.embeddings 16 | 17 | 18 | 19 | # old_word_embeddings = model.embeddings.word_embeddings 20 | old_word_embeddings_weight = old_word_embeddings.weight 21 | 22 | pruned_word_embeddings_weight = torch.index_select( 23 | old_word_embeddings_weight, 0, index=torch.LongTensor(token_ids).to(old_word_embeddings_weight.device)) 24 | pruned_num_tokens, embedding_dim = pruned_word_embeddings_weight.shape 25 | 26 | pruned_word_embeddings = nn.Embedding( 27 | pruned_num_tokens, embedding_dim).to(old_word_embeddings_weight.device) 28 | pruned_word_embeddings.weight.data[:] = pruned_word_embeddings_weight[:] 29 | 30 | 31 | if hasattr(model.embeddings, 'word_embeddings'): 32 | model.embeddings.word_embeddings = pruned_word_embeddings 33 | else: 34 | model.embeddings = pruned_word_embeddings 35 | 36 | 37 | 38 | 39 | class XLMStructure(ModelStructure): 40 | MODEL_PREFIX: str = "transformer." 41 | ENCODER_PREFIX: str = r"attention.[0-9]+\." 42 | LAYER_PATTERNS = dict( 43 | query=r"attentions\.[0-9]+\.q_lin", 44 | key=r"attentions\.[0-9]+\.k_lin", 45 | value=r"attentions\.[0-9]+\.v_lin", 46 | att_dense=r"attentions\.[0-9]+\.out_lin", 47 | interm_dense=r"ffns\.[0-9]+\.lin1", 48 | output_dense=r"ffns\.[0-9]+\.lin2", 49 | ) 50 | ATTENTION_PREFIX = (r"attentions\.[0-9]",) 51 | ATTENTION_LAYERS = ("q_lin", "k_lin", "v_lin") 52 | MHA_LAYERS = ATTENTION_LAYERS + ("att_dense",) 53 | NAME_CONFIG = dict( 54 | hidden_size="emb_dim", 55 | intermediate_size="emb_dim", 56 | num_hidden_layers="n_layers", 57 | num_attention_heads="n_heads", 58 | attention_head_size="attention_head_size", 59 | ) -------------------------------------------------------------------------------- /textpruner/model_utils/xlm_roberta.py: -------------------------------------------------------------------------------- 1 | from .utils import DefaultModelVocabResizer 2 | from .model_structure import ModelStructure 3 | 4 | class XLMRobertaVocabResizer(DefaultModelVocabResizer): 5 | model_name : str = 'xlm-roberta' 6 | 7 | class XLMRobertaStructure(ModelStructure): 8 | MODEL_PREFIX: str = "roberta." 9 | ENCODER_PREFIX: str = r"encoder.layer.[0-9]+\." 10 | LAYER_PATTERNS = dict( 11 | query="attention.self.query", 12 | key="attention.self.key", 13 | value="attention.self.value", 14 | att_dense="attention.output.dense", 15 | interm_dense="intermediate.dense", 16 | output_dense="output.dense", 17 | ) 18 | ATTENTION_PREFIX = ("attention.self",) 19 | ATTENTION_LAYERS = ("query", "key", "value") 20 | MHA_LAYERS = ATTENTION_LAYERS + ("att_dense",) 21 | NAME_CONFIG = dict( 22 | hidden_size="hidden_size", 23 | intermediate_size="intermediate_size", 24 | num_hidden_layers="num_hidden_layers", 25 | num_attention_heads="num_attention_heads", 26 | attention_head_size="attention_head_size", 27 | ) -------------------------------------------------------------------------------- /textpruner/pruners/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer_pruner import TransformerPruner 2 | from .vocabulary_pruner import VocabularyPruner 3 | from .pipeline_pruner import PipelinePruner -------------------------------------------------------------------------------- /textpruner/pruners/pipeline_pruner.py: -------------------------------------------------------------------------------- 1 | from .transformer_pruner import TransformerPruner 2 | from .vocabulary_pruner import VocabularyPruner 3 | from typing import Optional 4 | from ..configurations import GeneralConfig,VocabularyPruningConfig,TransformerPruningConfig 5 | import torch 6 | from torch import nn 7 | import os 8 | import logging 9 | logger = logging.getLogger(__name__) 10 | from .utils import infer_model_type 11 | from ..model_map import MODEL_MAP 12 | 13 | class PipelinePruner: 14 | ''' 15 | Args: 16 | model : The model to be pruned. 17 | tokenizer : The tokenizer for the model. 18 | vocabulary_pruning_config : a :class:`~textpruner.configurations.VocabularyPruningConfig` object. 19 | transformer_pruning_config : a :class:`~textpruner.configurations.TransformerPruningConfig` object. 20 | general_config : a :class:`~textpruner.configurations.GeneralConfig` object. 21 | base_model_prefix : The prefix of the base model, i.e., the name of the base model as a member in the model. \ 22 | For example, if ``model.bert_encoder = BertModel(...)``, then the ``base_model_prefix`` is ``bert_encoder``. \ 23 | TextPruner will infer the ``base_model_prefix`` so we can leave its value as ``None``. But if it fails, users have to set its value explicitly. 24 | ''' 25 | def __init__(self, 26 | model: nn.Module, 27 | tokenizer, 28 | transformer_pruning_config: Optional[TransformerPruningConfig] = None, 29 | vocabulary_pruning_config : Optional[VocabularyPruningConfig] = None, 30 | general_config: Optional[GeneralConfig] = None, 31 | base_model_prefix : Optional[str] = None): 32 | self.model = model 33 | self.tokenizer = tokenizer 34 | 35 | self.general_config = GeneralConfig() if general_config is None else general_config 36 | self.transformer_pruning_config = TransformerPruningConfig() if transformer_pruning_config is None else transformer_pruning_config 37 | self.vocabulary_pruning_config = VocabularyPruningConfig() if vocabulary_pruning_config is None else vocabulary_pruning_config 38 | 39 | 40 | self.output_dir = self.general_config.output_dir 41 | base_model, model_type = infer_model_type(model, base_model_prefix) 42 | assert model_type in MODEL_MAP, \ 43 | f"Model type {self.model_type} is not supported, or not understood. Model type must be one of {list(MODEL_MAP.keys())}" 44 | self.base_model = base_model 45 | self.model_type = model_type 46 | 47 | self.vocabulary_pruner = VocabularyPruner(model, tokenizer, vocabulary_pruning_config, general_config, base_model_prefix=base_model_prefix) 48 | self.transformer_pruner = TransformerPruner(model, transformer_pruning_config, general_config, base_model_prefix=base_model_prefix) 49 | self.save_dir = None 50 | 51 | def prune(self, 52 | dataloader=None, 53 | adaptor=None, 54 | batch_postprocessor=None, 55 | head_mask: Optional[torch.Tensor] =None, 56 | ffn_mask: Optional[torch.Tensor]=None, 57 | keep_shape=False, 58 | dataiter=None, 59 | additional_tokens=None, 60 | additional_token_ids=None, 61 | save_model=True) -> Optional[str]: 62 | ''' 63 | Prunes the transformers, then prunes the vocabulary. 64 | 65 | Args: 66 | dataloader : a dataloader that generates batches. Each batch should contains both the inputs and the labels. 67 | adaptor : a function that takes the model output and return the loss. 68 | batch_postprocessor : a function that takes the batch produced by the dataloader and return a batch. It is used for post-processing the batches if needed. 69 | head_mask : a tensor of shape ``(num_layers, num_attention_heads)``. `1` means to keep, `0` means to prune. 70 | ffn_mask : a tensor of shape ``(num_layers, intermediate_hidden_size)``. `1` means to keep, `0` means to prune. 71 | keep_shape : if ``True``, the model is no actually pruned and the model stucture is not changed, but the weights that *should be pruned* are set to zero. 72 | dataiter : a list of pre-tokenized strings. These strings will be tokenized by the tokenizer to generate a set of tokens. 73 | additional_tokens : a list of tokens. These tokens must be existed in the original vocabulary. 74 | additional_token_ids : a list of ints representing the token ids. 75 | save_model : whether to save the model when the pruning is finished. 76 | ''' 77 | 78 | logger.info("Transfomer pruning...") 79 | self.transformer_pruner.prune(dataloader, 80 | adaptor, 81 | batch_postprocessor=batch_postprocessor, 82 | keep_shape=keep_shape, 83 | head_mask=head_mask, 84 | ffn_mask=ffn_mask, 85 | save_model=False) 86 | logger.info("Vocabulary pruning...") 87 | self.vocabulary_pruner.prune(dataiter=dataiter, 88 | additional_tokens=additional_tokens, 89 | additional_token_ids=additional_token_ids, 90 | save_model=False) 91 | 92 | if save_model is True: 93 | self.save_dir = self.save_model() 94 | return self.save_dir 95 | 96 | def save_model(self, dir_name=None) -> str: 97 | ffn_sizes = self.transformer_pruner.ffn_mask.to(int).sum(-1).tolist() 98 | if self.transformer_pruner.keep_shape is False: 99 | ffn_size = ffn_sizes[0] 100 | num_of_heads = self.transformer_pruner.head_mask.sum().item() / self.transformer_pruner.head_mask.size(0) 101 | if len(set(ffn_sizes)) != 1: 102 | raise NotImplementedError("Cannot save pruned model with different ffn size per layer with keep_shape=False. \ 103 | Call PipelinePruner.save_masks or PipelinePruner.save_jit_model manually instead.") 104 | else: 105 | self.base_model.config.intermediate_size = ffn_size 106 | else: 107 | ffn_size = self.transformer_pruner.ffn_mask.size(1) #base_model.config.intermediate_size 108 | num_of_heads = self.transformer_pruner.head_mask.size(1) #self.transformer_pruning_config.target_num_of_heads 109 | 110 | vocab_size = len(self.vocabulary_pruner.pruned_token_ids) 111 | self.base_model.config.vocab_size = vocab_size 112 | 113 | 114 | if dir_name is None: 115 | save_dir = os.path.join(self.general_config.output_dir,f'pruned_V{vocab_size}H{num_of_heads}F{ffn_size}') 116 | else: 117 | save_dir = os.path.join(self.general_config.output_dir,dir_name) 118 | os.makedirs(save_dir, exist_ok=True) 119 | torch.save(self.model.state_dict(),os.path.join(save_dir,'pytorch_model.bin')) 120 | # save config 121 | self.base_model.config.save_pretrained(save_dir) 122 | # save tokenizer 123 | self.vocabulary_pruner.tokenizer_helper.save_vocab(self.tokenizer, self.vocabulary_pruner.pruned_token_ids, save_dir) 124 | 125 | logger.info(f"Model and configuration have been saved to {save_dir}") 126 | 127 | return save_dir 128 | 129 | def save_jit_model(self, example_inputs, dir_name=None) -> str: 130 | self.model.eval() 131 | with torch.no_grad(): 132 | traced_model = torch.jit.trace(self.model, example_inputs=example_inputs, strict=False) 133 | if dir_name is None: 134 | save_dir = os.path.join(self.general_config.output_dir,'pruned_H{num_of_heads}F{ffn_size}_traced') 135 | else: 136 | save_dir = os.path.join(self.general_config.output_dir,dir_name) 137 | os.makedirs(save_dir, exist_ok=True) 138 | torch.jit.save(traced_model, os.path.join(save_dir,'pytorch_model.ts')) 139 | 140 | return save_dir -------------------------------------------------------------------------------- /textpruner/pruners/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from collections import abc 4 | from typing import Tuple,Optional 5 | import logging 6 | logger = logging.getLogger(__name__) 7 | 8 | def move_to_device(batch, device): 9 | r"""Puts each data field to the device""" 10 | if isinstance(batch, torch.Tensor): 11 | return batch.to(device) 12 | elif isinstance(batch,(list,tuple)): 13 | return tuple(move_to_device(item,device) for item in batch) 14 | elif isinstance(batch, abc.Mapping): 15 | return {key: move_to_device(value,device) for key, value in batch.items()} 16 | else: 17 | return batch 18 | 19 | 20 | def infer_model_type(model, base_model_prefix) -> Tuple[nn.Module,str]: 21 | if base_model_prefix is not None: 22 | base_model = getattr(model, base_model_prefix, model) 23 | model_type = base_model.config.model_type 24 | else: 25 | if hasattr(model, 'base_model_prefix'): 26 | base_model = getattr(model, model.base_model_prefix, model) 27 | if hasattr(base_model, 'config'): 28 | model_type = base_model.config.model_type 29 | else: 30 | raise ValueError("Cannot get model_type! You should provide base_model_prefix") 31 | else: 32 | raise ValueError("Cannot get model_type! You should provide base_model_prefix") 33 | return base_model, model_type 34 | 35 | 36 | def random_mask_tensor(shape: Tuple[int,int], p : float = 0.5, dtype=None, even_masks=True): 37 | tensor = torch.zeros(shape) 38 | if even_masks is False: 39 | tensor = tensor.bernoulli_(p=0.5) 40 | else: 41 | num_masks_per_row = int(shape[1] * p) 42 | for i in range(shape[0]): 43 | tensor[i][:num_masks_per_row] = 1 44 | randindex = torch.randperm(shape[1]) 45 | tensor[i] = tensor[i][randindex] 46 | if dtype is not None: 47 | return tensor.to(dtype) 48 | else: 49 | return tensor 50 | 51 | 52 | def generate_mask(importance : torch.Tensor, total_target_size : int, even_masking : bool = False, 53 | layer_start : Optional[int] = None, layer_end: Optional[int] = None, multiple_of : int = 1 ) -> torch.Tensor: 54 | if layer_start is not None and layer_end is not None: 55 | target_size_per_layer = total_target_size // importance.size(0) 56 | mask = torch.ones_like(importance) 57 | for i in range(layer_start,layer_end): 58 | layer = importance[i] 59 | importance_layer_order = torch.argsort(layer) 60 | mask[i][importance_layer_order[:-target_size_per_layer]] = 0 61 | elif even_masking is True: 62 | target_size_per_layer = total_target_size // importance.size(0) 63 | mask = torch.ones_like(importance) 64 | for i,layer in enumerate(importance): 65 | importance_layer_order = torch.argsort(layer) 66 | mask[i][importance_layer_order[:-target_size_per_layer]] = 0 67 | elif multiple_of == 1: 68 | importance_flat = importance.reshape(-1) 69 | importance_order = torch.argsort(importance_flat) # ascending 70 | mask_flat = torch.ones_like(importance_flat) 71 | for pos in importance_order[:-total_target_size]: 72 | mask_flat[pos] = 0 73 | mask = mask_flat.reshape(importance.shape) 74 | else: 75 | num_layers = importance.size(0) 76 | num_groups = importance.size(1) // multiple_of 77 | importance_order_2d = torch.argsort(importance,dim=-1) 78 | importance_3d = torch.zeros(num_layers, num_groups, multiple_of).to(importance) 79 | for i, layer_order in enumerate(importance_order_2d): 80 | layer_sorted_by_importance = importance[i][layer_order].view(-1,multiple_of) # (num_head // multiple_of, multiple_of) 81 | importance_3d[i] = layer_sorted_by_importance 82 | importance_2d_order_2d = importance_order_2d.view(num_layers * num_groups, multiple_of) 83 | 84 | importance_3d_s_flat = importance_3d.sum(-1).view(-1) # num_layers * num_groups 85 | importance_3d_s_flat_order_flat = torch.argsort(importance_3d_s_flat) # ascending 86 | 87 | total_group_target_size = total_target_size // multiple_of 88 | mask = torch.ones_like(importance) 89 | 90 | for pos in importance_3d_s_flat_order_flat[:-total_group_target_size]: 91 | x = int(pos) // num_groups 92 | mask[x,importance_2d_order_2d[pos]] = 0 93 | 94 | # check for disconnected graph 95 | mask_sum = mask.sum(-1) 96 | for i in range(len(mask_sum)): 97 | if mask_sum[i]==0: 98 | print("Warning") 99 | most_imp = torch.argmax(importance[i]) 100 | mask[i][most_imp] = 1 101 | return mask 102 | 103 | 104 | def infer_logits(outputs,adaptor=None): 105 | if adaptor is None: 106 | try: 107 | if isinstance(outputs, torch.Tensor): 108 | logits = outputs 109 | assert len(logits.size())>0 110 | elif isinstance(outputs, (list,tuple)): 111 | logits = outputs[0] 112 | assert len(logits.size())>0 113 | elif isinstance(outputs, abc.Mapping): 114 | logits = outputs['logits'] 115 | else: 116 | logits = outputs.logits 117 | except (KeyError, AttributeError, AssertionError) as e: 118 | logger.error("Cannot infer logits from the outputs automatically! An adaptor is needed") 119 | raise e 120 | else: 121 | logits = adaptor(outputs) 122 | return logits 123 | 124 | 125 | def infer_loss(outputs, adaptor=None): 126 | if adaptor is None: 127 | try: 128 | if isinstance(outputs, torch.Tensor): 129 | loss = outputs 130 | assert len(loss.size())==0 131 | elif isinstance(outputs, (list,tuple)): 132 | loss = outputs[0] 133 | assert len(loss.size())==0 134 | elif isinstance(outputs, abc.Mapping): 135 | loss = outputs['loss'] 136 | else: 137 | loss = outputs.loss 138 | except (KeyError, AttributeError, AssertionError) as e: 139 | logger.error("Cannot infer loss from the outputs automatically! An adaptor is needed") 140 | raise e 141 | else: 142 | loss = adaptor(outputs) 143 | return loss -------------------------------------------------------------------------------- /textpruner/pruners/vocabulary_pruner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import os 4 | from ..model_map import MODEL_MAP 5 | from ..configurations import VocabularyPruningConfig, GeneralConfig 6 | from .utils import infer_model_type 7 | import logging 8 | from tqdm import tqdm 9 | from collections import abc 10 | from typing import Optional 11 | logger = logging.getLogger(__name__) 12 | 13 | class VocabularyPruner: 14 | """ 15 | Args: 16 | model : The model to be pruned. 17 | tokenizer : The tokenizer for the model. 18 | vocabulary_pruning_config : a :class:`~textpruner.configurations.VocabularyPruningConfig` object. 19 | general_config : a :class:`~textpruner.configurations.GeneralConfig` object. 20 | base_model_prefix : The prefix of the base model, i.e., the name of the base model as a member in the model. \ 21 | For example, if ``model.bert_encoder = BertModel(...)``, then the ``base_model_prefix`` is ``bert_encoder``. \ 22 | TextPruner will infer the ``base_model_prefix`` so we can leave its value as ``None``. But if it fails, users have to set its value explicitly. 23 | 24 | """ 25 | def __init__(self, 26 | model : nn.Module, 27 | tokenizer, 28 | vocabulary_pruning_config : Optional[VocabularyPruningConfig] = None, 29 | general_config : Optional[GeneralConfig] = None, 30 | base_model_prefix : Optional[str] = None): 31 | 32 | self.model = model 33 | self.tokenizer = tokenizer 34 | 35 | #infer model type 36 | base_model, model_type = infer_model_type(model, base_model_prefix) 37 | assert model_type in MODEL_MAP, \ 38 | f"Model type {self.model_type} is not supported, or not understood. Model type must be one of {list(MODEL_MAP.keys())}" 39 | self.base_model = base_model 40 | self.model_type = model_type 41 | 42 | 43 | self.general_config = GeneralConfig() if general_config is None else general_config 44 | self.vocabulary_pruning_config = VocabularyPruningConfig() if vocabulary_pruning_config is None else vocabulary_pruning_config 45 | 46 | self.model.to(self.general_config.device) 47 | 48 | self.model_vocab_resizer = MODEL_MAP[self.model_type]['resizer'] 49 | self.tokenizer_helper = MODEL_MAP[self.model_type]['tokenizer_helper'] 50 | self.pruned_token_ids = [] 51 | os.makedirs(self.general_config.output_dir, exist_ok=True) 52 | self.save_dir = None 53 | 54 | def prune(self, dataiter=None, additional_tokens=None, 55 | additional_token_ids=None, save_model=True) -> Optional[str]: 56 | ''' 57 | Prunes the vocabulay of the model and the tokenizer. The pruner will only keep the tokens in ``dataiter``, ``additional_tokens`` and ``additional_token_ids``. 58 | 59 | * Use ``dataiter`` to generate a set of tokens from the raw texts. 60 | * Use ``additional_tokens`` or ``additional_token_ids`` to specify the tokens or token_ids directly without running the tokenization. 61 | 62 | Args: 63 | dataiter : a list of pre-tokenized strings. These strings will be tokenized by the tokenizer to generate a set of tokens. 64 | additional_tokens : a list of tokens. These tokens must be existed in the original vocabulary. 65 | additional_token_ids : a list of ints representing the token ids. 66 | save_model : whether to save the model when the pruning is finished. 67 | ''' 68 | min_count = self.vocabulary_pruning_config.min_count 69 | lm_head_pruning= self.vocabulary_pruning_config.prune_lm_head 70 | pruned_token_ids = self.tokenizer_helper.get_token_ids(tokenizer=self.tokenizer, 71 | dataiter=dataiter, 72 | additional_tokens=additional_tokens, 73 | additional_token_ids=additional_token_ids, 74 | min_count=min_count) 75 | self.model_vocab_resizer.set_embeddings(model=self.base_model, token_ids=pruned_token_ids) 76 | 77 | if lm_head_pruning == 'auto' or lm_head_pruning is True: 78 | is_success = self.model_vocab_resizer.set_lm_head(self.model, pruned_token_ids) 79 | if is_success is False: 80 | if lm_head_pruning is True: 81 | logger.info("Cannot get output embeddings! Is your model has a MLM prediction head?") 82 | else: 83 | logger.info("Cannot get output embeddings. No LM head pruning.") 84 | self.pruned_token_ids = pruned_token_ids 85 | 86 | if save_model is True: 87 | self.save_dir = self.save_model() 88 | return self.save_dir 89 | 90 | 91 | def save_model(self, dir_name = None) -> str: 92 | 93 | if self.model_type.lower() in ['t5', 'mt5']: 94 | vocab_size = self.base_model.shared.weight.shape[0] 95 | else: 96 | vocab_size = len(self.pruned_token_ids) 97 | self.base_model.config.vocab_size = vocab_size 98 | 99 | if dir_name is None: 100 | save_dir = os.path.join(self.general_config.output_dir, f'pruned_V{vocab_size}') 101 | else: 102 | save_dir = os.path.join(self.general_config.output_dir, dir_name) 103 | os.makedirs(save_dir, exist_ok=True) 104 | 105 | # save tokenizer 106 | self.tokenizer_helper.save_vocab(self.tokenizer, self.pruned_token_ids, save_dir) 107 | 108 | # save weights 109 | torch.save(self.model.state_dict(),os.path.join(save_dir,f'pytorch_model.bin')) 110 | 111 | # save config 112 | config_dir = os.path.join(save_dir) 113 | self.base_model.config.save_pretrained(config_dir) 114 | logger.info(f"Model and configuration have been saved to {save_dir}") 115 | 116 | return save_dir 117 | -------------------------------------------------------------------------------- /textpruner/tokenizer_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .roberta_gpt2_tokenizer import RobertaGPT2Tokenizer 2 | from .subword_tokenizer import SubwordTokenizer 3 | from .sp_tokenizer import SentencepieceTokenizer 4 | from .xlmr_sp_tokenizer import XLMRSentencepieceTokenizer 5 | from .xlm_tokenizer import XLMTokenizer 6 | from .t5_sp_tokenizer import T5SentencepieceTokenizer 7 | from .mt5_sp_tokenizer import MT5SentencepieceTokenizer -------------------------------------------------------------------------------- /textpruner/tokenizer_utils/mt5_sp_tokenizer.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import re 4 | from .utils import count_unique_tokens 5 | import logging 6 | logger = logging.getLogger(__name__) 7 | try: 8 | from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model 9 | except ImportError: 10 | logger.warning("Could not import sentencepiece. Pruning embeddings of sentencepiece-based model is not available.") 11 | 12 | 13 | class MT5SentencepieceTokenizer: 14 | additional_special_token_ids = [] 15 | 16 | @classmethod 17 | def find_addition_special_token_ids(cls, tokenizer): 18 | add_spe_bound = ['▁', '▁'] 19 | lower, upper = tokenizer.convert_tokens_to_ids(add_spe_bound) 20 | add_spe_tokens_ids_not_in_tokenizer = list(range(lower, upper+1)) 21 | cls.additional_special_token_ids.extend(add_spe_tokens_ids_not_in_tokenizer) 22 | cls.additional_special_token_ids = sorted(list(set(cls.additional_special_token_ids))) 23 | 24 | @classmethod 25 | def get_token_ids(cls, tokenizer, dataiter=None, additional_tokens=None, additional_token_ids=None, min_count=1): 26 | base_token_ids = list(range(3, 3+256)) 27 | token_ids = [] 28 | special_token_ids = list(tokenizer.all_special_ids) 29 | cls.additional_special_token_ids = tokenizer.additional_special_tokens_ids 30 | if len(cls.additional_special_token_ids) == 0: 31 | cls.find_addition_special_token_ids(tokenizer) 32 | special_token_ids.extend(cls.additional_special_token_ids) 33 | special_token_ids = sorted(list(set(special_token_ids))) 34 | 35 | normal_token_ids = [] 36 | if dataiter is not None: 37 | token_ids_counter = count_unique_tokens(dataiter, tokenizer) 38 | normal_token_ids += [k for k,v in token_ids_counter.items() if v >= min_count] 39 | if additional_tokens is not None and len(additional_tokens) > 0: 40 | normal_token_ids += list( 41 | tokenizer.convert_tokens_to_ids(additional_tokens)) 42 | if additional_token_ids is not None and len(additional_token_ids) > 0: 43 | normal_token_ids += list(additional_token_ids) 44 | normal_token_ids = list(set(normal_token_ids)-set(special_token_ids)) 45 | token_ids = sorted(list(set(special_token_ids + normal_token_ids + base_token_ids))) 46 | 47 | return token_ids 48 | 49 | @classmethod 50 | def save_vocab(cls, tokenizer, token_ids, outdir): 51 | 52 | spm_token_ids = token_ids 53 | 54 | spm_token_ids = sorted(spm_token_ids) 55 | 56 | m = sp_pb2_model.ModelProto() 57 | m.ParseFromString(tokenizer.sp_model.serialized_model_proto()) 58 | spm_tokens = set([m.pieces[i].piece for i in spm_token_ids]) 59 | new_pieces = [p for p in m.pieces if p.piece in spm_tokens] 60 | 61 | del m.pieces[:] 62 | m.pieces.extend(new_pieces) 63 | 64 | pruned_vocab_file = os.path.join(outdir, 'spiece.model') 65 | with open(pruned_vocab_file, 'wb') as f: 66 | f.write(m.SerializeToString()) 67 | print(f"New embedding pruned vocab file has been saved to {pruned_vocab_file}. Reintialize the tokenizer!") -------------------------------------------------------------------------------- /textpruner/tokenizer_utils/roberta_gpt2_tokenizer.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import logging 4 | import json 5 | logger = logging.getLogger(__name__) 6 | from .utils import count_unique_tokens 7 | 8 | 9 | class RobertaGPT2Tokenizer: 10 | 11 | @staticmethod 12 | def get_token_ids(tokenizer, dataiter=None, additional_tokens=None, additional_token_ids=None, min_count=1): 13 | token_ids = [] 14 | # add special tokens 15 | special_token_ids = [0, 1, 2, 3] 16 | special_token_ids += [len(tokenizer)-4+i for i in range(4)] # ["unusedword0000","unusedword0001","unusedword0002",""] 17 | # remove special tokens, special tokens + normal tokens 18 | 19 | normal_token_ids = [] 20 | if dataiter is not None: 21 | token_ids_counter = count_unique_tokens(dataiter, tokenizer) 22 | normal_token_ids += [k for k,v in token_ids_counter.items() if v >= min_count] 23 | if additional_tokens is not None and len(additional_tokens) > 0: 24 | normal_token_ids += list( 25 | tokenizer.convert_tokens_to_ids(additional_tokens)) 26 | if additional_token_ids is not None and len(additional_token_ids) > 0: 27 | normal_token_ids += list(additional_token_ids) 28 | normal_token_ids = list(set(normal_token_ids)-set(special_token_ids)) 29 | token_ids = sorted(special_token_ids + normal_token_ids) # to make sure [0,1,2,3, ...., ] 30 | return token_ids 31 | 32 | @staticmethod 33 | def save_vocab(tokenizer, token_ids, outdir): 34 | 35 | assert len(token_ids) == len(set(token_ids)) 36 | 37 | tokens = tokenizer.convert_ids_to_tokens(token_ids) 38 | 39 | token_dict = {} 40 | for i in range(len(tokens)): 41 | token_dict[tokens[i]] = i 42 | 43 | pruned_vocab_file = os.path.join(outdir, 'vocab.json') 44 | with open(pruned_vocab_file, 'w', encoding='utf-8') as f: 45 | json.dump(token_dict, f) 46 | print(f"New embedding size {len(token_ids)} pruned vocab file has been saved to {pruned_vocab_file}. Reintialize the tokenizer!") 47 | 48 | index = 0 49 | bpe_ranks = sorted(tokenizer.bpe_ranks.items(), key = lambda k: k[1]) 50 | pruned_merges_file = os.path.join(outdir, 'merges.txt') 51 | with open(pruned_merges_file, "w", encoding="utf-8") as writer: 52 | writer.write("#version: 0.2\n") 53 | for bpe_tokens, _ in bpe_ranks: 54 | writer.write(bpe_tokens[0] + " " + bpe_tokens[1] + "\n") 55 | -------------------------------------------------------------------------------- /textpruner/tokenizer_utils/sp_tokenizer.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from .utils import count_unique_tokens 4 | import logging 5 | logger = logging.getLogger(__name__) 6 | try: 7 | from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model 8 | except ImportError: 9 | logger.warning("Could not import sentencepiece. Pruning embeddings of sentencepiece-based model is not available.") 10 | 11 | 12 | class SentencepieceTokenizer: 13 | 14 | @staticmethod 15 | def get_token_ids(tokenizer, dataiter=None, additional_tokens=None, additional_token_ids=None, min_count=1): 16 | token_ids = [] 17 | special_token_ids = list(tokenizer.all_special_ids) 18 | 19 | normal_token_ids = [] 20 | if dataiter is not None: 21 | token_ids_counter = count_unique_tokens(dataiter, tokenizer) 22 | normal_token_ids += [k for k,v in token_ids_counter.items() if v >= min_count] 23 | if additional_tokens is not None and len(additional_tokens) > 0: 24 | normal_token_ids += list( 25 | tokenizer.convert_tokens_to_ids(additional_tokens)) 26 | if additional_token_ids is not None and len(additional_token_ids) > 0: 27 | normal_token_ids += list(additional_token_ids) 28 | normal_token_ids = list(set(normal_token_ids)-set(special_token_ids)) 29 | token_ids = sorted(special_token_ids + normal_token_ids) 30 | return token_ids 31 | 32 | @staticmethod 33 | def save_vocab(tokenizer, token_ids, outdir): 34 | ''' 35 | fairseq_offset = 1 36 | # {"": 0, "": 1, "": 2, "": 3} 37 | fairseq_special_tokens_ids = [0, 1, 2, 3] 38 | fairseq_special_tokens_ids.append( 39 | len(tokenizer.sp_model) + fairseq_offset) # [""] 40 | # remove special tokens 41 | token_ids = [ 42 | t for t in token_ids if t not in fairseq_special_tokens_ids] 43 | 44 | # special tokens + normal tokens 45 | spm_token_ids = [0, 1, 2] + \ 46 | [t-fairseq_offset for t in token_ids] 47 | assert len(spm_token_ids) == len(set(spm_token_ids)) 48 | ''' 49 | 50 | spm_token_ids = token_ids 51 | m = sp_pb2_model.ModelProto() 52 | m.ParseFromString(tokenizer.sp_model.serialized_model_proto()) 53 | 54 | spm_tokens = set([m.pieces[i].piece for i in spm_token_ids]) 55 | new_pieces = [p for p in m.pieces if p.piece in spm_tokens] 56 | 57 | # delete all 58 | del m.pieces[:] 59 | m.pieces.extend(new_pieces) 60 | 61 | pruned_vocab_file = os.path.join(outdir, 'spiece.model') 62 | with open(pruned_vocab_file, 'wb') as f: 63 | f.write(m.SerializeToString()) 64 | print(f"New embedding size {len(new_pieces)+2} pruned vocab file has been saved to {pruned_vocab_file}. Reintialize the tokenizer!") -------------------------------------------------------------------------------- /textpruner/tokenizer_utils/subword_tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .utils import count_unique_tokens 3 | 4 | class SubwordTokenizer: 5 | @staticmethod 6 | def get_token_ids(tokenizer, dataiter=None, additional_tokens=None, additional_token_ids=None, min_count=1): 7 | token_ids = [] 8 | # add special tokens 9 | special_token_ids = list(tokenizer.all_special_ids) 10 | 11 | normal_token_ids = [] 12 | if dataiter is not None: 13 | token_ids_counter = count_unique_tokens(dataiter, tokenizer) 14 | normal_token_ids += [k for k,v in token_ids_counter.items() if v >= min_count] 15 | if additional_tokens is not None and len(additional_tokens) > 0: 16 | normal_token_ids += list( 17 | tokenizer.convert_tokens_to_ids(additional_tokens)) 18 | if additional_token_ids is not None and len(additional_token_ids) > 0: 19 | normal_token_ids += list(additional_token_ids) 20 | normal_token_ids = list(set(normal_token_ids)-set(special_token_ids)) 21 | token_ids = sorted(special_token_ids + normal_token_ids) 22 | return token_ids 23 | 24 | @staticmethod 25 | def save_vocab(tokenizer, token_ids, outdir): 26 | tokens = tokenizer.convert_ids_to_tokens(token_ids) 27 | pruned_vocab_file = os.path.join(outdir, 'vocab.txt') 28 | with open(pruned_vocab_file, 'w', encoding='utf-8') as f: 29 | for token in tokens: 30 | f.write(token+'\n') 31 | print(f"New embedding size {len(token_ids)} pruned vocab file has been saved to {pruned_vocab_file}. Reintialize the tokenizer!") -------------------------------------------------------------------------------- /textpruner/tokenizer_utils/t5_sp_tokenizer.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import re 4 | from .utils import count_unique_tokens 5 | import logging 6 | logger = logging.getLogger(__name__) 7 | try: 8 | from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model 9 | except ImportError: 10 | logger.warning("Could not import sentencepiece. Pruning embeddings of sentencepiece-based model is not available.") 11 | 12 | 13 | class T5SentencepieceTokenizer: 14 | additional_special_token_ids = [] 15 | 16 | 17 | 18 | @classmethod 19 | def get_token_ids(cls, tokenizer, dataiter=None, additional_tokens=None, additional_token_ids=None, min_count=1): 20 | token_ids = [] 21 | #special_token_ids = list(set(tokenizer.all_special_ids) - set(tokenizer.additional_special_tokens_ids)) 22 | special_token_ids = list(tokenizer.all_special_ids) 23 | cls.additional_special_token_ids = tokenizer.additional_special_tokens_ids 24 | 25 | 26 | normal_token_ids = [] 27 | if dataiter is not None: 28 | token_ids_counter = count_unique_tokens(dataiter, tokenizer) 29 | normal_token_ids += [k for k,v in token_ids_counter.items() if v >= min_count] 30 | if additional_tokens is not None and len(additional_tokens) > 0: 31 | normal_token_ids += list( 32 | tokenizer.convert_tokens_to_ids(additional_tokens)) 33 | if additional_token_ids is not None and len(additional_token_ids) > 0: 34 | normal_token_ids += list(additional_token_ids) 35 | normal_token_ids = list(set(normal_token_ids)-set(special_token_ids)) 36 | token_ids = sorted(special_token_ids + normal_token_ids) 37 | 38 | return token_ids 39 | 40 | @classmethod 41 | def save_vocab(cls, tokenizer, token_ids, outdir): 42 | 43 | 44 | spm_token_ids = list(set(token_ids) - set(cls.additional_special_token_ids)) 45 | m = sp_pb2_model.ModelProto() 46 | m.ParseFromString(tokenizer.sp_model.serialized_model_proto()) 47 | 48 | spm_tokens = set([m.pieces[i].piece for i in spm_token_ids]) 49 | new_pieces = [p for p in m.pieces if p.piece in spm_tokens] 50 | 51 | # delete all 52 | del m.pieces[:] 53 | m.pieces.extend(new_pieces) 54 | 55 | pruned_vocab_file = os.path.join(outdir, 'spiece.model') 56 | with open(pruned_vocab_file, 'wb') as f: 57 | f.write(m.SerializeToString()) 58 | print(f"New embedding pruned vocab file has been saved to {pruned_vocab_file}. Reintialize the tokenizer!") -------------------------------------------------------------------------------- /textpruner/tokenizer_utils/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | from collections import Counter 3 | from collections.abc import Iterable 4 | from typing import Callable, Optional 5 | from tqdm import tqdm 6 | import logging 7 | import json 8 | logger = logging.getLogger(__name__) 9 | 10 | def count_frequency(self, texts : Iterable): 11 | token_counter = Counter() 12 | 13 | for text in texts: 14 | tokens = self.tokenizer.tokenize(text) 15 | token_counter.update(tokens) 16 | all_tokens = [k for (k, v) in token_counter.most_common()] 17 | all_token_indices = self.tokenizer.convert_tokens_to_ids(all_tokens) 18 | return all_tokens, all_token_indices 19 | 20 | 21 | 22 | def count_unique_tokens(dataiter, tokenizer, fn : Optional[Callable] =None) -> Counter : 23 | assert not isinstance(dataiter,str), "dataiter is assumed to be a collection (list, tuple, ...) of strings, not a single string" 24 | token_ids = Counter() 25 | for item in tqdm(dataiter): 26 | if fn is not None: 27 | item = fn(item) # pre-transform 28 | if isinstance(item, str): 29 | token_ids.update(tokenizer.encode(item, add_special_tokens=True)) 30 | else: 31 | assert isinstance(item[0],str) # list of string 32 | token_ids.update(list(chain(*(tokenizer.encode(i, add_special_tokens=True) for i in item)))) 33 | return token_ids -------------------------------------------------------------------------------- /textpruner/tokenizer_utils/xlm_tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .utils import count_unique_tokens 3 | import logging 4 | import json 5 | logger = logging.getLogger(__name__) 6 | 7 | class XLMTokenizer: 8 | @staticmethod 9 | def get_token_ids(tokenizer, dataiter=None, additional_tokens=None, additional_token_ids=None, min_count=1): 10 | token_ids = [] 11 | # add special tokens 12 | special_tokens = ['', '', '', '', '', '', '', '', '', '', '', '', '', ''] 13 | special_token_ids = list(range(0, 14)) 14 | normal_token_ids = [] 15 | 16 | if dataiter is not None: 17 | token_ids_counter = count_unique_tokens(dataiter, tokenizer) 18 | normal_token_ids += [k for k,v in token_ids_counter.items() if v >= min_count] 19 | if additional_tokens is not None and len(additional_tokens) > 0: 20 | normal_token_ids += list( 21 | tokenizer.convert_tokens_to_ids(additional_tokens)) 22 | if additional_token_ids is not None and len(additional_token_ids) > 0: 23 | normal_token_ids += list(additional_token_ids) 24 | 25 | normal_token_ids = list(set(normal_token_ids)-set(special_token_ids)) 26 | token_ids = sorted(special_token_ids + normal_token_ids) 27 | return token_ids 28 | 29 | @staticmethod 30 | def save_vocab(tokenizer, token_ids, outdir): 31 | assert len(token_ids) == len(set(token_ids)) 32 | 33 | tokens = tokenizer.convert_ids_to_tokens(token_ids) 34 | token_dict = {} 35 | for i in range(len(tokens)): 36 | token_dict[tokens[i]] = i 37 | 38 | 39 | tokenizer.save_pretrained(outdir) 40 | pruned_vocab_file = os.path.join(outdir, 'vocab.json') 41 | with open(pruned_vocab_file, 'w', encoding='utf-8') as f: 42 | json.dump(token_dict, f) 43 | 44 | print(f"New embedding size {len(token_ids)} pruned vocab file has been saved to {pruned_vocab_file}. Reintialize the tokenizer!") 45 | 46 | 47 | bpe_ranks = sorted(tokenizer.bpe_ranks.items(), key = lambda k: k[1]) 48 | 49 | 50 | pruned_merges_file = os.path.join(outdir, 'merges.txt') 51 | with open(pruned_merges_file, "w", encoding="utf-8") as writer: 52 | for bpe_tokens, _ in bpe_ranks: 53 | if len(bpe_tokens) != 2: 54 | continue 55 | writer.write(bpe_tokens[0] + " " + bpe_tokens[1] + "\n") 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /textpruner/tokenizer_utils/xlmr_sp_tokenizer.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from .utils import count_unique_tokens 4 | import logging 5 | logger = logging.getLogger(__name__) 6 | try: 7 | from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model 8 | except ImportError: 9 | logger.warning("Could not import sentencepiece. Pruning embeddings of sentencepiece-based model is not available.") 10 | 11 | 12 | class XLMRSentencepieceTokenizer: 13 | 14 | @staticmethod 15 | def get_token_ids(tokenizer, dataiter=None, additional_tokens=None, additional_token_ids=None, min_count=1): 16 | token_ids = [] 17 | # add special tokens 18 | # should equal to [0,1,2,3,size +1] 19 | special_token_ids = list(tokenizer.all_special_ids) 20 | 21 | normal_token_ids = [] 22 | if dataiter is not None: 23 | token_ids_counter = count_unique_tokens(dataiter, tokenizer) 24 | normal_token_ids += [k for k,v in token_ids_counter.items() if v >= min_count] 25 | if additional_tokens is not None and len(additional_tokens) > 0: 26 | normal_token_ids += list( 27 | tokenizer.convert_tokens_to_ids(additional_tokens)) 28 | if additional_token_ids is not None and len(additional_token_ids) > 0: 29 | normal_token_ids += list(additional_token_ids) 30 | normal_token_ids = list(set(normal_token_ids)-set(special_token_ids)) 31 | token_ids = sorted(special_token_ids + normal_token_ids) # to make sure [0,1,2,3, ...., ] 32 | return token_ids 33 | 34 | @staticmethod 35 | def save_vocab(tokenizer, token_ids, outdir): 36 | fairseq_offset = 1 37 | # {"": 0, "": 1, "": 2, "": 3} 38 | fairseq_special_tokens_ids = [0, 1, 2, 3] 39 | fairseq_special_tokens_ids.append( 40 | len(tokenizer.sp_model) + fairseq_offset) # [""] 41 | # remove special tokens 42 | token_ids = [ 43 | t for t in token_ids if t not in fairseq_special_tokens_ids] 44 | 45 | # special tokens + normal tokens 46 | spm_token_ids = [0, 1, 2] + \ 47 | [t-fairseq_offset for t in token_ids] 48 | assert len(spm_token_ids) == len(set(spm_token_ids)) 49 | 50 | 51 | m = sp_pb2_model.ModelProto() 52 | m.ParseFromString(tokenizer.sp_model.serialized_model_proto()) 53 | 54 | spm_tokens = set([m.pieces[i].piece for i in spm_token_ids]) 55 | new_pieces = [p for p in m.pieces if p.piece in spm_tokens] 56 | 57 | # delete all 58 | del m.pieces[:] 59 | m.pieces.extend(new_pieces) 60 | 61 | # #debug 62 | # #debug 63 | # print ("spm_token_ids:",spm_token_ids) 64 | # print ("spm_tokens:",spm_tokens) 65 | # print ('new pieces:',[p.piece for p in m.pieces]) 66 | 67 | pruned_vocab_file = os.path.join(outdir, 'sentencepiece.bpe.model') 68 | with open(pruned_vocab_file, 'wb') as f: 69 | f.write(m.SerializeToString()) 70 | print(f"New embedding size {len(new_pieces)+2} pruned vocab file has been saved to {pruned_vocab_file}. Reintialize the tokenizer!") 71 | -------------------------------------------------------------------------------- /textpruner/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections.abc import Mapping 3 | from tqdm import tqdm 4 | import time 5 | from typing import Tuple, Union,Dict,Optional,List 6 | 7 | class LayerNode: 8 | def __init__(self,name,parent=None,value=None,fullname=None): 9 | self.name = name 10 | self.fullname = fullname 11 | self.value = None 12 | self.children_name = {} 13 | self.parent = parent 14 | def __contains__(self, key): 15 | return key in self.children_name 16 | def __getitem__(self,key): 17 | return self.children_name[key] 18 | def __setitem__(self,key,value): 19 | self.children_name[key]=value 20 | def update(self,value): 21 | if self.parent: 22 | if self.parent.value is None: 23 | self.parent.value = value 24 | else: 25 | if isinstance(value,(tuple,list)): 26 | old_value = self.parent.value 27 | new_value = [old_value[i]+value[i] for i in range(len(value))] 28 | self.parent.value = new_value 29 | else: 30 | self.parent.value += value 31 | if self.name.endswith('(shared)'): 32 | if self.parent.name.endswith('shared)'): 33 | pass 34 | elif self.parent.value[0] == 0: 35 | self.parent.name += '(shared)' 36 | else: 37 | self.parent.name += '(partially shared)' 38 | 39 | self.parent.update(value) 40 | 41 | def format(self, level=0, total=None ,indent='--',max_level=None,max_length=None): 42 | string ='' 43 | if total is None: 44 | total = self.value[0] 45 | if level ==0: 46 | max_length = self._max_name_length(indent,' ',max_level=max_level) + 1 47 | string += '\n' 48 | string +=f"{'LAYER NAME':<{max_length}}\t{'#PARAMS':>15}\t{'RATIO':>10}\t{'MEM(MB)':>8}\n" 49 | 50 | if max_level is not None and level==max_level: 51 | string += f"{indent+self.name+':':<{max_length}}\t{self.value[0]:15,d}\t{self.value[0]/total:>10.2%}\t{self.value[1]:>8.2f}\n" 52 | else: 53 | if len(self.children_name)==1: 54 | string += f"{indent+self.name:{max_length}}\n" 55 | else: 56 | string += f"{indent+self.name+':':<{max_length}}\t{self.value[0]:15,d}\t{self.value[0]/total:>10.2%}\t{self.value[1]:>8.2f}\n" 57 | for child_name, child in self.children_name.items(): 58 | string += child.format(level+1, total, 59 | indent=' '+indent, max_level=max_level,max_length=max_length) 60 | return string 61 | 62 | def _max_name_length(self,indent1='--', indent2=' ',level=0,max_level=None): 63 | length = len(self.name) + len(indent1) + level *len(indent2) 64 | if max_level is not None and level >= max_level: 65 | child_lengths = [] 66 | else: 67 | child_lengths = [child._max_name_length(indent1,indent2,level=level+1,max_level=max_level) 68 | for child in self.children_name.values()] 69 | max_length = max(child_lengths+[length]) 70 | return max_length 71 | 72 | 73 | def summary(model : Union[torch.nn.Module,Dict], max_level : Optional[int] = 2): 74 | """ 75 | Show the summary of model parameters. 76 | 77 | Args: 78 | model: the model to be inspected, can be a torch module or a state_dict. 79 | max_level: The max level to display. If ``max_level==None``, show all the levels. 80 | Returns: 81 | A formatted string. 82 | 83 | Example:: 84 | 85 | print(textpruner.summay(model)) 86 | 87 | """ 88 | if isinstance(model,torch.nn.Module): 89 | state_dict = model.state_dict() 90 | elif isinstance(model,dict): 91 | state_dict = model 92 | else: 93 | raise TypeError("model should be either torch.nn.Module or a dict") 94 | hash_set = set() 95 | model_node = LayerNode('model',fullname='model') 96 | current = model_node 97 | for key,value in state_dict.items(): 98 | names = key.split('.') 99 | for i,name in enumerate(names): 100 | if name not in current: 101 | current[name] = LayerNode(name,parent=current,fullname='.'.join(names[:i+1])) 102 | current = current[name] 103 | 104 | if (value.data_ptr()) in hash_set: 105 | current.value = [0,0] 106 | current.name += "(shared)" 107 | current.fullname += "(shared)" 108 | current.update(current.value) 109 | else: 110 | hash_set.add(value.data_ptr()) 111 | current.value = [value.numel(),value.numel() * value.element_size() / 1024 / 1024] 112 | current.update(current.value) 113 | 114 | current = model_node 115 | 116 | result = model_node.format(max_level=max_level) 117 | 118 | return result 119 | 120 | 121 | def inference_time(model : torch.nn.Module, dummy_inputs : Union[List,Tuple,Dict], warm_up : int = 5, repetitions : int = 10): 122 | """ 123 | Measure and print the inference time of the model. 124 | 125 | Args: 126 | model: the torch module to be measured. 127 | dummpy_inputs: the inputs to be fed into the model, can be a list ,tuple or dict. 128 | warm_up: Number of steps to warm up the device. 129 | repetitions: Number of steps to perform forward propagation. More repetitions result in more accurate measurements. 130 | 131 | Example:: 132 | 133 | input_ids = torch.randint(low=0,high=10000,size=(32,256)) 134 | textpruner.inference_time(model,dummy_inputs=[input_ids]) 135 | 136 | """ 137 | device = model.device 138 | is_train = model.training 139 | model.eval() 140 | 141 | if device.type == 'cpu': 142 | mean, std = cpu_inference_time(model, dummy_inputs, warm_up, repetitions) 143 | elif device.type == 'cuda': 144 | mean, std = cuda_inference_time(model, dummy_inputs, warm_up, repetitions) 145 | else: 146 | raise ValueError(f"Unknown device {device}") 147 | 148 | model.train(is_train) 149 | print(f"Device: {device}") 150 | print(f"Mean inference time: {mean:.2f}ms") 151 | print(f"Standard deviation: {std:.2f}ms") 152 | 153 | return mean, std 154 | 155 | 156 | def cuda_inference_time(model : torch.nn.Module, dummy_inputs, warm_up, repetitions): 157 | device = model.device 158 | starter = torch.cuda.Event(enable_timing=True) 159 | ender = torch.cuda.Event(enable_timing=True) 160 | timings=torch.zeros(repetitions) 161 | with torch.no_grad(): 162 | for _ in tqdm(range(warm_up),desc='cuda-warm-up'): 163 | if isinstance(dummy_inputs, Mapping): 164 | inputs = {k: v.to(device) for k,v in dummy_inputs.items()} 165 | _ = model(**inputs) 166 | else: 167 | inputs = [t.to(device) for t in dummy_inputs] 168 | _ = model(*inputs) 169 | for rep in tqdm(range(repetitions),desc='cuda-repetitions'): 170 | if isinstance(dummy_inputs, Mapping): 171 | inputs = {k: v.to(device) for k,v in dummy_inputs.items()} 172 | starter.record() 173 | _ = model(**inputs) 174 | ender.record() 175 | else: 176 | inputs = [t.to(device) for t in dummy_inputs] 177 | starter.record() 178 | _ = model(*inputs) 179 | ender.record() 180 | torch.cuda.synchronize() 181 | elapsed_time_ms = starter.elapsed_time(ender) 182 | timings[rep] = elapsed_time_ms 183 | mean = timings.sum().item() / repetitions 184 | std = timings.std().item() 185 | 186 | return mean, std 187 | 188 | 189 | def cpu_inference_time(model : torch.nn.Module, dummy_inputs, warm_up, repetitions): 190 | device = model.device 191 | timings=torch.zeros(repetitions) 192 | with torch.no_grad(): 193 | for _ in tqdm(range(warm_up),desc='cpu-warm-up'): 194 | if isinstance(dummy_inputs, Mapping): 195 | inputs = {k: v.to(device) for k,v in dummy_inputs.items()} 196 | _ = model(**inputs) 197 | else: 198 | inputs = [t.to(device) for t in dummy_inputs] 199 | _ = model(*inputs) 200 | for rep in tqdm(range(repetitions),desc='cpu-repetitions'): 201 | if isinstance(dummy_inputs, Mapping): 202 | inputs = {k: v.to(device) for k,v in dummy_inputs.items()} 203 | start = time.time() 204 | _ = model(**inputs) 205 | end = time.time() 206 | else: 207 | inputs = [t.to(device) for t in dummy_inputs] 208 | start = time.time() 209 | _ = model(*inputs) 210 | end = time.time() 211 | elapsed_time_ms = (end - start) * 1000 212 | timings[rep] = elapsed_time_ms 213 | mean = timings.sum().item() / repetitions 214 | std = timings.std().item() 215 | 216 | return mean, std --------------------------------------------------------------------------------