├── .github
└── workflows
│ └── pre-commit.yaml
├── .gitignore
├── .idea
└── workspace.xml
├── .pre-commit-config.yaml
├── .pylintrc
├── LICENSE
├── Makefile
├── Model Merge And Analysis Tools
├── Enhanced_Mixer.py
├── Enhanced_Mixer_Requirements.txt
├── LM_BlockMerge.py
├── LM_BlockMerge_Requirements.txt
├── StratusScope.py
├── StratusScope_BarGraph.png
├── StratusScope_ConsoleOutput.png
├── StratusScope_Requirements.txt
└── __Quick Tool Explainer__.txt
├── README.md
├── clustering
├── download.py
├── feature_extractor.py
├── hierarchical_clustering.py
├── memmap_utils.py
└── train_clusterer.py
├── conda-mdel.yml
├── configs
├── fp16_1-3B_4M_bs_1.4T_tok_summit_pp3_mp2_256_nodes.yml
├── fp16_2-7B_4M_bs_1.4T_tok_summit_pp6_mp2_256nodes.yml
└── fp16_6-7B_4M_bs_1T_tok_summit_pp12_mp2_mbs2_512_nodes_real.yml
├── distillation_sparsification
├── README.md
├── datautils.py
├── distill.py
├── falcon.py
├── lion.py
├── lm_seqs_dataset.py
├── make_student.py
├── modelutils.py
├── process_data.py
├── quant.py
├── sparsegpt.py
├── test.py
├── test1.py
├── tracker.py
├── train.py
└── utils.py
├── docs
└── .gitkeep
├── lora-x
├── README.md
├── bpt_attention_plugin.py
├── bpt_pt.py
├── bpt_triton.py
├── config.py
├── configs
│ └── zero3_offload_config.json
├── data.py
├── experimental
│ ├── qlora_lomo.py
│ └── train_qlora_lomo.py
├── flash_patch.py
├── lora.py
├── qlora_bpt.py
├── requirements.txt
├── scripts
│ └── juwels_booster.sh
└── utils.py
├── notebooks
├── .gitkeep
├── CalculatePerplexity.ipynb
└── Merge_N_Experts.ipynb
├── requirements.txt
├── resources.md
├── scripts
├── c-btmInference.py
├── calc_perplexities.sh
├── calc_perplexities_slurm.sh
├── create_domain_pile_mix.sh
├── get_pile_shard1_data.sh
└── upload_to_hf.sh
├── setup.py
└── src
└── mdel
├── __init__.py
├── calculate_perplexity.py
├── configs
├── config.yaml
└── zero_config.json
├── eval_merges.py
├── iterate_layers.sh
├── merge_experts.py
├── pile_upload.py
├── pile_utils.py
├── train.sh
├── train_cbtm_classifier.py
├── train_chat.sh
├── train_ds.sh
├── trainer.py
├── trainer_chat.bat
└── trainer_chat.py
/.github/workflows/pre-commit.yaml:
--------------------------------------------------------------------------------
1 | name: pre-commit
2 |
3 | on:
4 | workflow_call:
5 | pull_request_target:
6 |
7 | jobs:
8 | pre-commit:
9 | runs-on: ubuntu-latest
10 | steps:
11 | # in case of PR, check out the PR's head branch
12 | - uses: actions/checkout@v3
13 | if: github.event_name == 'pull_request_target'
14 | with:
15 | ref: ${{ github.event.pull_request.head.sha }}
16 |
17 | # in case of push, check out the main branch
18 | - uses: actions/checkout@v3
19 | if: github.event_name != 'pull_request_target'
20 |
21 | - uses: actions/setup-python@v4
22 | with:
23 | python-version: "3.10"
24 | cache: "pip"
25 | cache-dependency-path: "**/requirements*.txt"
26 | - uses: pre-commit/action@v3.0.0
27 | - name: Post PR comment on failure
28 | if: failure() && github.event_name == 'pull_request_target'
29 | uses: peter-evans/create-or-update-comment@v2
30 | with:
31 | issue-number: ${{ github.event.pull_request.number }}
32 | body: |
33 | :x: **pre-commit** failed.
34 | Please run `pre-commit run --all-files` locally and commit the changes.
35 | Find more information in the repository's CONTRIBUTING.md
36 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | .DS_Store
132 |
133 | # SLURM results
134 | *.out
135 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 | {
35 | "keyToString": {
36 | "RunOnceActivity.OpenProjectViewOnStart": "true",
37 | "RunOnceActivity.ShowReadmeOnStart": "true",
38 | "last_opened_file_path": "/Users/kentsui/opensource/MDEL",
39 | "settings.editor.selected.configurable": "com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable"
40 | }
41 | }
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 | 1683965013951
52 |
53 |
54 | 1683965013951
55 |
56 |
57 |
58 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # WARNING!
2 | #
3 | # When making changes to auto-formatters used in pre-commit hooks, you are
4 | # likely to cause merge conflicts with main and/or other pull requests.
5 | # Fixing them might revert other people's work. Expect pain!
6 | # To avoid accidental reversions and keep it easy to review, please make sure
7 | # that changes here are in a pull request by themselves, that it consists of
8 | # two commits:
9 | #
10 | # 1. The changes to this file
11 | # 2. Changes made by running `python3 -m pre_commit run --all-files`.
12 | #
13 | # Then each time your pull request is blocked by a merge conflict, do the
14 | # following steps:
15 | #
16 | # git reset HEAD^1 && git checkout -f # discard the change commit
17 | # git rebase main # re-apply other people's changes
18 | # python3 -m pre_commit run --all-files # re-run the rules
19 | # git add . # add the newly changed files
20 | # git commit -m 'apply pre-commit' # commit it
21 | # git push -f # force push back to your branch
22 | #
23 | # Keep in mind you may have to do this a few times, as changes here may impact
24 | # other pull requests. Try to keep it up-to-date so they can go in when it'll
25 | # cause least disruption.
26 | #
27 | # /WARNING!
28 |
29 | exclude: build|stubs|^bot/templates/$|openassistant/templates|docs/docs/api/openapi.json
30 |
31 | repos:
32 | - repo: https://github.com/pre-commit/pre-commit-hooks
33 | rev: v4.4.0
34 | hooks:
35 | - id: trailing-whitespace
36 | - id: check-ast
37 | - id: check-yaml
38 | # Always check YAML but skip a few YAML files that are auto-generated
39 | # and which break the standard YAML check. The alternative would be to
40 | # skip any unsafe errors (and thus break YAML compatibility) or use
41 | # some other checker that may not work in general.
42 | exclude: ^copilot/.*/addons/.*$
43 | - id: check-json
44 | - id: check-case-conflict
45 | - id: detect-private-key
46 | - id: fix-encoding-pragma
47 | args: [--remove]
48 | - id: forbid-submodules
49 | - id: mixed-line-ending
50 | - id: requirements-txt-fixer
51 | - id: check-executables-have-shebangs
52 | - id: check-shebang-scripts-are-executable
53 | - id: check-byte-order-marker
54 | - id: check-symlinks
55 | - id: check-merge-conflict
56 | - id: check-added-large-files
57 | args: [--maxkb=1024]
58 | - id: end-of-file-fixer
59 |
60 | - repo: https://github.com/pycqa/flake8
61 | rev: 6.0.0
62 | hooks:
63 | - id: flake8
64 | args: [--max-line-length=120, --ignore=C901]
65 |
66 | - repo: https://github.com/pycqa/isort
67 | rev: 5.12.0
68 | hooks:
69 | - id: isort
70 |
71 | - repo: https://github.com/pre-commit/mirrors-prettier
72 | rev: v3.0.0-alpha.4
73 | hooks:
74 | - id: prettier
75 | args: [--prose-wrap=always, --write]
76 |
77 | - repo: local
78 | hooks:
79 | - id: next-lint-website
80 | name: Lint website
81 | files: ^website/
82 | exclude: ^website/node_modules/
83 | types_or: [javascript, jsx, ts, tsx]
84 | language: node
85 | pass_filenames: false
86 | entry: website/next-lint.js
87 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | [BASIC]
2 |
3 | # Naming style matching correct argument names.
4 | argument-naming-style=snake_case
5 |
6 | # Regular expression matching correct argument names. Overrides argument-
7 | # naming-style.
8 | #argument-rgx=
9 |
10 | # Naming style matching correct attribute names.
11 | attr-naming-style=snake_case
12 |
13 | # Regular expression matching correct attribute names. Overrides attr-naming-
14 | # style.
15 | #attr-rgx=
16 |
17 | # Bad variable names which should always be refused, separated by a comma.
18 | bad-names=foo,
19 | bar,
20 | baz,
21 | toto,
22 | tutu,
23 | tata
24 |
25 | # Naming style matching correct class attribute names.
26 | class-attribute-naming-style=any
27 |
28 | # Regular expression matching correct class attribute names. Overrides class-
29 | # attribute-naming-style.
30 | #class-attribute-rgx=
31 |
32 | # Naming style matching correct class names.
33 | class-naming-style=PascalCase
34 |
35 | # Regular expression matching correct class names. Overrides class-naming-
36 | # style.
37 | #class-rgx=
38 |
39 | # Naming style matching correct constant names.
40 | const-naming-style=UPPER_CASE
41 |
42 | # Regular expression matching correct constant names. Overrides const-naming-
43 | # style.
44 | #const-rgx=
45 |
46 | # Minimum line length for functions/classes that require docstrings, shorter
47 | # ones are exempt.
48 | docstring-min-length=-1
49 |
50 | # Naming style matching correct function names.
51 | function-naming-style=snake_case
52 |
53 | # Regular expression matching correct function names. Overrides function-
54 | # naming-style.
55 | #function-rgx=
56 |
57 | # Good variable names which should always be accepted, separated by a comma.
58 | good-names=i,
59 | j,
60 | k,
61 | ex,
62 | Run,
63 | _
64 |
65 | # Include a hint for the correct naming format with invalid-name.
66 | include-naming-hint=no
67 |
68 | # Naming style matching correct inline iteration names.
69 | inlinevar-naming-style=any
70 |
71 | # Regular expression matching correct inline iteration names. Overrides
72 | # inlinevar-naming-style.
73 | #inlinevar-rgx=
74 |
75 | # Naming style matching correct method names.
76 | method-naming-style=snake_case
77 |
78 | # Regular expression matching correct method names. Overrides method-naming-
79 | # style.
80 | #method-rgx=
81 |
82 | # Naming style matching correct module names.
83 | module-naming-style=snake_case
84 |
85 | # Regular expression matching correct module names. Overrides module-naming-
86 | # style.
87 | #module-rgx=
88 |
89 | # Colon-delimited sets of names that determine each other's naming style when
90 | # the name regexes allow several styles.
91 | name-group=
92 |
93 | # Regular expression which should only match function or class names that do
94 | # not require a docstring.
95 | no-docstring-rgx=^_
96 |
97 | # List of decorators that produce properties, such as abc.abstractproperty. Add
98 | # to this list to register other decorators that produce valid properties.
99 | # These decorators are taken in consideration only for invalid-name.
100 | property-classes=abc.abstractproperty
101 |
102 | # Naming style matching correct variable names.
103 | variable-naming-style=snake_case
104 |
105 | [CLASSES]
106 |
107 | # List of method names used to declare (i.e. assign) instance attributes.
108 | defining-attr-methods=__init__,
109 | __new__,
110 | setUp,
111 | __post_init__
112 |
113 | # List of member names, which should be excluded from the protected access
114 | # warning.
115 | exclude-protected=_asdict,
116 | _fields,
117 | _replace,
118 | _source,
119 | _make
120 |
121 | # List of valid names for the first argument in a class method.
122 | valid-classmethod-first-arg=cls
123 |
124 | # List of valid names for the first argument in a metaclass class method.
125 | valid-metaclass-classmethod-first-arg=cls
126 |
127 |
128 | [DESIGN]
129 |
130 | # Maximum number of arguments for function / method.
131 | max-args=10
132 |
133 | # Maximum number of attributes for a class (see R0902).
134 | max-attributes=20
135 |
136 | # Maximum number of boolean expressions in an if statement (see R0916).
137 | max-bool-expr=5
138 |
139 | # Maximum number of branch for function / method body.
140 | max-branches=12
141 |
142 | # Maximum number of locals for function / method body.
143 | max-locals=15
144 |
145 | # Maximum number of parents for a class (see R0901).
146 | max-parents=7
147 |
148 | # Maximum number of public methods for a class (see R0904).
149 | max-public-methods=20
150 |
151 | # Maximum number of return / yield for function / method body.
152 | max-returns=6
153 |
154 | # Maximum number of statements in function / method body.
155 | max-statements=50
156 |
157 | # Minimum number of public methods for a class (see R0903).
158 | min-public-methods=2
159 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | VENV = venv
2 | VENV_PYTHON = $(VENV)/bin/python
3 | SYSTEM_PYTHON = $(or $(shell which python3), $(shell which python))
4 | PYTHON = $(or $(wildcard $(VENV_PYTHON)), $(SYSTEM_PYTHON))
5 |
6 | $(VENV_PYTHON):
7 | if [ ! -d "$(VENV)" ]; then $(SYSTEM_PYTHON) -m venv $(VENV); \
8 | else \
9 | echo "Virtual environment already exists in directory $(VENV)"; \
10 | fi
11 |
12 | venv: $(VENV_PYTHON)
13 |
14 | setup_dev:
15 | pip install -r requirements.txt
16 | pre-commit install
17 | pip install -e .
18 |
19 |
20 | .PHONY: venv $(VENV_PYTHON)
21 |
--------------------------------------------------------------------------------
/Model Merge And Analysis Tools/Enhanced_Mixer.py:
--------------------------------------------------------------------------------
1 | '''
2 | Original script by Concedo/LostRuins; the mastermind behind what they once called a "Rubbish experiment"
3 | Now, an incredible leap forward in Language Model engineering and experimentation.
4 |
5 | Script modified by Chasm/Digitous
6 | '''
7 |
8 | import json
9 | import os
10 | import shutil
11 | import subprocess
12 | from tkinter.filedialog import askdirectory, askopenfilename
13 |
14 | import torch
15 | from colorama import Fore, Style, init
16 | from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
17 | LlamaConfig, LlamaForCausalLM, LlamaTokenizer,
18 | PreTrainedTokenizer, PreTrainedTokenizerFast)
19 |
20 | newline = '\n'
21 | def clear_console():
22 | if os.name == "nt": # For Windows
23 | subprocess.call("cls", shell=True)
24 | else: # For Linux and macOS
25 | subprocess.call("clear", shell=True)
26 |
27 | clear_console()
28 | print(f"{Fore.YELLOW}Starting script, please wait...{Style.RESET_ALL}")
29 |
30 | #mixer output settings
31 | blend_ratio = 0.5 #setting to 0 gives first model, and 1 gives second model
32 | fp16 = False #perform operations in fp16. Saves memory, but CPU inference will not be possible.
33 | always_output_fp16 = True #if true, will output fp16 even if operating in fp32
34 | max_shard_size = "10000MiB" #set output shard size
35 | force_cpu = True #only use cpu
36 | load_sharded = True #load both models shard by shard
37 |
38 | print(f"Blend Ratio set to: {Fore.GREEN}{blend_ratio}{Style.RESET_ALL}")
39 | print(f"Operations in fp16 is: {Fore.GREEN}{fp16}{Style.RESET_ALL}")
40 | print(f"Save Result in fp16: {Fore.GREEN}{always_output_fp16}{Style.RESET_ALL}")
41 | print(f"CPU RAM Only: {Fore.GREEN}{force_cpu}{Style.RESET_ALL}{newline}")
42 |
43 | #test generation settings, only for fp32
44 | deterministic_test = True #determines if outputs are always the same
45 | test_prompt = "" #test prompt for generation. only for fp32. set to empty string to skip generating.
46 | test_max_length = 32 #test generation length
47 |
48 |
49 | blend_ratio_b = 1.0 - blend_ratio
50 |
51 | def get_model_info(model):
52 | with torch.no_grad():
53 | outfo = ""
54 | cntent = 0
55 | outfo += "\n==============================\n"
56 | for name, para in model.named_parameters():
57 | cntent += 1
58 | outfo += ('{}: {}'.format(name, para.shape))+"\n"
59 | outfo += ("Num Entries: " + str(cntent))+"\n"
60 | outfo += ("==============================\n")
61 | return outfo
62 |
63 | def merge_models(model1,model2):
64 | with torch.no_grad():
65 | tensornum = 0
66 | for p1, p2 in zip(model1.parameters(), model2.parameters()):
67 | p1 *= blend_ratio
68 | p2 *= blend_ratio_b
69 | p1 += p2
70 | tensornum += 1
71 | print("Merging tensor "+str(tensornum))
72 | pass
73 |
74 | def read_index_filenames(sourcedir):
75 | index = json.load(open(sourcedir + '/pytorch_model.bin.index.json','rt'))
76 | fl = []
77 | for k,v in index['weight_map'].items():
78 | if v not in fl:
79 | fl.append(v)
80 | return fl
81 |
82 | print("Opening file dialog, please select FIRST model directory...")
83 | model_path1 = askdirectory(title="Select Directory of FIRST model to merge")
84 | print(f"First Model is: {model_path1}")
85 | print("Opening file dialog, please select SECOND model directory...")
86 | model_path2 = askdirectory(title="Select Directory of SECOND model to merge")
87 | print(f"Second Model is: {model_path2}")
88 | print("Opening file dialog, please select OUTPUT model directory...")
89 | model_path3 = askdirectory(title="Select Output Directory of merged model")
90 | print(f"Merged Save Directory is: {model_path3}{newline}")
91 | if not model_path1 or not model_path2:
92 | print("\nYou must select two directories containing models to merge and one output directory. Exiting.")
93 | exit()
94 |
95 | with torch.no_grad():
96 | if fp16:
97 | torch.set_default_dtype(torch.float16)
98 | else:
99 | torch.set_default_dtype(torch.float32)
100 |
101 | device = torch.device("cuda") if (torch.cuda.is_available() and not force_cpu) else torch.device("cpu")
102 | print(device)
103 |
104 | print("Loading Model 1...")
105 | model1 = AutoModelForCausalLM.from_pretrained(model_path1) #,torch_dtype=torch.float16
106 | model1 = model1.to(device)
107 | model1.eval()
108 | print("Model 1 Loaded. Dtype: " + str(model1.dtype))
109 | print("Loading Model 2...")
110 | model2 = AutoModelForCausalLM.from_pretrained(model_path2) #,torch_dtype=torch.float16
111 | model2 = model2.to(device)
112 | model2.eval()
113 | print("Model 2 Loaded. Dtype: " + str(model2.dtype))
114 |
115 | # Saving for posterity reasons, handy for troubleshooting if model result is broken
116 | # #ensure both models have the exact same layout
117 | # m1_info = get_model_info(model1)
118 | # m2_info = get_model_info(model2)
119 | # if m1_info != m2_info:
120 | # print("Model 1 Info: " + m1_info)
121 | # print("Model 2 Info: " + m2_info)
122 | # print("\nERROR:\nThe two selected models are not compatible! They must have identical structure!")
123 | # exit()
124 |
125 | print("Merging models...")
126 | merge_models(model1,model2)
127 |
128 | if model_path3:
129 | print("Saving new model...")
130 | if always_output_fp16 and not fp16:
131 | model1.half()
132 | model1.save_pretrained(model_path3, max_shard_size=max_shard_size)
133 | print("\nSaved to: " + model_path3)
134 | print("\nCopying files to: " + model_path3)
135 | files_to_copy = ["tokenizer.model", "special_tokens_map.json", "tokenizer_config.json", "vocab.json", "merges.txt"]
136 | for filename in files_to_copy:
137 | src_path = os.path.join(model_path1, filename)
138 | dst_path = os.path.join(model_path3, filename)
139 | try:
140 | shutil.copy2(src_path, dst_path)
141 | except FileNotFoundError:
142 | print("\nFile " + filename + " not found in" + model_path1 + ". Skipping.")
143 | else:
144 | print("\nOutput model was not saved as no output path was selected.")
145 | print("\nScript Completed.")
146 |
--------------------------------------------------------------------------------
/Model Merge And Analysis Tools/Enhanced_Mixer_Requirements.txt:
--------------------------------------------------------------------------------
1 | git+https://github.com/huggingface/transformers
2 | numpy
3 | matplotlib
4 | tkinter
5 | colorama
6 | shutil
7 |
--------------------------------------------------------------------------------
/Model Merge And Analysis Tools/LM_BlockMerge_Requirements.txt:
--------------------------------------------------------------------------------
1 | git+https://github.com/huggingface/transformers
2 | torch
3 | numpy
4 | tkinter
5 | shutil
6 |
--------------------------------------------------------------------------------
/Model Merge And Analysis Tools/StratusScope.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import subprocess
4 | from tkinter import Tk, filedialog
5 |
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import torch
9 | from colorama import Fore, Style, init
10 | from transformers import AutoConfig, AutoModel, logging
11 |
12 | logging.set_verbosity_warning()
13 | logging.set_verbosity_error()
14 |
15 | def select_folder():
16 | Tk().withdraw()
17 | folder = filedialog.askdirectory()
18 | return folder
19 |
20 | def clear_console():
21 | if os.name == "nt": # For Windows
22 | subprocess.call("cls", shell=True)
23 | else: # For Linux and macOS
24 | subprocess.call("clear", shell=True)
25 |
26 | def load_sharded_layer(folder, target_layer):
27 | files = os.listdir(folder)
28 | model_files = sorted([f for f in files if f.startswith('pytorch_model-') and f.endswith('.bin')])
29 |
30 | layer_state_dict = {}
31 | for model_file in model_files:
32 | shard = torch.load(os.path.join(folder, model_file), map_location=torch.device('cpu'))
33 | for name, param in shard.items():
34 | layer_number = get_layer_number(name)
35 | if layer_number == target_layer:
36 | layer_state_dict[name] = param
37 |
38 | return layer_state_dict
39 |
40 | def get_layer_number(name):
41 | parts = name.split('.')
42 | for part in parts:
43 | if part.isdigit():
44 | return int(part)
45 | return None
46 |
47 | def get_total_layers(model_folder):
48 | files = os.listdir(model_folder)
49 | model_files = sorted([f for f in files if f.startswith('pytorch_model-') and f.endswith('.bin')])
50 |
51 | all_layers = set()
52 |
53 | for model_file in model_files:
54 | shard = torch.load(os.path.join(model_folder, model_file), map_location=torch.device('cpu'))
55 | for name in shard.keys():
56 | layer_number = get_layer_number(name)
57 | all_layers.add(layer_number)
58 |
59 | return len(all_layers)
60 |
61 | # https://pytorch.org/docs/stable/tensors.html
62 |
63 | def compare_layers(model1_folder, model2_folder):
64 | layer_diffs = []
65 | newline = '\n'
66 | num_layers = get_total_layers(model1_folder) -1
67 | print(f"Torch Version: {torch.__version__}")
68 | print(f"Total Layers Found: {num_layers}{newline}")
69 |
70 | for layer_number in range(num_layers):
71 | layer_diff = 0
72 |
73 | model1_layer = load_sharded_layer(model1_folder, layer_number)
74 | model2_layer = load_sharded_layer(model2_folder, layer_number)
75 |
76 | for n1, p1 in model1_layer.items():
77 | p2 = model2_layer[n1]
78 |
79 | print(f"{newline}{Fore.YELLOW}--------Found Tensor Pair--------{newline}")
80 | print(f"p1 = {p1}")
81 | print(f"p2 = {p2}")
82 | print(f"{newline}{Fore.GREEN}--------Casting p1 & p2 tensor pair to float32--------{newline}")
83 | p1 = p1.detach().to(torch.float32)
84 | print(f"p1 = {p1}")
85 | p2 = p2.detach().to(torch.float32)
86 | print(f"p2 = {p2}")
87 |
88 | if not (torch.isinf(p1).any() or torch.isinf(p2).any()):
89 | diff = torch.abs(p1 - p2).sum().item()
90 | layer_diff += diff
91 |
92 | print(f"{newline}{Fore.CYAN}----------- Layer {layer_number}: Aggregate Difference = {layer_diff} -----------{Style.RESET_ALL}{newline}")
93 | layer_diffs.append(layer_diff)
94 |
95 | return layer_diffs
96 |
97 | def plot_layer_diff(layer_diffs, model1_name, model2_name):
98 | plt.figure(figsize=(20, 6))
99 | num_layers = len(layer_diffs)
100 | layer_indices = range(num_layers)
101 | plt.bar(layer_indices, layer_diffs)
102 | plt.xticks(layer_indices)
103 | plt.xlabel('Layer')
104 | plt.ylabel('Difference')
105 | plt.title(f"{model1_name} vs {model2_name} Layer Difference")
106 | plt.ylim(bottom=0)
107 | print("Script completed, close graph to unload models and return to commandline.")
108 | plt.show()
109 |
110 | def main():
111 | print("Select model1 folder:")
112 | model1_folder = select_folder()
113 | model1_name = os.path.basename(model1_folder)
114 | print("Select model2 folder:")
115 | model2_folder = select_folder()
116 | model2_name = os.path.basename(model2_folder)
117 |
118 | print("Examining Models...")
119 | clear_console()
120 | layer_diffs = compare_layers(model1_folder, model2_folder)
121 |
122 | plot_layer_diff(layer_diffs, model1_name, model2_name)
123 |
124 | torch.cuda.empty_cache()
125 | import gc
126 | gc.collect()
127 |
128 | if __name__ == "__main__":
129 | main()
130 |
--------------------------------------------------------------------------------
/Model Merge And Analysis Tools/StratusScope_BarGraph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huu4ontocord/MDEL/d84a598e765accfb723edd58f6c0a426d8c16d8d/Model Merge And Analysis Tools/StratusScope_BarGraph.png
--------------------------------------------------------------------------------
/Model Merge And Analysis Tools/StratusScope_ConsoleOutput.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huu4ontocord/MDEL/d84a598e765accfb723edd58f6c0a426d8c16d8d/Model Merge And Analysis Tools/StratusScope_ConsoleOutput.png
--------------------------------------------------------------------------------
/Model Merge And Analysis Tools/StratusScope_Requirements.txt:
--------------------------------------------------------------------------------
1 | git+https://github.com/huggingface/transformers
2 | numpy
3 | matplotlib
4 | tkinter
5 | colorama
6 |
--------------------------------------------------------------------------------
/Model Merge And Analysis Tools/__Quick Tool Explainer__.txt:
--------------------------------------------------------------------------------
1 | ---All tools provided originate from Ontocord associated family (KoboldAI community)---
2 |
3 | StratusScope is a language model layer analysis tool that shows the aggregate difference per-layer from a base model
4 | VS any fine-tuned variant in bar graph format and posts tensors to console along with found differences.
5 |
6 | LM_BlockMerge is a per-layer language model merging tool that empowers one to choose a percent of each layer to weight-sum merge from model A to model B.
7 |
8 | EnhancedMixer is a simple weight-sum merge tool where one can run to choose two models to combine, merge percentage and simple parameters can be adjusted inside the well documented .py script.
9 |
10 | ------------------------------------
11 | In-Depth Explainer on Each Tool:
12 | ------------------------------------
13 |
14 | [[StratusScope]]
15 |
16 | Is a language model tool that utilizes HuggingFace's Transformers library and loads two language models of the same architecture and parameter size, consolidates the weights and biases within each layer of both models in memory, examines the aggregate difference between layers, and generates a bar graph detailing which layers have the most difference between each model with matplotlib.
17 |
18 | Use Case - This is an invaluable tool to measure layer differences between a base model and a fine-tune of that model to determine which layers inherited the most change from fine-tuning.
19 |
20 | Original Git [Author: Digitous]
21 | https://github.com/Digitous/StratusScope
22 |
23 | ------------------------------------
24 |
25 | [[Language Model Transformer Block Merge]]
26 |
27 | Uses a tkinter GUI that analyzes two selected models and allows per-layer percentage
28 | weight merging controlled by the user (inspired by an Image Diffusion block merging
29 | technique and applied to transformers based Language Models).
30 |
31 | Usage:
32 | Start the script via the command
33 | [On Linux] python ./LM_BlockMerge.py
34 | [On Windows] python LM_BlockMerge.py
35 | The script will then prompt you for three folders:
36 | The first model
37 | The first model will be used as a base for the transformers configuration, and will be used as reference when handling the layers and weights.
38 | The second model
39 | The second model will only be used for providing the secondary layers for the merge.
40 | The output folder
41 | The resulting model will be saved inside the selected directory, in a folder called "./this/is/your/chosen_path/" + "/converted_model"
42 | The script will then load the weights in memory, according to the precision (32/16 bit) chosen in the configuration header, and will subsequently prompt the user with a popup GUI listing all of the layers available in the selected models.
43 |
44 | The user will be able to merge layers according to any strategy, ranging from:
45 | Creating the output_model by completely replacing N layers at will;
46 | Creating the output_model by chosing an individual mix ratio on N layers;
47 | Any mix of the two strategies on N chosen layers.
48 | The layers will be merged according to the individual choices per layer, and the resulting model weights will be saved onto the output folder, alongside the first model's config.json file.
49 |
50 | Available output settings:
51 | fp16 = False # Perform operations in fp16. Saves memory, but CPU inference will not be possible.
52 | always_output_fp16 = True # If true, will output fp16 even if operating in fp32
53 | max_shard_size = "2000MiB" # Set output shard size
54 | verbose_info = True # Will show model information when loading
55 | force_cpu = True # Only use cpu
56 |
57 | Supported Models:
58 | GPT-NeoX, Pythia, GPT-J, Opt, Llama
59 | Pseudo-Supported Models:
60 |
61 | BERT (testing required to validate implementation)
62 | Notes:
63 |
64 | Performing the operation in FP16 mode halves the memory requirements, but will massively slow down the process of loading up the models on memory; Always outputting in fp16 is preferable to save in storage space, especially if the original weights were quantized down to 16bit already. But if your original models are using 32bit precision, then be sure whether you wish to halve the precision of the resulting file or not.
65 |
66 | Model loading is automatic; the script determines the model type and adjusts accordingly, no special command-line flags required. Current GPT-NeoX support is hacky, it tends to have an error mid-merge with 20B; it might work on GPT-NeoX and Pythia models of a smaller size 6b or lower for now until a solution is implemented.
67 |
68 | Original Git [Collaborators: LostRuins aka Concedo and TeH_Venom aka TehVenomm and Chasm aka Digitous aka Erik]
69 | https://github.com/TehVenomm/LM_Transformers_BlockMerge
70 |
71 | ------------------------------------
72 |
73 | [[Enhanced Mixer]]
74 |
75 | Simple tool that allows a user to select two model of the same architecture and param size and merge. Settings for merge percentage and more are documented inside the script. Simply run to select which models to merge.
76 |
77 | No GitHub Page [Original Author: LostRuins | Enhancements by: Digitous]
78 |
79 | ------------------------------------
80 |
81 | Authors GitHub Pages:
82 | https://github.com/LostRuins
83 | https://github.com/TehVenomm
84 | https://github.com/Digitous
85 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MDEL
2 |
3 | Multi-Domain Expert Learning
4 |
5 | # Environment Setup
6 |
7 | To set up the development environment, run `make setup_dev`. This will setup the
8 | pre-commit hooks.
9 |
10 | ## Creating Expert Datasets
11 |
12 | First, make sure you followed the Environment Setup guidelines.
13 |
14 | To create an expert dataset using the Pile data, follow these steps:
15 |
16 | 1. Download the Pile shard 1 data: `./scripts/get_pile_shard1_data.sh`
17 | 2. To set the domain, edit the variable `SUBSET_NAME` in
18 | `scripts/create_domain_pile_mix.sh`. This should be set to a valid value of
19 | the Pile's variable `pile_set_name`. A list of valid values can be found
20 | below.
21 | 3. Run the above script to process the dataset
22 | 4. Authenticate into Hugginface:
23 | `export HF_ACCESS_TOKEN={YOUR HUGGINGFACE TOKEN}`
24 | 5. Set the dataset name in `scripts/upload_to_hf.sh`
25 | 6. Run the above script to upload the processed dataset to HuggingFace
26 |
27 | ### Pile Subsets
28 |
29 | - Pile-CC
30 | - PubMed Central
31 | - Books3†
32 | - OpenWebText2
33 | - ArXiv
34 | - Github
35 | - FreeLaw
36 | - Stack Exchange
37 | - USPTO Backgrounds
38 | - PubMed Abstracts
39 | - Gutenberg (PG-19)†
40 | - OpenSubtitles†
41 | - Wikipedia (en)†
42 | - DM Mathematics†
43 | - Ubuntu IRC
44 | - BookCorpus2
45 | - EuroParl†
46 | - HackerNews
47 | - YoutubeSubtitles
48 | - PhilPapers
49 | - NIH ExPorter
50 | - Enron Emails†
51 |
52 | # Training Expert Models
53 |
54 | 1. Clone this repo and follow the Environment Setup instructions
55 | 2. Set up HF authentication: `export HUGGING_FACE_HUB_TOKEN=[FILL ME]`
56 | 3. Set up W&B authentication: `export WANDB_API_KEY=[FILL ME]`
57 | 4. Edit the variable `DATASET` in script `src/mdel/train.sh` to match a valid
58 | dataset name on the
59 | [MDEL HF](https://huggingface.co/Multi-Domain-Expert-Layers).
60 | 5. Run the above script in background mode to start the training: `./train.sh &`
61 | 6. The trained model should be uploaded to the MDEL HF
62 |
63 | # Merging Expert Models
64 |
65 | 1. Clone this repo and follow the Environment Setup instructions
66 | 2. Set up HF authentication: `export HUGGING_FACE_HUB_TOKEN=[FILL ME]`
67 | 3. Run the merge script
68 |
69 | ```bash
70 | python src/mdel/merge_experts.py \
71 | --hf-repo your_hf_username/desired_name_of_merged_model \
72 | -e mdel/expert_1 \
73 | -e mdel/expert_2 \
74 | -e mdel/expert_n
75 | ```
76 |
77 | # Evaluating Perplexity of Models
78 |
79 | 1. Clone this repo and follow the Environment Setup instructions
80 | 2. Set up HF authentication: `export HUGGING_FACE_HUB_TOKEN=[FILL ME]`
81 | 3. Run the perplexity script
82 |
83 | ```bash
84 | python3 src/mdel/calculate_perplexity.py \
85 | --model Multi-Domain-Expert-Layers/expert-arxiv \
86 | --dataset Multi-Domain-Expert-Layers/arxiv \
87 | --split validation_domain
88 | ```
89 |
90 | # References
91 |
92 | Gao, L., Biderman, S., Black, S., Golding, L., Hoppe, T., Foster, C., ... &
93 | Leahy, C. (2020).The pile: An 800gb dataset of diverse text for language
94 | modeling. _arXiv preprint arXiv:2101.00027_.
95 |
--------------------------------------------------------------------------------
/clustering/download.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import logging
3 |
4 | import trafilatura
5 |
6 | from transformers import pipeline
7 | from transformers import AutoTokenizer
8 |
9 | import numpy as np
10 |
11 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12 |
13 | max_embedding_characters = 128 # This is a deliberately low value, as the current model is not intended for document embedding
14 |
15 | feature_extractor_checkpoint = 'sentence-transformers/LaBSE'
16 | tokenizer_checkpoint = 'gpt2'
17 |
18 | feature_extractor = pipeline('feature-extraction', framework='pt', model=feature_extractor_checkpoint)
19 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint)
20 |
21 | def fetch_and_parse(url):
22 | try:
23 | response = requests.get(url, timeout=10)
24 |
25 | response.raise_for_status()
26 | except (requests.HTTPError, requests.ConnectionError, requests.Timeout) as error:
27 | logging.error(f'Failed to fetch {url}: {error}')
28 |
29 | return None, None
30 |
31 | content = response.text
32 |
33 | markdown = trafilatura.extract(content, output_format='txt', include_formatting=True, \
34 | include_tables=True, include_images=True, no_fallback=True, include_links=True)
35 |
36 | return content, markdown
37 |
38 | def embed(text):
39 | embedding = feature_extractor(text)
40 |
41 | return embedding
42 |
43 | def tokenize(text):
44 | tokens = tokenizer.encode(text)
45 |
46 | return tokens
47 |
48 | def process_url(url):
49 | content, markdown = fetch_and_parse(url)
50 |
51 | content_short = content[:max_embedding_characters]
52 |
53 | tokens = tokenize(content)
54 | embedding = embed(content_short)
55 |
56 | embedding = np.array(embedding)
57 |
58 | return content, markdown, tokens, embedding
59 |
60 | def main():
61 | url = 'https://huggingface.co'
62 |
63 | content, markdown, tokens, embedding = process_url(url)
64 |
65 | for current in [content, markdown, embedding.shape]:
66 | print(f'{"-" * 32}\n{current}')
67 |
68 | print('-' * 32)
69 |
70 | if __name__ == '__main__':
71 | main()
72 |
--------------------------------------------------------------------------------
/clustering/feature_extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from transformers import AutoTokenizer, AutoModelForCausalLM
4 |
5 |
6 | class FeatureExtractor:
7 | def __init__(self, device='cpu', model_id='bigscience/bloom-560m', num_decoder_blocks=8):
8 | self.device = device
9 |
10 | self.num_decoder_blocks = num_decoder_blocks
11 | self.model_id = model_id
12 |
13 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
14 |
15 | self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
16 |
17 | h = self.model.transformer.h[:num_decoder_blocks] # Note that this will change for different families of models
18 | self.model.transformer.h = h
19 |
20 | self.model = self.model.to(device)
21 |
22 |
23 | def encode(self, text):
24 | tokens = self.tokenizer(text, padding=True, return_tensors='pt').to(self.device)
25 |
26 | output = self.model(**tokens, output_hidden_states=True).hidden_states[-1]
27 | output = output.detach().cpu().numpy()
28 |
29 | return output
30 |
31 |
32 | def __call__(self, text):
33 | output = self.encode(text)
34 |
35 | return output
36 |
37 |
38 | def main():
39 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
40 | print(f'Using {device} device')
41 |
42 | feature_extractor = FeatureExtractor(device=device)
43 |
44 | output = feature_extractor('Hello world!')
45 | print(output)
46 |
47 |
48 | if __name__ == '__main__':
49 | main()
50 |
--------------------------------------------------------------------------------
/clustering/hierarchical_clustering.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn import functional as F
7 | from torch import arange, argmax
8 | from tqdm import tqdm
9 | from collections import Counter
10 |
11 | import uuid
12 |
13 | import numpy as np
14 | from fast_pytorch_kmeans import KMeans
15 |
16 | from feature_extractor import FeatureExtractor
17 | from memmap_utils import np_memmap, get_np_memmap_length
18 |
19 |
20 | class ClusterAnalysis(nn.Module):
21 | def __init__(
22 | self,
23 | mmap_file=None,
24 | embed_dim=128,
25 | dtype=np.float32,
26 | ):
27 | super().__init__()
28 |
29 | self.mmap_file = mmap_file
30 | self.embed_dim = embed_dim
31 |
32 | self.dtype = dtype
33 |
34 | self.clusters = {}
35 | self.span_to_cluster_label = {}
36 |
37 |
38 | @staticmethod
39 | def _cluster_one_batch(
40 | true_k,
41 | spans,
42 | clusters,
43 | span_to_cluster_label,
44 | level,
45 | cluster_embeddings,
46 | min_overlap_merge_cluster,
47 | device
48 | ):
49 | with torch.no_grad():
50 | embeddings = torch.from_numpy(cluster_embeddings)
51 |
52 | km = KMeans(n_clusters=true_k, mode='cosine')
53 | km_labels = km.fit_predict(embeddings.to(device=device, dtype=torch.float32)).tolist()
54 |
55 | embeddings = None
56 |
57 | if not clusters:
58 | label_to_label = {}
59 |
60 | for span, label in zip(spans, km_labels):
61 | label = (label, level)
62 |
63 | if label not in label_to_label:
64 | label_to_label[label] = (span[0], level)
65 |
66 | label = label_to_label[label]
67 |
68 | clusters[label] = clusters.get(label, []) +[ span]
69 | span_to_cluster_label[span] = label
70 |
71 | output = list(clusters.keys())
72 |
73 | return output
74 |
75 | tmp_cluster = {}
76 |
77 | for span, label in zip(spans, km_labels):
78 | tmp_cluster[label] = tmp_cluster.get(label, [])+[span]
79 |
80 | new_labels = []
81 |
82 | for a_cluster in tmp_cluster.values():
83 | for span in a_cluster:
84 | need_labels = [span for span in a_cluster if span not in span_to_cluster_label or span_to_cluster_label[span][1] != level]
85 | cluster_labels = [span_to_cluster_label[span] for span in a_cluster if span in span_to_cluster_label and span_to_cluster_label[span][1] == level]
86 |
87 | if not need_labels:
88 | continue
89 |
90 | if not cluster_labels:
91 |
92 | label = (span[0], level)
93 |
94 | else:
95 | most_common = Counter(cluster_labels).most_common(1)[0]
96 |
97 | if most_common[1] < min_overlap_merge_cluster:
98 | label = (span[0], level)
99 |
100 | else:
101 | label = most_common[0]
102 |
103 | new_labels.append(label)
104 |
105 | for span in need_labels:
106 | clusters[label] = clusters.get(label, []) + [span]
107 | span_to_cluster_label[span] = label
108 |
109 | return new_labels
110 |
111 |
112 | def create_hiearchical_clusters(
113 | self,
114 | force_recluster_idxs=None,
115 | max_level=4,
116 | max_cluster_size=32, # Small value for debug purposes
117 | min_overlap_merge_cluster=2,
118 | prefered_leaf_node_size=None,
119 | kmeans_batch_size=250000,
120 | use_tqdm=False,
121 | device='cuda:0'
122 | ):
123 | mmap_file = self.mmap_file
124 | embed_dim = self.embed_dim
125 | dtype = self.dtype
126 |
127 | mmap_len = get_np_memmap_length(mmap_file, [0, embed_dim], dtype=dtype)
128 |
129 | clusters = self.clusters
130 | span_to_cluster_label = self.span_to_cluster_label
131 |
132 | if force_recluster_idxs:
133 | force_recluster_idxs = set(force_recluster_idxs)
134 | else:
135 | force_recluster_idxs = ()
136 |
137 | already_clustered = set([span[0] for span in span_to_cluster_label if span[1] == 0 and span[0] not in force_recluster_idxs])
138 |
139 | idxs = []
140 |
141 | if force_recluster_idxs:
142 | idxs = list(force_recluster_idxs)
143 | force_recluster_idxs = None
144 |
145 | idxs.extend([idx for idx in range(mmap_len) if idx not in already_clustered])
146 |
147 | if not idxs:
148 | return
149 |
150 | already_clustered = list(already_clustered)
151 |
152 | if len(already_clustered) > int(0.5 * kmeans_batch_size):
153 | idxs.extend(random.sample(already_clustered, int(0.5 * kmeans_batch_size)))
154 | else:
155 | idxs.extend(already_clustered)
156 |
157 | already_clustered = None
158 |
159 | idxs.extend([span[0] for span in span_to_cluster_label if span[1] != 0])
160 | idxs = list(set(idxs))
161 | random.shuffle(idxs)
162 |
163 | if not prefered_leaf_node_size:
164 | prefered_leaf_node_size= int(max_cluster_size * 0.7)
165 |
166 | for level in range(max_level):
167 | all_spans = [(idx, level) for idx in idxs]
168 | len_spans = len(all_spans)
169 |
170 | step_size = int(0.7 * kmeans_batch_size)
171 | num_times = max(3, math.ceil(len_spans / step_size))
172 |
173 | if use_tqdm:
174 | num_times_2 = tqdm.tqdm(range(num_times))
175 |
176 | else:
177 | num_times_2 = range(num_times)
178 |
179 | for times in num_times_2:
180 | max_rng = min(len_spans, step_size)
181 |
182 | spans = all_spans[:max_rng]
183 |
184 | not_already_clustered = [span for span in all_spans[:max_rng - step_size] if span not in span_to_cluster_label]
185 |
186 | if len(not_already_clustered) > int(0.5 * kmeans_batch_size):
187 | spans.extend(random.sample(not_already_clustered, int(0.5 * kmeans_batch_size)))
188 | else:
189 | spans.extend(not_already_clustered)
190 |
191 | if len(spans) == 0: break
192 |
193 | already_clustered = [span for span in all_spans[:max_rng - step_size] if span in span_to_cluster_label]
194 |
195 | if len(already_clustered) > int(0.5 * kmeans_batch_size):
196 | spans.extend(random.sample(already_clustered, int(0.5 * kmeans_batch_size)))
197 |
198 | else:
199 | spans.extend(already_clustered)
200 |
201 | embedding_idxs = [span[0] for span in spans]
202 |
203 | if level == 0:
204 | true_k = int(len(embedding_idxs) / prefered_leaf_node_size)
205 |
206 | else:
207 | true_k = int(len(embedding_idxs ) / max_cluster_size)
208 |
209 | cluster_embeddings = np_memmap(mmap_file, shape=[mmap_len, embed_dim], idxs=embedding_idxs, dtype=dtype)
210 |
211 | new_labels = self._cluster_one_batch(true_k, spans, clusters, span_to_cluster_label, level, cluster_embeddings, min_overlap_merge_cluster, device)
212 |
213 | if not new_labels:
214 | break
215 |
216 | need_more = False
217 |
218 | assert prefered_leaf_node_size <= max_cluster_size, 'prefered_leaf_node_size Must not exceed max_cluster_size'
219 |
220 | if times <= num_times - 2:
221 | for label in new_labels:
222 | if len(clusters[label]) < prefered_leaf_node_size:
223 | del clusters[label]
224 |
225 | need_more = True
226 |
227 | if not need_more:
228 | break
229 |
230 | idxs = [val[0][0] for key, val in clusters.items() if key[1] == level]
231 |
232 | if len(idxs) < max_cluster_size:
233 | break
234 |
235 |
236 | def main():
237 | cluster_analysis = ClusterAnalysis(
238 | mmap_file='output/embeddings.mmap',
239 | embed_dim=1024
240 | )
241 |
242 | cluster_analysis.create_hiearchical_clusters()
243 |
244 | print(list(cluster_analysis.clusters.keys()))
245 |
246 |
247 | if __name__ == '__main__':
248 | main()
249 |
--------------------------------------------------------------------------------
/clustering/memmap_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 |
5 |
6 | def is_contiguous(arr):
7 | start = None
8 | prev = None
9 | contiguous = True
10 |
11 | for idx in arr:
12 | if start is None:
13 | start = idx
14 | if prev is None or idx == prev + 1:
15 | prev = idx
16 |
17 | continue
18 |
19 | contiguous = False
20 |
21 | break
22 |
23 | return contiguous, start, idx + 1
24 |
25 |
26 | def np_memmap(file_name, data=None, idxs=None, shape=None, dtype=np.float32, offset=0, order='C'):
27 | if not file_name.endswith('.mmap'):
28 | file_name += '.mmap'
29 |
30 | if os.path.exists(file_name):
31 | mode = 'r+'
32 | else:
33 | mode = 'w+'
34 |
35 | if shape is None and data is not None:
36 | shape = data.shape
37 |
38 | if not shape:
39 | shape = [0, 1]
40 |
41 | memmap = np.memmap(file_name, mode=mode, dtype=dtype, shape=tuple(shape), offset=offset, order=order)
42 |
43 | if idxs:
44 | contiguous, start, end = is_contiguous(idxs)
45 |
46 | if data is not None:
47 | if tuple(shape) == tuple(data.shape):
48 | memmap[:] = data
49 | elif contiguous:
50 | memmap[start:end] = data
51 | else:
52 | memmap[idxs] = data
53 |
54 | return memmap
55 |
56 |
57 | def get_np_memmap_length(file_name, shape, dtype=np.float32):
58 | if not os.path.exists(file_name):
59 | return shape[0]
60 |
61 | else:
62 | size = np.dtype(dtype).itemsize * np.prod(shape[1:])
63 |
64 | return int(os.path.getsize(file_name) / size)
65 |
--------------------------------------------------------------------------------
/clustering/train_clusterer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 |
4 | import numpy as np
5 |
6 | import pickle
7 | from pathlib import Path
8 |
9 | import torch
10 |
11 | from tqdm.auto import tqdm
12 |
13 | from torch.utils.data import DataLoader, IterableDataset
14 |
15 | from kmeans_pytorch import KMeans as BalancedKMeans
16 |
17 | from transformers import pipeline
18 |
19 | from datasets import load_dataset
20 |
21 | from sklearn.manifold import TSNE
22 | from sklearn.decomposition import PCA
23 |
24 | import matplotlib.pyplot as plt
25 |
26 | from feature_extractor import FeatureExtractor
27 | from memmap_utils import np_memmap
28 |
29 |
30 | def load_model(path_to_model: Path):
31 | with open(path_to_model, 'rb') as file:
32 | output = pickle.load(file)
33 |
34 | file.close()
35 |
36 | return output
37 |
38 |
39 | def extract_features(corpus, feature_extractor, batch_size=32, max_chars=256):
40 | corpus = [element[:max_chars] for element in corpus]
41 | batches = np.array_split(corpus, len(corpus) // batch_size, axis=0)
42 |
43 | features = []
44 |
45 | for batch in tqdm(batches):
46 | batch = list(batch) # batches is a list of numpy arrays
47 |
48 | features_current = feature_extractor(batch)
49 | features_current = np.max(features_current, axis=1)
50 |
51 | features.append(features_current)
52 |
53 | features = np.concatenate(features, axis=0)
54 |
55 | return features
56 |
57 |
58 | def train_kmeans(features, n_clusters, path_to_kmeans, balanced=False, device='cpu'):
59 | kmeans = BalancedKMeans(n_clusters=n_clusters, device=device, balanced=balanced)
60 |
61 | batch_size = 512 # Hyperparameter
62 | batch_size = min(batch_size, len(features))
63 |
64 | batches = np.array_split(features, features.shape[0] // batch_size, axis=0)
65 |
66 | for idx, batch in tqdm(enumerate(batches)):
67 | kmeans.fit(torch.from_numpy(batch), iter_limit=20, online=True, iter_k=idx)
68 |
69 | with open(path_to_kmeans, 'wb+') as file:
70 | pickle.dump(kmeans, file)
71 |
72 | file.close()
73 |
74 | return kmeans
75 |
76 |
77 | def main(n_clusters=16, balanced=False, output_dir=Path('cluster_output/'), shuffle_dataset=True, take_sample=None, embed_only=False, seed=42, visualize=False):
78 | dataset_name_train = 'JeanKaddour/minipile'
79 | content_column_train = 'text'
80 |
81 | subset_train = None # 'p3'
82 | split_train = 'train'
83 |
84 | dataset_train = load_dataset(dataset_name_train, subset_train, split=split_train, streaming=(take_sample is not None))
85 |
86 | if shuffle_dataset:
87 | dataset_train = dataset_train.shuffle(seed=seed)
88 |
89 | corpus = []
90 |
91 | for idx, element in enumerate(dataset_train):
92 | corpus.append(element[content_column_train])
93 |
94 | if take_sample:
95 | if idx >= take_sample:
96 | break
97 |
98 | if not output_dir.is_dir():
99 | output_dir.mkdir(parents=True, exist_ok=True)
100 |
101 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
102 | print(f'Using {device} device')
103 |
104 | feature_extractor_batch_size = 1
105 |
106 | feature_extractor_checkpoint = 'bigscience/bloom-560m' # 'sentence-transformers/LaBSE' # 'xlm-roberta-large'
107 | feature_extractor = FeatureExtractor(device=device) # pipeline('feature-extraction', framework='pt', model=feature_extractor_checkpoint)
108 |
109 | features = extract_features(corpus, feature_extractor, batch_size=feature_extractor_batch_size)
110 |
111 | memmap_file_path = 'output/embeddings.mmap' # TODO: Create a configs.py file
112 |
113 | np_memmap(memmap_file_path, data=features)
114 |
115 | if embed_only:
116 | return
117 |
118 | path_to_kmeans = output_dir / 'kmeans.pkl'
119 | kmeans = train_kmeans(features, n_clusters, path_to_kmeans, balanced=balanced, device=device)
120 |
121 | if visualize:
122 | tsne = TSNE(n_components=2)
123 | features_2d = tsne.fit_transform(features)
124 |
125 | plt.scatter(features_2d[:, 0], features_2d[:, 1], c=kmeans.predict(torch.from_numpy(features)).cpu())
126 | plt.show()
127 |
128 | return kmeans
129 |
130 |
131 | if __name__ == '__main__':
132 | parser = argparse.ArgumentParser()
133 |
134 | parser.add_argument('--n-clusters', required=True, type=int)
135 | parser.add_argument('--balanced', action='store_true')
136 | parser.add_argument('--output-dir', required=True, type=Path)
137 | parser.add_argument('--eval-only', action='store_true')
138 | parser.add_argument('--shuffle-dataset', required=False, type=bool, default=True)
139 | parser.add_argument('--take-sample', required=False, type=int)
140 | parser.add_argument('--embed-only', required=False, type=bool, default=False)
141 | parser.add_argument('--visualize', action='store_true')
142 |
143 | args = parser.parse_args()
144 |
145 | if not args.eval_only:
146 | kmeans = main(
147 | n_clusters=args.n_clusters,
148 | balanced=args.balanced,
149 | output_dir=args.output_dir,
150 | take_sample=args.take_sample,
151 | shuffle_dataset=args.shuffle_dataset,
152 | embed_only=args.embed_only,
153 | visualize=args.visualize
154 | )
155 |
156 | path_to_kmeans = args.output_dir / 'kmeans.pkl'
157 | kmeans = load_model(path_to_kmeans)
158 |
159 |
160 | # Usage
161 |
162 | # python3 train_clusterer.py --n-clusters 4 --output-dir output/ --take-sample 128 --embed-only False --visualize
163 |
164 | # Warning! You need to install kmeans from https://github.com/kernelmachine/balanced-kmeans.git
165 |
166 | # cd ..
167 | # git clone https://github.com/kernelmachine/balanced-kmeans.git
168 | # cd balanced-kmeans
169 | # pip3 install -e .
170 |
--------------------------------------------------------------------------------
/conda-mdel.yml:
--------------------------------------------------------------------------------
1 | name: mdel
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - pytorch==2.0.*
9 | - pytorch-cuda==11.8
10 | - cuda
11 | - pip
12 | - git=2.35.1
13 | - ninja
14 | - make
15 | - cxx-compiler
16 | - wget
17 | - git-lfs
18 | - pip:
19 | - accelerate~=0.18.0
20 | - datasets~=2.11.0
21 | - deepspeed==0.9.2
22 | - evaluate~=0.4.0
23 | - pre-commit~=2.21.0
24 | - scikit-learn~=1.2.2
25 | - tqdm~=4.65.0
26 | - transformers~=4.28.1
27 | - ujson~=5.7.0
28 | - wandb==0.15.2
29 | - zstandard~=0.21.0
30 |
--------------------------------------------------------------------------------
/configs/fp16_1-3B_4M_bs_1.4T_tok_summit_pp3_mp2_256_nodes.yml:
--------------------------------------------------------------------------------
1 | # GPT-2 pretraining setup
2 | {
3 | # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages
4 | # across the node boundaries )
5 | "pipe-parallel-size": 3,
6 | "model-parallel-size": 2, # one copy of the model per node
7 |
8 | # model settings
9 | "num-layers": 24,
10 | "hidden-size": 2048,
11 | "num-attention-heads": 16,
12 | "seq-length": 2048,
13 | "max-position-embeddings": 2048,
14 | "norm": "layernorm",
15 | "pos-emb": "rotary",
16 | "no-weight-tying": true,
17 | "gpt_j_residual": false,
18 | "output_layer_parallelism": "column",
19 |
20 | # these should provide some speedup but takes a while to build, set to true if desired
21 | "scaled-upper-triang-masked-softmax-fusion": true,
22 | "bias-gelu-fusion": true,
23 |
24 | # init methods
25 | "init_method": "small_init",
26 | "output_layer_init_method": "wang_init",
27 |
28 | # optimizer settings
29 | "optimizer":
30 | {
31 | "type": "Adam",
32 | "params": { "lr": 0.0002, "betas": [0.9, 0.95], "eps": 1.0e-8 },
33 | },
34 | "min_lr": 0.00002,
35 | # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
36 | "zero_optimization":
37 | {
38 | "stage": 1,
39 | "allgather_partitions": True,
40 | "allgather_bucket_size": 500000000,
41 | "overlap_comm": True,
42 | "reduce_scatter": True,
43 | "reduce_bucket_size": 500000000,
44 | "contiguous_gradients": True,
45 | },
46 |
47 | # batch / data settings
48 | "train_batch_size": 2048, # across 1024 nodes... fingers crossed
49 | # "train_micro_batch_size_per_gpu": 4,
50 | "gradient_accumulation_steps": 4,
51 | "data-impl": "mmap",
52 | "split": "949,50,1",
53 |
54 | # activation checkpointing
55 | "checkpoint-activations": true,
56 | "checkpoint-num-layers": 1,
57 | "partition-activations": true,
58 | "synchronize-each-layer": true,
59 |
60 | # regularization
61 | "gradient_clipping": 1.0,
62 | "weight-decay": 0.1,
63 | "hidden-dropout": 0.0,
64 | "attention-dropout": 0.0,
65 |
66 | # precision settings
67 | "fp16": {
68 | "enabled": true,
69 | # "type": "bfloat16", # set bf16 as precision
70 | "loss_scale": 0,
71 | "loss_scale_window": 1000,
72 | "hysteresis": 2,
73 | "min_loss_scale": 1,
74 | },
75 |
76 | # "fp32_allreduce": True, # without a patch to torch, bf16 models have to do the allreduce in fp32
77 | # misc. training settings
78 | "train-iters": 250000,
79 | "lr-decay-iters": 250000,
80 | "distributed-backend": "nccl",
81 | "lr-decay-style": "cosine",
82 | "warmup": 0.01,
83 | "checkpoint-factor": 1000,
84 | "eval-interval": 1000,
85 | "eval-iters": 10,
86 |
87 | # logging
88 | "log-interval": 1,
89 | "steps_per_print": 1,
90 | "keep-last-n-checkpoints": 1000,
91 | "wall_clock_breakdown": true,
92 | }
93 |
--------------------------------------------------------------------------------
/configs/fp16_2-7B_4M_bs_1.4T_tok_summit_pp6_mp2_256nodes.yml:
--------------------------------------------------------------------------------
1 | # GPT-2 pretraining setup
2 | {
3 | # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages
4 | # across the node boundaries )
5 | "pipe-parallel-size": 6,
6 | "model-parallel-size": 2, # one copy of the model per node
7 |
8 | # model settings
9 | "num-layers": 32,
10 | "hidden-size": 2560,
11 | "num-attention-heads": 32,
12 | "seq-length": 2048,
13 | "max-position-embeddings": 2048,
14 | "norm": "layernorm",
15 | "pos-emb": "rotary",
16 | "no-weight-tying": true,
17 | "gpt_j_residual": false,
18 | "output_layer_parallelism": "column",
19 |
20 | # these should provide some speedup but takes a while to build, set to true if desired
21 | "scaled-upper-triang-masked-softmax-fusion": true,
22 | "bias-gelu-fusion": true,
23 |
24 | # init methods
25 | "init_method": "small_init",
26 | "output_layer_init_method": "wang_init",
27 |
28 | # optimizer settings
29 | "optimizer":
30 | {
31 | "type": "Adam",
32 | "params": { "lr": 0.00016, "betas": [0.9, 0.95], "eps": 1.0e-8 },
33 | },
34 | "min_lr": 0.000016,
35 | # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
36 | "zero_optimization":
37 | {
38 | "stage": 1,
39 | "allgather_partitions": True,
40 | "allgather_bucket_size": 500000000,
41 | "overlap_comm": True,
42 | "reduce_scatter": True,
43 | "reduce_bucket_size": 500000000,
44 | "contiguous_gradients": True,
45 | },
46 |
47 | # batch / data settings
48 | "train_batch_size": 2048, # across 1024 nodes... fingers crossed
49 | # "train_micro_batch_size_per_gpu": 4,
50 | "gradient_accumulation_steps": 8,
51 | "data-impl": "mmap",
52 | "split": "949,50,1",
53 |
54 | # activation checkpointing
55 | "checkpoint-activations": true,
56 | "checkpoint-num-layers": 1,
57 | "partition-activations": true,
58 | "synchronize-each-layer": true,
59 |
60 | # regularization
61 | "gradient_clipping": 1.0,
62 | "weight-decay": 0.1,
63 | "hidden-dropout": 0.0,
64 | "attention-dropout": 0.0,
65 |
66 | # precision settings
67 | "fp16": {
68 | "enabled": true,
69 | # "type": "bfloat16", # set bf16 as precision
70 | "loss_scale": 0,
71 | "loss_scale_window": 1000,
72 | "hysteresis": 2,
73 | "min_loss_scale": 1,
74 | },
75 |
76 | # "fp32_allreduce": True, # without a patch to torch, bf16 models have to do the allreduce in fp32
77 | # misc. training settings
78 | "train-iters": 250000,
79 | "lr-decay-iters": 250000,
80 | "distributed-backend": "nccl",
81 | "lr-decay-style": "cosine",
82 | "warmup": 0.01,
83 | "checkpoint-factor": 1000,
84 | "eval-interval": 1000,
85 | "eval-iters": 10,
86 |
87 | # logging
88 | "log-interval": 1,
89 | "steps_per_print": 1,
90 | "keep-last-n-checkpoints": 1000,
91 | "wall_clock_breakdown": true,
92 | }
93 |
--------------------------------------------------------------------------------
/configs/fp16_6-7B_4M_bs_1T_tok_summit_pp12_mp2_mbs2_512_nodes_real.yml:
--------------------------------------------------------------------------------
1 | # GPT-2 pretraining setup
2 | {
3 | # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages
4 | # across the node boundaries )
5 | "pipe-parallel-size": 12,
6 | "model-parallel-size": 2, # one copy of the model per node
7 |
8 | # model settings
9 | "num-layers": 32,
10 | "hidden-size": 4096,
11 | "num-attention-heads": 32,
12 | "seq-length": 2048,
13 | "max-position-embeddings": 2048,
14 | "norm": "layernorm",
15 | "pos-emb": "rotary",
16 | "no-weight-tying": true,
17 | "gpt_j_residual": false,
18 | "output_layer_parallelism": "column",
19 |
20 | # these should provide some speedup but takes a while to build, set to true if desired
21 | "scaled-upper-triang-masked-softmax-fusion": true,
22 | "bias-gelu-fusion": true,
23 |
24 | # init methods
25 | "init_method": "small_init",
26 | "output_layer_init_method": "wang_init",
27 |
28 | # optimizer settings
29 | "optimizer":
30 | {
31 | "type": "Adam",
32 | "params": { "lr": 0.00012, "betas": [0.9, 0.95], "eps": 1.0e-8 },
33 | },
34 | "min_lr": 0.000012,
35 | # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
36 | "zero_optimization":
37 | {
38 | "stage": 1,
39 | "allgather_partitions": True,
40 | "allgather_bucket_size": 500000000,
41 | "overlap_comm": True,
42 | "reduce_scatter": True,
43 | "reduce_bucket_size": 500000000,
44 | "contiguous_gradients": True,
45 | },
46 |
47 | # batch / data settings
48 | "train_batch_size": 2048, # across 1024 nodes... fingers crossed
49 | # "train_micro_batch_size_per_gpu": 4,
50 | # "gradient_accumulation_steps": 2,
51 | "gradient_accumulation_steps": 8,
52 | "data-impl": "mmap",
53 | "split": "949,50,1",
54 |
55 | # activation checkpointing
56 | "checkpoint-activations": true,
57 | "checkpoint-num-layers": 1,
58 | "partition-activations": true,
59 | "synchronize-each-layer": true,
60 |
61 | # regularization
62 | "gradient_clipping": 1.0,
63 | "weight-decay": 0.1,
64 | "hidden-dropout": 0.0,
65 | "attention-dropout": 0.0,
66 |
67 | # precision settings
68 | "fp16": {
69 | "enabled": true,
70 | # "type": "bfloat16", # set bf16 as precision
71 | "loss_scale": 0,
72 | "loss_scale_window": 1000,
73 | "hysteresis": 2,
74 | "min_loss_scale": 1,
75 | },
76 |
77 | # "fp32_allreduce": True, # without a patch to torch, bf16 models have to do the allreduce in fp32
78 | # misc. training settings
79 | "train-iters": 250000,
80 | "lr-decay-iters": 250000,
81 | "distributed-backend": "nccl",
82 | "lr-decay-style": "cosine",
83 | "warmup": 0.01,
84 | "checkpoint-factor": 1000,
85 | "eval-interval": 1000,
86 | "eval-iters": 10,
87 |
88 | # logging
89 | "log-interval": 1,
90 | "steps_per_print": 1,
91 | "keep-last-n-checkpoints": 1000,
92 | "wall_clock_breakdown": true,
93 | }
94 |
--------------------------------------------------------------------------------
/distillation_sparsification/README.md:
--------------------------------------------------------------------------------
1 | ## Tools for sparsifying and distilling models
2 |
3 |
--------------------------------------------------------------------------------
/distillation_sparsification/datautils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def set_seed(seed):
6 | np.random.seed(seed)
7 | torch.random.manual_seed(seed)
8 |
9 |
10 | def get_wikitext2(nsamples, seed, seqlen, model):
11 | from datasets import load_dataset
12 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
13 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
14 |
15 | from transformers import AutoTokenizer, LlamaTokenizer
16 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
17 | trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt')
18 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
19 |
20 | import random
21 | random.seed(seed)
22 | trainloader = []
23 | for _ in range(nsamples):
24 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
25 | j = i + seqlen
26 | inp = trainenc.input_ids[:, i:j]
27 | tar = inp.clone()
28 | tar[:, :-1] = -100
29 | trainloader.append((inp, tar))
30 | return trainloader, testenc
31 |
32 | def get_ptb(nsamples, seed, seqlen, model):
33 | from datasets import load_dataset
34 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
35 | testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test')
36 |
37 | from transformers import AutoTokenizer, LlamaTokenizer
38 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
39 | trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt')
40 | testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt')
41 |
42 | import random
43 | random.seed(seed)
44 | trainloader = []
45 | for _ in range(nsamples):
46 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
47 | j = i + seqlen
48 | inp = trainenc.input_ids[:, i:j]
49 | tar = inp.clone()
50 | tar[:, :-1] = -100
51 | trainloader.append((inp, tar))
52 | return trainloader, testenc
53 |
54 | def get_c4(nsamples, seed, seqlen, model):
55 | from datasets import load_dataset
56 | traindata = load_dataset(
57 | 'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train'
58 | )
59 | valdata = load_dataset(
60 | 'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation'
61 | )
62 |
63 | from transformers import AutoTokenizer
64 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
65 |
66 | import random
67 | random.seed(seed)
68 | trainloader = []
69 | for _ in range(nsamples):
70 | while True:
71 | i = random.randint(0, len(traindata) - 1)
72 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
73 | if trainenc.input_ids.shape[1] >= seqlen:
74 | break
75 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
76 | j = i + seqlen
77 | inp = trainenc.input_ids[:, i:j]
78 | tar = inp.clone()
79 | tar[:, :-1] = -100
80 | trainloader.append((inp, tar))
81 |
82 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
83 | valenc = valenc.input_ids[:, :(256 * seqlen)]
84 |
85 | class TokenizerWrapper:
86 | def __init__(self, input_ids):
87 | self.input_ids = input_ids
88 | valenc = TokenizerWrapper(valenc)
89 |
90 | return trainloader, valenc
91 |
92 |
93 | def get_loaders(
94 | name, nsamples=128, seed=0, seqlen=2048, model=''
95 | ):
96 | if 'wikitext2' in name:
97 | return get_wikitext2(nsamples, seed, seqlen, model)
98 | if 'ptb' in name:
99 | return get_ptb(nsamples, seed, seqlen, model)
100 | if 'c4' in name:
101 | return get_c4(nsamples, seed, seqlen, model)
102 |
--------------------------------------------------------------------------------
/distillation_sparsification/distill.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import logging
4 | import os
5 | import sys
6 | import time
7 | from collections import defaultdict
8 | from pathlib import Path
9 | from typing import Dict, List, Tuple
10 |
11 | import numpy as np
12 | import torch
13 | from torch import nn
14 | from torch.utils.data import DataLoader
15 | from transformers import Trainer
16 | from make_student import *
17 | from utils import label_smoothed_nll_loss, freeze_params
18 | from make_student import *
19 |
20 | def get_falcon(model):
21 | def skip(*args, **kwargs):
22 | pass
23 |
24 | torch.nn.init.kaiming_uniform_ = skip
25 | torch.nn.init.uniform_ = skip
26 | torch.nn.init.normal_ = skip
27 | tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
28 | model = AutoModelForCausalLM.from_pretrained(model, trust_remote_code=True, torch_dtype=torch.bfloat16)
29 | model.seqlen = 2048
30 | return model, tokenizer
31 |
32 |
33 |
34 | class DistillationTrainer(Trainer):
35 | # length_field_name should possibly be part of TrainingArguments instead
36 | def __init__(self, length_field_name=None, swap_lang_prob=3, alpha_ce=0.5, alpha_hid=0.5, alpha_clm=0.5, temperature=2.0, all_teachers={}, normalize_hidden=False, *args, **kwargs):
37 | super().__init__(*args, **kwargs)
38 | self.length_field_name = length_field_name
39 |
40 | self.t_model, self.tokenizer = get_falcon("tiiuae/falcon-7b")
41 |
42 | # self.s_model, layer_ids = create_student_by_copying_alternating_layers(t_model, d=16, save_path='student/')
43 |
44 | freeze_params(self.t_model)
45 |
46 | self.d_matches = get_layers_to_supervise(
47 | n_student=16, n_teacher=32
48 | )
49 |
50 | self.restrict_ce_to_mask = False
51 | self.alpha_ce = alpha_ce
52 | self.alpha_hid = alpha_hid
53 | self.alpha_clm = alpha_clm
54 | self.temperature = temperature
55 | self.normalize_hidden = normalize_hidden
56 |
57 | self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
58 | self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
59 |
60 |
61 | def calc_hidden_loss(self, attention_mask, lm_labels,hidden_states, hidden_states_T, matches, normalize_hidden):
62 | """MSE(student_hid, teacher_hid[matches]). Called "Intermediate supervision" in paper. Inspired by TinyBERT."""
63 | msg = "expected list or tuple for hidden_states, got tensor of shape: "
64 | assert not isinstance(hidden_states, torch.Tensor), f"{msg}{hidden_states.shape}"
65 | assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}"
66 | if self.restrict_ce_to_mask:
67 | mask = lm_labels > -1 # (bs, seq_length, voc_size)
68 | else:
69 | mask = attention_mask # (bs, seq_length, voc_size)
70 |
71 | mask = mask.to(hidden_states[0])
72 | valid_count = mask.sum() * hidden_states[0].size(-1)
73 | s_states = torch.stack([hidden_states[i] for i in range(len(matches))])
74 | t_states = torch.stack([hidden_states_T[j] for j in matches])
75 | assert s_states.shape == t_states.shape, f"{s_states.shape} != {t_states.shape}"
76 | if normalize_hidden:
77 | s_states = nn.functional.layer_norm(s_states, s_states.shape[1:])
78 | t_states = nn.functional.layer_norm(t_states, t_states.shape[1:])
79 | mse = nn.functional.mse_loss(s_states, t_states, reduction="none")
80 | masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count
81 | return masked_mse
82 |
83 | def calc_ce_loss(self, attention_mask, lm_labels, s_logits, t_logits):
84 | """Copy pasted from distillbert (transformers/examples/distillation/)"""
85 | # mask has False at padding_idx
86 | if self.restrict_ce_to_mask:
87 | mask = (lm_labels > -1) # (bs, seq_length, voc_size)
88 | else:
89 | mask = attention_mask # (bs, seq_length, voc_size)
90 |
91 | sel_mask = mask.unsqueeze(-1).expand_as(s_logits)
92 | vocab_size = s_logits.size(-1)
93 | s_logits_slct = torch.masked_select(s_logits, sel_mask) # (bs * seq_length * voc_size) modulo the 1s in mask
94 | t_logits_slct = torch.masked_select(t_logits, sel_mask) # (bs * seq_length * voc_size) modulo the 1s in mask
95 | s_logits_slct = s_logits_slct.view(-1, vocab_size) # (bs * seq_length, voc_size) modulo the 1s in mask
96 | t_logits_slct = t_logits_slct.view(-1, vocab_size) # (bs * seq_length, voc_size) modulo the 1s in mask
97 | assert t_logits_slct.size() == s_logits_slct.size()
98 | loss_ce = (
99 | self.ce_loss_fct(
100 | nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1),
101 | nn.functional.softmax(t_logits_slct / self.temperature, dim=-1),
102 | )
103 | * (self.temperature) ** 2
104 | )
105 | return loss_ce
106 |
107 | def compute_loss(self, model, inputs, return_outputs=False):
108 | pad_token_id = self.tokenizer.pad_token_id
109 | input_ids, attn_mask, labels = inputs["input_ids"], inputs["attention_mask"], inputs["labels"]
110 | lm_labels = input_ids.new(input_ids.size()).copy_(input_ids)
111 | lm_labels[~attn_mask] = -100 # previously `lm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
112 | # sanity checks
113 | assert 0 <= input_ids.min() <= input_ids.max() < self.tokenizer.vocab_size
114 | # noinspection PyCallingNonCallable
115 | s_outputs = model(
116 | input_ids,
117 | attention_mask=None,
118 | )
119 |
120 | t_outputs = self.t_model(
121 | input_ids,
122 | attention_mask=None,
123 | )
124 | print(5)
125 | s_logits, s_hidden_states = s_outputs["logits"], s_outputs["hidden_states"]
126 | t_logits, t_hidden_states = t_outputs["logits"], t_outputs["hidden_states"]
127 |
128 |
129 | if self.alpha_ce > 0.0:
130 | loss_ce = self.calc_ce_loss(attn_mask, lm_labels, s_logits, t_logits)
131 |
132 | if self.alpha_hid > 0.0:
133 | hid_loss = self.calc_hidden_loss(
134 | attn_mask,
135 | lm_labels,
136 | s_hidden_states,
137 | t_hidden_states,
138 | self.d_matches,
139 | normalize_hidden=self.normalize_hidden,
140 | )
141 |
142 | if self.alpha_clm > 0.0:
143 | shift_logits = s_logits[..., :-1, :].contiguous()
144 | shift_labels = lm_labels[..., 1:].contiguous()
145 | loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
146 |
147 | loss = self.alpha_ce * loss_ce + self.alpha_clm * loss_clm + self.alpha_hid * hid_loss
148 |
149 | return (loss, loss_ce, loss_clm, hid_loss, s_outputs) if return_outputs else (loss, loss_ce, loss_clm, hid_loss)
150 |
151 |
152 |
--------------------------------------------------------------------------------
/distillation_sparsification/lion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.optimizer import Optimizer
3 |
4 |
5 | class Lion(Optimizer):
6 | r"""Implements Lion algorithm."""
7 |
8 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0):
9 | """Initialize the hyperparameters.
10 |
11 | Args:
12 | params (iterable): iterable of parameters to optimize or dicts defining
13 | parameter groups
14 | lr (float, optional): learning rate (default: 1e-4)
15 | betas (Tuple[float, float], optional): coefficients used for computing
16 | running averages of gradient and its square (default: (0.9, 0.99))
17 | weight_decay (float, optional): weight decay coefficient (default: 0)
18 | """
19 |
20 | if not 0.0 <= lr:
21 | raise ValueError('Invalid learning rate: {}'.format(lr))
22 | if not 0.0 <= betas[0] < 1.0:
23 | raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0]))
24 | if not 0.0 <= betas[1] < 1.0:
25 | raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1]))
26 | defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
27 | super().__init__(params, defaults)
28 |
29 | @torch.no_grad()
30 | def step(self, closure=None):
31 | """Performs a single optimization step.
32 |
33 | Args:
34 | closure (callable, optional): A closure that reevaluates the model
35 | and returns the loss.
36 |
37 | Returns:
38 | the loss.
39 | """
40 | loss = None
41 | if closure is not None:
42 | with torch.enable_grad():
43 | loss = closure()
44 |
45 | for group in self.param_groups:
46 | for p in group['params']:
47 | if p.grad is None:
48 | continue
49 |
50 | # Perform stepweight decay
51 | p.data.mul_(1 - group['lr'] * group['weight_decay'])
52 |
53 | grad = p.grad
54 | state = self.state[p]
55 | # State initialization
56 | if len(state) == 0:
57 | # Exponential moving average of gradient values
58 | state['exp_avg'] = torch.zeros_like(p)
59 |
60 | exp_avg = state['exp_avg']
61 | beta1, beta2 = group['betas']
62 |
63 | # Weight update
64 | update = exp_avg * beta1 + grad * (1 - beta1)
65 | p.add_(torch.sign(update), alpha=-group['lr'])
66 | # Decay the momentum running average coefficient
67 | exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
68 |
69 | return loss
--------------------------------------------------------------------------------
/distillation_sparsification/lm_seqs_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import Dataset
4 |
5 | from utils import logger
6 |
7 |
8 | class LmSeqsDataset(Dataset):
9 | """Custom Dataset wrapping language modeling sequences.
10 |
11 | Each sample will be retrieved by indexing the list of token_ids and their corresponding lengths.
12 |
13 | Input:
14 | ------
15 | params: `NameSpace` parameters
16 | data: `List[np.array[int]]
17 | """
18 |
19 | def __init__(self, params, data):
20 | self.params = params
21 |
22 | self.token_ids = np.array(data)
23 | self.lengths = np.array([len(t) for t in data])
24 |
25 | self.check()
26 | self.remove_long_sequences()
27 | self.remove_empty_sequences()
28 | self.remove_unknown_sequences()
29 | self.check()
30 | self.print_statistics()
31 |
32 | def __getitem__(self, index):
33 | return (self.token_ids[index], self.lengths[index])
34 |
35 | def __len__(self):
36 | return len(self.lengths)
37 |
38 | def check(self):
39 | """
40 | Some sanity checks
41 | """
42 | assert len(self.token_ids) == len(self.lengths)
43 | assert all(self.lengths[i] == len(self.token_ids[i]) for i in range(len(self.lengths)))
44 |
45 | def remove_long_sequences(self):
46 | """
47 | Sequences that are too long are split by chunk of max_model_input_size.
48 | """
49 | max_len = self.params.max_model_input_size
50 | indices = self.lengths > max_len
51 | logger.info(f"Splitting {sum(indices)} too long sequences.")
52 |
53 | def divide_chunks(l, n):
54 | return [l[i : i + n] for i in range(0, len(l), n)]
55 |
56 | new_tok_ids = []
57 | new_lengths = []
58 | if self.params.mlm:
59 | cls_id, sep_id = self.params.special_tok_ids["cls_token"], self.params.special_tok_ids["sep_token"]
60 | else:
61 | cls_id, sep_id = self.params.special_tok_ids["bos_token"], self.params.special_tok_ids["eos_token"]
62 |
63 | for seq_, len_ in zip(self.token_ids, self.lengths):
64 | assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_
65 | if len_ <= max_len:
66 | new_tok_ids.append(seq_)
67 | new_lengths.append(len_)
68 | else:
69 | sub_seqs = []
70 | for sub_s in divide_chunks(seq_, max_len - 2):
71 | if sub_s[0] != cls_id:
72 | sub_s = np.insert(sub_s, 0, cls_id)
73 | if sub_s[-1] != sep_id:
74 | sub_s = np.insert(sub_s, len(sub_s), sep_id)
75 | assert len(sub_s) <= max_len
76 | assert (sub_s[0] == cls_id) and (sub_s[-1] == sep_id), sub_s
77 | sub_seqs.append(sub_s)
78 |
79 | new_tok_ids.extend(sub_seqs)
80 | new_lengths.extend([len(l) for l in sub_seqs])
81 |
82 | self.token_ids = np.array(new_tok_ids)
83 | self.lengths = np.array(new_lengths)
84 |
85 | def remove_empty_sequences(self):
86 | """
87 | Too short sequences are simply removed. This could be tuned.
88 | """
89 | init_size = len(self)
90 | indices = self.lengths > 11
91 | self.token_ids = self.token_ids[indices]
92 | self.lengths = self.lengths[indices]
93 | new_size = len(self)
94 | logger.info(f"Remove {init_size - new_size} too short (<=11 tokens) sequences.")
95 |
96 | def remove_unknown_sequences(self):
97 | """
98 | Remove sequences with a (too) high level of unknown tokens.
99 | """
100 | if "unk_token" not in self.params.special_tok_ids:
101 | return
102 | else:
103 | unk_token_id = self.params.special_tok_ids["unk_token"]
104 | init_size = len(self)
105 | unk_occs = np.array([np.count_nonzero(a == unk_token_id) for a in self.token_ids])
106 | indices = (unk_occs / self.lengths) < 0.5
107 | self.token_ids = self.token_ids[indices]
108 | self.lengths = self.lengths[indices]
109 | new_size = len(self)
110 | logger.info(f"Remove {init_size - new_size} sequences with a high level of unknown tokens (50%).")
111 |
112 | def print_statistics(self):
113 | """
114 | Print some statistics on the corpus. Only the master process.
115 | """
116 | if not self.params.is_master:
117 | return
118 | logger.info(f"{len(self)} sequences")
119 | # data_len = sum(self.lengths)
120 | # nb_unique_tokens = len(Counter(list(chain(*self.token_ids))))
121 | # logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)')
122 |
123 | # unk_idx = self.params.special_tok_ids['unk_token']
124 | # nb_unknown = sum([(t==unk_idx).sum() for t in self.token_ids])
125 | # logger.info(f'{nb_unknown} unknown tokens (covering {100*nb_unknown/data_len:.2f}% of the data)')
126 |
127 | def batch_sequences(self, batch):
128 | """
129 | Do the padding and transform into torch.tensor.
130 | """
131 | token_ids = [t[0] for t in batch]
132 | lengths = [t[1] for t in batch]
133 | assert len(token_ids) == len(lengths)
134 |
135 | # Max for paddings
136 | max_seq_len_ = max(lengths)
137 |
138 | # Pad token ids
139 | if self.params.mlm:
140 | pad_idx = self.params.special_tok_ids["pad_token"]
141 | else:
142 | pad_idx = self.params.special_tok_ids["unk_token"]
143 | tk_ = [list(t.astype(int)) + [pad_idx] * (max_seq_len_ - len(t)) for t in token_ids]
144 | assert len(tk_) == len(token_ids)
145 | assert all(len(t) == max_seq_len_ for t in tk_)
146 |
147 | tk_t = torch.tensor(tk_) # (bs, max_seq_len_)
148 | lg_t = torch.tensor(lengths) # (bs)
149 | return tk_t, lg_t
--------------------------------------------------------------------------------
/distillation_sparsification/make_student.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from pathlib import Path
3 | from typing import List, Tuple, Union
4 |
5 | import fire
6 | from torch import nn
7 |
8 | from transformers import AutoTokenizer, PreTrainedModel, AutoModelForCausalLM
9 | from transformers.utils import logging
10 |
11 | logger = logging.get_logger(__name__)
12 |
13 |
14 | def copy_layers(src_layers: nn.ModuleList, dest_layers: nn.ModuleList, layers_to_copy: List[int]) -> None:
15 | layers_to_copy = nn.ModuleList([src_layers[i] for i in layers_to_copy])
16 | assert len(dest_layers) == len(layers_to_copy), f"{len(dest_layers)} != {len(layers_to_copy)}"
17 | dest_layers.load_state_dict(layers_to_copy.state_dict())
18 |
19 |
20 | LAYERS_TO_COPY = {
21 | # maps num layers in teacher -> num_layers in student -> which teacher layers to copy.
22 | 32: {
23 | 1: [0],
24 | 2: [0, 31],
25 | 3: [0, 16, 31],
26 | # 4: [0, 10, 20, 31],
27 | # 6: [0, 3, 6, 9, 12, 31],
28 | 8: [0, 4, 8, 12, 16, 20, 24, 31],
29 | # 9: [0, 1, 3, 5, 7, 9, 11, 13, 31],
30 | 12: [0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 31],
31 | 16: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 31],
32 | 32: list(range(32))
33 | },
34 | }
35 |
36 | LAYERS_TO_SUPERVISE = {
37 | # maps num layers in student -> which teacher layers to copy.
38 | 6: {1: [5], 2: [3, 5], 3: [1, 4, 5], 4: [1, 2, 4, 5]},
39 | 12: {1: [11], 2: [5, 11], 3: [3, 7, 11], 6: [1, 3, 5, 8, 10, 11]},
40 | 16: {1: [15], 4: [4, 9, 12, 15], 8: [1, 3, 5, 7, 9, 11, 13, 15]},
41 | 32: {8: [3, 7, 11, 15, 19, 23, 27, 31], 16: [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31]},
42 | }
43 |
44 | def get_layers_to_supervise(n_student, n_teacher) -> List[int]:
45 | """Used or the --supervise_forward kwarg"""
46 | if n_student > n_teacher:
47 | raise ValueError(f"Cannot perform intermediate supervision for student {n_student} > teacher {n_teacher}")
48 | elif n_teacher == n_student:
49 | return list(range(n_teacher))
50 | elif n_student == 1:
51 | return [n_teacher - 1]
52 | else:
53 | return LAYERS_TO_SUPERVISE[n_teacher][n_student]
54 |
55 |
56 | def pick_layers_to_copy(n_student, n_teacher):
57 | try:
58 | val = LAYERS_TO_COPY[n_teacher][n_student]
59 | return val
60 | except KeyError:
61 | if n_student != n_teacher:
62 | warnings.warn(
63 | f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first"
64 | f" {n_student}"
65 | )
66 | return list(range(n_student))
67 |
68 |
69 | def create_student_by_copying_alternating_layers(
70 | teacher: Union[str, PreTrainedModel],
71 | save_path: Union[str, Path] = "student",
72 | d: Union[int, None] = None,
73 | copy_first_teacher_layers=False,
74 | d_layers_to_copy=None,
75 | **extra_config_kwargs,
76 | ) -> Tuple[PreTrainedModel, List[int], List[int]]:
77 | """Make a student by copying alternating layers from a teacher, save it to save_path.
78 | Args:
79 | teacher: str or PreTrainedModel if str, this will call AutoModelForCausalLM.from_pretrained(teacher) before
80 | copying layers
81 | save_path: where to save the student, defaults to student directory.
82 | d: how many Decoder layers should the student have, default is fully copy of teacher
83 | copy_first_teacher_layers: [bool] dont copy alternating layers, just the first e/d.
84 | **extra_config_kwargs: extra kwargs to pass to the student, by default the teacher config is used.
85 |
86 | Returns:
87 | student: new, smaller model. (Also saves it to save_path)
88 | d_layers_to_copy: list of which teacher decoder layers were used
89 | """
90 | _msg = "decoder_layers cannot be both None-- you would just have an identical teacher."
91 | assert (d is not None), _msg
92 | if isinstance(teacher, str):
93 | AutoTokenizer.from_pretrained(teacher).save_pretrained(save_path) # purely for convenience
94 | teacher = AutoModelForCausalLM.from_pretrained(teacher).eval()
95 | else:
96 | assert isinstance(teacher, PreTrainedModel), f"teacher must be a model or string got type {type(teacher)}"
97 | init_kwargs = teacher.config.to_diff_dict()
98 |
99 | try:
100 | teacher_d = teacher.config.decoder_layers
101 | if d is None:
102 | d = teacher_d
103 | init_kwargs.update({"n_layer": d})
104 | except AttributeError: # T5
105 | if hasattr(teacher.config, "n_layer"):
106 | teacher_d = teacher.config.n_layer
107 | else:
108 | teacher_d = teacher.config.n_layer
109 | if d is None:
110 | d = teacher_d
111 | if hasattr(teacher.config, "n_layer"):
112 | init_kwargs.update({"n_layer": d})
113 | else:
114 | init_kwargs.update({"n_layer": d})
115 |
116 | # Kwargs to instantiate student: teacher kwargs with updated layer numbers + **extra_config_kwargs
117 | init_kwargs.update(extra_config_kwargs)
118 |
119 | # Copy weights
120 | student_cfg = teacher.config_class(**init_kwargs)
121 | student = AutoModelForCausalLM.from_config(student_cfg, trust_remote_code=True)
122 | # Start by copying the full teacher state dict this will copy the first N teacher layers to the student.
123 | info = student.load_state_dict(teacher.state_dict(), strict=False)
124 | assert info.missing_keys == [], info.missing_keys # every student key should have a teacher keys.
125 |
126 | if copy_first_teacher_layers: # Our copying is done. We just log and save
127 | d_layers_to_copy = list(range(d))
128 | logger.info(
129 | f"Copied decoder layers {d_layers_to_copy}. Saving them to"
130 | f" {save_path}"
131 | )
132 | student.save_pretrained(save_path)
133 | return student, d_layers_to_copy
134 |
135 | # Decide which layers of the teacher to copy. Not exactly alternating -- we try to keep first and last layer.
136 | if d_layers_to_copy is None:
137 | d_layers_to_copy: List[int] = pick_layers_to_copy(d, teacher_d)
138 |
139 | try:
140 | if hasattr(
141 | teacher, "prophetnet"
142 | ):
143 | copy_layers(teacher.prophetnet.decoder.layers, student.prophetnet.decoder.layers, d_layers_to_copy)
144 | else:
145 | copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy)
146 | except AttributeError:
147 | copy_layers(teacher.transformer.h, student.transformer.h, d_layers_to_copy)
148 | logger.info(
149 | f"Copied decoder layers {d_layers_to_copy}. Saving them to {save_path}"
150 | )
151 | student.config.init_metadata = {
152 | "teacher_type": teacher.config.model_type,
153 | "copied_decoder_layers": d_layers_to_copy,
154 | }
155 | student.save_pretrained(save_path)
156 | # Save information about copying for easier reproducibility
157 |
158 | return student, d_layers_to_copy
159 |
160 |
161 | if __name__ == "__main__":
162 | fire.Fire(create_student_by_copying_alternating_layers)
--------------------------------------------------------------------------------
/distillation_sparsification/modelutils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | DEV = torch.device('cuda:0')
6 |
7 |
8 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
9 | if type(module) in layers or 'Linear' in str(type(module)):
10 | return {name: module}
11 | res = {}
12 |
13 | for name1, child in module.named_children():
14 | res.update(find_layers(
15 | child, layers=layers, name=name + '.' + name1 if name != '' else name1
16 | ))
17 | return res
18 |
--------------------------------------------------------------------------------
/distillation_sparsification/process_data.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 |
4 | import torch
5 | import torch.nn as nn
6 | import transformers
7 |
8 | from sparsegpt import *
9 | from modelutils import *
10 | import sys
11 | from transformers import AutoTokenizer, AutoModelForCausalLM
12 | from falcon7b import modelling_RW
13 | from datasets import load_dataset
14 | from make_student import *
15 | import multiprocessing
16 | from datasets import load_from_disk
17 | from itertools import chain
18 |
19 |
20 |
21 | tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b", trust_remote_code=True)
22 | if tokenizer.pad_token is None:
23 | tokenizer.add_special_tokens({'pad_token': '[PAD]'})
24 |
25 | ds = load_dataset('JeanKaddour/minipile')
26 | # print(ds.column_names)
27 |
28 | def preprocess_function(examples):
29 | return tokenizer(examples["text"], )
30 |
31 | block_size = 2048
32 |
33 | def group_texts(examples):
34 | # Concatenate all texts.
35 | concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
36 | total_length = len(concatenated_examples[list(examples.keys())[0]])
37 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
38 | # customize this part to your needs.
39 | if total_length >= block_size:
40 | total_length = (total_length // block_size) * block_size
41 | else:
42 | total_length = 0
43 | # Split by chunks of block_size.
44 | result = {
45 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
46 | for k, t in concatenated_examples.items()
47 | }
48 | result["labels"] = result["input_ids"].copy()
49 | del result['token_type_ids']
50 | return result
51 |
52 |
53 | tokenized_minipile = ds.map(
54 | preprocess_function,
55 | batched=True,
56 | num_proc=multiprocessing.cpu_count(),
57 | remove_columns=ds['validation'].column_names,
58 | )
59 |
60 | tokenized_minipile.save_to_disk("ds/raw")
61 |
62 | # tokenized_minipile = load_from_disk("ds/raw")
63 |
64 | ds = tokenized_minipile.map(group_texts, batched=True, num_proc=multiprocessing.cpu_count())
65 |
66 |
67 | ds.save_to_disk("ds/processed")
68 |
69 | # ds = load_from_disk("ds/processed")
70 | # print(ds)
71 | # train_dataset = ds["train"]
72 | # eval_dataset = ds["validation"]
73 | # test_dataset = ds["test"]
74 | # print(test_dataset)
75 | # eval_dataset = eval_dataset.select(
76 | # (
77 | # i for i in range(len(eval_dataset))
78 | # if len(eval_dataset[i]['input_ids']) == 2048
79 | # )
80 | # )
81 |
82 | # test_dataset = test_dataset.select(
83 | # (
84 | # i for i in range(len(test_dataset))
85 | # if len(test_dataset[i]['input_ids']) == 2048
86 | # )
87 | # )
88 |
89 | # train_dataset = train_dataset.select(
90 | # (
91 | # i for i in range(len(train_dataset))
92 | # if len(train_dataset[i]['input_ids']) == 2048
93 | # )
94 | # )
95 |
96 | # ds["train"] = train_dataset
97 | # ds["validation"] = eval_dataset
98 | # ds["test"] = test_dataset
99 |
100 | print(ds)
101 | # print(eval_dataset)
102 | # print(test_dataset)
103 | # print(train_dataset)
--------------------------------------------------------------------------------
/distillation_sparsification/quant.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | def quantize(x, scale, zero, maxq):
7 | q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
8 | return scale * (q - zero)
9 |
10 | class Quantizer(nn.Module):
11 |
12 | def __init__(self, shape=1):
13 | super(Quantizer, self).__init__()
14 | self.register_buffer('maxq', torch.tensor(0))
15 | self.register_buffer('scale', torch.zeros(shape))
16 | self.register_buffer('zero', torch.zeros(shape))
17 |
18 | def configure(
19 | self,
20 | bits, perchannel=False, sym=True,
21 | mse=False, norm=2.4, grid=100, maxshrink=.8,
22 | grouprows=1
23 | ):
24 | self.maxq = torch.tensor(2 ** bits - 1)
25 | self.perchannel = perchannel
26 | self.sym = sym
27 | self.mse = mse
28 | self.norm = norm
29 | self.grid = grid
30 | self.maxshrink = maxshrink
31 | self.grouprows = grouprows
32 |
33 | def find_params(self, x, weight=False):
34 | dev = x.device
35 | self.maxq = self.maxq.to(dev)
36 |
37 | shape = x.shape
38 | if self.perchannel:
39 | if weight:
40 | x = x.flatten(1)
41 | if self.grouprows > 1:
42 | x = x.reshape((x.shape[0] // self.grouprows, -1))
43 | else:
44 | if len(shape) == 4:
45 | x = x.permute([1, 0, 2, 3])
46 | x = x.flatten(1)
47 | if len(shape) == 3:
48 | x = x.reshape((-1, shape[-1])).t()
49 | if len(shape) == 2:
50 | x = x.t()
51 | else:
52 | x = x.flatten().unsqueeze(0)
53 |
54 | tmp = torch.zeros(x.shape[0], device=dev)
55 | xmin = torch.minimum(x.min(1)[0], tmp)
56 | xmax = torch.maximum(x.max(1)[0], tmp)
57 |
58 | if self.sym:
59 | xmax = torch.maximum(torch.abs(xmin), xmax)
60 | tmp = xmin < 0
61 | if torch.any(tmp):
62 | xmin[tmp] = -xmax[tmp]
63 | tmp = (xmin == 0) & (xmax == 0)
64 | xmin[tmp] = -1
65 | xmax[tmp] = +1
66 |
67 | self.scale = (xmax - xmin) / self.maxq
68 | if self.sym:
69 | self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
70 | else:
71 | self.zero = torch.round(-xmin / self.scale)
72 |
73 | if self.mse:
74 | best = torch.full([x.shape[0]], float('inf'), device=dev)
75 | for i in range(int(self.maxshrink * self.grid)):
76 | p = 1 - i / self.grid
77 | xmin1 = p * xmin
78 | xmax1 = p * xmax
79 | scale1 = (xmax1 - xmin1) / self.maxq
80 | zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
81 | q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
82 | q -= x
83 | q.abs_()
84 | q.pow_(self.norm)
85 | err = torch.sum(q, 1)
86 | tmp = err < best
87 | if torch.any(tmp):
88 | best[tmp] = err[tmp]
89 | self.scale[tmp] = scale1[tmp]
90 | self.zero[tmp] = zero1[tmp]
91 | if not self.perchannel:
92 | if weight:
93 | tmp = shape[0]
94 | else:
95 | tmp = shape[1] if len(shape) != 3 else shape[2]
96 | self.scale = self.scale.repeat(tmp)
97 | self.zero = self.zero.repeat(tmp)
98 |
99 | if weight:
100 | if self.grouprows > 1:
101 | self.scale = self.scale.unsqueeze(1).repeat(1, self.grouprows)
102 | self.zero = self.zero.unsqueeze(1).repeat(1, self.grouprows)
103 | shape = [-1] + [1] * (len(shape) - 1)
104 | self.scale = self.scale.reshape(shape)
105 | self.zero = self.zero.reshape(shape)
106 | return
107 | if len(shape) == 4:
108 | self.scale = self.scale.reshape((1, -1, 1, 1))
109 | self.zero = self.zero.reshape((1, -1, 1, 1))
110 | if len(shape) == 3:
111 | self.scale = self.scale.reshape((1, 1, -1))
112 | self.zero = self.zero.reshape((1, 1, -1))
113 | if len(shape) == 2:
114 | self.scale = self.scale.unsqueeze(0)
115 | self.zero = self.zero.unsqueeze(0)
116 |
117 | def quantize(self, x):
118 | if self.ready():
119 | return quantize(x, self.scale, self.zero, self.maxq)
120 | return x
121 |
122 | def enabled(self):
123 | return self.maxq > 0
124 |
125 | def ready(self):
126 | return torch.all(self.scale != 0)
127 |
--------------------------------------------------------------------------------
/distillation_sparsification/sparsegpt.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 |
4 | import torch
5 | import torch.nn as nn
6 | import transformers
7 |
8 | from quant import *
9 |
10 |
11 | DEBUG = False
12 |
13 | torch.backends.cuda.matmul.allow_tf32 = False
14 | torch.backends.cudnn.allow_tf32 = False
15 |
16 |
17 | class SparseGPT:
18 |
19 | def __init__(self, layer):
20 | self.layer = layer
21 | self.dev = self.layer.weight.device
22 | W = layer.weight.data.clone()
23 | if isinstance(self.layer, nn.Conv2d):
24 | W = W.flatten(1)
25 | if isinstance(self.layer, transformers.Conv1D):
26 | W = W.t()
27 | self.rows = W.shape[0]
28 | self.columns = W.shape[1]
29 | self.H = torch.zeros((self.columns, self.columns), device=self.dev)
30 | self.nsamples = 0
31 |
32 | def add_batch(self, inp, out, blocksize=1024):
33 | if DEBUG:
34 | self.inp1 = inp
35 | self.out1 = out
36 | if len(inp.shape) == 2:
37 | inp = inp.unsqueeze(0)
38 | tmp = inp.shape[0]
39 | if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
40 | if len(inp.shape) == 3:
41 | inp = inp.reshape((-1, inp.shape[-1]))
42 | inp = inp.t()
43 | self.H *= self.nsamples / (self.nsamples + tmp)
44 | self.nsamples += tmp
45 | inp = math.sqrt(2 / self.nsamples) * inp.float()
46 | self.H += inp.matmul(inp.t())
47 |
48 | def fasterprune(
49 | self, sparsity, prunen=0, prunem=0, blocksize=128, percdamp=.01
50 | ):
51 | W = self.layer.weight.data.clone()
52 | if isinstance(self.layer, nn.Conv2d):
53 | W = W.flatten(1)
54 | if isinstance(self.layer, transformers.Conv1D):
55 | W = W.t()
56 | W = W.float()
57 |
58 | if hasattr(self, 'quantizer'):
59 | if not self.quantizer.ready():
60 | self.quantizer.find_params(W, weight=True)
61 |
62 | tick = time.time()
63 |
64 | H = self.H
65 | del self.H
66 | dead = torch.diag(H) == 0
67 | H[dead, dead] = 1
68 | W[:, dead] = 0
69 |
70 | Losses = torch.zeros(self.rows, device=self.dev)
71 |
72 | damp = percdamp * torch.mean(torch.diag(H))
73 | diag = torch.arange(self.columns, device=self.dev)
74 | H[diag, diag] += damp
75 | H = torch.linalg.cholesky(H)
76 | H = torch.cholesky_inverse(H)
77 | H = torch.linalg.cholesky(H, upper=True)
78 | Hinv = H
79 |
80 | mask = None
81 |
82 | for i1 in range(0, self.columns, blocksize):
83 | i2 = min(i1 + blocksize, self.columns)
84 | count = i2 - i1
85 |
86 | W1 = W[:, i1:i2].clone()
87 | Q1 = torch.zeros_like(W1)
88 | Err1 = torch.zeros_like(W1)
89 | Losses1 = torch.zeros_like(W1)
90 | Hinv1 = Hinv[i1:i2, i1:i2]
91 |
92 | if prunen == 0:
93 | if mask is not None:
94 | mask1 = mask[:, i1:i2]
95 | else:
96 | tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2
97 | thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)]
98 | mask1 = tmp <= thresh
99 | else:
100 | mask1 = torch.zeros_like(W1) == 1
101 |
102 | for i in range(count):
103 | w = W1[:, i]
104 | d = Hinv1[i, i]
105 |
106 | if prunen != 0 and i % prunem == 0:
107 | tmp = W1[:, i:(i + prunem)] ** 2 / (torch.diag(Hinv1)[i:(i + prunem)].reshape((1, -1))) ** 2
108 | mask1.scatter_(1, i + torch.topk(tmp, prunen, dim=1, largest=False)[1], True)
109 |
110 | q = w.clone()
111 | q[mask1[:, i]] = 0
112 |
113 | if hasattr(self, 'quantizer'):
114 | q = quantize(
115 | q.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq
116 | ).flatten()
117 |
118 | Q1[:, i] = q
119 | Losses1[:, i] = (w - q) ** 2 / d ** 2
120 |
121 | err1 = (w - q) / d
122 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
123 | Err1[:, i] = err1
124 |
125 | W[:, i1:i2] = Q1
126 | Losses += torch.sum(Losses1, 1) / 2
127 |
128 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
129 |
130 | if DEBUG:
131 | self.layer.weight.data[:, :i2] = W[:, :i2]
132 | self.layer.weight.data[:, i2:] = W[:, i2:]
133 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
134 | print(torch.sum(Losses))
135 |
136 | torch.cuda.synchronize()
137 | print('time %.2f' % (time.time() - tick))
138 | print('error', torch.sum(Losses).item())
139 |
140 | if isinstance(self.layer, transformers.Conv1D):
141 | W = W.t()
142 | self.layer.weight.data = W.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
143 | if DEBUG:
144 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
145 |
146 | def free(self):
147 | if DEBUG:
148 | self.inp1 = None
149 | self.out1 = None
150 | self.H = None
151 | torch.cuda.empty_cache()
152 |
--------------------------------------------------------------------------------
/distillation_sparsification/test.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 |
4 | import torch
5 | import torch.nn as nn
6 | import transformers
7 |
8 | from sparsegpt import *
9 | from modelutils import *
10 | import sys
11 | #from transformers import AutoTokenizer, AutoModelForCausalLM
12 | from transformers import (
13 | CONFIG_MAPPING,
14 | MODEL_MAPPING,
15 | AutoConfig,
16 | AutoModelForCausalLM,
17 | AutoTokenizer,
18 | SchedulerType,
19 | default_data_collator,
20 | get_scheduler,
21 | )
22 |
23 | from falcon7b import modelling_RW
24 | from datasets import load_dataset
25 | from make_student import *
26 | import multiprocessing
27 | from transformers import DataCollatorForLanguageModeling
28 | from distill import DistillationTrainer
29 | from datasets import load_from_disk
30 | from torch.utils.data import DataLoader
31 |
32 | from accelerate import Accelerator, DistributedType, notebook_launcher
33 | from accelerate.logging import get_logger
34 | from accelerate.utils import set_seed
35 | from tqdm.auto import tqdm
36 |
37 | from transformers import Trainer, TrainingArguments
38 | import deepspeed
39 |
40 | ds = load_from_disk("processed_ds/processed")
41 | print(ds)
42 | train_dataset = ds["train"]
43 | eval_dataset = ds["validation"]
44 |
45 |
46 | eval_dataset = eval_dataset.select((i for i in range(len(eval_dataset)-1)))
47 |
48 | for k,v in eval_dataset[-1].items():
49 | print(k, torch.tensor(v).shape)
50 | # for k,v in eval_dataset[-1].items():
51 | # print(k, v.shape)
52 |
53 |
54 | # tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b", trust_remote_code=True)
55 | # data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
56 | # if tokenizer.pad_token is None:
57 | # tokenizer.add_special_tokens({'pad_token': '[PAD]'})
58 |
59 | # model = AutoModelForCausalLM.from_pretrained('student/', trust_remote_code=True, torch_dtype=torch.bfloat16)
60 |
61 | # # ds_config = {
62 | # # "tensor_parallel": {"tp_size": 1},
63 | # # "dtype": "fp16",
64 | # # "replace_with_kernel_inject": True,
65 | # # "replace_method": "auto",
66 | # # }
67 |
68 | # # ds_model = deepspeed.init_inference(model=model, config=ds_config)
69 |
70 | # per_device_train_batch_size = 2
71 | # per_device_eval_batch_size = 2
72 | # learning_rate = 3e-4
73 | # lr_scheduler_type = "cosine"
74 | # num_warmup_steps = 100
75 | # max_train_steps = 1_000
76 | # num_train_epochs = 1
77 | # weight_decay = 0.01
78 | # gradient_accumulation_steps = 1
79 | # output_dir = 'log/'
80 | # with_tracking = True
81 | # report_to = "tensorboard"
82 |
83 | # accelerator_log_kwargs = {}
84 |
85 | # if with_tracking:
86 | # accelerator_log_kwargs["log_with"] = report_to
87 | # accelerator_log_kwargs["project_dir"] = output_dir
88 |
89 | # accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps, log_with="tensorboard", project_dir="log")
90 |
91 | # train_dataloader = DataLoader(
92 | # train_dataset, shuffle=True, collate_fn=data_collator, batch_size=per_device_train_batch_size
93 | # )
94 | # eval_dataloader = DataLoader(
95 | # eval_dataset, collate_fn=data_collator, batch_size=per_device_eval_batch_size
96 | # )
97 |
98 | # no_decay = ["bias", "layer_norm.weight"]
99 | # optimizer_grouped_parameters = [
100 | # {
101 | # "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
102 | # "weight_decay": weight_decay,
103 | # },
104 | # {
105 | # "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
106 | # "weight_decay": 0.0,
107 | # },
108 | # ]
109 | # optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=learning_rate)
110 |
111 | # overrode_max_train_steps = False
112 | # num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
113 | # if max_train_steps is None:
114 | # max_train_steps = num_train_epochs * num_update_steps_per_epoch
115 | # overrode_max_train_steps = True
116 |
117 | # lr_scheduler = get_scheduler(
118 | # name=lr_scheduler_type,
119 | # optimizer=optimizer,
120 | # num_warmup_steps=num_warmup_steps * gradient_accumulation_steps,
121 | # num_training_steps=max_train_steps * gradient_accumulation_steps,
122 | # )
123 |
124 | # model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
125 | # model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
126 | # )
127 |
128 | # experiment_config = {
129 | # "num_iterations": 1,
130 | # "learning_rate": 3e-4,
131 | # }
132 |
133 |
134 | # accelerator.init_trackers("clm_no_trainer", experiment_config)
135 |
136 | # # We need to recalculate our total training steps as the size of the training dataloader may have changed.
137 | # num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
138 | # if overrode_max_train_steps:
139 | # max_train_steps = num_train_epochs * num_update_steps_per_epoch
140 | # # Afterwards we recalculate our number of training epochs
141 | # num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
142 |
143 | # # Train!
144 | # total_batch_size = per_device_train_batch_size * accelerator.num_processes * gradient_accumulation_steps
145 |
146 | # progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)
147 | # completed_steps = 0
148 | # starting_epoch = 0
149 |
150 | # # update the progress_bar if load from checkpoint
151 | # progress_bar.update(completed_steps)
152 | # print("training started")
153 | # for epoch in range(starting_epoch, num_train_epochs):
154 | # print("Train!")
155 | # model.train()
156 | # if with_tracking:
157 | # total_loss = 0
158 |
159 | # active_dataloader = train_dataloader
160 | # for step, batch in enumerate(active_dataloader):
161 | # with accelerator.accumulate(model):
162 | # batch.pop("token_type_ids")
163 | # outputs = model(**batch)
164 | # loss = outputs.loss
165 | # # We keep track of the loss at each epoch
166 | # if with_tracking:
167 | # total_loss += loss.detach().float()
168 | # accelerator.backward(loss)
169 | # optimizer.step()
170 | # lr_scheduler.step()
171 | # optimizer.zero_grad()
172 | # accelerator.log({"training_loss": loss}, step=step)
173 |
174 | # # Checks if the accelerator has performed an optimization step behind the scenes
175 | # if accelerator.sync_gradients:
176 | # progress_bar.update(1)
177 | # completed_steps += 1
178 |
179 | # if completed_steps >= max_train_steps:
180 | # break
181 |
182 | # model.eval()
183 | # losses = []
184 | # for step, batch in enumerate(eval_dataloader):
185 | # with torch.no_grad():
186 | # batch.pop("token_type_ids")
187 | # outputs = model(**batch)
188 |
189 | # loss = outputs.loss
190 | # losses.append(accelerator.gather_for_metrics(loss.repeat(per_device_eval_batch_size)))
191 |
192 | # losses = torch.cat(losses)
193 | # try:
194 | # eval_loss = torch.mean(losses)
195 | # perplexity = math.exp(eval_loss)
196 | # except OverflowError:
197 | # perplexity = float("inf")
198 |
199 | # logger.info(f"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}")
200 | # if with_tracking:
201 | # accelerator.log(
202 | # {
203 | # "perplexity": perplexity,
204 | # "eval_loss": eval_loss,
205 | # "train_loss": total_loss.item() / len(train_dataloader),
206 | # "epoch": epoch,
207 | # "step": completed_steps,
208 | # },
209 | # step=completed_steps,
210 | # )
211 | # print(f"epoch: {epoch}")
212 | # print(f"eval_loss: {eval_loss}")
213 | # print(f"train_loss: {total_loss.item() / len(train_dataloader)}")
214 | # print(f"perplexity: {perplexity}")
215 | # accelerator.end_training()
216 |
217 |
218 |
219 |
--------------------------------------------------------------------------------
/distillation_sparsification/test1.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 |
4 | import torch
5 | import torch.nn as nn
6 | import transformers
7 |
8 | from sparsegpt import *
9 | from modelutils import *
10 | import sys
11 | #from transformers import AutoTokenizer, AutoModelForCausalLM
12 | from transformers import (
13 | CONFIG_MAPPING,
14 | MODEL_MAPPING,
15 | AutoConfig,
16 | AutoModelForCausalLM,
17 | AutoTokenizer,
18 | SchedulerType,
19 | default_data_collator,
20 | get_scheduler,
21 | )
22 |
23 | from falcon7b import modelling_RW
24 | from datasets import load_dataset
25 | from make_student import *
26 | import multiprocessing
27 | from transformers import DataCollatorForLanguageModeling
28 | from distill import DistillationTrainer
29 | from datasets import load_from_disk
30 | from torch.utils.data import DataLoader
31 |
32 | from accelerate import Accelerator, DistributedType, notebook_launcher
33 | from accelerate.logging import get_logger
34 | from accelerate.utils import set_seed
35 | from tqdm.auto import tqdm
36 |
37 | from transformers import Trainer, TrainingArguments
38 | import deepspeed
39 |
40 | device = torch.device('cuda:3')
41 | ds = load_from_disk("processed_ds/processed")
42 | print(ds)
43 | train_dataset = ds["test"]
44 | eval_dataset = ds["validation"]
45 |
46 | tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b", trust_remote_code=True)
47 | print(tokenizer)
48 |
49 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
50 | if tokenizer.pad_token is None:
51 | tokenizer.add_special_tokens({'pad_token': '[PAD]'})
52 | t_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-7b", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
53 | print(t_model.config)
54 | t_model.config.output_hidden_states=True
55 |
56 | # model = AutoModelForCausalLM.from_pretrained('student/', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
57 | model, layer_ids = create_student_by_copying_alternating_layers(t_model, d=8, save_path='student1/')
58 | print(model)
59 | print(layer_ids)
60 | # pytorch_total_params = sum(p.numel() for p in model.parameters())
61 |
62 | # print(pytorch_total_params)
63 | # train_dataloader = DataLoader(
64 | # train_dataset, shuffle=True, collate_fn=data_collator, batch_size=1
65 | # )
66 |
67 | # batch = next(iter(train_dataloader))
68 | # inputs_id = batch['input_ids'].to(device)
69 | # print(inputs_id.shape)
70 |
71 | # outputs = t_model(input_ids=inputs_id, attention_mask=None)
72 | # print(list(outputs.keys()))
--------------------------------------------------------------------------------
/distillation_sparsification/tracker.py:
--------------------------------------------------------------------------------
1 | from accelerate.tracking import GeneralTracker, on_main_process
2 | from typing import Optional
3 |
4 | import wandb
5 |
6 |
7 | class MyCustomTracker(GeneralTracker):
8 | name = "wandb"
9 | requires_logging_directory = False
10 |
11 |
12 | def __init__(self, run_name: str):
13 | self.run_name = run_name
14 | run = wandb.init(self.run_name)
15 |
16 |
17 | def tracker(self):
18 | return self.run.run
19 |
20 |
21 | def store_init_configuration(self, values: dict):
22 | wandb.config(values)
23 |
24 |
25 | def log(self, values: dict, step: Optional[int] = None):
26 | wandb.log(values, step=step)
--------------------------------------------------------------------------------
/distillation_sparsification/utils.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import json
3 | import linecache
4 | import math
5 | import os
6 | import pickle
7 | import socket
8 | from logging import getLogger
9 | from pathlib import Path
10 | from typing import Callable, Dict, Iterable, List, Tuple, Union
11 |
12 | import git
13 | import numpy as np
14 | import torch
15 | import torch.distributed as dist
16 | from torch import nn
17 | from torch.utils.data import Dataset, Sampler
18 |
19 | def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
20 | """From fairseq"""
21 | if target.dim() == lprobs.dim() - 1:
22 | target = target.unsqueeze(-1)
23 | nll_loss = -lprobs.gather(dim=-1, index=target)
24 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
25 | if ignore_index is not None:
26 | pad_mask = target.eq(ignore_index)
27 | nll_loss.masked_fill_(pad_mask, 0.0)
28 | smooth_loss.masked_fill_(pad_mask, 0.0)
29 | else:
30 | nll_loss = nll_loss.squeeze(-1)
31 | smooth_loss = smooth_loss.squeeze(-1)
32 |
33 | nll_loss = nll_loss.sum() # mean()? Scared to break other math.
34 | smooth_loss = smooth_loss.sum()
35 | eps_i = epsilon / lprobs.size(-1)
36 | loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
37 | return loss, nll_loss
38 |
39 | def freeze_params(model: nn.Module):
40 | """Set requires_grad=False for each of model.parameters()"""
41 | for par in model.parameters():
42 | par.requires_grad = False
43 |
44 |
45 | def calc_ce_loss(attention_mask, lm_labels, s_logits, t_logits, temperature, restrict_ce_to_mask):
46 | """Copy pasted from distillbert (transformers/examples/distillation/)"""
47 | # mask has False at padding_idx
48 | ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
49 | if restrict_ce_to_mask:
50 | mask = (lm_labels > -1) # (bs, seq_length, voc_size)
51 | else:
52 | mask = attention_mask # (bs, seq_length, voc_size)
53 |
54 | mask = mask.unsqueeze(-1).expand_as(s_logits)
55 | mask = torch.gt(mask, 0)
56 | s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
57 | s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
58 | t_logits_slct = torch.masked_select(t_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
59 | t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
60 | assert t_logits_slct.size() == s_logits_slct.size()
61 |
62 | loss_ce = (
63 | ce_loss_fct(
64 | nn.functional.log_softmax(s_logits_slct / temperature, dim=-1),
65 | nn.functional.softmax(t_logits_slct / temperature, dim=-1),
66 | )
67 | * (temperature) ** 2
68 | )
69 | return loss_ce
70 |
71 |
72 | def calc_hidden_loss(attention_mask, lm_labels,hidden_states, hidden_states_T, matches, normalize_hidden, restrict_ce_to_mask):
73 | """MSE(student_hid, teacher_hid[matches]). Called "Intermediate supervision" in paper. Inspired by TinyBERT."""
74 | msg = "expected list or tuple for hidden_states, got tensor of shape: "
75 | assert not isinstance(hidden_states, torch.Tensor), f"{msg}{hidden_states.shape}"
76 | assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}"
77 | if restrict_ce_to_mask:
78 | mask = lm_labels > -1 # (bs, seq_length, voc_size)
79 | else:
80 | mask = attention_mask # (bs, seq_length, voc_size)
81 |
82 | mask = mask.to(hidden_states[0])
83 | valid_count = mask.sum() * hidden_states[0].size(-1)
84 | s_states = torch.stack([hidden_states[i] for i in range(len(matches))])
85 | t_states = torch.stack([hidden_states_T[j] for j in matches])
86 | assert s_states.shape == t_states.shape, f"{s_states.shape} != {t_states.shape}"
87 | if normalize_hidden:
88 | s_states = nn.functional.layer_norm(s_states, s_states.shape[1:])
89 | t_states = nn.functional.layer_norm(t_states, t_states.shape[1:])
90 | mse = nn.functional.mse_loss(s_states, t_states, reduction="none")
91 | masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count
92 | return masked_mse
93 |
94 |
95 | def eval(s_model, ds, with_tracking=False):
96 | s_model.eval()
97 | lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
98 | losses = []
99 | for step, batch in enumerate(ds):
100 | with torch.no_grad():
101 | input_ids, attn_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"]
102 | lm_labels = batch["input_ids"]
103 |
104 | s_outputs = s_model(
105 | input_ids,
106 | attention_mask=None,
107 | )
108 |
109 | s_logits, s_hidden_states = s_outputs["logits"], s_outputs["hidden_states"]
110 | shift_logits = s_logits[..., :-1, :].contiguous()
111 | shift_labels = lm_labels[..., 1:].contiguous()
112 | loss = lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
113 |
114 | losses.append(accelerator.gather_for_metrics(loss.repeat(per_device_eval_batch_size)))
115 |
116 | losses = torch.cat(losses)
117 | try:
118 | eval_loss = torch.mean(losses)
119 | perplexity = math.exp(eval_loss)
120 | except OverflowError:
121 | perplexity = float("inf")
122 |
123 | logger.info(f"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}")
124 | if with_tracking:
125 | accelerator.log(
126 | {
127 | "perplexity": perplexity,
128 | "eval_loss": eval_loss,
129 | "step": completed_steps,
130 | "train_loss": total_loss.item() / len(train_dataloader),
131 | },
132 | step=completed_steps,
133 | )
134 | accelerator.print(f"eval_loss: {eval_loss}")
135 | accelerator.print(f"perplexity: {perplexity}")
136 | accelerator.print(f"train_loss: {total_loss.item() / len(train_dataloader)}")
--------------------------------------------------------------------------------
/docs/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huu4ontocord/MDEL/d84a598e765accfb723edd58f6c0a426d8c16d8d/docs/.gitkeep
--------------------------------------------------------------------------------
/lora-x/README.md:
--------------------------------------------------------------------------------
1 | # Lora-X - A library to support long context multimodal, multilingual, and multi-domain training
2 |
3 | Lora-X is intended to create a base model for MDEL. We are aiming at llama2 architecture.
4 |
5 | - TODO: confirm a100 40gb training at 70b 8K context on JUWELS
6 | - TODO: work on ways to increase length given memory constraint
7 | + test different lora configs and methods for improving perforamnce
8 | - TODO: use the ntk scaling in the HF
9 | - TODO: Add LLavaR code - in particular the projection layer
10 | - TODO: integrate bpt and test speed difference
11 | - TODO: add Hummingbird's dynamic token and media loading during training
12 | - TODO: add proxy embeddings (m-clip)
13 |
14 | # Credits
15 | This library is based on the code of participants in MDEL
16 | - https://github.com/jordiclive/scaled-rope (which is in turn based on jquesnelle/scaled-rope and Open Assistant code)
17 | - https://github.com/arnavdantuluri/long-context-transformers
18 | - reference for Hummingbird
19 |
20 |
21 | # Multinode LoRA + Flash + DS
22 | This project adds LoRA, Flash attn patch and DS (Deepspeed) to [Scaled Rope](https://github.com/jquesnelle/scaled-rope) that can be run multinode.
23 |
24 | Flash Attention 2 and LLaMA 2 ready 🚀
25 |
26 | ## Setup and Installation
27 |
28 | 1. To install the necessary dependencies, use pip to install the required Python packages:
29 |
30 | ```
31 | pip install -r requirements.txt
32 | ```
33 |
34 | 2. Update the `config.yaml` file as per your requirements.
35 |
36 | ## Usage
37 |
38 | To run the application, use the following command:
39 |
40 | ```
41 | python --configs defaults
42 | ```
43 |
44 | Replace `` with your specific configuration specified in `configs/config.yaml`. Command line arguments can also be overridden.
45 |
46 | **Please Note:** This uses the HF recent PR, so models are HF compatible. Linear scaling argument: 'interpolation_factor', i.e. how much you want to scale the model. If set to None will scale `config.max_position_embeddings / 4096`. As this is the default for LLaMA 2.
47 |
48 |
49 | ## Data
50 | - Specify packed untokenized datasets on the hub under dataset_names e.g. (`Multi-Domain-Expert-Layers/the_pile_books3_packed_128k`)
51 | - If pretokenized=True, specify a single pre-tokenized dataset on the hub under dataset_names (`conceptofmind/rp-packed-32k-no-filter` for OpenLLaMA)
52 |
53 | ### Running on a Single Node
54 |
55 | Use the following commands for running on a single node:
56 |
57 | 1. Export the necessary paths:
58 |
59 | ```
60 | export PYTHONPATH="/mnt/data/jordiclive/scaled-rope:$PYTHONPATH"
61 | export TRANSFORMERS_CACHE="/mnt/data/jordiclive/transformers_cache"
62 | export HF_DATASETS_CACHE="/mnt/data/jordiclive/transformers_cache"
63 | export HF_HOME="/mnt/data/jordiclive/transformers_cache"
64 | export WANDB_API_KEY=""
65 | ```
66 |
67 | 2. Run the script:
68 |
69 | ```
70 | deepspeed --include=localhost:0,1,2,3,4,5,6,7 --master_port 61500 finetune.py --output_dir saved_ckpts_32k --configs defaults lora-7b-llama2 --deepspeed
71 | ```
72 |
73 | ### Running on Multiple Nodes
74 |
75 | Example script using the slurm launcher with deepspeed: `scripts/juwels_booster.sh`
76 |
--------------------------------------------------------------------------------
/lora-x/bpt_pt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.checkpoint import checkpoint
3 | import torch.nn as nn
4 | from transformers.activations import ACT2FN
5 |
6 | class GPTNeoXMLP(nn.Module):
7 | def __init__(self):
8 | super().__init__()
9 | self.dense_h_to_4h = nn.Linear(512, 2048)
10 | self.dense_4h_to_h = nn.Linear(2048, 512)
11 | self.act = ACT2FN["gelu"]
12 |
13 | def forward(self, hidden_states):
14 | hidden_states = self.dense_h_to_4h(hidden_states)
15 | hidden_states = self.act(hidden_states)
16 | hidden_states = self.dense_4h_to_h(hidden_states)
17 | return hidden_states
18 |
19 | import torch
20 | from functools import partial
21 | from torch import nn, einsum
22 | from torch.utils.checkpoint import checkpoint
23 | import torch.nn.functional as F
24 | from .bpt_triton import matmul, add
25 |
26 | from torch.nn.functional import scaled_dot_product_attention
27 | from einops import rearrange
28 | from datetime import datetime
29 | # helper functions
30 |
31 | def exists(val):
32 | return val is not None
33 |
34 | def default(val, d):
35 | return val if exists(val) else d
36 |
37 | # regular attention
38 |
39 | def attention(
40 | q, k, v,
41 | mask = None,
42 | causal = False,
43 | attn_bias = None,
44 | **kwargs
45 | ):
46 | scale = q.shape[-1] ** -0.5
47 | q = q * scale
48 |
49 | sim = einsum('b h i d, b h j d -> b h i j', q, k)
50 |
51 | if exists(attn_bias):
52 | sim = sim + attn_bias
53 |
54 | mask_value = -torch.finfo(sim.dtype).max
55 |
56 | if exists(mask):
57 | mask = rearrange(mask, 'b j -> b 1 1 j')
58 | sim = sim.masked_fill(~mask, mask_value)
59 |
60 | if causal:
61 | i, j = sim.shape[-2:]
62 | mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1)
63 | sim = sim.masked_fill(mask, mask_value)
64 |
65 | sim = sim - sim.amax(dim = -1, keepdim = True).detach()
66 | attn = sim.softmax(dim = -1)
67 |
68 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
69 | return out
70 |
71 | # memory efficient attention
72 |
73 | def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices, dropout):
74 | q_start_index, k_start_index, q_chunk_size, k_chunk_size, device = *qk_start_indices, q.shape[-2], k.shape[-2], q.device
75 |
76 | weight = einsum('b h i d, b h j d -> b h i j', q, k)
77 |
78 | if exists(attn_bias_chunk):
79 | weight = weight + attn_bias_chunk
80 |
81 | mask_value = -torch.finfo(weight.dtype).max
82 |
83 | if exists(mask):
84 | mask = rearrange(mask, 'b j -> b 1 1 j')
85 | weight = weight.masked_fill(~mask, mask_value)
86 |
87 | if causal and q_start_index < (k_start_index + k_chunk_size - 1):
88 | causal_mask = torch.ones((q_chunk_size, k_chunk_size), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
89 | weight = weight.masked_fill(causal_mask, mask_value)
90 |
91 | weight_max = weight.amax(dim = -1, keepdim = True).detach()
92 | weight = weight - weight_max
93 |
94 | exp_weight = weight.exp()
95 |
96 | exp_weight = F.dropout(exp_weight, p = dropout)
97 |
98 | weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v)
99 |
100 | return exp_weight.sum(dim = -1), weighted_value, rearrange(weight_max, '... 1 -> ...')
101 |
102 | checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk)
103 |
104 | def memory_efficient_attention(
105 | q, k, v,
106 | mask = None,
107 | causal = False,
108 | attn_bias = None,
109 | q_bucket_size = 512,
110 | k_bucket_size = 1024,
111 | eps = 1e-8,
112 | dropout = 0.,
113 | training = False
114 | ):
115 | scale = q.shape[-1] ** -0.5
116 | q = q * scale
117 |
118 | # function
119 |
120 | needs_backwards = q.requires_grad or k.requires_grad or v.requires_grad
121 | summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk
122 |
123 | # chunk all the inputs
124 |
125 | q_chunks = q.split(q_bucket_size, dim = -2)
126 | k_chunks = k.split(k_bucket_size, dim = -2)
127 | v_chunks = v.split(k_bucket_size, dim = -2)
128 | mask_chunks = mask.split(k_bucket_size, dim = -1) if exists(mask) else ((None,) * len(k_chunks))
129 |
130 | if exists(attn_bias):
131 | i, j = attn_bias.shape[-2:]
132 | attn_bias_chunks = attn_bias.split(q_bucket_size, dim = -2)
133 | attn_bias_chunks = list(map(lambda t: t.split(k_bucket_size, dim = -1), attn_bias_chunks))
134 |
135 | # loop through all chunks and accumulate
136 |
137 | out = []
138 | for q_index, q_chunk in enumerate(q_chunks):
139 | exp_weights = []
140 | weighted_values = []
141 | weight_maxes = []
142 |
143 | for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)):
144 | q_start_index = q_index * q_bucket_size
145 | k_start_index = k_index * k_bucket_size
146 |
147 | if causal and k_start_index > (q_start_index + q_chunk.shape[-2] - 1):
148 | # if chunk is to be all masked out causally, skip
149 | continue
150 |
151 | attn_bias_chunk = attn_bias_chunks[q_index][k_index] if exists(attn_bias) else None
152 |
153 | exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn(
154 | q_chunk,
155 | k_chunk,
156 | v_chunk,
157 | mask_chunk,
158 | attn_bias_chunk,
159 | causal,
160 | (q_start_index, k_start_index),
161 | dropout if training else 0.
162 | )
163 |
164 | exp_weights.append(exp_weight_chunk)
165 | weighted_values.append(weighted_value_chunk)
166 | weight_maxes.append(weight_max_chunk)
167 |
168 | weight_maxes = torch.stack(weight_maxes, dim = -1)
169 |
170 | weighted_values = torch.stack(weighted_values, dim = -1)
171 | exp_weights = torch.stack(exp_weights, dim = -1)
172 |
173 | global_max = weight_maxes.amax(dim = -1, keepdim = True)
174 | renorm_factor = (weight_maxes - global_max).exp().detach()
175 |
176 | exp_weights = exp_weights * renorm_factor
177 | weighted_values = weighted_values * rearrange(renorm_factor, '... c -> ... 1 c')
178 |
179 | all_values = weighted_values.sum(dim = -1)
180 | all_weights = exp_weights.sum(dim = -1)
181 |
182 | normalized_values = all_values / (rearrange(all_weights, '... -> ... 1') + eps)
183 | out.append(normalized_values)
184 |
185 | return torch.cat(out, dim = -2)
186 |
187 | def blockwise_compute_ffn(cell, inputs, chunk_size):
188 | inputs = torch.split(inputs, chunk_size, dim=1)
189 | num_q = len(inputs)
190 |
191 | def ffn(cell, _, hidden_states):
192 | outputs = cell(hidden_states)
193 | return outputs
194 |
195 | outputs = []
196 | for i in range(num_q):
197 | outputs.append(ffn(cell, None, inputs[i]))
198 |
199 | res = torch.concat(outputs, dim=1)
200 | # res = rearrange(res, 'n b c d -> b (n c) d')
201 | return res
202 |
203 | if __name__ == "__main__":
204 | # Blocked mem stuff
205 | q = torch.rand(2, 512, 16, 128)
206 | k = torch.rand(2, 512, 16, 128)
207 | v = torch.rand(2, 512, 16, 128)
208 | bias = torch.rand(2, 1, 512, 2048)
209 |
210 | # Blocked FFN Stuff
211 | x = torch.rand(2, 256, 512)
212 | cell = GPTNeoXMLP()
213 | startTime = datetime.now()
214 | y_pt_mem = memory_efficient_attention(q, k, v, q_bucket_size=512, k_bucket_size=512)
215 | print('pythonic mem eff attn', datetime.now() - startTime)
216 |
217 | torch.backends.cuda.sdp_kernel(True)
218 | torch.backends.cuda.enable_flash_sdp(True)
219 | startTime = datetime.now()
220 | y_pt_mem = scaled_dot_product_attention(q, k, v)
221 | print('pythonic mem eff attn', datetime.now() - startTime)
222 |
223 | startTime = datetime.now()
224 | y_pt_ffn = blockwise_compute_ffn(cell, x, 256)
225 | print('pythonic blocked ffn', datetime.now() - startTime)
226 |
227 | startTime = datetime.now()
228 | y_pt_ffn = blockwise_compute_ffn_triton(cell, x)
229 | print('pythonic blocked ffn', datetime.now() - startTime)
230 | print(y_pt_ffn.shape)
231 |
--------------------------------------------------------------------------------
/lora-x/bpt_triton.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import triton
4 | import triton.language as tl
5 | from datetime import datetime
6 |
7 | @triton.jit
8 | def matmul_kernel(
9 | # Pointers to matrices
10 | a_ptr, b_ptr, c_ptr,
11 | # Matrix dimensions
12 | M, N, K,
13 | # The stride variables represent how much to increase the ptr by when moving by 1
14 | # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
15 | # by to get the element one row down (A has M rows).
16 | stride_am, stride_ak,
17 | stride_bk, stride_bn,
18 | stride_cm, stride_cn,
19 | # Meta-parameters
20 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
21 | GROUP_SIZE_M: tl.constexpr,
22 | ):
23 | """Kernel for computing the matmul C = A x B.
24 | A has shape (M, K), B has shape (K, N) and C has shape (M, N)
25 | """
26 | # -----------------------------------------------------------
27 | # Map program ids `pid` to the block of C it should compute.
28 | # This is done in a grouped ordering to promote L2 data reuse.
29 | # See above `L2 Cache Optimizations` section for details.
30 | pid = tl.program_id(axis=0)
31 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
32 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
33 | num_pid_in_group = GROUP_SIZE_M * num_pid_n
34 | group_id = pid // num_pid_in_group
35 | first_pid_m = group_id * GROUP_SIZE_M
36 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
37 | pid_m = first_pid_m + (pid % group_size_m)
38 | pid_n = (pid % num_pid_in_group) // group_size_m
39 |
40 | # ----------------------------------------------------------
41 | # Create pointers for the first blocks of A and B.
42 | # We will advance this pointer as we move in the K direction
43 | # and accumulate
44 | # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
45 | # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
46 | # See above `Pointer Arithmetics` section for details
47 | offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
48 | offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
49 | offs_k = tl.arange(0, BLOCK_SIZE_K)
50 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
51 | b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
52 |
53 | # -----------------------------------------------------------
54 | # Iterate to compute a block of the C matrix.
55 | # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
56 | # of fp32 values for higher accuracy.
57 | # `accumulator` will be converted back to fp16 after the loop.
58 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
59 | for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
60 | # Load the next block of A and B, generate a mask by checking the K dimension.
61 | # If it is out of bounds, set it to 0.
62 | a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
63 | b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
64 | # We accumulate along the K dimension.
65 | accumulator += tl.dot(a, b)
66 | # Advance the ptrs to the next K block.
67 | a_ptrs += BLOCK_SIZE_K * stride_ak
68 | b_ptrs += BLOCK_SIZE_K * stride_bk
69 | # You can fuse arbitrary activation functions here
70 | # while the accumulator is still in FP32!
71 | c = accumulator.to(tl.float16)
72 |
73 | # -----------------------------------------------------------
74 | # Write back the block of the output matrix C with masks.
75 | offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
76 | offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
77 | c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
78 | c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
79 | tl.store(c_ptrs, c, mask=c_mask)
80 |
81 |
82 | # We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`.
83 | @triton.jit
84 | def leaky_relu(x):
85 | x = x + 1
86 | return tl.where(x >= 0, x, 0.01 * x)
87 |
88 | @torch.jit.script
89 | def gelu(x):
90 | return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
91 |
92 | def matmul(a, b, block_size_m=32, block_size_n=32, block_size_k=32, group_size_m=8):
93 | # Check constraints.
94 | assert a.shape[1] == b.shape[0], "Incompatible dimensions"
95 | assert a.is_contiguous(), "Matrix A must be contiguous"
96 | assert b.is_contiguous(), "Matrix B must be contiguous"
97 | M, K = a.shape
98 | K, N = b.shape
99 | # Allocates output.
100 | c = torch.empty((M, N), device=a.device, dtype=a.dtype)
101 | # 1D launch kernel where each block gets its own program.
102 | grid = lambda META: (
103 | triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
104 | )
105 | matmul_kernel[grid](
106 | a, b, c,
107 | M, N, K,
108 | a.stride(0), a.stride(1),
109 | b.stride(0), b.stride(1),
110 | c.stride(0), c.stride(1),
111 | block_size_m, block_size_n, block_size_k, group_size_m
112 | )
113 | return c
114 |
115 | @triton.jit
116 | def add_kernel(
117 | x_ptr, # *Pointer* to first input vector.
118 | y_ptr, # *Pointer* to second input vector.
119 | output_ptr, # *Pointer* to output vector.
120 | n_elements, # Size of the vector.
121 | BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
122 | # NOTE: `constexpr` so it can be used as a shape value.
123 | ):
124 | # There are multiple 'programs' processing different data. We identify which program
125 | # we are here:
126 | pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
127 | # This program will process inputs that are offset from the initial data.
128 | # For instance, if you had a vector of length 256 and block_size of 64, the programs
129 | # would each access the elements [0:64, 64:128, 128:192, 192:256].
130 | # Note that offsets is a list of pointers:
131 | block_start = pid * BLOCK_SIZE
132 | offsets = block_start + tl.arange(0, BLOCK_SIZE)
133 | # Create a mask to guard memory operations against out-of-bounds accesses.
134 | mask = offsets < n_elements
135 | # Load x and y from DRAM, masking out any extra elements in case the input is not a
136 | # multiple of the block size.
137 | x = tl.load(x_ptr + offsets, mask=mask)
138 | y = tl.load(y_ptr + offsets, mask=mask)
139 | output = x + y
140 | # Write x + y back to DRAM after activation
141 | tl.store(output_ptr + offsets, output, mask=mask)
142 |
143 | def add(x: torch.Tensor, y: torch.Tensor):
144 | # We need to preallocate the output.
145 | output = torch.empty_like(x)
146 | assert x.is_cuda and y.is_cuda and output.is_cuda
147 | n_elements = output.numel()
148 | # The SPMD launch grid denotes the number of kernel instances that run in parallel.
149 | # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
150 | # In this case, we use a 1D grid where the size is the number of blocks:
151 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
152 | # NOTE:
153 | # - Each torch.tensor object is implicitly converted into a pointer to its first element.
154 | # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
155 | # - Don't forget to pass meta-parameters as keywords arguments.
156 | add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
157 | # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
158 | # running asynchronously at this point.
159 | return output
160 |
161 | def forward(x: torch.Tensor, w: torch.Tensor, b:torch.Tensor):
162 | output = matmul(x, w)
163 | # output = add(output, b)
164 | return output
165 |
166 | if __name__ == "__main__":
167 | torch.manual_seed(0)
168 | a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
169 | layer = torch.nn.Linear(512, 512).cuda().half()
170 | start = datetime.now()
171 | triton_output = forward(a, layer.weight.T.contiguous(), layer.bias)
172 | print("triton", datetime.now()- start)
173 |
174 | start = datetime.now()
175 | torch_output = layer(a)
176 | print("torch", datetime.now()- start)
177 |
178 | print("diff", (torch_output - triton_output).abs())
179 | if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0):
180 | print("Triton and Torch match")
181 | else:
182 | print("Triton and Torch differ")
183 |
--------------------------------------------------------------------------------
/lora-x/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from distutils.util import strtobool
4 | from pathlib import Path
5 |
6 | import yaml
7 |
8 |
9 | def _strtobool(x):
10 | return bool(strtobool(x))
11 |
12 |
13 | def read_yamls(dir):
14 | conf = {}
15 | no_conf = True
16 |
17 | for config_file in Path(dir).glob("**/*.yaml"):
18 | no_conf = False
19 | with config_file.open("r") as f:
20 | conf.update(yaml.safe_load(f))
21 |
22 | if no_conf:
23 | print(f"WARNING: No yaml files found in {dir}")
24 |
25 | return conf
26 |
27 |
28 | def rank_zero_info(msg) -> None:
29 | local_rank = int(os.getenv("LOCAL_RANK", "0"))
30 | if local_rank in (None, 0):
31 | print(msg)
32 |
33 |
34 | def argument_parsing(notebook=False, notebook_args=None):
35 | parser = argparse.ArgumentParser()
36 | parser.add_argument(
37 | "--configs",
38 | nargs="+",
39 | required=True,
40 | help="""
41 | Multiple configs can be passed to set different options.
42 | For example, run as:
43 |
44 | ./trainer_sft.py --configs default juweles
45 |
46 | """,
47 | )
48 | parser.add_argument("--local_rank", type=int, default=-1)
49 | parser.add_argument("--deepspeed", action="store_true")
50 | parser.add_argument("--no_deepspeed", action="store_true")
51 | parser.add_argument("--wandb-entity", type=str, default="open-assistant")
52 | parser.add_argument(
53 | "--resume_from_checkpoint",
54 | action="store_true",
55 | help="Resume from last saved checkpoint",
56 | )
57 | parser.add_argument("--rng_seed", type=int, help="rng seed")
58 | parser.add_argument(
59 | "--show_dataset_stats",
60 | action="store_true",
61 | help="Show dataset stats",
62 | default=False,
63 | )
64 | parser.set_defaults(deepspeed=False)
65 |
66 | if notebook:
67 | args, remaining = parser.parse_known_args(notebook_args)
68 | else:
69 | args, remaining = parser.parse_known_args()
70 |
71 | # Config from YAML
72 | conf = {}
73 | configs = read_yamls("configs/")
74 |
75 | conf.update(configs["defaults"])
76 | try:
77 | for name in args.configs:
78 | if "," in name:
79 | for n in name.split(","):
80 | conf.update(configs[n])
81 | else:
82 | conf.update(configs[name])
83 | except KeyError as e:
84 | print(f'Error: Could not find the config "{e.args[0]}" in config.yaml')
85 | exit(1)
86 |
87 | conf["wandb_entity"] = args.wandb_entity
88 | conf["local_rank"] = args.local_rank
89 | conf["deepspeed"] = args.deepspeed
90 | if args.no_deepspeed:
91 | conf["deepspeed"] = None
92 | conf["resume_from_checkpoint"] = args.resume_from_checkpoint
93 | if args.rng_seed is not None:
94 | conf["rng_seed"] = args.rng_seed
95 | conf["show_dataset_stats"] = args.show_dataset_stats
96 |
97 | # get the world size in deeepspeed
98 | if conf["deepspeed"]:
99 | conf["world_size"] = int(os.getenv("WORLD_SIZE", default="1"))
100 | else:
101 | conf["world_size"] = 1
102 |
103 | # Override config from command-line
104 | parser = argparse.ArgumentParser()
105 | for key, value in conf.items():
106 | type_ = type(value) if value is not None else str
107 | if type_ == bool:
108 | type_ = _strtobool
109 | parser.add_argument(f"--{key}", type=type_, default=value)
110 | # Allow --no-{key} to remove it completely
111 | parser.add_argument(f"--no-{key}", dest=key, action="store_const", const=None)
112 |
113 | return parser.parse_args(remaining)
114 |
--------------------------------------------------------------------------------
/lora-x/configs/zero3_offload_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "scheduler": {
6 | "type": "WarmupLR",
7 | "params": {
8 | "warmup_min_lr": "auto",
9 | "warmup_max_lr": "auto",
10 | "warmup_num_steps": "auto"
11 | }
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "zero_optimization": {
23 | "stage": 3,
24 | "offload_optimizer": {
25 | "device": "cpu",
26 | "pin_memory": true
27 | },
28 | "offload_param": {
29 | "device": "cpu",
30 | "pin_memory": true
31 | },
32 | "overlap_comm": true,
33 | "contiguous_gradients": true,
34 | "sub_group_size": 1e9,
35 | "reduce_bucket_size": "auto",
36 | "stage3_prefetch_bucket_size": "auto",
37 | "stage3_param_persistence_threshold": "auto",
38 | "stage3_max_live_parameters": 1e9,
39 | "stage3_max_reuse_distance": 1e9,
40 | "stage3_gather_16bit_weights_on_model_save": true
41 | },
42 | "gradient_accumulation_steps": "auto",
43 | "gradient_clipping": "auto",
44 | "steps_per_print": 2000,
45 | "train_batch_size": "auto",
46 | "train_micro_batch_size_per_gpu": "auto",
47 | "wall_clock_breakdown": false
48 | }
49 |
--------------------------------------------------------------------------------
/lora-x/data.py:
--------------------------------------------------------------------------------
1 | #from https://github.com/jordiclive/scaled-rope
2 | from dataclasses import dataclass
3 | from typing import NamedTuple, Optional, Union
4 |
5 | import datasets
6 | import numpy as np
7 | from sklearn.model_selection import train_test_split
8 | from torch.utils.data import ConcatDataset, Dataset, Subset
9 | from transformers.tokenization_utils_base import (PaddingStrategy,
10 | PreTrainedTokenizerBase,
11 | TruncationStrategy)
12 |
13 |
14 | class DatasetEntryLm(NamedTuple):
15 | """Language modelling dataset entry"""
16 |
17 | text: Union[str, None] = None
18 |
19 |
20 | class LMDataset(Dataset):
21 | name = "LMDataset"
22 |
23 | def __init__(self, dataset_name, char_max_len: str = 200000) -> None:
24 | super().__init__()
25 | self.char_max_len = char_max_len
26 | self.dataset = datasets.load_dataset(dataset_name)["train"]
27 |
28 | def __len__(self) -> int:
29 | return len(self.dataset)
30 |
31 | def __getitem__(self, index) -> DatasetEntryLm:
32 | dialogue = DatasetEntryLm(text=self.dataset[index]["text"][: self.char_max_len])
33 | return dialogue
34 |
35 |
36 | @dataclass
37 | class DataCollator:
38 | tokenizer: PreTrainedTokenizerBase
39 | padding: Union[bool, str, PaddingStrategy] = True
40 | max_length: Optional[int] = None
41 | mix_length_threshold: Optional[int] = 256
42 | mix_probability: Optional[float] = 0.6
43 | pad_to_multiple_of: Optional[int] = None
44 | samples_mixing: Optional[bool] = False
45 |
46 | def __post_init__(self):
47 | assert self.tokenizer.eos_token
48 |
49 | def process_one(self, messages, return_length=False):
50 | truncation = TruncationStrategy.LONGEST_FIRST
51 | max_length = self.max_length
52 |
53 | messages = messages.text
54 |
55 | flatten_message = self.tokenizer(
56 | "".join(messages),
57 | max_length=max_length,
58 | truncation=truncation,
59 | padding=False,
60 | return_token_type_ids=False,
61 | )
62 |
63 | label_mask = np.ones(len(flatten_message.input_ids), dtype=bool)
64 | return flatten_message, label_mask, 0
65 |
66 | def __call__(self, features):
67 | flatten_messages = []
68 | label_masks = []
69 | total_short_context = 0
70 | for messages in features:
71 | flatten_message, label_mask, total_short_context_one = self.process_one(
72 | messages
73 | )
74 | flatten_messages.append(flatten_message)
75 | label_masks.append(label_mask)
76 | total_short_context += total_short_context_one
77 |
78 | batch = self.tokenizer.pad(
79 | flatten_messages,
80 | padding=self.padding,
81 | pad_to_multiple_of=self.pad_to_multiple_of,
82 | return_tensors="pt",
83 | )
84 | batch["labels"] = batch["input_ids"].clone()
85 |
86 | return batch
87 |
88 |
89 | def train_val_dataset(dataset, val_split=0.2):
90 | if val_split == 0:
91 | return dataset, None
92 |
93 | train_idx, val_idx = train_test_split(
94 | list(range(len(dataset))), test_size=val_split, random_state=666, shuffle=True
95 | )
96 | return Subset(dataset, train_idx), Subset(dataset, val_idx)
97 |
98 |
99 | def get_one_dataset(
100 | conf,
101 | val_split: float = 0.025,
102 | data_path: str = None,
103 | mode: str = "sft",
104 | max_val_set: Optional[int] = 50,
105 | **kwargs,
106 | ):
107 | data_path = data_path or conf.cache_dir
108 | # dataset_name = dataset_name.lower()
109 | train_datasets = []
110 | eval_datasets = []
111 | for data_file in conf.dataset_names:
112 | dataset = LMDataset(data_file)
113 |
114 | # if eval not already defined
115 | if not ("eval" in locals() and "train" in locals()):
116 | train, eval = train_val_dataset(dataset, val_split=val_split)
117 |
118 | if eval and max_val_set and len(eval) > max_val_set:
119 | subset_indices = np.random.choice(len(eval), max_val_set)
120 | eval = Subset(eval, subset_indices)
121 | train_datasets.append(train)
122 | eval_datasets.append(eval)
123 |
124 | train = ConcatDataset(train_datasets)
125 | eval = ConcatDataset(eval_datasets)
126 | return train, eval
127 |
--------------------------------------------------------------------------------
/lora-x/experimental/train_qlora_lomo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3 |
4 | model_id = "NousResearch/Llama-2-7b-hf"
5 | bnb_config = BitsAndBytesConfig(
6 | load_in_4bit=True,
7 | bnb_4bit_use_double_quant=True,
8 | bnb_4bit_quant_type="nf4",
9 | bnb_4bit_compute_dtype=torch.bfloat16
10 | )
11 |
12 | tokenizer = AutoTokenizer.from_pretrained(model_id)
13 | model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})
14 |
15 | from peft import prepare_model_for_kbit_training
16 |
17 | model.gradient_checkpointing_enable()
18 | model = prepare_model_for_kbit_training(model)
19 |
20 | def print_trainable_parameters(model):
21 | """
22 | Prints the number of trainable parameters in the model.
23 | """
24 | trainable_params = 0
25 | all_param = 0
26 | for _, param in model.named_parameters():
27 | all_param += param.numel()
28 | if param.requires_grad:
29 | trainable_params += param.numel()
30 | print(
31 | f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
32 | )
33 | from datasets import load_dataset
34 |
35 | data = load_dataset("Abirate/english_quotes")
36 | data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
37 | from peft import LoraConfig, get_peft_model
38 |
39 | config = LoraConfig(
40 | r=8,
41 | lora_alpha=32,
42 | #target_modules=["query_key_value"],
43 | lora_dropout=0.05,
44 | bias="none",
45 | task_type="CAUSAL_LM"
46 | )
47 |
48 | model = get_peft_model(model, config)
49 | print_trainable_parameters(model)
50 | import transformers
51 |
52 | # needed for gpt-neo-x tokenizer
53 | tokenizer.pad_token = tokenizer.eos_token
54 | import transformers
55 | import qlora_lomo
56 | trainer = qlora_lomo.Trainer(
57 | model=model,
58 | train_dataset=data["train"],
59 | args=qlora_lomo.TrainingArguments(
60 | per_device_train_batch_size=1,
61 | gradient_accumulation_steps=4,
62 | warmup_steps=2,
63 | max_steps=10,
64 | learning_rate=2e-4,
65 | fp16=True,
66 | logging_steps=1,
67 | output_dir="outputs",
68 | optim="LOMO" # paged_adamw_8bit
69 | ),
70 | data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
71 | )
72 | model.config.use_cache = False # silence the warnings. Please re-enable for inference!
73 | trainer.train()
74 |
75 |
--------------------------------------------------------------------------------
/lora-x/lora.py:
--------------------------------------------------------------------------------
1 | #from https://github.com/jordiclive/scaled-rope/blob/hf_version/lora.py
2 | from pathlib import Path
3 |
4 | import torch
5 | from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training
6 |
7 |
8 | def prepare_model_for_gradient_checkpointing(model):
9 | r"""
10 | Prepares the model for gradient checkpointing if necessary
11 | """
12 | if not getattr(model, "is_loaded_in_8bit", False):
13 | if hasattr(model, "enable_input_require_grads"):
14 | model.enable_input_require_grads()
15 | else:
16 |
17 | def make_inputs_require_grad(module, input, output):
18 | output.requires_grad_(True)
19 |
20 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
21 | return model
22 |
23 |
24 | def peft_model(
25 | model, peft_config, model_name, int8_training=False, gradient_checkpointing=False
26 | ):
27 |
28 | if "falcon" in model_name:
29 | target_modules = ["dense_4h_to_h", "dense", "query_key_value", "dense_h_to_4h"]
30 |
31 | elif "llama" in model_name:
32 | target_modules = [
33 | "down_proj",
34 | "k_proj",
35 | "q_proj",
36 | "gate_proj",
37 | "o_proj",
38 | "up_proj",
39 | "v_proj",
40 | ]
41 | else:
42 | raise ValueError(
43 | f"Invalid model name '{model_name}'. The model name should contain 'falcon' or 'llama'"
44 | )
45 | config = LoraConfig(
46 | r=peft_config["r"],
47 | lora_alpha=peft_config["alpha"],
48 | target_modules=target_modules,
49 | lora_dropout=peft_config["dropout"],
50 | bias="none",
51 | task_type="CAUSAL_LM",
52 | )
53 |
54 | model = get_peft_model(model, config)
55 | if int8_training:
56 | model = prepare_model_for_int8_training(model)
57 |
58 | if gradient_checkpointing:
59 | model = prepare_model_for_gradient_checkpointing(model)
60 | model.print_trainable_parameters()
61 | return model
62 |
63 |
64 | def load_peft_finetuned_model(model, peft_model_path):
65 | adapters_weights = torch.load(
66 | Path(peft_model_path).joinpath("adapter_model.bin"), map_location=model.device
67 | )
68 | model.load_state_dict(adapters_weights, strict=False)
69 | return model
70 |
--------------------------------------------------------------------------------
/lora-x/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.0.1
2 | accelerate==0.20.3
3 | datasets==2.13.1
4 | deepspeed==0.9.5
5 | git+https://github.com/Dao-AILab/flash-attention.git
6 | #flash-attn==1.0.5
7 | peft==0.3.0
8 | protobuf==3.19.6
9 | scikit-learn==1.3.0
10 | sentencepiece==0.1.99
11 | setuptools==59.5.0
12 | transformers==4.31.0
13 | wandb==0.15.5
14 | einops==0.6.1
15 |
--------------------------------------------------------------------------------
/lora-x/scripts/juwels_booster.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # SLURM Configuration
3 | #SBATCH --account=cstdl
4 | #SBATCH --nodes=1
5 | #SBATCH --gres=gpu:4
6 | #SBATCH --ntasks-per-node=1
7 | #SBATCH --cpus-per-task=8
8 | #SBATCH --partition develbooster
9 |
10 | # JUWELS Configuration
11 | conda deactivate
12 | module purge
13 | ml use $OTHERSTAGES
14 | module load Stages/2023 GCC/11.3.0 OpenMPI/4.1.4
15 | module load CUDA/11.7
16 |
17 | # Network Configuration
18 | export NCCL_IB_TIMEOUT=50
19 | export UCX_RC_TIMEOUT=4s
20 | export NCCL_IB_RETRY_CNT=10
21 | export NCCL_ASYNC_ERROR_HANDLING=1
22 |
23 | # Environment Configuration
24 | source /p/home/jusers/clive1/juwels/clive1/miniconda3/bin/activate jordan_lora
25 | export WANDB_API_KEY="d8216641d549f9bb3d0c5074baa39e15dfd55030"
26 | export HUGGING_FACE_HUB_TOKEN="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"
27 | export PYTHONPATH="/p/home/jusers/clive1/juwels/clive1/scaled-rope:$PYTHONPATH"
28 | export TRANSFORMERS_CACHE="/p/home/jusers/clive1/juwels/clive1/transformers_cache"
29 | export HF_DATASETS_CACHE="/p/home/jusers/clive1/juwels/clive1/transformers_cache"
30 | export HF_HOME="/p/home/jusers/clive1/juwels/clive1/transformers_cache"
31 | export PATH="/p/software/juwelsbooster/stages/2023/software/OpenMPI/4.1.4-GCC-11.3.0/bin:$PATH"
32 |
33 | # Juwls specific env
34 | export CUDA_VISIBLE_DEVICES="0,1,2,3"
35 | export WANDB_MODE="offline"
36 | export TRANSFORMERS_OFFLINE=1
37 |
38 | # SLURM Host Configuration
39 | hostfile='/p/home/jusers/clive1/juwels/hostfiles/hostfile.txt'
40 | rm $hostfile
41 |
42 | for i in `scontrol show hostnames $SLURM_NODELIST`
43 | do
44 | echo $i slots=4 >>$hostfile
45 | done
46 |
47 | export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"`
48 | export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
49 | export MASTER_PORT=12802
50 | export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l`
51 | export DLTS_HOSTFILE=$hostfile
52 |
53 |
54 | # Print System Information
55 | echo "GPUs available to job: $SLURM_JOB_GPUS"
56 | echo "Total tasks: $SLURM_NTASKS"
57 |
58 | deepspeed --master_port 12802 \
59 | --launcher slurm \
60 | --hostfile '/p/home/jusers/clive1/juwels/hostfiles/hostfile.txt' \
61 | --master_addr $MASTER_ADDR \
62 | --no_ssh_check \
63 | /p/home/jusers/clive1/juwels/clive1/scaled-rope/finetune.py \
64 | --output_dir saved_ckpts_32k \
65 | --configs defaults lora-7b-llama2 \
66 | --deepspeed
67 |
--------------------------------------------------------------------------------
/lora-x/utils.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import torch
4 | from transformers.activations import (FastGELUActivation, GELUActivation,
5 | NewGELUActivation, QuickGELUActivation)
6 |
7 |
8 | def rsetattr(obj, attr, val):
9 | pre, _, post = attr.rpartition(".")
10 | return setattr(rgetattr(obj, pre) if pre else obj, post, val)
11 |
12 |
13 | def rgetattr(obj, attr, *args):
14 | def _getattr(obj, attr):
15 | return getattr(obj, attr, *args)
16 |
17 | return functools.reduce(_getattr, [obj] + attr.split("."))
18 |
19 |
20 | def fuse_gelu(model):
21 | @torch.jit.script
22 | def gelu_fwd(x):
23 | return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
24 |
25 | @torch.jit.script
26 | def gelu_bwd(g, x):
27 | tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
28 | ff = 0.5 * x * (
29 | (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
30 | ) + 0.5 * (1 + tanh_out)
31 | return ff * g
32 |
33 | class _FusedGeLUFunction(torch.autograd.Function):
34 | @staticmethod
35 | # bias is an optional argument
36 | def forward(ctx, input):
37 | ctx.input_tensor = input
38 | return gelu_fwd(input)
39 |
40 | @staticmethod
41 | def backward(ctx, grad_output):
42 | input = ctx.input_tensor
43 | tmp = gelu_bwd(grad_output, input)
44 | return tmp
45 |
46 | class FusedGelu(torch.nn.Module):
47 | def forward(self, input):
48 | return _FusedGeLUFunction.apply(input)
49 |
50 | fused_gelu_module = FusedGelu()
51 | hf_gelu_functions = [
52 | GELUActivation,
53 | FastGELUActivation,
54 | NewGELUActivation,
55 | QuickGELUActivation,
56 | ]
57 |
58 | for name, module in model.named_modules():
59 | for hf_gelu_function in hf_gelu_functions:
60 | if isinstance(module, hf_gelu_function):
61 | rsetattr(model, name, fused_gelu_module)
62 |
63 | return model
64 |
--------------------------------------------------------------------------------
/notebooks/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huu4ontocord/MDEL/d84a598e765accfb723edd58f6c0a426d8c16d8d/notebooks/.gitkeep
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate~=0.18.0
2 | datasets~=2.11.0
3 | deepspeed==0.9.2
4 | evaluate~=0.4.0
5 | pre-commit~=2.21.0
6 | scikit-learn~=1.2.2
7 | torch~=2.0.0
8 | tqdm~=4.65.0
9 | transformers~=4.28.1
10 | ujson~=5.7.0
11 | wandb==0.15.2
12 | zstandard~=0.21.0
13 |
--------------------------------------------------------------------------------
/resources.md:
--------------------------------------------------------------------------------
1 | https://github.com/TehVenomm/LM_Transformers_BlockMerge
2 |
--------------------------------------------------------------------------------
/scripts/c-btmInference.py:
--------------------------------------------------------------------------------
1 | from clustering.feature_extractor import FeatureExtractor
2 | from transformers import AutoModelForCausalLM, AutoTokenizer
3 | import torch
4 | from torch.nn import PairwiseDistance
5 | import numpy as np
6 | import fire
7 |
8 |
9 | def topKFilter(ensembleWeights, k):
10 | """
11 | Filters and normalizes top k ensemble weights.
12 |
13 | Parameters:
14 | - ensembleWeights: list of ensemble weights as calculated in findEnsembleWeights
15 | - k: number of top experts to choose
16 |
17 | Returns:
18 | Tuple with the normalized weights and corresponding indices as they pertain to the top k domains.
19 | """
20 | topK = torch.topk(ensembleWeights, k=k)
21 | indices = topK.indices
22 |
23 | if torch.sum(topK.values) == 0:
24 | return [1/k for i in range(k)], indices
25 |
26 | normalizedTopKValues = [float(p)/torch.sum(topK.values) for p in topK.values]
27 |
28 | return normalizedTopKValues, indices
29 |
30 |
31 |
32 | def findTopKModelProbabilities(indices, query, models, tokenizers):
33 | """
34 | Find the probabilities of all tokens appearing in the next step for each model in the given list.
35 |
36 | Parameters:
37 | - indices: A list of indices representing the models to consider.
38 | - query: The input query for which to predict the next token probabilities.
39 | - models: A list of PyTorch models for predicting next token probabilities.
40 | - tokenizers: A list of tokenizers corresponding to the models.
41 |
42 | Returns:
43 | List of tensors, where each tensor contains the probabilities of the next token
44 | for the corresponding model in the input list.
45 |
46 | The function iterates over the specified models, tokenizes the input query, and
47 | calculates the probabilities of all tokens in the next step for each model. The results are
48 | collected into a list and returned.
49 |
50 | Note: The models and tokenizers in the input lists should correspond to each other.
51 | The length of indices, models, and tokenizers should be the same.
52 | """
53 | all_models_next_token_probs = []
54 |
55 | for index in range(len(indices)):
56 |
57 | input_ids = tokenizers[index].encode(query, return_tensors="pt")
58 | output = models[index](input_ids)
59 | next_token_probs = output.logits[:, -1, :].softmax(dim=-1)
60 | all_models_next_token_probs.append(next_token_probs)
61 |
62 | return all_models_next_token_probs
63 |
64 |
65 |
66 | def findEnsembleWeights(embedder, query, clusterCenters, T, firstToken=False, prompt_data=None):
67 | """
68 | Finds ensemble weights based on distance of query and cluster centers.
69 |
70 | Parameters:
71 | - embedder: the embedder that was used to train the clustering model
72 | - query: the string prompt from input_file
73 | - clusterCenters: dictionary of cluster centers with the keys as domain{i}, as per the input_file
74 | - T: temperature parameter for softmax
75 | - firstToken: boolean to indicate if the first token is yet to be generated
76 | - prompt_data: the entire json object for the prompt
77 |
78 | Returns:
79 | list the distance between the queries and each of the cluster centers.
80 | """
81 | ensembleWeights = []
82 | pdist = torch.nn.PairwiseDistance(p=2)
83 |
84 | for i in range(len(clusterCenters)):
85 |
86 | if firstToken:
87 |
88 | distance = torch.tensor(prompt_data['meta'][f'domain_score{i+1}']) # this should be the L2 norm between the prompt and the cluster center
89 |
90 | else:
91 |
92 | cluster = clusterCenters[f'domain{i+1}']
93 |
94 | embedded_query = embedder(query)
95 |
96 | if type(embedded_query) is not torch.Tensor:
97 |
98 | embedded_query = torch.Tensor(embedded_query)
99 |
100 | distance = torch.pow(pdist(embedded_query, cluster),2)
101 |
102 | ensembleWeights.append(torch.exp(-1 * distance / T).mean())
103 |
104 | return torch.tensor(ensembleWeights)
105 |
106 |
107 | def findModelProbabilities(embedder, query, clusterCenters, models, tokenizers, T, k, firstToken, prompt_data):
108 | """
109 | Calculates the sum of the probabilities of the ensembled tokens using the top k most relevant domains.
110 |
111 | Parameters:
112 | - embedder: the embedder that was used to train the clustering model
113 | - query: the string prompt from input_file
114 | - models: list of models that were trained on most relevant domains
115 | - tokenizers: list of tokenizers that were trained on most relevant domains
116 | - T: temperature parameter for softmax
117 | - k: number of most relevant domains to choose from
118 | - firstToken: boolean to indicate if the first token is yet to be generated
119 | - prompt_data: the entire json object for the prompt
120 |
121 | Returns:
122 | the ensembled probabilities of all tokens at the next step.
123 | """
124 | modelProbs = 0
125 | ensembleWeights = findEnsembleWeights(embedder, query, clusterCenters, T, firstToken, prompt_data)
126 | weights, indices = topKFilter(ensembleWeights, k)
127 | modelWeights = findTopKModelProbabilities(indices, query, models, tokenizers)
128 |
129 | for i in range(k):
130 | index = indices[i]
131 | modelProbs += modelWeights[index] * weights[index]
132 |
133 | return modelProbs
134 |
135 |
136 | def findNextToken(embedder, query, clusterCenters, models, tokenizers, T, k, firstToken, prompt_data):
137 | """
138 | Returns decoded next-step token with highest probability.
139 |
140 | Parameters:
141 | - embedder: the embedder that was used to train the clustering model
142 | - query: the string prompt from input_file
143 | - clusterCenters: dictionary of cluster centers with the keys as domain{i}, as per the input_file
144 | - models: list of models that were trained on most relevant domains
145 | - tokenizers: list of tokenizers that were trained on most relevant domains
146 | - T: temperature parameter for softmax
147 | - k: number of most relevant domains to choose from
148 | - firstToken: boolean to indicate if the first token is yet to be generated
149 | - prompt_data: the entire json object for the prompt
150 | """
151 | modelProbs = findModelProbabilities(embedder,query, clusterCenters, models, tokenizers, T, k, firstToken, prompt_data)
152 | return tokenizers[0].decode(np.argmax(modelProbs)) # doesn't matter which tokenizer bc they're all the same
153 |
154 |
155 | def generateSequence(embedder, prompt_data, end_token, clusterCenters, models, tokenizers, T, k, maxLength):
156 | """
157 | Takes in prompt, end_token which ideally is uniform across all tokenizers, parameter k,
158 | temperature T and cluster centers and finds most likely token from most relevant domains based on prompt. Then, it
159 | builds sequence until end token is generated or maxLength is reached.
160 |
161 | Parameters:
162 | - embedder: the embedder that was used to train the clustering model
163 | - query: the string prompt from input_file
164 | - end_token: the token that ideally is uniform across all tokenizers
165 | - clusterCenters: dictionary of cluster centers with the keys as domain{i}, as per the input_file
166 | - models: list of models that were trained on most relevant domains
167 | - tokenizers: list of tokenizers that were trained on most relevant domains
168 | - T: temperature parameter for softmax
169 | - k: number of most relevant domains to choose from
170 | - maxLength: the maximum length the generated output can reach
171 | - firstToken: boolean to indicate if the first token is yet to be generated
172 | - prompt_data: the entire json object for the prompt
173 |
174 | Returns:
175 | Generated string sequence.
176 | """
177 | prompt = prompt_data['text']
178 | currToken, currSequence = None, prompt
179 | while (len(currSequence) < maxLength) and (currToken != end_token):
180 |
181 | if not currToken:
182 | currToken = findNextToken(
183 | embedder, currSequence, clusterCenters, models, tokenizers, T, k, firstToken=True, prompt_data=prompt_data)
184 | currSequence += currToken
185 | continue
186 |
187 | if currToken == end_token:
188 | break
189 |
190 | currToken = findNextToken(
191 | embedder, currSequence, clusterCenters, models, tokenizers, T, k, firstToken=False, prompt_data=None)
192 | currSequence += currToken
193 |
194 | return currSequence
195 |
196 |
197 | def load_models(model_names):
198 | """
199 | Loads models and tokenizers as per the ids provided in model_names.
200 |
201 | Parameters:
202 | model_names: takes in list of model names and loads model and corresponding tokenizers
203 |
204 | Returns:
205 | Separate lists of models and tokenizers to be accessed by other functions later
206 | """
207 | models = []
208 | tokenizers = []
209 |
210 | for model_name in model_names:
211 | model = AutoModelForCausalLM.from_pretrained(model_name)
212 | tokenizer = AutoTokenizer.from_pretrained(model_name)
213 |
214 | # Freeze the parameters of the loaded models if needed
215 | model.eval()
216 | for param in model.parameters():
217 | param.requires_grad = False
218 |
219 | models.append(model)
220 | tokenizers.append(tokenizer)
221 |
222 | return models, tokenizers
223 |
224 |
225 | def run_inference(embedder, model_names,
226 | input_file, output_file, end_token, maxLength, k, T, clusterCenters):
227 | """
228 | Function that takes in an input file of prompts and writes generated outputs to an output file.
229 | Parameters:
230 | - embedder: the embedder that was used to train the clustering model
231 | - model_names: list of model names
232 | - input_file: input file name, contents are prompts
233 | - output_file: where generated sequences are written to
234 | - end_token: end token to signify termination of sequence
235 | - maxLength: max length of generated sequence
236 | - k: number of most relevant domains to choose from
237 | - T: temperature parameter for softmax
238 | - clusterCenters: dictionary of cluster centers with the keys as domain{i}, as per the input_file
239 | Returns:
240 | Should write generated output to output file.
241 | """
242 | models, tokenizers = load_models(model_names)
243 |
244 | with open(input_file, 'r') as file:
245 | input_data = file.readlines()
246 |
247 | results = []
248 | for query in input_data:
249 | results.append(generateSequence(
250 | embedder, query, end_token, clusterCenters, models, tokenizers, T, k, maxLength)
251 | )
252 |
253 | with open(output_file, 'w') as file:
254 | for result in results:
255 | file.write(f"{result}\n")
256 |
257 |
258 |
259 | if __name__ == '__main__':
260 | fire.Fire(run_inference)
261 |
--------------------------------------------------------------------------------
/scripts/calc_perplexities.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if command -v python3 &>/dev/null; then
4 | PYTHON_CMD=python3
5 | else
6 | PYTHON_CMD=python
7 | fi
8 |
9 | for MODEL in "Multi-Domain-Expert-Layers/expert-arxiv" "Multi-Domain-Expert-Layers/expert-freelaw" "Multi-Domain-Expert-Layers/expert-github" "EleutherAI/pythia-1b-deduped"
10 | do
11 | for DATASET in "Multi-Domain-Expert-Layers/arxiv" "Multi-Domain-Expert-Layers/freelaw" "Multi-Domain-Expert-Layers/github"
12 | do
13 | for SPLIT in "validation_domain" "train" "validation_pile"
14 | do
15 | $PYTHON_CMD ../src/mdel/calculate_perplexity.py --model $MODEL --dataset $DATASET --split $SPLIT
16 | done
17 | done
18 | done
19 |
--------------------------------------------------------------------------------
/scripts/calc_perplexities_slurm.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if command -v python3 &>/dev/null; then
4 | PYTHON_CMD=python3
5 | else
6 | PYTHON_CMD=python
7 | fi
8 |
9 | for MODEL in "Multi-Domain-Expert-Layers/expert-arxiv" "Multi-Domain-Expert-Layers/expert-freelaw" "Multi-Domain-Expert-Layers/expert-github" "EleutherAI/pythia-1b-deduped"
10 | do
11 | for DATASET in "Multi-Domain-Expert-Layers/arxiv" "Multi-Domain-Expert-Layers/freelaw" "Multi-Domain-Expert-Layers/github"
12 | do
13 | for SPLIT in "validation_domain" "train" "validation_pile"
14 | do
15 | JOB_NAME="${MODEL}-${DATASET}-${SPLIT}"
16 | sbatch --job-name="$JOB_NAME" </dev/null; then
5 | PYTHON_CMD=python3
6 | else
7 | PYTHON_CMD=python
8 | fi
9 |
10 | SUBSET_NAME="USPTO Backgrounds"
11 |
12 | for SPLIT in "test" "val" "train"
13 | do
14 | PILE_FILE_PATH="../data/pile/$SPLIT/*.jsonl.zst"
15 | OUTPUT_DIR="../data/mix_uspto_all/$SPLIT"
16 |
17 | # Shard the input data
18 | #$PYTHON_CMD -c "from mdel.pile_utils import *; split_pile('$PILE_FILE_PATH')"
19 |
20 | $PYTHON_CMD -c "from mdel.pile_utils import *; create_pile_domain_mix('$PILE_FILE_PATH', '$PILE_FILE_PATH', '$OUTPUT_DIR', '$SUBSET_NAME')"
21 | done
22 |
--------------------------------------------------------------------------------
/scripts/get_pile_shard1_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # This script downloads the Pile shard 1 data and puts it under data/pile_01.
3 | mkdir -p ../data
4 | mkdir -p ../data/pile
5 | mkdir -p ../data/pile/train
6 | mkdir -p ../data/pile/test
7 | mkdir -p ../data/pile/val
8 |
9 | wget https://the-eye.eu/public/AI/pile/train/01.jsonl.zst -P ../data/pile/train
10 | wget https://the-eye.eu/public/AI/pile/test.jsonl.zst -P ../data/pile/test
11 | wget https://the-eye.eu/public/AI/pile/val.jsonl.zst -P ../data/pile/val
12 |
--------------------------------------------------------------------------------
/scripts/upload_to_hf.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | if command -v python3 &>/dev/null; then
3 | PYTHON_CMD=python3
4 | else
5 | PYTHON_CMD=python
6 | fi
7 |
8 | HF_REPO=Multi-Domain-Expert-Layers/uspto
9 |
10 | for SPLIT in "test" "val" "train"
11 | do
12 | FOLDER_PATH=$(readlink -f ../data/mix_uspto_all/$SPLIT/)
13 | $PYTHON_CMD src/mdel/pile_upload.py --folder-path "$FOLDER_PATH" --hf-repo $HF_REPO --split $SPLIT
14 | done
15 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | setup(
4 | name="mdel",
5 | version="0.1",
6 | packages=find_packages(where="src"),
7 | package_dir={"": "src"},
8 | include_package_data=True
9 | )
10 |
--------------------------------------------------------------------------------
/src/mdel/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huu4ontocord/MDEL/d84a598e765accfb723edd58f6c0a426d8c16d8d/src/mdel/__init__.py
--------------------------------------------------------------------------------
/src/mdel/calculate_perplexity.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import math
4 | import time
5 |
6 | from datasets import load_dataset
7 | from transformers import (AutoModelForCausalLM, AutoTokenizer,
8 | DataCollatorForLanguageModeling, Trainer,
9 | TrainingArguments)
10 |
11 |
12 | def load_model(args):
13 | tokenizer = AutoTokenizer.from_pretrained(
14 | args.tokenizer if args.tokenizer else args.model
15 | )
16 | model = AutoModelForCausalLM.from_pretrained(args.model).cuda()
17 |
18 | tokenizer.pad_token = tokenizer.eos_token
19 |
20 | return tokenizer, model
21 |
22 |
23 | def prep_dataset(args, tokenizer):
24 | ds = load_dataset(args.dataset, split=args.split)
25 | ds = ds.flatten()
26 | ds = ds.select(range(min(args.num_samples, len(ds))))
27 | print("Loaded Dataset with {} samples".format(len(ds)))
28 |
29 | def preprocess_function(examples):
30 | return tokenizer(
31 | [" ".join(x) for x in examples[args.dataset_key]],
32 | truncation=True,
33 | max_length=args.max_length,
34 | )
35 |
36 | tokenized_ds = ds.map(
37 | preprocess_function,
38 | batched=True,
39 | num_proc=4,
40 | remove_columns=ds.column_names,
41 | )
42 |
43 | return tokenized_ds
44 |
45 |
46 | def parse_args():
47 | parser = argparse.ArgumentParser(
48 | description="Calculate perplexity of a model on a dataset"
49 | )
50 | parser.add_argument(
51 | "--model",
52 | type=str,
53 | required=True,
54 | help="Name of the HF model e.g. MDEL/merged-arxiv-github",
55 | )
56 | parser.add_argument(
57 | "--tokenizer",
58 | type=str,
59 | required=False,
60 | help="Optional tokenizer name e.g. MDEL/merged-arxiv-github. If not provided, will use the model name",
61 | )
62 | parser.add_argument(
63 | "--dataset",
64 | type=str,
65 | required=True,
66 | help="Name of the HF dataset e.g. MDEL/pubmed_abstracts",
67 | )
68 | parser.add_argument(
69 | "--split",
70 | type=str,
71 | required=True,
72 | help="Name of the split to evaluate on e.g. validation",
73 | )
74 | parser.add_argument(
75 | "--max-length",
76 | type=int,
77 | required=False,
78 | default=1024,
79 | help="Max length of the input sequence",
80 | )
81 | parser.add_argument(
82 | "--dataset-key",
83 | type=str,
84 | required=False,
85 | default='text',
86 | help="Key to use to access the dataset e.g. text or answers.text",
87 | )
88 |
89 | parser.add_argument(
90 | "--num-samples",
91 | type=int,
92 | required=False,
93 | default=10000,
94 | help="Max number of samples to evaluate on",
95 | )
96 |
97 | return parser.parse_args()
98 |
99 |
100 | if __name__ == "__main__":
101 | args = parse_args()
102 |
103 | tokenizer, model = load_model(args)
104 | dataset = prep_dataset(args, tokenizer)
105 |
106 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
107 |
108 | training_args = TrainingArguments(
109 | output_dir="perplexity-results",
110 | push_to_hub=False,
111 | )
112 |
113 | trainer = Trainer(
114 | model=model,
115 | args=training_args,
116 | eval_dataset=dataset,
117 | data_collator=data_collator,
118 | )
119 |
120 | eval_results = trainer.evaluate()
121 | perplexity = math.exp(eval_results['eval_loss'])
122 | message = f"Perplexity for {args.model} on {args.dataset}[{args.split}]: {perplexity}"
123 | print(message)
124 |
125 | # write to jsonl
126 | data = {
127 | "date": time.time(),
128 | "runtime": eval_results['eval_runtime'],
129 | "model": args.model,
130 | "tokenizer": args.tokenizer if args.tokenizer else args.model,
131 | "dataset": args.dataset,
132 | "split": args.split,
133 | "max_length": args.max_length,
134 | "dataset_key": args.dataset_key,
135 | "num_samples": args.num_samples,
136 | "perplexity": perplexity,
137 | }
138 |
139 | with open("perplexity-results.jsonl", "a") as f:
140 | f.write(json.dumps(data) + "\n")
141 |
--------------------------------------------------------------------------------
/src/mdel/configs/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | dataset_name: Multi-Domain-Expert-Layers/arxiv
3 | model_name_or_path: EleutherAI/pythia-1b-deduped
4 | output_dir: "ckpts/pythia-1b-deduped/uspto/layer_9,10,11,12,13"
5 | training_layers: "4,5"
6 | per_device_train_batch_size: 32
7 | per_device_eval_batch_size: 20
8 | preprocessing_num_workers: 32
9 | learning_rate: 0.0001
10 | block_size: 512
11 | num_train_epochs: 1
12 | gradient_accumulation_steps: 1
13 | do_train: true
14 | do_eval: true
15 | evaluation_strategy: "steps"
16 | save_total_limit: 2
17 | max_grad_norm: 2.0
18 | save_steps: 500
19 | deepspeed: "configs/zero_config.json"
20 | overwrite_output_dir: true
21 | logging_steps: 20
22 | dtype: "float16"
23 | wandb_entity: "ontocord"
24 | wandb_project: "layer-experts"
25 | wandb_run_name: "default"
26 | # Model and Tokenizer related settings
27 | model_revision: "main"
28 | validation_splits:
29 | config_name:
30 | tokenizer_name:
31 | use_fast_tokenizer: true
32 | low_cpu_mem_usage: false
33 | eval_steps: 200
34 |
35 | # Dataset related settings
36 | dataset_config_name:
37 | max_train_samples:
38 | max_eval_samples:
39 | streaming:
40 | overwrite_cache: false
41 |
42 | # Training related settings
43 | gradient_checkpointing: false
44 | warmup_steps: 0
45 | adam_beta1: 0.9
46 | adam_beta2: 0.99
47 | adam_epsilon: 1e-8
48 | weight_decay: 0
49 | save_strategy: "steps"
50 | quantization:
51 |
52 | # Logging and Caching
53 | log_wandb: true
54 | cache_dir:
55 | use_auth_token:
56 |
57 | push_to_hub:
58 | push_to_hub_organization:
59 | push_to_hub_model_id:
60 |
61 | max_steps: -1
62 |
63 | debug:
64 | model_name_or_path: "EleutherAI/pythia-70m-deduped"
65 | output_dir: "layer9"
66 | training_layers: "4,5"
67 | per_device_train_batch_size: 1
68 | per_device_eval_batch_size: 1
69 | preprocessing_num_workers: 32
70 | learning_rate: 1e-4
71 | block_size: 512
72 | num_train_epochs: 1
73 | gradient_accumulation_steps: 1
74 | do_train: true
75 | do_eval: true
76 | evaluation_strategy: "steps"
77 | save_total_limit: 2
78 | max_grad_norm: 2.0
79 | save_steps: 1
80 | eval_steps: 1
81 | max_steps: 1
82 | dtype: "float32"
83 | overwrite_output_dir: true
84 | logging_steps: 20
85 | validation_splits: "validation_pile,validation_domain"
86 | max_train_samples: 2
87 | max_eval_samples: 2
88 |
--------------------------------------------------------------------------------
/src/mdel/configs/zero_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "scheduler": {
14 | "type": "WarmupDecayLR",
15 | "params": {
16 | "warmup_min_lr": "auto",
17 | "warmup_max_lr": "auto",
18 | "warmup_num_steps": "auto",
19 | "total_num_steps": "auto"
20 | }
21 | },
22 | "zero_optimization": {
23 | "stage": 2,
24 | "allgather_partitions": true,
25 | "allgather_bucket_size": 1e9,
26 | "overlap_comm": false,
27 | "reduce_scatter": true,
28 | "reduce_bucket_size": 1e9,
29 | "contiguous_gradients": true
30 | },
31 | "gradient_accumulation_steps": "auto",
32 | "gradient_clipping": "auto",
33 | "steps_per_print": 2000,
34 | "train_batch_size": "auto",
35 | "train_micro_batch_size_per_gpu": "auto",
36 | "wall_clock_breakdown": false
37 | }
38 |
--------------------------------------------------------------------------------
/src/mdel/iterate_layers.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export WANDB_PROJECT=pythia-6.9b-layer-test
3 |
4 |
5 | for i in 6 18 30 9 20 28 12 19 26 15 22 24 ;
6 | do
7 | export WANDB_NAME="layer_${i}"
8 | accelerate launch trainer.py \
9 | --train_file data/train_data.txt \
10 | --validation_file data/book_val.txt \
11 | --model_name_or_path EleutherAI/pythia-6.9b-deduped \
12 | --output_dir "ckpts/pythia-6.9b/books/layer_${i}" \
13 | --training_layer ${i} \
14 | --per_device_train_batch_size 1 \
15 | --per_device_eval_batch_size 1 \
16 | --preprocessing_num_workers 32 \
17 | --learning_rate 1e-4 \
18 | --block_size 512 \
19 | --num_train_epochs 1 \
20 | --gradient_accumulation_steps 8 \
21 | --do_train \
22 | --do_eval \
23 | --overwrite_output_dir \
24 | --logging_steps 20 \
25 | --max_steps 1000
26 | done
27 |
--------------------------------------------------------------------------------
/src/mdel/merge_experts.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | from huggingface_hub import HfApi
5 | from transformers import AutoModelForCausalLM
6 |
7 |
8 | class HFModel:
9 | def __init__(self, model_name):
10 | self.name = model_name
11 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
12 | print("Using device:", self.device)
13 |
14 | def __enter__(self):
15 | print(f"Loading Model {self.name}")
16 | with torch.no_grad():
17 | self.model = AutoModelForCausalLM.from_pretrained(self.name)
18 | self.model = self.model.to(self.device)
19 | self.model.eval()
20 | return self.model
21 |
22 | def __exit__(self, type, value, traceback):
23 | print(f"Unloading Model {self.name}")
24 | del self.model
25 |
26 |
27 | def merge_n_models(model_names):
28 | with torch.no_grad():
29 | with HFModel(model_names[0]) as blended_model:
30 |
31 | # zero out blended models params
32 | for p in blended_model.parameters():
33 | p.data *= 0
34 |
35 | for mn in model_names:
36 | with HFModel(mn) as temp_model:
37 | for p1, p2 in zip(blended_model.parameters(), temp_model.parameters()):
38 | p1.data += p2.data * (1 / len(model_names))
39 | del temp_model
40 | return blended_model
41 |
42 |
43 | def upload_expert(merged_model, args):
44 | sources = "\n".join(map(lambda mn: f"- [{mn}](https://huggingface.co/{mn})", args.experts))
45 |
46 | model_card_yaml = f"""---
47 | tags:
48 | - MDEL
49 | ---
50 |
51 | # Model Name
52 | {args.hf_repo}
53 |
54 | # Model Description
55 | This model was generated by averaging the weights of the following models
56 | {sources}
57 | """
58 | print(model_card_yaml)
59 |
60 | print("Pushing Model to Hub")
61 | merged_model.push_to_hub(args.hf_repo, model_card=model_card_yaml)
62 |
63 | print("Pushing Model Card to Hub")
64 |
65 | # Upload model card
66 | with open("./README.md", "w") as f:
67 | f.write(model_card_yaml)
68 | api = HfApi()
69 | api.upload_file(
70 | path_or_fileobj="./README.md",
71 | path_in_repo="README.md",
72 | repo_id=args.hf_repo,
73 | repo_type="model",
74 | )
75 |
76 |
77 | def parse_args():
78 | parser = argparse.ArgumentParser(description='Merge expert models and upload to the HuggingFace Hub')
79 | parser.add_argument('--hf-repo', type=str, required=True,
80 | help='Name of the repo to upload the merged model e.g. MDEL/merged-arxiv-github')
81 | parser.add_argument('-e', '--expert', action='append',
82 | help='Name of an expert repo e.g MDEL/expert-arxiv (use this flag multiple times)',
83 | required=True, dest='experts')
84 |
85 | return parser.parse_args()
86 |
87 |
88 | if __name__ == '__main__':
89 | args = parse_args()
90 |
91 | if (len(args.experts) < 2):
92 | print("Must specify at least 2 experts to merge")
93 | exit(1)
94 |
95 | merged_model = merge_n_models(args.experts)
96 | upload_expert(merged_model, args)
97 |
--------------------------------------------------------------------------------
/src/mdel/pile_upload.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | from huggingface_hub import HfApi
5 |
6 | api = HfApi()
7 |
8 |
9 | def upload_dataset(folder_path, hf_repo, split):
10 | auth_token = os.getenv('HF_ACCESS_TOKEN')
11 |
12 | print(f"Uploading {folder_path} to {hf_repo}...")
13 |
14 | api.upload_folder(folder_path=folder_path,
15 | repo_id=hf_repo,
16 | repo_type="dataset",
17 | path_in_repo=f"data/{split}",
18 | use_auth_token=auth_token)
19 |
20 |
21 | def parse_args():
22 | parser = argparse.ArgumentParser(description='Upload dataset to Hugging Face Hub')
23 | parser.add_argument('--folder-path', type=str, required=True, help='Path to dataset file')
24 | parser.add_argument('--hf-repo', type=str, required=True, help='Path to dataset file')
25 | parser.add_argument('--split', type=str, required=True, help='Split')
26 | return parser.parse_args()
27 |
28 |
29 | if __name__ == '__main__':
30 | args = parse_args()
31 | folder_path = args.folder_path
32 | hf_repo = args.hf_repo
33 | split = args.split
34 | upload_dataset(folder_path, hf_repo, split)
35 |
--------------------------------------------------------------------------------
/src/mdel/pile_utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import io
3 | import multiprocessing
4 | import os.path
5 | import pathlib
6 |
7 | import ujson
8 | import zstandard
9 | from tqdm import tqdm
10 | from tqdm.contrib.concurrent import process_map
11 |
12 | __HERE__ = pathlib.Path(__file__).parent.resolve()
13 |
14 | SHARD_SIZE = 100000
15 |
16 |
17 | def jsonl_extract_transform_write(infile, outfile, map_func, filter_func, max_lines=-1):
18 | line_counter = 0
19 | num_failed = 0
20 | dctx = zstandard.ZstdDecompressor()
21 |
22 | with dctx.stream_reader(infile) as reader:
23 | cctx = zstandard.ZstdCompressor()
24 | with cctx.stream_writer(outfile, closefd=False) as writer:
25 | text_stream = io.TextIOWrapper(reader, encoding='utf-8')
26 | for line in text_stream:
27 | try:
28 | if filter_func is not None and filter_func(line):
29 | continue
30 |
31 | func_output = map_func(line)
32 | writer.write(f'{func_output}\n'.encode())
33 | line_counter += 1
34 | if -1 < max_lines <= line_counter:
35 | print(f'Stopping after {line_counter} lines')
36 | break
37 | except Exception:
38 | num_failed += 1
39 | continue
40 |
41 | return line_counter, num_failed
42 |
43 |
44 | def pile_get_text(line):
45 | # Just getting rid of the meta tag for a consistent schema
46 | json_obj = ujson.loads(line)
47 | text = json_obj['text']
48 | text_json = ujson.dumps({'text': text})
49 | return text_json
50 |
51 |
52 | def pile_filter_subset(line, subset_name):
53 | json_obj = ujson.loads(line)
54 | return json_obj['meta']['pile_set_name'] != subset_name
55 |
56 |
57 | def domain_mapper(args):
58 | input_file_path, output_file_path, subset_name = args
59 | with open(output_file_path, 'wb+') as outfile:
60 | with open(input_file_path, 'rb') as infile:
61 | num_domain_samples = jsonl_extract_transform_write(infile,
62 | outfile,
63 | pile_get_text,
64 | lambda x: pile_filter_subset(x, subset_name))
65 | return num_domain_samples
66 |
67 |
68 | def pile_mapper(args):
69 | input_file_path, output_file_path, max_lines = args
70 | with open(output_file_path, 'wb+') as outfile:
71 | with open(input_file_path, 'rb') as infile_d:
72 | num_samples = jsonl_extract_transform_write(infile_d, outfile, pile_get_text, None,
73 | max_lines=max_lines)
74 | return num_samples
75 |
76 |
77 | def create_pile_domain_mix(domain_data_file_path: str,
78 | pile_file_path: str,
79 | output_dir: str,
80 | subset_name: str,
81 | max_files: int = -1,
82 | max_workers: int = multiprocessing.cpu_count()):
83 | if not os.path.exists(output_dir):
84 | os.makedirs(output_dir)
85 | else:
86 | raise IOError('Output path already exists')
87 |
88 | domain_data_file_path_expanded, domain_data_processed_paths = process_mix_file_paths(domain_data_file_path,
89 | max_files, output_dir,
90 | 'domain')
91 | print('Processing domain data samples')
92 | file_sample_counts = process_map(domain_mapper,
93 | zip(domain_data_file_path_expanded,
94 | domain_data_processed_paths,
95 | len(domain_data_processed_paths) * [subset_name]),
96 | max_workers=max_workers)
97 |
98 | num_domain_samples = sum([x[0] for x in file_sample_counts])
99 | num_failed_samples = sum([x[1] for x in file_sample_counts])
100 | fail_rate = 1000 * num_failed_samples / num_domain_samples
101 | print(f'Number of domain samples: {num_domain_samples}, rate of samples failed to parse {fail_rate}%')
102 |
103 | print('Processing Pile data samples')
104 | pile_file_path_expanded, pile_processed_paths = process_mix_file_paths(pile_file_path,
105 | -1, output_dir, 'pile')
106 | num_pile_samples = 0
107 | pile_file_idx = 0
108 | while num_pile_samples < num_domain_samples:
109 | cur_num_samples = pile_mapper(
110 | (pile_file_path_expanded[pile_file_idx],
111 | pile_processed_paths[pile_file_idx],
112 | num_domain_samples))
113 | num_pile_samples += cur_num_samples[0]
114 | pile_file_idx += 1
115 |
116 |
117 | def process_mix_file_paths(domain_data_file_path, max_files, output_dir, name_prefix):
118 | domain_data_file_path_expanded = glob.glob(domain_data_file_path)
119 | if max_files > 0:
120 | print(f'Using {max_files} data files')
121 | domain_data_file_path_expanded = domain_data_file_path_expanded[:max_files]
122 | domain_data_processed_paths = [os.path.join(output_dir, name_prefix + '_' + os.path.basename(x))
123 | for x in domain_data_file_path_expanded]
124 | return domain_data_file_path_expanded, domain_data_processed_paths
125 |
126 |
127 | def read_pile_texts(input_file_path):
128 | """
129 | Reads a Pile dataset file in zstd-compressed JSON format and returns a list of 'text' fields.
130 |
131 | :param input_file_path: The path to the input file.
132 | :type input_file_path: str
133 | :return: A list of 'text' fields from each line of the input file.
134 | :rtype: List[str]
135 | :raises FileNotFoundError: If the input file path does not exist.
136 | :raises ValueError: If the input file path is not a string.
137 | :example: read_pile_texts('pile_texts.zst')
138 | """
139 | with open(input_file_path, 'rb') as infile:
140 | dctx = zstandard.ZstdDecompressor()
141 | with dctx.stream_reader(infile) as reader:
142 | text_stream = io.TextIOWrapper(reader, encoding='utf-8')
143 | return [ujson.loads(line)['text'] for line in text_stream]
144 |
145 |
146 | def split_pile(input_file_path, shard_size=SHARD_SIZE):
147 | print(input_file_path)
148 | resolved_files = glob.glob(os.path.abspath(input_file_path))
149 |
150 | for resolved_file in resolved_files:
151 | dctx = zstandard.ZstdDecompressor()
152 | with open(resolved_file, 'rb') as infile:
153 | with dctx.stream_reader(infile) as reader:
154 | text_stream = io.TextIOWrapper(reader, encoding='utf-8')
155 | cctx = zstandard.ZstdCompressor()
156 | shard_num = -1
157 | writer = None
158 | outfile = None
159 |
160 | for line_counter, line in enumerate(tqdm(text_stream)):
161 | if line_counter % shard_size == 0:
162 | if writer is not None:
163 | writer.close()
164 | if outfile is not None:
165 | outfile.close()
166 |
167 | shard_num += 1
168 | output_file_path = os.path.join(
169 | os.path.dirname(resolved_file),
170 | os.path.splitext(resolved_file)[0].replace('.jsonl', '') + f'_{shard_num}.jsonl.zst')
171 | outfile = open(output_file_path, 'wb')
172 | writer = cctx.stream_writer(outfile, closefd=False)
173 |
174 | writer.write(line.encode(encoding='utf-8'))
175 | os.remove(resolved_file)
176 |
177 |
178 | if __name__ == '__main__':
179 | repo_root = __HERE__.parent.parent
180 | # domain_data_file_path = str(repo_root / 'data/pile_uspto/*.jsonl.zst')
181 | # pile_file_path = str(repo_root / 'data/pile_01/01.jsonl.zst')
182 | pile_file_path = '/Users/vmay/Documents/git/MDEL/data/pile/val/*.jsonl.zst'
183 | output_dir = str(repo_root / 'data/mix_uspto_all_3' / 'val')
184 |
185 | # create_pile_domain_mix(pile_file_path, pile_file_path, output_dir, max_files=-1, max_workers=4)
186 | # split_pile('/Users/vmay/Documents/git/MDEL/data/pile/train/*.jsonl.zst')
187 | # print(read_pile_texts('/Users/vmay/Documents/git/MDEL/data/mix_uspto_all/val/domain_val_0.jsonl.zst')[150])
188 | # print(read_pile_texts('/Users/vmay/Documents/git/MDEL/data/mix_uspto_all/val/domain_val_0.jsonl.zst')[151])
189 | # print(read_pile_texts('/Users/vmay/Documents/git/MDEL/data/mix_uspto_all/test/domain_test_0.jsonl.zst')[150])
190 | # print(read_pile_texts('/Users/vmay/Documents/git/MDEL/data/mix_uspto_all/test/domain_test_0.jsonl.zst')[151])
191 | # print(read_pile_texts('/Users/vmay/Documents/git/MDEL/data/mix_uspto_all/test/pile_test_0.jsonl.zst')[150])
192 | # print(read_pile_texts('/Users/vmay/Documents/git/MDEL/data/mix_uspto_all/test/pile_test_0.jsonl.zst')[151])
193 | print(len(read_pile_texts('/Users/vmay/Documents/git/MDEL/data/mix_uspto_all/train/domain_01_0.jsonl.zst')))
194 | print(len(read_pile_texts('/Users/vmay/Documents/git/MDEL/data/mix_uspto_all/train/pile_01_0.jsonl.zst')))
195 |
--------------------------------------------------------------------------------
/src/mdel/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATASET=uspto
4 | TRAINING_LAYERS=9,10,11,12,13
5 |
6 | export WANDB_PROJECT=pythia-1b-deduped-layer-test-$DATASET
7 | export WANDB_NAME="layer_$TRAINING_LAYERS"
8 | export WANDB_ENTITY=ontocord
9 |
10 | # check if venv or conda is activated
11 | if [ -n "$CONDA_DEFAULT_ENV" ] || [ -n "$VIRTUAL_ENV" ]; then
12 | echo "Virtual environment is activated"
13 | else
14 | echo "Error: virtual environment is not activated"
15 | exit 1
16 | fi
17 |
18 | accelerate launch trainer.py \
19 | --configs defaults \
20 | --dataset_name Multi-Domain-Expert-Layers/$DATASET \
21 | --model_name_or_path EleutherAI/pythia-1b-deduped \
22 | --output_dir "ckpts/pythia-1b-deduped/$DATASET/layer_$TRAINING_LAYERS" \
23 | --training_layers $TRAINING_LAYERS \
24 | --per_device_train_batch_size 1 \
25 | --per_device_eval_batch_size 8 \
26 | --preprocessing_num_workers 32 \
27 | --learning_rate 1e-4 \
28 | --block_size 512 \
29 | --num_train_epochs 1 \
30 | --gradient_accumulation_steps 8 \
31 | --evaluation_strategy steps \
32 | --eval_steps 200 \
33 | --logging_steps 20 \
34 | --max_steps 1000 \
35 | --push_to_hub true \
36 | --push_to_hub_model_id expert-$DATASET \
37 | --push_to_hub_organization Multi-Domain-Expert-Layers \
38 | --wandb_entity $WANDB_ENTITY \
39 | --wandb_project $WANDB_PROJECT \
40 | --wandb_run_name $WANDB_NAME \
41 | --validation_splits "validation_pile,validation_domain" \
42 | --dtype "float32" \
43 | --no_deepspeed
44 |
--------------------------------------------------------------------------------
/src/mdel/train_cbtm_classifier.py:
--------------------------------------------------------------------------------
1 | import joblib
2 | from datasets import load_dataset
3 | from sklearn.feature_extraction import FeatureHasher
4 | from sklearn.linear_model import LogisticRegression
5 | from sklearn.metrics import classification_report
6 | from sklearn.model_selection import train_test_split
7 | from sklearn.pipeline import Pipeline
8 | from tokenizers import Tokenizer
9 |
10 | data_url = [
11 | "Multi-Domain-Expert-Layers/pubmed_abstracts",
12 | "Multi-Domain-Expert-Layers/philpapers",
13 | "Multi-Domain-Expert-Layers/pubmed_central",
14 | "Multi-Domain-Expert-Layers/freelaw",
15 | "Multi-Domain-Expert-Layers/arxiv",
16 | "Multi-Domain-Expert-Layers/github",
17 | "Multi-Domain-Expert-Layers/uspto",
18 | ]
19 |
20 | expert_datasets = [load_dataset(i) for i in data_url]
21 |
22 |
23 | tokenizer = Tokenizer.from_pretrained("EleutherAI/pythia-1b-deduped")
24 | tokenizer.enable_truncation(max_length=1024)
25 |
26 |
27 | tokenized_datasets = []
28 | for ed in expert_datasets:
29 | tokenized_datasets.append(ed['train'].map(lambda x: {
30 | "token": [i.tokens for i in tokenizer.encode_batch(x["text"])]
31 | }, batched=True))
32 |
33 |
34 | features = []
35 | label = []
36 | for ed, lab in zip(expert_datasets, data_url):
37 | for i in ed.select(range(min(10000, len(ed)))).iter(batch_size=10000):
38 | features.extend(i['token'])
39 | label.extend([lab] * len(i['token']))
40 |
41 |
42 | X_train, X_test, y_train, y_test = train_test_split(features, label, test_size=0.2, random_state=42)
43 |
44 |
45 | pipeline = Pipeline(
46 | [
47 | ('hasher', FeatureHasher(n_features=512, input_type="string")),
48 | ('lr', LogisticRegression(multi_class='multinomial', solver='lbfgs'))
49 | ]
50 | )
51 | pipeline.fit(X_train, y_train)
52 |
53 |
54 | y_train_pred = pipeline.predict(X_train)
55 | y_test_pred = pipeline.predict(X_test)
56 |
57 |
58 | print(classification_report(y_train, y_train_pred))
59 | print(classification_report(y_test, y_test_pred))
60 |
61 |
62 | joblib.dump(pipeline, 'cbtm_classifier.pkl')
63 |
--------------------------------------------------------------------------------
/src/mdel/train_chat.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATASET=mini-pile-instruct
4 | TRAINING_LAYERS=4,5,6,7,8
5 |
6 | export WANDB_PROJECT=pythia-1b-deduped-layer-test-$DATASET
7 | export WANDB_NAME="layer_$TRAINING_LAYERS"
8 | export WANDB_ENTITY=ontocord
9 |
10 | # check if venv or conda is activated
11 | if [ -n "$CONDA_DEFAULT_ENV" ] || [ -n "$VIRTUAL_ENV" ]; then
12 | echo "Virtual environment is activated"
13 | else
14 | echo "Error: virtual environment is not activated"
15 | exit 1
16 | fi
17 |
18 | accelerate launch trainer.py \
19 | --dataset_name Multi-Domain-Expert-Layers/$DATASET \
20 | --model_name_or_path EleutherAI/pythia-1b-deduped \
21 | --output_dir "ckpts/pythia-1b-deduped/$DATASET/layer_$TRAINING_LAYERS" \
22 | --training_layers $TRAINING_LAYERS \
23 | --per_device_train_batch_size 1 \
24 | --per_device_eval_batch_size 8 \
25 | --preprocessing_num_workers 32 \
26 | --learning_rate 1e-4 \
27 | --block_size 512 \
28 | --num_train_epochs 1 \
29 | --gradient_accumulation_steps 8 \
30 | --do_train \
31 | --do_eval \
32 | --evaluation_strategy steps \
33 | --eval_steps 200 \
34 | --overwrite_output_dir \
35 | --logging_steps 20 \
36 | --max_steps 1000 \
37 | --push_to_hub true \
38 | --push_to_hub_model_id expert-$DATASET \
39 | --push_to_hub_organization Multi-Domain-Expert-Layers
40 |
--------------------------------------------------------------------------------
/src/mdel/train_ds.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATASET=uspto
4 | TRAINING_LAYERS=9,10,11,12,13
5 |
6 | export WANDB_PROJECT=pythia-1b-deduped-layer-test-$DATASET
7 | export WANDB_NAME="layer_$TRAINING_LAYERS"
8 | export WANDB_ENTITY=ontocord
9 |
10 | # check if venv or conda is activated
11 | if [ -n "$CONDA_DEFAULT_ENV" ] || [ -n "$VIRTUAL_ENV" ]; then
12 | echo "Virtual environment is activated"
13 | else
14 | echo "Error: virtual environment is not activated"
15 | exit 1
16 | fi
17 |
18 | deepspeed trainer.py \
19 | --configs defaults \
20 | --dataset_name Multi-Domain-Expert-Layers/$DATASET \
21 | --model_name_or_path EleutherAI/pythia-1b-deduped \
22 | --output_dir "ckpts/pythia-1b-deduped/$DATASET/layer_$TRAINING_LAYERS" \
23 | --training_layers $TRAINING_LAYERS \
24 | --push_to_hub true \
25 | --push_to_hub_model_id expert-$DATASET \
26 | --push_to_hub_organization Multi-Domain-Expert-Layers \
27 | --wandb_entity $WANDB_ENTITY \
28 | --wandb_project $WANDB_PROJECT \
29 | --wandb_run_name $WANDB_NAME \
30 | --validation_splits "validation_pile,validation_domain" \
31 |
--------------------------------------------------------------------------------
/src/mdel/trainer_chat.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | accelerate launch trainer_chat.py ^
3 | --dataset_name Dahoas/full-hh-rlhf ^
4 | --model_name_or_path EleutherAI/pythia-160m ^
5 | --output_dir "output_chat" ^
6 | --training_layers "5,6,7" ^
7 | --separator "Assistant:" ^
8 | --prompt_column "prompt" ^
9 | --answer_column "response" ^
10 | --per_device_train_batch_size 1 ^
11 | --per_device_eval_batch_size 8 ^
12 | --preprocessing_num_workers 32 ^
13 | --learning_rate 1e-4 ^
14 | --block_size 512 ^
15 | --num_train_epochs 1 ^
16 | --gradient_accumulation_steps 8 ^
17 | --do_train ^
18 | --do_eval ^
19 | --evaluation_strategy steps ^
20 | --eval_steps 200 ^
21 | --overwrite_output_dir ^
22 | --logging_steps 20 ^
23 | --max_steps 1000
24 |
25 | pause
26 |
--------------------------------------------------------------------------------