├── .github └── workflows │ └── pre-commit.yaml ├── .gitignore ├── .idea └── workspace.xml ├── .pre-commit-config.yaml ├── .pylintrc ├── LICENSE ├── Makefile ├── Model Merge And Analysis Tools ├── Enhanced_Mixer.py ├── Enhanced_Mixer_Requirements.txt ├── LM_BlockMerge.py ├── LM_BlockMerge_Requirements.txt ├── StratusScope.py ├── StratusScope_BarGraph.png ├── StratusScope_ConsoleOutput.png ├── StratusScope_Requirements.txt └── __Quick Tool Explainer__.txt ├── README.md ├── clustering ├── download.py ├── feature_extractor.py ├── hierarchical_clustering.py ├── memmap_utils.py └── train_clusterer.py ├── conda-mdel.yml ├── configs ├── fp16_1-3B_4M_bs_1.4T_tok_summit_pp3_mp2_256_nodes.yml ├── fp16_2-7B_4M_bs_1.4T_tok_summit_pp6_mp2_256nodes.yml └── fp16_6-7B_4M_bs_1T_tok_summit_pp12_mp2_mbs2_512_nodes_real.yml ├── distillation_sparsification ├── README.md ├── datautils.py ├── distill.py ├── falcon.py ├── lion.py ├── lm_seqs_dataset.py ├── make_student.py ├── modelutils.py ├── process_data.py ├── quant.py ├── sparsegpt.py ├── test.py ├── test1.py ├── tracker.py ├── train.py └── utils.py ├── docs └── .gitkeep ├── lora-x ├── README.md ├── bpt_attention_plugin.py ├── bpt_pt.py ├── bpt_triton.py ├── config.py ├── configs │ └── zero3_offload_config.json ├── data.py ├── experimental │ ├── qlora_lomo.py │ └── train_qlora_lomo.py ├── flash_patch.py ├── lora.py ├── qlora_bpt.py ├── requirements.txt ├── scripts │ └── juwels_booster.sh └── utils.py ├── notebooks ├── .gitkeep ├── CalculatePerplexity.ipynb └── Merge_N_Experts.ipynb ├── requirements.txt ├── resources.md ├── scripts ├── c-btmInference.py ├── calc_perplexities.sh ├── calc_perplexities_slurm.sh ├── create_domain_pile_mix.sh ├── get_pile_shard1_data.sh └── upload_to_hf.sh ├── setup.py └── src └── mdel ├── __init__.py ├── calculate_perplexity.py ├── configs ├── config.yaml └── zero_config.json ├── eval_merges.py ├── iterate_layers.sh ├── merge_experts.py ├── pile_upload.py ├── pile_utils.py ├── train.sh ├── train_cbtm_classifier.py ├── train_chat.sh ├── train_ds.sh ├── trainer.py ├── trainer_chat.bat └── trainer_chat.py /.github/workflows/pre-commit.yaml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | workflow_call: 5 | pull_request_target: 6 | 7 | jobs: 8 | pre-commit: 9 | runs-on: ubuntu-latest 10 | steps: 11 | # in case of PR, check out the PR's head branch 12 | - uses: actions/checkout@v3 13 | if: github.event_name == 'pull_request_target' 14 | with: 15 | ref: ${{ github.event.pull_request.head.sha }} 16 | 17 | # in case of push, check out the main branch 18 | - uses: actions/checkout@v3 19 | if: github.event_name != 'pull_request_target' 20 | 21 | - uses: actions/setup-python@v4 22 | with: 23 | python-version: "3.10" 24 | cache: "pip" 25 | cache-dependency-path: "**/requirements*.txt" 26 | - uses: pre-commit/action@v3.0.0 27 | - name: Post PR comment on failure 28 | if: failure() && github.event_name == 'pull_request_target' 29 | uses: peter-evans/create-or-update-comment@v2 30 | with: 31 | issue-number: ${{ github.event.pull_request.number }} 32 | body: | 33 | :x: **pre-commit** failed. 34 | Please run `pre-commit run --all-files` locally and commit the changes. 35 | Find more information in the repository's CONTRIBUTING.md 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .DS_Store 132 | 133 | # SLURM results 134 | *.out 135 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 23 | 24 | 26 | 27 | 29 | 30 | 31 | 34 | { 35 | "keyToString": { 36 | "RunOnceActivity.OpenProjectViewOnStart": "true", 37 | "RunOnceActivity.ShowReadmeOnStart": "true", 38 | "last_opened_file_path": "/Users/kentsui/opensource/MDEL", 39 | "settings.editor.selected.configurable": "com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" 40 | } 41 | } 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 1683965013951 52 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # WARNING! 2 | # 3 | # When making changes to auto-formatters used in pre-commit hooks, you are 4 | # likely to cause merge conflicts with main and/or other pull requests. 5 | # Fixing them might revert other people's work. Expect pain! 6 | # To avoid accidental reversions and keep it easy to review, please make sure 7 | # that changes here are in a pull request by themselves, that it consists of 8 | # two commits: 9 | # 10 | # 1. The changes to this file 11 | # 2. Changes made by running `python3 -m pre_commit run --all-files`. 12 | # 13 | # Then each time your pull request is blocked by a merge conflict, do the 14 | # following steps: 15 | # 16 | # git reset HEAD^1 && git checkout -f # discard the change commit 17 | # git rebase main # re-apply other people's changes 18 | # python3 -m pre_commit run --all-files # re-run the rules 19 | # git add . # add the newly changed files 20 | # git commit -m 'apply pre-commit' # commit it 21 | # git push -f # force push back to your branch 22 | # 23 | # Keep in mind you may have to do this a few times, as changes here may impact 24 | # other pull requests. Try to keep it up-to-date so they can go in when it'll 25 | # cause least disruption. 26 | # 27 | # /WARNING! 28 | 29 | exclude: build|stubs|^bot/templates/$|openassistant/templates|docs/docs/api/openapi.json 30 | 31 | repos: 32 | - repo: https://github.com/pre-commit/pre-commit-hooks 33 | rev: v4.4.0 34 | hooks: 35 | - id: trailing-whitespace 36 | - id: check-ast 37 | - id: check-yaml 38 | # Always check YAML but skip a few YAML files that are auto-generated 39 | # and which break the standard YAML check. The alternative would be to 40 | # skip any unsafe errors (and thus break YAML compatibility) or use 41 | # some other checker that may not work in general. 42 | exclude: ^copilot/.*/addons/.*$ 43 | - id: check-json 44 | - id: check-case-conflict 45 | - id: detect-private-key 46 | - id: fix-encoding-pragma 47 | args: [--remove] 48 | - id: forbid-submodules 49 | - id: mixed-line-ending 50 | - id: requirements-txt-fixer 51 | - id: check-executables-have-shebangs 52 | - id: check-shebang-scripts-are-executable 53 | - id: check-byte-order-marker 54 | - id: check-symlinks 55 | - id: check-merge-conflict 56 | - id: check-added-large-files 57 | args: [--maxkb=1024] 58 | - id: end-of-file-fixer 59 | 60 | - repo: https://github.com/pycqa/flake8 61 | rev: 6.0.0 62 | hooks: 63 | - id: flake8 64 | args: [--max-line-length=120, --ignore=C901] 65 | 66 | - repo: https://github.com/pycqa/isort 67 | rev: 5.12.0 68 | hooks: 69 | - id: isort 70 | 71 | - repo: https://github.com/pre-commit/mirrors-prettier 72 | rev: v3.0.0-alpha.4 73 | hooks: 74 | - id: prettier 75 | args: [--prose-wrap=always, --write] 76 | 77 | - repo: local 78 | hooks: 79 | - id: next-lint-website 80 | name: Lint website 81 | files: ^website/ 82 | exclude: ^website/node_modules/ 83 | types_or: [javascript, jsx, ts, tsx] 84 | language: node 85 | pass_filenames: false 86 | entry: website/next-lint.js 87 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [BASIC] 2 | 3 | # Naming style matching correct argument names. 4 | argument-naming-style=snake_case 5 | 6 | # Regular expression matching correct argument names. Overrides argument- 7 | # naming-style. 8 | #argument-rgx= 9 | 10 | # Naming style matching correct attribute names. 11 | attr-naming-style=snake_case 12 | 13 | # Regular expression matching correct attribute names. Overrides attr-naming- 14 | # style. 15 | #attr-rgx= 16 | 17 | # Bad variable names which should always be refused, separated by a comma. 18 | bad-names=foo, 19 | bar, 20 | baz, 21 | toto, 22 | tutu, 23 | tata 24 | 25 | # Naming style matching correct class attribute names. 26 | class-attribute-naming-style=any 27 | 28 | # Regular expression matching correct class attribute names. Overrides class- 29 | # attribute-naming-style. 30 | #class-attribute-rgx= 31 | 32 | # Naming style matching correct class names. 33 | class-naming-style=PascalCase 34 | 35 | # Regular expression matching correct class names. Overrides class-naming- 36 | # style. 37 | #class-rgx= 38 | 39 | # Naming style matching correct constant names. 40 | const-naming-style=UPPER_CASE 41 | 42 | # Regular expression matching correct constant names. Overrides const-naming- 43 | # style. 44 | #const-rgx= 45 | 46 | # Minimum line length for functions/classes that require docstrings, shorter 47 | # ones are exempt. 48 | docstring-min-length=-1 49 | 50 | # Naming style matching correct function names. 51 | function-naming-style=snake_case 52 | 53 | # Regular expression matching correct function names. Overrides function- 54 | # naming-style. 55 | #function-rgx= 56 | 57 | # Good variable names which should always be accepted, separated by a comma. 58 | good-names=i, 59 | j, 60 | k, 61 | ex, 62 | Run, 63 | _ 64 | 65 | # Include a hint for the correct naming format with invalid-name. 66 | include-naming-hint=no 67 | 68 | # Naming style matching correct inline iteration names. 69 | inlinevar-naming-style=any 70 | 71 | # Regular expression matching correct inline iteration names. Overrides 72 | # inlinevar-naming-style. 73 | #inlinevar-rgx= 74 | 75 | # Naming style matching correct method names. 76 | method-naming-style=snake_case 77 | 78 | # Regular expression matching correct method names. Overrides method-naming- 79 | # style. 80 | #method-rgx= 81 | 82 | # Naming style matching correct module names. 83 | module-naming-style=snake_case 84 | 85 | # Regular expression matching correct module names. Overrides module-naming- 86 | # style. 87 | #module-rgx= 88 | 89 | # Colon-delimited sets of names that determine each other's naming style when 90 | # the name regexes allow several styles. 91 | name-group= 92 | 93 | # Regular expression which should only match function or class names that do 94 | # not require a docstring. 95 | no-docstring-rgx=^_ 96 | 97 | # List of decorators that produce properties, such as abc.abstractproperty. Add 98 | # to this list to register other decorators that produce valid properties. 99 | # These decorators are taken in consideration only for invalid-name. 100 | property-classes=abc.abstractproperty 101 | 102 | # Naming style matching correct variable names. 103 | variable-naming-style=snake_case 104 | 105 | [CLASSES] 106 | 107 | # List of method names used to declare (i.e. assign) instance attributes. 108 | defining-attr-methods=__init__, 109 | __new__, 110 | setUp, 111 | __post_init__ 112 | 113 | # List of member names, which should be excluded from the protected access 114 | # warning. 115 | exclude-protected=_asdict, 116 | _fields, 117 | _replace, 118 | _source, 119 | _make 120 | 121 | # List of valid names for the first argument in a class method. 122 | valid-classmethod-first-arg=cls 123 | 124 | # List of valid names for the first argument in a metaclass class method. 125 | valid-metaclass-classmethod-first-arg=cls 126 | 127 | 128 | [DESIGN] 129 | 130 | # Maximum number of arguments for function / method. 131 | max-args=10 132 | 133 | # Maximum number of attributes for a class (see R0902). 134 | max-attributes=20 135 | 136 | # Maximum number of boolean expressions in an if statement (see R0916). 137 | max-bool-expr=5 138 | 139 | # Maximum number of branch for function / method body. 140 | max-branches=12 141 | 142 | # Maximum number of locals for function / method body. 143 | max-locals=15 144 | 145 | # Maximum number of parents for a class (see R0901). 146 | max-parents=7 147 | 148 | # Maximum number of public methods for a class (see R0904). 149 | max-public-methods=20 150 | 151 | # Maximum number of return / yield for function / method body. 152 | max-returns=6 153 | 154 | # Maximum number of statements in function / method body. 155 | max-statements=50 156 | 157 | # Minimum number of public methods for a class (see R0903). 158 | min-public-methods=2 159 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | VENV = venv 2 | VENV_PYTHON = $(VENV)/bin/python 3 | SYSTEM_PYTHON = $(or $(shell which python3), $(shell which python)) 4 | PYTHON = $(or $(wildcard $(VENV_PYTHON)), $(SYSTEM_PYTHON)) 5 | 6 | $(VENV_PYTHON): 7 | if [ ! -d "$(VENV)" ]; then $(SYSTEM_PYTHON) -m venv $(VENV); \ 8 | else \ 9 | echo "Virtual environment already exists in directory $(VENV)"; \ 10 | fi 11 | 12 | venv: $(VENV_PYTHON) 13 | 14 | setup_dev: 15 | pip install -r requirements.txt 16 | pre-commit install 17 | pip install -e . 18 | 19 | 20 | .PHONY: venv $(VENV_PYTHON) 21 | -------------------------------------------------------------------------------- /Model Merge And Analysis Tools/Enhanced_Mixer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Original script by Concedo/LostRuins; the mastermind behind what they once called a "Rubbish experiment" 3 | Now, an incredible leap forward in Language Model engineering and experimentation. 4 | 5 | Script modified by Chasm/Digitous 6 | ''' 7 | 8 | import json 9 | import os 10 | import shutil 11 | import subprocess 12 | from tkinter.filedialog import askdirectory, askopenfilename 13 | 14 | import torch 15 | from colorama import Fore, Style, init 16 | from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, 17 | LlamaConfig, LlamaForCausalLM, LlamaTokenizer, 18 | PreTrainedTokenizer, PreTrainedTokenizerFast) 19 | 20 | newline = '\n' 21 | def clear_console(): 22 | if os.name == "nt": # For Windows 23 | subprocess.call("cls", shell=True) 24 | else: # For Linux and macOS 25 | subprocess.call("clear", shell=True) 26 | 27 | clear_console() 28 | print(f"{Fore.YELLOW}Starting script, please wait...{Style.RESET_ALL}") 29 | 30 | #mixer output settings 31 | blend_ratio = 0.5 #setting to 0 gives first model, and 1 gives second model 32 | fp16 = False #perform operations in fp16. Saves memory, but CPU inference will not be possible. 33 | always_output_fp16 = True #if true, will output fp16 even if operating in fp32 34 | max_shard_size = "10000MiB" #set output shard size 35 | force_cpu = True #only use cpu 36 | load_sharded = True #load both models shard by shard 37 | 38 | print(f"Blend Ratio set to: {Fore.GREEN}{blend_ratio}{Style.RESET_ALL}") 39 | print(f"Operations in fp16 is: {Fore.GREEN}{fp16}{Style.RESET_ALL}") 40 | print(f"Save Result in fp16: {Fore.GREEN}{always_output_fp16}{Style.RESET_ALL}") 41 | print(f"CPU RAM Only: {Fore.GREEN}{force_cpu}{Style.RESET_ALL}{newline}") 42 | 43 | #test generation settings, only for fp32 44 | deterministic_test = True #determines if outputs are always the same 45 | test_prompt = "" #test prompt for generation. only for fp32. set to empty string to skip generating. 46 | test_max_length = 32 #test generation length 47 | 48 | 49 | blend_ratio_b = 1.0 - blend_ratio 50 | 51 | def get_model_info(model): 52 | with torch.no_grad(): 53 | outfo = "" 54 | cntent = 0 55 | outfo += "\n==============================\n" 56 | for name, para in model.named_parameters(): 57 | cntent += 1 58 | outfo += ('{}: {}'.format(name, para.shape))+"\n" 59 | outfo += ("Num Entries: " + str(cntent))+"\n" 60 | outfo += ("==============================\n") 61 | return outfo 62 | 63 | def merge_models(model1,model2): 64 | with torch.no_grad(): 65 | tensornum = 0 66 | for p1, p2 in zip(model1.parameters(), model2.parameters()): 67 | p1 *= blend_ratio 68 | p2 *= blend_ratio_b 69 | p1 += p2 70 | tensornum += 1 71 | print("Merging tensor "+str(tensornum)) 72 | pass 73 | 74 | def read_index_filenames(sourcedir): 75 | index = json.load(open(sourcedir + '/pytorch_model.bin.index.json','rt')) 76 | fl = [] 77 | for k,v in index['weight_map'].items(): 78 | if v not in fl: 79 | fl.append(v) 80 | return fl 81 | 82 | print("Opening file dialog, please select FIRST model directory...") 83 | model_path1 = askdirectory(title="Select Directory of FIRST model to merge") 84 | print(f"First Model is: {model_path1}") 85 | print("Opening file dialog, please select SECOND model directory...") 86 | model_path2 = askdirectory(title="Select Directory of SECOND model to merge") 87 | print(f"Second Model is: {model_path2}") 88 | print("Opening file dialog, please select OUTPUT model directory...") 89 | model_path3 = askdirectory(title="Select Output Directory of merged model") 90 | print(f"Merged Save Directory is: {model_path3}{newline}") 91 | if not model_path1 or not model_path2: 92 | print("\nYou must select two directories containing models to merge and one output directory. Exiting.") 93 | exit() 94 | 95 | with torch.no_grad(): 96 | if fp16: 97 | torch.set_default_dtype(torch.float16) 98 | else: 99 | torch.set_default_dtype(torch.float32) 100 | 101 | device = torch.device("cuda") if (torch.cuda.is_available() and not force_cpu) else torch.device("cpu") 102 | print(device) 103 | 104 | print("Loading Model 1...") 105 | model1 = AutoModelForCausalLM.from_pretrained(model_path1) #,torch_dtype=torch.float16 106 | model1 = model1.to(device) 107 | model1.eval() 108 | print("Model 1 Loaded. Dtype: " + str(model1.dtype)) 109 | print("Loading Model 2...") 110 | model2 = AutoModelForCausalLM.from_pretrained(model_path2) #,torch_dtype=torch.float16 111 | model2 = model2.to(device) 112 | model2.eval() 113 | print("Model 2 Loaded. Dtype: " + str(model2.dtype)) 114 | 115 | # Saving for posterity reasons, handy for troubleshooting if model result is broken 116 | # #ensure both models have the exact same layout 117 | # m1_info = get_model_info(model1) 118 | # m2_info = get_model_info(model2) 119 | # if m1_info != m2_info: 120 | # print("Model 1 Info: " + m1_info) 121 | # print("Model 2 Info: " + m2_info) 122 | # print("\nERROR:\nThe two selected models are not compatible! They must have identical structure!") 123 | # exit() 124 | 125 | print("Merging models...") 126 | merge_models(model1,model2) 127 | 128 | if model_path3: 129 | print("Saving new model...") 130 | if always_output_fp16 and not fp16: 131 | model1.half() 132 | model1.save_pretrained(model_path3, max_shard_size=max_shard_size) 133 | print("\nSaved to: " + model_path3) 134 | print("\nCopying files to: " + model_path3) 135 | files_to_copy = ["tokenizer.model", "special_tokens_map.json", "tokenizer_config.json", "vocab.json", "merges.txt"] 136 | for filename in files_to_copy: 137 | src_path = os.path.join(model_path1, filename) 138 | dst_path = os.path.join(model_path3, filename) 139 | try: 140 | shutil.copy2(src_path, dst_path) 141 | except FileNotFoundError: 142 | print("\nFile " + filename + " not found in" + model_path1 + ". Skipping.") 143 | else: 144 | print("\nOutput model was not saved as no output path was selected.") 145 | print("\nScript Completed.") 146 | -------------------------------------------------------------------------------- /Model Merge And Analysis Tools/Enhanced_Mixer_Requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/huggingface/transformers 2 | numpy 3 | matplotlib 4 | tkinter 5 | colorama 6 | shutil 7 | -------------------------------------------------------------------------------- /Model Merge And Analysis Tools/LM_BlockMerge_Requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/huggingface/transformers 2 | torch 3 | numpy 4 | tkinter 5 | shutil 6 | -------------------------------------------------------------------------------- /Model Merge And Analysis Tools/StratusScope.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import subprocess 4 | from tkinter import Tk, filedialog 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import torch 9 | from colorama import Fore, Style, init 10 | from transformers import AutoConfig, AutoModel, logging 11 | 12 | logging.set_verbosity_warning() 13 | logging.set_verbosity_error() 14 | 15 | def select_folder(): 16 | Tk().withdraw() 17 | folder = filedialog.askdirectory() 18 | return folder 19 | 20 | def clear_console(): 21 | if os.name == "nt": # For Windows 22 | subprocess.call("cls", shell=True) 23 | else: # For Linux and macOS 24 | subprocess.call("clear", shell=True) 25 | 26 | def load_sharded_layer(folder, target_layer): 27 | files = os.listdir(folder) 28 | model_files = sorted([f for f in files if f.startswith('pytorch_model-') and f.endswith('.bin')]) 29 | 30 | layer_state_dict = {} 31 | for model_file in model_files: 32 | shard = torch.load(os.path.join(folder, model_file), map_location=torch.device('cpu')) 33 | for name, param in shard.items(): 34 | layer_number = get_layer_number(name) 35 | if layer_number == target_layer: 36 | layer_state_dict[name] = param 37 | 38 | return layer_state_dict 39 | 40 | def get_layer_number(name): 41 | parts = name.split('.') 42 | for part in parts: 43 | if part.isdigit(): 44 | return int(part) 45 | return None 46 | 47 | def get_total_layers(model_folder): 48 | files = os.listdir(model_folder) 49 | model_files = sorted([f for f in files if f.startswith('pytorch_model-') and f.endswith('.bin')]) 50 | 51 | all_layers = set() 52 | 53 | for model_file in model_files: 54 | shard = torch.load(os.path.join(model_folder, model_file), map_location=torch.device('cpu')) 55 | for name in shard.keys(): 56 | layer_number = get_layer_number(name) 57 | all_layers.add(layer_number) 58 | 59 | return len(all_layers) 60 | 61 | # https://pytorch.org/docs/stable/tensors.html 62 | 63 | def compare_layers(model1_folder, model2_folder): 64 | layer_diffs = [] 65 | newline = '\n' 66 | num_layers = get_total_layers(model1_folder) -1 67 | print(f"Torch Version: {torch.__version__}") 68 | print(f"Total Layers Found: {num_layers}{newline}") 69 | 70 | for layer_number in range(num_layers): 71 | layer_diff = 0 72 | 73 | model1_layer = load_sharded_layer(model1_folder, layer_number) 74 | model2_layer = load_sharded_layer(model2_folder, layer_number) 75 | 76 | for n1, p1 in model1_layer.items(): 77 | p2 = model2_layer[n1] 78 | 79 | print(f"{newline}{Fore.YELLOW}--------Found Tensor Pair--------{newline}") 80 | print(f"p1 = {p1}") 81 | print(f"p2 = {p2}") 82 | print(f"{newline}{Fore.GREEN}--------Casting p1 & p2 tensor pair to float32--------{newline}") 83 | p1 = p1.detach().to(torch.float32) 84 | print(f"p1 = {p1}") 85 | p2 = p2.detach().to(torch.float32) 86 | print(f"p2 = {p2}") 87 | 88 | if not (torch.isinf(p1).any() or torch.isinf(p2).any()): 89 | diff = torch.abs(p1 - p2).sum().item() 90 | layer_diff += diff 91 | 92 | print(f"{newline}{Fore.CYAN}----------- Layer {layer_number}: Aggregate Difference = {layer_diff} -----------{Style.RESET_ALL}{newline}") 93 | layer_diffs.append(layer_diff) 94 | 95 | return layer_diffs 96 | 97 | def plot_layer_diff(layer_diffs, model1_name, model2_name): 98 | plt.figure(figsize=(20, 6)) 99 | num_layers = len(layer_diffs) 100 | layer_indices = range(num_layers) 101 | plt.bar(layer_indices, layer_diffs) 102 | plt.xticks(layer_indices) 103 | plt.xlabel('Layer') 104 | plt.ylabel('Difference') 105 | plt.title(f"{model1_name} vs {model2_name} Layer Difference") 106 | plt.ylim(bottom=0) 107 | print("Script completed, close graph to unload models and return to commandline.") 108 | plt.show() 109 | 110 | def main(): 111 | print("Select model1 folder:") 112 | model1_folder = select_folder() 113 | model1_name = os.path.basename(model1_folder) 114 | print("Select model2 folder:") 115 | model2_folder = select_folder() 116 | model2_name = os.path.basename(model2_folder) 117 | 118 | print("Examining Models...") 119 | clear_console() 120 | layer_diffs = compare_layers(model1_folder, model2_folder) 121 | 122 | plot_layer_diff(layer_diffs, model1_name, model2_name) 123 | 124 | torch.cuda.empty_cache() 125 | import gc 126 | gc.collect() 127 | 128 | if __name__ == "__main__": 129 | main() 130 | -------------------------------------------------------------------------------- /Model Merge And Analysis Tools/StratusScope_BarGraph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huu4ontocord/MDEL/d84a598e765accfb723edd58f6c0a426d8c16d8d/Model Merge And Analysis Tools/StratusScope_BarGraph.png -------------------------------------------------------------------------------- /Model Merge And Analysis Tools/StratusScope_ConsoleOutput.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huu4ontocord/MDEL/d84a598e765accfb723edd58f6c0a426d8c16d8d/Model Merge And Analysis Tools/StratusScope_ConsoleOutput.png -------------------------------------------------------------------------------- /Model Merge And Analysis Tools/StratusScope_Requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/huggingface/transformers 2 | numpy 3 | matplotlib 4 | tkinter 5 | colorama 6 | -------------------------------------------------------------------------------- /Model Merge And Analysis Tools/__Quick Tool Explainer__.txt: -------------------------------------------------------------------------------- 1 | ---All tools provided originate from Ontocord associated family (KoboldAI community)--- 2 | 3 | StratusScope is a language model layer analysis tool that shows the aggregate difference per-layer from a base model 4 | VS any fine-tuned variant in bar graph format and posts tensors to console along with found differences. 5 | 6 | LM_BlockMerge is a per-layer language model merging tool that empowers one to choose a percent of each layer to weight-sum merge from model A to model B. 7 | 8 | EnhancedMixer is a simple weight-sum merge tool where one can run to choose two models to combine, merge percentage and simple parameters can be adjusted inside the well documented .py script. 9 | 10 | ------------------------------------ 11 | In-Depth Explainer on Each Tool: 12 | ------------------------------------ 13 | 14 | [[StratusScope]] 15 | 16 | Is a language model tool that utilizes HuggingFace's Transformers library and loads two language models of the same architecture and parameter size, consolidates the weights and biases within each layer of both models in memory, examines the aggregate difference between layers, and generates a bar graph detailing which layers have the most difference between each model with matplotlib. 17 | 18 | Use Case - This is an invaluable tool to measure layer differences between a base model and a fine-tune of that model to determine which layers inherited the most change from fine-tuning. 19 | 20 | Original Git [Author: Digitous] 21 | https://github.com/Digitous/StratusScope 22 | 23 | ------------------------------------ 24 | 25 | [[Language Model Transformer Block Merge]] 26 | 27 | Uses a tkinter GUI that analyzes two selected models and allows per-layer percentage 28 | weight merging controlled by the user (inspired by an Image Diffusion block merging 29 | technique and applied to transformers based Language Models). 30 | 31 | Usage: 32 | Start the script via the command 33 | [On Linux] python ./LM_BlockMerge.py 34 | [On Windows] python LM_BlockMerge.py 35 | The script will then prompt you for three folders: 36 | The first model 37 | The first model will be used as a base for the transformers configuration, and will be used as reference when handling the layers and weights. 38 | The second model 39 | The second model will only be used for providing the secondary layers for the merge. 40 | The output folder 41 | The resulting model will be saved inside the selected directory, in a folder called "./this/is/your/chosen_path/" + "/converted_model" 42 | The script will then load the weights in memory, according to the precision (32/16 bit) chosen in the configuration header, and will subsequently prompt the user with a popup GUI listing all of the layers available in the selected models. 43 | 44 | The user will be able to merge layers according to any strategy, ranging from: 45 | Creating the output_model by completely replacing N layers at will; 46 | Creating the output_model by chosing an individual mix ratio on N layers; 47 | Any mix of the two strategies on N chosen layers. 48 | The layers will be merged according to the individual choices per layer, and the resulting model weights will be saved onto the output folder, alongside the first model's config.json file. 49 | 50 | Available output settings: 51 | fp16 = False # Perform operations in fp16. Saves memory, but CPU inference will not be possible. 52 | always_output_fp16 = True # If true, will output fp16 even if operating in fp32 53 | max_shard_size = "2000MiB" # Set output shard size 54 | verbose_info = True # Will show model information when loading 55 | force_cpu = True # Only use cpu 56 | 57 | Supported Models: 58 | GPT-NeoX, Pythia, GPT-J, Opt, Llama 59 | Pseudo-Supported Models: 60 | 61 | BERT (testing required to validate implementation) 62 | Notes: 63 | 64 | Performing the operation in FP16 mode halves the memory requirements, but will massively slow down the process of loading up the models on memory; Always outputting in fp16 is preferable to save in storage space, especially if the original weights were quantized down to 16bit already. But if your original models are using 32bit precision, then be sure whether you wish to halve the precision of the resulting file or not. 65 | 66 | Model loading is automatic; the script determines the model type and adjusts accordingly, no special command-line flags required. Current GPT-NeoX support is hacky, it tends to have an error mid-merge with 20B; it might work on GPT-NeoX and Pythia models of a smaller size 6b or lower for now until a solution is implemented. 67 | 68 | Original Git [Collaborators: LostRuins aka Concedo and TeH_Venom aka TehVenomm and Chasm aka Digitous aka Erik] 69 | https://github.com/TehVenomm/LM_Transformers_BlockMerge 70 | 71 | ------------------------------------ 72 | 73 | [[Enhanced Mixer]] 74 | 75 | Simple tool that allows a user to select two model of the same architecture and param size and merge. Settings for merge percentage and more are documented inside the script. Simply run to select which models to merge. 76 | 77 | No GitHub Page [Original Author: LostRuins | Enhancements by: Digitous] 78 | 79 | ------------------------------------ 80 | 81 | Authors GitHub Pages: 82 | https://github.com/LostRuins 83 | https://github.com/TehVenomm 84 | https://github.com/Digitous 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MDEL 2 | 3 | Multi-Domain Expert Learning 4 | 5 | # Environment Setup 6 | 7 | To set up the development environment, run `make setup_dev`. This will setup the 8 | pre-commit hooks. 9 | 10 | ## Creating Expert Datasets 11 | 12 | First, make sure you followed the Environment Setup guidelines. 13 | 14 | To create an expert dataset using the Pile data, follow these steps: 15 | 16 | 1. Download the Pile shard 1 data: `./scripts/get_pile_shard1_data.sh` 17 | 2. To set the domain, edit the variable `SUBSET_NAME` in 18 | `scripts/create_domain_pile_mix.sh`. This should be set to a valid value of 19 | the Pile's variable `pile_set_name`. A list of valid values can be found 20 | below. 21 | 3. Run the above script to process the dataset 22 | 4. Authenticate into Hugginface: 23 | `export HF_ACCESS_TOKEN={YOUR HUGGINGFACE TOKEN}` 24 | 5. Set the dataset name in `scripts/upload_to_hf.sh` 25 | 6. Run the above script to upload the processed dataset to HuggingFace 26 | 27 | ### Pile Subsets 28 | 29 | - Pile-CC 30 | - PubMed Central 31 | - Books3† 32 | - OpenWebText2 33 | - ArXiv 34 | - Github 35 | - FreeLaw 36 | - Stack Exchange 37 | - USPTO Backgrounds 38 | - PubMed Abstracts 39 | - Gutenberg (PG-19)† 40 | - OpenSubtitles† 41 | - Wikipedia (en)† 42 | - DM Mathematics† 43 | - Ubuntu IRC 44 | - BookCorpus2 45 | - EuroParl† 46 | - HackerNews 47 | - YoutubeSubtitles 48 | - PhilPapers 49 | - NIH ExPorter 50 | - Enron Emails† 51 | 52 | # Training Expert Models 53 | 54 | 1. Clone this repo and follow the Environment Setup instructions 55 | 2. Set up HF authentication: `export HUGGING_FACE_HUB_TOKEN=[FILL ME]` 56 | 3. Set up W&B authentication: `export WANDB_API_KEY=[FILL ME]` 57 | 4. Edit the variable `DATASET` in script `src/mdel/train.sh` to match a valid 58 | dataset name on the 59 | [MDEL HF](https://huggingface.co/Multi-Domain-Expert-Layers). 60 | 5. Run the above script in background mode to start the training: `./train.sh &` 61 | 6. The trained model should be uploaded to the MDEL HF 62 | 63 | # Merging Expert Models 64 | 65 | 1. Clone this repo and follow the Environment Setup instructions 66 | 2. Set up HF authentication: `export HUGGING_FACE_HUB_TOKEN=[FILL ME]` 67 | 3. Run the merge script 68 | 69 | ```bash 70 | python src/mdel/merge_experts.py \ 71 | --hf-repo your_hf_username/desired_name_of_merged_model \ 72 | -e mdel/expert_1 \ 73 | -e mdel/expert_2 \ 74 | -e mdel/expert_n 75 | ``` 76 | 77 | # Evaluating Perplexity of Models 78 | 79 | 1. Clone this repo and follow the Environment Setup instructions 80 | 2. Set up HF authentication: `export HUGGING_FACE_HUB_TOKEN=[FILL ME]` 81 | 3. Run the perplexity script 82 | 83 | ```bash 84 | python3 src/mdel/calculate_perplexity.py \ 85 | --model Multi-Domain-Expert-Layers/expert-arxiv \ 86 | --dataset Multi-Domain-Expert-Layers/arxiv \ 87 | --split validation_domain 88 | ``` 89 | 90 | # References 91 | 92 | Gao, L., Biderman, S., Black, S., Golding, L., Hoppe, T., Foster, C., ... & 93 | Leahy, C. (2020).The pile: An 800gb dataset of diverse text for language 94 | modeling. _arXiv preprint arXiv:2101.00027_. 95 | -------------------------------------------------------------------------------- /clustering/download.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import logging 3 | 4 | import trafilatura 5 | 6 | from transformers import pipeline 7 | from transformers import AutoTokenizer 8 | 9 | import numpy as np 10 | 11 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 12 | 13 | max_embedding_characters = 128 # This is a deliberately low value, as the current model is not intended for document embedding 14 | 15 | feature_extractor_checkpoint = 'sentence-transformers/LaBSE' 16 | tokenizer_checkpoint = 'gpt2' 17 | 18 | feature_extractor = pipeline('feature-extraction', framework='pt', model=feature_extractor_checkpoint) 19 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint) 20 | 21 | def fetch_and_parse(url): 22 | try: 23 | response = requests.get(url, timeout=10) 24 | 25 | response.raise_for_status() 26 | except (requests.HTTPError, requests.ConnectionError, requests.Timeout) as error: 27 | logging.error(f'Failed to fetch {url}: {error}') 28 | 29 | return None, None 30 | 31 | content = response.text 32 | 33 | markdown = trafilatura.extract(content, output_format='txt', include_formatting=True, \ 34 | include_tables=True, include_images=True, no_fallback=True, include_links=True) 35 | 36 | return content, markdown 37 | 38 | def embed(text): 39 | embedding = feature_extractor(text) 40 | 41 | return embedding 42 | 43 | def tokenize(text): 44 | tokens = tokenizer.encode(text) 45 | 46 | return tokens 47 | 48 | def process_url(url): 49 | content, markdown = fetch_and_parse(url) 50 | 51 | content_short = content[:max_embedding_characters] 52 | 53 | tokens = tokenize(content) 54 | embedding = embed(content_short) 55 | 56 | embedding = np.array(embedding) 57 | 58 | return content, markdown, tokens, embedding 59 | 60 | def main(): 61 | url = 'https://huggingface.co' 62 | 63 | content, markdown, tokens, embedding = process_url(url) 64 | 65 | for current in [content, markdown, embedding.shape]: 66 | print(f'{"-" * 32}\n{current}') 67 | 68 | print('-' * 32) 69 | 70 | if __name__ == '__main__': 71 | main() 72 | -------------------------------------------------------------------------------- /clustering/feature_extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformers import AutoTokenizer, AutoModelForCausalLM 4 | 5 | 6 | class FeatureExtractor: 7 | def __init__(self, device='cpu', model_id='bigscience/bloom-560m', num_decoder_blocks=8): 8 | self.device = device 9 | 10 | self.num_decoder_blocks = num_decoder_blocks 11 | self.model_id = model_id 12 | 13 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) 14 | 15 | self.model = AutoModelForCausalLM.from_pretrained(self.model_id) 16 | 17 | h = self.model.transformer.h[:num_decoder_blocks] # Note that this will change for different families of models 18 | self.model.transformer.h = h 19 | 20 | self.model = self.model.to(device) 21 | 22 | 23 | def encode(self, text): 24 | tokens = self.tokenizer(text, padding=True, return_tensors='pt').to(self.device) 25 | 26 | output = self.model(**tokens, output_hidden_states=True).hidden_states[-1] 27 | output = output.detach().cpu().numpy() 28 | 29 | return output 30 | 31 | 32 | def __call__(self, text): 33 | output = self.encode(text) 34 | 35 | return output 36 | 37 | 38 | def main(): 39 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 40 | print(f'Using {device} device') 41 | 42 | feature_extractor = FeatureExtractor(device=device) 43 | 44 | output = feature_extractor('Hello world!') 45 | print(output) 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | -------------------------------------------------------------------------------- /clustering/hierarchical_clustering.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | from torch import arange, argmax 8 | from tqdm import tqdm 9 | from collections import Counter 10 | 11 | import uuid 12 | 13 | import numpy as np 14 | from fast_pytorch_kmeans import KMeans 15 | 16 | from feature_extractor import FeatureExtractor 17 | from memmap_utils import np_memmap, get_np_memmap_length 18 | 19 | 20 | class ClusterAnalysis(nn.Module): 21 | def __init__( 22 | self, 23 | mmap_file=None, 24 | embed_dim=128, 25 | dtype=np.float32, 26 | ): 27 | super().__init__() 28 | 29 | self.mmap_file = mmap_file 30 | self.embed_dim = embed_dim 31 | 32 | self.dtype = dtype 33 | 34 | self.clusters = {} 35 | self.span_to_cluster_label = {} 36 | 37 | 38 | @staticmethod 39 | def _cluster_one_batch( 40 | true_k, 41 | spans, 42 | clusters, 43 | span_to_cluster_label, 44 | level, 45 | cluster_embeddings, 46 | min_overlap_merge_cluster, 47 | device 48 | ): 49 | with torch.no_grad(): 50 | embeddings = torch.from_numpy(cluster_embeddings) 51 | 52 | km = KMeans(n_clusters=true_k, mode='cosine') 53 | km_labels = km.fit_predict(embeddings.to(device=device, dtype=torch.float32)).tolist() 54 | 55 | embeddings = None 56 | 57 | if not clusters: 58 | label_to_label = {} 59 | 60 | for span, label in zip(spans, km_labels): 61 | label = (label, level) 62 | 63 | if label not in label_to_label: 64 | label_to_label[label] = (span[0], level) 65 | 66 | label = label_to_label[label] 67 | 68 | clusters[label] = clusters.get(label, []) +[ span] 69 | span_to_cluster_label[span] = label 70 | 71 | output = list(clusters.keys()) 72 | 73 | return output 74 | 75 | tmp_cluster = {} 76 | 77 | for span, label in zip(spans, km_labels): 78 | tmp_cluster[label] = tmp_cluster.get(label, [])+[span] 79 | 80 | new_labels = [] 81 | 82 | for a_cluster in tmp_cluster.values(): 83 | for span in a_cluster: 84 | need_labels = [span for span in a_cluster if span not in span_to_cluster_label or span_to_cluster_label[span][1] != level] 85 | cluster_labels = [span_to_cluster_label[span] for span in a_cluster if span in span_to_cluster_label and span_to_cluster_label[span][1] == level] 86 | 87 | if not need_labels: 88 | continue 89 | 90 | if not cluster_labels: 91 | 92 | label = (span[0], level) 93 | 94 | else: 95 | most_common = Counter(cluster_labels).most_common(1)[0] 96 | 97 | if most_common[1] < min_overlap_merge_cluster: 98 | label = (span[0], level) 99 | 100 | else: 101 | label = most_common[0] 102 | 103 | new_labels.append(label) 104 | 105 | for span in need_labels: 106 | clusters[label] = clusters.get(label, []) + [span] 107 | span_to_cluster_label[span] = label 108 | 109 | return new_labels 110 | 111 | 112 | def create_hiearchical_clusters( 113 | self, 114 | force_recluster_idxs=None, 115 | max_level=4, 116 | max_cluster_size=32, # Small value for debug purposes 117 | min_overlap_merge_cluster=2, 118 | prefered_leaf_node_size=None, 119 | kmeans_batch_size=250000, 120 | use_tqdm=False, 121 | device='cuda:0' 122 | ): 123 | mmap_file = self.mmap_file 124 | embed_dim = self.embed_dim 125 | dtype = self.dtype 126 | 127 | mmap_len = get_np_memmap_length(mmap_file, [0, embed_dim], dtype=dtype) 128 | 129 | clusters = self.clusters 130 | span_to_cluster_label = self.span_to_cluster_label 131 | 132 | if force_recluster_idxs: 133 | force_recluster_idxs = set(force_recluster_idxs) 134 | else: 135 | force_recluster_idxs = () 136 | 137 | already_clustered = set([span[0] for span in span_to_cluster_label if span[1] == 0 and span[0] not in force_recluster_idxs]) 138 | 139 | idxs = [] 140 | 141 | if force_recluster_idxs: 142 | idxs = list(force_recluster_idxs) 143 | force_recluster_idxs = None 144 | 145 | idxs.extend([idx for idx in range(mmap_len) if idx not in already_clustered]) 146 | 147 | if not idxs: 148 | return 149 | 150 | already_clustered = list(already_clustered) 151 | 152 | if len(already_clustered) > int(0.5 * kmeans_batch_size): 153 | idxs.extend(random.sample(already_clustered, int(0.5 * kmeans_batch_size))) 154 | else: 155 | idxs.extend(already_clustered) 156 | 157 | already_clustered = None 158 | 159 | idxs.extend([span[0] for span in span_to_cluster_label if span[1] != 0]) 160 | idxs = list(set(idxs)) 161 | random.shuffle(idxs) 162 | 163 | if not prefered_leaf_node_size: 164 | prefered_leaf_node_size= int(max_cluster_size * 0.7) 165 | 166 | for level in range(max_level): 167 | all_spans = [(idx, level) for idx in idxs] 168 | len_spans = len(all_spans) 169 | 170 | step_size = int(0.7 * kmeans_batch_size) 171 | num_times = max(3, math.ceil(len_spans / step_size)) 172 | 173 | if use_tqdm: 174 | num_times_2 = tqdm.tqdm(range(num_times)) 175 | 176 | else: 177 | num_times_2 = range(num_times) 178 | 179 | for times in num_times_2: 180 | max_rng = min(len_spans, step_size) 181 | 182 | spans = all_spans[:max_rng] 183 | 184 | not_already_clustered = [span for span in all_spans[:max_rng - step_size] if span not in span_to_cluster_label] 185 | 186 | if len(not_already_clustered) > int(0.5 * kmeans_batch_size): 187 | spans.extend(random.sample(not_already_clustered, int(0.5 * kmeans_batch_size))) 188 | else: 189 | spans.extend(not_already_clustered) 190 | 191 | if len(spans) == 0: break 192 | 193 | already_clustered = [span for span in all_spans[:max_rng - step_size] if span in span_to_cluster_label] 194 | 195 | if len(already_clustered) > int(0.5 * kmeans_batch_size): 196 | spans.extend(random.sample(already_clustered, int(0.5 * kmeans_batch_size))) 197 | 198 | else: 199 | spans.extend(already_clustered) 200 | 201 | embedding_idxs = [span[0] for span in spans] 202 | 203 | if level == 0: 204 | true_k = int(len(embedding_idxs) / prefered_leaf_node_size) 205 | 206 | else: 207 | true_k = int(len(embedding_idxs ) / max_cluster_size) 208 | 209 | cluster_embeddings = np_memmap(mmap_file, shape=[mmap_len, embed_dim], idxs=embedding_idxs, dtype=dtype) 210 | 211 | new_labels = self._cluster_one_batch(true_k, spans, clusters, span_to_cluster_label, level, cluster_embeddings, min_overlap_merge_cluster, device) 212 | 213 | if not new_labels: 214 | break 215 | 216 | need_more = False 217 | 218 | assert prefered_leaf_node_size <= max_cluster_size, 'prefered_leaf_node_size Must not exceed max_cluster_size' 219 | 220 | if times <= num_times - 2: 221 | for label in new_labels: 222 | if len(clusters[label]) < prefered_leaf_node_size: 223 | del clusters[label] 224 | 225 | need_more = True 226 | 227 | if not need_more: 228 | break 229 | 230 | idxs = [val[0][0] for key, val in clusters.items() if key[1] == level] 231 | 232 | if len(idxs) < max_cluster_size: 233 | break 234 | 235 | 236 | def main(): 237 | cluster_analysis = ClusterAnalysis( 238 | mmap_file='output/embeddings.mmap', 239 | embed_dim=1024 240 | ) 241 | 242 | cluster_analysis.create_hiearchical_clusters() 243 | 244 | print(list(cluster_analysis.clusters.keys())) 245 | 246 | 247 | if __name__ == '__main__': 248 | main() 249 | -------------------------------------------------------------------------------- /clustering/memmap_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | 6 | def is_contiguous(arr): 7 | start = None 8 | prev = None 9 | contiguous = True 10 | 11 | for idx in arr: 12 | if start is None: 13 | start = idx 14 | if prev is None or idx == prev + 1: 15 | prev = idx 16 | 17 | continue 18 | 19 | contiguous = False 20 | 21 | break 22 | 23 | return contiguous, start, idx + 1 24 | 25 | 26 | def np_memmap(file_name, data=None, idxs=None, shape=None, dtype=np.float32, offset=0, order='C'): 27 | if not file_name.endswith('.mmap'): 28 | file_name += '.mmap' 29 | 30 | if os.path.exists(file_name): 31 | mode = 'r+' 32 | else: 33 | mode = 'w+' 34 | 35 | if shape is None and data is not None: 36 | shape = data.shape 37 | 38 | if not shape: 39 | shape = [0, 1] 40 | 41 | memmap = np.memmap(file_name, mode=mode, dtype=dtype, shape=tuple(shape), offset=offset, order=order) 42 | 43 | if idxs: 44 | contiguous, start, end = is_contiguous(idxs) 45 | 46 | if data is not None: 47 | if tuple(shape) == tuple(data.shape): 48 | memmap[:] = data 49 | elif contiguous: 50 | memmap[start:end] = data 51 | else: 52 | memmap[idxs] = data 53 | 54 | return memmap 55 | 56 | 57 | def get_np_memmap_length(file_name, shape, dtype=np.float32): 58 | if not os.path.exists(file_name): 59 | return shape[0] 60 | 61 | else: 62 | size = np.dtype(dtype).itemsize * np.prod(shape[1:]) 63 | 64 | return int(os.path.getsize(file_name) / size) 65 | -------------------------------------------------------------------------------- /clustering/train_clusterer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | 4 | import numpy as np 5 | 6 | import pickle 7 | from pathlib import Path 8 | 9 | import torch 10 | 11 | from tqdm.auto import tqdm 12 | 13 | from torch.utils.data import DataLoader, IterableDataset 14 | 15 | from kmeans_pytorch import KMeans as BalancedKMeans 16 | 17 | from transformers import pipeline 18 | 19 | from datasets import load_dataset 20 | 21 | from sklearn.manifold import TSNE 22 | from sklearn.decomposition import PCA 23 | 24 | import matplotlib.pyplot as plt 25 | 26 | from feature_extractor import FeatureExtractor 27 | from memmap_utils import np_memmap 28 | 29 | 30 | def load_model(path_to_model: Path): 31 | with open(path_to_model, 'rb') as file: 32 | output = pickle.load(file) 33 | 34 | file.close() 35 | 36 | return output 37 | 38 | 39 | def extract_features(corpus, feature_extractor, batch_size=32, max_chars=256): 40 | corpus = [element[:max_chars] for element in corpus] 41 | batches = np.array_split(corpus, len(corpus) // batch_size, axis=0) 42 | 43 | features = [] 44 | 45 | for batch in tqdm(batches): 46 | batch = list(batch) # batches is a list of numpy arrays 47 | 48 | features_current = feature_extractor(batch) 49 | features_current = np.max(features_current, axis=1) 50 | 51 | features.append(features_current) 52 | 53 | features = np.concatenate(features, axis=0) 54 | 55 | return features 56 | 57 | 58 | def train_kmeans(features, n_clusters, path_to_kmeans, balanced=False, device='cpu'): 59 | kmeans = BalancedKMeans(n_clusters=n_clusters, device=device, balanced=balanced) 60 | 61 | batch_size = 512 # Hyperparameter 62 | batch_size = min(batch_size, len(features)) 63 | 64 | batches = np.array_split(features, features.shape[0] // batch_size, axis=0) 65 | 66 | for idx, batch in tqdm(enumerate(batches)): 67 | kmeans.fit(torch.from_numpy(batch), iter_limit=20, online=True, iter_k=idx) 68 | 69 | with open(path_to_kmeans, 'wb+') as file: 70 | pickle.dump(kmeans, file) 71 | 72 | file.close() 73 | 74 | return kmeans 75 | 76 | 77 | def main(n_clusters=16, balanced=False, output_dir=Path('cluster_output/'), shuffle_dataset=True, take_sample=None, embed_only=False, seed=42, visualize=False): 78 | dataset_name_train = 'JeanKaddour/minipile' 79 | content_column_train = 'text' 80 | 81 | subset_train = None # 'p3' 82 | split_train = 'train' 83 | 84 | dataset_train = load_dataset(dataset_name_train, subset_train, split=split_train, streaming=(take_sample is not None)) 85 | 86 | if shuffle_dataset: 87 | dataset_train = dataset_train.shuffle(seed=seed) 88 | 89 | corpus = [] 90 | 91 | for idx, element in enumerate(dataset_train): 92 | corpus.append(element[content_column_train]) 93 | 94 | if take_sample: 95 | if idx >= take_sample: 96 | break 97 | 98 | if not output_dir.is_dir(): 99 | output_dir.mkdir(parents=True, exist_ok=True) 100 | 101 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 102 | print(f'Using {device} device') 103 | 104 | feature_extractor_batch_size = 1 105 | 106 | feature_extractor_checkpoint = 'bigscience/bloom-560m' # 'sentence-transformers/LaBSE' # 'xlm-roberta-large' 107 | feature_extractor = FeatureExtractor(device=device) # pipeline('feature-extraction', framework='pt', model=feature_extractor_checkpoint) 108 | 109 | features = extract_features(corpus, feature_extractor, batch_size=feature_extractor_batch_size) 110 | 111 | memmap_file_path = 'output/embeddings.mmap' # TODO: Create a configs.py file 112 | 113 | np_memmap(memmap_file_path, data=features) 114 | 115 | if embed_only: 116 | return 117 | 118 | path_to_kmeans = output_dir / 'kmeans.pkl' 119 | kmeans = train_kmeans(features, n_clusters, path_to_kmeans, balanced=balanced, device=device) 120 | 121 | if visualize: 122 | tsne = TSNE(n_components=2) 123 | features_2d = tsne.fit_transform(features) 124 | 125 | plt.scatter(features_2d[:, 0], features_2d[:, 1], c=kmeans.predict(torch.from_numpy(features)).cpu()) 126 | plt.show() 127 | 128 | return kmeans 129 | 130 | 131 | if __name__ == '__main__': 132 | parser = argparse.ArgumentParser() 133 | 134 | parser.add_argument('--n-clusters', required=True, type=int) 135 | parser.add_argument('--balanced', action='store_true') 136 | parser.add_argument('--output-dir', required=True, type=Path) 137 | parser.add_argument('--eval-only', action='store_true') 138 | parser.add_argument('--shuffle-dataset', required=False, type=bool, default=True) 139 | parser.add_argument('--take-sample', required=False, type=int) 140 | parser.add_argument('--embed-only', required=False, type=bool, default=False) 141 | parser.add_argument('--visualize', action='store_true') 142 | 143 | args = parser.parse_args() 144 | 145 | if not args.eval_only: 146 | kmeans = main( 147 | n_clusters=args.n_clusters, 148 | balanced=args.balanced, 149 | output_dir=args.output_dir, 150 | take_sample=args.take_sample, 151 | shuffle_dataset=args.shuffle_dataset, 152 | embed_only=args.embed_only, 153 | visualize=args.visualize 154 | ) 155 | 156 | path_to_kmeans = args.output_dir / 'kmeans.pkl' 157 | kmeans = load_model(path_to_kmeans) 158 | 159 | 160 | # Usage 161 | 162 | # python3 train_clusterer.py --n-clusters 4 --output-dir output/ --take-sample 128 --embed-only False --visualize 163 | 164 | # Warning! You need to install kmeans from https://github.com/kernelmachine/balanced-kmeans.git 165 | 166 | # cd .. 167 | # git clone https://github.com/kernelmachine/balanced-kmeans.git 168 | # cd balanced-kmeans 169 | # pip3 install -e . 170 | -------------------------------------------------------------------------------- /conda-mdel.yml: -------------------------------------------------------------------------------- 1 | name: mdel 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - pytorch==2.0.* 9 | - pytorch-cuda==11.8 10 | - cuda 11 | - pip 12 | - git=2.35.1 13 | - ninja 14 | - make 15 | - cxx-compiler 16 | - wget 17 | - git-lfs 18 | - pip: 19 | - accelerate~=0.18.0 20 | - datasets~=2.11.0 21 | - deepspeed==0.9.2 22 | - evaluate~=0.4.0 23 | - pre-commit~=2.21.0 24 | - scikit-learn~=1.2.2 25 | - tqdm~=4.65.0 26 | - transformers~=4.28.1 27 | - ujson~=5.7.0 28 | - wandb==0.15.2 29 | - zstandard~=0.21.0 30 | -------------------------------------------------------------------------------- /configs/fp16_1-3B_4M_bs_1.4T_tok_summit_pp3_mp2_256_nodes.yml: -------------------------------------------------------------------------------- 1 | # GPT-2 pretraining setup 2 | { 3 | # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages 4 | # across the node boundaries ) 5 | "pipe-parallel-size": 3, 6 | "model-parallel-size": 2, # one copy of the model per node 7 | 8 | # model settings 9 | "num-layers": 24, 10 | "hidden-size": 2048, 11 | "num-attention-heads": 16, 12 | "seq-length": 2048, 13 | "max-position-embeddings": 2048, 14 | "norm": "layernorm", 15 | "pos-emb": "rotary", 16 | "no-weight-tying": true, 17 | "gpt_j_residual": false, 18 | "output_layer_parallelism": "column", 19 | 20 | # these should provide some speedup but takes a while to build, set to true if desired 21 | "scaled-upper-triang-masked-softmax-fusion": true, 22 | "bias-gelu-fusion": true, 23 | 24 | # init methods 25 | "init_method": "small_init", 26 | "output_layer_init_method": "wang_init", 27 | 28 | # optimizer settings 29 | "optimizer": 30 | { 31 | "type": "Adam", 32 | "params": { "lr": 0.0002, "betas": [0.9, 0.95], "eps": 1.0e-8 }, 33 | }, 34 | "min_lr": 0.00002, 35 | # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training 36 | "zero_optimization": 37 | { 38 | "stage": 1, 39 | "allgather_partitions": True, 40 | "allgather_bucket_size": 500000000, 41 | "overlap_comm": True, 42 | "reduce_scatter": True, 43 | "reduce_bucket_size": 500000000, 44 | "contiguous_gradients": True, 45 | }, 46 | 47 | # batch / data settings 48 | "train_batch_size": 2048, # across 1024 nodes... fingers crossed 49 | # "train_micro_batch_size_per_gpu": 4, 50 | "gradient_accumulation_steps": 4, 51 | "data-impl": "mmap", 52 | "split": "949,50,1", 53 | 54 | # activation checkpointing 55 | "checkpoint-activations": true, 56 | "checkpoint-num-layers": 1, 57 | "partition-activations": true, 58 | "synchronize-each-layer": true, 59 | 60 | # regularization 61 | "gradient_clipping": 1.0, 62 | "weight-decay": 0.1, 63 | "hidden-dropout": 0.0, 64 | "attention-dropout": 0.0, 65 | 66 | # precision settings 67 | "fp16": { 68 | "enabled": true, 69 | # "type": "bfloat16", # set bf16 as precision 70 | "loss_scale": 0, 71 | "loss_scale_window": 1000, 72 | "hysteresis": 2, 73 | "min_loss_scale": 1, 74 | }, 75 | 76 | # "fp32_allreduce": True, # without a patch to torch, bf16 models have to do the allreduce in fp32 77 | # misc. training settings 78 | "train-iters": 250000, 79 | "lr-decay-iters": 250000, 80 | "distributed-backend": "nccl", 81 | "lr-decay-style": "cosine", 82 | "warmup": 0.01, 83 | "checkpoint-factor": 1000, 84 | "eval-interval": 1000, 85 | "eval-iters": 10, 86 | 87 | # logging 88 | "log-interval": 1, 89 | "steps_per_print": 1, 90 | "keep-last-n-checkpoints": 1000, 91 | "wall_clock_breakdown": true, 92 | } 93 | -------------------------------------------------------------------------------- /configs/fp16_2-7B_4M_bs_1.4T_tok_summit_pp6_mp2_256nodes.yml: -------------------------------------------------------------------------------- 1 | # GPT-2 pretraining setup 2 | { 3 | # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages 4 | # across the node boundaries ) 5 | "pipe-parallel-size": 6, 6 | "model-parallel-size": 2, # one copy of the model per node 7 | 8 | # model settings 9 | "num-layers": 32, 10 | "hidden-size": 2560, 11 | "num-attention-heads": 32, 12 | "seq-length": 2048, 13 | "max-position-embeddings": 2048, 14 | "norm": "layernorm", 15 | "pos-emb": "rotary", 16 | "no-weight-tying": true, 17 | "gpt_j_residual": false, 18 | "output_layer_parallelism": "column", 19 | 20 | # these should provide some speedup but takes a while to build, set to true if desired 21 | "scaled-upper-triang-masked-softmax-fusion": true, 22 | "bias-gelu-fusion": true, 23 | 24 | # init methods 25 | "init_method": "small_init", 26 | "output_layer_init_method": "wang_init", 27 | 28 | # optimizer settings 29 | "optimizer": 30 | { 31 | "type": "Adam", 32 | "params": { "lr": 0.00016, "betas": [0.9, 0.95], "eps": 1.0e-8 }, 33 | }, 34 | "min_lr": 0.000016, 35 | # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training 36 | "zero_optimization": 37 | { 38 | "stage": 1, 39 | "allgather_partitions": True, 40 | "allgather_bucket_size": 500000000, 41 | "overlap_comm": True, 42 | "reduce_scatter": True, 43 | "reduce_bucket_size": 500000000, 44 | "contiguous_gradients": True, 45 | }, 46 | 47 | # batch / data settings 48 | "train_batch_size": 2048, # across 1024 nodes... fingers crossed 49 | # "train_micro_batch_size_per_gpu": 4, 50 | "gradient_accumulation_steps": 8, 51 | "data-impl": "mmap", 52 | "split": "949,50,1", 53 | 54 | # activation checkpointing 55 | "checkpoint-activations": true, 56 | "checkpoint-num-layers": 1, 57 | "partition-activations": true, 58 | "synchronize-each-layer": true, 59 | 60 | # regularization 61 | "gradient_clipping": 1.0, 62 | "weight-decay": 0.1, 63 | "hidden-dropout": 0.0, 64 | "attention-dropout": 0.0, 65 | 66 | # precision settings 67 | "fp16": { 68 | "enabled": true, 69 | # "type": "bfloat16", # set bf16 as precision 70 | "loss_scale": 0, 71 | "loss_scale_window": 1000, 72 | "hysteresis": 2, 73 | "min_loss_scale": 1, 74 | }, 75 | 76 | # "fp32_allreduce": True, # without a patch to torch, bf16 models have to do the allreduce in fp32 77 | # misc. training settings 78 | "train-iters": 250000, 79 | "lr-decay-iters": 250000, 80 | "distributed-backend": "nccl", 81 | "lr-decay-style": "cosine", 82 | "warmup": 0.01, 83 | "checkpoint-factor": 1000, 84 | "eval-interval": 1000, 85 | "eval-iters": 10, 86 | 87 | # logging 88 | "log-interval": 1, 89 | "steps_per_print": 1, 90 | "keep-last-n-checkpoints": 1000, 91 | "wall_clock_breakdown": true, 92 | } 93 | -------------------------------------------------------------------------------- /configs/fp16_6-7B_4M_bs_1T_tok_summit_pp12_mp2_mbs2_512_nodes_real.yml: -------------------------------------------------------------------------------- 1 | # GPT-2 pretraining setup 2 | { 3 | # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages 4 | # across the node boundaries ) 5 | "pipe-parallel-size": 12, 6 | "model-parallel-size": 2, # one copy of the model per node 7 | 8 | # model settings 9 | "num-layers": 32, 10 | "hidden-size": 4096, 11 | "num-attention-heads": 32, 12 | "seq-length": 2048, 13 | "max-position-embeddings": 2048, 14 | "norm": "layernorm", 15 | "pos-emb": "rotary", 16 | "no-weight-tying": true, 17 | "gpt_j_residual": false, 18 | "output_layer_parallelism": "column", 19 | 20 | # these should provide some speedup but takes a while to build, set to true if desired 21 | "scaled-upper-triang-masked-softmax-fusion": true, 22 | "bias-gelu-fusion": true, 23 | 24 | # init methods 25 | "init_method": "small_init", 26 | "output_layer_init_method": "wang_init", 27 | 28 | # optimizer settings 29 | "optimizer": 30 | { 31 | "type": "Adam", 32 | "params": { "lr": 0.00012, "betas": [0.9, 0.95], "eps": 1.0e-8 }, 33 | }, 34 | "min_lr": 0.000012, 35 | # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training 36 | "zero_optimization": 37 | { 38 | "stage": 1, 39 | "allgather_partitions": True, 40 | "allgather_bucket_size": 500000000, 41 | "overlap_comm": True, 42 | "reduce_scatter": True, 43 | "reduce_bucket_size": 500000000, 44 | "contiguous_gradients": True, 45 | }, 46 | 47 | # batch / data settings 48 | "train_batch_size": 2048, # across 1024 nodes... fingers crossed 49 | # "train_micro_batch_size_per_gpu": 4, 50 | # "gradient_accumulation_steps": 2, 51 | "gradient_accumulation_steps": 8, 52 | "data-impl": "mmap", 53 | "split": "949,50,1", 54 | 55 | # activation checkpointing 56 | "checkpoint-activations": true, 57 | "checkpoint-num-layers": 1, 58 | "partition-activations": true, 59 | "synchronize-each-layer": true, 60 | 61 | # regularization 62 | "gradient_clipping": 1.0, 63 | "weight-decay": 0.1, 64 | "hidden-dropout": 0.0, 65 | "attention-dropout": 0.0, 66 | 67 | # precision settings 68 | "fp16": { 69 | "enabled": true, 70 | # "type": "bfloat16", # set bf16 as precision 71 | "loss_scale": 0, 72 | "loss_scale_window": 1000, 73 | "hysteresis": 2, 74 | "min_loss_scale": 1, 75 | }, 76 | 77 | # "fp32_allreduce": True, # without a patch to torch, bf16 models have to do the allreduce in fp32 78 | # misc. training settings 79 | "train-iters": 250000, 80 | "lr-decay-iters": 250000, 81 | "distributed-backend": "nccl", 82 | "lr-decay-style": "cosine", 83 | "warmup": 0.01, 84 | "checkpoint-factor": 1000, 85 | "eval-interval": 1000, 86 | "eval-iters": 10, 87 | 88 | # logging 89 | "log-interval": 1, 90 | "steps_per_print": 1, 91 | "keep-last-n-checkpoints": 1000, 92 | "wall_clock_breakdown": true, 93 | } 94 | -------------------------------------------------------------------------------- /distillation_sparsification/README.md: -------------------------------------------------------------------------------- 1 | ## Tools for sparsifying and distilling models 2 | 3 | -------------------------------------------------------------------------------- /distillation_sparsification/datautils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def set_seed(seed): 6 | np.random.seed(seed) 7 | torch.random.manual_seed(seed) 8 | 9 | 10 | def get_wikitext2(nsamples, seed, seqlen, model): 11 | from datasets import load_dataset 12 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 13 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 14 | 15 | from transformers import AutoTokenizer, LlamaTokenizer 16 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 17 | trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt') 18 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 19 | 20 | import random 21 | random.seed(seed) 22 | trainloader = [] 23 | for _ in range(nsamples): 24 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 25 | j = i + seqlen 26 | inp = trainenc.input_ids[:, i:j] 27 | tar = inp.clone() 28 | tar[:, :-1] = -100 29 | trainloader.append((inp, tar)) 30 | return trainloader, testenc 31 | 32 | def get_ptb(nsamples, seed, seqlen, model): 33 | from datasets import load_dataset 34 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') 35 | testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') 36 | 37 | from transformers import AutoTokenizer, LlamaTokenizer 38 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 39 | trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') 40 | testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') 41 | 42 | import random 43 | random.seed(seed) 44 | trainloader = [] 45 | for _ in range(nsamples): 46 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 47 | j = i + seqlen 48 | inp = trainenc.input_ids[:, i:j] 49 | tar = inp.clone() 50 | tar[:, :-1] = -100 51 | trainloader.append((inp, tar)) 52 | return trainloader, testenc 53 | 54 | def get_c4(nsamples, seed, seqlen, model): 55 | from datasets import load_dataset 56 | traindata = load_dataset( 57 | 'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train' 58 | ) 59 | valdata = load_dataset( 60 | 'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation' 61 | ) 62 | 63 | from transformers import AutoTokenizer 64 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 65 | 66 | import random 67 | random.seed(seed) 68 | trainloader = [] 69 | for _ in range(nsamples): 70 | while True: 71 | i = random.randint(0, len(traindata) - 1) 72 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 73 | if trainenc.input_ids.shape[1] >= seqlen: 74 | break 75 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 76 | j = i + seqlen 77 | inp = trainenc.input_ids[:, i:j] 78 | tar = inp.clone() 79 | tar[:, :-1] = -100 80 | trainloader.append((inp, tar)) 81 | 82 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') 83 | valenc = valenc.input_ids[:, :(256 * seqlen)] 84 | 85 | class TokenizerWrapper: 86 | def __init__(self, input_ids): 87 | self.input_ids = input_ids 88 | valenc = TokenizerWrapper(valenc) 89 | 90 | return trainloader, valenc 91 | 92 | 93 | def get_loaders( 94 | name, nsamples=128, seed=0, seqlen=2048, model='' 95 | ): 96 | if 'wikitext2' in name: 97 | return get_wikitext2(nsamples, seed, seqlen, model) 98 | if 'ptb' in name: 99 | return get_ptb(nsamples, seed, seqlen, model) 100 | if 'c4' in name: 101 | return get_c4(nsamples, seed, seqlen, model) 102 | -------------------------------------------------------------------------------- /distillation_sparsification/distill.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import logging 4 | import os 5 | import sys 6 | import time 7 | from collections import defaultdict 8 | from pathlib import Path 9 | from typing import Dict, List, Tuple 10 | 11 | import numpy as np 12 | import torch 13 | from torch import nn 14 | from torch.utils.data import DataLoader 15 | from transformers import Trainer 16 | from make_student import * 17 | from utils import label_smoothed_nll_loss, freeze_params 18 | from make_student import * 19 | 20 | def get_falcon(model): 21 | def skip(*args, **kwargs): 22 | pass 23 | 24 | torch.nn.init.kaiming_uniform_ = skip 25 | torch.nn.init.uniform_ = skip 26 | torch.nn.init.normal_ = skip 27 | tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) 28 | model = AutoModelForCausalLM.from_pretrained(model, trust_remote_code=True, torch_dtype=torch.bfloat16) 29 | model.seqlen = 2048 30 | return model, tokenizer 31 | 32 | 33 | 34 | class DistillationTrainer(Trainer): 35 | # length_field_name should possibly be part of TrainingArguments instead 36 | def __init__(self, length_field_name=None, swap_lang_prob=3, alpha_ce=0.5, alpha_hid=0.5, alpha_clm=0.5, temperature=2.0, all_teachers={}, normalize_hidden=False, *args, **kwargs): 37 | super().__init__(*args, **kwargs) 38 | self.length_field_name = length_field_name 39 | 40 | self.t_model, self.tokenizer = get_falcon("tiiuae/falcon-7b") 41 | 42 | # self.s_model, layer_ids = create_student_by_copying_alternating_layers(t_model, d=16, save_path='student/') 43 | 44 | freeze_params(self.t_model) 45 | 46 | self.d_matches = get_layers_to_supervise( 47 | n_student=16, n_teacher=32 48 | ) 49 | 50 | self.restrict_ce_to_mask = False 51 | self.alpha_ce = alpha_ce 52 | self.alpha_hid = alpha_hid 53 | self.alpha_clm = alpha_clm 54 | self.temperature = temperature 55 | self.normalize_hidden = normalize_hidden 56 | 57 | self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean") 58 | self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100) 59 | 60 | 61 | def calc_hidden_loss(self, attention_mask, lm_labels,hidden_states, hidden_states_T, matches, normalize_hidden): 62 | """MSE(student_hid, teacher_hid[matches]). Called "Intermediate supervision" in paper. Inspired by TinyBERT.""" 63 | msg = "expected list or tuple for hidden_states, got tensor of shape: " 64 | assert not isinstance(hidden_states, torch.Tensor), f"{msg}{hidden_states.shape}" 65 | assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}" 66 | if self.restrict_ce_to_mask: 67 | mask = lm_labels > -1 # (bs, seq_length, voc_size) 68 | else: 69 | mask = attention_mask # (bs, seq_length, voc_size) 70 | 71 | mask = mask.to(hidden_states[0]) 72 | valid_count = mask.sum() * hidden_states[0].size(-1) 73 | s_states = torch.stack([hidden_states[i] for i in range(len(matches))]) 74 | t_states = torch.stack([hidden_states_T[j] for j in matches]) 75 | assert s_states.shape == t_states.shape, f"{s_states.shape} != {t_states.shape}" 76 | if normalize_hidden: 77 | s_states = nn.functional.layer_norm(s_states, s_states.shape[1:]) 78 | t_states = nn.functional.layer_norm(t_states, t_states.shape[1:]) 79 | mse = nn.functional.mse_loss(s_states, t_states, reduction="none") 80 | masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count 81 | return masked_mse 82 | 83 | def calc_ce_loss(self, attention_mask, lm_labels, s_logits, t_logits): 84 | """Copy pasted from distillbert (transformers/examples/distillation/)""" 85 | # mask has False at padding_idx 86 | if self.restrict_ce_to_mask: 87 | mask = (lm_labels > -1) # (bs, seq_length, voc_size) 88 | else: 89 | mask = attention_mask # (bs, seq_length, voc_size) 90 | 91 | sel_mask = mask.unsqueeze(-1).expand_as(s_logits) 92 | vocab_size = s_logits.size(-1) 93 | s_logits_slct = torch.masked_select(s_logits, sel_mask) # (bs * seq_length * voc_size) modulo the 1s in mask 94 | t_logits_slct = torch.masked_select(t_logits, sel_mask) # (bs * seq_length * voc_size) modulo the 1s in mask 95 | s_logits_slct = s_logits_slct.view(-1, vocab_size) # (bs * seq_length, voc_size) modulo the 1s in mask 96 | t_logits_slct = t_logits_slct.view(-1, vocab_size) # (bs * seq_length, voc_size) modulo the 1s in mask 97 | assert t_logits_slct.size() == s_logits_slct.size() 98 | loss_ce = ( 99 | self.ce_loss_fct( 100 | nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1), 101 | nn.functional.softmax(t_logits_slct / self.temperature, dim=-1), 102 | ) 103 | * (self.temperature) ** 2 104 | ) 105 | return loss_ce 106 | 107 | def compute_loss(self, model, inputs, return_outputs=False): 108 | pad_token_id = self.tokenizer.pad_token_id 109 | input_ids, attn_mask, labels = inputs["input_ids"], inputs["attention_mask"], inputs["labels"] 110 | lm_labels = input_ids.new(input_ids.size()).copy_(input_ids) 111 | lm_labels[~attn_mask] = -100 # previously `lm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility 112 | # sanity checks 113 | assert 0 <= input_ids.min() <= input_ids.max() < self.tokenizer.vocab_size 114 | # noinspection PyCallingNonCallable 115 | s_outputs = model( 116 | input_ids, 117 | attention_mask=None, 118 | ) 119 | 120 | t_outputs = self.t_model( 121 | input_ids, 122 | attention_mask=None, 123 | ) 124 | print(5) 125 | s_logits, s_hidden_states = s_outputs["logits"], s_outputs["hidden_states"] 126 | t_logits, t_hidden_states = t_outputs["logits"], t_outputs["hidden_states"] 127 | 128 | 129 | if self.alpha_ce > 0.0: 130 | loss_ce = self.calc_ce_loss(attn_mask, lm_labels, s_logits, t_logits) 131 | 132 | if self.alpha_hid > 0.0: 133 | hid_loss = self.calc_hidden_loss( 134 | attn_mask, 135 | lm_labels, 136 | s_hidden_states, 137 | t_hidden_states, 138 | self.d_matches, 139 | normalize_hidden=self.normalize_hidden, 140 | ) 141 | 142 | if self.alpha_clm > 0.0: 143 | shift_logits = s_logits[..., :-1, :].contiguous() 144 | shift_labels = lm_labels[..., 1:].contiguous() 145 | loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 146 | 147 | loss = self.alpha_ce * loss_ce + self.alpha_clm * loss_clm + self.alpha_hid * hid_loss 148 | 149 | return (loss, loss_ce, loss_clm, hid_loss, s_outputs) if return_outputs else (loss, loss_ce, loss_clm, hid_loss) 150 | 151 | 152 | -------------------------------------------------------------------------------- /distillation_sparsification/lion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | 4 | 5 | class Lion(Optimizer): 6 | r"""Implements Lion algorithm.""" 7 | 8 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0): 9 | """Initialize the hyperparameters. 10 | 11 | Args: 12 | params (iterable): iterable of parameters to optimize or dicts defining 13 | parameter groups 14 | lr (float, optional): learning rate (default: 1e-4) 15 | betas (Tuple[float, float], optional): coefficients used for computing 16 | running averages of gradient and its square (default: (0.9, 0.99)) 17 | weight_decay (float, optional): weight decay coefficient (default: 0) 18 | """ 19 | 20 | if not 0.0 <= lr: 21 | raise ValueError('Invalid learning rate: {}'.format(lr)) 22 | if not 0.0 <= betas[0] < 1.0: 23 | raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0])) 24 | if not 0.0 <= betas[1] < 1.0: 25 | raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1])) 26 | defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) 27 | super().__init__(params, defaults) 28 | 29 | @torch.no_grad() 30 | def step(self, closure=None): 31 | """Performs a single optimization step. 32 | 33 | Args: 34 | closure (callable, optional): A closure that reevaluates the model 35 | and returns the loss. 36 | 37 | Returns: 38 | the loss. 39 | """ 40 | loss = None 41 | if closure is not None: 42 | with torch.enable_grad(): 43 | loss = closure() 44 | 45 | for group in self.param_groups: 46 | for p in group['params']: 47 | if p.grad is None: 48 | continue 49 | 50 | # Perform stepweight decay 51 | p.data.mul_(1 - group['lr'] * group['weight_decay']) 52 | 53 | grad = p.grad 54 | state = self.state[p] 55 | # State initialization 56 | if len(state) == 0: 57 | # Exponential moving average of gradient values 58 | state['exp_avg'] = torch.zeros_like(p) 59 | 60 | exp_avg = state['exp_avg'] 61 | beta1, beta2 = group['betas'] 62 | 63 | # Weight update 64 | update = exp_avg * beta1 + grad * (1 - beta1) 65 | p.add_(torch.sign(update), alpha=-group['lr']) 66 | # Decay the momentum running average coefficient 67 | exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) 68 | 69 | return loss -------------------------------------------------------------------------------- /distillation_sparsification/lm_seqs_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | from utils import logger 6 | 7 | 8 | class LmSeqsDataset(Dataset): 9 | """Custom Dataset wrapping language modeling sequences. 10 | 11 | Each sample will be retrieved by indexing the list of token_ids and their corresponding lengths. 12 | 13 | Input: 14 | ------ 15 | params: `NameSpace` parameters 16 | data: `List[np.array[int]] 17 | """ 18 | 19 | def __init__(self, params, data): 20 | self.params = params 21 | 22 | self.token_ids = np.array(data) 23 | self.lengths = np.array([len(t) for t in data]) 24 | 25 | self.check() 26 | self.remove_long_sequences() 27 | self.remove_empty_sequences() 28 | self.remove_unknown_sequences() 29 | self.check() 30 | self.print_statistics() 31 | 32 | def __getitem__(self, index): 33 | return (self.token_ids[index], self.lengths[index]) 34 | 35 | def __len__(self): 36 | return len(self.lengths) 37 | 38 | def check(self): 39 | """ 40 | Some sanity checks 41 | """ 42 | assert len(self.token_ids) == len(self.lengths) 43 | assert all(self.lengths[i] == len(self.token_ids[i]) for i in range(len(self.lengths))) 44 | 45 | def remove_long_sequences(self): 46 | """ 47 | Sequences that are too long are split by chunk of max_model_input_size. 48 | """ 49 | max_len = self.params.max_model_input_size 50 | indices = self.lengths > max_len 51 | logger.info(f"Splitting {sum(indices)} too long sequences.") 52 | 53 | def divide_chunks(l, n): 54 | return [l[i : i + n] for i in range(0, len(l), n)] 55 | 56 | new_tok_ids = [] 57 | new_lengths = [] 58 | if self.params.mlm: 59 | cls_id, sep_id = self.params.special_tok_ids["cls_token"], self.params.special_tok_ids["sep_token"] 60 | else: 61 | cls_id, sep_id = self.params.special_tok_ids["bos_token"], self.params.special_tok_ids["eos_token"] 62 | 63 | for seq_, len_ in zip(self.token_ids, self.lengths): 64 | assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_ 65 | if len_ <= max_len: 66 | new_tok_ids.append(seq_) 67 | new_lengths.append(len_) 68 | else: 69 | sub_seqs = [] 70 | for sub_s in divide_chunks(seq_, max_len - 2): 71 | if sub_s[0] != cls_id: 72 | sub_s = np.insert(sub_s, 0, cls_id) 73 | if sub_s[-1] != sep_id: 74 | sub_s = np.insert(sub_s, len(sub_s), sep_id) 75 | assert len(sub_s) <= max_len 76 | assert (sub_s[0] == cls_id) and (sub_s[-1] == sep_id), sub_s 77 | sub_seqs.append(sub_s) 78 | 79 | new_tok_ids.extend(sub_seqs) 80 | new_lengths.extend([len(l) for l in sub_seqs]) 81 | 82 | self.token_ids = np.array(new_tok_ids) 83 | self.lengths = np.array(new_lengths) 84 | 85 | def remove_empty_sequences(self): 86 | """ 87 | Too short sequences are simply removed. This could be tuned. 88 | """ 89 | init_size = len(self) 90 | indices = self.lengths > 11 91 | self.token_ids = self.token_ids[indices] 92 | self.lengths = self.lengths[indices] 93 | new_size = len(self) 94 | logger.info(f"Remove {init_size - new_size} too short (<=11 tokens) sequences.") 95 | 96 | def remove_unknown_sequences(self): 97 | """ 98 | Remove sequences with a (too) high level of unknown tokens. 99 | """ 100 | if "unk_token" not in self.params.special_tok_ids: 101 | return 102 | else: 103 | unk_token_id = self.params.special_tok_ids["unk_token"] 104 | init_size = len(self) 105 | unk_occs = np.array([np.count_nonzero(a == unk_token_id) for a in self.token_ids]) 106 | indices = (unk_occs / self.lengths) < 0.5 107 | self.token_ids = self.token_ids[indices] 108 | self.lengths = self.lengths[indices] 109 | new_size = len(self) 110 | logger.info(f"Remove {init_size - new_size} sequences with a high level of unknown tokens (50%).") 111 | 112 | def print_statistics(self): 113 | """ 114 | Print some statistics on the corpus. Only the master process. 115 | """ 116 | if not self.params.is_master: 117 | return 118 | logger.info(f"{len(self)} sequences") 119 | # data_len = sum(self.lengths) 120 | # nb_unique_tokens = len(Counter(list(chain(*self.token_ids)))) 121 | # logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)') 122 | 123 | # unk_idx = self.params.special_tok_ids['unk_token'] 124 | # nb_unknown = sum([(t==unk_idx).sum() for t in self.token_ids]) 125 | # logger.info(f'{nb_unknown} unknown tokens (covering {100*nb_unknown/data_len:.2f}% of the data)') 126 | 127 | def batch_sequences(self, batch): 128 | """ 129 | Do the padding and transform into torch.tensor. 130 | """ 131 | token_ids = [t[0] for t in batch] 132 | lengths = [t[1] for t in batch] 133 | assert len(token_ids) == len(lengths) 134 | 135 | # Max for paddings 136 | max_seq_len_ = max(lengths) 137 | 138 | # Pad token ids 139 | if self.params.mlm: 140 | pad_idx = self.params.special_tok_ids["pad_token"] 141 | else: 142 | pad_idx = self.params.special_tok_ids["unk_token"] 143 | tk_ = [list(t.astype(int)) + [pad_idx] * (max_seq_len_ - len(t)) for t in token_ids] 144 | assert len(tk_) == len(token_ids) 145 | assert all(len(t) == max_seq_len_ for t in tk_) 146 | 147 | tk_t = torch.tensor(tk_) # (bs, max_seq_len_) 148 | lg_t = torch.tensor(lengths) # (bs) 149 | return tk_t, lg_t -------------------------------------------------------------------------------- /distillation_sparsification/make_student.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | from typing import List, Tuple, Union 4 | 5 | import fire 6 | from torch import nn 7 | 8 | from transformers import AutoTokenizer, PreTrainedModel, AutoModelForCausalLM 9 | from transformers.utils import logging 10 | 11 | logger = logging.get_logger(__name__) 12 | 13 | 14 | def copy_layers(src_layers: nn.ModuleList, dest_layers: nn.ModuleList, layers_to_copy: List[int]) -> None: 15 | layers_to_copy = nn.ModuleList([src_layers[i] for i in layers_to_copy]) 16 | assert len(dest_layers) == len(layers_to_copy), f"{len(dest_layers)} != {len(layers_to_copy)}" 17 | dest_layers.load_state_dict(layers_to_copy.state_dict()) 18 | 19 | 20 | LAYERS_TO_COPY = { 21 | # maps num layers in teacher -> num_layers in student -> which teacher layers to copy. 22 | 32: { 23 | 1: [0], 24 | 2: [0, 31], 25 | 3: [0, 16, 31], 26 | # 4: [0, 10, 20, 31], 27 | # 6: [0, 3, 6, 9, 12, 31], 28 | 8: [0, 4, 8, 12, 16, 20, 24, 31], 29 | # 9: [0, 1, 3, 5, 7, 9, 11, 13, 31], 30 | 12: [0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 31], 31 | 16: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 31], 32 | 32: list(range(32)) 33 | }, 34 | } 35 | 36 | LAYERS_TO_SUPERVISE = { 37 | # maps num layers in student -> which teacher layers to copy. 38 | 6: {1: [5], 2: [3, 5], 3: [1, 4, 5], 4: [1, 2, 4, 5]}, 39 | 12: {1: [11], 2: [5, 11], 3: [3, 7, 11], 6: [1, 3, 5, 8, 10, 11]}, 40 | 16: {1: [15], 4: [4, 9, 12, 15], 8: [1, 3, 5, 7, 9, 11, 13, 15]}, 41 | 32: {8: [3, 7, 11, 15, 19, 23, 27, 31], 16: [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31]}, 42 | } 43 | 44 | def get_layers_to_supervise(n_student, n_teacher) -> List[int]: 45 | """Used or the --supervise_forward kwarg""" 46 | if n_student > n_teacher: 47 | raise ValueError(f"Cannot perform intermediate supervision for student {n_student} > teacher {n_teacher}") 48 | elif n_teacher == n_student: 49 | return list(range(n_teacher)) 50 | elif n_student == 1: 51 | return [n_teacher - 1] 52 | else: 53 | return LAYERS_TO_SUPERVISE[n_teacher][n_student] 54 | 55 | 56 | def pick_layers_to_copy(n_student, n_teacher): 57 | try: 58 | val = LAYERS_TO_COPY[n_teacher][n_student] 59 | return val 60 | except KeyError: 61 | if n_student != n_teacher: 62 | warnings.warn( 63 | f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first" 64 | f" {n_student}" 65 | ) 66 | return list(range(n_student)) 67 | 68 | 69 | def create_student_by_copying_alternating_layers( 70 | teacher: Union[str, PreTrainedModel], 71 | save_path: Union[str, Path] = "student", 72 | d: Union[int, None] = None, 73 | copy_first_teacher_layers=False, 74 | d_layers_to_copy=None, 75 | **extra_config_kwargs, 76 | ) -> Tuple[PreTrainedModel, List[int], List[int]]: 77 | """Make a student by copying alternating layers from a teacher, save it to save_path. 78 | Args: 79 | teacher: str or PreTrainedModel if str, this will call AutoModelForCausalLM.from_pretrained(teacher) before 80 | copying layers 81 | save_path: where to save the student, defaults to student directory. 82 | d: how many Decoder layers should the student have, default is fully copy of teacher 83 | copy_first_teacher_layers: [bool] dont copy alternating layers, just the first e/d. 84 | **extra_config_kwargs: extra kwargs to pass to the student, by default the teacher config is used. 85 | 86 | Returns: 87 | student: new, smaller model. (Also saves it to save_path) 88 | d_layers_to_copy: list of which teacher decoder layers were used 89 | """ 90 | _msg = "decoder_layers cannot be both None-- you would just have an identical teacher." 91 | assert (d is not None), _msg 92 | if isinstance(teacher, str): 93 | AutoTokenizer.from_pretrained(teacher).save_pretrained(save_path) # purely for convenience 94 | teacher = AutoModelForCausalLM.from_pretrained(teacher).eval() 95 | else: 96 | assert isinstance(teacher, PreTrainedModel), f"teacher must be a model or string got type {type(teacher)}" 97 | init_kwargs = teacher.config.to_diff_dict() 98 | 99 | try: 100 | teacher_d = teacher.config.decoder_layers 101 | if d is None: 102 | d = teacher_d 103 | init_kwargs.update({"n_layer": d}) 104 | except AttributeError: # T5 105 | if hasattr(teacher.config, "n_layer"): 106 | teacher_d = teacher.config.n_layer 107 | else: 108 | teacher_d = teacher.config.n_layer 109 | if d is None: 110 | d = teacher_d 111 | if hasattr(teacher.config, "n_layer"): 112 | init_kwargs.update({"n_layer": d}) 113 | else: 114 | init_kwargs.update({"n_layer": d}) 115 | 116 | # Kwargs to instantiate student: teacher kwargs with updated layer numbers + **extra_config_kwargs 117 | init_kwargs.update(extra_config_kwargs) 118 | 119 | # Copy weights 120 | student_cfg = teacher.config_class(**init_kwargs) 121 | student = AutoModelForCausalLM.from_config(student_cfg, trust_remote_code=True) 122 | # Start by copying the full teacher state dict this will copy the first N teacher layers to the student. 123 | info = student.load_state_dict(teacher.state_dict(), strict=False) 124 | assert info.missing_keys == [], info.missing_keys # every student key should have a teacher keys. 125 | 126 | if copy_first_teacher_layers: # Our copying is done. We just log and save 127 | d_layers_to_copy = list(range(d)) 128 | logger.info( 129 | f"Copied decoder layers {d_layers_to_copy}. Saving them to" 130 | f" {save_path}" 131 | ) 132 | student.save_pretrained(save_path) 133 | return student, d_layers_to_copy 134 | 135 | # Decide which layers of the teacher to copy. Not exactly alternating -- we try to keep first and last layer. 136 | if d_layers_to_copy is None: 137 | d_layers_to_copy: List[int] = pick_layers_to_copy(d, teacher_d) 138 | 139 | try: 140 | if hasattr( 141 | teacher, "prophetnet" 142 | ): 143 | copy_layers(teacher.prophetnet.decoder.layers, student.prophetnet.decoder.layers, d_layers_to_copy) 144 | else: 145 | copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy) 146 | except AttributeError: 147 | copy_layers(teacher.transformer.h, student.transformer.h, d_layers_to_copy) 148 | logger.info( 149 | f"Copied decoder layers {d_layers_to_copy}. Saving them to {save_path}" 150 | ) 151 | student.config.init_metadata = { 152 | "teacher_type": teacher.config.model_type, 153 | "copied_decoder_layers": d_layers_to_copy, 154 | } 155 | student.save_pretrained(save_path) 156 | # Save information about copying for easier reproducibility 157 | 158 | return student, d_layers_to_copy 159 | 160 | 161 | if __name__ == "__main__": 162 | fire.Fire(create_student_by_copying_alternating_layers) -------------------------------------------------------------------------------- /distillation_sparsification/modelutils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | DEV = torch.device('cuda:0') 6 | 7 | 8 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): 9 | if type(module) in layers or 'Linear' in str(type(module)): 10 | return {name: module} 11 | res = {} 12 | 13 | for name1, child in module.named_children(): 14 | res.update(find_layers( 15 | child, layers=layers, name=name + '.' + name1 if name != '' else name1 16 | )) 17 | return res 18 | -------------------------------------------------------------------------------- /distillation_sparsification/process_data.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | import transformers 7 | 8 | from sparsegpt import * 9 | from modelutils import * 10 | import sys 11 | from transformers import AutoTokenizer, AutoModelForCausalLM 12 | from falcon7b import modelling_RW 13 | from datasets import load_dataset 14 | from make_student import * 15 | import multiprocessing 16 | from datasets import load_from_disk 17 | from itertools import chain 18 | 19 | 20 | 21 | tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b", trust_remote_code=True) 22 | if tokenizer.pad_token is None: 23 | tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 24 | 25 | ds = load_dataset('JeanKaddour/minipile') 26 | # print(ds.column_names) 27 | 28 | def preprocess_function(examples): 29 | return tokenizer(examples["text"], ) 30 | 31 | block_size = 2048 32 | 33 | def group_texts(examples): 34 | # Concatenate all texts. 35 | concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} 36 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 37 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 38 | # customize this part to your needs. 39 | if total_length >= block_size: 40 | total_length = (total_length // block_size) * block_size 41 | else: 42 | total_length = 0 43 | # Split by chunks of block_size. 44 | result = { 45 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)] 46 | for k, t in concatenated_examples.items() 47 | } 48 | result["labels"] = result["input_ids"].copy() 49 | del result['token_type_ids'] 50 | return result 51 | 52 | 53 | tokenized_minipile = ds.map( 54 | preprocess_function, 55 | batched=True, 56 | num_proc=multiprocessing.cpu_count(), 57 | remove_columns=ds['validation'].column_names, 58 | ) 59 | 60 | tokenized_minipile.save_to_disk("ds/raw") 61 | 62 | # tokenized_minipile = load_from_disk("ds/raw") 63 | 64 | ds = tokenized_minipile.map(group_texts, batched=True, num_proc=multiprocessing.cpu_count()) 65 | 66 | 67 | ds.save_to_disk("ds/processed") 68 | 69 | # ds = load_from_disk("ds/processed") 70 | # print(ds) 71 | # train_dataset = ds["train"] 72 | # eval_dataset = ds["validation"] 73 | # test_dataset = ds["test"] 74 | # print(test_dataset) 75 | # eval_dataset = eval_dataset.select( 76 | # ( 77 | # i for i in range(len(eval_dataset)) 78 | # if len(eval_dataset[i]['input_ids']) == 2048 79 | # ) 80 | # ) 81 | 82 | # test_dataset = test_dataset.select( 83 | # ( 84 | # i for i in range(len(test_dataset)) 85 | # if len(test_dataset[i]['input_ids']) == 2048 86 | # ) 87 | # ) 88 | 89 | # train_dataset = train_dataset.select( 90 | # ( 91 | # i for i in range(len(train_dataset)) 92 | # if len(train_dataset[i]['input_ids']) == 2048 93 | # ) 94 | # ) 95 | 96 | # ds["train"] = train_dataset 97 | # ds["validation"] = eval_dataset 98 | # ds["test"] = test_dataset 99 | 100 | print(ds) 101 | # print(eval_dataset) 102 | # print(test_dataset) 103 | # print(train_dataset) -------------------------------------------------------------------------------- /distillation_sparsification/quant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def quantize(x, scale, zero, maxq): 7 | q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) 8 | return scale * (q - zero) 9 | 10 | class Quantizer(nn.Module): 11 | 12 | def __init__(self, shape=1): 13 | super(Quantizer, self).__init__() 14 | self.register_buffer('maxq', torch.tensor(0)) 15 | self.register_buffer('scale', torch.zeros(shape)) 16 | self.register_buffer('zero', torch.zeros(shape)) 17 | 18 | def configure( 19 | self, 20 | bits, perchannel=False, sym=True, 21 | mse=False, norm=2.4, grid=100, maxshrink=.8, 22 | grouprows=1 23 | ): 24 | self.maxq = torch.tensor(2 ** bits - 1) 25 | self.perchannel = perchannel 26 | self.sym = sym 27 | self.mse = mse 28 | self.norm = norm 29 | self.grid = grid 30 | self.maxshrink = maxshrink 31 | self.grouprows = grouprows 32 | 33 | def find_params(self, x, weight=False): 34 | dev = x.device 35 | self.maxq = self.maxq.to(dev) 36 | 37 | shape = x.shape 38 | if self.perchannel: 39 | if weight: 40 | x = x.flatten(1) 41 | if self.grouprows > 1: 42 | x = x.reshape((x.shape[0] // self.grouprows, -1)) 43 | else: 44 | if len(shape) == 4: 45 | x = x.permute([1, 0, 2, 3]) 46 | x = x.flatten(1) 47 | if len(shape) == 3: 48 | x = x.reshape((-1, shape[-1])).t() 49 | if len(shape) == 2: 50 | x = x.t() 51 | else: 52 | x = x.flatten().unsqueeze(0) 53 | 54 | tmp = torch.zeros(x.shape[0], device=dev) 55 | xmin = torch.minimum(x.min(1)[0], tmp) 56 | xmax = torch.maximum(x.max(1)[0], tmp) 57 | 58 | if self.sym: 59 | xmax = torch.maximum(torch.abs(xmin), xmax) 60 | tmp = xmin < 0 61 | if torch.any(tmp): 62 | xmin[tmp] = -xmax[tmp] 63 | tmp = (xmin == 0) & (xmax == 0) 64 | xmin[tmp] = -1 65 | xmax[tmp] = +1 66 | 67 | self.scale = (xmax - xmin) / self.maxq 68 | if self.sym: 69 | self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) 70 | else: 71 | self.zero = torch.round(-xmin / self.scale) 72 | 73 | if self.mse: 74 | best = torch.full([x.shape[0]], float('inf'), device=dev) 75 | for i in range(int(self.maxshrink * self.grid)): 76 | p = 1 - i / self.grid 77 | xmin1 = p * xmin 78 | xmax1 = p * xmax 79 | scale1 = (xmax1 - xmin1) / self.maxq 80 | zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero 81 | q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) 82 | q -= x 83 | q.abs_() 84 | q.pow_(self.norm) 85 | err = torch.sum(q, 1) 86 | tmp = err < best 87 | if torch.any(tmp): 88 | best[tmp] = err[tmp] 89 | self.scale[tmp] = scale1[tmp] 90 | self.zero[tmp] = zero1[tmp] 91 | if not self.perchannel: 92 | if weight: 93 | tmp = shape[0] 94 | else: 95 | tmp = shape[1] if len(shape) != 3 else shape[2] 96 | self.scale = self.scale.repeat(tmp) 97 | self.zero = self.zero.repeat(tmp) 98 | 99 | if weight: 100 | if self.grouprows > 1: 101 | self.scale = self.scale.unsqueeze(1).repeat(1, self.grouprows) 102 | self.zero = self.zero.unsqueeze(1).repeat(1, self.grouprows) 103 | shape = [-1] + [1] * (len(shape) - 1) 104 | self.scale = self.scale.reshape(shape) 105 | self.zero = self.zero.reshape(shape) 106 | return 107 | if len(shape) == 4: 108 | self.scale = self.scale.reshape((1, -1, 1, 1)) 109 | self.zero = self.zero.reshape((1, -1, 1, 1)) 110 | if len(shape) == 3: 111 | self.scale = self.scale.reshape((1, 1, -1)) 112 | self.zero = self.zero.reshape((1, 1, -1)) 113 | if len(shape) == 2: 114 | self.scale = self.scale.unsqueeze(0) 115 | self.zero = self.zero.unsqueeze(0) 116 | 117 | def quantize(self, x): 118 | if self.ready(): 119 | return quantize(x, self.scale, self.zero, self.maxq) 120 | return x 121 | 122 | def enabled(self): 123 | return self.maxq > 0 124 | 125 | def ready(self): 126 | return torch.all(self.scale != 0) 127 | -------------------------------------------------------------------------------- /distillation_sparsification/sparsegpt.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | import transformers 7 | 8 | from quant import * 9 | 10 | 11 | DEBUG = False 12 | 13 | torch.backends.cuda.matmul.allow_tf32 = False 14 | torch.backends.cudnn.allow_tf32 = False 15 | 16 | 17 | class SparseGPT: 18 | 19 | def __init__(self, layer): 20 | self.layer = layer 21 | self.dev = self.layer.weight.device 22 | W = layer.weight.data.clone() 23 | if isinstance(self.layer, nn.Conv2d): 24 | W = W.flatten(1) 25 | if isinstance(self.layer, transformers.Conv1D): 26 | W = W.t() 27 | self.rows = W.shape[0] 28 | self.columns = W.shape[1] 29 | self.H = torch.zeros((self.columns, self.columns), device=self.dev) 30 | self.nsamples = 0 31 | 32 | def add_batch(self, inp, out, blocksize=1024): 33 | if DEBUG: 34 | self.inp1 = inp 35 | self.out1 = out 36 | if len(inp.shape) == 2: 37 | inp = inp.unsqueeze(0) 38 | tmp = inp.shape[0] 39 | if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): 40 | if len(inp.shape) == 3: 41 | inp = inp.reshape((-1, inp.shape[-1])) 42 | inp = inp.t() 43 | self.H *= self.nsamples / (self.nsamples + tmp) 44 | self.nsamples += tmp 45 | inp = math.sqrt(2 / self.nsamples) * inp.float() 46 | self.H += inp.matmul(inp.t()) 47 | 48 | def fasterprune( 49 | self, sparsity, prunen=0, prunem=0, blocksize=128, percdamp=.01 50 | ): 51 | W = self.layer.weight.data.clone() 52 | if isinstance(self.layer, nn.Conv2d): 53 | W = W.flatten(1) 54 | if isinstance(self.layer, transformers.Conv1D): 55 | W = W.t() 56 | W = W.float() 57 | 58 | if hasattr(self, 'quantizer'): 59 | if not self.quantizer.ready(): 60 | self.quantizer.find_params(W, weight=True) 61 | 62 | tick = time.time() 63 | 64 | H = self.H 65 | del self.H 66 | dead = torch.diag(H) == 0 67 | H[dead, dead] = 1 68 | W[:, dead] = 0 69 | 70 | Losses = torch.zeros(self.rows, device=self.dev) 71 | 72 | damp = percdamp * torch.mean(torch.diag(H)) 73 | diag = torch.arange(self.columns, device=self.dev) 74 | H[diag, diag] += damp 75 | H = torch.linalg.cholesky(H) 76 | H = torch.cholesky_inverse(H) 77 | H = torch.linalg.cholesky(H, upper=True) 78 | Hinv = H 79 | 80 | mask = None 81 | 82 | for i1 in range(0, self.columns, blocksize): 83 | i2 = min(i1 + blocksize, self.columns) 84 | count = i2 - i1 85 | 86 | W1 = W[:, i1:i2].clone() 87 | Q1 = torch.zeros_like(W1) 88 | Err1 = torch.zeros_like(W1) 89 | Losses1 = torch.zeros_like(W1) 90 | Hinv1 = Hinv[i1:i2, i1:i2] 91 | 92 | if prunen == 0: 93 | if mask is not None: 94 | mask1 = mask[:, i1:i2] 95 | else: 96 | tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2 97 | thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] 98 | mask1 = tmp <= thresh 99 | else: 100 | mask1 = torch.zeros_like(W1) == 1 101 | 102 | for i in range(count): 103 | w = W1[:, i] 104 | d = Hinv1[i, i] 105 | 106 | if prunen != 0 and i % prunem == 0: 107 | tmp = W1[:, i:(i + prunem)] ** 2 / (torch.diag(Hinv1)[i:(i + prunem)].reshape((1, -1))) ** 2 108 | mask1.scatter_(1, i + torch.topk(tmp, prunen, dim=1, largest=False)[1], True) 109 | 110 | q = w.clone() 111 | q[mask1[:, i]] = 0 112 | 113 | if hasattr(self, 'quantizer'): 114 | q = quantize( 115 | q.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq 116 | ).flatten() 117 | 118 | Q1[:, i] = q 119 | Losses1[:, i] = (w - q) ** 2 / d ** 2 120 | 121 | err1 = (w - q) / d 122 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) 123 | Err1[:, i] = err1 124 | 125 | W[:, i1:i2] = Q1 126 | Losses += torch.sum(Losses1, 1) / 2 127 | 128 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) 129 | 130 | if DEBUG: 131 | self.layer.weight.data[:, :i2] = W[:, :i2] 132 | self.layer.weight.data[:, i2:] = W[:, i2:] 133 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 134 | print(torch.sum(Losses)) 135 | 136 | torch.cuda.synchronize() 137 | print('time %.2f' % (time.time() - tick)) 138 | print('error', torch.sum(Losses).item()) 139 | 140 | if isinstance(self.layer, transformers.Conv1D): 141 | W = W.t() 142 | self.layer.weight.data = W.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) 143 | if DEBUG: 144 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 145 | 146 | def free(self): 147 | if DEBUG: 148 | self.inp1 = None 149 | self.out1 = None 150 | self.H = None 151 | torch.cuda.empty_cache() 152 | -------------------------------------------------------------------------------- /distillation_sparsification/test.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | import transformers 7 | 8 | from sparsegpt import * 9 | from modelutils import * 10 | import sys 11 | #from transformers import AutoTokenizer, AutoModelForCausalLM 12 | from transformers import ( 13 | CONFIG_MAPPING, 14 | MODEL_MAPPING, 15 | AutoConfig, 16 | AutoModelForCausalLM, 17 | AutoTokenizer, 18 | SchedulerType, 19 | default_data_collator, 20 | get_scheduler, 21 | ) 22 | 23 | from falcon7b import modelling_RW 24 | from datasets import load_dataset 25 | from make_student import * 26 | import multiprocessing 27 | from transformers import DataCollatorForLanguageModeling 28 | from distill import DistillationTrainer 29 | from datasets import load_from_disk 30 | from torch.utils.data import DataLoader 31 | 32 | from accelerate import Accelerator, DistributedType, notebook_launcher 33 | from accelerate.logging import get_logger 34 | from accelerate.utils import set_seed 35 | from tqdm.auto import tqdm 36 | 37 | from transformers import Trainer, TrainingArguments 38 | import deepspeed 39 | 40 | ds = load_from_disk("processed_ds/processed") 41 | print(ds) 42 | train_dataset = ds["train"] 43 | eval_dataset = ds["validation"] 44 | 45 | 46 | eval_dataset = eval_dataset.select((i for i in range(len(eval_dataset)-1))) 47 | 48 | for k,v in eval_dataset[-1].items(): 49 | print(k, torch.tensor(v).shape) 50 | # for k,v in eval_dataset[-1].items(): 51 | # print(k, v.shape) 52 | 53 | 54 | # tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b", trust_remote_code=True) 55 | # data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) 56 | # if tokenizer.pad_token is None: 57 | # tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 58 | 59 | # model = AutoModelForCausalLM.from_pretrained('student/', trust_remote_code=True, torch_dtype=torch.bfloat16) 60 | 61 | # # ds_config = { 62 | # # "tensor_parallel": {"tp_size": 1}, 63 | # # "dtype": "fp16", 64 | # # "replace_with_kernel_inject": True, 65 | # # "replace_method": "auto", 66 | # # } 67 | 68 | # # ds_model = deepspeed.init_inference(model=model, config=ds_config) 69 | 70 | # per_device_train_batch_size = 2 71 | # per_device_eval_batch_size = 2 72 | # learning_rate = 3e-4 73 | # lr_scheduler_type = "cosine" 74 | # num_warmup_steps = 100 75 | # max_train_steps = 1_000 76 | # num_train_epochs = 1 77 | # weight_decay = 0.01 78 | # gradient_accumulation_steps = 1 79 | # output_dir = 'log/' 80 | # with_tracking = True 81 | # report_to = "tensorboard" 82 | 83 | # accelerator_log_kwargs = {} 84 | 85 | # if with_tracking: 86 | # accelerator_log_kwargs["log_with"] = report_to 87 | # accelerator_log_kwargs["project_dir"] = output_dir 88 | 89 | # accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps, log_with="tensorboard", project_dir="log") 90 | 91 | # train_dataloader = DataLoader( 92 | # train_dataset, shuffle=True, collate_fn=data_collator, batch_size=per_device_train_batch_size 93 | # ) 94 | # eval_dataloader = DataLoader( 95 | # eval_dataset, collate_fn=data_collator, batch_size=per_device_eval_batch_size 96 | # ) 97 | 98 | # no_decay = ["bias", "layer_norm.weight"] 99 | # optimizer_grouped_parameters = [ 100 | # { 101 | # "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 102 | # "weight_decay": weight_decay, 103 | # }, 104 | # { 105 | # "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 106 | # "weight_decay": 0.0, 107 | # }, 108 | # ] 109 | # optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=learning_rate) 110 | 111 | # overrode_max_train_steps = False 112 | # num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) 113 | # if max_train_steps is None: 114 | # max_train_steps = num_train_epochs * num_update_steps_per_epoch 115 | # overrode_max_train_steps = True 116 | 117 | # lr_scheduler = get_scheduler( 118 | # name=lr_scheduler_type, 119 | # optimizer=optimizer, 120 | # num_warmup_steps=num_warmup_steps * gradient_accumulation_steps, 121 | # num_training_steps=max_train_steps * gradient_accumulation_steps, 122 | # ) 123 | 124 | # model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( 125 | # model, optimizer, train_dataloader, eval_dataloader, lr_scheduler 126 | # ) 127 | 128 | # experiment_config = { 129 | # "num_iterations": 1, 130 | # "learning_rate": 3e-4, 131 | # } 132 | 133 | 134 | # accelerator.init_trackers("clm_no_trainer", experiment_config) 135 | 136 | # # We need to recalculate our total training steps as the size of the training dataloader may have changed. 137 | # num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) 138 | # if overrode_max_train_steps: 139 | # max_train_steps = num_train_epochs * num_update_steps_per_epoch 140 | # # Afterwards we recalculate our number of training epochs 141 | # num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) 142 | 143 | # # Train! 144 | # total_batch_size = per_device_train_batch_size * accelerator.num_processes * gradient_accumulation_steps 145 | 146 | # progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process) 147 | # completed_steps = 0 148 | # starting_epoch = 0 149 | 150 | # # update the progress_bar if load from checkpoint 151 | # progress_bar.update(completed_steps) 152 | # print("training started") 153 | # for epoch in range(starting_epoch, num_train_epochs): 154 | # print("Train!") 155 | # model.train() 156 | # if with_tracking: 157 | # total_loss = 0 158 | 159 | # active_dataloader = train_dataloader 160 | # for step, batch in enumerate(active_dataloader): 161 | # with accelerator.accumulate(model): 162 | # batch.pop("token_type_ids") 163 | # outputs = model(**batch) 164 | # loss = outputs.loss 165 | # # We keep track of the loss at each epoch 166 | # if with_tracking: 167 | # total_loss += loss.detach().float() 168 | # accelerator.backward(loss) 169 | # optimizer.step() 170 | # lr_scheduler.step() 171 | # optimizer.zero_grad() 172 | # accelerator.log({"training_loss": loss}, step=step) 173 | 174 | # # Checks if the accelerator has performed an optimization step behind the scenes 175 | # if accelerator.sync_gradients: 176 | # progress_bar.update(1) 177 | # completed_steps += 1 178 | 179 | # if completed_steps >= max_train_steps: 180 | # break 181 | 182 | # model.eval() 183 | # losses = [] 184 | # for step, batch in enumerate(eval_dataloader): 185 | # with torch.no_grad(): 186 | # batch.pop("token_type_ids") 187 | # outputs = model(**batch) 188 | 189 | # loss = outputs.loss 190 | # losses.append(accelerator.gather_for_metrics(loss.repeat(per_device_eval_batch_size))) 191 | 192 | # losses = torch.cat(losses) 193 | # try: 194 | # eval_loss = torch.mean(losses) 195 | # perplexity = math.exp(eval_loss) 196 | # except OverflowError: 197 | # perplexity = float("inf") 198 | 199 | # logger.info(f"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}") 200 | # if with_tracking: 201 | # accelerator.log( 202 | # { 203 | # "perplexity": perplexity, 204 | # "eval_loss": eval_loss, 205 | # "train_loss": total_loss.item() / len(train_dataloader), 206 | # "epoch": epoch, 207 | # "step": completed_steps, 208 | # }, 209 | # step=completed_steps, 210 | # ) 211 | # print(f"epoch: {epoch}") 212 | # print(f"eval_loss: {eval_loss}") 213 | # print(f"train_loss: {total_loss.item() / len(train_dataloader)}") 214 | # print(f"perplexity: {perplexity}") 215 | # accelerator.end_training() 216 | 217 | 218 | 219 | -------------------------------------------------------------------------------- /distillation_sparsification/test1.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | import transformers 7 | 8 | from sparsegpt import * 9 | from modelutils import * 10 | import sys 11 | #from transformers import AutoTokenizer, AutoModelForCausalLM 12 | from transformers import ( 13 | CONFIG_MAPPING, 14 | MODEL_MAPPING, 15 | AutoConfig, 16 | AutoModelForCausalLM, 17 | AutoTokenizer, 18 | SchedulerType, 19 | default_data_collator, 20 | get_scheduler, 21 | ) 22 | 23 | from falcon7b import modelling_RW 24 | from datasets import load_dataset 25 | from make_student import * 26 | import multiprocessing 27 | from transformers import DataCollatorForLanguageModeling 28 | from distill import DistillationTrainer 29 | from datasets import load_from_disk 30 | from torch.utils.data import DataLoader 31 | 32 | from accelerate import Accelerator, DistributedType, notebook_launcher 33 | from accelerate.logging import get_logger 34 | from accelerate.utils import set_seed 35 | from tqdm.auto import tqdm 36 | 37 | from transformers import Trainer, TrainingArguments 38 | import deepspeed 39 | 40 | device = torch.device('cuda:3') 41 | ds = load_from_disk("processed_ds/processed") 42 | print(ds) 43 | train_dataset = ds["test"] 44 | eval_dataset = ds["validation"] 45 | 46 | tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b", trust_remote_code=True) 47 | print(tokenizer) 48 | 49 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) 50 | if tokenizer.pad_token is None: 51 | tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 52 | t_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-7b", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device) 53 | print(t_model.config) 54 | t_model.config.output_hidden_states=True 55 | 56 | # model = AutoModelForCausalLM.from_pretrained('student/', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device) 57 | model, layer_ids = create_student_by_copying_alternating_layers(t_model, d=8, save_path='student1/') 58 | print(model) 59 | print(layer_ids) 60 | # pytorch_total_params = sum(p.numel() for p in model.parameters()) 61 | 62 | # print(pytorch_total_params) 63 | # train_dataloader = DataLoader( 64 | # train_dataset, shuffle=True, collate_fn=data_collator, batch_size=1 65 | # ) 66 | 67 | # batch = next(iter(train_dataloader)) 68 | # inputs_id = batch['input_ids'].to(device) 69 | # print(inputs_id.shape) 70 | 71 | # outputs = t_model(input_ids=inputs_id, attention_mask=None) 72 | # print(list(outputs.keys())) -------------------------------------------------------------------------------- /distillation_sparsification/tracker.py: -------------------------------------------------------------------------------- 1 | from accelerate.tracking import GeneralTracker, on_main_process 2 | from typing import Optional 3 | 4 | import wandb 5 | 6 | 7 | class MyCustomTracker(GeneralTracker): 8 | name = "wandb" 9 | requires_logging_directory = False 10 | 11 | 12 | def __init__(self, run_name: str): 13 | self.run_name = run_name 14 | run = wandb.init(self.run_name) 15 | 16 | 17 | def tracker(self): 18 | return self.run.run 19 | 20 | 21 | def store_init_configuration(self, values: dict): 22 | wandb.config(values) 23 | 24 | 25 | def log(self, values: dict, step: Optional[int] = None): 26 | wandb.log(values, step=step) -------------------------------------------------------------------------------- /distillation_sparsification/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | import linecache 4 | import math 5 | import os 6 | import pickle 7 | import socket 8 | from logging import getLogger 9 | from pathlib import Path 10 | from typing import Callable, Dict, Iterable, List, Tuple, Union 11 | 12 | import git 13 | import numpy as np 14 | import torch 15 | import torch.distributed as dist 16 | from torch import nn 17 | from torch.utils.data import Dataset, Sampler 18 | 19 | def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): 20 | """From fairseq""" 21 | if target.dim() == lprobs.dim() - 1: 22 | target = target.unsqueeze(-1) 23 | nll_loss = -lprobs.gather(dim=-1, index=target) 24 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True) 25 | if ignore_index is not None: 26 | pad_mask = target.eq(ignore_index) 27 | nll_loss.masked_fill_(pad_mask, 0.0) 28 | smooth_loss.masked_fill_(pad_mask, 0.0) 29 | else: 30 | nll_loss = nll_loss.squeeze(-1) 31 | smooth_loss = smooth_loss.squeeze(-1) 32 | 33 | nll_loss = nll_loss.sum() # mean()? Scared to break other math. 34 | smooth_loss = smooth_loss.sum() 35 | eps_i = epsilon / lprobs.size(-1) 36 | loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss 37 | return loss, nll_loss 38 | 39 | def freeze_params(model: nn.Module): 40 | """Set requires_grad=False for each of model.parameters()""" 41 | for par in model.parameters(): 42 | par.requires_grad = False 43 | 44 | 45 | def calc_ce_loss(attention_mask, lm_labels, s_logits, t_logits, temperature, restrict_ce_to_mask): 46 | """Copy pasted from distillbert (transformers/examples/distillation/)""" 47 | # mask has False at padding_idx 48 | ce_loss_fct = nn.KLDivLoss(reduction="batchmean") 49 | if restrict_ce_to_mask: 50 | mask = (lm_labels > -1) # (bs, seq_length, voc_size) 51 | else: 52 | mask = attention_mask # (bs, seq_length, voc_size) 53 | 54 | mask = mask.unsqueeze(-1).expand_as(s_logits) 55 | mask = torch.gt(mask, 0) 56 | s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask 57 | s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask 58 | t_logits_slct = torch.masked_select(t_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask 59 | t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask 60 | assert t_logits_slct.size() == s_logits_slct.size() 61 | 62 | loss_ce = ( 63 | ce_loss_fct( 64 | nn.functional.log_softmax(s_logits_slct / temperature, dim=-1), 65 | nn.functional.softmax(t_logits_slct / temperature, dim=-1), 66 | ) 67 | * (temperature) ** 2 68 | ) 69 | return loss_ce 70 | 71 | 72 | def calc_hidden_loss(attention_mask, lm_labels,hidden_states, hidden_states_T, matches, normalize_hidden, restrict_ce_to_mask): 73 | """MSE(student_hid, teacher_hid[matches]). Called "Intermediate supervision" in paper. Inspired by TinyBERT.""" 74 | msg = "expected list or tuple for hidden_states, got tensor of shape: " 75 | assert not isinstance(hidden_states, torch.Tensor), f"{msg}{hidden_states.shape}" 76 | assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}" 77 | if restrict_ce_to_mask: 78 | mask = lm_labels > -1 # (bs, seq_length, voc_size) 79 | else: 80 | mask = attention_mask # (bs, seq_length, voc_size) 81 | 82 | mask = mask.to(hidden_states[0]) 83 | valid_count = mask.sum() * hidden_states[0].size(-1) 84 | s_states = torch.stack([hidden_states[i] for i in range(len(matches))]) 85 | t_states = torch.stack([hidden_states_T[j] for j in matches]) 86 | assert s_states.shape == t_states.shape, f"{s_states.shape} != {t_states.shape}" 87 | if normalize_hidden: 88 | s_states = nn.functional.layer_norm(s_states, s_states.shape[1:]) 89 | t_states = nn.functional.layer_norm(t_states, t_states.shape[1:]) 90 | mse = nn.functional.mse_loss(s_states, t_states, reduction="none") 91 | masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count 92 | return masked_mse 93 | 94 | 95 | def eval(s_model, ds, with_tracking=False): 96 | s_model.eval() 97 | lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100) 98 | losses = [] 99 | for step, batch in enumerate(ds): 100 | with torch.no_grad(): 101 | input_ids, attn_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"] 102 | lm_labels = batch["input_ids"] 103 | 104 | s_outputs = s_model( 105 | input_ids, 106 | attention_mask=None, 107 | ) 108 | 109 | s_logits, s_hidden_states = s_outputs["logits"], s_outputs["hidden_states"] 110 | shift_logits = s_logits[..., :-1, :].contiguous() 111 | shift_labels = lm_labels[..., 1:].contiguous() 112 | loss = lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 113 | 114 | losses.append(accelerator.gather_for_metrics(loss.repeat(per_device_eval_batch_size))) 115 | 116 | losses = torch.cat(losses) 117 | try: 118 | eval_loss = torch.mean(losses) 119 | perplexity = math.exp(eval_loss) 120 | except OverflowError: 121 | perplexity = float("inf") 122 | 123 | logger.info(f"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}") 124 | if with_tracking: 125 | accelerator.log( 126 | { 127 | "perplexity": perplexity, 128 | "eval_loss": eval_loss, 129 | "step": completed_steps, 130 | "train_loss": total_loss.item() / len(train_dataloader), 131 | }, 132 | step=completed_steps, 133 | ) 134 | accelerator.print(f"eval_loss: {eval_loss}") 135 | accelerator.print(f"perplexity: {perplexity}") 136 | accelerator.print(f"train_loss: {total_loss.item() / len(train_dataloader)}") -------------------------------------------------------------------------------- /docs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huu4ontocord/MDEL/d84a598e765accfb723edd58f6c0a426d8c16d8d/docs/.gitkeep -------------------------------------------------------------------------------- /lora-x/README.md: -------------------------------------------------------------------------------- 1 | # Lora-X - A library to support long context multimodal, multilingual, and multi-domain training 2 | 3 | Lora-X is intended to create a base model for MDEL. We are aiming at llama2 architecture. 4 | 5 | - TODO: confirm a100 40gb training at 70b 8K context on JUWELS 6 | - TODO: work on ways to increase length given memory constraint 7 | + test different lora configs and methods for improving perforamnce 8 | - TODO: use the ntk scaling in the HF 9 | - TODO: Add LLavaR code - in particular the projection layer 10 | - TODO: integrate bpt and test speed difference 11 | - TODO: add Hummingbird's dynamic token and media loading during training 12 | - TODO: add proxy embeddings (m-clip) 13 | 14 | # Credits 15 | This library is based on the code of participants in MDEL 16 | - https://github.com/jordiclive/scaled-rope (which is in turn based on jquesnelle/scaled-rope and Open Assistant code) 17 | - https://github.com/arnavdantuluri/long-context-transformers 18 | - reference for Hummingbird 19 | 20 | 21 | # Multinode LoRA + Flash + DS 22 | This project adds LoRA, Flash attn patch and DS (Deepspeed) to [Scaled Rope](https://github.com/jquesnelle/scaled-rope) that can be run multinode. 23 | 24 | Flash Attention 2 and LLaMA 2 ready 🚀 25 | 26 | ## Setup and Installation 27 | 28 | 1. To install the necessary dependencies, use pip to install the required Python packages: 29 | 30 | ``` 31 | pip install -r requirements.txt 32 | ``` 33 | 34 | 2. Update the `config.yaml` file as per your requirements. 35 | 36 | ## Usage 37 | 38 | To run the application, use the following command: 39 | 40 | ``` 41 | python --configs defaults 42 | ``` 43 | 44 | Replace `` with your specific configuration specified in `configs/config.yaml`. Command line arguments can also be overridden. 45 | 46 | **Please Note:** This uses the HF recent PR, so models are HF compatible. Linear scaling argument: 'interpolation_factor', i.e. how much you want to scale the model. If set to None will scale `config.max_position_embeddings / 4096`. As this is the default for LLaMA 2. 47 | 48 | 49 | ## Data 50 | - Specify packed untokenized datasets on the hub under dataset_names e.g. (`Multi-Domain-Expert-Layers/the_pile_books3_packed_128k`) 51 | - If pretokenized=True, specify a single pre-tokenized dataset on the hub under dataset_names (`conceptofmind/rp-packed-32k-no-filter` for OpenLLaMA) 52 | 53 | ### Running on a Single Node 54 | 55 | Use the following commands for running on a single node: 56 | 57 | 1. Export the necessary paths: 58 | 59 | ``` 60 | export PYTHONPATH="/mnt/data/jordiclive/scaled-rope:$PYTHONPATH" 61 | export TRANSFORMERS_CACHE="/mnt/data/jordiclive/transformers_cache" 62 | export HF_DATASETS_CACHE="/mnt/data/jordiclive/transformers_cache" 63 | export HF_HOME="/mnt/data/jordiclive/transformers_cache" 64 | export WANDB_API_KEY="" 65 | ``` 66 | 67 | 2. Run the script: 68 | 69 | ``` 70 | deepspeed --include=localhost:0,1,2,3,4,5,6,7 --master_port 61500 finetune.py --output_dir saved_ckpts_32k --configs defaults lora-7b-llama2 --deepspeed 71 | ``` 72 | 73 | ### Running on Multiple Nodes 74 | 75 | Example script using the slurm launcher with deepspeed: `scripts/juwels_booster.sh` 76 | -------------------------------------------------------------------------------- /lora-x/bpt_pt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.checkpoint import checkpoint 3 | import torch.nn as nn 4 | from transformers.activations import ACT2FN 5 | 6 | class GPTNeoXMLP(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | self.dense_h_to_4h = nn.Linear(512, 2048) 10 | self.dense_4h_to_h = nn.Linear(2048, 512) 11 | self.act = ACT2FN["gelu"] 12 | 13 | def forward(self, hidden_states): 14 | hidden_states = self.dense_h_to_4h(hidden_states) 15 | hidden_states = self.act(hidden_states) 16 | hidden_states = self.dense_4h_to_h(hidden_states) 17 | return hidden_states 18 | 19 | import torch 20 | from functools import partial 21 | from torch import nn, einsum 22 | from torch.utils.checkpoint import checkpoint 23 | import torch.nn.functional as F 24 | from .bpt_triton import matmul, add 25 | 26 | from torch.nn.functional import scaled_dot_product_attention 27 | from einops import rearrange 28 | from datetime import datetime 29 | # helper functions 30 | 31 | def exists(val): 32 | return val is not None 33 | 34 | def default(val, d): 35 | return val if exists(val) else d 36 | 37 | # regular attention 38 | 39 | def attention( 40 | q, k, v, 41 | mask = None, 42 | causal = False, 43 | attn_bias = None, 44 | **kwargs 45 | ): 46 | scale = q.shape[-1] ** -0.5 47 | q = q * scale 48 | 49 | sim = einsum('b h i d, b h j d -> b h i j', q, k) 50 | 51 | if exists(attn_bias): 52 | sim = sim + attn_bias 53 | 54 | mask_value = -torch.finfo(sim.dtype).max 55 | 56 | if exists(mask): 57 | mask = rearrange(mask, 'b j -> b 1 1 j') 58 | sim = sim.masked_fill(~mask, mask_value) 59 | 60 | if causal: 61 | i, j = sim.shape[-2:] 62 | mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1) 63 | sim = sim.masked_fill(mask, mask_value) 64 | 65 | sim = sim - sim.amax(dim = -1, keepdim = True).detach() 66 | attn = sim.softmax(dim = -1) 67 | 68 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 69 | return out 70 | 71 | # memory efficient attention 72 | 73 | def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices, dropout): 74 | q_start_index, k_start_index, q_chunk_size, k_chunk_size, device = *qk_start_indices, q.shape[-2], k.shape[-2], q.device 75 | 76 | weight = einsum('b h i d, b h j d -> b h i j', q, k) 77 | 78 | if exists(attn_bias_chunk): 79 | weight = weight + attn_bias_chunk 80 | 81 | mask_value = -torch.finfo(weight.dtype).max 82 | 83 | if exists(mask): 84 | mask = rearrange(mask, 'b j -> b 1 1 j') 85 | weight = weight.masked_fill(~mask, mask_value) 86 | 87 | if causal and q_start_index < (k_start_index + k_chunk_size - 1): 88 | causal_mask = torch.ones((q_chunk_size, k_chunk_size), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1) 89 | weight = weight.masked_fill(causal_mask, mask_value) 90 | 91 | weight_max = weight.amax(dim = -1, keepdim = True).detach() 92 | weight = weight - weight_max 93 | 94 | exp_weight = weight.exp() 95 | 96 | exp_weight = F.dropout(exp_weight, p = dropout) 97 | 98 | weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v) 99 | 100 | return exp_weight.sum(dim = -1), weighted_value, rearrange(weight_max, '... 1 -> ...') 101 | 102 | checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk) 103 | 104 | def memory_efficient_attention( 105 | q, k, v, 106 | mask = None, 107 | causal = False, 108 | attn_bias = None, 109 | q_bucket_size = 512, 110 | k_bucket_size = 1024, 111 | eps = 1e-8, 112 | dropout = 0., 113 | training = False 114 | ): 115 | scale = q.shape[-1] ** -0.5 116 | q = q * scale 117 | 118 | # function 119 | 120 | needs_backwards = q.requires_grad or k.requires_grad or v.requires_grad 121 | summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk 122 | 123 | # chunk all the inputs 124 | 125 | q_chunks = q.split(q_bucket_size, dim = -2) 126 | k_chunks = k.split(k_bucket_size, dim = -2) 127 | v_chunks = v.split(k_bucket_size, dim = -2) 128 | mask_chunks = mask.split(k_bucket_size, dim = -1) if exists(mask) else ((None,) * len(k_chunks)) 129 | 130 | if exists(attn_bias): 131 | i, j = attn_bias.shape[-2:] 132 | attn_bias_chunks = attn_bias.split(q_bucket_size, dim = -2) 133 | attn_bias_chunks = list(map(lambda t: t.split(k_bucket_size, dim = -1), attn_bias_chunks)) 134 | 135 | # loop through all chunks and accumulate 136 | 137 | out = [] 138 | for q_index, q_chunk in enumerate(q_chunks): 139 | exp_weights = [] 140 | weighted_values = [] 141 | weight_maxes = [] 142 | 143 | for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)): 144 | q_start_index = q_index * q_bucket_size 145 | k_start_index = k_index * k_bucket_size 146 | 147 | if causal and k_start_index > (q_start_index + q_chunk.shape[-2] - 1): 148 | # if chunk is to be all masked out causally, skip 149 | continue 150 | 151 | attn_bias_chunk = attn_bias_chunks[q_index][k_index] if exists(attn_bias) else None 152 | 153 | exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn( 154 | q_chunk, 155 | k_chunk, 156 | v_chunk, 157 | mask_chunk, 158 | attn_bias_chunk, 159 | causal, 160 | (q_start_index, k_start_index), 161 | dropout if training else 0. 162 | ) 163 | 164 | exp_weights.append(exp_weight_chunk) 165 | weighted_values.append(weighted_value_chunk) 166 | weight_maxes.append(weight_max_chunk) 167 | 168 | weight_maxes = torch.stack(weight_maxes, dim = -1) 169 | 170 | weighted_values = torch.stack(weighted_values, dim = -1) 171 | exp_weights = torch.stack(exp_weights, dim = -1) 172 | 173 | global_max = weight_maxes.amax(dim = -1, keepdim = True) 174 | renorm_factor = (weight_maxes - global_max).exp().detach() 175 | 176 | exp_weights = exp_weights * renorm_factor 177 | weighted_values = weighted_values * rearrange(renorm_factor, '... c -> ... 1 c') 178 | 179 | all_values = weighted_values.sum(dim = -1) 180 | all_weights = exp_weights.sum(dim = -1) 181 | 182 | normalized_values = all_values / (rearrange(all_weights, '... -> ... 1') + eps) 183 | out.append(normalized_values) 184 | 185 | return torch.cat(out, dim = -2) 186 | 187 | def blockwise_compute_ffn(cell, inputs, chunk_size): 188 | inputs = torch.split(inputs, chunk_size, dim=1) 189 | num_q = len(inputs) 190 | 191 | def ffn(cell, _, hidden_states): 192 | outputs = cell(hidden_states) 193 | return outputs 194 | 195 | outputs = [] 196 | for i in range(num_q): 197 | outputs.append(ffn(cell, None, inputs[i])) 198 | 199 | res = torch.concat(outputs, dim=1) 200 | # res = rearrange(res, 'n b c d -> b (n c) d') 201 | return res 202 | 203 | if __name__ == "__main__": 204 | # Blocked mem stuff 205 | q = torch.rand(2, 512, 16, 128) 206 | k = torch.rand(2, 512, 16, 128) 207 | v = torch.rand(2, 512, 16, 128) 208 | bias = torch.rand(2, 1, 512, 2048) 209 | 210 | # Blocked FFN Stuff 211 | x = torch.rand(2, 256, 512) 212 | cell = GPTNeoXMLP() 213 | startTime = datetime.now() 214 | y_pt_mem = memory_efficient_attention(q, k, v, q_bucket_size=512, k_bucket_size=512) 215 | print('pythonic mem eff attn', datetime.now() - startTime) 216 | 217 | torch.backends.cuda.sdp_kernel(True) 218 | torch.backends.cuda.enable_flash_sdp(True) 219 | startTime = datetime.now() 220 | y_pt_mem = scaled_dot_product_attention(q, k, v) 221 | print('pythonic mem eff attn', datetime.now() - startTime) 222 | 223 | startTime = datetime.now() 224 | y_pt_ffn = blockwise_compute_ffn(cell, x, 256) 225 | print('pythonic blocked ffn', datetime.now() - startTime) 226 | 227 | startTime = datetime.now() 228 | y_pt_ffn = blockwise_compute_ffn_triton(cell, x) 229 | print('pythonic blocked ffn', datetime.now() - startTime) 230 | print(y_pt_ffn.shape) 231 | -------------------------------------------------------------------------------- /lora-x/bpt_triton.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import triton 4 | import triton.language as tl 5 | from datetime import datetime 6 | 7 | @triton.jit 8 | def matmul_kernel( 9 | # Pointers to matrices 10 | a_ptr, b_ptr, c_ptr, 11 | # Matrix dimensions 12 | M, N, K, 13 | # The stride variables represent how much to increase the ptr by when moving by 1 14 | # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` 15 | # by to get the element one row down (A has M rows). 16 | stride_am, stride_ak, 17 | stride_bk, stride_bn, 18 | stride_cm, stride_cn, 19 | # Meta-parameters 20 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, 21 | GROUP_SIZE_M: tl.constexpr, 22 | ): 23 | """Kernel for computing the matmul C = A x B. 24 | A has shape (M, K), B has shape (K, N) and C has shape (M, N) 25 | """ 26 | # ----------------------------------------------------------- 27 | # Map program ids `pid` to the block of C it should compute. 28 | # This is done in a grouped ordering to promote L2 data reuse. 29 | # See above `L2 Cache Optimizations` section for details. 30 | pid = tl.program_id(axis=0) 31 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 32 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 33 | num_pid_in_group = GROUP_SIZE_M * num_pid_n 34 | group_id = pid // num_pid_in_group 35 | first_pid_m = group_id * GROUP_SIZE_M 36 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 37 | pid_m = first_pid_m + (pid % group_size_m) 38 | pid_n = (pid % num_pid_in_group) // group_size_m 39 | 40 | # ---------------------------------------------------------- 41 | # Create pointers for the first blocks of A and B. 42 | # We will advance this pointer as we move in the K direction 43 | # and accumulate 44 | # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers 45 | # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers 46 | # See above `Pointer Arithmetics` section for details 47 | offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M 48 | offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N 49 | offs_k = tl.arange(0, BLOCK_SIZE_K) 50 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 51 | b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 52 | 53 | # ----------------------------------------------------------- 54 | # Iterate to compute a block of the C matrix. 55 | # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block 56 | # of fp32 values for higher accuracy. 57 | # `accumulator` will be converted back to fp16 after the loop. 58 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 59 | for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 60 | # Load the next block of A and B, generate a mask by checking the K dimension. 61 | # If it is out of bounds, set it to 0. 62 | a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) 63 | b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) 64 | # We accumulate along the K dimension. 65 | accumulator += tl.dot(a, b) 66 | # Advance the ptrs to the next K block. 67 | a_ptrs += BLOCK_SIZE_K * stride_ak 68 | b_ptrs += BLOCK_SIZE_K * stride_bk 69 | # You can fuse arbitrary activation functions here 70 | # while the accumulator is still in FP32! 71 | c = accumulator.to(tl.float16) 72 | 73 | # ----------------------------------------------------------- 74 | # Write back the block of the output matrix C with masks. 75 | offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 76 | offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 77 | c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 78 | c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 79 | tl.store(c_ptrs, c, mask=c_mask) 80 | 81 | 82 | # We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`. 83 | @triton.jit 84 | def leaky_relu(x): 85 | x = x + 1 86 | return tl.where(x >= 0, x, 0.01 * x) 87 | 88 | @torch.jit.script 89 | def gelu(x): 90 | return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) 91 | 92 | def matmul(a, b, block_size_m=32, block_size_n=32, block_size_k=32, group_size_m=8): 93 | # Check constraints. 94 | assert a.shape[1] == b.shape[0], "Incompatible dimensions" 95 | assert a.is_contiguous(), "Matrix A must be contiguous" 96 | assert b.is_contiguous(), "Matrix B must be contiguous" 97 | M, K = a.shape 98 | K, N = b.shape 99 | # Allocates output. 100 | c = torch.empty((M, N), device=a.device, dtype=a.dtype) 101 | # 1D launch kernel where each block gets its own program. 102 | grid = lambda META: ( 103 | triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), 104 | ) 105 | matmul_kernel[grid]( 106 | a, b, c, 107 | M, N, K, 108 | a.stride(0), a.stride(1), 109 | b.stride(0), b.stride(1), 110 | c.stride(0), c.stride(1), 111 | block_size_m, block_size_n, block_size_k, group_size_m 112 | ) 113 | return c 114 | 115 | @triton.jit 116 | def add_kernel( 117 | x_ptr, # *Pointer* to first input vector. 118 | y_ptr, # *Pointer* to second input vector. 119 | output_ptr, # *Pointer* to output vector. 120 | n_elements, # Size of the vector. 121 | BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. 122 | # NOTE: `constexpr` so it can be used as a shape value. 123 | ): 124 | # There are multiple 'programs' processing different data. We identify which program 125 | # we are here: 126 | pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. 127 | # This program will process inputs that are offset from the initial data. 128 | # For instance, if you had a vector of length 256 and block_size of 64, the programs 129 | # would each access the elements [0:64, 64:128, 128:192, 192:256]. 130 | # Note that offsets is a list of pointers: 131 | block_start = pid * BLOCK_SIZE 132 | offsets = block_start + tl.arange(0, BLOCK_SIZE) 133 | # Create a mask to guard memory operations against out-of-bounds accesses. 134 | mask = offsets < n_elements 135 | # Load x and y from DRAM, masking out any extra elements in case the input is not a 136 | # multiple of the block size. 137 | x = tl.load(x_ptr + offsets, mask=mask) 138 | y = tl.load(y_ptr + offsets, mask=mask) 139 | output = x + y 140 | # Write x + y back to DRAM after activation 141 | tl.store(output_ptr + offsets, output, mask=mask) 142 | 143 | def add(x: torch.Tensor, y: torch.Tensor): 144 | # We need to preallocate the output. 145 | output = torch.empty_like(x) 146 | assert x.is_cuda and y.is_cuda and output.is_cuda 147 | n_elements = output.numel() 148 | # The SPMD launch grid denotes the number of kernel instances that run in parallel. 149 | # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. 150 | # In this case, we use a 1D grid where the size is the number of blocks: 151 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 152 | # NOTE: 153 | # - Each torch.tensor object is implicitly converted into a pointer to its first element. 154 | # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. 155 | # - Don't forget to pass meta-parameters as keywords arguments. 156 | add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) 157 | # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still 158 | # running asynchronously at this point. 159 | return output 160 | 161 | def forward(x: torch.Tensor, w: torch.Tensor, b:torch.Tensor): 162 | output = matmul(x, w) 163 | # output = add(output, b) 164 | return output 165 | 166 | if __name__ == "__main__": 167 | torch.manual_seed(0) 168 | a = torch.randn((512, 512), device='cuda', dtype=torch.float16) 169 | layer = torch.nn.Linear(512, 512).cuda().half() 170 | start = datetime.now() 171 | triton_output = forward(a, layer.weight.T.contiguous(), layer.bias) 172 | print("triton", datetime.now()- start) 173 | 174 | start = datetime.now() 175 | torch_output = layer(a) 176 | print("torch", datetime.now()- start) 177 | 178 | print("diff", (torch_output - triton_output).abs()) 179 | if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0): 180 | print("Triton and Torch match") 181 | else: 182 | print("Triton and Torch differ") 183 | -------------------------------------------------------------------------------- /lora-x/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from distutils.util import strtobool 4 | from pathlib import Path 5 | 6 | import yaml 7 | 8 | 9 | def _strtobool(x): 10 | return bool(strtobool(x)) 11 | 12 | 13 | def read_yamls(dir): 14 | conf = {} 15 | no_conf = True 16 | 17 | for config_file in Path(dir).glob("**/*.yaml"): 18 | no_conf = False 19 | with config_file.open("r") as f: 20 | conf.update(yaml.safe_load(f)) 21 | 22 | if no_conf: 23 | print(f"WARNING: No yaml files found in {dir}") 24 | 25 | return conf 26 | 27 | 28 | def rank_zero_info(msg) -> None: 29 | local_rank = int(os.getenv("LOCAL_RANK", "0")) 30 | if local_rank in (None, 0): 31 | print(msg) 32 | 33 | 34 | def argument_parsing(notebook=False, notebook_args=None): 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument( 37 | "--configs", 38 | nargs="+", 39 | required=True, 40 | help=""" 41 | Multiple configs can be passed to set different options. 42 | For example, run as: 43 | 44 | ./trainer_sft.py --configs default juweles 45 | 46 | """, 47 | ) 48 | parser.add_argument("--local_rank", type=int, default=-1) 49 | parser.add_argument("--deepspeed", action="store_true") 50 | parser.add_argument("--no_deepspeed", action="store_true") 51 | parser.add_argument("--wandb-entity", type=str, default="open-assistant") 52 | parser.add_argument( 53 | "--resume_from_checkpoint", 54 | action="store_true", 55 | help="Resume from last saved checkpoint", 56 | ) 57 | parser.add_argument("--rng_seed", type=int, help="rng seed") 58 | parser.add_argument( 59 | "--show_dataset_stats", 60 | action="store_true", 61 | help="Show dataset stats", 62 | default=False, 63 | ) 64 | parser.set_defaults(deepspeed=False) 65 | 66 | if notebook: 67 | args, remaining = parser.parse_known_args(notebook_args) 68 | else: 69 | args, remaining = parser.parse_known_args() 70 | 71 | # Config from YAML 72 | conf = {} 73 | configs = read_yamls("configs/") 74 | 75 | conf.update(configs["defaults"]) 76 | try: 77 | for name in args.configs: 78 | if "," in name: 79 | for n in name.split(","): 80 | conf.update(configs[n]) 81 | else: 82 | conf.update(configs[name]) 83 | except KeyError as e: 84 | print(f'Error: Could not find the config "{e.args[0]}" in config.yaml') 85 | exit(1) 86 | 87 | conf["wandb_entity"] = args.wandb_entity 88 | conf["local_rank"] = args.local_rank 89 | conf["deepspeed"] = args.deepspeed 90 | if args.no_deepspeed: 91 | conf["deepspeed"] = None 92 | conf["resume_from_checkpoint"] = args.resume_from_checkpoint 93 | if args.rng_seed is not None: 94 | conf["rng_seed"] = args.rng_seed 95 | conf["show_dataset_stats"] = args.show_dataset_stats 96 | 97 | # get the world size in deeepspeed 98 | if conf["deepspeed"]: 99 | conf["world_size"] = int(os.getenv("WORLD_SIZE", default="1")) 100 | else: 101 | conf["world_size"] = 1 102 | 103 | # Override config from command-line 104 | parser = argparse.ArgumentParser() 105 | for key, value in conf.items(): 106 | type_ = type(value) if value is not None else str 107 | if type_ == bool: 108 | type_ = _strtobool 109 | parser.add_argument(f"--{key}", type=type_, default=value) 110 | # Allow --no-{key} to remove it completely 111 | parser.add_argument(f"--no-{key}", dest=key, action="store_const", const=None) 112 | 113 | return parser.parse_args(remaining) 114 | -------------------------------------------------------------------------------- /lora-x/configs/zero3_offload_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "scheduler": { 6 | "type": "WarmupLR", 7 | "params": { 8 | "warmup_min_lr": "auto", 9 | "warmup_max_lr": "auto", 10 | "warmup_num_steps": "auto" 11 | } 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 3, 24 | "offload_optimizer": { 25 | "device": "cpu", 26 | "pin_memory": true 27 | }, 28 | "offload_param": { 29 | "device": "cpu", 30 | "pin_memory": true 31 | }, 32 | "overlap_comm": true, 33 | "contiguous_gradients": true, 34 | "sub_group_size": 1e9, 35 | "reduce_bucket_size": "auto", 36 | "stage3_prefetch_bucket_size": "auto", 37 | "stage3_param_persistence_threshold": "auto", 38 | "stage3_max_live_parameters": 1e9, 39 | "stage3_max_reuse_distance": 1e9, 40 | "stage3_gather_16bit_weights_on_model_save": true 41 | }, 42 | "gradient_accumulation_steps": "auto", 43 | "gradient_clipping": "auto", 44 | "steps_per_print": 2000, 45 | "train_batch_size": "auto", 46 | "train_micro_batch_size_per_gpu": "auto", 47 | "wall_clock_breakdown": false 48 | } 49 | -------------------------------------------------------------------------------- /lora-x/data.py: -------------------------------------------------------------------------------- 1 | #from https://github.com/jordiclive/scaled-rope 2 | from dataclasses import dataclass 3 | from typing import NamedTuple, Optional, Union 4 | 5 | import datasets 6 | import numpy as np 7 | from sklearn.model_selection import train_test_split 8 | from torch.utils.data import ConcatDataset, Dataset, Subset 9 | from transformers.tokenization_utils_base import (PaddingStrategy, 10 | PreTrainedTokenizerBase, 11 | TruncationStrategy) 12 | 13 | 14 | class DatasetEntryLm(NamedTuple): 15 | """Language modelling dataset entry""" 16 | 17 | text: Union[str, None] = None 18 | 19 | 20 | class LMDataset(Dataset): 21 | name = "LMDataset" 22 | 23 | def __init__(self, dataset_name, char_max_len: str = 200000) -> None: 24 | super().__init__() 25 | self.char_max_len = char_max_len 26 | self.dataset = datasets.load_dataset(dataset_name)["train"] 27 | 28 | def __len__(self) -> int: 29 | return len(self.dataset) 30 | 31 | def __getitem__(self, index) -> DatasetEntryLm: 32 | dialogue = DatasetEntryLm(text=self.dataset[index]["text"][: self.char_max_len]) 33 | return dialogue 34 | 35 | 36 | @dataclass 37 | class DataCollator: 38 | tokenizer: PreTrainedTokenizerBase 39 | padding: Union[bool, str, PaddingStrategy] = True 40 | max_length: Optional[int] = None 41 | mix_length_threshold: Optional[int] = 256 42 | mix_probability: Optional[float] = 0.6 43 | pad_to_multiple_of: Optional[int] = None 44 | samples_mixing: Optional[bool] = False 45 | 46 | def __post_init__(self): 47 | assert self.tokenizer.eos_token 48 | 49 | def process_one(self, messages, return_length=False): 50 | truncation = TruncationStrategy.LONGEST_FIRST 51 | max_length = self.max_length 52 | 53 | messages = messages.text 54 | 55 | flatten_message = self.tokenizer( 56 | "".join(messages), 57 | max_length=max_length, 58 | truncation=truncation, 59 | padding=False, 60 | return_token_type_ids=False, 61 | ) 62 | 63 | label_mask = np.ones(len(flatten_message.input_ids), dtype=bool) 64 | return flatten_message, label_mask, 0 65 | 66 | def __call__(self, features): 67 | flatten_messages = [] 68 | label_masks = [] 69 | total_short_context = 0 70 | for messages in features: 71 | flatten_message, label_mask, total_short_context_one = self.process_one( 72 | messages 73 | ) 74 | flatten_messages.append(flatten_message) 75 | label_masks.append(label_mask) 76 | total_short_context += total_short_context_one 77 | 78 | batch = self.tokenizer.pad( 79 | flatten_messages, 80 | padding=self.padding, 81 | pad_to_multiple_of=self.pad_to_multiple_of, 82 | return_tensors="pt", 83 | ) 84 | batch["labels"] = batch["input_ids"].clone() 85 | 86 | return batch 87 | 88 | 89 | def train_val_dataset(dataset, val_split=0.2): 90 | if val_split == 0: 91 | return dataset, None 92 | 93 | train_idx, val_idx = train_test_split( 94 | list(range(len(dataset))), test_size=val_split, random_state=666, shuffle=True 95 | ) 96 | return Subset(dataset, train_idx), Subset(dataset, val_idx) 97 | 98 | 99 | def get_one_dataset( 100 | conf, 101 | val_split: float = 0.025, 102 | data_path: str = None, 103 | mode: str = "sft", 104 | max_val_set: Optional[int] = 50, 105 | **kwargs, 106 | ): 107 | data_path = data_path or conf.cache_dir 108 | # dataset_name = dataset_name.lower() 109 | train_datasets = [] 110 | eval_datasets = [] 111 | for data_file in conf.dataset_names: 112 | dataset = LMDataset(data_file) 113 | 114 | # if eval not already defined 115 | if not ("eval" in locals() and "train" in locals()): 116 | train, eval = train_val_dataset(dataset, val_split=val_split) 117 | 118 | if eval and max_val_set and len(eval) > max_val_set: 119 | subset_indices = np.random.choice(len(eval), max_val_set) 120 | eval = Subset(eval, subset_indices) 121 | train_datasets.append(train) 122 | eval_datasets.append(eval) 123 | 124 | train = ConcatDataset(train_datasets) 125 | eval = ConcatDataset(eval_datasets) 126 | return train, eval 127 | -------------------------------------------------------------------------------- /lora-x/experimental/train_qlora_lomo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig 3 | 4 | model_id = "NousResearch/Llama-2-7b-hf" 5 | bnb_config = BitsAndBytesConfig( 6 | load_in_4bit=True, 7 | bnb_4bit_use_double_quant=True, 8 | bnb_4bit_quant_type="nf4", 9 | bnb_4bit_compute_dtype=torch.bfloat16 10 | ) 11 | 12 | tokenizer = AutoTokenizer.from_pretrained(model_id) 13 | model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0}) 14 | 15 | from peft import prepare_model_for_kbit_training 16 | 17 | model.gradient_checkpointing_enable() 18 | model = prepare_model_for_kbit_training(model) 19 | 20 | def print_trainable_parameters(model): 21 | """ 22 | Prints the number of trainable parameters in the model. 23 | """ 24 | trainable_params = 0 25 | all_param = 0 26 | for _, param in model.named_parameters(): 27 | all_param += param.numel() 28 | if param.requires_grad: 29 | trainable_params += param.numel() 30 | print( 31 | f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" 32 | ) 33 | from datasets import load_dataset 34 | 35 | data = load_dataset("Abirate/english_quotes") 36 | data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True) 37 | from peft import LoraConfig, get_peft_model 38 | 39 | config = LoraConfig( 40 | r=8, 41 | lora_alpha=32, 42 | #target_modules=["query_key_value"], 43 | lora_dropout=0.05, 44 | bias="none", 45 | task_type="CAUSAL_LM" 46 | ) 47 | 48 | model = get_peft_model(model, config) 49 | print_trainable_parameters(model) 50 | import transformers 51 | 52 | # needed for gpt-neo-x tokenizer 53 | tokenizer.pad_token = tokenizer.eos_token 54 | import transformers 55 | import qlora_lomo 56 | trainer = qlora_lomo.Trainer( 57 | model=model, 58 | train_dataset=data["train"], 59 | args=qlora_lomo.TrainingArguments( 60 | per_device_train_batch_size=1, 61 | gradient_accumulation_steps=4, 62 | warmup_steps=2, 63 | max_steps=10, 64 | learning_rate=2e-4, 65 | fp16=True, 66 | logging_steps=1, 67 | output_dir="outputs", 68 | optim="LOMO" # paged_adamw_8bit 69 | ), 70 | data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False), 71 | ) 72 | model.config.use_cache = False # silence the warnings. Please re-enable for inference! 73 | trainer.train() 74 | 75 | -------------------------------------------------------------------------------- /lora-x/lora.py: -------------------------------------------------------------------------------- 1 | #from https://github.com/jordiclive/scaled-rope/blob/hf_version/lora.py 2 | from pathlib import Path 3 | 4 | import torch 5 | from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training 6 | 7 | 8 | def prepare_model_for_gradient_checkpointing(model): 9 | r""" 10 | Prepares the model for gradient checkpointing if necessary 11 | """ 12 | if not getattr(model, "is_loaded_in_8bit", False): 13 | if hasattr(model, "enable_input_require_grads"): 14 | model.enable_input_require_grads() 15 | else: 16 | 17 | def make_inputs_require_grad(module, input, output): 18 | output.requires_grad_(True) 19 | 20 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 21 | return model 22 | 23 | 24 | def peft_model( 25 | model, peft_config, model_name, int8_training=False, gradient_checkpointing=False 26 | ): 27 | 28 | if "falcon" in model_name: 29 | target_modules = ["dense_4h_to_h", "dense", "query_key_value", "dense_h_to_4h"] 30 | 31 | elif "llama" in model_name: 32 | target_modules = [ 33 | "down_proj", 34 | "k_proj", 35 | "q_proj", 36 | "gate_proj", 37 | "o_proj", 38 | "up_proj", 39 | "v_proj", 40 | ] 41 | else: 42 | raise ValueError( 43 | f"Invalid model name '{model_name}'. The model name should contain 'falcon' or 'llama'" 44 | ) 45 | config = LoraConfig( 46 | r=peft_config["r"], 47 | lora_alpha=peft_config["alpha"], 48 | target_modules=target_modules, 49 | lora_dropout=peft_config["dropout"], 50 | bias="none", 51 | task_type="CAUSAL_LM", 52 | ) 53 | 54 | model = get_peft_model(model, config) 55 | if int8_training: 56 | model = prepare_model_for_int8_training(model) 57 | 58 | if gradient_checkpointing: 59 | model = prepare_model_for_gradient_checkpointing(model) 60 | model.print_trainable_parameters() 61 | return model 62 | 63 | 64 | def load_peft_finetuned_model(model, peft_model_path): 65 | adapters_weights = torch.load( 66 | Path(peft_model_path).joinpath("adapter_model.bin"), map_location=model.device 67 | ) 68 | model.load_state_dict(adapters_weights, strict=False) 69 | return model 70 | -------------------------------------------------------------------------------- /lora-x/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | accelerate==0.20.3 3 | datasets==2.13.1 4 | deepspeed==0.9.5 5 | git+https://github.com/Dao-AILab/flash-attention.git 6 | #flash-attn==1.0.5 7 | peft==0.3.0 8 | protobuf==3.19.6 9 | scikit-learn==1.3.0 10 | sentencepiece==0.1.99 11 | setuptools==59.5.0 12 | transformers==4.31.0 13 | wandb==0.15.5 14 | einops==0.6.1 15 | -------------------------------------------------------------------------------- /lora-x/scripts/juwels_booster.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # SLURM Configuration 3 | #SBATCH --account=cstdl 4 | #SBATCH --nodes=1 5 | #SBATCH --gres=gpu:4 6 | #SBATCH --ntasks-per-node=1 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH --partition develbooster 9 | 10 | # JUWELS Configuration 11 | conda deactivate 12 | module purge 13 | ml use $OTHERSTAGES 14 | module load Stages/2023 GCC/11.3.0 OpenMPI/4.1.4 15 | module load CUDA/11.7 16 | 17 | # Network Configuration 18 | export NCCL_IB_TIMEOUT=50 19 | export UCX_RC_TIMEOUT=4s 20 | export NCCL_IB_RETRY_CNT=10 21 | export NCCL_ASYNC_ERROR_HANDLING=1 22 | 23 | # Environment Configuration 24 | source /p/home/jusers/clive1/juwels/clive1/miniconda3/bin/activate jordan_lora 25 | export WANDB_API_KEY="d8216641d549f9bb3d0c5074baa39e15dfd55030" 26 | export HUGGING_FACE_HUB_TOKEN="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" 27 | export PYTHONPATH="/p/home/jusers/clive1/juwels/clive1/scaled-rope:$PYTHONPATH" 28 | export TRANSFORMERS_CACHE="/p/home/jusers/clive1/juwels/clive1/transformers_cache" 29 | export HF_DATASETS_CACHE="/p/home/jusers/clive1/juwels/clive1/transformers_cache" 30 | export HF_HOME="/p/home/jusers/clive1/juwels/clive1/transformers_cache" 31 | export PATH="/p/software/juwelsbooster/stages/2023/software/OpenMPI/4.1.4-GCC-11.3.0/bin:$PATH" 32 | 33 | # Juwls specific env 34 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 35 | export WANDB_MODE="offline" 36 | export TRANSFORMERS_OFFLINE=1 37 | 38 | # SLURM Host Configuration 39 | hostfile='/p/home/jusers/clive1/juwels/hostfiles/hostfile.txt' 40 | rm $hostfile 41 | 42 | for i in `scontrol show hostnames $SLURM_NODELIST` 43 | do 44 | echo $i slots=4 >>$hostfile 45 | done 46 | 47 | export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` 48 | export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 49 | export MASTER_PORT=12802 50 | export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` 51 | export DLTS_HOSTFILE=$hostfile 52 | 53 | 54 | # Print System Information 55 | echo "GPUs available to job: $SLURM_JOB_GPUS" 56 | echo "Total tasks: $SLURM_NTASKS" 57 | 58 | deepspeed --master_port 12802 \ 59 | --launcher slurm \ 60 | --hostfile '/p/home/jusers/clive1/juwels/hostfiles/hostfile.txt' \ 61 | --master_addr $MASTER_ADDR \ 62 | --no_ssh_check \ 63 | /p/home/jusers/clive1/juwels/clive1/scaled-rope/finetune.py \ 64 | --output_dir saved_ckpts_32k \ 65 | --configs defaults lora-7b-llama2 \ 66 | --deepspeed 67 | -------------------------------------------------------------------------------- /lora-x/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | from transformers.activations import (FastGELUActivation, GELUActivation, 5 | NewGELUActivation, QuickGELUActivation) 6 | 7 | 8 | def rsetattr(obj, attr, val): 9 | pre, _, post = attr.rpartition(".") 10 | return setattr(rgetattr(obj, pre) if pre else obj, post, val) 11 | 12 | 13 | def rgetattr(obj, attr, *args): 14 | def _getattr(obj, attr): 15 | return getattr(obj, attr, *args) 16 | 17 | return functools.reduce(_getattr, [obj] + attr.split(".")) 18 | 19 | 20 | def fuse_gelu(model): 21 | @torch.jit.script 22 | def gelu_fwd(x): 23 | return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) 24 | 25 | @torch.jit.script 26 | def gelu_bwd(g, x): 27 | tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) 28 | ff = 0.5 * x * ( 29 | (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x) 30 | ) + 0.5 * (1 + tanh_out) 31 | return ff * g 32 | 33 | class _FusedGeLUFunction(torch.autograd.Function): 34 | @staticmethod 35 | # bias is an optional argument 36 | def forward(ctx, input): 37 | ctx.input_tensor = input 38 | return gelu_fwd(input) 39 | 40 | @staticmethod 41 | def backward(ctx, grad_output): 42 | input = ctx.input_tensor 43 | tmp = gelu_bwd(grad_output, input) 44 | return tmp 45 | 46 | class FusedGelu(torch.nn.Module): 47 | def forward(self, input): 48 | return _FusedGeLUFunction.apply(input) 49 | 50 | fused_gelu_module = FusedGelu() 51 | hf_gelu_functions = [ 52 | GELUActivation, 53 | FastGELUActivation, 54 | NewGELUActivation, 55 | QuickGELUActivation, 56 | ] 57 | 58 | for name, module in model.named_modules(): 59 | for hf_gelu_function in hf_gelu_functions: 60 | if isinstance(module, hf_gelu_function): 61 | rsetattr(model, name, fused_gelu_module) 62 | 63 | return model 64 | -------------------------------------------------------------------------------- /notebooks/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huu4ontocord/MDEL/d84a598e765accfb723edd58f6c0a426d8c16d8d/notebooks/.gitkeep -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate~=0.18.0 2 | datasets~=2.11.0 3 | deepspeed==0.9.2 4 | evaluate~=0.4.0 5 | pre-commit~=2.21.0 6 | scikit-learn~=1.2.2 7 | torch~=2.0.0 8 | tqdm~=4.65.0 9 | transformers~=4.28.1 10 | ujson~=5.7.0 11 | wandb==0.15.2 12 | zstandard~=0.21.0 13 | -------------------------------------------------------------------------------- /resources.md: -------------------------------------------------------------------------------- 1 | https://github.com/TehVenomm/LM_Transformers_BlockMerge 2 | -------------------------------------------------------------------------------- /scripts/c-btmInference.py: -------------------------------------------------------------------------------- 1 | from clustering.feature_extractor import FeatureExtractor 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | import torch 4 | from torch.nn import PairwiseDistance 5 | import numpy as np 6 | import fire 7 | 8 | 9 | def topKFilter(ensembleWeights, k): 10 | """ 11 | Filters and normalizes top k ensemble weights. 12 | 13 | Parameters: 14 | - ensembleWeights: list of ensemble weights as calculated in findEnsembleWeights 15 | - k: number of top experts to choose 16 | 17 | Returns: 18 | Tuple with the normalized weights and corresponding indices as they pertain to the top k domains. 19 | """ 20 | topK = torch.topk(ensembleWeights, k=k) 21 | indices = topK.indices 22 | 23 | if torch.sum(topK.values) == 0: 24 | return [1/k for i in range(k)], indices 25 | 26 | normalizedTopKValues = [float(p)/torch.sum(topK.values) for p in topK.values] 27 | 28 | return normalizedTopKValues, indices 29 | 30 | 31 | 32 | def findTopKModelProbabilities(indices, query, models, tokenizers): 33 | """ 34 | Find the probabilities of all tokens appearing in the next step for each model in the given list. 35 | 36 | Parameters: 37 | - indices: A list of indices representing the models to consider. 38 | - query: The input query for which to predict the next token probabilities. 39 | - models: A list of PyTorch models for predicting next token probabilities. 40 | - tokenizers: A list of tokenizers corresponding to the models. 41 | 42 | Returns: 43 | List of tensors, where each tensor contains the probabilities of the next token 44 | for the corresponding model in the input list. 45 | 46 | The function iterates over the specified models, tokenizes the input query, and 47 | calculates the probabilities of all tokens in the next step for each model. The results are 48 | collected into a list and returned. 49 | 50 | Note: The models and tokenizers in the input lists should correspond to each other. 51 | The length of indices, models, and tokenizers should be the same. 52 | """ 53 | all_models_next_token_probs = [] 54 | 55 | for index in range(len(indices)): 56 | 57 | input_ids = tokenizers[index].encode(query, return_tensors="pt") 58 | output = models[index](input_ids) 59 | next_token_probs = output.logits[:, -1, :].softmax(dim=-1) 60 | all_models_next_token_probs.append(next_token_probs) 61 | 62 | return all_models_next_token_probs 63 | 64 | 65 | 66 | def findEnsembleWeights(embedder, query, clusterCenters, T, firstToken=False, prompt_data=None): 67 | """ 68 | Finds ensemble weights based on distance of query and cluster centers. 69 | 70 | Parameters: 71 | - embedder: the embedder that was used to train the clustering model 72 | - query: the string prompt from input_file 73 | - clusterCenters: dictionary of cluster centers with the keys as domain{i}, as per the input_file 74 | - T: temperature parameter for softmax 75 | - firstToken: boolean to indicate if the first token is yet to be generated 76 | - prompt_data: the entire json object for the prompt 77 | 78 | Returns: 79 | list the distance between the queries and each of the cluster centers. 80 | """ 81 | ensembleWeights = [] 82 | pdist = torch.nn.PairwiseDistance(p=2) 83 | 84 | for i in range(len(clusterCenters)): 85 | 86 | if firstToken: 87 | 88 | distance = torch.tensor(prompt_data['meta'][f'domain_score{i+1}']) # this should be the L2 norm between the prompt and the cluster center 89 | 90 | else: 91 | 92 | cluster = clusterCenters[f'domain{i+1}'] 93 | 94 | embedded_query = embedder(query) 95 | 96 | if type(embedded_query) is not torch.Tensor: 97 | 98 | embedded_query = torch.Tensor(embedded_query) 99 | 100 | distance = torch.pow(pdist(embedded_query, cluster),2) 101 | 102 | ensembleWeights.append(torch.exp(-1 * distance / T).mean()) 103 | 104 | return torch.tensor(ensembleWeights) 105 | 106 | 107 | def findModelProbabilities(embedder, query, clusterCenters, models, tokenizers, T, k, firstToken, prompt_data): 108 | """ 109 | Calculates the sum of the probabilities of the ensembled tokens using the top k most relevant domains. 110 | 111 | Parameters: 112 | - embedder: the embedder that was used to train the clustering model 113 | - query: the string prompt from input_file 114 | - models: list of models that were trained on most relevant domains 115 | - tokenizers: list of tokenizers that were trained on most relevant domains 116 | - T: temperature parameter for softmax 117 | - k: number of most relevant domains to choose from 118 | - firstToken: boolean to indicate if the first token is yet to be generated 119 | - prompt_data: the entire json object for the prompt 120 | 121 | Returns: 122 | the ensembled probabilities of all tokens at the next step. 123 | """ 124 | modelProbs = 0 125 | ensembleWeights = findEnsembleWeights(embedder, query, clusterCenters, T, firstToken, prompt_data) 126 | weights, indices = topKFilter(ensembleWeights, k) 127 | modelWeights = findTopKModelProbabilities(indices, query, models, tokenizers) 128 | 129 | for i in range(k): 130 | index = indices[i] 131 | modelProbs += modelWeights[index] * weights[index] 132 | 133 | return modelProbs 134 | 135 | 136 | def findNextToken(embedder, query, clusterCenters, models, tokenizers, T, k, firstToken, prompt_data): 137 | """ 138 | Returns decoded next-step token with highest probability. 139 | 140 | Parameters: 141 | - embedder: the embedder that was used to train the clustering model 142 | - query: the string prompt from input_file 143 | - clusterCenters: dictionary of cluster centers with the keys as domain{i}, as per the input_file 144 | - models: list of models that were trained on most relevant domains 145 | - tokenizers: list of tokenizers that were trained on most relevant domains 146 | - T: temperature parameter for softmax 147 | - k: number of most relevant domains to choose from 148 | - firstToken: boolean to indicate if the first token is yet to be generated 149 | - prompt_data: the entire json object for the prompt 150 | """ 151 | modelProbs = findModelProbabilities(embedder,query, clusterCenters, models, tokenizers, T, k, firstToken, prompt_data) 152 | return tokenizers[0].decode(np.argmax(modelProbs)) # doesn't matter which tokenizer bc they're all the same 153 | 154 | 155 | def generateSequence(embedder, prompt_data, end_token, clusterCenters, models, tokenizers, T, k, maxLength): 156 | """ 157 | Takes in prompt, end_token which ideally is uniform across all tokenizers, parameter k, 158 | temperature T and cluster centers and finds most likely token from most relevant domains based on prompt. Then, it 159 | builds sequence until end token is generated or maxLength is reached. 160 | 161 | Parameters: 162 | - embedder: the embedder that was used to train the clustering model 163 | - query: the string prompt from input_file 164 | - end_token: the token that ideally is uniform across all tokenizers 165 | - clusterCenters: dictionary of cluster centers with the keys as domain{i}, as per the input_file 166 | - models: list of models that were trained on most relevant domains 167 | - tokenizers: list of tokenizers that were trained on most relevant domains 168 | - T: temperature parameter for softmax 169 | - k: number of most relevant domains to choose from 170 | - maxLength: the maximum length the generated output can reach 171 | - firstToken: boolean to indicate if the first token is yet to be generated 172 | - prompt_data: the entire json object for the prompt 173 | 174 | Returns: 175 | Generated string sequence. 176 | """ 177 | prompt = prompt_data['text'] 178 | currToken, currSequence = None, prompt 179 | while (len(currSequence) < maxLength) and (currToken != end_token): 180 | 181 | if not currToken: 182 | currToken = findNextToken( 183 | embedder, currSequence, clusterCenters, models, tokenizers, T, k, firstToken=True, prompt_data=prompt_data) 184 | currSequence += currToken 185 | continue 186 | 187 | if currToken == end_token: 188 | break 189 | 190 | currToken = findNextToken( 191 | embedder, currSequence, clusterCenters, models, tokenizers, T, k, firstToken=False, prompt_data=None) 192 | currSequence += currToken 193 | 194 | return currSequence 195 | 196 | 197 | def load_models(model_names): 198 | """ 199 | Loads models and tokenizers as per the ids provided in model_names. 200 | 201 | Parameters: 202 | model_names: takes in list of model names and loads model and corresponding tokenizers 203 | 204 | Returns: 205 | Separate lists of models and tokenizers to be accessed by other functions later 206 | """ 207 | models = [] 208 | tokenizers = [] 209 | 210 | for model_name in model_names: 211 | model = AutoModelForCausalLM.from_pretrained(model_name) 212 | tokenizer = AutoTokenizer.from_pretrained(model_name) 213 | 214 | # Freeze the parameters of the loaded models if needed 215 | model.eval() 216 | for param in model.parameters(): 217 | param.requires_grad = False 218 | 219 | models.append(model) 220 | tokenizers.append(tokenizer) 221 | 222 | return models, tokenizers 223 | 224 | 225 | def run_inference(embedder, model_names, 226 | input_file, output_file, end_token, maxLength, k, T, clusterCenters): 227 | """ 228 | Function that takes in an input file of prompts and writes generated outputs to an output file. 229 | Parameters: 230 | - embedder: the embedder that was used to train the clustering model 231 | - model_names: list of model names 232 | - input_file: input file name, contents are prompts 233 | - output_file: where generated sequences are written to 234 | - end_token: end token to signify termination of sequence 235 | - maxLength: max length of generated sequence 236 | - k: number of most relevant domains to choose from 237 | - T: temperature parameter for softmax 238 | - clusterCenters: dictionary of cluster centers with the keys as domain{i}, as per the input_file 239 | Returns: 240 | Should write generated output to output file. 241 | """ 242 | models, tokenizers = load_models(model_names) 243 | 244 | with open(input_file, 'r') as file: 245 | input_data = file.readlines() 246 | 247 | results = [] 248 | for query in input_data: 249 | results.append(generateSequence( 250 | embedder, query, end_token, clusterCenters, models, tokenizers, T, k, maxLength) 251 | ) 252 | 253 | with open(output_file, 'w') as file: 254 | for result in results: 255 | file.write(f"{result}\n") 256 | 257 | 258 | 259 | if __name__ == '__main__': 260 | fire.Fire(run_inference) 261 | -------------------------------------------------------------------------------- /scripts/calc_perplexities.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if command -v python3 &>/dev/null; then 4 | PYTHON_CMD=python3 5 | else 6 | PYTHON_CMD=python 7 | fi 8 | 9 | for MODEL in "Multi-Domain-Expert-Layers/expert-arxiv" "Multi-Domain-Expert-Layers/expert-freelaw" "Multi-Domain-Expert-Layers/expert-github" "EleutherAI/pythia-1b-deduped" 10 | do 11 | for DATASET in "Multi-Domain-Expert-Layers/arxiv" "Multi-Domain-Expert-Layers/freelaw" "Multi-Domain-Expert-Layers/github" 12 | do 13 | for SPLIT in "validation_domain" "train" "validation_pile" 14 | do 15 | $PYTHON_CMD ../src/mdel/calculate_perplexity.py --model $MODEL --dataset $DATASET --split $SPLIT 16 | done 17 | done 18 | done 19 | -------------------------------------------------------------------------------- /scripts/calc_perplexities_slurm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if command -v python3 &>/dev/null; then 4 | PYTHON_CMD=python3 5 | else 6 | PYTHON_CMD=python 7 | fi 8 | 9 | for MODEL in "Multi-Domain-Expert-Layers/expert-arxiv" "Multi-Domain-Expert-Layers/expert-freelaw" "Multi-Domain-Expert-Layers/expert-github" "EleutherAI/pythia-1b-deduped" 10 | do 11 | for DATASET in "Multi-Domain-Expert-Layers/arxiv" "Multi-Domain-Expert-Layers/freelaw" "Multi-Domain-Expert-Layers/github" 12 | do 13 | for SPLIT in "validation_domain" "train" "validation_pile" 14 | do 15 | JOB_NAME="${MODEL}-${DATASET}-${SPLIT}" 16 | sbatch --job-name="$JOB_NAME" </dev/null; then 5 | PYTHON_CMD=python3 6 | else 7 | PYTHON_CMD=python 8 | fi 9 | 10 | SUBSET_NAME="USPTO Backgrounds" 11 | 12 | for SPLIT in "test" "val" "train" 13 | do 14 | PILE_FILE_PATH="../data/pile/$SPLIT/*.jsonl.zst" 15 | OUTPUT_DIR="../data/mix_uspto_all/$SPLIT" 16 | 17 | # Shard the input data 18 | #$PYTHON_CMD -c "from mdel.pile_utils import *; split_pile('$PILE_FILE_PATH')" 19 | 20 | $PYTHON_CMD -c "from mdel.pile_utils import *; create_pile_domain_mix('$PILE_FILE_PATH', '$PILE_FILE_PATH', '$OUTPUT_DIR', '$SUBSET_NAME')" 21 | done 22 | -------------------------------------------------------------------------------- /scripts/get_pile_shard1_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script downloads the Pile shard 1 data and puts it under data/pile_01. 3 | mkdir -p ../data 4 | mkdir -p ../data/pile 5 | mkdir -p ../data/pile/train 6 | mkdir -p ../data/pile/test 7 | mkdir -p ../data/pile/val 8 | 9 | wget https://the-eye.eu/public/AI/pile/train/01.jsonl.zst -P ../data/pile/train 10 | wget https://the-eye.eu/public/AI/pile/test.jsonl.zst -P ../data/pile/test 11 | wget https://the-eye.eu/public/AI/pile/val.jsonl.zst -P ../data/pile/val 12 | -------------------------------------------------------------------------------- /scripts/upload_to_hf.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if command -v python3 &>/dev/null; then 3 | PYTHON_CMD=python3 4 | else 5 | PYTHON_CMD=python 6 | fi 7 | 8 | HF_REPO=Multi-Domain-Expert-Layers/uspto 9 | 10 | for SPLIT in "test" "val" "train" 11 | do 12 | FOLDER_PATH=$(readlink -f ../data/mix_uspto_all/$SPLIT/) 13 | $PYTHON_CMD src/mdel/pile_upload.py --folder-path "$FOLDER_PATH" --hf-repo $HF_REPO --split $SPLIT 14 | done 15 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="mdel", 5 | version="0.1", 6 | packages=find_packages(where="src"), 7 | package_dir={"": "src"}, 8 | include_package_data=True 9 | ) 10 | -------------------------------------------------------------------------------- /src/mdel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huu4ontocord/MDEL/d84a598e765accfb723edd58f6c0a426d8c16d8d/src/mdel/__init__.py -------------------------------------------------------------------------------- /src/mdel/calculate_perplexity.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import math 4 | import time 5 | 6 | from datasets import load_dataset 7 | from transformers import (AutoModelForCausalLM, AutoTokenizer, 8 | DataCollatorForLanguageModeling, Trainer, 9 | TrainingArguments) 10 | 11 | 12 | def load_model(args): 13 | tokenizer = AutoTokenizer.from_pretrained( 14 | args.tokenizer if args.tokenizer else args.model 15 | ) 16 | model = AutoModelForCausalLM.from_pretrained(args.model).cuda() 17 | 18 | tokenizer.pad_token = tokenizer.eos_token 19 | 20 | return tokenizer, model 21 | 22 | 23 | def prep_dataset(args, tokenizer): 24 | ds = load_dataset(args.dataset, split=args.split) 25 | ds = ds.flatten() 26 | ds = ds.select(range(min(args.num_samples, len(ds)))) 27 | print("Loaded Dataset with {} samples".format(len(ds))) 28 | 29 | def preprocess_function(examples): 30 | return tokenizer( 31 | [" ".join(x) for x in examples[args.dataset_key]], 32 | truncation=True, 33 | max_length=args.max_length, 34 | ) 35 | 36 | tokenized_ds = ds.map( 37 | preprocess_function, 38 | batched=True, 39 | num_proc=4, 40 | remove_columns=ds.column_names, 41 | ) 42 | 43 | return tokenized_ds 44 | 45 | 46 | def parse_args(): 47 | parser = argparse.ArgumentParser( 48 | description="Calculate perplexity of a model on a dataset" 49 | ) 50 | parser.add_argument( 51 | "--model", 52 | type=str, 53 | required=True, 54 | help="Name of the HF model e.g. MDEL/merged-arxiv-github", 55 | ) 56 | parser.add_argument( 57 | "--tokenizer", 58 | type=str, 59 | required=False, 60 | help="Optional tokenizer name e.g. MDEL/merged-arxiv-github. If not provided, will use the model name", 61 | ) 62 | parser.add_argument( 63 | "--dataset", 64 | type=str, 65 | required=True, 66 | help="Name of the HF dataset e.g. MDEL/pubmed_abstracts", 67 | ) 68 | parser.add_argument( 69 | "--split", 70 | type=str, 71 | required=True, 72 | help="Name of the split to evaluate on e.g. validation", 73 | ) 74 | parser.add_argument( 75 | "--max-length", 76 | type=int, 77 | required=False, 78 | default=1024, 79 | help="Max length of the input sequence", 80 | ) 81 | parser.add_argument( 82 | "--dataset-key", 83 | type=str, 84 | required=False, 85 | default='text', 86 | help="Key to use to access the dataset e.g. text or answers.text", 87 | ) 88 | 89 | parser.add_argument( 90 | "--num-samples", 91 | type=int, 92 | required=False, 93 | default=10000, 94 | help="Max number of samples to evaluate on", 95 | ) 96 | 97 | return parser.parse_args() 98 | 99 | 100 | if __name__ == "__main__": 101 | args = parse_args() 102 | 103 | tokenizer, model = load_model(args) 104 | dataset = prep_dataset(args, tokenizer) 105 | 106 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) 107 | 108 | training_args = TrainingArguments( 109 | output_dir="perplexity-results", 110 | push_to_hub=False, 111 | ) 112 | 113 | trainer = Trainer( 114 | model=model, 115 | args=training_args, 116 | eval_dataset=dataset, 117 | data_collator=data_collator, 118 | ) 119 | 120 | eval_results = trainer.evaluate() 121 | perplexity = math.exp(eval_results['eval_loss']) 122 | message = f"Perplexity for {args.model} on {args.dataset}[{args.split}]: {perplexity}" 123 | print(message) 124 | 125 | # write to jsonl 126 | data = { 127 | "date": time.time(), 128 | "runtime": eval_results['eval_runtime'], 129 | "model": args.model, 130 | "tokenizer": args.tokenizer if args.tokenizer else args.model, 131 | "dataset": args.dataset, 132 | "split": args.split, 133 | "max_length": args.max_length, 134 | "dataset_key": args.dataset_key, 135 | "num_samples": args.num_samples, 136 | "perplexity": perplexity, 137 | } 138 | 139 | with open("perplexity-results.jsonl", "a") as f: 140 | f.write(json.dumps(data) + "\n") 141 | -------------------------------------------------------------------------------- /src/mdel/configs/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | dataset_name: Multi-Domain-Expert-Layers/arxiv 3 | model_name_or_path: EleutherAI/pythia-1b-deduped 4 | output_dir: "ckpts/pythia-1b-deduped/uspto/layer_9,10,11,12,13" 5 | training_layers: "4,5" 6 | per_device_train_batch_size: 32 7 | per_device_eval_batch_size: 20 8 | preprocessing_num_workers: 32 9 | learning_rate: 0.0001 10 | block_size: 512 11 | num_train_epochs: 1 12 | gradient_accumulation_steps: 1 13 | do_train: true 14 | do_eval: true 15 | evaluation_strategy: "steps" 16 | save_total_limit: 2 17 | max_grad_norm: 2.0 18 | save_steps: 500 19 | deepspeed: "configs/zero_config.json" 20 | overwrite_output_dir: true 21 | logging_steps: 20 22 | dtype: "float16" 23 | wandb_entity: "ontocord" 24 | wandb_project: "layer-experts" 25 | wandb_run_name: "default" 26 | # Model and Tokenizer related settings 27 | model_revision: "main" 28 | validation_splits: 29 | config_name: 30 | tokenizer_name: 31 | use_fast_tokenizer: true 32 | low_cpu_mem_usage: false 33 | eval_steps: 200 34 | 35 | # Dataset related settings 36 | dataset_config_name: 37 | max_train_samples: 38 | max_eval_samples: 39 | streaming: 40 | overwrite_cache: false 41 | 42 | # Training related settings 43 | gradient_checkpointing: false 44 | warmup_steps: 0 45 | adam_beta1: 0.9 46 | adam_beta2: 0.99 47 | adam_epsilon: 1e-8 48 | weight_decay: 0 49 | save_strategy: "steps" 50 | quantization: 51 | 52 | # Logging and Caching 53 | log_wandb: true 54 | cache_dir: 55 | use_auth_token: 56 | 57 | push_to_hub: 58 | push_to_hub_organization: 59 | push_to_hub_model_id: 60 | 61 | max_steps: -1 62 | 63 | debug: 64 | model_name_or_path: "EleutherAI/pythia-70m-deduped" 65 | output_dir: "layer9" 66 | training_layers: "4,5" 67 | per_device_train_batch_size: 1 68 | per_device_eval_batch_size: 1 69 | preprocessing_num_workers: 32 70 | learning_rate: 1e-4 71 | block_size: 512 72 | num_train_epochs: 1 73 | gradient_accumulation_steps: 1 74 | do_train: true 75 | do_eval: true 76 | evaluation_strategy: "steps" 77 | save_total_limit: 2 78 | max_grad_norm: 2.0 79 | save_steps: 1 80 | eval_steps: 1 81 | max_steps: 1 82 | dtype: "float32" 83 | overwrite_output_dir: true 84 | logging_steps: 20 85 | validation_splits: "validation_pile,validation_domain" 86 | max_train_samples: 2 87 | max_eval_samples: 2 88 | -------------------------------------------------------------------------------- /src/mdel/configs/zero_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "scheduler": { 14 | "type": "WarmupDecayLR", 15 | "params": { 16 | "warmup_min_lr": "auto", 17 | "warmup_max_lr": "auto", 18 | "warmup_num_steps": "auto", 19 | "total_num_steps": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 2, 24 | "allgather_partitions": true, 25 | "allgather_bucket_size": 1e9, 26 | "overlap_comm": false, 27 | "reduce_scatter": true, 28 | "reduce_bucket_size": 1e9, 29 | "contiguous_gradients": true 30 | }, 31 | "gradient_accumulation_steps": "auto", 32 | "gradient_clipping": "auto", 33 | "steps_per_print": 2000, 34 | "train_batch_size": "auto", 35 | "train_micro_batch_size_per_gpu": "auto", 36 | "wall_clock_breakdown": false 37 | } 38 | -------------------------------------------------------------------------------- /src/mdel/iterate_layers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export WANDB_PROJECT=pythia-6.9b-layer-test 3 | 4 | 5 | for i in 6 18 30 9 20 28 12 19 26 15 22 24 ; 6 | do 7 | export WANDB_NAME="layer_${i}" 8 | accelerate launch trainer.py \ 9 | --train_file data/train_data.txt \ 10 | --validation_file data/book_val.txt \ 11 | --model_name_or_path EleutherAI/pythia-6.9b-deduped \ 12 | --output_dir "ckpts/pythia-6.9b/books/layer_${i}" \ 13 | --training_layer ${i} \ 14 | --per_device_train_batch_size 1 \ 15 | --per_device_eval_batch_size 1 \ 16 | --preprocessing_num_workers 32 \ 17 | --learning_rate 1e-4 \ 18 | --block_size 512 \ 19 | --num_train_epochs 1 \ 20 | --gradient_accumulation_steps 8 \ 21 | --do_train \ 22 | --do_eval \ 23 | --overwrite_output_dir \ 24 | --logging_steps 20 \ 25 | --max_steps 1000 26 | done 27 | -------------------------------------------------------------------------------- /src/mdel/merge_experts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from huggingface_hub import HfApi 5 | from transformers import AutoModelForCausalLM 6 | 7 | 8 | class HFModel: 9 | def __init__(self, model_name): 10 | self.name = model_name 11 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 12 | print("Using device:", self.device) 13 | 14 | def __enter__(self): 15 | print(f"Loading Model {self.name}") 16 | with torch.no_grad(): 17 | self.model = AutoModelForCausalLM.from_pretrained(self.name) 18 | self.model = self.model.to(self.device) 19 | self.model.eval() 20 | return self.model 21 | 22 | def __exit__(self, type, value, traceback): 23 | print(f"Unloading Model {self.name}") 24 | del self.model 25 | 26 | 27 | def merge_n_models(model_names): 28 | with torch.no_grad(): 29 | with HFModel(model_names[0]) as blended_model: 30 | 31 | # zero out blended models params 32 | for p in blended_model.parameters(): 33 | p.data *= 0 34 | 35 | for mn in model_names: 36 | with HFModel(mn) as temp_model: 37 | for p1, p2 in zip(blended_model.parameters(), temp_model.parameters()): 38 | p1.data += p2.data * (1 / len(model_names)) 39 | del temp_model 40 | return blended_model 41 | 42 | 43 | def upload_expert(merged_model, args): 44 | sources = "\n".join(map(lambda mn: f"- [{mn}](https://huggingface.co/{mn})", args.experts)) 45 | 46 | model_card_yaml = f"""--- 47 | tags: 48 | - MDEL 49 | --- 50 | 51 | # Model Name 52 | {args.hf_repo} 53 | 54 | # Model Description 55 | This model was generated by averaging the weights of the following models 56 | {sources} 57 | """ 58 | print(model_card_yaml) 59 | 60 | print("Pushing Model to Hub") 61 | merged_model.push_to_hub(args.hf_repo, model_card=model_card_yaml) 62 | 63 | print("Pushing Model Card to Hub") 64 | 65 | # Upload model card 66 | with open("./README.md", "w") as f: 67 | f.write(model_card_yaml) 68 | api = HfApi() 69 | api.upload_file( 70 | path_or_fileobj="./README.md", 71 | path_in_repo="README.md", 72 | repo_id=args.hf_repo, 73 | repo_type="model", 74 | ) 75 | 76 | 77 | def parse_args(): 78 | parser = argparse.ArgumentParser(description='Merge expert models and upload to the HuggingFace Hub') 79 | parser.add_argument('--hf-repo', type=str, required=True, 80 | help='Name of the repo to upload the merged model e.g. MDEL/merged-arxiv-github') 81 | parser.add_argument('-e', '--expert', action='append', 82 | help='Name of an expert repo e.g MDEL/expert-arxiv (use this flag multiple times)', 83 | required=True, dest='experts') 84 | 85 | return parser.parse_args() 86 | 87 | 88 | if __name__ == '__main__': 89 | args = parse_args() 90 | 91 | if (len(args.experts) < 2): 92 | print("Must specify at least 2 experts to merge") 93 | exit(1) 94 | 95 | merged_model = merge_n_models(args.experts) 96 | upload_expert(merged_model, args) 97 | -------------------------------------------------------------------------------- /src/mdel/pile_upload.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from huggingface_hub import HfApi 5 | 6 | api = HfApi() 7 | 8 | 9 | def upload_dataset(folder_path, hf_repo, split): 10 | auth_token = os.getenv('HF_ACCESS_TOKEN') 11 | 12 | print(f"Uploading {folder_path} to {hf_repo}...") 13 | 14 | api.upload_folder(folder_path=folder_path, 15 | repo_id=hf_repo, 16 | repo_type="dataset", 17 | path_in_repo=f"data/{split}", 18 | use_auth_token=auth_token) 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description='Upload dataset to Hugging Face Hub') 23 | parser.add_argument('--folder-path', type=str, required=True, help='Path to dataset file') 24 | parser.add_argument('--hf-repo', type=str, required=True, help='Path to dataset file') 25 | parser.add_argument('--split', type=str, required=True, help='Split') 26 | return parser.parse_args() 27 | 28 | 29 | if __name__ == '__main__': 30 | args = parse_args() 31 | folder_path = args.folder_path 32 | hf_repo = args.hf_repo 33 | split = args.split 34 | upload_dataset(folder_path, hf_repo, split) 35 | -------------------------------------------------------------------------------- /src/mdel/pile_utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import io 3 | import multiprocessing 4 | import os.path 5 | import pathlib 6 | 7 | import ujson 8 | import zstandard 9 | from tqdm import tqdm 10 | from tqdm.contrib.concurrent import process_map 11 | 12 | __HERE__ = pathlib.Path(__file__).parent.resolve() 13 | 14 | SHARD_SIZE = 100000 15 | 16 | 17 | def jsonl_extract_transform_write(infile, outfile, map_func, filter_func, max_lines=-1): 18 | line_counter = 0 19 | num_failed = 0 20 | dctx = zstandard.ZstdDecompressor() 21 | 22 | with dctx.stream_reader(infile) as reader: 23 | cctx = zstandard.ZstdCompressor() 24 | with cctx.stream_writer(outfile, closefd=False) as writer: 25 | text_stream = io.TextIOWrapper(reader, encoding='utf-8') 26 | for line in text_stream: 27 | try: 28 | if filter_func is not None and filter_func(line): 29 | continue 30 | 31 | func_output = map_func(line) 32 | writer.write(f'{func_output}\n'.encode()) 33 | line_counter += 1 34 | if -1 < max_lines <= line_counter: 35 | print(f'Stopping after {line_counter} lines') 36 | break 37 | except Exception: 38 | num_failed += 1 39 | continue 40 | 41 | return line_counter, num_failed 42 | 43 | 44 | def pile_get_text(line): 45 | # Just getting rid of the meta tag for a consistent schema 46 | json_obj = ujson.loads(line) 47 | text = json_obj['text'] 48 | text_json = ujson.dumps({'text': text}) 49 | return text_json 50 | 51 | 52 | def pile_filter_subset(line, subset_name): 53 | json_obj = ujson.loads(line) 54 | return json_obj['meta']['pile_set_name'] != subset_name 55 | 56 | 57 | def domain_mapper(args): 58 | input_file_path, output_file_path, subset_name = args 59 | with open(output_file_path, 'wb+') as outfile: 60 | with open(input_file_path, 'rb') as infile: 61 | num_domain_samples = jsonl_extract_transform_write(infile, 62 | outfile, 63 | pile_get_text, 64 | lambda x: pile_filter_subset(x, subset_name)) 65 | return num_domain_samples 66 | 67 | 68 | def pile_mapper(args): 69 | input_file_path, output_file_path, max_lines = args 70 | with open(output_file_path, 'wb+') as outfile: 71 | with open(input_file_path, 'rb') as infile_d: 72 | num_samples = jsonl_extract_transform_write(infile_d, outfile, pile_get_text, None, 73 | max_lines=max_lines) 74 | return num_samples 75 | 76 | 77 | def create_pile_domain_mix(domain_data_file_path: str, 78 | pile_file_path: str, 79 | output_dir: str, 80 | subset_name: str, 81 | max_files: int = -1, 82 | max_workers: int = multiprocessing.cpu_count()): 83 | if not os.path.exists(output_dir): 84 | os.makedirs(output_dir) 85 | else: 86 | raise IOError('Output path already exists') 87 | 88 | domain_data_file_path_expanded, domain_data_processed_paths = process_mix_file_paths(domain_data_file_path, 89 | max_files, output_dir, 90 | 'domain') 91 | print('Processing domain data samples') 92 | file_sample_counts = process_map(domain_mapper, 93 | zip(domain_data_file_path_expanded, 94 | domain_data_processed_paths, 95 | len(domain_data_processed_paths) * [subset_name]), 96 | max_workers=max_workers) 97 | 98 | num_domain_samples = sum([x[0] for x in file_sample_counts]) 99 | num_failed_samples = sum([x[1] for x in file_sample_counts]) 100 | fail_rate = 1000 * num_failed_samples / num_domain_samples 101 | print(f'Number of domain samples: {num_domain_samples}, rate of samples failed to parse {fail_rate}%') 102 | 103 | print('Processing Pile data samples') 104 | pile_file_path_expanded, pile_processed_paths = process_mix_file_paths(pile_file_path, 105 | -1, output_dir, 'pile') 106 | num_pile_samples = 0 107 | pile_file_idx = 0 108 | while num_pile_samples < num_domain_samples: 109 | cur_num_samples = pile_mapper( 110 | (pile_file_path_expanded[pile_file_idx], 111 | pile_processed_paths[pile_file_idx], 112 | num_domain_samples)) 113 | num_pile_samples += cur_num_samples[0] 114 | pile_file_idx += 1 115 | 116 | 117 | def process_mix_file_paths(domain_data_file_path, max_files, output_dir, name_prefix): 118 | domain_data_file_path_expanded = glob.glob(domain_data_file_path) 119 | if max_files > 0: 120 | print(f'Using {max_files} data files') 121 | domain_data_file_path_expanded = domain_data_file_path_expanded[:max_files] 122 | domain_data_processed_paths = [os.path.join(output_dir, name_prefix + '_' + os.path.basename(x)) 123 | for x in domain_data_file_path_expanded] 124 | return domain_data_file_path_expanded, domain_data_processed_paths 125 | 126 | 127 | def read_pile_texts(input_file_path): 128 | """ 129 | Reads a Pile dataset file in zstd-compressed JSON format and returns a list of 'text' fields. 130 | 131 | :param input_file_path: The path to the input file. 132 | :type input_file_path: str 133 | :return: A list of 'text' fields from each line of the input file. 134 | :rtype: List[str] 135 | :raises FileNotFoundError: If the input file path does not exist. 136 | :raises ValueError: If the input file path is not a string. 137 | :example: read_pile_texts('pile_texts.zst') 138 | """ 139 | with open(input_file_path, 'rb') as infile: 140 | dctx = zstandard.ZstdDecompressor() 141 | with dctx.stream_reader(infile) as reader: 142 | text_stream = io.TextIOWrapper(reader, encoding='utf-8') 143 | return [ujson.loads(line)['text'] for line in text_stream] 144 | 145 | 146 | def split_pile(input_file_path, shard_size=SHARD_SIZE): 147 | print(input_file_path) 148 | resolved_files = glob.glob(os.path.abspath(input_file_path)) 149 | 150 | for resolved_file in resolved_files: 151 | dctx = zstandard.ZstdDecompressor() 152 | with open(resolved_file, 'rb') as infile: 153 | with dctx.stream_reader(infile) as reader: 154 | text_stream = io.TextIOWrapper(reader, encoding='utf-8') 155 | cctx = zstandard.ZstdCompressor() 156 | shard_num = -1 157 | writer = None 158 | outfile = None 159 | 160 | for line_counter, line in enumerate(tqdm(text_stream)): 161 | if line_counter % shard_size == 0: 162 | if writer is not None: 163 | writer.close() 164 | if outfile is not None: 165 | outfile.close() 166 | 167 | shard_num += 1 168 | output_file_path = os.path.join( 169 | os.path.dirname(resolved_file), 170 | os.path.splitext(resolved_file)[0].replace('.jsonl', '') + f'_{shard_num}.jsonl.zst') 171 | outfile = open(output_file_path, 'wb') 172 | writer = cctx.stream_writer(outfile, closefd=False) 173 | 174 | writer.write(line.encode(encoding='utf-8')) 175 | os.remove(resolved_file) 176 | 177 | 178 | if __name__ == '__main__': 179 | repo_root = __HERE__.parent.parent 180 | # domain_data_file_path = str(repo_root / 'data/pile_uspto/*.jsonl.zst') 181 | # pile_file_path = str(repo_root / 'data/pile_01/01.jsonl.zst') 182 | pile_file_path = '/Users/vmay/Documents/git/MDEL/data/pile/val/*.jsonl.zst' 183 | output_dir = str(repo_root / 'data/mix_uspto_all_3' / 'val') 184 | 185 | # create_pile_domain_mix(pile_file_path, pile_file_path, output_dir, max_files=-1, max_workers=4) 186 | # split_pile('/Users/vmay/Documents/git/MDEL/data/pile/train/*.jsonl.zst') 187 | # print(read_pile_texts('/Users/vmay/Documents/git/MDEL/data/mix_uspto_all/val/domain_val_0.jsonl.zst')[150]) 188 | # print(read_pile_texts('/Users/vmay/Documents/git/MDEL/data/mix_uspto_all/val/domain_val_0.jsonl.zst')[151]) 189 | # print(read_pile_texts('/Users/vmay/Documents/git/MDEL/data/mix_uspto_all/test/domain_test_0.jsonl.zst')[150]) 190 | # print(read_pile_texts('/Users/vmay/Documents/git/MDEL/data/mix_uspto_all/test/domain_test_0.jsonl.zst')[151]) 191 | # print(read_pile_texts('/Users/vmay/Documents/git/MDEL/data/mix_uspto_all/test/pile_test_0.jsonl.zst')[150]) 192 | # print(read_pile_texts('/Users/vmay/Documents/git/MDEL/data/mix_uspto_all/test/pile_test_0.jsonl.zst')[151]) 193 | print(len(read_pile_texts('/Users/vmay/Documents/git/MDEL/data/mix_uspto_all/train/domain_01_0.jsonl.zst'))) 194 | print(len(read_pile_texts('/Users/vmay/Documents/git/MDEL/data/mix_uspto_all/train/pile_01_0.jsonl.zst'))) 195 | -------------------------------------------------------------------------------- /src/mdel/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATASET=uspto 4 | TRAINING_LAYERS=9,10,11,12,13 5 | 6 | export WANDB_PROJECT=pythia-1b-deduped-layer-test-$DATASET 7 | export WANDB_NAME="layer_$TRAINING_LAYERS" 8 | export WANDB_ENTITY=ontocord 9 | 10 | # check if venv or conda is activated 11 | if [ -n "$CONDA_DEFAULT_ENV" ] || [ -n "$VIRTUAL_ENV" ]; then 12 | echo "Virtual environment is activated" 13 | else 14 | echo "Error: virtual environment is not activated" 15 | exit 1 16 | fi 17 | 18 | accelerate launch trainer.py \ 19 | --configs defaults \ 20 | --dataset_name Multi-Domain-Expert-Layers/$DATASET \ 21 | --model_name_or_path EleutherAI/pythia-1b-deduped \ 22 | --output_dir "ckpts/pythia-1b-deduped/$DATASET/layer_$TRAINING_LAYERS" \ 23 | --training_layers $TRAINING_LAYERS \ 24 | --per_device_train_batch_size 1 \ 25 | --per_device_eval_batch_size 8 \ 26 | --preprocessing_num_workers 32 \ 27 | --learning_rate 1e-4 \ 28 | --block_size 512 \ 29 | --num_train_epochs 1 \ 30 | --gradient_accumulation_steps 8 \ 31 | --evaluation_strategy steps \ 32 | --eval_steps 200 \ 33 | --logging_steps 20 \ 34 | --max_steps 1000 \ 35 | --push_to_hub true \ 36 | --push_to_hub_model_id expert-$DATASET \ 37 | --push_to_hub_organization Multi-Domain-Expert-Layers \ 38 | --wandb_entity $WANDB_ENTITY \ 39 | --wandb_project $WANDB_PROJECT \ 40 | --wandb_run_name $WANDB_NAME \ 41 | --validation_splits "validation_pile,validation_domain" \ 42 | --dtype "float32" \ 43 | --no_deepspeed 44 | -------------------------------------------------------------------------------- /src/mdel/train_cbtm_classifier.py: -------------------------------------------------------------------------------- 1 | import joblib 2 | from datasets import load_dataset 3 | from sklearn.feature_extraction import FeatureHasher 4 | from sklearn.linear_model import LogisticRegression 5 | from sklearn.metrics import classification_report 6 | from sklearn.model_selection import train_test_split 7 | from sklearn.pipeline import Pipeline 8 | from tokenizers import Tokenizer 9 | 10 | data_url = [ 11 | "Multi-Domain-Expert-Layers/pubmed_abstracts", 12 | "Multi-Domain-Expert-Layers/philpapers", 13 | "Multi-Domain-Expert-Layers/pubmed_central", 14 | "Multi-Domain-Expert-Layers/freelaw", 15 | "Multi-Domain-Expert-Layers/arxiv", 16 | "Multi-Domain-Expert-Layers/github", 17 | "Multi-Domain-Expert-Layers/uspto", 18 | ] 19 | 20 | expert_datasets = [load_dataset(i) for i in data_url] 21 | 22 | 23 | tokenizer = Tokenizer.from_pretrained("EleutherAI/pythia-1b-deduped") 24 | tokenizer.enable_truncation(max_length=1024) 25 | 26 | 27 | tokenized_datasets = [] 28 | for ed in expert_datasets: 29 | tokenized_datasets.append(ed['train'].map(lambda x: { 30 | "token": [i.tokens for i in tokenizer.encode_batch(x["text"])] 31 | }, batched=True)) 32 | 33 | 34 | features = [] 35 | label = [] 36 | for ed, lab in zip(expert_datasets, data_url): 37 | for i in ed.select(range(min(10000, len(ed)))).iter(batch_size=10000): 38 | features.extend(i['token']) 39 | label.extend([lab] * len(i['token'])) 40 | 41 | 42 | X_train, X_test, y_train, y_test = train_test_split(features, label, test_size=0.2, random_state=42) 43 | 44 | 45 | pipeline = Pipeline( 46 | [ 47 | ('hasher', FeatureHasher(n_features=512, input_type="string")), 48 | ('lr', LogisticRegression(multi_class='multinomial', solver='lbfgs')) 49 | ] 50 | ) 51 | pipeline.fit(X_train, y_train) 52 | 53 | 54 | y_train_pred = pipeline.predict(X_train) 55 | y_test_pred = pipeline.predict(X_test) 56 | 57 | 58 | print(classification_report(y_train, y_train_pred)) 59 | print(classification_report(y_test, y_test_pred)) 60 | 61 | 62 | joblib.dump(pipeline, 'cbtm_classifier.pkl') 63 | -------------------------------------------------------------------------------- /src/mdel/train_chat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATASET=mini-pile-instruct 4 | TRAINING_LAYERS=4,5,6,7,8 5 | 6 | export WANDB_PROJECT=pythia-1b-deduped-layer-test-$DATASET 7 | export WANDB_NAME="layer_$TRAINING_LAYERS" 8 | export WANDB_ENTITY=ontocord 9 | 10 | # check if venv or conda is activated 11 | if [ -n "$CONDA_DEFAULT_ENV" ] || [ -n "$VIRTUAL_ENV" ]; then 12 | echo "Virtual environment is activated" 13 | else 14 | echo "Error: virtual environment is not activated" 15 | exit 1 16 | fi 17 | 18 | accelerate launch trainer.py \ 19 | --dataset_name Multi-Domain-Expert-Layers/$DATASET \ 20 | --model_name_or_path EleutherAI/pythia-1b-deduped \ 21 | --output_dir "ckpts/pythia-1b-deduped/$DATASET/layer_$TRAINING_LAYERS" \ 22 | --training_layers $TRAINING_LAYERS \ 23 | --per_device_train_batch_size 1 \ 24 | --per_device_eval_batch_size 8 \ 25 | --preprocessing_num_workers 32 \ 26 | --learning_rate 1e-4 \ 27 | --block_size 512 \ 28 | --num_train_epochs 1 \ 29 | --gradient_accumulation_steps 8 \ 30 | --do_train \ 31 | --do_eval \ 32 | --evaluation_strategy steps \ 33 | --eval_steps 200 \ 34 | --overwrite_output_dir \ 35 | --logging_steps 20 \ 36 | --max_steps 1000 \ 37 | --push_to_hub true \ 38 | --push_to_hub_model_id expert-$DATASET \ 39 | --push_to_hub_organization Multi-Domain-Expert-Layers 40 | -------------------------------------------------------------------------------- /src/mdel/train_ds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATASET=uspto 4 | TRAINING_LAYERS=9,10,11,12,13 5 | 6 | export WANDB_PROJECT=pythia-1b-deduped-layer-test-$DATASET 7 | export WANDB_NAME="layer_$TRAINING_LAYERS" 8 | export WANDB_ENTITY=ontocord 9 | 10 | # check if venv or conda is activated 11 | if [ -n "$CONDA_DEFAULT_ENV" ] || [ -n "$VIRTUAL_ENV" ]; then 12 | echo "Virtual environment is activated" 13 | else 14 | echo "Error: virtual environment is not activated" 15 | exit 1 16 | fi 17 | 18 | deepspeed trainer.py \ 19 | --configs defaults \ 20 | --dataset_name Multi-Domain-Expert-Layers/$DATASET \ 21 | --model_name_or_path EleutherAI/pythia-1b-deduped \ 22 | --output_dir "ckpts/pythia-1b-deduped/$DATASET/layer_$TRAINING_LAYERS" \ 23 | --training_layers $TRAINING_LAYERS \ 24 | --push_to_hub true \ 25 | --push_to_hub_model_id expert-$DATASET \ 26 | --push_to_hub_organization Multi-Domain-Expert-Layers \ 27 | --wandb_entity $WANDB_ENTITY \ 28 | --wandb_project $WANDB_PROJECT \ 29 | --wandb_run_name $WANDB_NAME \ 30 | --validation_splits "validation_pile,validation_domain" \ 31 | -------------------------------------------------------------------------------- /src/mdel/trainer_chat.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | accelerate launch trainer_chat.py ^ 3 | --dataset_name Dahoas/full-hh-rlhf ^ 4 | --model_name_or_path EleutherAI/pythia-160m ^ 5 | --output_dir "output_chat" ^ 6 | --training_layers "5,6,7" ^ 7 | --separator "Assistant:" ^ 8 | --prompt_column "prompt" ^ 9 | --answer_column "response" ^ 10 | --per_device_train_batch_size 1 ^ 11 | --per_device_eval_batch_size 8 ^ 12 | --preprocessing_num_workers 32 ^ 13 | --learning_rate 1e-4 ^ 14 | --block_size 512 ^ 15 | --num_train_epochs 1 ^ 16 | --gradient_accumulation_steps 8 ^ 17 | --do_train ^ 18 | --do_eval ^ 19 | --evaluation_strategy steps ^ 20 | --eval_steps 200 ^ 21 | --overwrite_output_dir ^ 22 | --logging_steps 20 ^ 23 | --max_steps 1000 24 | 25 | pause 26 | --------------------------------------------------------------------------------