├── .gitignore ├── docs ├── .editorconfig ├── .github │ ├── release-drafter.yml │ └── workflows │ │ └── release-notes.yml ├── .gitignore ├── 404.html ├── Gemfile ├── Gemfile.lock ├── LICENSE.md ├── README.md ├── _config.yml ├── _includes │ ├── head.html │ └── mathjax_support.html ├── _layouts │ ├── default.html │ ├── page.html │ └── post.html ├── _posts │ └── 2024-06-28-StructuredFFN.md ├── _sass │ ├── _base.scss │ ├── _code.scss │ ├── _layout.scss │ ├── _masthead.scss │ ├── _message.scss │ ├── _pagination.scss │ ├── _posts.scss │ ├── _syntax.scss │ ├── _toc.scss │ ├── _type.scss │ └── _variables.scss ├── assets │ ├── apple-touch-icon-precomposed.png │ ├── author.png │ ├── favicon.ico │ ├── fig_sgt_lowrank.png │ ├── gpt.png │ ├── latency.png │ ├── latency_bs.png │ ├── method.png │ ├── scaling_law_lowrank.png │ ├── training_dynamic.png │ └── wide_structured.png ├── atom.xml ├── index.html ├── poole-for-jekyll.gemspec └── styles.scss ├── experiment ├── Dockerfile ├── basic.sh └── sgd.sh ├── image.png ├── readme.md └── src ├── benchmark_acc └── refinedweb_experiment.py ├── benchmark_eff ├── bench_kernel.py ├── benchmark_mlp_train.py ├── benchmark_model_infer.py ├── benchmark_model_train.py └── cac_batch.py ├── configs ├── data │ ├── Aug │ │ ├── mixup.yaml │ │ └── randomaugment.yaml │ ├── cifar10.yaml │ └── refinedweb.yaml ├── method │ ├── blockdense.yaml │ ├── blockshuffle.yaml │ ├── linear.yaml │ └── lowrank.yaml ├── model │ ├── gpt2.yaml │ ├── gpt2l.yaml │ ├── gpt2m.yaml │ └── gpt2xl.yaml ├── optimization │ ├── basic.yaml │ ├── lr_scheduler │ │ ├── cosineannealinglr.yaml │ │ └── multisteplr.yaml │ ├── optimizer │ │ ├── adam.yaml │ │ ├── adamw.yaml │ │ └── sgd.yaml │ └── training │ │ ├── regular_training.yaml │ │ └── self_guided_training.yaml └── refinedweb_config.yaml ├── modules ├── __init__.py ├── layer │ ├── __init__.py │ ├── basiclinear.py │ ├── blockdense.py │ ├── blockshuffle.py │ ├── customlinear.py │ ├── lowrank.py │ └── util.py ├── mlp │ ├── __init__.py │ ├── basic_mlp.py │ ├── blockdense_mlp.py │ ├── blockshuffle_mlp.py │ ├── lowrank_mlp.py │ └── mlp.py ├── model │ ├── __init__.py │ └── gpt2.py └── op │ ├── __init__.py │ ├── block_dense.py │ ├── block_shuffle.py │ ├── common │ ├── fused_bias_dropout_add.py │ ├── fused_gelu.py │ ├── fused_swiglu.py │ └── rotary_embeddings.py │ └── low_rank.py ├── optimization ├── __init__.py ├── scheduler.py └── trainer.py └── utils └── refinedweb_llama.py /.gitignore: -------------------------------------------------------------------------------- 1 | exp/ 2 | .DS_Store 3 | # Created by https://www.gitignore.io/api/python 4 | # Edit at https://www.gitignore.io/?templates=python 5 | 6 | old_test/ 7 | ### Python ### 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # celery beat schedule file 91 | celerybeat-schedule 92 | 93 | # SageMath parsed files 94 | *.sage.py 95 | 96 | # Environments 97 | .env 98 | .venv 99 | env/ 100 | venv/ 101 | ENV/ 102 | env.bak/ 103 | venv.bak/ 104 | 105 | # Spyder project settings 106 | .spyderproject 107 | .spyproject 108 | 109 | # Rope project settings 110 | .ropeproject 111 | 112 | # mkdocs documentation 113 | /site 114 | 115 | # mypy 116 | .mypy_cache/ 117 | .dmypy.json 118 | dmypy.json 119 | 120 | # Pyre type checker 121 | .pyre/ 122 | 123 | ### Python Patch ### 124 | .venv/ 125 | 126 | ### Python.VirtualEnv Stack ### 127 | # Virtualenv 128 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ 129 | [Bb]in 130 | [Ii]nclude 131 | [Ll]ib 132 | [Ll]ib64 133 | [Ll]ocal 134 | [Ss]cripts 135 | pyvenv.cfg 136 | pip-selfcheck.json 137 | 138 | # End of https://www.gitignore.io/api/python -------------------------------------------------------------------------------- /docs/.editorconfig: -------------------------------------------------------------------------------- 1 | # editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | charset = utf-8 7 | end_of_line = lf 8 | indent_size = 2 9 | indent_style = space 10 | insert_final_newline = true 11 | trim_trailing_whitespace = true 12 | -------------------------------------------------------------------------------- /docs/.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name-template: 'v$NEXT_MINOR_VERSION' 2 | tag-template: 'v$NEXT_MINOR_VERSION' 3 | prerelease: true 4 | exclude-labels: 5 | - 'skip-changelog' 6 | categories: 7 | - title: '🚀 Features' 8 | labels: 9 | - 'new-feature' 10 | - 'feature' 11 | - 'enhancement' 12 | - title: '🐛 Bug fixes' 13 | labels: 14 | - 'fix' 15 | - 'bugfix' 16 | - 'bug' 17 | - title: '📖 Docs' 18 | labels: 19 | - 'docs' 20 | - title: '📦 Dependencies' 21 | labels: 22 | - 'dependencies' 23 | - title: '🧰 Maintenance' 24 | label: 'chore' 25 | change-template: '- #$NUMBER: $TITLE' 26 | template: | 27 | ## Changes 28 | 29 | $CHANGES 30 | -------------------------------------------------------------------------------- /docs/.github/workflows/release-notes.yml: -------------------------------------------------------------------------------- 1 | name: Release notes 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | 8 | jobs: 9 | update_release_draft: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: release-drafter/release-drafter@v5 13 | env: 14 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 15 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore docs files 2 | _gh_pages 3 | _site 4 | .ruby-version 5 | .sass-cache 6 | .jekyll-cache 7 | 8 | # Numerous always-ignore extensions 9 | *.diff 10 | *.err 11 | *.orig 12 | *.log 13 | *.rej 14 | *.swo 15 | *.swp 16 | *.zip 17 | *.vi 18 | *~ 19 | 20 | # OS or Editor folders 21 | .DS_Store 22 | ._* 23 | Thumbs.db 24 | .cache 25 | .project 26 | .settings 27 | .tmproj 28 | *.esproj 29 | nbproject 30 | *.sublime-project 31 | *.sublime-workspace 32 | .idea 33 | 34 | # Komodo 35 | *.komodoproject 36 | .komodotools 37 | 38 | # grunt-html-validation 39 | validation-status.json 40 | validation-report.json 41 | 42 | # Folders to ignore 43 | node_modules 44 | -------------------------------------------------------------------------------- /docs/404.html: -------------------------------------------------------------------------------- 1 | --- 2 | layout: default 3 | title: "404: Page not found" 4 | permalink: 404.html 5 | --- 6 | 7 |
8 |

404: Page not found

9 |

Sorry, we've misplaced that URL or it's pointing to something that doesn't exist. Head back home to try finding it again.

10 |
11 | -------------------------------------------------------------------------------- /docs/Gemfile: -------------------------------------------------------------------------------- 1 | source "https://rubygems.org" 2 | 3 | gem "jekyll" 4 | gem "jekyll-gist" 5 | gem "jekyll-paginate" 6 | gem "jekyll-seo-tag" 7 | 8 | gem "webrick", "~> 1.8" 9 | -------------------------------------------------------------------------------- /docs/Gemfile.lock: -------------------------------------------------------------------------------- 1 | GEM 2 | remote: https://rubygems.org/ 3 | specs: 4 | addressable (2.8.7) 5 | public_suffix (>= 2.0.2, < 7.0) 6 | bigdecimal (3.1.1) 7 | colorator (1.1.0) 8 | concurrent-ruby (1.3.3) 9 | em-websocket (0.5.3) 10 | eventmachine (>= 0.12.9) 11 | http_parser.rb (~> 0) 12 | eventmachine (1.2.7) 13 | faraday (2.9.2) 14 | faraday-net_http (>= 2.0, < 3.2) 15 | faraday-net_http (3.1.0) 16 | net-http 17 | ffi (1.17.0-arm64-darwin) 18 | forwardable-extended (2.6.0) 19 | google-protobuf (4.27.2-arm64-darwin) 20 | bigdecimal 21 | rake (>= 13) 22 | http_parser.rb (0.8.0) 23 | i18n (1.14.5) 24 | concurrent-ruby (~> 1.0) 25 | jekyll (4.3.3) 26 | addressable (~> 2.4) 27 | colorator (~> 1.0) 28 | em-websocket (~> 0.5) 29 | i18n (~> 1.0) 30 | jekyll-sass-converter (>= 2.0, < 4.0) 31 | jekyll-watch (~> 2.0) 32 | kramdown (~> 2.3, >= 2.3.1) 33 | kramdown-parser-gfm (~> 1.0) 34 | liquid (~> 4.0) 35 | mercenary (>= 0.3.6, < 0.5) 36 | pathutil (~> 0.9) 37 | rouge (>= 3.0, < 5.0) 38 | safe_yaml (~> 1.0) 39 | terminal-table (>= 1.8, < 4.0) 40 | webrick (~> 1.7) 41 | jekyll-gist (1.5.0) 42 | octokit (~> 4.2) 43 | jekyll-paginate (1.1.0) 44 | jekyll-sass-converter (3.0.0) 45 | sass-embedded (~> 1.54) 46 | jekyll-seo-tag (2.8.0) 47 | jekyll (>= 3.8, < 5.0) 48 | jekyll-watch (2.2.1) 49 | listen (~> 3.0) 50 | kramdown (2.4.0) 51 | rexml 52 | kramdown-parser-gfm (1.1.0) 53 | kramdown (~> 2.0) 54 | liquid (4.0.4) 55 | listen (3.9.0) 56 | rb-fsevent (~> 0.10, >= 0.10.3) 57 | rb-inotify (~> 0.9, >= 0.9.10) 58 | mercenary (0.4.0) 59 | net-http (0.3.0.1) 60 | uri 61 | octokit (4.25.1) 62 | faraday (>= 1, < 3) 63 | sawyer (~> 0.9) 64 | pathutil (0.16.2) 65 | forwardable-extended (~> 2.6) 66 | public_suffix (6.0.0) 67 | rake (13.0.6) 68 | rb-fsevent (0.11.2) 69 | rb-inotify (0.11.1) 70 | ffi (~> 1.0) 71 | rexml (3.2.5) 72 | rouge (4.3.0) 73 | safe_yaml (1.0.5) 74 | sass-embedded (1.77.5-arm64-darwin) 75 | google-protobuf (>= 3.25, < 5.0) 76 | sawyer (0.9.2) 77 | addressable (>= 2.3.5) 78 | faraday (>= 0.17.3, < 3) 79 | terminal-table (3.0.2) 80 | unicode-display_width (>= 1.1.1, < 3) 81 | unicode-display_width (2.5.0) 82 | uri (0.12.2) 83 | webrick (1.8.1) 84 | 85 | PLATFORMS 86 | arm64-darwin 87 | 88 | DEPENDENCIES 89 | jekyll 90 | jekyll-gist 91 | jekyll-paginate 92 | jekyll-seo-tag 93 | webrick (~> 1.8) 94 | 95 | BUNDLED WITH 96 | 2.5.14 97 | -------------------------------------------------------------------------------- /docs/LICENSE.md: -------------------------------------------------------------------------------- 1 | # Released under MIT License 2 | 3 | Copyright (c) 2013 Mark Otto. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # StructuredFFN 2 | 3 | ![StructuredFFN](https://user-images.githubusercontent.com/13270895/89133355-26b3af80-d4e9-11ea-81cd-eacaa9c78320.png) 4 | 5 | StructuredFFN is a permanent dark theme of the Poole theme by [@mdo](https://github.com/mdo). I made the theme darker, inspired by [Derek Kedziora's site](https://derekkedziora.com/). Unlike default Poole that utilizes CSS media queries to activate dark mode, the theme will stay dark regardless of the user's preference. 6 | 7 | - I added a navbar that is easily customizable. Check out [Development](#development) to see how. 8 | - I also got rid of the "tagline" in the navbar. I think it looks cleaner without it. 9 | - Finally, I changed the default font size to 20px. I have 20/20 vision and still thought the original font size was too small. 10 | 11 | That's it! I tried to be least intrusive as possible to the Poole code base. 12 | 13 | **I noticed that Poole's documentation is slightly outdated and misleading. This documentation will try to address most, if not all, of these issues.** 14 | 15 | --- 16 | 17 | ## Contents 18 | 19 | - [Usage](#usage) 20 | - [Development](#development) 21 | - [Author](#author) 22 | - [License](#license) 23 | 24 | ## Usage 25 | 26 | ### 1. Install dependencies 27 | 28 | Poole is built on Jekyll and uses its built-in SCSS compiler to generate our CSS. Before getting started, you'll need to install the Jekyll gem and related dependencies: 29 | 30 | ```bash 31 | $ gem install jekyll jekyll-gist jekyll-sitemap jekyll-seo-tag 32 | ``` 33 | 34 | ### 2. Install bundler 35 | 36 | You must have bundler installed. If you already have bundler installed, please skip this step. 37 | 38 | ```bash 39 | # Update Rubygems 40 | $ gem update --system 41 | # Update bundler 42 | $ gem install bundler 43 | ``` 44 | 45 | ### 3. Running locally 46 | 47 | To see your Jekyll site with Poole applied, start a Jekyll server. In Terminal, from `/dark-poole` (or whatever your Jekyll site's root directory is named): 48 | 49 | ```bash 50 | $ bundle exec jekyll serve 51 | ``` 52 | 53 | Open in your browser, and voilà. 54 | 55 | ### 4. Serving it up 56 | 57 | If you host your code on GitHub, you can use [GitHub Pages](https://pages.github.com) to host your project. 58 | 59 | 1. Fork this repo and switch to the `gh-pages` branch. 60 | 1. If you're [using a custom domain name](https://help.github.com/articles/setting-up-a-custom-domain-with-github-pages), modify the `CNAME` file to point to your new domain. 61 | 1. If you're not using a custom domain name, **modify the `url` in `_config.yml`** to point to your GitHub Pages URL. Example: for a site hosted at `username.github.io`, use `http://username.github.io`. 62 | 1. If you want to use your repo name as a base url, **set the `url`** to your repo link and **set the `baseurl`** to your repo name in **`_config.yml`**. Example: for site hosted on `https://username.github.io/dark-poole`, set `url` as `https://username.github.io/dark-poole` and `baseurl` as `/dark-poole`. 63 | 1. Done! Head to your GitHub Pages URL or custom domain. 64 | 65 | No matter your production or hosting setup, be sure to verify the `baseurl` option file and `CNAME` settings. Not applying this correctly can mean broken styles on your site. 66 | 67 | ### 5. Pagination for sites with base urls 68 | 69 | If you are using a base url for your site, (for example, hosted on `https://username.github.io/dark-poole`) you have to make some changes to get jekyll-pagination to work correctly: 70 | 71 | In `_config.yml`, add this line: 72 | 73 | ```yaml 74 | paginate_path: "/baseurl/page:num/" 75 | ``` 76 | 77 | In `archive.md`, add `{{ site.baseurl }}` before `{{ post.url }}` 78 | 79 | ```html 80 | 81 |
  • {{ post.title }}
  • 82 | ``` 83 | 84 | In `index.html`, remove the `prepend:`: 85 | 86 | ```html 87 | 88 | Newer 93 | ``` 94 | 95 | ## Development 96 | 97 | Poole has two branches, but only one is used for active development. 98 | 99 | - `master` for development. **All pull requests should be to submitted against `master`.** 100 | - `gh-pages` for hosted demo **Please avoid using this branch.** 101 | 102 | CSS is handled via Jeykll's built-in Sass compiler. Source Sass files are located in `_sass/`, included into `styles.scss`, and compile to `styles.css`. 103 | 104 | ### Customize Navbar 105 | 106 | You can easily customize the navbar by tweaking the `_config.yml` file. Simply change the title and url of each of the nav elements, or add more. The order will be preserved in the site. 107 | 108 | ```yaml 109 | nav: 110 | - title: Blog 111 | url: /archive 112 | 113 | - title: About 114 | url: /about 115 | ``` 116 | 117 | ## Author 118 | 119 | **Mark Otto** 120 | 121 | - 122 | - 123 | 124 | ## License 125 | 126 | Open sourced under the [MIT license](LICENSE.md). 127 | 128 | <3 129 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | title: StructuredFFN 3 | url: https://claire-labo.github.io 4 | paginate: 1 5 | baseurl: "/StructuredFFN" 6 | permalink: pretty 7 | 8 | # Gems 9 | plugins: 10 | - jekyll-gist 11 | - jekyll-paginate 12 | - jekyll-seo-tag 13 | 14 | # Optimize Jekyll 15 | exclude: 16 | - .editorconfig 17 | - .git 18 | - .jekyll-cache 19 | - Gemfile 20 | - Gemfile.lock 21 | - LICENSE.md 22 | - README.md 23 | 24 | sass: 25 | sass_dir: _sass 26 | style: :compressed 27 | 28 | # Options 29 | 30 | # Replace this value and uncomment to enable Google Analytics tracking 31 | # ga_analytics: UA-000000-0 32 | 33 | # Specify the author for blog posts 34 | author: 35 | name: Mark Otto 36 | url: https://twitter.com/mdo 37 | email: markdotto@gmail.com 38 | 39 | # Custom vars 40 | version: 3.0.0 41 | 42 | # # Navbar page list 43 | # nav: 44 | # - title: Blog 45 | # url: /archive 46 | 47 | # - title: About 48 | # url: /about 49 | -------------------------------------------------------------------------------- /docs/_includes/head.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | {% if page.title == "Home" %} 7 | {{ site.title }}{% if site.tagline %} · {{ site.tagline }}{% endif %} 8 | {% else %} 9 | {{ page.title }} · {{ site.title }} 10 | {% endif %} 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | {% seo title=false %} 19 | 20 | -------------------------------------------------------------------------------- /docs/_includes/mathjax_support.html: -------------------------------------------------------------------------------- 1 | 21 | 24 | -------------------------------------------------------------------------------- /docs/_layouts/default.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | {% include head.html %} 4 | {% include mathjax_support.html %} 5 | 6 | 7 |
    8 |
    9 |

    10 | {{ site.title }} 11 | 12 | 17 |

    18 |
    19 | 20 |
    21 | {{ content }} 22 |
    23 | 24 |
    25 | 26 | © 27 | . All rights reserved. 30 | 31 |
    32 |
    33 | 34 | {% if site.ga_analytics %} 35 | 58 | {% endif %} 59 | 60 | 61 | -------------------------------------------------------------------------------- /docs/_layouts/page.html: -------------------------------------------------------------------------------- 1 | --- 2 | layout: default 3 | --- 4 | 5 |
    6 |

    {{ page.title }}

    7 | {{ content }} 8 |
    9 | -------------------------------------------------------------------------------- /docs/_layouts/post.html: -------------------------------------------------------------------------------- 1 | --- 2 | layout: default 3 | --- 4 | 5 |
    6 |

    {{ page.title }}

    7 | 8 | {{ content }} 9 |
    10 | 11 | {% if site.related_posts != empty %} 12 | 25 | {% endif %} 26 | -------------------------------------------------------------------------------- /docs/_posts/2024-06-28-StructuredFFN.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: post 3 | title: Building on Efficient Foundations Effectively Training LLMs with Structured Feedforward Layers 4 | --- 5 | 6 | **Author list: Xiuying Wei (CLAIRE, EPFL), Skander Moalla (CLAIRE, EPFL), Razvan Pascanu (Google DeepMind), Caglar Gulcehre (CLAIRE, EPFL)** 7 | 8 | ## Abstract 9 | 10 | State-of-the-art results in large language models (LLMs) often rely on scale, which becomes computationally expensive. This has sparked a research agenda to reduce these models' parameter count and computational costs without significantly impacting their performance. Our study focuses on transformer-based LLMs, specifically targeting the computationally intensive feedforward networks (FFN), which are less studied than attention blocks. We consider three candidate linear layer approximations in the FFN by combining efficient low-rank and block-diagonal matrices. In contrast to many previous works that examined these approximations, our study i) explores these structures from the training-from-scratch perspective, ii) scales up to 1.3B parameters, and iii) is conducted within recent Transformer-based LLMs rather than convolutional architectures. We first demonstrate they can lead to actual computational gains in various scenarios, including online decoding when using a pre-merge technique. Additionally, we propose a novel training regime, called *self-guided training*, aimed at improving the poor training dynamics that these approximations exhibit when used from initialization. Experiments on the large RefinedWeb dataset show that our methods are both efficient and effective for training and inference. Interestingly, these structured FFNs exhibit steeper scaling curves than the original models. Further applying self-guided training to the structured matrices with 32% FFN parameters and 2.5$\times$ speed-up enables only a 0.4 perplexity increase under the same training FLOPs. Finally, we develop the wide and structured networks surpassing the current medium-sized and large-sized Transformer in perplexity and throughput performance. 11 | 12 | 13 | ## Method 14 | 15 | ### Structured linear parametrization 16 | We consider three structured parameterizations to approximate a linear layer ($Wx$) as below which have demonstrated computational gains on existing hardware. 17 | 18 | * LowRank: $Wx \approx U^r(V^rx)$, where the superscript $^r$ is used to indicate matrices projecting in or from low dimensional states. 19 | * BlockShuffle (two block-diagonal matrices, same as Monarch [1]): $Wx \approx f^{-1}(U^b f(V^bx))$, where $V^b$ and $U^b$ are block-diagonal matrices and the shuffle function $f(\cdot)$ enables global feature mixing by cycling different blocks. 20 | * BlockDense (block-diagonal followed by a dense matrix): $Wx \approx U^r(V^bx)$. Technically, the second projection does not need to be a low-rank approximation to reduce the parameter. But in practice, we chose the low-rank one with superscript $r$ to limit our search space. 21 | 22 | The figure below shows how they perform and their reduced parameters and MAC. 23 | 24 | 25 | 26 | 27 | 28 | Then, we go deeper to investigate their common challenges including efficiency and optimization. 29 | 30 | ### Maintaining efficiency during online decoding 31 | 32 | Challenge: While they have demonstrated materialized computational gains, they face challenges in the practical online decoding scenario of LLM, which may process only limited input tokens at one time, leading to under-utilization of computing resources and decreased efficiency due to the additional linear projection. 33 | 34 | Pre-merge technique: We address this with a pre-merge technique that restores the original dense efficiency when the total number of tokens is quite small (e.g., 16). Taking advantage of the fact that these parametrizations do not have non-linearity, we propose to combine the structured matrices into a single dense layer and keep both the structured and the dense one for online decoding. Then, we can dynamically decide which parametrization to use based on the current batch size and setting. 35 | 36 | 37 | 38 | ### Addressing the optimization challenge 39 | 40 | Challenge: Using the efficient parametrization from initialization can suffer from optimization difficulty because the deep linear parametrization introduces additional symmetries, which is a source of proliferation of saddle points and generally less smooth loss function as pointed out in [2]. Empirically, we show that the deep linear form of $U(Vx)$ leads to instability and loss spike or to slow convergence compared to the dense linear projection in the figure below. 41 | 42 | 43 | 44 | 45 | 46 | Self-guided training: Addressing poor training dynamics by tuning the learning rate and gradient clipping is costly and unstable. We propose a simpler, cost-effective approach called self-guided training, requiring minimal hyperparameter re-tuning. This method uses dense parametrization to efficiently navigate early stages, where symmetries introduced by the structured parametrization impact feature specialization, then transfers the control to $U$ and $V$, defined as: 47 | 48 | ​ $o = \alpha \cdot W x + (1-\alpha) \cdot U(Vx)$, 49 | 50 | $o$ is the layer's output, and $\alpha$ decays following a cosine scheduler. As a residual component, learning $W$ is unaffected by the additional saddles and pathologies, allowing units to specialize. This *guides* the training of $U$ and $V$, which are forced slowly to take over by providing the hidden units semantics learned by $W$. The loss curves above show that such a method makes the training dynamics much better. 51 | 52 | For more details, please check the paper. 53 | 54 | ## Experiments 55 | 56 | We conduct our experiments at scale on Transformers ranging from 110M to 1.3B parameters. We demonstrate the efficiency of these parametrizations, conduct a scaling analysis that structured matrices have steeper scaling curves compared to the dense ones, and validate that self-guided training can boost the final performance efficiently. Finally, we design the wide and structured networks by combing the GQA [4], improving both the perplexity and throughput. 57 | 58 | ### Evaluating latency results 59 | 60 | We investigate the efficiency of structured FFN and consider different numbers of tokens to discuss different scenarios. 61 | 62 | - Large number of tokens (usually concerning training, the prefill phase of inference, and extensive decoding cases) 63 | 64 | From width 1536, LowRank and BlockDense begin to enable about a 1.4$\times$ speed-up and a 2.5$\times$ speed-up with 63% and 32% parameters, respectively. 65 | 66 | 67 | 68 | 69 | 70 | - Small number of tokens (may happen at the decoding stage, especially for the online case) 71 | 72 | We vary the batch of tokens to determine when to use efficient alternatives or choose pre-merged dense matrices. For example, with a 2048-width FFN, it is difficult to fully utilize resources on GPU with limited tokens. The performance improves significantly when using width 5120 and 6144, such as speed improvements of 2.63$\times$ speed-up of LowRank with 32% FFN parameters on total number of tokens of 2048 and 2.81$\times$ acceleration of BlockDense with 32% parameters on 1536 tokens. 73 | 74 | 75 | 76 | 77 | 78 | ### Findings on efficient training 79 | 80 | - Comparison between structured FFNs 81 | 82 | With the model and training FLOPs fixed, we show that LowRank and BlockDense can be better than the BlockShuffle for FFN in NLP tasks. However, we think this is task-dependent, because in vision tasks where block-diagonal matrices are better for local information, we find that block-diagonal matrix is a more suitable inductive bias (see experiments in the appendix). 83 | 84 | ![](assets/gpt.png) 85 | 86 | 87 | 88 | - Scaling analysis 89 | 90 | As we scale the model size, we find steeper scaling curves of structured matrices. Below, it's a figure for LowRank, but the other two hold similar curves. Specifically, 91 | 92 | ​ *(i) The structured matrices exhibit steeper scaling curves compared to the dense networks, indicating significant potential for these efficient designs in LLMs.* 93 | 94 | ​ *(ii) The scaling curve of 32\% parameters of FFN is steeper than the 63\% parameters of FFN highlights the scaling potential of highly structured large models.* 95 | 96 | ​ *(iii) Given fixed training FLOPs budget, a wider and structured network with more tokens may achieve comparable or superior performance to dense networks at the optimal trade-off.* 97 | 98 | 99 | 100 | 101 | ### Self-guided training 102 | 103 | With the self-guided training, our performance gets closer to dense models. For example, with the same training FLOPs, our 1.3B model has a 0.4 perplexity loss vs. the dense one and enjoys about 2.5x FFN speed-up for inference. Additionally, we compare our method with another advanced baseline that trains structured parametrizations with more tokens, showing that ours achieves comparable or superior results even with the same number of tokens. 104 | 105 | 106 | 107 | 108 | 109 | ### Wide and Structured network 110 | 111 | As maintaining the parameter ratio of attention to FFN can be important, in this section, we use GQA to make attention efficient and LowRank for FFN, designing a wide and structured network from Transformer-m and Transformer-l. To match the training FLOPs, we either train on more tokens or apply self-guided training. 112 | 113 | It can be seen that our methods achieve an 8% and 17% maximum throughput boost, respectively, while maintaining or slightly improving perplexity. TP refers to the maximum throughput measured on a generation length of 256. 114 | 115 | 116 | 117 | 118 | 119 | ## Conclusion and Limitation 120 | 121 | Conclusion: In this paper, we conducted extensive experiments investigating the use of structured matrices to parameterize FFN in Transformers, with models up to 1.3B parameters on the RefinedWeb dataset. Our primary aim was not to determine which structured matrices perform best, as this can be task-dependent, but to explore common issues including efficiency and optimization challenges of existing structured matrices as well as BlockDense. 122 | 123 | Limitation: BlockDense and BlockShuffle are more complicated than LowRank. In this work, we only explored a limited range of hyperparameter settings of them. Also, we primarily focused on language modeling with limited vision experiments included in the appendix. Additionally, we did not explore the optimal scaling laws for structured matrices, which may further enhance performance. 124 | 125 | ## References 126 | 127 | [1]. Monarch: Expressive Structured Matrices for Efficient and Accurate Training. ICML2022 128 | 129 | [2]. Exact solutions to the nonlinear dynamics of learning in deep linear neural networks. ICLR2016 130 | 131 | [3]. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP2023 132 | 133 | ## Useful Links 134 | 135 | Paper:https://arxiv.org/pdf/2406.16450 136 | 137 | Code: https://github.com/CLAIRE-Labo/StructuredFFN -------------------------------------------------------------------------------- /docs/_sass/_base.scss: -------------------------------------------------------------------------------- 1 | // Body resets 2 | // 3 | // Update the foundational and global aspects of the page. 4 | 5 | * { 6 | box-sizing: border-box; 7 | } 8 | 9 | body { 10 | margin: 0; 11 | font-family: var(--body-font); 12 | font-size: var(--body-font-size); 13 | line-height: var(--body-line-height); 14 | color: var(--body-color); 15 | background-color: var(--body-bg); 16 | -webkit-text-size-adjust: 100%; 17 | -ms-text-size-adjust: 100%; 18 | } 19 | 20 | // No `:visited` state is required by default (browsers will use `a`) 21 | a { 22 | color: var(--link-color); 23 | 24 | // `:focus` is linked to `:hover` for basic accessibility 25 | &:hover, 26 | &:focus { 27 | color: var(--link-hover-color); 28 | } 29 | 30 | strong { 31 | color: inherit; 32 | } 33 | } 34 | 35 | img { 36 | display: block; 37 | margin: auto; 38 | max-width: 100%; 39 | margin-bottom: var(--spacer); 40 | border-radius: var(--border-radius); 41 | } 42 | 43 | table { 44 | margin-bottom: 1rem; 45 | width: 100%; 46 | border: 0 solid var(--border-color); 47 | border-collapse: collapse; 48 | } 49 | 50 | td, 51 | th { 52 | padding: .25rem .5rem; 53 | border-color: inherit; 54 | border-style: solid; 55 | border-width: 0; 56 | border-bottom-width: 1px; 57 | } 58 | 59 | 60 | th { 61 | text-align: left; 62 | } 63 | 64 | thead th { 65 | border-bottom-color: currentColor; 66 | } 67 | 68 | mark { 69 | padding: .15rem; 70 | background-color: var(--yellow-100); 71 | border-radius: .125rem; 72 | } 73 | 74 | p { 75 | text-align: justify; 76 | } 77 | -------------------------------------------------------------------------------- /docs/_sass/_code.scss: -------------------------------------------------------------------------------- 1 | // Code 2 | // 3 | // Inline and block-level code snippets. Includes tweaks to syntax highlighted 4 | // snippets from Pygments/Rouge and Gist embeds. 5 | 6 | code, 7 | pre { 8 | font-family: var(--code-font); 9 | } 10 | 11 | code { 12 | font-size: 85%; 13 | } 14 | 15 | pre { 16 | display: block; 17 | margin-top: 0; 18 | margin-bottom: var(--spacer-3); 19 | overflow: auto; 20 | } 21 | 22 | .highlight { 23 | padding: var(--spacer); 24 | margin-bottom: var(--spacer); 25 | background-color: var(--code-bg); 26 | border-radius: var(--border-radius); 27 | 28 | pre { 29 | margin-bottom: 0; 30 | } 31 | 32 | // Triple backticks (code fencing) doubles the .highlight elements 33 | .highlight { 34 | padding: 0; 35 | } 36 | } 37 | 38 | .rouge-table { 39 | margin-bottom: 0; 40 | font-size: 100%; 41 | 42 | &, 43 | td, 44 | th { 45 | border: 0; 46 | } 47 | 48 | .gutter { 49 | vertical-align: top; 50 | user-select: none; 51 | opacity: .25; 52 | } 53 | } 54 | 55 | // Gist via GitHub Pages 56 | .gist .markdown-body { 57 | padding: 15px !important; 58 | } 59 | -------------------------------------------------------------------------------- /docs/_sass/_layout.scss: -------------------------------------------------------------------------------- 1 | // Layout 2 | // 3 | // Styles for managing the structural hierarchy of the site. 4 | 5 | .container { 6 | max-width: 50%; 7 | padding-left: var(--spacer-2); 8 | padding-right: var(--spacer-2); 9 | margin-left: auto; 10 | margin-right: auto; 11 | } 12 | 13 | footer { 14 | margin-top: var(--spacer-3); 15 | margin-bottom: var(--spacer-3); 16 | } 17 | -------------------------------------------------------------------------------- /docs/_sass/_masthead.scss: -------------------------------------------------------------------------------- 1 | // Masthead 2 | // 3 | // Super small header above the content for site name and short description. 4 | 5 | .masthead { 6 | padding-top: var(--spacer); 7 | padding-bottom: var(--spacer); 8 | margin-bottom: var(--spacer-3); 9 | } 10 | 11 | .masthead-title { 12 | margin-bottom: 0; 13 | 14 | a { 15 | color: inherit; 16 | text-decoration: none; 17 | } 18 | 19 | small { 20 | font-weight: 400; 21 | opacity: 0.5; 22 | } 23 | } 24 | 25 | // Navbar styles 26 | .nav { 27 | float: right; 28 | line-height: 1.25rem; 29 | word-spacing: 1rem; 30 | } 31 | -------------------------------------------------------------------------------- /docs/_sass/_message.scss: -------------------------------------------------------------------------------- 1 | // Messages 2 | // 3 | // Show alert messages to users. You may add it to single elements like a `

    `, 4 | // or to a parent if there are multiple elements to show. 5 | 6 | .message { 7 | padding: var(--spacer); 8 | margin-bottom: var(--spacer); 9 | color: var(--gray-900); 10 | background-color: var(--yellow-100); 11 | border-radius: var(--border-radius); 12 | } 13 | -------------------------------------------------------------------------------- /docs/_sass/_pagination.scss: -------------------------------------------------------------------------------- 1 | // Pagination 2 | // 3 | // Super lightweight (HTML-wise) blog pagination. `span`s are provide for when 4 | // there are no more previous or next posts to show. 5 | 6 | .pagination { 7 | display: flex; 8 | margin: 0 -1.5rem var(--spacer); 9 | color: grey; 10 | text-align: center; 11 | } 12 | 13 | // Pagination items can be `span`s or `a`s 14 | .pagination-item { 15 | display: block; 16 | padding: var(--spacer); 17 | text-decoration: none; 18 | border: solid var(--border-color); 19 | border-width: 1px 0; 20 | 21 | &:first-child { 22 | margin-bottom: -1px; 23 | } 24 | } 25 | 26 | // Only provide a hover state for linked pagination items 27 | a.pagination-item:hover { 28 | background-color: var(--border-color); 29 | } 30 | 31 | @media (min-width: 30em) { 32 | .pagination { 33 | margin: var(--spacer-3) 0; 34 | } 35 | 36 | .pagination-item { 37 | float: left; 38 | width: 50%; 39 | border-width: 1px; 40 | 41 | &:first-child { 42 | margin-bottom: 0; 43 | border-top-left-radius: var(--border-radius); 44 | border-bottom-left-radius: var(--border-radius); 45 | } 46 | &:last-child { 47 | margin-left: -1px; 48 | border-top-right-radius: var(--border-radius); 49 | border-bottom-right-radius: var(--border-radius); 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /docs/_sass/_posts.scss: -------------------------------------------------------------------------------- 1 | // Posts and pages 2 | // 3 | // Each post is wrapped in `.post` and is used on default and post layouts. Each 4 | // page is wrapped in `.page` and is only used on the page layout. 5 | 6 | .page, 7 | .post { 8 | margin-bottom: 4em; 9 | 10 | li + li { 11 | margin-top: .25rem; 12 | } 13 | } 14 | 15 | // Blog post or page title 16 | .page-title, 17 | .post-title { 18 | color: var(--heading-color); 19 | } 20 | .page-title, 21 | .post-title { 22 | margin-top: 0; 23 | } 24 | .post-title a { 25 | color: inherit; 26 | text-decoration: none; 27 | 28 | &:hover, 29 | &:focus { 30 | text-decoration: underline; 31 | } 32 | } 33 | 34 | // Meta data line below post title 35 | .post-date { 36 | display: block; 37 | margin-top: -.5rem; 38 | margin-bottom: var(--spacer); 39 | color: var(--gray-600); 40 | } 41 | 42 | 43 | // Related posts 44 | .related { 45 | padding-top: var(--spacer-2); 46 | padding-bottom: var(--spacer-2); 47 | margin-bottom: var(--spacer-2); 48 | border-top: 1px solid var(--border-color); 49 | border-bottom: 1px solid var(--border-color); 50 | } 51 | 52 | .related-posts { 53 | padding-left: 0; 54 | list-style: none; 55 | 56 | h3 { 57 | margin-top: 0; 58 | } 59 | 60 | a { 61 | text-decoration: none; 62 | 63 | small { 64 | color: var(--gray-600); 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /docs/_sass/_syntax.scss: -------------------------------------------------------------------------------- 1 | .highlight .hll { background-color: #ffc; } 2 | .highlight .c { color: #999; } /* Comment */ 3 | .highlight .err { color: #a00; background-color: #faa } /* Error */ 4 | .highlight .k { color: #069; } /* Keyword */ 5 | .highlight .o { color: #555 } /* Operator */ 6 | .highlight .cm { color: #09f; font-style: italic } /* Comment.Multiline */ 7 | .highlight .cp { color: #099 } /* Comment.Preproc */ 8 | .highlight .c1 { color: #999; } /* Comment.Single */ 9 | .highlight .cs { color: #999; } /* Comment.Special */ 10 | .highlight .gd { background-color: #fcc; border: 1px solid #c00 } /* Generic.Deleted */ 11 | .highlight .ge { font-style: italic } /* Generic.Emph */ 12 | .highlight .gr { color: #f00 } /* Generic.Error */ 13 | .highlight .gh { color: #030; } /* Generic.Heading */ 14 | .highlight .gi { background-color: #cfc; border: 1px solid #0c0 } /* Generic.Inserted */ 15 | .highlight .go { color: #aaa } /* Generic.Output */ 16 | .highlight .gp { color: #009; } /* Generic.Prompt */ 17 | .highlight .gs { } /* Generic.Strong */ 18 | .highlight .gu { color: #030; } /* Generic.Subheading */ 19 | .highlight .gt { color: #9c6 } /* Generic.Traceback */ 20 | .highlight .kc { color: #069; } /* Keyword.Constant */ 21 | .highlight .kd { color: #069; } /* Keyword.Declaration */ 22 | .highlight .kn { color: #069; } /* Keyword.Namespace */ 23 | .highlight .kp { color: #069 } /* Keyword.Pseudo */ 24 | .highlight .kr { color: #069; } /* Keyword.Reserved */ 25 | .highlight .kt { color: #078; } /* Keyword.Type */ 26 | .highlight .m { color: #f60 } /* Literal.Number */ 27 | .highlight .s { color: #d44950 } /* Literal.String */ 28 | .highlight .na { color: #4f9fcf } /* Name.Attribute */ 29 | .highlight .nb { color: #366 } /* Name.Builtin */ 30 | .highlight .nc { color: #0a8; } /* Name.Class */ 31 | .highlight .no { color: #360 } /* Name.Constant */ 32 | .highlight .nd { color: #99f } /* Name.Decorator */ 33 | .highlight .ni { color: #999; } /* Name.Entity */ 34 | .highlight .ne { color: #c00; } /* Name.Exception */ 35 | .highlight .nf { color: #c0f } /* Name.Function */ 36 | .highlight .nl { color: #99f } /* Name.Label */ 37 | .highlight .nn { color: #0cf; } /* Name.Namespace */ 38 | .highlight .nt { color: #2f6f9f; } /* Name.Tag */ 39 | .highlight .nv { color: #033 } /* Name.Variable */ 40 | .highlight .ow { color: #000; } /* Operator.Word */ 41 | .highlight .w { color: #bbb } /* Text.Whitespace */ 42 | .highlight .mf { color: #f60 } /* Literal.Number.Float */ 43 | .highlight .mh { color: #f60 } /* Literal.Number.Hex */ 44 | .highlight .mi { color: #f60 } /* Literal.Number.Integer */ 45 | .highlight .mo { color: #f60 } /* Literal.Number.Oct */ 46 | .highlight .sb { color: #c30 } /* Literal.String.Backtick */ 47 | .highlight .sc { color: #c30 } /* Literal.String.Char */ 48 | .highlight .sd { color: #c30; font-style: italic } /* Literal.String.Doc */ 49 | .highlight .s2 { color: #c30 } /* Literal.String.Double */ 50 | .highlight .se { color: #c30; } /* Literal.String.Escape */ 51 | .highlight .sh { color: #c30 } /* Literal.String.Heredoc */ 52 | .highlight .si { color: #a00 } /* Literal.String.Interpol */ 53 | .highlight .sx { color: #c30 } /* Literal.String.Other */ 54 | .highlight .sr { color: #3aa } /* Literal.String.Regex */ 55 | .highlight .s1 { color: #c30 } /* Literal.String.Single */ 56 | .highlight .ss { color: #fc3 } /* Literal.String.Symbol */ 57 | .highlight .bp { color: #366 } /* Name.Builtin.Pseudo */ 58 | .highlight .vc { color: #033 } /* Name.Variable.Class */ 59 | .highlight .vg { color: #033 } /* Name.Variable.Global */ 60 | .highlight .vi { color: #033 } /* Name.Variable.Instance */ 61 | .highlight .il { color: #f60 } /* Literal.Number.Integer.Long */ 62 | 63 | .css .o, 64 | .css .o + .nt, 65 | .css .nt + .nt { color: #999; } 66 | -------------------------------------------------------------------------------- /docs/_sass/_toc.scss: -------------------------------------------------------------------------------- 1 | // Table of Contents 2 | 3 | #markdown-toc { 4 | padding: var(--spacer-2) var(--spacer-3); 5 | margin-bottom: var(--spacer-2); 6 | border: solid var(--border-color); 7 | border-width: 1px 0; 8 | 9 | &::before { 10 | display: block; 11 | margin-left: calc(var(--spacer-3) * -1); 12 | content: "Contents"; 13 | font-size: 85%; 14 | font-weight: 500; 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /docs/_sass/_type.scss: -------------------------------------------------------------------------------- 1 | // Typography 2 | // 3 | // Headings, body text, lists, and other misc typographic elements. 4 | 5 | h1, h2, h3, h4, h5, h6 { 6 | margin-bottom: .5rem; 7 | font-weight: 600; 8 | line-height: 1.25; 9 | color: var(--heading-color); 10 | } 11 | 12 | h1 { 13 | font-size: 2rem; 14 | } 15 | 16 | h2 { 17 | margin-top: 1rem; 18 | font-size: 1.5rem; 19 | } 20 | 21 | h3 { 22 | margin-top: 1.5rem; 23 | font-size: 1.25rem; 24 | } 25 | 26 | h4, h5, h6 { 27 | margin-top: 1rem; 28 | font-size: 1rem; 29 | } 30 | 31 | p { 32 | margin-top: 0; 33 | margin-bottom: 1rem; 34 | } 35 | 36 | ul, ol, dl { 37 | margin-top: 0; 38 | margin-bottom: 1rem; 39 | } 40 | 41 | dt { 42 | font-weight: bold; 43 | } 44 | 45 | dd { 46 | margin-bottom: .5rem; 47 | } 48 | 49 | hr { 50 | position: relative; 51 | margin: var(--spacer-2) 0; 52 | border: 0; 53 | border-top: 1px solid var(--border-color); 54 | } 55 | 56 | abbr { 57 | font-size: 85%; 58 | font-weight: bold; 59 | color: var(--gray-600); 60 | text-transform: uppercase; 61 | 62 | &[title] { 63 | cursor: help; 64 | border-bottom: 1px dotted var(--border-color); 65 | } 66 | } 67 | 68 | blockquote { 69 | padding: .5rem 1rem; 70 | margin: .8rem 0; 71 | color: var(--gray-500); 72 | border-left: .25rem solid var(--border-color); 73 | 74 | p:last-child { 75 | margin-bottom: 0; 76 | } 77 | 78 | @media (min-width: 30em) { 79 | padding-right: 5rem; 80 | padding-left: 1.25rem; 81 | } 82 | } 83 | 84 | figure { 85 | margin: 0; 86 | } 87 | 88 | 89 | // Markdown footnotes 90 | // 91 | // See the example content post for an example. 92 | 93 | // Footnote number within body text 94 | a[href^="#fn:"], 95 | // Back to footnote link 96 | a[href^="#fnref:"] { 97 | display: inline-block; 98 | margin-left: .1rem; 99 | font-weight: bold; 100 | } 101 | 102 | // List of footnotes 103 | .footnotes { 104 | margin-top: 2rem; 105 | font-size: 85%; 106 | } 107 | 108 | // Custom type 109 | // 110 | // Extend paragraphs with `.lead` for larger introductory text. 111 | 112 | .lead { 113 | font-size: 1.25rem; 114 | font-weight: 300; 115 | } 116 | -------------------------------------------------------------------------------- /docs/_sass/_variables.scss: -------------------------------------------------------------------------------- 1 | :root { 2 | --gray-000: #f8f9fa; 3 | --gray-100: #f1f3f5; 4 | --gray-200: #e9ecef; 5 | --gray-300: #dee2e6; 6 | --gray-400: #ced4da; 7 | --gray-500: #adb5bd; 8 | --gray-600: #868e96; 9 | --gray-700: #495057; 10 | --gray-800: #343a40; 11 | --gray-900: #212529; 12 | --dark-poole-001: hsl(200, 3%, 12%); 13 | --dark-poole-002: hsl(0, 0%, 85%); 14 | --dark-poole-link-color: rgba(255, 255, 255, 0.75); 15 | --dark-poole-link-hover: #fff; 16 | 17 | --red: #fa5252; 18 | --pink: #e64980; 19 | --grape: #be4bdb; 20 | --purple: #7950f2; 21 | --indigo: #4c6ef5; 22 | --blue: #228be6; 23 | --cyan: #15aabf; 24 | --teal: #12b886; 25 | --green: #40c057; 26 | --yellow: #fab005; 27 | --orange: #fd7e14; 28 | 29 | --blue-300: #74c0fc; 30 | --blue-400: #4dabf7; 31 | --yellow-100: #fff3bf; 32 | 33 | --body-font: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, 34 | "Helvetica Neue", Arial, "Noto Sans", sans-serif, "Apple Color Emoji", 35 | "Segoe UI Emoji", "Segoe UI Symbol", "Noto Color Emoji"; 36 | --body-font-size: 20px; 37 | --body-line-height: 1.5; 38 | --body-color: var(--gray-700); 39 | --body-bg: #fff; 40 | 41 | --link-color: var(--blue); 42 | --link-hover-color: #1c7ed6; 43 | 44 | --heading-color: var(--gray-900); 45 | 46 | --border-color: var(--gray-300); 47 | --border-radius: 0.25rem; 48 | 49 | --code-font: SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", 50 | "Courier New", monospace; 51 | --code-color: var(--grape); 52 | --code-bg: var(--gray-000); 53 | 54 | --spacer: 1rem; 55 | --spacer-2: calc(var(--spacer) * 1.5); 56 | --spacer-3: calc(var(--spacer) * 3); 57 | } 58 | 59 | @media (prefers-color-scheme: dark) { 60 | :root { 61 | --body-color: var(--gray-300); 62 | --body-bg: var(--gray-800); 63 | 64 | --heading-color: #fff; 65 | 66 | --link-color: var(--blue-300); 67 | --link-hover-color: var(--blue-400); 68 | 69 | --border-color: rgba(255, 255, 255, 0.15); 70 | 71 | --code-bg: var(--gray-900); 72 | } 73 | } 74 | 75 | // StructuredFFN theme 76 | [data-theme="dark-poole"] { 77 | --body-color: var(--dark-poole-002); 78 | --body-bg: var(--dark-poole-001); 79 | --heading-color: var(--dark-poole-002); 80 | --link-color: var(--dark-poole-link-color); 81 | --link-hover-color: var(--dark-poole-link-hover); 82 | --border-color: rgba(255, 255, 255, 0.15); 83 | --code-bg: var(--gray-900); 84 | } 85 | -------------------------------------------------------------------------------- /docs/assets/apple-touch-icon-precomposed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/docs/assets/apple-touch-icon-precomposed.png -------------------------------------------------------------------------------- /docs/assets/author.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/docs/assets/author.png -------------------------------------------------------------------------------- /docs/assets/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/docs/assets/favicon.ico -------------------------------------------------------------------------------- /docs/assets/fig_sgt_lowrank.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/docs/assets/fig_sgt_lowrank.png -------------------------------------------------------------------------------- /docs/assets/gpt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/docs/assets/gpt.png -------------------------------------------------------------------------------- /docs/assets/latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/docs/assets/latency.png -------------------------------------------------------------------------------- /docs/assets/latency_bs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/docs/assets/latency_bs.png -------------------------------------------------------------------------------- /docs/assets/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/docs/assets/method.png -------------------------------------------------------------------------------- /docs/assets/scaling_law_lowrank.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/docs/assets/scaling_law_lowrank.png -------------------------------------------------------------------------------- /docs/assets/training_dynamic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/docs/assets/training_dynamic.png -------------------------------------------------------------------------------- /docs/assets/wide_structured.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/docs/assets/wide_structured.png -------------------------------------------------------------------------------- /docs/atom.xml: -------------------------------------------------------------------------------- 1 | --- 2 | layout: null 3 | --- 4 | 5 | 6 | 7 | 8 | {{ site.title }} 9 | 10 | 11 | {{ site.time | date_to_xmlschema }} 12 | {{ site.url }} 13 | 14 | {{ site.author.name }} 15 | {{ site.author.email }} 16 | 17 | 18 | {% for post in site.posts %} 19 | 20 | {{ post.title | xml_escape }} 21 | 22 | {{ post.date | date_to_xmlschema }} 23 | {{ site.url }}{{ post.id }} 24 | {{ post.content | xml_escape }} 25 | 26 | {% endfor %} 27 | 28 | 29 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | --- 2 | layout: default 3 | title: Home 4 | --- 5 | 6 |

    7 | {% for post in paginator.posts %} 8 |
    9 |

    10 | 11 | {{ post.title }} 12 | 13 |

    14 | 15 | 16 | 17 | {{ post.content }} 18 |
    19 | {% endfor %} 20 |
    21 | 22 | 34 | -------------------------------------------------------------------------------- /docs/poole-for-jekyll.gemspec: -------------------------------------------------------------------------------- 1 | # frozen_string_literal: true 2 | 3 | Gem::Specification.new do |spec| 4 | spec.name = "poole-for-jekyll" 5 | spec.version = "3.0.0" 6 | spec.authors = ["Mark Otto"] 7 | spec.email = ["markdotto@gmail.com"] 8 | 9 | spec.summary = "The Jekyll Butler. A no frills responsive Jekyll blog theme." 10 | spec.homepage = "https://getpoole.com" 11 | spec.license = "MIT" 12 | 13 | spec.files = `git ls-files -z`.split("\x0").select { |f| f.match(%r!^(assets|_layouts|_includes|_sass|LICENSE|README)!i) } 14 | 15 | spec.add_runtime_dependency "jekyll", "~> 4.0" 16 | 17 | spec.add_development_dependency "bundler", "~> 1.16" 18 | spec.add_development_dependency "rake", "~> 12.0" 19 | end 20 | -------------------------------------------------------------------------------- /docs/styles.scss: -------------------------------------------------------------------------------- 1 | --- 2 | # Use a comment to ensure Jekyll reads the file to be transformed into CSS later 3 | # only main files contain this front matter, not partials. 4 | --- 5 | 6 | // 7 | // ___ 8 | // /\_ \ 9 | // _____ ___ ___\//\ \ __ 10 | // /\ '__`\ / __`\ / __`\\ \ \ /'__`\ 11 | // \ \ \_\ \/\ \_\ \/\ \_\ \\_\ \_/\ __/ 12 | // \ \ ,__/\ \____/\ \____//\____\ \____\ 13 | // \ \ \/ \/___/ \/___/ \/____/\/____/ 14 | // \ \_\ 15 | // \/_/ 16 | // 17 | // Designed, built, and released under MIT license by @mdo. Learn more at 18 | // https://github.com/poole/poole. 19 | 20 | @import "variables"; 21 | @import "base"; 22 | @import "type"; 23 | @import "syntax"; 24 | @import "code"; 25 | @import "layout"; 26 | @import "masthead"; 27 | @import "posts"; 28 | @import "pagination"; 29 | @import "message"; 30 | @import "toc"; 31 | 32 | // Sass for creating the swatches 33 | .colors { 34 | display: grid; 35 | grid-template-columns: max-content 1fr; 36 | 37 | dt { 38 | width: 3rem; 39 | height: 3rem; 40 | border-radius: var(--border-radius); 41 | box-shadow: inset 0 0 0 1px rgba(255,255,255,.15); 42 | } 43 | 44 | dd { 45 | margin-left: var(--spacer); 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /experiment/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:24.02-py3 2 | MAINTAINER Xiuying Wei 3 | 4 | 5 | ARG DEBIAN_FRONTEND=noninteractive 6 | 7 | # package install 8 | RUN apt-get update && apt-get install -y \ 9 | curl vim htop\ 10 | ca-certificates \ 11 | openssh-server \ 12 | cmake \ 13 | sudo \ 14 | git \ 15 | bzip2 \ 16 | libx11-6 \ 17 | zip \ 18 | unzip ssh \ 19 | tmux \ 20 | && rm -rf /var/lib/apt/lists/* 21 | 22 | 23 | # Install Python 3.8 with Miniconda 24 | #RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-py38_4.10.3-Linux-x86_64.sh -O ~/miniconda.sh \ 25 | # && /bin/bash ~/miniconda.sh -b -p /opt/conda \ 26 | # && rm ~/miniconda.sh \ 27 | # && /opt/conda/bin/conda install mkl numpy scipy pandas openmpi ipython jupyter \ 28 | # && /opt/conda/bin/conda clean --all -y 29 | 30 | 31 | # ENV PATH="~/.local/bin:/opt/conda/bin:/usr/local/cuda/bin:${PATH}" \ 32 | # LD_LIBRARY_PATH="/usr/local/cuda/lib64" 33 | ENV PATH="~/.local/bin:/usr/local/cuda/bin:${PATH}" \ 34 | LD_LIBRARY_PATH="/usr/local/cuda/lib64" 35 | 36 | # Make $PATH and $LD_LIBRARY PATH available to all users 37 | RUN echo PATH="${PATH}" >> /etc/environment && \ 38 | echo LD_LIBRARY_PATH="${LD_LIBRARY_PATH}" >> /etc/environment 39 | 40 | # transformers==4.34.0 41 | # datasets 42 | # evaluate 43 | # accelerate 44 | # RUN pip uninstall transformer-engine --yes 45 | # The following two rows are for butterfly 46 | RUN pip --no-cache-dir install \ 47 | easydict \ 48 | h5py \ 49 | pyyaml \ 50 | tqdm \ 51 | pillow \ 52 | protobuf \ 53 | seaborn \ 54 | scipy \ 55 | scikit-learn \ 56 | wandb \ 57 | hydra-core \ 58 | transformers==4.34.0 \ 59 | datasets \ 60 | evaluate \ 61 | accelerate \ 62 | sentencepiece 63 | 64 | # RUN pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu121 65 | # RUN pip3 install --upgrade flash-attn==2.4.2 --no-build-isolation 66 | # entrypoint 67 | RUN pip install --upgrade protobuf==3.20.0 68 | ENV ENTRYPOINTS_ROOT=/opt/entrypoints 69 | RUN mkdir -p ${ENTRYPOINTS_ROOT} 70 | 71 | 72 | # The entrypoint is run in an interactive shell so that the conda environment is activated before. 73 | # Don't overwrite the entrypoint, it is installing the project 74 | # and testing that you correctly mounted the project code and data and output directories. 75 | # It also performs some other important setup depending on the deployment platform. 76 | COPY --link entrypoint.sh ${ENTRYPOINTS_ROOT}/entrypoint.sh 77 | ENTRYPOINT ["/bin/bash", "-i", "/opt/entrypoints/entrypoint.sh"] 78 | CMD ["/bin/bash"] 79 | 80 | 81 | # userconfig 82 | # define your own config here 83 | 84 | -------------------------------------------------------------------------------- /experiment/basic.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | s_token=2200000000 3 | m_token=6700000000 4 | l_token=14580000000 5 | xl_token=25500000000 6 | s_lr=0.0006 7 | m_lr=0.0003 8 | l_lr=0.00025 9 | xl_lr=0.0002 10 | s_train_batch=64 11 | s_test_batch=64 12 | m_train_batch=32 13 | m_test_batch=32 14 | l_train_batch=16 15 | l_test_batch=32 16 | xl_train_batch=16 17 | xl_test_batch=16 18 | 19 | 20 | # dense 21 | ./run_gpt.sh "gpt2s" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=linear optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${s_train_batch} data.test.test_batch=${s_test_batch} optimization.log_interval=100" 22 | ./run_gpt.sh "gpt2m" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=linear optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${m_train_batch} data.test.test_batch=${m_test_batch} optimization.log_interval=100" 23 | ./run_gpt.sh "gpt2l" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=linear optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${l_train_batch} data.test.test_batch=${l_test_batch} optimization.log_interval=100" 24 | ./run_gpt.sh "gpt2xl" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2xl method=linear optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${xl_train_batch} data.test.test_batch=${xl_test_batch} optimization.log_interval=100" 25 | 26 | # LowRank 27 | ./run_gpt.sh "gpt2s-lr-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=lowrank method.kwargs.rank=384 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${s_train_batch} data.test.test_batch=${s_test_batch} optimization.log_interval=100" 28 | ./run_gpt.sh "gpt2s-lr-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=lowrank method.kwargs.rank=192 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${s_train_batch} data.test.test_batch=${s_test_batch} optimization.log_interval=100" 29 | ./run_gpt.sh "gpt2m-lr-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=lowrank method.kwargs.rank=512 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${m_train_batch} data.test.test_batch=${m_test_batch} optimization.log_interval=100" 30 | ./run_gpt.sh "gpt2m-lr-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=lowrank method.kwargs.rank=256 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${m_train_batch} data.test.test_batch=${m_test_batch} optimization.log_interval=100" 31 | ./run_gpt.sh "gpt2l-lr-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=lowrank method.kwargs.rank=768 optimization.optimizer.kwargs.lr=${l_lr} optimization.max_tokens=${l_token} data.train.train_batch=${l_train_batch} data.test.test_batch=${l_test_batch} optimization.log_interval=100" 32 | ./run_gpt.sh "gpt2l-lr-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=lowrank method.kwargs.rank=384 optimization.optimizer.kwargs.lr=${l_lr} optimization.max_tokens=${l_token} data.train.train_batch=${l_train_batch} data.test.test_batch=${l_test_batch} optimization.log_interval=100" 33 | ./run_gpt.sh "gpt2xl-lr-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2xl method=lowrank method.kwargs.rank=1024 optimization.optimizer.kwargs.lr=${xl_lr} optimization.max_tokens=${xl_token} data.train.train_batch=${xl_train_batch} data.test.test_batch=${xl_test_batch} optimization.log_interval=100" 34 | ./run_gpt.sh "gpt2xl-lr-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2xl method=lowrank method.kwargs.rank=512 optimization.optimizer.kwargs.lr=${xl_lr} optimization.max_tokens=${xl_token} data.train.train_batch=${xl_train_batch} data.test.test_batch=${xl_test_batch} optimization.log_interval=100" 35 | 36 | # BlockDense 37 | ./run_gpt.sh "gpt2s-bld-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=blockdense method.kwargs.nblocks=2 method.kwargs.rank=512 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${s_train_batch} data.test.test_batch=${s_test_batch} optimization.log_interval=100" 38 | ./run_gpt.sh "gpt2s-bld-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=blockdense method.kwargs.nblocks=2 method.kwargs.rank=256 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${s_train_batch} data.test.test_batch=${s_test_batch} optimization.log_interval=100" 39 | ./run_gpt.sh "gpt2m-bld-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=blockdense method.kwargs.nblocks=4 method.kwargs.rank=768 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${m_train_batch} data.test.test_batch=${m_test_batch} optimization.log_interval=100" 40 | ./run_gpt.sh "gpt2m-bld-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=blockdense method.kwargs.nblocks=4 method.kwargs.rank=384 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${m_train_batch} data.test.test_batch=${m_test_batch} optimization.log_interval=100" 41 | ./run_gpt.sh "gpt2l-bld-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=blockdense method.kwargs.nblocks=2 method.kwargs.rank=1024 optimization.optimizer.kwargs.lr=${l_lr} optimization.max_tokens=${l_token} data.train.train_batch=${l_train_batch} data.test.test_batch=${l_test_batch} optimization.log_interval=100" 42 | ./run_gpt.sh "gpt2l-bld-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=blockdense method.kwargs.nblocks=2 method.kwargs.rank=512 optimization.optimizer.kwargs.lr=${l_lr} optimization.max_tokens=${l_token} data.train.train_batch=${l_train_batch} data.test.test_batch=${l_test_batch} optimization.log_interval=100" 43 | ./run_gpt.sh "gpt2xl-bld-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2xl method=blockdense method.kwargs.nblocks=4 method.kwargs.rank=1536 optimization.optimizer.kwargs.lr=${xl_lr} optimization.max_tokens=${xl_token} data.train.train_batch=${xl_train_batch} data.test.test_batch=${xl_test_batch} optimization.log_interval=100" 44 | ./run_gpt.sh "gpt2xl-bld-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2xl method=blockdense method.kwargs.nblocks=4 method.kwargs.rank=768 optimization.optimizer.kwargs.lr=${xl_lr} optimization.max_tokens=${xl_token} data.train.train_batch=${xl_train_batch} data.test.test_batch=${xl_test_batch} optimization.log_interval=100" 45 | 46 | # BlockShuffle 47 | ./run_gpt.sh "gpt2s-bls-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=blockshuffle method.kwargs.nblocks=2 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${s_train_batch} data.test.test_batch=${s_test_batch} optimization.log_interval=100" 48 | ./run_gpt.sh "gpt2s-bls-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=blockshuffle method.kwargs.nblocks=4 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${s_train_batch} data.test.test_batch=${s_test_batch} optimization.log_interval=100" 49 | ./run_gpt.sh "gpt2m-bls-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=blockshuffle method.kwargs.nblocks=2 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${m_train_batch} data.test.test_batch=${m_test_batch} optimization.log_interval=100" 50 | ./run_gpt.sh "gpt2m-bls-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=blockshuffle method.kwargs.nblocks=4 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${m_train_batch} data.test.test_batch=${m_test_batch} optimization.log_interval=100" 51 | ./run_gpt.sh "gpt2l-bls-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=blockshuffle method.kwargs.nblocks=2 optimization.optimizer.kwargs.lr=${l_lr} optimization.max_tokens=${l_token} data.train.train_batch=${l_train_batch} data.test.test_batch=${l_test_batch} optimization.log_interval=100" 52 | ./run_gpt.sh "gpt2l-bls-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=blockshuffle method.kwargs.nblocks=4 optimization.optimizer.kwargs.lr=${l_lr} optimization.max_tokens=${l_token} data.train.train_batch=${l_train_batch} data.test.test_batch=${l_test_batch} optimization.log_interval=100" 53 | ./run_gpt.sh "gpt2xl-bls-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2xl method=blockshuffle method.kwargs.nblocks=2 optimization.optimizer.kwargs.lr=${xl_lr} optimization.max_tokens=${xl_token} data.train.train_batch=${xl_train_batch} data.test.test_batch=${xl_test_batch} optimization.log_interval=100" 54 | ./run_gpt.sh "gpt2xl-bls-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2xl method=blockshuffle method.kwargs.nblocks=4 optimization.optimizer.kwargs.lr=${xl_lr} optimization.max_tokens=${xl_token} data.train.train_batch=${xl_train_batch} data.test.test_batch=${xl_test_batch} optimization.log_interval=100" 55 | -------------------------------------------------------------------------------- /experiment/sgd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | s_token=2200000000 3 | m_token=6700000000 4 | l_token=14580000000 5 | xl_token=25500000000 6 | s_lr=0.0006 7 | m_lr=0.0003 8 | l_lr=0.00025 9 | xl_lr=0.0002 10 | train_batch=8 11 | test_batch=8 12 | s_lr_ratio=0.30 13 | m_lr_ratio=0.38 14 | l_lr_ratio=0.41 15 | xl_lr_ratio=0.43 16 | # The parameters for BlockDense with about 32\% parameters are not matched with the LowRank and BlockShuffle exactly. Thus, we provide the max_step_ratio of self-guided training for BlockDense separately to exactly match the training FLOPs. 17 | 18 | s_bld_ratio=0.30 19 | m_bld_ratio=0.40 20 | l_bld_ratio=0.41 21 | xl_bld_ratio=0.45 22 | 23 | # apply self-guided training for the first half of training 24 | ./run_gpt.sh "gpt2s-lr-0x33-sgd-fixedstep" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=lowrank method.kwargs.rank=192 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedstep optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=0.5" 25 | ./run_gpt.sh "gpt2m-lr-0x33-sgd-fixedstep" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=lowrank method.kwargs.rank=256 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedstep optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=0.5" 26 | ./run_gpt.sh "gpt2s-bld-0x33-sgd-fixedstep" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=blockdense method.kwargs.nblocks=2 method.kwargs.rank=256 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedstep optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=0.5" 27 | ./run_gpt.sh "gpt2m-bld-0x33-sgd-fixedstep" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=blockdense method.kwargs.nblocks=4 method.kwargs.rank=384 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedstep optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=0.5" 28 | ./run_gpt.sh "gpt2s-bls-0x33-sgd-fixedstep" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=blockshuffle method.kwargs.nblocks=4 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedstep optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=0.5" 29 | ./run_gpt.sh "gpt2m-bls-0x33-sgd-fixedstep" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=blockshuffle method.kwargs.nblocks=4 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedstep optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=0.5" 30 | 31 | # to match the flops 32 | # LowRank 33 | ./run_gpt.sh "gpt2s-lr-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=lowrank method.kwargs.rank=192 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${s_lr_ratio}" 34 | ./run_gpt.sh "gpt2m-lr-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=lowrank method.kwargs.rank=256 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${m_lr_ratio}" 35 | ./run_gpt.sh "gpt2l-lr-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=lowrank method.kwargs.rank=384 optimization.optimizer.kwargs.lr=${l_lr} optimization.max_tokens=${l_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${l_lr_ratio}" 36 | ./run_gpt.sh "gpt2xl-lr-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2xl method=lowrank method.kwargs.rank=512 optimization.optimizer.kwargs.lr=${xl_lr} optimization.max_tokens=${xl_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${xl_lr_ratio}" 37 | 38 | # BlockDense 39 | ./run_gpt.sh "gpt2s-bld-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=blockdense method.kwargs.nblocks=2 method.kwargs.rank=256 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${s_bld_ratio}" 40 | ./run_gpt.sh "gpt2m-bld-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=blockdense method.kwargs.nblocks=4 method.kwargs.rank=384 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${m_bld_ratio}" 41 | ./run_gpt.sh "gpt2l-bld-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=blockdense method.kwargs.nblocks=2 method.kwargs.rank=512 optimization.optimizer.kwargs.lr=${l_lr} optimization.max_tokens=${l_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${l_bld_ratio}" 42 | ./run_gpt.sh "gpt2xl-bld-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2xl method=blockdense method.kwargs.nblocks=4 method.kwargs.rank=768 optimization.optimizer.kwargs.lr=${xl_lr} optimization.max_tokens=${xl_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${xl_bld_ratio}" 43 | 44 | # BlockShuffle 45 | ./run_gpt.sh "gpt2s-bls-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=blockshuffle method.kwargs.nblocks=4 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${s_lr_ratio}" 46 | ./run_gpt.sh "gpt2m-bls-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=blockshuffle method.kwargs.nblocks=4 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${m_lr_ratio}" 47 | ./run_gpt.sh "gpt2l-bls-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=blockshuffle method.kwargs.nblocks=4 optimization.optimizer.kwargs.lr=${l_lr} optimization.max_tokens=${l_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${l_lr_ratio}" 48 | ./run_gpt.sh "gpt2xl-bls-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2xl method=blockshuffle method.kwargs.nblocks=4 optimization.optimizer.kwargs.lr=${xl_lr} optimization.max_tokens=${xl_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${xl_lr_ratio}" 49 | 50 | 51 | -------------------------------------------------------------------------------- /image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/image.png -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Building on Efficient Foundations: Effectively Training LLMs with Structured Feedforward Layers 2 | 3 | ## Introduction 4 | This repository contains the offical implementation for our paper 5 | 6 | **Building on Efficient Foundations: Effectively Training LLMs with Structured Feedforward Layers** 7 | 8 | Xiuying Wei, Skander Moalla, Razvan Pascanu, Caglar Gulcehre 9 | 10 | > In this work, we investigate structured matrices for FFN blocks from the train-from-scratch aspect, first identifying their efficiency and optimization challenges and then presenting experimental results. We consider three efficient linear parametrizations: LowRank, BlockShuffle (comprising two block-diagonal matrices), and BlockDense (a combination of dense and block-diagonal matrices). We propose the pre-merge technique to solve their efficiency bottleneck at the online decoding stage. Then, a training strategy called self-guided training is proposed to improve their training dynamics. Experimental results include the steeper scaling curves of these structured matrices compared to the dense ones on FFN, the improvement brought by self-guided training, and the performance of wide and structured networks when combined with GQA for the attention block. 11 | 12 | ![alt text](image.png) 13 | ## File Organization 14 | ``` 15 | Structured/src/ 16 | ├── benchmark_acc/ [training and evaluation entry for different dataset] 17 | │ └── refinedweb_experiment.py [refinedweb entry] 18 | ├── benchmark_eff [efficiency entry] 19 | │ ├── bench_kernel.py [kernel efficiency] 20 | │ ├── bench_mlp_train.py [mlp efficiency] 21 | │ ├── benchmark_model_infer.py [decoding efficiency] 22 | │ └── benchmark_model_train.py [prefill/ context efficiency] 23 | ├── configs [hydra config] 24 | │ ├── data [No use. refinedweb is preprocessed in advance] 25 | │ ├── method [different efficient linear layer] 26 | │ ├── model [gpt and llama] 27 | │ ├── optimization [optimization including scheduler, optimizer, self-guided training etc.] 28 | │ └── refinedweb_config.yaml 29 | ├── modules 30 | │ ├── __init__.py 31 | │ ├── op [fast op. Commons ones invoke others or paste from megatron] 32 | │ ├── layer [efficient lineaer layers that invoke functions in op dir] 33 | │ ├── mlp [efficient mlps that invoke functions in layer dir] 34 | │ └── model [supports layernorm or rmsnorm, bias or not, tie we or not, rotary or absolute, gelu or swilu] 35 | ├── optimization 36 | │ ├── __init__.py 37 | │ ├── scheduler.py [cosine with warmup] 38 | │ └── trainer.py [basic training function including seed, checkpoint, and info] 39 | └── utils 40 | └── refinedweb_llama.py [preprocess file] 41 | ``` 42 | 43 | ## Env 44 | We use a Docker container for the environment and the GPU type of A100 80G for experiments. The Dockerfile is provided in the experiments folder, where the base image is from Nvidia (nvcr.io/nvidia/pytorch:24.02-py3) with the transformer engine, flash attention, and apex pre-installed. The required Python packages include transformers, wandb, datasets, etc., as listed in the Dockerfile. 45 | 46 | ## Data preprocess 47 | ``` 48 | python refinedweb_llama.py --tokenizer llama --block_size 1025 --num_proc=32 49 | ``` 50 | 51 | Refinedweb is quite large. So we shuffle, extract, and tokenize them into token ids in advance. Their token ids are kept in np. memmap to avoid loading data into CPU memory at one time. The above command will randomly split out about 0.7B validation tokens and 65B training tokens for later use. 52 | 53 | ## Experiments 54 | ### Structured linear parametrization (Table 1 and Table 9) 55 | We provide several examples below. We put the whole commands in basic.sh 56 | ``` 57 | # gpt2 and linear 58 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=linear 59 | 60 | # LowRank 61 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=lowrank method.kwargs.rank=384 optimization.max_tokens=2200000000 optimization.optimizer.kwargs.lr=6.0e-4 data.train.train_batch=64 data.test.test_batch=64 62 | 63 | # BlockShuffle 64 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=blockshuffle method.kwargs.nblocks=2 optimization.max_tokens=2200000000 optimization.optimizer.kwargs.lr=6.0e-4 data.train.train_batch=64 data.test.test_batch=64 65 | 66 | # BlockDense 67 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=blockdense method.kwargs.rank=512 method.kwargs.nblocks=2 optimization.max_tokens=2200000000 optimization.optimizer.kwargs.lr=6.0e-4 data.train.train_batch=64 data.test.test_batch=64 68 | ``` 69 | 70 | ### Self-Guided Training (Table 3, 4, and 10) 71 | There are two modes: 72 | 73 | * ablation study that applies the method to the first half of training and incurs 25% extra FFN FLOPs 74 | 75 | * experiments with the same training FLOPs to see the straightforward improvement. We use self-guided training for the beginning and repeat this part of tokens at the end to ensure that structured matrices also learn from this data thoroughly. The amount of self-guided training is adjusted to match the training FLOPs. 76 | 77 | We provide examples here, and put all the reproducible commands in experiments/sgd.sh 78 | ``` 79 | # Ablation 80 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=lowrank method.kwargs.rank=192 optimization.max_tokens=2200000000 optimization.optimizer.kwargs.lr=6.0e-4 data.train.train_batch=16 data.test.test_batch=16 optimization/training=self_guided_training optimization.training.kwargs.reduce_flop=true 81 | 82 | # to match the flops 83 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=lowrank method.kwargs.rank=192 optimization.max_tokens=2200000000 optimization.optimizer.kwargs.lr=6.0e-4 data.train.train_batch=32 data.test.test_batch=32 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=0.3 84 | ``` 85 | 86 | ### Wide and Structured network (Table 2) 87 | Motivated by the scaling curves, we make the wide model structured with LowRank for FFN and GQA for attention block. 88 | 89 | Transformer-m 90 | ``` 91 | # GQA 92 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=linear model.kwargs.num_kv_heads=4 model.kwargs.ffn_dim=4864 data.train.train_batch=32 data.test.test_batch=32 optimization.max_tokens=6700000000 optimization.optimizer.kwargs.lr=3.0e-4 93 | 94 | # Ours 95 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=lowrank method.kwargs.rank=512 model.kwargs.hidden_dim=1024 model.kwargs.ffn_dim=4864 model.kwargs.attn_dim=512 model.kwargs.num_q_heads=8 model.kwargs.num_kv_heads=4 data.train.train_batch=32 data.test.test_batch=32 optimization.optimizer.kwargs.lr=3.0e-4 optimization.max_tokens=10580000000 96 | 97 | # Ours (self-guided training) 98 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=lowrank method.kwargs.rank=512 model.kwargs.hidden_dim=1024 model.kwargs.ffn_dim=4864 model.kwargs.attn_dim=512 model.kwargs.num_q_heads=8 model.kwargs.num_kv_heads=4 optimization.optimizer.kwargs.lr=3.0e-4 optimization.max_tokens=6700000000 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=0.40 99 | ``` 100 | 101 | Transformer-l 102 | ``` 103 | # GQA 104 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=linear model.kwargs.num_kv_heads=2 model.kwargs.ffn_dim=7424 data.train.train_batch=8 data.test.test_batch=8 optimization.max_tokens=14580000000 optimization.optimizer.kwargs.lr=0.00025 105 | 106 | # Ours 107 | # we keep the KV Channels to be 256, aligning with what we used in GQA. 108 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=lowrank method.kwargs.rank=768 model.kwargs.hidden_dim=1536 model.kwargs.ffn_dim=7424 model.kwargs.attn_dim=768 model.kwargs.num_q_heads=12 model.kwargs.num_kv_heads=4 data.train.train_batch=16 data.test.test_batch=16 optimization.optimizer.kwargs.lr=2.5e-4 optimization.max_tokens=23360000000 109 | 110 | # Ours (self-guided training) 111 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=lowrank method.kwargs.rank=768 model.kwargs.hidden_dim=1536 model.kwargs.ffn_dim=7424 model.kwargs.attn_dim=768 model.kwargs.num_q_heads=12 model.kwargs.num_kv_heads=4 optimization.optimizer.kwargs.lr=2.5e-4 optimization.max_tokens=14580000000 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=0.395 112 | ``` 113 | 114 | ### Citation 115 | If you find this repo useful for your research, please consider citing the paper: 116 | ``` 117 | @article{wei2024building, 118 | title={Building on Efficient Foundations: Effectively Training LLMs with Structured Feedforward Layers}, 119 | author={Wei, Xiuying and Moalla, Skander and Pascanu, Razvan and Gulcehre, Caglar}, 120 | journal={arXiv preprint arXiv:2406.16450}, 121 | year={2024} 122 | } 123 | 124 | @article{wei2024investigating, 125 | title={Investigating Low-Rank Training in Transformer Language Models: Efficiency and Scaling Analysis}, 126 | author={Wei, Xiuying and Moalla, Skander and Pascanu, Razvan and Gulcehre, Caglar}, 127 | journal={arXiv preprint arXiv:2407.09835}, 128 | year={2024} 129 | } 130 | 131 | ``` 132 | -------------------------------------------------------------------------------- /src/benchmark_acc/refinedweb_experiment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import numpy as np 5 | import time 6 | from tqdm import tqdm 7 | 8 | 9 | project_root = os.path.dirname( 10 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 11 | ) 12 | 13 | print(project_root) 14 | sys.path.insert(0, project_root) 15 | import wandb 16 | import hydra 17 | import torch 18 | import torch.distributed as dist 19 | from omegaconf import DictConfig, OmegaConf 20 | 21 | from src.optimization import ( 22 | get_optimizer, 23 | get_lr_scheduler, 24 | trainer, 25 | ) 26 | 27 | from src.modules import get_model, update_ratio 28 | 29 | 30 | class RefinedWebGPT(trainer.TrainableModel): 31 | # NLP tasks are dominated by step rather than epoch, because we need to consider gradient accumulation 32 | def __init__(self, config): 33 | super().__init__(config) 34 | # get dataset 35 | self.set_seed(self.config.optimization.seed) 36 | # get data files 37 | self.train_file_path = os.path.join( 38 | self.config.data.train.path, 39 | self.config.data.tokenizer.name + "-train-tmp.bin", 40 | ) 41 | self.val_file_path = os.path.join( 42 | self.config.data.test.path, 43 | self.config.data.tokenizer.name + "-val-tmp.bin", 44 | ) 45 | self.block_size = min( 46 | self.config.data.block_size, self.config.data.tokenizer.model_max_length 47 | ) 48 | validate_tokens = 512000 * 1024 49 | self.validate_samples = validate_tokens // self.block_size 50 | assert ( 51 | self.validate_samples % (self.ngpus * self.config.data.test.test_batch) == 0 52 | ) 53 | assert self.gpu_id != -1, "we only support torchrun in job submission" 54 | 55 | # get metric 56 | self.max_step = int( 57 | self.config.optimization.max_tokens 58 | / self.global_batch_size 59 | / self.block_size 60 | ) 61 | self.set_self_guided_training() 62 | self.config.optimization.lr_scheduler.kwargs.T_max = self.max_step 63 | if self.gpu_id in [-1, 0]: 64 | self.metric = { 65 | "train_loss": 0.0, 66 | "train_ppl": 0.0, 67 | "test_loss": 0.0, 68 | "test_ppl": 0.0, 69 | "step": 0, 70 | "lr": 0.0, 71 | "fwd+bwd": 0.0, 72 | } 73 | # get model 74 | self.set_seed(self.config.optimization.seed) 75 | self.model = get_model(self.config, self.device) 76 | self.get_info() 77 | 78 | # get optimizer 79 | self.optimizer = get_optimizer( 80 | self.config.optimization, self.get_optimize_param() 81 | ) 82 | if getattr(self.config.optimization, "lr_scheduler", None): 83 | self.lr_scheduler = get_lr_scheduler( 84 | self.config.optimization, self.optimizer 85 | ) 86 | 87 | # get wandb 88 | if self.gpu_id in [-1, 0] and self.config.wandb_use: 89 | self.wandblog = trainer.WandbLog( 90 | self.config.wandb, self.metric, x_axis="step" 91 | ) 92 | 93 | assert self.load_save_mode == "step" 94 | self.prepare_load_save() 95 | self.resume_kwargs = self.load_checkpoint() 96 | if self.gpu_id != -1: 97 | self.model = torch.nn.parallel.DistributedDataParallel( 98 | self.model, 99 | device_ids=[self.gpu_id], 100 | output_device=self.gpu_id, 101 | find_unused_parameters=self.special_training, 102 | ) 103 | if self.gpu_id in [-1, 0]: 104 | print(self.config) 105 | 106 | def get_batch(self, split, offset_row): 107 | if split == "train": 108 | arr = np.memmap( 109 | self.train_file_path, 110 | dtype=np.uint16, # we store in 2 bytes 111 | mode="r", 112 | offset=offset_row * (self.block_size + 1) * 2, 113 | shape=(self.config.data.train.train_batch, (self.block_size + 1)), 114 | ) 115 | elif split == "val": 116 | arr = np.memmap( 117 | self.val_file_path, 118 | dtype=np.uint16, # we store in 2 bytes 119 | mode="r", 120 | offset=offset_row * (self.block_size + 1) * 2, 121 | shape=(self.config.data.test.test_batch, (self.block_size + 1)), 122 | ) 123 | else: 124 | raise NotImplementedError 125 | 126 | x = torch.from_numpy(arr[:, :-1].astype(np.int64)) 127 | y = torch.from_numpy(arr[:, 1:].astype(np.int64)) 128 | x, y = x.pin_memory().to("cuda", non_blocking=True), y.pin_memory().to( 129 | "cuda", non_blocking=True 130 | ) 131 | return x, y 132 | 133 | def _validate(self): 134 | self.model.eval() 135 | ddp_loss = torch.tensor(0.0).to(self.device) 136 | ddp_samples = torch.tensor(0).to(self.device) 137 | samples_per_gpu = self.validate_samples // self.ngpus 138 | with torch.no_grad(): 139 | offset_row = self.gpu_id * samples_per_gpu 140 | for i in range(samples_per_gpu // self.config.data.test.test_batch): 141 | input_ids, labels = self.get_batch( 142 | split="val", offset_row=offset_row + ddp_samples.item() 143 | ) 144 | with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): 145 | loss = self.model( 146 | input_ids=input_ids, 147 | labels=labels, 148 | ) 149 | if i % 100 == 0 and self.gpu_id in [-1, 0]: 150 | print("the loss at batch {} is {}".format(i, loss)) 151 | ddp_loss += loss.item() * input_ids.shape[0] 152 | ddp_samples += input_ids.shape[0] 153 | print("The samples on rank {} is {}".format(self.gpu_id, ddp_samples)) 154 | dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) 155 | dist.all_reduce(ddp_samples, op=dist.ReduceOp.SUM) 156 | var_loss = (ddp_loss / ddp_samples).item() 157 | var_ppl = math.exp(var_loss) 158 | return var_loss, var_ppl 159 | 160 | def _train(self, resume_batch, max_step, offset_row=-1): 161 | if resume_batch >= max_step: 162 | return 163 | train_iterator = tqdm( 164 | range(resume_batch, max_step), 165 | desc="Steps", 166 | disable=self.gpu_id not in [-1, 0], 167 | ) 168 | samples_per_gpu = self.global_batch_size // self.ngpus 169 | self.model.train() 170 | self.optimizer.zero_grad() 171 | train_loss = 0.0 172 | train_samples = 0 173 | if offset_row == -1: 174 | offset_row = resume_batch * self.global_batch_size 175 | offset_row += self.gpu_id * samples_per_gpu 176 | for i in train_iterator: 177 | torch.cuda.synchronize() 178 | t0 = time.time() 179 | train_loss = 0.0 180 | train_samples = 0 181 | for micro_step in range(self.gradient_accumulation_steps): 182 | input_ids, labels = self.get_batch( 183 | split="train", offset_row=offset_row + train_samples 184 | ) 185 | with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): 186 | loss = self.model( 187 | input_ids=input_ids, 188 | labels=labels, 189 | ) 190 | train_samples += self.config.data.train.train_batch 191 | train_loss += loss.item() * self.config.data.train.train_batch 192 | loss = loss / self.gradient_accumulation_steps 193 | loss.backward() 194 | # finish the step 195 | if self.special_training: 196 | self.model.apply(lambda module: update_ratio(module=module)) 197 | self.set_gradient_clipping() 198 | self.optimizer.step() 199 | self.lr_scheduler.step() 200 | self.optimizer.zero_grad() 201 | torch.cuda.synchronize() 202 | t2 = time.time() 203 | self.step += 1 204 | offset_row += self.global_batch_size 205 | if self.gpu_id in [-1, 0] and (self.step + 1) % self.log_interval == 0: 206 | # test_loss, test_ppl = self._test() 207 | # self.model.train() 208 | self.metric.update( 209 | { 210 | "train_loss": train_loss / train_samples, 211 | "train_ppl": math.exp(train_loss / train_samples), 212 | "step": self.step, 213 | "lr": self.optimizer.param_groups[0]["lr"], 214 | "fwd+bwd": (t2 - t0), 215 | } 216 | ) 217 | if self.config.wandb_use: 218 | self.wandblog.record(self.metric) 219 | else: 220 | print(self.metric) 221 | 222 | self.save_checkpoint(**{"resume_batch": i + 1}) 223 | 224 | def train(self): 225 | self.set_seed(self.config.optimization.seed) 226 | print("***** Running training *****") 227 | num_examples = self.max_step * self.global_batch_size 228 | print("Num Examples = {}".format(num_examples)) 229 | # Note that epoch would always be zero here 230 | print("Num Tokens = {}".format(num_examples * self.block_size)) 231 | print("Num Steps = {}".format(self.max_step)) 232 | print("Global batch size = {}".format(self.global_batch_size)) 233 | print( 234 | "Gradient Accumulation steps = {}".format(self.gradient_accumulation_steps) 235 | ) 236 | resume_batch = self.resume_kwargs.get("resume_batch", 0) # next one 237 | print("resume from batch {}".format(resume_batch)) 238 | # train guided steps 239 | self._train(resume_batch, self.guided_steps, offset_row=-1) 240 | self.close_self_guided_training() 241 | self._train( 242 | max(self.guided_steps, resume_batch), 243 | self.max_step - self.repeat_steps, 244 | offset_row=-1, 245 | ) 246 | self._train( 247 | max(self.max_step - self.repeat_steps, resume_batch), 248 | self.max_step, 249 | offset_row=max(0, resume_batch + self.repeat_steps - self.max_step), 250 | ) 251 | 252 | 253 | @hydra.main( 254 | version_base=None, 255 | config_path="../configs", 256 | config_name="refinedweb_config", 257 | ) 258 | def main(config): 259 | OmegaConf.register_new_resolver("eval", eval) 260 | config.base_dir = os.path.join( 261 | config.base_dir, config.data.name + "_" + config.model.name 262 | ) 263 | config.wandb.dir = config.base_dir 264 | config.wandb.dir = os.path.join(config.base_dir, config.method.name) 265 | gpu_id = int(os.getenv("RANK", -1)) 266 | if gpu_id in [-1, 0] and not os.path.exists(config.wandb.dir): 267 | os.makedirs(config.wandb.dir) 268 | 269 | if gpu_id in [-1, 0] and config.wandb_use: 270 | wandb.init( 271 | config=OmegaConf.to_container(config, resolve=True, throw_on_missing=True), 272 | entity=config.wandb.entity, 273 | project=config.wandb.project, 274 | resume=None if config.optimization.load_checkpoint else "allow", 275 | anonymous=config.wandb.anonymous, 276 | mode=config.wandb.mode, 277 | dir=config.wandb.dir, 278 | ) 279 | if gpu_id != -1: 280 | dist.barrier() 281 | model = RefinedWebGPT(config) 282 | model.train() 283 | 284 | if gpu_id != -1: 285 | dist.barrier() 286 | print("Finish Training!") 287 | print("Begin to validate!") 288 | var_loss, var_ppl = model._validate() 289 | print("The var loss is {:.4f} and var ppl is {:.4f}".format(var_loss, var_ppl)) 290 | if gpu_id in [-1, 0]: 291 | if config.wandb_use: 292 | wandb.finish() 293 | return var_loss, var_ppl 294 | 295 | 296 | if __name__ == "__main__": 297 | gpu_id = int(os.getenv("RANK", -1)) 298 | world_size = int(os.getenv("WORLD_SIZE", 1)) 299 | if gpu_id != -1: 300 | torch.cuda.set_device(gpu_id) 301 | dist.init_process_group( 302 | backend="nccl", world_size=world_size, rank=gpu_id, init_method="env://" 303 | ) 304 | 305 | main() 306 | -------------------------------------------------------------------------------- /src/benchmark_eff/bench_kernel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import triton 4 | import torch 5 | 6 | project_root = os.path.dirname( 7 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 8 | ) 9 | 10 | print(project_root) 11 | sys.path.insert(0, project_root) 12 | 13 | from src.modules.op import ( 14 | block_shuffle_bmm, 15 | block_shuffle_einsum, 16 | block_shuffle_custom, 17 | block_dense_bmm, 18 | block_dense_custom, 19 | low_rank_custom, 20 | ) 21 | 22 | 23 | @triton.testing.perf_report( 24 | triton.testing.Benchmark( 25 | x_names=["bs", "blocks", "in_blksz", "out_blksz"], 26 | x_vals=[ 27 | (16 * 1024, 4, 512, 512), 28 | (16 * 512, 16, 512, 512 * 4), 29 | (32 * 1024, 4, 1024, 1024), 30 | (32 * 1024, 2, 4096, 4096 * 4), 31 | (64 * 1024, 16, 256, 256 * 4), 32 | ], 33 | line_arg="provider", 34 | line_vals=["einsum", "bmm", "custom"], 35 | line_names=["einsum", "bmm", "custom"], 36 | styles=[("blue", "-"), ("green", "-"), ("green", "--")], 37 | ylabel="latency (ms)", 38 | plot_name="blockshuffle-performance", 39 | args={"torch_dtype": torch.float16}, 40 | ) 41 | ) 42 | def benchmark_blockshuffle(bs, blocks, in_blksz, out_blksz, torch_dtype, provider): 43 | input = torch.randn(bs, blocks * in_blksz, device="cuda", dtype=torch_dtype) * 0.02 44 | if in_blksz < out_blksz: 45 | w1 = ( 46 | torch.randn(blocks, in_blksz, in_blksz, device="cuda", dtype=torch_dtype) 47 | * 0.02 48 | ) 49 | w2 = ( 50 | torch.randn(blocks, out_blksz, in_blksz, device="cuda", dtype=torch_dtype) 51 | * 0.02 52 | ) 53 | else: 54 | w1 = ( 55 | torch.randn(blocks, out_blksz, in_blksz, device="cuda", dtype=torch_dtype) 56 | * 0.02 57 | ) 58 | w2 = ( 59 | torch.randn(blocks, out_blksz, out_blksz, device="cuda", dtype=torch_dtype) 60 | * 0.02 61 | ) 62 | quantiles = [0.5, 0.2, 0.8] 63 | if provider == "einsum": 64 | ms, min_ms, max_ms = triton.testing.do_bench( 65 | lambda: block_shuffle_einsum(input, w1, w2), quantiles=quantiles 66 | ) 67 | if provider == "bmm": 68 | ms, min_ms, max_ms = triton.testing.do_bench( 69 | lambda: block_shuffle_bmm(input, w1, w2), quantiles=quantiles 70 | ) 71 | if provider == "custom": 72 | ms, min_ms, max_ms = triton.testing.do_bench( 73 | lambda: block_shuffle_custom(input, w1, w2), quantiles=quantiles 74 | ) 75 | return ms, max_ms, min_ms 76 | 77 | 78 | @triton.testing.perf_report( 79 | triton.testing.Benchmark( 80 | x_names=["bs", "blocks", "in_blksz", "r_blksz", "out"], 81 | x_vals=[ 82 | (16 * 1024, 4, 512, 384, 512), 83 | (16 * 512, 16, 512, 384, 512 * 4), 84 | (32 * 1024, 4, 1024, 512, 1024), 85 | (32 * 1024, 2, 1024, 512, 4096 * 4), 86 | (64 * 1024, 16, 256, 128, 256 * 4), 87 | ], 88 | line_arg="provider", 89 | line_vals=["bmm", "custom"], 90 | line_names=["bmm", "custom"], 91 | styles=[("green", "-"), ("green", "--")], 92 | ylabel="latency (ms)", 93 | plot_name="block-linear-performance", 94 | args={"torch_dtype": torch.float16}, 95 | ) 96 | ) 97 | def benchmark_blockdense(bs, blocks, in_blksz, r_blksz, out, torch_dtype, provider): 98 | input = torch.randn(bs, in_blksz * blocks, device="cuda", dtype=torch_dtype) * 0.02 99 | w1 = ( 100 | torch.randn( 101 | blocks, 102 | r_blksz, 103 | in_blksz, 104 | device="cuda", 105 | dtype=torch_dtype, 106 | ) 107 | * 0.02 108 | ) 109 | w2 = torch.randn(out, r_blksz * blocks, device="cuda", dtype=torch_dtype) * 0.02 110 | 111 | quantiles = [0.5, 0.2, 0.8] 112 | if provider == "bmm": 113 | ms, min_ms, max_ms = triton.testing.do_bench( 114 | lambda: block_dense_bmm(input, w1, w2), quantiles=quantiles 115 | ) 116 | if provider == "custom": 117 | ms, min_ms, max_ms = triton.testing.do_bench( 118 | lambda: block_dense_custom(input, w1, w2), quantiles=quantiles 119 | ) 120 | return ms, max_ms, min_ms 121 | 122 | 123 | @triton.testing.perf_report( 124 | triton.testing.Benchmark( 125 | x_names=["bs", "seq", "d_in", "d_r", "d_out"], 126 | x_vals=[ 127 | (16, 1024, 4 * 512, 384 * 4, 512 * 4), 128 | (16, 512, 16 * 512, 384 * 16, 512 * 16), 129 | (32, 1024, 4 * 1024, 512 * 4, 1024 * 4), 130 | (32, 1024, 2 * 1024, 512 * 2, 4096 * 2), 131 | (64, 1024, 16 * 256, 128 * 16, 256 * 16), 132 | ], 133 | line_arg="provider", 134 | line_vals=["custom"], 135 | line_names=["custom"], 136 | styles=[("green", "-"), ("green", "--")], 137 | ylabel="latency (ms)", 138 | plot_name="block-linear-performance", 139 | args={"torch_dtype": torch.float16}, 140 | ) 141 | ) 142 | def benchmark_lowrank(bs, seq, d_in, d_r, d_out, torch_dtype, provider): 143 | input = torch.randn(bs, seq, d_in, device="cuda", dtype=torch_dtype) * 0.02 144 | w1 = torch.randn(d_r, d_in, device="cuda", dtype=torch_dtype) * 0.02 145 | w2 = torch.randn(d_out, d_r, device="cuda", dtype=torch_dtype) * 0.02 146 | 147 | quantiles = [0.5, 0.2, 0.8] 148 | if provider == "custom": 149 | ms, min_ms, max_ms = triton.testing.do_bench( 150 | lambda: low_rank_custom(input, w1, w2), quantiles=quantiles 151 | ) 152 | return ms, max_ms, min_ms 153 | 154 | 155 | benchmark_blockshuffle.run(show_plots=True, print_data=True) 156 | benchmark_blockdense.run(show_plots=True, print_data=True) 157 | benchmark_lowrank.run(show_plots=True, print_data=True) 158 | -------------------------------------------------------------------------------- /src/benchmark_eff/benchmark_mlp_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch 5 | import random 6 | import os 7 | import hydra 8 | import triton 9 | import time 10 | import sys 11 | from hide_warnings import hide_warnings 12 | 13 | project_root = os.path.dirname( 14 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 15 | ) 16 | 17 | print(project_root) 18 | sys.path.insert(0, project_root) 19 | 20 | from src.modules.mlp import ( 21 | FusedBlockDenseMLP, 22 | FusedLowRankMLP, 23 | FusedBlockShuffleMLP, 24 | FusedMLP, 25 | ) 26 | 27 | # pure bfloat16 efficiency 28 | 29 | name_to_method = { 30 | "lowrank": FusedLowRankMLP, 31 | "blockdense": FusedBlockDenseMLP, 32 | "blockshuffle": FusedBlockShuffleMLP, 33 | "linear": FusedMLP, 34 | } 35 | from omegaconf import DictConfig, OmegaConf 36 | 37 | torch_dtype = torch.bfloat16 38 | 39 | 40 | def set_seed(seed): 41 | random.seed(seed) 42 | os.environ["PYTHONHASHSEED"] = str(seed) 43 | np.random.seed(seed) 44 | torch.manual_seed(seed) 45 | torch.cuda.manual_seed(seed) 46 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 47 | torch.backends.cudnn.benchmark = False 48 | torch.backends.cudnn.deterministic = True 49 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 50 | 51 | 52 | def benchmark_train(net, inp): 53 | def fn(input): 54 | net(input) 55 | 56 | quantiles = [0.5, 0.2, 0.8] 57 | 58 | t, min_ms, max_ms = triton.testing.do_bench( 59 | lambda: fn(inp), quantiles=quantiles, warmup=50, rep=200 60 | ) 61 | latency = t 62 | throughput = inp.shape[0] * inp.shape[1] / latency * 10**3 63 | print("Latency (ms): {}, Throughput (token/s): {}".format(latency, throughput)) 64 | return latency, throughput 65 | 66 | 67 | @hide_warnings(out=False) 68 | @hydra.main( 69 | version_base=None, 70 | config_path="../configs", 71 | config_name="refinedweb_config", 72 | ) 73 | def main(config): 74 | OmegaConf.register_new_resolver("eval", eval) 75 | config_model = config.model 76 | config_method = config.method 77 | f = open("../../exp/logs/fig3.log", "a+") 78 | 79 | if config_method.name == "linear": 80 | model = ( 81 | FusedMLP( 82 | config_model.kwargs.hidden_dim, 83 | config_model.kwargs.ffn_dim, 84 | config_model.kwargs.bias, 85 | config_model.kwargs.act, 86 | ) 87 | .cuda() 88 | .to(torch_dtype) 89 | ) 90 | else: 91 | model = ( 92 | name_to_method[config.method.name.lower()]( 93 | config_model.kwargs.hidden_dim, 94 | config_model.kwargs.ffn_dim, 95 | config_model.kwargs.bias, 96 | config_model.kwargs.act, 97 | config_method.kwargs, 98 | config_model.kwargs.init, 99 | device="cuda", 100 | ) 101 | .cuda() 102 | .to(torch_dtype) 103 | ) 104 | model.eval() 105 | with torch.no_grad(): 106 | input = ( 107 | torch.randn( 108 | config.data.test.test_batch, 1024, config_model.kwargs.hidden_dim 109 | ) 110 | .cuda() 111 | .to(torch_dtype) 112 | ) 113 | latency, throughput = benchmark_train(model, input) 114 | if config_method.name == "linear": 115 | print( 116 | f"{config_model.kwargs.hidden_dim}, {config_model.kwargs.ffn_dim}", file=f 117 | ) 118 | else: 119 | print( 120 | f"{config_model.kwargs.hidden_dim}, {config_model.kwargs.ffn_dim}, {model.get_ckpt_name(config_method.kwargs)}", 121 | file=f, 122 | ) 123 | print( 124 | f"latency: {latency}, throughput: {throughput}, bs: {config.data.test.test_batch}, params: {sum(p.numel() for p in model.parameters())}", 125 | file=f, 126 | ) 127 | 128 | 129 | if __name__ == "__main__": 130 | set_seed(1005) 131 | main() 132 | print("******END*******") 133 | -------------------------------------------------------------------------------- /src/benchmark_eff/benchmark_model_infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch 5 | import random 6 | import os 7 | import hydra 8 | import time 9 | import sys 10 | 11 | project_root = os.path.dirname( 12 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 13 | ) 14 | 15 | print(project_root) 16 | sys.path.insert(0, project_root) 17 | from hide_warnings import hide_warnings 18 | from src.modules import get_model 19 | from omegaconf import DictConfig, OmegaConf 20 | 21 | torch_dtype = torch.bfloat16 22 | prefill = 0 23 | 24 | 25 | def set_seed(seed): 26 | random.seed(seed) 27 | os.environ["PYTHONHASHSEED"] = str(seed) 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | torch.cuda.manual_seed(seed) 31 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 32 | torch.backends.cudnn.benchmark = False 33 | torch.backends.cudnn.deterministic = True 34 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 35 | 36 | 37 | @torch.no_grad() 38 | def benchmark_infer(net, generation, inp=None): 39 | net.eval() 40 | net.to(torch_dtype) 41 | seq_len = inp.shape[1] 42 | bs = inp.shape[0] 43 | inference_params = net.prepare_inference_params( 44 | bs, 45 | mx_seq=seq_len + generation, 46 | ) 47 | tokens = bs * (seq_len + generation) 48 | repeat = 10 49 | warmup = 10 50 | for i in range(warmup): 51 | if seq_len > 0: 52 | inference_params.sequence_len_offset = 0 53 | net(inp, inference_params=inference_params) 54 | cur = torch.zeros(bs, 1).long().cuda() 55 | for j in range(seq_len, seq_len + generation): 56 | inference_params.sequence_len_offset = j 57 | net(input_ids=cur, inference_params=inference_params, use_cache=True) 58 | torch.cuda.synchronize() 59 | t0 = time.time() 60 | for i in range(repeat): 61 | if seq_len > 0: 62 | inference_params.sequence_len_offset = 0 63 | net(inp, inference_params=inference_params) 64 | cur = torch.zeros(bs, 1).long().cuda() 65 | for j in range(seq_len, seq_len + generation): 66 | inference_params.sequence_len_offset = j 67 | net(input_ids=cur, inference_params=inference_params, use_cache=True) 68 | torch.cuda.synchronize() 69 | t1 = time.time() 70 | latency = (t1 - t0) / repeat * (10**3) 71 | throughput = tokens * repeat / (t1 - t0) 72 | print("Latency (ms): {}, Throughput (token/s): {}".format(latency, throughput)) 73 | return latency, throughput 74 | 75 | 76 | @hide_warnings(out=False) 77 | @hydra.main( 78 | version_base=None, 79 | config_path="../configs", 80 | config_name="refinedweb_config", 81 | ) 82 | def main(config): 83 | OmegaConf.register_new_resolver("eval", eval) 84 | model = get_model(config) 85 | model.eval() 86 | f = open("../../exp/logs/arch_infer_latency.log", "a+") 87 | print( 88 | "h-f-a-nq-nkv", 89 | config.model.kwargs.hidden_dim, 90 | config.model.kwargs.ffn_dim, 91 | config.model.kwargs.attn_dim, 92 | config.model.kwargs.num_q_heads, 93 | config.model.kwargs.num_kv_heads, 94 | file=f, 95 | ) 96 | if config.method.name != "linear": 97 | print(config.method.kwargs, file=f) 98 | input_ids = torch.zeros(config.data.test.test_batch, prefill).long().cuda() 99 | latency, throughput = benchmark_infer( 100 | model, config.model.kwargs.max_position_embeddings, input_ids 101 | ) 102 | params = sum(p.numel() for p in model.parameters()) 103 | params_woemb = params - 32000 * config.model.kwargs.hidden_dim 104 | 105 | print( 106 | f"bs: {config.data.test.test_batch}, generation: {config.model.kwargs.max_position_embeddings}, latency: {latency}, params: {params}, params_woemb: {params_woemb}, throughput: {throughput}", 107 | file=f, 108 | ) 109 | 110 | 111 | if __name__ == "__main__": 112 | set_seed(1005) 113 | main() 114 | print("******END*******") 115 | -------------------------------------------------------------------------------- /src/benchmark_eff/benchmark_model_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch 5 | import random 6 | import os 7 | import hydra 8 | import time 9 | import sys 10 | from hide_warnings import hide_warnings 11 | 12 | project_root = os.path.dirname( 13 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 14 | ) 15 | 16 | print(project_root) 17 | sys.path.insert(0, project_root) 18 | 19 | from src.modules import get_model 20 | from omegaconf import DictConfig, OmegaConf 21 | 22 | 23 | torch_dtype = torch.bfloat16 24 | 25 | 26 | def set_seed(seed): 27 | random.seed(seed) 28 | os.environ["PYTHONHASHSEED"] = str(seed) 29 | np.random.seed(seed) 30 | torch.manual_seed(seed) 31 | torch.cuda.manual_seed(seed) 32 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 33 | torch.backends.cudnn.benchmark = False 34 | torch.backends.cudnn.deterministic = True 35 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 36 | 37 | 38 | def benchmark_train(net, inp): 39 | tokens = inp.shape[0] * inp.shape[1] 40 | repeat = 50 41 | warmup = 10 42 | for i in range(warmup): 43 | net(inp) 44 | torch.cuda.synchronize() 45 | t0 = time.time() 46 | for i in range(repeat): 47 | net(inp) 48 | torch.cuda.synchronize() 49 | t1 = time.time() 50 | latency = (t1 - t0) / repeat * (10**3) 51 | throughput = tokens * repeat / (t1 - t0) 52 | print("Latency (ms): {}, Throughput (token/s): {}".format(latency, throughput)) 53 | return latency, throughput 54 | 55 | 56 | @hide_warnings(out=False) 57 | @hydra.main( 58 | version_base=None, 59 | config_path="../configs", 60 | config_name="refinedweb_config", 61 | ) 62 | def main(config): 63 | OmegaConf.register_new_resolver("eval", eval) 64 | model = get_model(config) 65 | model.eval() 66 | f = open("../../exp/logs/arch_train_latency.log", "a+") 67 | print( 68 | "h-f-a-nq-nkv", 69 | config.model.kwargs.hidden_dim, 70 | config.model.kwargs.ffn_dim, 71 | config.model.kwargs.attn_dim, 72 | config.model.kwargs.num_q_heads, 73 | config.model.kwargs.num_kv_heads, 74 | file=f, 75 | ) 76 | if config.method.name != "linear": 77 | print(config.method.kwargs, file=f) 78 | with torch.no_grad(): 79 | input_ids = torch.zeros(config.data.test.test_batch, 1024).long().cuda() 80 | model.to(torch_dtype) 81 | latency, throughput = benchmark_train(model, input_ids) 82 | params = sum(p.numel() for p in model.parameters()) 83 | params_woemb = params - 32000 * config.model.kwargs.hidden_dim 84 | print( 85 | f"bs: {config.data.test.test_batch}, latency: {latency}, params: {params}, params_woemb: {params_woemb}, throughput: {throughput}", 86 | file=f, 87 | ) 88 | 89 | 90 | if __name__ == "__main__": 91 | set_seed(1005) 92 | main() 93 | print("******END*******") 94 | -------------------------------------------------------------------------------- /src/benchmark_eff/cac_batch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch 5 | import random 6 | import os 7 | import hydra 8 | import time 9 | import sys 10 | 11 | project_root = os.path.dirname( 12 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 13 | ) 14 | 15 | print(project_root) 16 | sys.path.insert(0, project_root) 17 | 18 | from src.modules import get_model 19 | from omegaconf import DictConfig, OmegaConf 20 | from hide_warnings import hide_warnings 21 | 22 | torch_dtype = torch.bfloat16 23 | prefill = 0 24 | 25 | 26 | def set_seed(seed): 27 | random.seed(seed) 28 | os.environ["PYTHONHASHSEED"] = str(seed) 29 | np.random.seed(seed) 30 | torch.manual_seed(seed) 31 | torch.cuda.manual_seed(seed) 32 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 33 | torch.backends.cudnn.benchmark = False 34 | torch.backends.cudnn.deterministic = True 35 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 36 | 37 | 38 | def func(batch_size, generation, model): 39 | try: 40 | torch.cuda.empty_cache() 41 | cur = torch.zeros(batch_size, 1).long().cuda() # test the last token directly 42 | 43 | inference_params = model.prepare_inference_params( 44 | batch_size, 45 | mx_seq=prefill + generation, 46 | ) 47 | inference_params.sequence_len_offset = prefill + generation - 1 48 | model(input_ids=cur, inference_params=inference_params, use_cache=True) 49 | except RuntimeError as e: 50 | return None 51 | return batch_size 52 | 53 | 54 | @torch.no_grad() 55 | def find_max_batch_size(model, generation): 56 | start = 256 57 | batch_size = start 58 | max_batch_size = start 59 | step = 256 60 | while True: 61 | if func(batch_size, generation, model): 62 | max_batch_size = batch_size 63 | batch_size += step 64 | else: 65 | break 66 | print(f"bs: {max_batch_size}") 67 | return max_batch_size 68 | 69 | 70 | @hide_warnings(out=False) 71 | @hydra.main( 72 | version_base=None, 73 | config_path="../configs", 74 | config_name="refinedweb_config", 75 | ) 76 | def main(config): 77 | OmegaConf.register_new_resolver("eval", eval) 78 | model = get_model(config) 79 | model.eval() 80 | model.to(torch_dtype) 81 | f = open("../../exp/logs/arch_bs.log", "a+") 82 | print( 83 | "h-f-a-nq-nkv", 84 | config.model.kwargs.hidden_dim, 85 | config.model.kwargs.ffn_dim, 86 | config.model.kwargs.attn_dim, 87 | config.model.kwargs.num_q_heads, 88 | config.model.kwargs.num_kv_heads, 89 | file=f, 90 | ) 91 | if config.method.name != "linear": 92 | print(config.method.kwargs, file=f) 93 | bs = find_max_batch_size(model, config.model.kwargs.max_position_embeddings) 94 | print( 95 | f"bs: {bs}, generation: {config.model.kwargs.max_position_embeddings}", 96 | file=f, 97 | ) 98 | 99 | 100 | if __name__ == "__main__": 101 | set_seed(1005) 102 | main() 103 | print("******END*******") 104 | -------------------------------------------------------------------------------- /src/configs/data/Aug/mixup.yaml: -------------------------------------------------------------------------------- 1 | mixup: 2 | name: mixup 3 | kwargs: 4 | alpha: 0.2 # 0.2 or 0.4 5 | -------------------------------------------------------------------------------- /src/configs/data/Aug/randomaugment.yaml: -------------------------------------------------------------------------------- 1 | randomaugment: 2 | name: RandomAugment 3 | kwargs: 4 | n: 2 5 | m: 14 6 | -------------------------------------------------------------------------------- /src/configs/data/cifar10.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - Aug: 3 | - randomaugment 4 | - mixup 5 | 6 | name: cifar10 7 | train: 8 | path: "/claire-rcp-scratch/shared/xwei/dataset" 9 | train_batch: 512 10 | # sweep: 11 | # values: [32, 64, 128, 256] 12 | test: 13 | path: "/claire-rcp-scratch/shared/xwei/dataset" 14 | test_batch: 512 -------------------------------------------------------------------------------- /src/configs/data/refinedweb.yaml: -------------------------------------------------------------------------------- 1 | name: refinedweb 2 | train: 3 | path: "/claire-rcp-scratch/shared/xwei/dataset/refinedweb" 4 | train_batch: 16 5 | test: 6 | path: "/claire-rcp-scratch/shared/xwei/dataset/refinedweb" 7 | test_batch: 32 8 | overwrite_cache: false 9 | num_workers: 16 10 | block_size: 1024 11 | tokenizer: 12 | name: null 13 | model_max_length: 1024 -------------------------------------------------------------------------------- /src/configs/method/blockdense.yaml: -------------------------------------------------------------------------------- 1 | name: blockdense 2 | kwargs: 3 | first_layer: true 4 | nblocks: 4 5 | rank: 256 6 | training: 7 | enabled: false 8 | init: 9 | post_init: ortho 10 | -------------------------------------------------------------------------------- /src/configs/method/blockshuffle.yaml: -------------------------------------------------------------------------------- 1 | name: blockshuffle 2 | kwargs: 3 | first_layer: true 4 | nblocks: 2 5 | training: 6 | enabled: false 7 | init: 8 | post_init: ortho 9 | -------------------------------------------------------------------------------- /src/configs/method/linear.yaml: -------------------------------------------------------------------------------- 1 | name: linear 2 | -------------------------------------------------------------------------------- /src/configs/method/lowrank.yaml: -------------------------------------------------------------------------------- 1 | name: lowrank 2 | kwargs: 3 | first_layer: true 4 | rank: 256 5 | training: 6 | enabled: false 7 | init: 8 | post_init: svd 9 | -------------------------------------------------------------------------------- /src/configs/model/gpt2.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2 2 | kwargs: 3 | model_type: llama 4 | bos_token_id: 1 5 | eos_token_id: 2 6 | hidden_dim: 768 7 | attn_dim: 768 8 | ffn_dim: 3072 9 | num_q_heads: 12 10 | num_kv_heads: 12 11 | num_layers: 12 12 | hidden_drop: 0.0 13 | embd_drop: 0.0 14 | max_position_embeddings: 1024 15 | vocab_size: 32000 16 | tie_word_embeddings: true 17 | ln: layernorm 18 | act: gelu 19 | bias: true 20 | scale_attn_by_inverse_layer_idx: false 21 | pos_emb: 22 | name: rope 23 | rotary_interleaved: false 24 | seq_len_interpolation_factor: null 25 | rotary_base: 10000 26 | init: 27 | weight_init: fixed 28 | initializer_range: 0.02 29 | -------------------------------------------------------------------------------- /src/configs/model/gpt2l.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2 # to indicate Model Function 2 | kwargs: 3 | model_type: llama # to indicate tokenizer 4 | bos_token_id: 1 5 | eos_token_id: 2 6 | hidden_dim: 1536 7 | attn_dim: 1536 8 | ffn_dim: 6144 9 | num_q_heads: 12 10 | num_kv_heads: 12 11 | num_layers: 24 12 | hidden_drop: 0.0 13 | embd_drop: 0.0 14 | max_position_embeddings: 1024 15 | vocab_size: 32000 16 | tie_word_embeddings: true 17 | ln: layernorm 18 | act: gelu 19 | bias: true 20 | scale_attn_by_inverse_layer_idx: false 21 | pos_emb: 22 | name: rope 23 | rotary_interleaved: false 24 | seq_len_interpolation_factor: null 25 | rotary_base: 10000 26 | init: 27 | weight_init: fixed 28 | initializer_range: 0.02 29 | -------------------------------------------------------------------------------- /src/configs/model/gpt2m.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2 2 | kwargs: 3 | model_type: llama 4 | bos_token_id: 1 5 | eos_token_id: 2 6 | hidden_dim: 1024 7 | attn_dim: 1024 8 | ffn_dim: 4096 9 | num_q_heads: 16 10 | num_kv_heads: 16 11 | num_layers: 24 12 | hidden_drop: 0.0 13 | embd_drop: 0.0 14 | max_position_embeddings: 1024 15 | vocab_size: 32000 16 | tie_word_embeddings: true 17 | ln: layernorm 18 | act: gelu 19 | bias: true 20 | scale_attn_by_inverse_layer_idx: false 21 | pos_emb: 22 | name: rope 23 | rotary_interleaved: false 24 | seq_len_interpolation_factor: null 25 | rotary_base: 10000 26 | init: 27 | weight_init: fixed 28 | initializer_range: 0.02 29 | -------------------------------------------------------------------------------- /src/configs/model/gpt2xl.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2 2 | kwargs: 3 | model_type: llama 4 | bos_token_id: 1 5 | eos_token_id: 2 6 | hidden_dim: 2048 7 | attn_dim: 2048 8 | ffn_dim: 8192 9 | num_q_heads: 16 10 | num_kv_heads: 16 11 | num_layers: 24 12 | hidden_drop: 0.0 13 | embd_drop: 0.0 14 | max_position_embeddings: 1024 15 | vocab_size: 32000 16 | tie_word_embeddings: true 17 | ln: layernorm 18 | act: gelu 19 | bias: true 20 | scale_attn_by_inverse_layer_idx: false 21 | pos_emb: 22 | name: rope 23 | rotary_interleaved: false 24 | seq_len_interpolation_factor: null 25 | rotary_base: 10000 26 | init: 27 | weight_init: fixed 28 | initializer_range: 0.02 29 | -------------------------------------------------------------------------------- /src/configs/optimization/basic.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - lr_scheduler: cosineannealinglr 3 | - optimizer: adamw 4 | - training: regular_training 5 | 6 | max_epoch: 1000 7 | device: cuda 8 | seed: 1005 9 | load_checkpoint: false 10 | save_checkpoint: false 11 | save_dir: /home/xwei/transformers/final_version/exp/ckpt/ 12 | load_save_mode: epoch 13 | check_gradient_norm: false 14 | check_weight_norm: false 15 | gradient_clipping: false -------------------------------------------------------------------------------- /src/configs/optimization/lr_scheduler/cosineannealinglr.yaml: -------------------------------------------------------------------------------- 1 | name: CosineAnnealingLR 2 | kwargs: 3 | warmup_iter: 0.02 4 | T_max: 1000 5 | eta_min: 1.0e-7 -------------------------------------------------------------------------------- /src/configs/optimization/lr_scheduler/multisteplr.yaml: -------------------------------------------------------------------------------- 1 | name: MultiStepLR 2 | kwargs: 3 | milestones: [0.3, 0.6, 0.8] 4 | gamma: 0.1 5 | T_max: 1000 6 | -------------------------------------------------------------------------------- /src/configs/optimization/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | name: adam 2 | kwargs: 3 | lr: 1.0e-4 4 | -------------------------------------------------------------------------------- /src/configs/optimization/optimizer/adamw.yaml: -------------------------------------------------------------------------------- 1 | name: adamw 2 | kwargs: 3 | lr: 1.0e-4 4 | weight_decay: 0.01 -------------------------------------------------------------------------------- /src/configs/optimization/optimizer/sgd.yaml: -------------------------------------------------------------------------------- 1 | name: sgd 2 | kwargs: 3 | lr: 0.1 4 | momentum: 0.9 5 | weight_decay: 5.0e-4 6 | -------------------------------------------------------------------------------- /src/configs/optimization/training/regular_training.yaml: -------------------------------------------------------------------------------- 1 | name: regular_training 2 | -------------------------------------------------------------------------------- /src/configs/optimization/training/self_guided_training.yaml: -------------------------------------------------------------------------------- 1 | name: self_guided_training 2 | kwargs: 3 | mode: fixedstep # fixedstep, fixedflops 4 | scheduler: cosine 5 | reduce_flop: false 6 | max_step: null 7 | max_step_ratio: 0.5 # ratio of the total steps -------------------------------------------------------------------------------- /src/configs/refinedweb_config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - method: &method linear 3 | - optimization: &optimization basic 4 | - data: &data refinedweb 5 | - model: &model gpt2 6 | - _self_ 7 | 8 | # rewrite optimization cifar here 9 | optimization: 10 | max_tokens: 2200000000 11 | global_batch_size: 512 12 | gradient_checkpointing: false 13 | gradient_clipping: 1.0 14 | log_interval: 20 15 | load_save_mode: step 16 | load_checkpoint: true 17 | save_checkpoint: true 18 | optimizer: 19 | kwargs: 20 | lr: &lr 6.0e-4 21 | weight_decay: 0.1 22 | betas: [0.9, 0.999] 23 | lr_scheduler: 24 | kwargs: 25 | warmup_iter: 0.1 26 | eta_min: ${eval:0.1 * ${optimization.optimizer.kwargs.lr}} 27 | 28 | data: 29 | tokenizer: 30 | name: ${model.kwargs.model_type} 31 | model_max_length: ${model.kwargs.max_position_embeddings} 32 | block_size: ${model.kwargs.max_position_embeddings} 33 | 34 | base_dir: &base_dir /home/xwei/transformers/final_version/exp/ 35 | 36 | wandb: 37 | entity: xiuying-wei 38 | project: gpt2reproduce 39 | mode: online 40 | anonymous: allow 41 | dir: *base_dir 42 | 43 | wandb_use: true 44 | hydra: 45 | run: 46 | dir: /home/xwei/transformers/final_version/exp/hydra/${now:%Y-%m-%d}/${now:%H-%M-%S} 47 | -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from easydict import EasyDict 3 | from .model import GPT2LMHeadModel 4 | from .mlp import ( 5 | FusedBlockDenseMLP, 6 | FusedLowRankMLP, 7 | FusedBlockShuffleMLP, 8 | FusedMLP, 9 | ) 10 | 11 | name_to_model = { 12 | "gpt2": GPT2LMHeadModel, 13 | } 14 | 15 | name_to_method = { 16 | "lowrank": FusedLowRankMLP, 17 | "blockdense": FusedBlockDenseMLP, 18 | "blockshuffle": FusedBlockShuffleMLP, 19 | } 20 | 21 | 22 | def replace_mlp(model, config_method, config_model, device="cuda"): 23 | first_layer = ( 24 | config_method.kwargs.first_layer 25 | ) # true: keep the original linear layer 26 | for i in range(config_model.kwargs.num_layers): 27 | if first_layer and i == 0: 28 | continue 29 | new_module = name_to_method[config_method.name.lower()]( 30 | config_model.kwargs.hidden_dim, 31 | config_model.kwargs.ffn_dim, 32 | config_model.kwargs.bias, 33 | config_model.kwargs.act, 34 | config_method.kwargs, 35 | config_model.kwargs.init, 36 | device=device, 37 | ) 38 | del model.model.layers[i].mlp 39 | model.model.layers[i].mlp = new_module 40 | 41 | 42 | def get_model(config, device="cuda"): 43 | config_model = config.model 44 | config_method = config.method 45 | model = name_to_model[config_model.name.lower()](config_model.get("kwargs", {})).to( 46 | device 47 | ) 48 | 49 | # replace here 50 | if config_method.name.lower() == "linear": 51 | return model 52 | replace_mlp(model, config_method, config_model, device) 53 | model.to(device) 54 | return model 55 | 56 | 57 | def get_ckpt_name(config): 58 | config_model = config.model 59 | config_method = config.method 60 | long_name = config_model.name + name_to_model[ 61 | config_model.name.lower() 62 | ].get_ckpt_name(config_model.get("kwargs", {})) 63 | if config_method.name != "linear": 64 | long_name += ( 65 | "-" 66 | + config_method.name 67 | + name_to_method[config_method.name.lower()].get_ckpt_name( 68 | config_method.get("kwargs", {}) 69 | ) 70 | ) 71 | return long_name 72 | 73 | 74 | def update_ratio(module): 75 | if hasattr(module, "_update_ratio"): 76 | module._update_ratio() 77 | -------------------------------------------------------------------------------- /src/modules/layer/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .customlinear import CustomLinear 3 | from .lowrank import LowRank 4 | from .blockdense import BlockDense 5 | from .blockshuffle import BlockShuffle 6 | -------------------------------------------------------------------------------- /src/modules/layer/basiclinear.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | from .util import LinearTempDecay, CosineTempDecay 5 | 6 | 7 | class BasicLinear(nn.Module): 8 | 9 | def __init__( 10 | self, in_features, out_features, bias, return_bias, config, init_config, device 11 | ): 12 | super().__init__() 13 | # config: method part, and model init 14 | self.device = device 15 | self.config = config 16 | self.init_config = init_config 17 | self.training_config = self.config.training 18 | # model part 19 | self.in_features = in_features 20 | self.out_features = out_features 21 | # otherwise, we need to fuse the bias into the ops 22 | assert return_bias is True 23 | if bias: 24 | self.bias = nn.Parameter(torch.empty(self.out_features, device=device)) 25 | else: 26 | self.bias = None 27 | 28 | if self.training_config.enabled: 29 | self.guide_linear = nn.Parameter( 30 | torch.empty(self.out_features, self.in_features, device=device) 31 | ) 32 | self.register_buffer("count", torch.tensor(0).cuda(), persistent=True) 33 | self.register_buffer("ratio", torch.tensor(1.0).cuda(), persistent=True) 34 | guide_scheduler = { 35 | "linear": LinearTempDecay, 36 | "cosine": CosineTempDecay, 37 | } 38 | self.guide_scheduler = guide_scheduler[self.training_config.scheduler]( 39 | t_max=self.training_config.max_step 40 | ) 41 | 42 | @torch.no_grad() 43 | def _update_ratio( 44 | self, 45 | ): 46 | self.count += 1 47 | self.ratio = self.guide_scheduler(self.count) 48 | 49 | def _check_guide_layer( 50 | self, 51 | ): 52 | if not self.training_config.enabled: 53 | return False 54 | if ( 55 | self.training_config.reduce_flop 56 | and torch.rand_like(self.ratio) >= self.ratio 57 | ): 58 | return False 59 | return True 60 | 61 | def forward_guide_layer(self, input, out): 62 | if self._check_guide_layer(): 63 | guide_out = torch.matmul(input, self.guide_linear.transpose(-1, -2)) 64 | out = self.ratio * guide_out + (1.0 - self.ratio) * out 65 | return out, self.bias 66 | 67 | def get_weights( 68 | self, 69 | ): 70 | pass 71 | 72 | @torch.no_grad() 73 | def _init_weights( 74 | self, 75 | ): 76 | if self.bias is not None: 77 | nn.init.zeros_(self.bias) 78 | for para in self.get_weights(): 79 | if self.init_config.weight_init == "xavier": 80 | nn.init.normal_(para, mean=0.0, std=(para.shape[-1] ** -0.5)) 81 | elif self.init_config.weight_init == "fixed": 82 | nn.init.normal_(para, std=self.init_config.initializer_range) 83 | else: 84 | raise NotImplementedError 85 | -------------------------------------------------------------------------------- /src/modules/layer/blockdense.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from ..op import block_dense_custom 4 | from .basiclinear import BasicLinear 5 | 6 | 7 | class BlockDense(BasicLinear): 8 | 9 | def __init__( 10 | self, 11 | in_features, 12 | out_features, 13 | bias, 14 | return_bias, 15 | config, 16 | init_config, 17 | device="cuda", 18 | ): 19 | super().__init__( 20 | in_features, out_features, bias, return_bias, config, init_config, device 21 | ) 22 | self.rank = config["rank"] 23 | self.nblocks = config["nblocks"] 24 | assert self.in_features % self.nblocks == 0 25 | assert self.rank % self.nblocks == 0 26 | self.blkdiag = nn.Parameter( 27 | torch.empty( 28 | self.nblocks, 29 | self.rank // self.nblocks, 30 | self.in_features // self.nblocks, 31 | device=device, 32 | ) 33 | ) 34 | self.lr = nn.Parameter(torch.empty(self.out_features, self.rank, device=device)) 35 | 36 | self._init_weights() 37 | self.post_init() 38 | 39 | def get_weights( 40 | self, 41 | ): 42 | return [self.blkdiag, self.lr] 43 | 44 | @torch.no_grad() 45 | def post_init( 46 | self, 47 | ): 48 | if self.config.init.post_init == "ortho": 49 | for i in range(self.nblocks): 50 | U, S, Vh = torch.linalg.svd(self.blkdiag.data[i], full_matrices=False) 51 | self.blkdiag.data[i] = torch.mm(U, Vh) 52 | U, S, Vh = torch.linalg.svd(self.lr.data, full_matrices=False) 53 | self.lr.data = torch.mm(U, Vh) 54 | # init guide linear 55 | if hasattr(self, "guide_linear"): 56 | self.guide_linear.data = torch.mm( 57 | self.lr.data, torch.block_diag(*torch.unbind(self.blkdiag.data, dim=0)) 58 | ) 59 | 60 | def forward(self, input): 61 | out = block_dense_custom(input, self.blkdiag, self.lr) 62 | return self.forward_guide_layer(input, out) 63 | 64 | def extra_repr(self) -> str: 65 | return f"blockdiag1={self.blkdiag.shape}, linear={self.lr.shape}, bias={self.bias is not None}, guide={self.training_config.enabled}" 66 | -------------------------------------------------------------------------------- /src/modules/layer/blockshuffle.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from .basiclinear import BasicLinear 4 | from ..op import block_shuffle_bmm, block_shuffle_custom 5 | 6 | 7 | class BlockShuffle(BasicLinear): 8 | 9 | def __init__( 10 | self, 11 | in_features, 12 | out_features, 13 | bias, 14 | return_bias, 15 | config, 16 | init_config, 17 | device="cuda", 18 | ): 19 | super().__init__( 20 | in_features, out_features, bias, return_bias, config, init_config, device 21 | ) 22 | self.nblocks = config["nblocks"] 23 | assert self.in_features % self.nblocks == 0 24 | assert self.out_features % self.nblocks == 0 25 | 26 | in_blksz = self.in_features // self.nblocks 27 | out_blksz = self.out_features // self.nblocks 28 | 29 | if self.in_features < self.out_features: 30 | self.blkdiag1 = nn.Parameter( 31 | torch.empty(self.nblocks, in_blksz, in_blksz, device=device) 32 | ) 33 | self.blkdiag2 = nn.Parameter( 34 | torch.empty(self.nblocks, out_blksz, in_blksz, device=device) 35 | ) 36 | else: 37 | self.blkdiag1 = nn.Parameter( 38 | torch.empty(self.nblocks, out_blksz, in_blksz, device=device) 39 | ) 40 | self.blkdiag2 = nn.Parameter( 41 | torch.empty(self.nblocks, out_blksz, out_blksz, device=device) 42 | ) 43 | self._init_weights() 44 | self.post_init() 45 | 46 | def get_weights( 47 | self, 48 | ): 49 | return [self.blkdiag1, self.blkdiag2] 50 | 51 | @torch.no_grad() 52 | def post_init( 53 | self, 54 | ): 55 | if self.config.init.post_init == "ortho": 56 | for i in range(self.nblocks): 57 | U, S, Vh = torch.linalg.svd(self.blkdiag1.data[i], full_matrices=False) 58 | self.blkdiag1.data[i] = torch.mm(U, Vh) 59 | U, S, Vh = torch.linalg.svd(self.blkdiag2.data[i], full_matrices=False) 60 | self.blkdiag2.data[i] = torch.mm(U, Vh) 61 | 62 | # init guide linear 63 | if hasattr(self, "guide_linear"): 64 | self.guide_linear.data = torch.mm( 65 | torch.block_diag(*torch.unbind(self.blkdiag2.data, dim=0)), 66 | torch.block_diag(*torch.unbind(self.blkdiag1.data, dim=0)), 67 | ) 68 | 69 | def forward(self, input): 70 | out = block_shuffle_custom(input, self.blkdiag1, self.blkdiag2) 71 | return self.forward_guide_layer(input, out) 72 | 73 | def extra_repr(self) -> str: 74 | return f"blockdiag1={self.blkdiag1.shape}, blockdiag2={self.blkdiag2.shape}, bias={self.bias is not None}, guide={self.training_config.enabled}" 75 | -------------------------------------------------------------------------------- /src/modules/layer/customlinear.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | # please do not inhere the basic linear here 6 | class CustomLinear(nn.Module): 7 | 8 | def __init__(self, in_features, out_features, bias, return_bias=True): 9 | super().__init__() 10 | self.in_features = in_features 11 | self.out_features = out_features 12 | self.weight = nn.Parameter( 13 | torch.empty( 14 | out_features, 15 | in_features, 16 | ) 17 | ) 18 | # otherwise, we need to fuse the bias into the ops 19 | assert return_bias is True 20 | 21 | if bias: 22 | self.bias = nn.Parameter(torch.empty(out_features)) 23 | else: 24 | self.bias = None 25 | 26 | def forward(self, inp): 27 | output = torch.matmul(inp, self.weight.transpose(-1, -2)) 28 | return output, self.bias 29 | 30 | def extra_repr(self) -> str: 31 | return f"linearshape={self.weight.shape}, bias={self.bias is not None}" 32 | -------------------------------------------------------------------------------- /src/modules/layer/lowrank.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from .basiclinear import BasicLinear 4 | from ..op import low_rank_custom 5 | 6 | 7 | class LowRank(BasicLinear): 8 | 9 | def __init__( 10 | self, 11 | in_features, 12 | out_features, 13 | bias, 14 | return_bias, 15 | config, 16 | init_config, 17 | device="cuda", 18 | ): 19 | super().__init__( 20 | in_features, 21 | out_features, 22 | bias, 23 | return_bias, 24 | config, 25 | init_config, 26 | device=device, 27 | ) 28 | self.rank = config["rank"] 29 | self.lr1 = nn.Parameter(torch.empty(self.rank, self.in_features, device=device)) 30 | self.lr2 = nn.Parameter( 31 | torch.empty(self.out_features, self.rank, device=device) 32 | ) 33 | self._init_weights() 34 | self.post_init() 35 | 36 | def get_weights( 37 | self, 38 | ): 39 | return [self.lr1, self.lr2] 40 | 41 | @torch.no_grad() 42 | def post_init( 43 | self, 44 | ): 45 | if self.config.init.post_init == "svd": 46 | org_linear = nn.Parameter( 47 | torch.empty(self.out_features, self.in_features, device=self.device) 48 | ) 49 | if self.init_config.weight_init == "xavier": 50 | nn.init.normal_( 51 | org_linear, mean=0.0, std=(org_linear.shape[-1] ** -0.5) 52 | ) 53 | elif self.init_config.weight_init == "fixed": 54 | nn.init.normal_(org_linear, std=self.init_config.initializer_range) 55 | else: 56 | raise NotImplementedError 57 | U, S, Vh = torch.linalg.svd(org_linear, full_matrices=False) 58 | sqrt_S = torch.sqrt(torch.diag_embed(S[: self.rank])) 59 | self.lr1.data = sqrt_S @ Vh[: self.rank, :] 60 | self.lr2.data = U[:, : self.rank] @ sqrt_S 61 | 62 | # init guide linear 63 | if hasattr(self, "guide_linear"): 64 | self.guide_linear.data = torch.mm(self.lr2.data, self.lr1.data) 65 | 66 | def forward(self, input): 67 | out = low_rank_custom(input, self.lr1, self.lr2) 68 | return self.forward_guide_layer(input, out) 69 | 70 | def extra_repr(self) -> str: 71 | return f"lr1={self.lr1.shape}, lr2={self.lr2.shape}, bias={self.bias is not None}, guide={self.training_config.enabled}" 72 | -------------------------------------------------------------------------------- /src/modules/layer/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | class LinearTempDecay: 6 | def __init__(self, t_max=20000, warm_up=0, start_b=1.0, end_b=0.0): 7 | self.t_max = t_max 8 | self.warmup = warm_up 9 | self.start_b = torch.tensor(start_b).cuda() 10 | self.end_b = torch.tensor(end_b).cuda() 11 | print( 12 | "linear scheduler for self-guided training in steps {} with warmup {}".format( 13 | self.t_max, self.warmup 14 | ) 15 | ) 16 | 17 | def __call__(self, t): 18 | if t < self.warmup: 19 | return self.start_b 20 | elif t > self.t_max: 21 | return self.end_b 22 | else: 23 | rel_t = (t - self.warmup) / (self.t_max - self.warmup) 24 | return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t)) 25 | 26 | 27 | class CosineTempDecay: 28 | def __init__(self, t_max=20000, warm_up=0, start_b=1.0, end_b=0.0): 29 | self.t_max = t_max 30 | self.warmup = warm_up 31 | self.start_b = torch.tensor(start_b).cuda() 32 | self.end_b = torch.tensor(end_b).cuda() 33 | print( 34 | "Cosine scheduler for self-guided training in steps {} with warmup {}".format( 35 | self.t_max, self.warmup 36 | ) 37 | ) 38 | 39 | def __call__(self, t): 40 | if t < self.warmup: 41 | return self.start_b 42 | elif t > self.t_max: 43 | return self.end_b 44 | else: 45 | rel_t = (t - self.warmup) / (self.t_max - self.warmup) 46 | return self.end_b + 0.5 * (self.start_b - self.end_b) * ( 47 | 1 + torch.cos(rel_t * math.pi) 48 | ) 49 | -------------------------------------------------------------------------------- /src/modules/mlp/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .lowrank_mlp import FusedLowRankMLP 3 | from .blockdense_mlp import FusedBlockDenseMLP 4 | from .blockshuffle_mlp import FusedBlockShuffleMLP 5 | from .mlp import FusedMLP 6 | -------------------------------------------------------------------------------- /src/modules/mlp/basic_mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformer_engine.pytorch.jit import set_jit_fusion_options 3 | from ..op import bias_gelu_impl, bias_swiglu_impl 4 | 5 | 6 | act_func_dict = { 7 | "gelu": bias_gelu_impl, 8 | "swiglu": bias_swiglu_impl, 9 | } 10 | 11 | 12 | class FusedBasicMLP(nn.Module): 13 | def __init__(self, hidden_dim, ffn_dim, bias, act="gelu"): 14 | super().__init__() 15 | self.hidden_dim = hidden_dim 16 | self.ffn_dim = ffn_dim 17 | # fuse bias and gelu 18 | set_jit_fusion_options() 19 | self.fc1 = None 20 | self.fc2 = None 21 | self.act_func = act_func_dict.get(act, None) 22 | if act in ["swiglu"]: 23 | self.ffn_dim *= 2 24 | 25 | def forward(self, input): 26 | fc1_outs = self.fc1(input) 27 | gelu_out = self.act_func(*fc1_outs) 28 | fc2_outs = self.fc2(gelu_out) 29 | return fc2_outs 30 | -------------------------------------------------------------------------------- /src/modules/mlp/blockdense_mlp.py: -------------------------------------------------------------------------------- 1 | from .basic_mlp import FusedBasicMLP 2 | from ..layer import BlockDense 3 | 4 | 5 | class FusedBlockDenseMLP(FusedBasicMLP): 6 | 7 | def __init__( 8 | self, 9 | hidden_dim, 10 | ffn_dim, 11 | bias, 12 | act, 13 | config, 14 | init_config, 15 | device, 16 | ): 17 | super().__init__(hidden_dim, ffn_dim, bias, act=act) 18 | self.fc1 = BlockDense( 19 | hidden_dim, 20 | self.ffn_dim, 21 | bias=bias, 22 | return_bias=True, 23 | config=config, 24 | init_config=init_config, 25 | device=device, 26 | ) 27 | self.fc2 = BlockDense( 28 | ffn_dim, 29 | hidden_dim, 30 | bias=bias, 31 | return_bias=True, 32 | config=config, 33 | init_config=init_config, 34 | device=device, 35 | ) 36 | 37 | @staticmethod 38 | def get_ckpt_name(config_method): 39 | long_name = ( 40 | "r" 41 | + str(config_method.rank) 42 | + "b" 43 | + str(config_method.nblocks) 44 | + "-" 45 | + str(config_method.init.post_init) 46 | ) 47 | return long_name 48 | -------------------------------------------------------------------------------- /src/modules/mlp/blockshuffle_mlp.py: -------------------------------------------------------------------------------- 1 | from .basic_mlp import FusedBasicMLP 2 | from ..layer import BlockShuffle 3 | 4 | 5 | class FusedBlockShuffleMLP(FusedBasicMLP): 6 | 7 | def __init__( 8 | self, 9 | hidden_dim, 10 | ffn_dim, 11 | bias, 12 | act, 13 | config, 14 | init_config, 15 | device, 16 | ): 17 | super().__init__(hidden_dim, ffn_dim, bias, act=act) 18 | self.fc1 = BlockShuffle( 19 | hidden_dim, 20 | self.ffn_dim, 21 | bias=bias, 22 | return_bias=True, 23 | config=config, 24 | init_config=init_config, 25 | device=device, 26 | ) 27 | self.fc2 = BlockShuffle( 28 | ffn_dim, 29 | hidden_dim, 30 | bias=bias, 31 | return_bias=True, 32 | config=config, 33 | init_config=init_config, 34 | device=device, 35 | ) 36 | 37 | @staticmethod 38 | def get_ckpt_name(config_method): 39 | long_name = ( 40 | "b" + str(config_method.nblocks) + "-" + str(config_method.init.post_init) 41 | ) 42 | return long_name 43 | -------------------------------------------------------------------------------- /src/modules/mlp/lowrank_mlp.py: -------------------------------------------------------------------------------- 1 | from .basic_mlp import FusedBasicMLP 2 | from ..layer import LowRank 3 | 4 | 5 | class FusedLowRankMLP(FusedBasicMLP): 6 | 7 | def __init__( 8 | self, 9 | hidden_dim, 10 | ffn_dim, 11 | bias, 12 | act, 13 | config, 14 | init_config, 15 | device, 16 | ): 17 | super().__init__(hidden_dim, ffn_dim, bias, act=act) 18 | self.fc1 = LowRank( 19 | hidden_dim, 20 | self.ffn_dim, 21 | bias=bias, 22 | return_bias=True, 23 | config=config, 24 | init_config=init_config, 25 | device=device, 26 | ) 27 | self.fc2 = LowRank( 28 | ffn_dim, 29 | hidden_dim, 30 | bias=bias, 31 | return_bias=True, 32 | config=config, 33 | init_config=init_config, 34 | device=device, 35 | ) 36 | 37 | @staticmethod 38 | def get_ckpt_name(config_method): 39 | long_name = ( 40 | "r" + str(config_method.rank) + "-" + str(config_method.init.post_init) 41 | ) 42 | return long_name 43 | -------------------------------------------------------------------------------- /src/modules/mlp/mlp.py: -------------------------------------------------------------------------------- 1 | from .basic_mlp import FusedBasicMLP 2 | from ..layer import CustomLinear 3 | 4 | 5 | class FusedMLP(FusedBasicMLP): 6 | def __init__(self, hidden_dim, ffn_dim, bias, act="gelu"): 7 | super().__init__(hidden_dim, ffn_dim, bias, act) 8 | self.fc1 = CustomLinear(hidden_dim, self.ffn_dim, bias=bias, return_bias=True) 9 | self.fc2 = CustomLinear(ffn_dim, hidden_dim, bias=bias, return_bias=True) 10 | -------------------------------------------------------------------------------- /src/modules/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .gpt2 import GPT2LMHeadModel 2 | -------------------------------------------------------------------------------- /src/modules/model/gpt2.py: -------------------------------------------------------------------------------- 1 | """A fast version of gpt2 with flash attention and transformer engine""" 2 | 3 | import torch.nn as nn 4 | import torch 5 | import torch.nn.functional as F 6 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 7 | import transformer_engine.pytorch as te 8 | 9 | from transformer_engine.pytorch.jit import ( 10 | set_jit_fusion_options, 11 | ) 12 | from flash_attn import flash_attn_func, flash_attn_with_kvcache 13 | from ..layer import CustomLinear 14 | from ..mlp import FusedMLP 15 | from ..op import bias_dropout_add_impl, RotaryEmbedding, apply_rotary_pos_emb 16 | 17 | 18 | layernorm_func = { 19 | "layernorm": te.LayerNorm, 20 | "rmsnorm": te.RMSNorm, 21 | } 22 | 23 | 24 | class InferenceParams: 25 | 26 | def __init__(self, max_batch_size, max_sequence_length): 27 | self.max_sequence_length = max_sequence_length 28 | self.max_batch_size = max_batch_size 29 | self.sequence_len_offset = 0 30 | self.batch_size_offset = 0 31 | self.key_value_memory_dict = {} 32 | 33 | 34 | class TransformerLayer(nn.Module): 35 | 36 | def __init__(self, config, layer_number): 37 | super().__init__() 38 | self.hidden_dim = config.hidden_dim 39 | self.attn_dim = config.attn_dim 40 | self.ffn_dim = config.ffn_dim 41 | self.num_q_heads = config.num_q_heads 42 | assert self.attn_dim % self.num_q_heads == 0 43 | self.head_dim = self.attn_dim // self.num_q_heads 44 | self.num_kv_heads = config.num_kv_heads 45 | self.layer_number = layer_number 46 | self.hidden_dropout = config.hidden_drop 47 | 48 | self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx 49 | set_jit_fusion_options() 50 | self.ln1 = layernorm_func[config.ln](self.hidden_dim) 51 | self.qkv_linear = nn.Linear( 52 | self.hidden_dim, 53 | (self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim, 54 | bias=config.bias, 55 | ) 56 | self.o_linear = CustomLinear( 57 | self.attn_dim, 58 | self.hidden_dim, 59 | bias=config.bias, 60 | return_bias=True, 61 | ) 62 | self.ln2 = layernorm_func[config.ln](self.hidden_dim) 63 | self.mlp = FusedMLP( 64 | self.hidden_dim, self.ffn_dim, bias=config.bias, act=config.act 65 | ) 66 | 67 | def _bias_dropout_add(self, hidden_state, bias, residual): 68 | bias_dropout_add_func = bias_dropout_add_impl(self.training) 69 | output = bias_dropout_add_func( 70 | (hidden_state, bias), residual, self.hidden_dropout 71 | ) 72 | return output 73 | 74 | def _adjust_key_value_for_inference( 75 | self, inference_params, k_out, v_out, rotary_pos_emb 76 | ): 77 | if inference_params is None: 78 | return k_out, v_out, rotary_pos_emb 79 | bs = k_out.shape[0] 80 | seq_len = k_out.shape[1] 81 | 82 | inference_key_memory, inference_value_memory = ( 83 | inference_params.key_value_memory_dict[self.layer_number] 84 | ) 85 | batch_start = inference_params.batch_size_offset 86 | batch_end = batch_start + bs 87 | assert batch_end <= inference_key_memory.size(0) 88 | sequence_start = inference_params.sequence_len_offset 89 | sequence_end = sequence_start + seq_len 90 | assert sequence_end <= inference_key_memory.size(1) 91 | inference_key_memory[ 92 | batch_start:batch_end, sequence_start:sequence_end, ... 93 | ] = k_out 94 | inference_value_memory[ 95 | batch_start:batch_end, sequence_start:sequence_end, ... 96 | ] = v_out 97 | key = inference_key_memory[batch_start:batch_end, :sequence_end, ...] 98 | value = inference_value_memory[batch_start:batch_end, :sequence_end, ...] 99 | 100 | # adjust the key rotary positional embedding 101 | if rotary_pos_emb is not None: 102 | q_pos_emb, k_pos_emb = rotary_pos_emb 103 | q_pos_emb = q_pos_emb[sequence_start:sequence_end, :, :, :] 104 | k_pos_emb = k_pos_emb[:sequence_end, :, :, :] 105 | rotary_pos_emb = (q_pos_emb, k_pos_emb) 106 | 107 | return key, value, rotary_pos_emb 108 | 109 | def forward( 110 | self, 111 | hidden_states, 112 | inference_params: Optional[InferenceParams] = None, 113 | use_cache=False, 114 | rotary_pos_emb: torch.Tensor = None, 115 | ): 116 | hidden_states = hidden_states.contiguous() 117 | bs, seq_len, _ = hidden_states.shape 118 | qkv_out = self.qkv_linear(self.ln1(hidden_states)) 119 | q_out = qkv_out[..., : (self.num_q_heads * self.head_dim)] 120 | kv_out = qkv_out[..., (self.num_q_heads * self.head_dim) :] 121 | k_out, v_out = kv_out.chunk(2, -1) 122 | q_out = q_out.reshape(bs, seq_len, self.num_q_heads, self.head_dim) 123 | k_out = k_out.reshape(bs, seq_len, self.num_kv_heads, self.head_dim) 124 | v_out = v_out.reshape(bs, seq_len, self.num_kv_heads, self.head_dim) 125 | 126 | if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): 127 | rotary_pos_emb = (rotary_pos_emb,) * 2 128 | 129 | k_out, v_out, rotary_pos_emb = self._adjust_key_value_for_inference( 130 | inference_params, k_out, v_out, rotary_pos_emb 131 | ) 132 | if rotary_pos_emb is not None: 133 | q_pos_emb, k_pos_emb = rotary_pos_emb 134 | q_out = apply_rotary_pos_emb( 135 | q_out, 136 | q_pos_emb, 137 | ) 138 | k_out = apply_rotary_pos_emb( 139 | k_out, 140 | k_pos_emb, 141 | ) 142 | softmax_scale = q_out.shape[-1] ** (-0.5) 143 | if self.scale_attn_by_inverse_layer_idx: 144 | softmax_scale /= float(self.layer_number + 1) 145 | if not use_cache: 146 | attention_out = flash_attn_func( 147 | q_out, k_out, v_out, softmax_scale=softmax_scale, causal=True 148 | ).reshape(bs, seq_len, self.attn_dim) 149 | else: 150 | attention_out = flash_attn_with_kvcache( 151 | q_out, 152 | k_out, 153 | v_out, 154 | softmax_scale=softmax_scale, 155 | cache_seqlens=inference_params.sequence_len_offset, 156 | causal=True, 157 | ).reshape(bs, seq_len, self.attn_dim) 158 | 159 | attention_out, attention_bias = self.o_linear(attention_out) 160 | hidden_states = self._bias_dropout_add( 161 | attention_out, attention_bias, hidden_states 162 | ) 163 | ln2_out = self.ln2(hidden_states) 164 | fc2_out, fc2_bias = self.mlp(ln2_out) 165 | hidden_states = self._bias_dropout_add(fc2_out, fc2_bias, hidden_states) 166 | return hidden_states 167 | 168 | 169 | class BasicGPT2(nn.Module): 170 | 171 | def __init__( 172 | self, 173 | ): 174 | super().__init__() 175 | 176 | @torch.no_grad() 177 | def _init_weights(self, module, init_config): 178 | """initialize the weight""" 179 | if init_config.weight_init == "fixed": 180 | initializer_range = init_config.initializer_range 181 | if isinstance(module, (nn.Linear, CustomLinear)): 182 | module.weight.data.normal_(mean=0.0, std=initializer_range) 183 | if module.bias is not None: 184 | module.bias.data.zero_() 185 | elif isinstance(module, nn.Embedding): 186 | module.weight.data.normal_(mean=0.0, std=initializer_range) 187 | elif isinstance(module, (nn.LayerNorm, te.LayerNorm, te.RMSNorm)): 188 | if hasattr(module, "bias"): 189 | module.bias.data.zero_() 190 | module.weight.data.fill_(1.0) 191 | else: 192 | raise NotImplementedError 193 | 194 | 195 | class GPT2Model(BasicGPT2): 196 | 197 | def __init__(self, config): 198 | super().__init__() 199 | self.config = config 200 | self.embed_dim = config.hidden_dim 201 | self.wte = nn.Embedding(config.vocab_size, self.embed_dim) 202 | if config.pos_emb.name == "absolute": 203 | self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) 204 | elif config.pos_emb.name == "rope": 205 | self.rotary_pos_emb = RotaryEmbedding( 206 | kv_channels=config.attn_dim // config.num_q_heads, 207 | rotary_interleaved=config.pos_emb.rotary_interleaved, 208 | seq_len_interpolation_factor=config.pos_emb.seq_len_interpolation_factor, 209 | rotary_base=config.pos_emb.rotary_base, 210 | ) 211 | else: 212 | raise NotImplementedError 213 | self.drop = nn.Dropout(config.embd_drop) 214 | self.layers = nn.ModuleList( 215 | [TransformerLayer(config, i) for i in range(config.num_layers)] 216 | ) 217 | self.ln_f = layernorm_func[config.ln](self.embed_dim) 218 | 219 | def forward( 220 | self, 221 | input_ids: torch.LongTensor = None, 222 | inference_params: Optional[InferenceParams] = None, 223 | use_cache=False, 224 | ): 225 | bs, seq = input_ids.shape 226 | seq_start = ( 227 | inference_params.sequence_len_offset 228 | if use_cache and inference_params is not None 229 | else 0 230 | ) 231 | seq_end = seq_start + seq 232 | 233 | position_ids = ( 234 | torch.arange(seq_start, seq_end, dtype=torch.long, device=input_ids.device) 235 | .unsqueeze(0) 236 | .view(-1, seq) 237 | ) 238 | inputs_embeds = self.wte(input_ids) 239 | if self.config.pos_emb.name == "absolute": 240 | position_embeds = self.wpe(position_ids) 241 | hidden_states = inputs_embeds + position_embeds 242 | else: 243 | hidden_states = inputs_embeds 244 | hidden_states = self.drop(hidden_states) 245 | if self.config.pos_emb.name == "rope": 246 | rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( 247 | inference_params, hidden_states 248 | ) 249 | rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) 250 | else: 251 | rotary_pos_emb = None 252 | for layer in self.layers: 253 | hidden_states = layer( 254 | hidden_states, inference_params, use_cache, rotary_pos_emb 255 | ) 256 | hidden_states = self.ln_f(hidden_states) 257 | return hidden_states 258 | 259 | 260 | class GPT2LMHeadModel(BasicGPT2): 261 | 262 | def __init__(self, config): 263 | super().__init__() 264 | self.config = config 265 | self.model = GPT2Model(config) 266 | self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False) 267 | # init weight 268 | self.apply( 269 | lambda module: self._init_weights(module=module, init_config=config.init) 270 | ) 271 | # tie weight embedding 272 | if config.tie_word_embeddings: 273 | self.lm_head.weight = self.model.wte.weight 274 | 275 | @staticmethod 276 | def get_ckpt_name(config_model): 277 | return ( 278 | "h" 279 | + f"{config_model.hidden_dim}" 280 | + "a" 281 | + f"{config_model.attn_dim}" 282 | + "f" 283 | + f"{config_model.ffn_dim}" 284 | + "nkv" 285 | + f"{config_model.num_kv_heads}" 286 | + f"{config_model.act}" 287 | + f"{config_model.pos_emb.name}" 288 | + f"{config_model.ln}" 289 | ) 290 | 291 | def get_flops(self, bs, seq_len): 292 | attn_qo = 2 * bs * seq_len * self.config.attn_dim * self.config.hidden_dim 293 | attn_kv = ( 294 | 2 295 | * bs 296 | * seq_len 297 | * (self.config.attn_dim // self.config.num_q_heads) 298 | * self.config.num_kv_heads 299 | * self.config.hidden_dim 300 | ) 301 | sdp = 2 * bs * seq_len * seq_len * self.config.attn_dim 302 | return ( 303 | 2 * self.config.num_layers * (attn_qo + attn_kv + sdp) 304 | + self.get_flops_mlp(bs, seq_len) 305 | + 2 * bs * seq_len * self.config.vocab_size * self.config.hidden_dim 306 | ) 307 | 308 | def get_params( 309 | self, 310 | ): 311 | attn_qo = 2 * self.config.attn_dim * self.config.hidden_dim 312 | attn_kv = ( 313 | 2 314 | * (self.config.attn_dim // self.config.num_q_heads) 315 | * self.config.num_kv_heads 316 | * self.config.hidden_dim 317 | ) 318 | return ( 319 | self.config.num_layers * (attn_qo + attn_kv) 320 | + self.get_params_mlp() 321 | + self.config.vocab_size * self.config.hidden_dim 322 | ) 323 | 324 | def get_params_woembedding( 325 | self, 326 | ): 327 | attn_qo = 2 * self.config.attn_dim * self.config.hidden_dim 328 | attn_kv = ( 329 | 2 330 | * (self.config.attn_dim // self.config.num_q_heads) 331 | * self.config.num_kv_heads 332 | * self.config.hidden_dim 333 | ) 334 | return self.config.num_layers * (attn_qo + attn_kv) + self.get_params_mlp() 335 | 336 | def get_flops_mlp(self, bs, seq): 337 | # as they're all linear layers. The flops just scales with the parameters 338 | mlp = 0 339 | for layer in self.model.layers: 340 | for para in layer.mlp.parameters(): 341 | if len(para.shape) != 1: 342 | mlp += para.numel() 343 | return 2 * mlp * bs * seq 344 | 345 | def get_params_mlp( 346 | self, 347 | ): 348 | mlp = 0 349 | for layer in self.model.layers: 350 | for para in layer.mlp.parameters(): 351 | if len(para.shape) != 1: 352 | mlp += para.numel() 353 | return mlp 354 | 355 | def forward( 356 | self, 357 | input_ids: torch.LongTensor = None, 358 | labels: Optional[torch.LongTensor] = None, 359 | inference_params: Optional[InferenceParams] = None, 360 | use_cache=False, 361 | ): 362 | out = self.model(input_ids, inference_params, use_cache) 363 | lm_logits = self.lm_head(out) 364 | loss = None 365 | if labels is not None: 366 | loss = F.cross_entropy( 367 | lm_logits.view(-1, lm_logits.size(-1)).contiguous(), 368 | labels.view(-1).contiguous(), 369 | ) 370 | if loss is not None: 371 | return loss 372 | else: 373 | return lm_logits 374 | 375 | def prepare_inference_params( 376 | self, batch_size, mx_seq, torch_dtype=torch.bfloat16, device="cuda" 377 | ): 378 | # mx_seq is composed of prefill and generation length 379 | inference_params = InferenceParams(batch_size, mx_seq) 380 | inf_max_seq_len = inference_params.max_sequence_length 381 | inf_max_batch_size = inference_params.max_batch_size 382 | for i in range(self.config.num_layers): 383 | inference_key_memory = torch.empty( 384 | inf_max_batch_size, 385 | inf_max_seq_len, 386 | self.config.num_kv_heads, 387 | (self.config.attn_dim // self.config.num_q_heads), 388 | dtype=torch_dtype, 389 | device=device, 390 | ) 391 | inference_value_memory = torch.empty( 392 | inf_max_batch_size, 393 | inf_max_seq_len, 394 | self.config.num_kv_heads, 395 | (self.config.attn_dim // self.config.num_q_heads), 396 | dtype=torch_dtype, 397 | device=device, 398 | ) 399 | inference_params.key_value_memory_dict[i] = ( 400 | inference_key_memory, 401 | inference_value_memory, 402 | ) 403 | return inference_params 404 | -------------------------------------------------------------------------------- /src/modules/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .low_rank import low_rank_custom 2 | from .block_dense import block_dense_custom, block_dense_bmm 3 | from .block_shuffle import ( 4 | block_shuffle_custom, 5 | block_shuffle_bmm, 6 | block_shuffle_einsum, 7 | ) 8 | 9 | from .common.fused_gelu import bias_gelu_impl 10 | from .common.fused_swiglu import bias_swiglu_impl 11 | from .common.fused_bias_dropout_add import bias_dropout_add_impl 12 | from .common.rotary_embeddings import RotaryEmbedding, apply_rotary_pos_emb 13 | -------------------------------------------------------------------------------- /src/modules/op/block_dense.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def block_dense_bmm(input, blkdiag, linear): 6 | batch_shape, h = input.shape[:-1], input.shape[-1] 7 | batch_dim = np.prod(batch_shape) 8 | k, q, p = blkdiag.shape 9 | l, r = linear.shape 10 | assert k * p == h 11 | assert r == k * q 12 | input = input.reshape(batch_dim, k, p).transpose(0, 1) 13 | out1 = torch.bmm(input, blkdiag.transpose(-1, -2)) 14 | out1 = out1.transpose(0, 1).reshape(batch_dim, r) 15 | out2 = torch.mm(out1, linear.transpose(-1, -2)).reshape(*batch_shape, l) 16 | return out2 17 | 18 | 19 | class BlockDenseCustom(torch.autograd.Function): 20 | """This is a faster implementation, with careful memory copies for the fastest 21 | bmm performance. 22 | The backward pass is also written manually with careful memory copies. 23 | Arguments: 24 | x: (batch, n) 25 | w1_bfly: (k, q, p), where k = n / p 26 | w2_bfly: (l, s, r), where l = k * q / r = n * q / (p * r) 27 | Outputs: 28 | out: (batch, m), where m = l * s = n * s * q / (p * r) 29 | """ 30 | 31 | @staticmethod 32 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.bfloat16) 33 | def forward(ctx, x, w1_bfly, linear): 34 | # due to bugs in torch.bmm with specific out dtype, we need to change the weight dtype here by hand 35 | # note that this only changes the weight dtype in this scope 36 | batch_shape, n = x.shape[:-1], x.shape[-1] 37 | batch_dim = np.prod(batch_shape) 38 | k, q, p = w1_bfly.shape 39 | l, r = linear.shape 40 | assert k * p == n 41 | assert r == k * q 42 | x_reshaped = x.reshape(batch_dim, k, p).transpose(0, 1) 43 | out1 = torch.empty(batch_dim, k, q, device=x.device, dtype=x.dtype).transpose( 44 | 0, 1 45 | ) 46 | out1 = torch.bmm(x_reshaped, w1_bfly.transpose(-1, -2), out=out1) 47 | out1 = out1.transpose(0, 1).reshape(batch_dim, r) 48 | out2 = torch.mm(out1, linear.transpose(-1, -2)).reshape(*batch_shape, l) 49 | ctx.save_for_backward(x, w1_bfly, linear, out1) 50 | return out2 51 | 52 | @staticmethod 53 | @torch.cuda.amp.custom_bwd 54 | def backward(ctx, dout): 55 | x, w1_bfly, linear, out1 = ctx.saved_tensors 56 | batch_shape, n = x.shape[:-1], x.shape[-1] 57 | batch_dim = np.prod(batch_shape) 58 | k, q, p = w1_bfly.shape 59 | l, r = linear.shape 60 | 61 | dx, dw1_bfly, dw2_linear = None, None, None 62 | dout_reshaped = dout.reshape(batch_dim, l) 63 | if ctx.needs_input_grad[2]: 64 | dw2_linear = torch.mm(dout_reshaped.transpose(-1, -2), out1) 65 | if ctx.needs_input_grad[1] or ctx.needs_input_grad[0]: 66 | dout1 = ( 67 | torch.mm(dout_reshaped, linear).reshape(batch_dim, k, q).transpose(0, 1) 68 | ) 69 | if ctx.needs_input_grad[0]: 70 | dx = torch.empty( 71 | batch_dim, k, p, device=x.device, dtype=x.dtype 72 | ).transpose(0, 1) 73 | dx = ( 74 | torch.bmm(dout1, w1_bfly, out=dx) 75 | .transpose(0, 1) 76 | .reshape(*batch_shape, n) 77 | ) 78 | if ctx.needs_input_grad[1]: 79 | x_reshaped = x.reshape(batch_dim, k, p).transpose(0, 1) 80 | dw1_bfly = torch.bmm(dout1.transpose(-1, -2), x_reshaped) 81 | return dx, dw1_bfly, dw2_linear 82 | 83 | 84 | block_dense_custom = BlockDenseCustom.apply 85 | -------------------------------------------------------------------------------- /src/modules/op/block_shuffle.py: -------------------------------------------------------------------------------- 1 | # paste from Monarch by TriDao 2 | import torch 3 | import numpy as np 4 | from einops import rearrange 5 | 6 | 7 | def block_shuffle_einsum(input, blkdiag1, blkdiag2): 8 | batch_shape, h = input.shape[:-1], input.shape[-1] 9 | batch_dim = np.prod(batch_shape) 10 | k, q, p = blkdiag1.shape 11 | l, s, r = blkdiag2.shape 12 | assert k * p == h 13 | assert l * r == k * q 14 | input = input.reshape(batch_dim, k, p) 15 | out1 = torch.einsum("kqp,bkp->bkq", blkdiag1, input) 16 | out1 = rearrange(rearrange(out1, "b k q -> b (k q)"), "b (r l) -> b l r", l=l) 17 | return torch.einsum("lsr,blr->bsl", blkdiag2, out1).reshape(*batch_shape, s * l) 18 | 19 | 20 | def block_shuffle_bmm(input, blkdiag1, blkdiag2): 21 | batch_shape, h = input.shape[:-1], input.shape[-1] 22 | batch_dim = np.prod(batch_shape) 23 | k, q, p = blkdiag1.shape 24 | l, s, r = blkdiag2.shape 25 | assert k * p == h 26 | assert l * r == k * q 27 | input = input.reshape(batch_dim, k, p).transpose(0, 1) 28 | out1 = torch.bmm(input, blkdiag1.transpose(-1, -2)) 29 | out1 = ( 30 | out1.transpose(0, 1) 31 | .reshape(batch_dim, r, l) 32 | .transpose(-1, -2) 33 | .contiguous() 34 | .transpose(0, 1) 35 | ) 36 | out2 = torch.bmm(out1, blkdiag2.transpose(-1, -2)) 37 | out2 = out2.permute(1, 2, 0).reshape(*batch_shape, s * l) 38 | return out2 39 | 40 | 41 | class BlockShuffleCustom(torch.autograd.Function): 42 | # Paste from monarch repo 43 | """This is a faster implementation, with careful memory copies for the fastest 44 | bmm performance. 45 | The backward pass is also written manually with careful memory copies. 46 | Arguments: 47 | x: (batch, n) 48 | w1_bfly: (k, q, p), where k = n / p 49 | w2_bfly: (l, s, r), where l = k * q / r = n * q / (p * r) 50 | Outputs: 51 | out: (batch, m), where m = l * s = n * s * q / (p * r) 52 | """ 53 | 54 | @staticmethod 55 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.bfloat16) 56 | def forward(ctx, x, w1_bfly, w2_bfly): 57 | batch_shape, n = x.shape[:-1], x.shape[-1] 58 | batch_dim = np.prod(batch_shape) 59 | k, q, p = w1_bfly.shape 60 | l, s, r = w2_bfly.shape 61 | assert k * p == n 62 | assert l * r == k * q 63 | x_reshaped = x.reshape(batch_dim, k, p).transpose(0, 1) 64 | out1 = torch.empty(batch_dim, k, q, device=x.device, dtype=x.dtype).transpose( 65 | 0, 1 66 | ) 67 | out1 = torch.bmm(x_reshaped, w1_bfly.transpose(-1, -2), out=out1) 68 | out1 = ( 69 | out1.transpose(0, 1) 70 | .reshape(batch_dim, r, l) 71 | .transpose(-1, -2) 72 | .contiguous() 73 | .transpose(0, 1) 74 | ) 75 | out2 = torch.empty(batch_dim, l, s, device=x.device, dtype=x.dtype).transpose( 76 | 0, 1 77 | ) 78 | out2 = torch.bmm(out1, w2_bfly.transpose(-1, -2), out=out2) 79 | out2 = out2.permute(1, 2, 0).reshape(*batch_shape, s * l) 80 | ctx.save_for_backward(x, w1_bfly, w2_bfly, out1) 81 | return out2 82 | 83 | @staticmethod 84 | @torch.cuda.amp.custom_bwd 85 | def backward(ctx, dout): 86 | x, w1_bfly, w2_bfly, out1 = ctx.saved_tensors 87 | batch_shape, n = x.shape[:-1], x.shape[-1] 88 | batch_dim = np.prod(batch_shape) 89 | k, q, p = w1_bfly.shape 90 | l, s, r = w2_bfly.shape 91 | # assert k * p == n 92 | # assert l * r == k * q 93 | dx, dw1_bfly, dw2_bfly = None, None, None 94 | # dout_reshaped = dout.reshape(batch_dim, sqrtn, sqrtn).permute(2, 1, 0).contiguous() 95 | dout_reshaped = dout.reshape(batch_dim, s, l).transpose(-1, -2).contiguous() 96 | dout_reshaped = dout_reshaped.transpose(0, 1) 97 | if ctx.needs_input_grad[2]: 98 | # dw2_bfly = torch.empty(l, s, r, device=w2_bfly.device, dtype=w2_bfly.dtype) 99 | # dw2_bfly = torch.bmm(dout_reshaped.transpose(-1, -2), out1, out=dw2_bfly) 100 | dw2_bfly = torch.bmm(dout_reshaped.transpose(-1, -2), out1) 101 | if ctx.needs_input_grad[1] or ctx.needs_input_grad[0]: 102 | dout1 = torch.empty( 103 | batch_dim, l, r, device=x.device, dtype=x.dtype 104 | ).transpose(0, 1) 105 | dout1 = torch.bmm(dout_reshaped, w2_bfly, out=dout1) 106 | dout1 = ( 107 | dout1.transpose(0, 1) 108 | .transpose(-1, -2) 109 | .contiguous() 110 | .reshape(batch_dim, k, q) 111 | .transpose(0, 1) 112 | ) 113 | # dout1 = dout1.permute(1, 2, 0).contiguous().transpose(0, 1) 114 | if ctx.needs_input_grad[0]: 115 | dx = torch.empty(batch_dim, k, p, device=x.device, dtype=x.dtype) 116 | dx = ( 117 | torch.bmm(dout1, w1_bfly, out=dx.transpose(0, 1)) 118 | .transpose(0, 1) 119 | .reshape(*batch_shape, n) 120 | ) 121 | if ctx.needs_input_grad[1]: 122 | x_reshaped = x.reshape(batch_dim, k, p).transpose(0, 1) 123 | dw1_bfly = torch.bmm(dout1.transpose(-1, -2), x_reshaped) 124 | return dx, dw1_bfly, dw2_bfly 125 | 126 | 127 | block_shuffle_custom = BlockShuffleCustom.apply 128 | -------------------------------------------------------------------------------- /src/modules/op/common/fused_bias_dropout_add.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 2 | import os 3 | from typing import Optional, Tuple 4 | import torch 5 | 6 | 7 | jit_fuser = torch.jit.script 8 | if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): 9 | jit_fuser = torch.compile 10 | 11 | 12 | def _bias_dropout_add_func(x_with_bias, residual, prob, training): 13 | x, bias = x_with_bias # unpack 14 | 15 | # If we want to train mixed precision, then the output of this function 16 | # should be half precision. However, in AMP O1, the input (residual) is 17 | # in fp32, and it will up-cast the result to fp32, causing pipeline parallel 18 | # GPU communication to hang. Therefore, we need to cast residual to the same 19 | # dtype as x. 20 | residual = residual if residual.dtype == x.dtype else residual.to(x.dtype) 21 | 22 | if bias is not None: 23 | x = x + bias 24 | out = torch.nn.functional.dropout(x, p=prob, training=training) 25 | out = residual + out 26 | return out 27 | else: 28 | out = torch.nn.functional.dropout(x, p=prob, training=training) 29 | out = residual + out 30 | return out 31 | 32 | 33 | @jit_fuser 34 | def bias_dropout_add_fused_train( 35 | x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], 36 | residual: torch.Tensor, 37 | prob: float, 38 | ) -> torch.Tensor: 39 | return _bias_dropout_add_func(x_with_bias, residual, prob, True) 40 | 41 | 42 | @jit_fuser 43 | def bias_dropout_add_fused_inference( 44 | x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], 45 | residual: torch.Tensor, 46 | prob: float, 47 | ) -> torch.Tensor: 48 | return _bias_dropout_add_func(x_with_bias, residual, prob, False) 49 | 50 | 51 | def bias_dropout_add_impl(training): 52 | if training: 53 | return bias_dropout_add_fused_train 54 | else: 55 | return bias_dropout_add_fused_inference 56 | -------------------------------------------------------------------------------- /src/modules/op/common/fused_gelu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | # paste from megatron 3 | import os 4 | import torch 5 | from typing import Callable, Optional, Tuple 6 | 7 | 8 | jit_fuser = torch.jit.script 9 | if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): 10 | jit_fuser = torch.compile 11 | 12 | 13 | @jit_fuser 14 | def bias_gelu_fused_(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: 15 | """Bias-GeLU fused""" 16 | x = inp + bias 17 | return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) 18 | 19 | 20 | @jit_fuser 21 | def gelu_fused_(inp: torch.Tensor) -> torch.Tensor: 22 | """ 23 | GeLU fused, this is copy of bias_gelu_fused cause jit fusion doesn't allow conditioning. 24 | """ 25 | x = inp 26 | return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) 27 | 28 | 29 | @jit_fuser 30 | def dgelu_bgrad_fused_( 31 | grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor 32 | ) -> Tuple[torch.Tensor, torch.Tensor]: 33 | """Bgrad-Dgelu fused""" 34 | x = inp + bias 35 | tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) 36 | # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 37 | ff = 0.5 * x * ( 38 | (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x) 39 | ) + 0.5 * (1 + tanh_out) 40 | dgelu = ff * grad_output 41 | bgrad = dgelu.sum(dim=0) 42 | return dgelu, bgrad 43 | 44 | 45 | @jit_fuser 46 | def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor: 47 | """ 48 | Dgelu fused, this is copy of bgrad_dgelu_fused_ cause jit fusion doesn't allow conditioning. 49 | """ 50 | x = inp 51 | tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) 52 | # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 53 | ff = 0.5 * x * ( 54 | (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x) 55 | ) + 0.5 * (1 + tanh_out) 56 | dgelu = ff * grad_output 57 | return dgelu 58 | 59 | 60 | class BiasGeLUFunction(torch.autograd.Function): 61 | @staticmethod 62 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) 63 | def forward(ctx, input, bias): 64 | ctx.save_for_backward(input, bias) 65 | return bias_gelu_fused_(input, bias) 66 | 67 | @staticmethod 68 | @torch.cuda.amp.custom_bwd 69 | def backward(ctx, grad_output): 70 | input, bias = ctx.saved_tensors 71 | return dgelu_bgrad_fused_(grad_output, input, bias) 72 | 73 | 74 | class GeLUFunction(torch.autograd.Function): 75 | @staticmethod 76 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) 77 | def forward(ctx, input): 78 | ctx.save_for_backward(input) 79 | return gelu_fused_(input) 80 | 81 | @staticmethod 82 | @torch.cuda.amp.custom_bwd 83 | def backward(ctx, grad_output): 84 | input = ctx.saved_tensors[0] 85 | return dgelu_fused_(grad_output, input) 86 | 87 | 88 | def bias_gelu_impl(input, bias): 89 | ori_shape = input.shape 90 | assert len(ori_shape) in [2, 3] 91 | input = input.view(-1, ori_shape[-1]) 92 | if bias is not None: 93 | output = BiasGeLUFunction.apply(input, bias) 94 | else: 95 | output = GeLUFunction.apply(input) 96 | return ( 97 | output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1) 98 | ) 99 | -------------------------------------------------------------------------------- /src/modules/op/common/fused_swiglu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | # paste from megatron 3 | import torch 4 | import torch.nn.functional as F 5 | import os 6 | 7 | jit_fuser = torch.jit.script 8 | if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): 9 | jit_fuser = torch.compile 10 | 11 | 12 | @jit_fuser 13 | def swiglu(y): 14 | y_1, y_2 = torch.chunk(y, 2, -1) 15 | return F.silu(y_1) * y_2 16 | 17 | 18 | @jit_fuser 19 | def bias_swiglu(y, bias): 20 | y = y + bias 21 | return swiglu(y) 22 | 23 | 24 | @jit_fuser 25 | def swiglu_back(g, y): 26 | y_1, y_2 = torch.chunk(y, 2, -1) 27 | return torch.cat( 28 | ( 29 | g * torch.sigmoid(y_1) * (1 + y_1 * (1 - torch.sigmoid(y_1))) * y_2, 30 | g * F.silu(y_1), 31 | ), 32 | -1, 33 | ) 34 | 35 | 36 | @jit_fuser 37 | def bias_swiglu_back(g, y, bias): 38 | y = y + bias 39 | dy = swiglu_back(g, y) 40 | bgrad = dy.sum(dim=0) 41 | return dy, bgrad 42 | 43 | 44 | class BiasSwiGLUFunction(torch.autograd.Function): 45 | @staticmethod 46 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) 47 | def forward(ctx, input, bias): 48 | ctx.save_for_backward(input, bias) 49 | return bias_swiglu(input, bias) 50 | 51 | @staticmethod 52 | @torch.cuda.amp.custom_bwd 53 | def backward(ctx, grad_output): 54 | input, bias = ctx.saved_tensors 55 | return bias_swiglu_back(grad_output, input, bias) 56 | 57 | 58 | class SwiGLUFunction(torch.autograd.Function): 59 | @staticmethod 60 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) 61 | def forward(ctx, input): 62 | ctx.save_for_backward(input) 63 | return swiglu(input) 64 | 65 | @staticmethod 66 | @torch.cuda.amp.custom_bwd 67 | def backward(ctx, grad_output): 68 | input = ctx.saved_tensors[0] 69 | return swiglu_back(grad_output, input) 70 | 71 | 72 | def bias_swiglu_impl(input, bias): 73 | ori_shape = input.shape 74 | assert len(ori_shape) in [2, 3] 75 | input = input.view(-1, ori_shape[-1]) 76 | if bias is not None: 77 | output = BiasSwiGLUFunction.apply(input, bias) 78 | else: 79 | output = SwiGLUFunction.apply(input) 80 | 81 | return ( 82 | output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1) 83 | ) 84 | -------------------------------------------------------------------------------- /src/modules/op/common/rotary_embeddings.py: -------------------------------------------------------------------------------- 1 | from apex.transformer.functional import fused_apply_rotary_pos_emb 2 | import torch.nn as nn 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | class RotaryEmbedding(nn.Module): 8 | """Rotary Embedding for language model. 9 | 10 | Args: 11 | kv_channels (int): Projection weights dimension in multi-head attention. Obtained from transformer config 12 | seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE for longer sequences. The value must be a float larger than 1.0. Defaults to None 13 | rotary_base (int, optional): Base period for rotary position embeddings. Defaults to 10000. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | kv_channels: int, 19 | rotary_interleaved: bool = False, 20 | seq_len_interpolation_factor: float = None, 21 | rotary_base: int = 10000, 22 | ) -> None: 23 | super().__init__() 24 | 25 | dim = kv_channels 26 | self.rotary_interleaved = rotary_interleaved 27 | 28 | self.seq_len_interpolation_factor = seq_len_interpolation_factor 29 | self.inv_freq = 1.0 / ( 30 | rotary_base 31 | ** ( 32 | torch.arange( 33 | 0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device() 34 | ) 35 | / dim 36 | ) 37 | ) 38 | 39 | def forward(self, max_seq_len: int, offset: int = 0) -> Tensor: 40 | """Forward pass of RoPE embedding. 41 | 42 | Args: 43 | max_seq_len (int): Maximum size of sequence 44 | offset (int, optional): _description_. Defaults to 0. 45 | 46 | Returns: 47 | Tensor: Embeddings after applying RoPE. 48 | """ 49 | seq = ( 50 | torch.arange( 51 | max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype 52 | ) 53 | + offset 54 | ) 55 | 56 | if self.seq_len_interpolation_factor is not None: 57 | seq *= 1 / self.seq_len_interpolation_factor 58 | 59 | freqs = torch.outer(seq, self.inv_freq) 60 | # first part even vector components, second part odd vector components, 61 | # 2 * dim in dimension size 62 | if not self.rotary_interleaved: 63 | emb = torch.cat((freqs, freqs), dim=-1) 64 | else: 65 | emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view( 66 | freqs.shape[0], -1 67 | ) 68 | # emb [seq_length, .., dim] 69 | emb = emb[:, None, None, :] 70 | return emb 71 | 72 | def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): 73 | state_dict.pop(f"{prefix}inv_freq", None) 74 | return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) 75 | 76 | def get_rotary_seq_len( 77 | self, 78 | inference_params, 79 | transformer_input, 80 | ) -> float: 81 | 82 | if inference_params is not None: 83 | return inference_params.max_sequence_length 84 | return transformer_input.shape[1] 85 | 86 | 87 | def apply_rotary_pos_emb( 88 | t: Tensor, 89 | freqs: Tensor, 90 | ): 91 | # bshd -> sbhd 92 | return fused_apply_rotary_pos_emb( 93 | t.transpose(0, 1), freqs, transpose_output_memory=True 94 | ).transpose(0, 1) 95 | -------------------------------------------------------------------------------- /src/modules/op/low_rank.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def low_rank_custom(input, linear1, linear2): 6 | batch_shape, h = input.shape[:-1], input.shape[-1] 7 | batch_dim = np.prod(batch_shape) 8 | input = input.reshape(batch_dim, h) 9 | out2 = torch.mm( 10 | torch.mm(input, linear1.transpose(-1, -2)), linear2.transpose(-1, -2) 11 | ).reshape(*batch_shape, -1) 12 | return out2 13 | -------------------------------------------------------------------------------- /src/optimization/__init__.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | from .scheduler import * 3 | 4 | 5 | name_to_scheduler = { 6 | "multisteplr": lambda optimizer, kwargs: _MultiStepLR(optimizer, **kwargs), 7 | "cosineannealinglr": lambda optimizer, kwargs: _CosineAnnealingLR( 8 | optimizer, **kwargs 9 | ), 10 | } 11 | 12 | name_to_optimizer = { 13 | "adam": lambda params, kwargs: optim.Adam(params, **kwargs), 14 | "sgd": lambda params, kwargs: optim.SGD(params, **kwargs), 15 | "adamw": lambda params, kwargs: optim.AdamW(params, **kwargs), 16 | } 17 | 18 | 19 | def get_lr_scheduler(config_optimization, optimizer): 20 | name = config_optimization.lr_scheduler.name.lower() 21 | return name_to_scheduler[name](optimizer, config_optimization.lr_scheduler.kwargs) 22 | 23 | 24 | def get_optimizer(config_optimization, params): 25 | name = config_optimization.optimizer.name.lower() 26 | return name_to_optimizer[name](params, config_optimization.optimizer.kwargs) 27 | -------------------------------------------------------------------------------- /src/optimization/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import ( 3 | MultiStepLR, 4 | CosineAnnealingLR, 5 | ) 6 | 7 | 8 | __all__ = [ 9 | "_MultiStepLR", 10 | "_CosineAnnealingLR", 11 | ] 12 | 13 | 14 | class _MultiStepLR(MultiStepLR): 15 | 16 | def __init__(self, optimizer, **kwargs): 17 | kwargs["milestones"] = [ 18 | int(e * kwargs.pop("T_max")) for e in kwargs["milestones"] 19 | ] 20 | super(_MultiStepLR, self).__init__(optimizer, **kwargs) 21 | 22 | 23 | class _CosineAnnealingLR(CosineAnnealingLR): 24 | def __init__(self, optimizer, **kwargs): 25 | self.warmup_iter = 0 26 | if "warmup_iter" in kwargs: 27 | self.warmup_iter = int(kwargs.pop("warmup_iter") * kwargs["T_max"]) 28 | super(_CosineAnnealingLR, self).__init__(optimizer, **kwargs) 29 | 30 | def get_lr(self): 31 | if self.last_epoch < self.warmup_iter: 32 | return [ 33 | (self.last_epoch + 1) / self.warmup_iter * base_lr 34 | for base_lr in self.base_lrs 35 | ] 36 | 37 | return [ 38 | self.eta_min 39 | + (base_lr - self.eta_min) 40 | * ( 41 | 1 42 | + math.cos( 43 | math.pi 44 | * (self.last_epoch - self.warmup_iter) 45 | / (self.T_max - self.warmup_iter) 46 | ) 47 | ) 48 | / 2 49 | for base_lr in self.base_lrs 50 | ] 51 | -------------------------------------------------------------------------------- /src/optimization/trainer.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import glob 4 | import wandb 5 | import shutil 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from easydict import EasyDict 10 | from omegaconf import open_dict 11 | import torch.distributed as dist 12 | from src.modules import get_ckpt_name 13 | 14 | 15 | class WandbLog: 16 | def __init__(self, config, metric, x_axis="epoch"): 17 | self.config = config 18 | for k, v in metric.items(): 19 | if k == x_axis: 20 | wandb.define_metric(x_axis) 21 | else: 22 | wandb.define_metric(k, step_metric=x_axis) 23 | 24 | def record(self, item): 25 | wandb.log(item) 26 | 27 | 28 | class TrainableModel: 29 | 30 | def __init__(self, config): 31 | self.config = config 32 | self.epoch = -1 33 | self.step = -1 34 | self.max_epoch = self.config.optimization.max_epoch 35 | self.max_step = None # define in specific Trainer 36 | 37 | # gpu setting 38 | self.gpu_id = int(os.getenv("RANK", -1)) 39 | self.device = ( 40 | torch.device("cuda", self.gpu_id) 41 | if self.gpu_id != -1 42 | else torch.device("cuda") 43 | ) 44 | self.ngpus = dist.get_world_size() if self.gpu_id != -1 else 1 45 | print("The device is {} out of {}".format(self.device, self.ngpus)) 46 | 47 | self.global_batch_size = getattr( 48 | self.config.optimization, 49 | "global_batch_size", 50 | self.config.data.train.train_batch, 51 | ) 52 | assert ( 53 | self.global_batch_size % (self.ngpus * self.config.data.train.train_batch) 54 | == 0 55 | ) 56 | self.gradient_accumulation_steps = self.global_batch_size // ( 57 | self.ngpus * self.config.data.train.train_batch 58 | ) 59 | 60 | self.log_interval = getattr(self.config.optimization, "log_interval", False) 61 | self.check_gradient_norm = getattr( 62 | self.config.optimization, "check_gradient_norm", False 63 | ) 64 | self.check_weight_norm = getattr( 65 | self.config.optimization, "check_weight_norm", False 66 | ) 67 | self.gradient_clipping = getattr( 68 | self.config.optimization, "gradient_clipping", False 69 | ) 70 | self.special_training = ( 71 | self.config.optimization.training.name == "self_guided_training" 72 | ) 73 | # save 74 | self.is_save_checkpoint = getattr( 75 | self.config.optimization, "save_checkpoint", False 76 | ) 77 | self.is_load_checkpoint = getattr( 78 | self.config.optimization, "load_checkpoint", False 79 | ) 80 | self.load_save_mode = getattr( 81 | self.config.optimization, "load_save_mode", "epoch" 82 | ) 83 | 84 | def prepare_load_save( 85 | self, 86 | ): 87 | if self.is_save_checkpoint or self.is_load_checkpoint: 88 | long_name = get_ckpt_name(self.config) + "-" + str(self.special_training) 89 | if self.special_training: 90 | long_name += ( 91 | "-" 92 | + self.config.optimization.training.kwargs.mode 93 | + "-" 94 | + str(self.config.optimization.training.kwargs.reduce_flop) 95 | ) 96 | self.save_dir = os.path.join(self.config.optimization.save_dir, long_name) 97 | self.save_dir = os.path.join( 98 | self.save_dir, 99 | str(self.config.optimization.optimizer.kwargs.lr).replace(".", "x"), 100 | ) 101 | if self.load_save_mode == "epoch": 102 | self.save_interval = self.max_epoch // 10 103 | elif self.load_save_mode == "step": 104 | self.save_interval = self.max_step // 10 105 | else: 106 | raise NotImplementedError 107 | print( 108 | "plan to save or load checkpoint in {} for each {} in the mode {}".format( 109 | self.save_dir, self.save_interval, self.load_save_mode 110 | ) 111 | ) 112 | if not self.is_load_checkpoint: 113 | shutil.rmtree(self.save_dir) 114 | if not os.path.exists(self.save_dir): 115 | os.makedirs(self.save_dir) 116 | 117 | def set_gradient_clipping( 118 | self, 119 | ): 120 | if self.gradient_clipping is not False: 121 | torch.nn.utils.clip_grad_norm_( 122 | self.model.parameters(), self.gradient_clipping 123 | ) 124 | 125 | def get_info( 126 | self, 127 | ): 128 | nparam = self.get_nparam() 129 | nflops = self.model.get_flops( 130 | self.global_batch_size, 131 | self.block_size, 132 | ) # we consider all the matrix multiplication including the final logits in the model 133 | total_flops = nflops * self.max_step 134 | if self.special_training: 135 | guide_params = sum( 136 | [ 137 | p.guide_linear.numel() 138 | for p in self.model.modules() 139 | if hasattr(p, "guide_linear") 140 | ] 141 | ) 142 | # print("the number of guide parameters are {:.2f}".format(guide_params)) 143 | guide_flops = ( 144 | 2 * guide_params * self.global_batch_size * self.block_size 145 | ) # addition and multiplication 146 | total_flops -= guide_flops * (self.max_step - self.guided_steps) 147 | if self.config.optimization.training.kwargs.reduce_flop: 148 | total_flops -= 0.5 * guide_flops * self.guided_steps 149 | print("The total parameter is {:.2f} M".format(nparam / 10**6)) 150 | print( 151 | "FLOPs information: flops per forward step {:.2f}T, total flops {:.2f}T".format( 152 | nflops / 10**12, 153 | total_flops * 3 / 10**12, # backward and forward 154 | ) 155 | ) 156 | nparam_mlp = self.model.get_params_mlp() 157 | nflops_mlp = self.model.get_flops_mlp( 158 | self.global_batch_size, 159 | self.block_size, 160 | ) 161 | print( 162 | "MLP information: params {:.2f}M, flops per step {:.2f}T".format( 163 | nparam_mlp / 10**6, 164 | nflops_mlp / 10**12, 165 | ) 166 | ) 167 | 168 | print(self.model) 169 | 170 | def get_nparam( 171 | self, 172 | ): 173 | self.nparam = sum(param.numel() for param in self.model.parameters()) 174 | return self.nparam 175 | 176 | def set_seed(self, seed): 177 | random.seed(seed) 178 | os.environ["PYTHONHASHSEED"] = str(seed) 179 | np.random.seed(seed) 180 | torch.manual_seed(seed) 181 | torch.cuda.manual_seed(seed) 182 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 183 | torch.backends.cudnn.benchmark = False 184 | torch.backends.cudnn.deterministic = True 185 | 186 | def set_self_guided_training( 187 | self, 188 | ): 189 | self.repeat_steps = 0 190 | self.guided_steps = 0 191 | if self.special_training: 192 | self.guided_steps = int( 193 | self.max_step * self.config.optimization.training.kwargs.max_step_ratio 194 | ) 195 | with open_dict(self.config.method.kwargs) as f: 196 | f.training.enabled = True 197 | f.training.scheduler = ( 198 | self.config.optimization.training.kwargs.scheduler 199 | ) 200 | f.training.max_step = self.guided_steps 201 | f.training.reduce_flop = ( 202 | self.config.optimization.training.kwargs.reduce_flop 203 | ) 204 | if self.config.optimization.training.kwargs.mode == "fixedflop": 205 | self.repeat_steps = self.guided_steps 206 | self.max_step += self.repeat_steps 207 | elif self.config.method.name != "linear": 208 | with open_dict(self.config.method.kwargs) as f: 209 | f.training.enabled = False 210 | 211 | def close_self_guided_training( 212 | self, 213 | ): 214 | from src.modules.layer.basiclinear import BasicLinear 215 | 216 | self.special_training = False 217 | for name, module in self.model.named_modules(): 218 | if isinstance(module, BasicLinear): 219 | module.training_config.enabled = False 220 | 221 | def get_optimize_param( 222 | self, 223 | ): 224 | params = [{"params": self.model.parameters()}] 225 | return params 226 | 227 | def save_checkpoint(self, **resume_kwargs): 228 | # save checkpoint by epoch 229 | if not self.is_save_checkpoint or self.gpu_id not in [-1, 0]: 230 | return 231 | if self.load_save_mode == "epoch": 232 | cur = self.epoch 233 | cur_max = self.max_epoch 234 | elif self.load_save_mode == "step": 235 | cur = self.step 236 | cur_max = self.max_step 237 | if (cur + 1) % self.save_interval == 0 or cur + 1 == cur_max: 238 | ckpt_path = os.path.join( 239 | self.save_dir, 240 | f"{cur}.pth", 241 | ) 242 | ckpt = { 243 | "model": ( 244 | self.model.module.state_dict() 245 | if self.gpu_id == 0 246 | else self.model.state_dict() 247 | ), 248 | self.load_save_mode: cur, 249 | "config": self.config, 250 | "nparam": self.nparam, 251 | "optimizer": self.optimizer.state_dict(), 252 | "lr_scheduler": ( 253 | self.lr_scheduler.state_dict() 254 | if getattr(self, "lr_scheduler", None) 255 | else None 256 | ), 257 | "resume_kwargs": resume_kwargs, 258 | } 259 | torch.save(ckpt, ckpt_path) 260 | 261 | def load_checkpoint(self): 262 | if not self.is_load_checkpoint: 263 | return {} 264 | 265 | def find_latest_checkpoint(): 266 | checkpoint_files = glob.glob( 267 | os.path.join( 268 | self.save_dir, 269 | f"*.pth", 270 | ) 271 | ) 272 | if not checkpoint_files: 273 | return None 274 | 275 | latest_checkpoint_file = max(checkpoint_files, key=os.path.getctime) 276 | return latest_checkpoint_file 277 | 278 | latest_checkpoint = find_latest_checkpoint() 279 | if latest_checkpoint is not None: 280 | print("load checkpoint from {}".format(latest_checkpoint)) 281 | ckpt = torch.load(latest_checkpoint, map_location=self.device) 282 | self.model.load_state_dict(ckpt["model"]) 283 | self.optimizer.load_state_dict(ckpt["optimizer"]) 284 | if getattr(self, "lr_scheduler", None): 285 | self.lr_scheduler.load_state_dict(ckpt["lr_scheduler"]) 286 | if self.load_save_mode == "epoch": 287 | self.epoch = ckpt["epoch"] 288 | elif self.load_save_mode == "step": 289 | self.step = ckpt["step"] 290 | return ckpt["resume_kwargs"] 291 | return {} 292 | -------------------------------------------------------------------------------- /src/utils/refinedweb_llama.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, concatenate_datasets 2 | from transformers import AutoTokenizer 3 | from argparse import ArgumentParser 4 | import tiktoken 5 | import os 6 | from itertools import chain 7 | import numpy as np 8 | from tqdm import tqdm 9 | from transformers import LlamaTokenizer 10 | 11 | long_path = "/claire-rcp-scratch/shared/xwei/dataset/tiiuae___falcon-refinedweb/default-4033b99bd924aaad/0.0.0/0111277fb19b16f696664cde7f0cb90f833dec72db2cc73cfdf87e697f78fe02" 12 | cache_dir = "/claire-rcp-scratch/shared/xwei/dataset" 13 | 14 | 15 | def tokenize(tokenizer, num_proc, dataset): 16 | if tokenizer == "gpt2": 17 | enc = tiktoken.get_encoding("gpt2") 18 | 19 | def tokenize_process(example): 20 | ids = enc.encode_ordinary( 21 | example["text"] 22 | ) # encode_ordinary ignores any special tokens 23 | ids.append( 24 | enc.eot_token 25 | ) # add the end of text token, e.g. 50256 for gpt2 bpe 26 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though... 27 | out = {"ids": ids} 28 | return out 29 | 30 | elif tokenizer == "llama": 31 | enc = LlamaTokenizer.from_pretrained("huggyllama/llama-7b") 32 | eos_tokens = enc( 33 | "", truncation=False, padding=False, add_special_tokens=False 34 | )["input_ids"] 35 | 36 | def tokenize_process(example): 37 | ids = enc( 38 | example["text"], 39 | truncation=False, 40 | padding=False, 41 | add_special_tokens=False, 42 | )["input_ids"] 43 | ids = ids + eos_tokens 44 | out = {"ids": ids} 45 | return out 46 | 47 | else: 48 | raise NotImplementedError 49 | 50 | tokenized = dataset.map( 51 | tokenize_process, 52 | remove_columns=["text", "url", "timestamp", "dump", "segment", "image_urls"], 53 | desc="tokenizing the splits", 54 | num_proc=num_proc, 55 | ) 56 | print(tokenized) 57 | return tokenized 58 | 59 | 60 | def group_context(block_size, num_proc, dataset): 61 | 62 | def group_process(examples): 63 | # Concatenate all texts. 64 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 65 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 66 | # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. 67 | # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. 68 | total_length = (total_length // block_size) * block_size 69 | # Split by chunks of max_len. 70 | result = { 71 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)] 72 | for k, t in concatenated_examples.items() 73 | } 74 | return result 75 | 76 | lm_datasets = dataset.map( 77 | group_process, 78 | batched=True, 79 | num_proc=num_proc, 80 | desc=f"Grouping texts in chunks of {block_size}", 81 | ) 82 | print(lm_datasets) 83 | return lm_datasets 84 | 85 | 86 | def save_to_npmemmap(split, dset, tokenizer, block_size): 87 | arr_len = dset.num_rows 88 | print(split, arr_len) 89 | filename = os.path.join( 90 | os.path.join(cache_dir, "refinedweb"), f"{tokenizer}-{split}-tmp.bin" 91 | ) 92 | dtype = np.uint16 # (can do since enc.max_token_value == 32000 is < 2**16) 93 | arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len, block_size)) 94 | total_batches = 1024 95 | 96 | idx = 0 97 | for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"): 98 | # Batch together samples for faster write 99 | batch = dset.shard( 100 | num_shards=total_batches, index=batch_idx, contiguous=True 101 | ).with_format("numpy") 102 | # Write into mmap 103 | arr_batch = np.stack(batch["ids"]) 104 | arr[idx : idx + arr_batch.shape[0], :] = arr_batch 105 | idx += arr_batch.shape[0] 106 | arr.flush() 107 | 108 | 109 | def parse_args(): 110 | parser = ArgumentParser( 111 | description="Convert dataset into MDS format, optionally concatenating and tokenizing" 112 | ) 113 | parser.add_argument("--tokenizer", type=str, required=True) 114 | parser.add_argument( 115 | "--block_size", 116 | type=int, 117 | help="Convert text to tokens and concatenate up to this many tokens", 118 | ) 119 | 120 | parser.add_argument("--num_proc", type=int, required=True, default=None) 121 | return parser.parse_args() 122 | 123 | 124 | def main(args): 125 | print(args.num_proc) 126 | new_dataset = [] 127 | for i in range(6): 128 | for j in range(10): 129 | if i == 5 and j > 3: 130 | continue 131 | refinedweb_chunk = load_dataset( 132 | path=long_path, 133 | split="train", 134 | data_files=f"falcon-refinedweb-train-0{i}{j}*-of-05379.arrow", 135 | num_proc=args.num_proc, 136 | ).shuffle(seed=i * 10 + j) 137 | print(refinedweb_chunk) 138 | total_rows = refinedweb_chunk.num_rows 139 | selected_rows = int(0.1 * total_rows) 140 | cur_chunk = refinedweb_chunk.select(range(selected_rows)).rename_column( 141 | "content", "text" 142 | ) 143 | del refinedweb_chunk 144 | print("begin to tokenize!") 145 | # tokenization 146 | cur_chunk = tokenize(args.tokenizer, args.num_proc, cur_chunk) 147 | cur_chunk = group_context(args.block_size, args.num_proc, cur_chunk) 148 | new_dataset.append(cur_chunk) 149 | 150 | new_dataset = concatenate_datasets(new_dataset) 151 | new_dataset = new_dataset.train_test_split(test_size=0.01, seed=1005, shuffle=True) 152 | 153 | save_to_npmemmap("train", new_dataset["train"], args.tokenizer, args.block_size) 154 | save_to_npmemmap("val", new_dataset["test"], args.tokenizer, args.block_size) 155 | 156 | 157 | if __name__ == "__main__": 158 | main(parse_args()) 159 | --------------------------------------------------------------------------------