123 |
124 | ## License
125 |
126 | Open sourced under the [MIT license](LICENSE.md).
127 |
128 | <3
129 |
--------------------------------------------------------------------------------
/docs/_config.yml:
--------------------------------------------------------------------------------
1 | # Setup
2 | title: StructuredFFN
3 | url: https://claire-labo.github.io
4 | paginate: 1
5 | baseurl: "/StructuredFFN"
6 | permalink: pretty
7 |
8 | # Gems
9 | plugins:
10 | - jekyll-gist
11 | - jekyll-paginate
12 | - jekyll-seo-tag
13 |
14 | # Optimize Jekyll
15 | exclude:
16 | - .editorconfig
17 | - .git
18 | - .jekyll-cache
19 | - Gemfile
20 | - Gemfile.lock
21 | - LICENSE.md
22 | - README.md
23 |
24 | sass:
25 | sass_dir: _sass
26 | style: :compressed
27 |
28 | # Options
29 |
30 | # Replace this value and uncomment to enable Google Analytics tracking
31 | # ga_analytics: UA-000000-0
32 |
33 | # Specify the author for blog posts
34 | author:
35 | name: Mark Otto
36 | url: https://twitter.com/mdo
37 | email: markdotto@gmail.com
38 |
39 | # Custom vars
40 | version: 3.0.0
41 |
42 | # # Navbar page list
43 | # nav:
44 | # - title: Blog
45 | # url: /archive
46 |
47 | # - title: About
48 | # url: /about
49 |
--------------------------------------------------------------------------------
/docs/_includes/head.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | {% if page.title == "Home" %}
7 | {{ site.title }}{% if site.tagline %} · {{ site.tagline }}{% endif %}
8 | {% else %}
9 | {{ page.title }} · {{ site.title }}
10 | {% endif %}
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 | {% seo title=false %}
19 |
20 |
--------------------------------------------------------------------------------
/docs/_includes/mathjax_support.html:
--------------------------------------------------------------------------------
1 |
21 |
24 |
--------------------------------------------------------------------------------
/docs/_layouts/default.html:
--------------------------------------------------------------------------------
1 |
2 |
3 | {% include head.html %}
4 | {% include mathjax_support.html %}
5 |
6 |
7 |
8 |
19 |
20 |
21 | {{ content }}
22 |
23 |
24 |
32 |
33 |
34 | {% if site.ga_analytics %}
35 |
58 | {% endif %}
59 |
60 |
61 |
--------------------------------------------------------------------------------
/docs/_layouts/page.html:
--------------------------------------------------------------------------------
1 | ---
2 | layout: default
3 | ---
4 |
5 |
6 | {{ page.title }}
7 | {{ content }}
8 |
9 |
--------------------------------------------------------------------------------
/docs/_layouts/post.html:
--------------------------------------------------------------------------------
1 | ---
2 | layout: default
3 | ---
4 |
5 |
6 | {{ page.title }}
7 | {{ page.date | date_to_string }}
8 | {{ content }}
9 |
10 |
11 | {% if site.related_posts != empty %}
12 |
25 | {% endif %}
26 |
--------------------------------------------------------------------------------
/docs/_posts/2024-06-28-StructuredFFN.md:
--------------------------------------------------------------------------------
1 | ---
2 | layout: post
3 | title: Building on Efficient Foundations Effectively Training LLMs with Structured Feedforward Layers
4 | ---
5 |
6 | **Author list: Xiuying Wei (CLAIRE, EPFL), Skander Moalla (CLAIRE, EPFL), Razvan Pascanu (Google DeepMind), Caglar Gulcehre (CLAIRE, EPFL)**
7 |
8 | ## Abstract
9 |
10 | State-of-the-art results in large language models (LLMs) often rely on scale, which becomes computationally expensive. This has sparked a research agenda to reduce these models' parameter count and computational costs without significantly impacting their performance. Our study focuses on transformer-based LLMs, specifically targeting the computationally intensive feedforward networks (FFN), which are less studied than attention blocks. We consider three candidate linear layer approximations in the FFN by combining efficient low-rank and block-diagonal matrices. In contrast to many previous works that examined these approximations, our study i) explores these structures from the training-from-scratch perspective, ii) scales up to 1.3B parameters, and iii) is conducted within recent Transformer-based LLMs rather than convolutional architectures. We first demonstrate they can lead to actual computational gains in various scenarios, including online decoding when using a pre-merge technique. Additionally, we propose a novel training regime, called *self-guided training*, aimed at improving the poor training dynamics that these approximations exhibit when used from initialization. Experiments on the large RefinedWeb dataset show that our methods are both efficient and effective for training and inference. Interestingly, these structured FFNs exhibit steeper scaling curves than the original models. Further applying self-guided training to the structured matrices with 32% FFN parameters and 2.5$\times$ speed-up enables only a 0.4 perplexity increase under the same training FLOPs. Finally, we develop the wide and structured networks surpassing the current medium-sized and large-sized Transformer in perplexity and throughput performance.
11 |
12 |
13 | ## Method
14 |
15 | ### Structured linear parametrization
16 | We consider three structured parameterizations to approximate a linear layer ($Wx$) as below which have demonstrated computational gains on existing hardware.
17 |
18 | * LowRank: $Wx \approx U^r(V^rx)$, where the superscript $^r$ is used to indicate matrices projecting in or from low dimensional states.
19 | * BlockShuffle (two block-diagonal matrices, same as Monarch [1]): $Wx \approx f^{-1}(U^b f(V^bx))$, where $V^b$ and $U^b$ are block-diagonal matrices and the shuffle function $f(\cdot)$ enables global feature mixing by cycling different blocks.
20 | * BlockDense (block-diagonal followed by a dense matrix): $Wx \approx U^r(V^bx)$. Technically, the second projection does not need to be a low-rank approximation to reduce the parameter. But in practice, we chose the low-rank one with superscript $r$ to limit our search space.
21 |
22 | The figure below shows how they perform and their reduced parameters and MAC.
23 |
24 |
25 |
26 |
27 |
28 | Then, we go deeper to investigate their common challenges including efficiency and optimization.
29 |
30 | ### Maintaining efficiency during online decoding
31 |
32 | Challenge: While they have demonstrated materialized computational gains, they face challenges in the practical online decoding scenario of LLM, which may process only limited input tokens at one time, leading to under-utilization of computing resources and decreased efficiency due to the additional linear projection.
33 |
34 | Pre-merge technique: We address this with a pre-merge technique that restores the original dense efficiency when the total number of tokens is quite small (e.g., 16). Taking advantage of the fact that these parametrizations do not have non-linearity, we propose to combine the structured matrices into a single dense layer and keep both the structured and the dense one for online decoding. Then, we can dynamically decide which parametrization to use based on the current batch size and setting.
35 |
36 |
37 |
38 | ### Addressing the optimization challenge
39 |
40 | Challenge: Using the efficient parametrization from initialization can suffer from optimization difficulty because the deep linear parametrization introduces additional symmetries, which is a source of proliferation of saddle points and generally less smooth loss function as pointed out in [2]. Empirically, we show that the deep linear form of $U(Vx)$ leads to instability and loss spike or to slow convergence compared to the dense linear projection in the figure below.
41 |
42 |
43 |
44 |
45 |
46 | Self-guided training: Addressing poor training dynamics by tuning the learning rate and gradient clipping is costly and unstable. We propose a simpler, cost-effective approach called self-guided training, requiring minimal hyperparameter re-tuning. This method uses dense parametrization to efficiently navigate early stages, where symmetries introduced by the structured parametrization impact feature specialization, then transfers the control to $U$ and $V$, defined as:
47 |
48 | $o = \alpha \cdot W x + (1-\alpha) \cdot U(Vx)$,
49 |
50 | $o$ is the layer's output, and $\alpha$ decays following a cosine scheduler. As a residual component, learning $W$ is unaffected by the additional saddles and pathologies, allowing units to specialize. This *guides* the training of $U$ and $V$, which are forced slowly to take over by providing the hidden units semantics learned by $W$. The loss curves above show that such a method makes the training dynamics much better.
51 |
52 | For more details, please check the paper.
53 |
54 | ## Experiments
55 |
56 | We conduct our experiments at scale on Transformers ranging from 110M to 1.3B parameters. We demonstrate the efficiency of these parametrizations, conduct a scaling analysis that structured matrices have steeper scaling curves compared to the dense ones, and validate that self-guided training can boost the final performance efficiently. Finally, we design the wide and structured networks by combing the GQA [4], improving both the perplexity and throughput.
57 |
58 | ### Evaluating latency results
59 |
60 | We investigate the efficiency of structured FFN and consider different numbers of tokens to discuss different scenarios.
61 |
62 | - Large number of tokens (usually concerning training, the prefill phase of inference, and extensive decoding cases)
63 |
64 | From width 1536, LowRank and BlockDense begin to enable about a 1.4$\times$ speed-up and a 2.5$\times$ speed-up with 63% and 32% parameters, respectively.
65 |
66 |
67 |
68 |
69 |
70 | - Small number of tokens (may happen at the decoding stage, especially for the online case)
71 |
72 | We vary the batch of tokens to determine when to use efficient alternatives or choose pre-merged dense matrices. For example, with a 2048-width FFN, it is difficult to fully utilize resources on GPU with limited tokens. The performance improves significantly when using width 5120 and 6144, such as speed improvements of 2.63$\times$ speed-up of LowRank with 32% FFN parameters on total number of tokens of 2048 and 2.81$\times$ acceleration of BlockDense with 32% parameters on 1536 tokens.
73 |
74 |
75 |
76 |
77 |
78 | ### Findings on efficient training
79 |
80 | - Comparison between structured FFNs
81 |
82 | With the model and training FLOPs fixed, we show that LowRank and BlockDense can be better than the BlockShuffle for FFN in NLP tasks. However, we think this is task-dependent, because in vision tasks where block-diagonal matrices are better for local information, we find that block-diagonal matrix is a more suitable inductive bias (see experiments in the appendix).
83 |
84 | 
85 |
86 |
87 |
88 | - Scaling analysis
89 |
90 | As we scale the model size, we find steeper scaling curves of structured matrices. Below, it's a figure for LowRank, but the other two hold similar curves. Specifically,
91 |
92 | *(i) The structured matrices exhibit steeper scaling curves compared to the dense networks, indicating significant potential for these efficient designs in LLMs.*
93 |
94 | *(ii) The scaling curve of 32\% parameters of FFN is steeper than the 63\% parameters of FFN highlights the scaling potential of highly structured large models.*
95 |
96 | *(iii) Given fixed training FLOPs budget, a wider and structured network with more tokens may achieve comparable or superior performance to dense networks at the optimal trade-off.*
97 |
98 |
99 |
100 |
101 | ### Self-guided training
102 |
103 | With the self-guided training, our performance gets closer to dense models. For example, with the same training FLOPs, our 1.3B model has a 0.4 perplexity loss vs. the dense one and enjoys about 2.5x FFN speed-up for inference. Additionally, we compare our method with another advanced baseline that trains structured parametrizations with more tokens, showing that ours achieves comparable or superior results even with the same number of tokens.
104 |
105 |
106 |
107 |
108 |
109 | ### Wide and Structured network
110 |
111 | As maintaining the parameter ratio of attention to FFN can be important, in this section, we use GQA to make attention efficient and LowRank for FFN, designing a wide and structured network from Transformer-m and Transformer-l. To match the training FLOPs, we either train on more tokens or apply self-guided training.
112 |
113 | It can be seen that our methods achieve an 8% and 17% maximum throughput boost, respectively, while maintaining or slightly improving perplexity. TP refers to the maximum throughput measured on a generation length of 256.
114 |
115 |
116 |
117 |
118 |
119 | ## Conclusion and Limitation
120 |
121 | Conclusion: In this paper, we conducted extensive experiments investigating the use of structured matrices to parameterize FFN in Transformers, with models up to 1.3B parameters on the RefinedWeb dataset. Our primary aim was not to determine which structured matrices perform best, as this can be task-dependent, but to explore common issues including efficiency and optimization challenges of existing structured matrices as well as BlockDense.
122 |
123 | Limitation: BlockDense and BlockShuffle are more complicated than LowRank. In this work, we only explored a limited range of hyperparameter settings of them. Also, we primarily focused on language modeling with limited vision experiments included in the appendix. Additionally, we did not explore the optimal scaling laws for structured matrices, which may further enhance performance.
124 |
125 | ## References
126 |
127 | [1]. Monarch: Expressive Structured Matrices for Efficient and Accurate Training. ICML2022
128 |
129 | [2]. Exact solutions to the nonlinear dynamics of learning in deep linear neural networks. ICLR2016
130 |
131 | [3]. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP2023
132 |
133 | ## Useful Links
134 |
135 | Paper:https://arxiv.org/pdf/2406.16450
136 |
137 | Code: https://github.com/CLAIRE-Labo/StructuredFFN
--------------------------------------------------------------------------------
/docs/_sass/_base.scss:
--------------------------------------------------------------------------------
1 | // Body resets
2 | //
3 | // Update the foundational and global aspects of the page.
4 |
5 | * {
6 | box-sizing: border-box;
7 | }
8 |
9 | body {
10 | margin: 0;
11 | font-family: var(--body-font);
12 | font-size: var(--body-font-size);
13 | line-height: var(--body-line-height);
14 | color: var(--body-color);
15 | background-color: var(--body-bg);
16 | -webkit-text-size-adjust: 100%;
17 | -ms-text-size-adjust: 100%;
18 | }
19 |
20 | // No `:visited` state is required by default (browsers will use `a`)
21 | a {
22 | color: var(--link-color);
23 |
24 | // `:focus` is linked to `:hover` for basic accessibility
25 | &:hover,
26 | &:focus {
27 | color: var(--link-hover-color);
28 | }
29 |
30 | strong {
31 | color: inherit;
32 | }
33 | }
34 |
35 | img {
36 | display: block;
37 | margin: auto;
38 | max-width: 100%;
39 | margin-bottom: var(--spacer);
40 | border-radius: var(--border-radius);
41 | }
42 |
43 | table {
44 | margin-bottom: 1rem;
45 | width: 100%;
46 | border: 0 solid var(--border-color);
47 | border-collapse: collapse;
48 | }
49 |
50 | td,
51 | th {
52 | padding: .25rem .5rem;
53 | border-color: inherit;
54 | border-style: solid;
55 | border-width: 0;
56 | border-bottom-width: 1px;
57 | }
58 |
59 |
60 | th {
61 | text-align: left;
62 | }
63 |
64 | thead th {
65 | border-bottom-color: currentColor;
66 | }
67 |
68 | mark {
69 | padding: .15rem;
70 | background-color: var(--yellow-100);
71 | border-radius: .125rem;
72 | }
73 |
74 | p {
75 | text-align: justify;
76 | }
77 |
--------------------------------------------------------------------------------
/docs/_sass/_code.scss:
--------------------------------------------------------------------------------
1 | // Code
2 | //
3 | // Inline and block-level code snippets. Includes tweaks to syntax highlighted
4 | // snippets from Pygments/Rouge and Gist embeds.
5 |
6 | code,
7 | pre {
8 | font-family: var(--code-font);
9 | }
10 |
11 | code {
12 | font-size: 85%;
13 | }
14 |
15 | pre {
16 | display: block;
17 | margin-top: 0;
18 | margin-bottom: var(--spacer-3);
19 | overflow: auto;
20 | }
21 |
22 | .highlight {
23 | padding: var(--spacer);
24 | margin-bottom: var(--spacer);
25 | background-color: var(--code-bg);
26 | border-radius: var(--border-radius);
27 |
28 | pre {
29 | margin-bottom: 0;
30 | }
31 |
32 | // Triple backticks (code fencing) doubles the .highlight elements
33 | .highlight {
34 | padding: 0;
35 | }
36 | }
37 |
38 | .rouge-table {
39 | margin-bottom: 0;
40 | font-size: 100%;
41 |
42 | &,
43 | td,
44 | th {
45 | border: 0;
46 | }
47 |
48 | .gutter {
49 | vertical-align: top;
50 | user-select: none;
51 | opacity: .25;
52 | }
53 | }
54 |
55 | // Gist via GitHub Pages
56 | .gist .markdown-body {
57 | padding: 15px !important;
58 | }
59 |
--------------------------------------------------------------------------------
/docs/_sass/_layout.scss:
--------------------------------------------------------------------------------
1 | // Layout
2 | //
3 | // Styles for managing the structural hierarchy of the site.
4 |
5 | .container {
6 | max-width: 50%;
7 | padding-left: var(--spacer-2);
8 | padding-right: var(--spacer-2);
9 | margin-left: auto;
10 | margin-right: auto;
11 | }
12 |
13 | footer {
14 | margin-top: var(--spacer-3);
15 | margin-bottom: var(--spacer-3);
16 | }
17 |
--------------------------------------------------------------------------------
/docs/_sass/_masthead.scss:
--------------------------------------------------------------------------------
1 | // Masthead
2 | //
3 | // Super small header above the content for site name and short description.
4 |
5 | .masthead {
6 | padding-top: var(--spacer);
7 | padding-bottom: var(--spacer);
8 | margin-bottom: var(--spacer-3);
9 | }
10 |
11 | .masthead-title {
12 | margin-bottom: 0;
13 |
14 | a {
15 | color: inherit;
16 | text-decoration: none;
17 | }
18 |
19 | small {
20 | font-weight: 400;
21 | opacity: 0.5;
22 | }
23 | }
24 |
25 | // Navbar styles
26 | .nav {
27 | float: right;
28 | line-height: 1.25rem;
29 | word-spacing: 1rem;
30 | }
31 |
--------------------------------------------------------------------------------
/docs/_sass/_message.scss:
--------------------------------------------------------------------------------
1 | // Messages
2 | //
3 | // Show alert messages to users. You may add it to single elements like a ``,
4 | // or to a parent if there are multiple elements to show.
5 |
6 | .message {
7 | padding: var(--spacer);
8 | margin-bottom: var(--spacer);
9 | color: var(--gray-900);
10 | background-color: var(--yellow-100);
11 | border-radius: var(--border-radius);
12 | }
13 |
--------------------------------------------------------------------------------
/docs/_sass/_pagination.scss:
--------------------------------------------------------------------------------
1 | // Pagination
2 | //
3 | // Super lightweight (HTML-wise) blog pagination. `span`s are provide for when
4 | // there are no more previous or next posts to show.
5 |
6 | .pagination {
7 | display: flex;
8 | margin: 0 -1.5rem var(--spacer);
9 | color: grey;
10 | text-align: center;
11 | }
12 |
13 | // Pagination items can be `span`s or `a`s
14 | .pagination-item {
15 | display: block;
16 | padding: var(--spacer);
17 | text-decoration: none;
18 | border: solid var(--border-color);
19 | border-width: 1px 0;
20 |
21 | &:first-child {
22 | margin-bottom: -1px;
23 | }
24 | }
25 |
26 | // Only provide a hover state for linked pagination items
27 | a.pagination-item:hover {
28 | background-color: var(--border-color);
29 | }
30 |
31 | @media (min-width: 30em) {
32 | .pagination {
33 | margin: var(--spacer-3) 0;
34 | }
35 |
36 | .pagination-item {
37 | float: left;
38 | width: 50%;
39 | border-width: 1px;
40 |
41 | &:first-child {
42 | margin-bottom: 0;
43 | border-top-left-radius: var(--border-radius);
44 | border-bottom-left-radius: var(--border-radius);
45 | }
46 | &:last-child {
47 | margin-left: -1px;
48 | border-top-right-radius: var(--border-radius);
49 | border-bottom-right-radius: var(--border-radius);
50 | }
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/docs/_sass/_posts.scss:
--------------------------------------------------------------------------------
1 | // Posts and pages
2 | //
3 | // Each post is wrapped in `.post` and is used on default and post layouts. Each
4 | // page is wrapped in `.page` and is only used on the page layout.
5 |
6 | .page,
7 | .post {
8 | margin-bottom: 4em;
9 |
10 | li + li {
11 | margin-top: .25rem;
12 | }
13 | }
14 |
15 | // Blog post or page title
16 | .page-title,
17 | .post-title {
18 | color: var(--heading-color);
19 | }
20 | .page-title,
21 | .post-title {
22 | margin-top: 0;
23 | }
24 | .post-title a {
25 | color: inherit;
26 | text-decoration: none;
27 |
28 | &:hover,
29 | &:focus {
30 | text-decoration: underline;
31 | }
32 | }
33 |
34 | // Meta data line below post title
35 | .post-date {
36 | display: block;
37 | margin-top: -.5rem;
38 | margin-bottom: var(--spacer);
39 | color: var(--gray-600);
40 | }
41 |
42 |
43 | // Related posts
44 | .related {
45 | padding-top: var(--spacer-2);
46 | padding-bottom: var(--spacer-2);
47 | margin-bottom: var(--spacer-2);
48 | border-top: 1px solid var(--border-color);
49 | border-bottom: 1px solid var(--border-color);
50 | }
51 |
52 | .related-posts {
53 | padding-left: 0;
54 | list-style: none;
55 |
56 | h3 {
57 | margin-top: 0;
58 | }
59 |
60 | a {
61 | text-decoration: none;
62 |
63 | small {
64 | color: var(--gray-600);
65 | }
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/docs/_sass/_syntax.scss:
--------------------------------------------------------------------------------
1 | .highlight .hll { background-color: #ffc; }
2 | .highlight .c { color: #999; } /* Comment */
3 | .highlight .err { color: #a00; background-color: #faa } /* Error */
4 | .highlight .k { color: #069; } /* Keyword */
5 | .highlight .o { color: #555 } /* Operator */
6 | .highlight .cm { color: #09f; font-style: italic } /* Comment.Multiline */
7 | .highlight .cp { color: #099 } /* Comment.Preproc */
8 | .highlight .c1 { color: #999; } /* Comment.Single */
9 | .highlight .cs { color: #999; } /* Comment.Special */
10 | .highlight .gd { background-color: #fcc; border: 1px solid #c00 } /* Generic.Deleted */
11 | .highlight .ge { font-style: italic } /* Generic.Emph */
12 | .highlight .gr { color: #f00 } /* Generic.Error */
13 | .highlight .gh { color: #030; } /* Generic.Heading */
14 | .highlight .gi { background-color: #cfc; border: 1px solid #0c0 } /* Generic.Inserted */
15 | .highlight .go { color: #aaa } /* Generic.Output */
16 | .highlight .gp { color: #009; } /* Generic.Prompt */
17 | .highlight .gs { } /* Generic.Strong */
18 | .highlight .gu { color: #030; } /* Generic.Subheading */
19 | .highlight .gt { color: #9c6 } /* Generic.Traceback */
20 | .highlight .kc { color: #069; } /* Keyword.Constant */
21 | .highlight .kd { color: #069; } /* Keyword.Declaration */
22 | .highlight .kn { color: #069; } /* Keyword.Namespace */
23 | .highlight .kp { color: #069 } /* Keyword.Pseudo */
24 | .highlight .kr { color: #069; } /* Keyword.Reserved */
25 | .highlight .kt { color: #078; } /* Keyword.Type */
26 | .highlight .m { color: #f60 } /* Literal.Number */
27 | .highlight .s { color: #d44950 } /* Literal.String */
28 | .highlight .na { color: #4f9fcf } /* Name.Attribute */
29 | .highlight .nb { color: #366 } /* Name.Builtin */
30 | .highlight .nc { color: #0a8; } /* Name.Class */
31 | .highlight .no { color: #360 } /* Name.Constant */
32 | .highlight .nd { color: #99f } /* Name.Decorator */
33 | .highlight .ni { color: #999; } /* Name.Entity */
34 | .highlight .ne { color: #c00; } /* Name.Exception */
35 | .highlight .nf { color: #c0f } /* Name.Function */
36 | .highlight .nl { color: #99f } /* Name.Label */
37 | .highlight .nn { color: #0cf; } /* Name.Namespace */
38 | .highlight .nt { color: #2f6f9f; } /* Name.Tag */
39 | .highlight .nv { color: #033 } /* Name.Variable */
40 | .highlight .ow { color: #000; } /* Operator.Word */
41 | .highlight .w { color: #bbb } /* Text.Whitespace */
42 | .highlight .mf { color: #f60 } /* Literal.Number.Float */
43 | .highlight .mh { color: #f60 } /* Literal.Number.Hex */
44 | .highlight .mi { color: #f60 } /* Literal.Number.Integer */
45 | .highlight .mo { color: #f60 } /* Literal.Number.Oct */
46 | .highlight .sb { color: #c30 } /* Literal.String.Backtick */
47 | .highlight .sc { color: #c30 } /* Literal.String.Char */
48 | .highlight .sd { color: #c30; font-style: italic } /* Literal.String.Doc */
49 | .highlight .s2 { color: #c30 } /* Literal.String.Double */
50 | .highlight .se { color: #c30; } /* Literal.String.Escape */
51 | .highlight .sh { color: #c30 } /* Literal.String.Heredoc */
52 | .highlight .si { color: #a00 } /* Literal.String.Interpol */
53 | .highlight .sx { color: #c30 } /* Literal.String.Other */
54 | .highlight .sr { color: #3aa } /* Literal.String.Regex */
55 | .highlight .s1 { color: #c30 } /* Literal.String.Single */
56 | .highlight .ss { color: #fc3 } /* Literal.String.Symbol */
57 | .highlight .bp { color: #366 } /* Name.Builtin.Pseudo */
58 | .highlight .vc { color: #033 } /* Name.Variable.Class */
59 | .highlight .vg { color: #033 } /* Name.Variable.Global */
60 | .highlight .vi { color: #033 } /* Name.Variable.Instance */
61 | .highlight .il { color: #f60 } /* Literal.Number.Integer.Long */
62 |
63 | .css .o,
64 | .css .o + .nt,
65 | .css .nt + .nt { color: #999; }
66 |
--------------------------------------------------------------------------------
/docs/_sass/_toc.scss:
--------------------------------------------------------------------------------
1 | // Table of Contents
2 |
3 | #markdown-toc {
4 | padding: var(--spacer-2) var(--spacer-3);
5 | margin-bottom: var(--spacer-2);
6 | border: solid var(--border-color);
7 | border-width: 1px 0;
8 |
9 | &::before {
10 | display: block;
11 | margin-left: calc(var(--spacer-3) * -1);
12 | content: "Contents";
13 | font-size: 85%;
14 | font-weight: 500;
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/docs/_sass/_type.scss:
--------------------------------------------------------------------------------
1 | // Typography
2 | //
3 | // Headings, body text, lists, and other misc typographic elements.
4 |
5 | h1, h2, h3, h4, h5, h6 {
6 | margin-bottom: .5rem;
7 | font-weight: 600;
8 | line-height: 1.25;
9 | color: var(--heading-color);
10 | }
11 |
12 | h1 {
13 | font-size: 2rem;
14 | }
15 |
16 | h2 {
17 | margin-top: 1rem;
18 | font-size: 1.5rem;
19 | }
20 |
21 | h3 {
22 | margin-top: 1.5rem;
23 | font-size: 1.25rem;
24 | }
25 |
26 | h4, h5, h6 {
27 | margin-top: 1rem;
28 | font-size: 1rem;
29 | }
30 |
31 | p {
32 | margin-top: 0;
33 | margin-bottom: 1rem;
34 | }
35 |
36 | ul, ol, dl {
37 | margin-top: 0;
38 | margin-bottom: 1rem;
39 | }
40 |
41 | dt {
42 | font-weight: bold;
43 | }
44 |
45 | dd {
46 | margin-bottom: .5rem;
47 | }
48 |
49 | hr {
50 | position: relative;
51 | margin: var(--spacer-2) 0;
52 | border: 0;
53 | border-top: 1px solid var(--border-color);
54 | }
55 |
56 | abbr {
57 | font-size: 85%;
58 | font-weight: bold;
59 | color: var(--gray-600);
60 | text-transform: uppercase;
61 |
62 | &[title] {
63 | cursor: help;
64 | border-bottom: 1px dotted var(--border-color);
65 | }
66 | }
67 |
68 | blockquote {
69 | padding: .5rem 1rem;
70 | margin: .8rem 0;
71 | color: var(--gray-500);
72 | border-left: .25rem solid var(--border-color);
73 |
74 | p:last-child {
75 | margin-bottom: 0;
76 | }
77 |
78 | @media (min-width: 30em) {
79 | padding-right: 5rem;
80 | padding-left: 1.25rem;
81 | }
82 | }
83 |
84 | figure {
85 | margin: 0;
86 | }
87 |
88 |
89 | // Markdown footnotes
90 | //
91 | // See the example content post for an example.
92 |
93 | // Footnote number within body text
94 | a[href^="#fn:"],
95 | // Back to footnote link
96 | a[href^="#fnref:"] {
97 | display: inline-block;
98 | margin-left: .1rem;
99 | font-weight: bold;
100 | }
101 |
102 | // List of footnotes
103 | .footnotes {
104 | margin-top: 2rem;
105 | font-size: 85%;
106 | }
107 |
108 | // Custom type
109 | //
110 | // Extend paragraphs with `.lead` for larger introductory text.
111 |
112 | .lead {
113 | font-size: 1.25rem;
114 | font-weight: 300;
115 | }
116 |
--------------------------------------------------------------------------------
/docs/_sass/_variables.scss:
--------------------------------------------------------------------------------
1 | :root {
2 | --gray-000: #f8f9fa;
3 | --gray-100: #f1f3f5;
4 | --gray-200: #e9ecef;
5 | --gray-300: #dee2e6;
6 | --gray-400: #ced4da;
7 | --gray-500: #adb5bd;
8 | --gray-600: #868e96;
9 | --gray-700: #495057;
10 | --gray-800: #343a40;
11 | --gray-900: #212529;
12 | --dark-poole-001: hsl(200, 3%, 12%);
13 | --dark-poole-002: hsl(0, 0%, 85%);
14 | --dark-poole-link-color: rgba(255, 255, 255, 0.75);
15 | --dark-poole-link-hover: #fff;
16 |
17 | --red: #fa5252;
18 | --pink: #e64980;
19 | --grape: #be4bdb;
20 | --purple: #7950f2;
21 | --indigo: #4c6ef5;
22 | --blue: #228be6;
23 | --cyan: #15aabf;
24 | --teal: #12b886;
25 | --green: #40c057;
26 | --yellow: #fab005;
27 | --orange: #fd7e14;
28 |
29 | --blue-300: #74c0fc;
30 | --blue-400: #4dabf7;
31 | --yellow-100: #fff3bf;
32 |
33 | --body-font: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto,
34 | "Helvetica Neue", Arial, "Noto Sans", sans-serif, "Apple Color Emoji",
35 | "Segoe UI Emoji", "Segoe UI Symbol", "Noto Color Emoji";
36 | --body-font-size: 20px;
37 | --body-line-height: 1.5;
38 | --body-color: var(--gray-700);
39 | --body-bg: #fff;
40 |
41 | --link-color: var(--blue);
42 | --link-hover-color: #1c7ed6;
43 |
44 | --heading-color: var(--gray-900);
45 |
46 | --border-color: var(--gray-300);
47 | --border-radius: 0.25rem;
48 |
49 | --code-font: SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono",
50 | "Courier New", monospace;
51 | --code-color: var(--grape);
52 | --code-bg: var(--gray-000);
53 |
54 | --spacer: 1rem;
55 | --spacer-2: calc(var(--spacer) * 1.5);
56 | --spacer-3: calc(var(--spacer) * 3);
57 | }
58 |
59 | @media (prefers-color-scheme: dark) {
60 | :root {
61 | --body-color: var(--gray-300);
62 | --body-bg: var(--gray-800);
63 |
64 | --heading-color: #fff;
65 |
66 | --link-color: var(--blue-300);
67 | --link-hover-color: var(--blue-400);
68 |
69 | --border-color: rgba(255, 255, 255, 0.15);
70 |
71 | --code-bg: var(--gray-900);
72 | }
73 | }
74 |
75 | // StructuredFFN theme
76 | [data-theme="dark-poole"] {
77 | --body-color: var(--dark-poole-002);
78 | --body-bg: var(--dark-poole-001);
79 | --heading-color: var(--dark-poole-002);
80 | --link-color: var(--dark-poole-link-color);
81 | --link-hover-color: var(--dark-poole-link-hover);
82 | --border-color: rgba(255, 255, 255, 0.15);
83 | --code-bg: var(--gray-900);
84 | }
85 |
--------------------------------------------------------------------------------
/docs/assets/apple-touch-icon-precomposed.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/docs/assets/apple-touch-icon-precomposed.png
--------------------------------------------------------------------------------
/docs/assets/author.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/docs/assets/author.png
--------------------------------------------------------------------------------
/docs/assets/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/docs/assets/favicon.ico
--------------------------------------------------------------------------------
/docs/assets/fig_sgt_lowrank.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/docs/assets/fig_sgt_lowrank.png
--------------------------------------------------------------------------------
/docs/assets/gpt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/docs/assets/gpt.png
--------------------------------------------------------------------------------
/docs/assets/latency.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/docs/assets/latency.png
--------------------------------------------------------------------------------
/docs/assets/latency_bs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/docs/assets/latency_bs.png
--------------------------------------------------------------------------------
/docs/assets/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/docs/assets/method.png
--------------------------------------------------------------------------------
/docs/assets/scaling_law_lowrank.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/docs/assets/scaling_law_lowrank.png
--------------------------------------------------------------------------------
/docs/assets/training_dynamic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/docs/assets/training_dynamic.png
--------------------------------------------------------------------------------
/docs/assets/wide_structured.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/docs/assets/wide_structured.png
--------------------------------------------------------------------------------
/docs/atom.xml:
--------------------------------------------------------------------------------
1 | ---
2 | layout: null
3 | ---
4 |
5 |
6 |
7 |
8 | {{ site.title }}
9 |
10 |
11 | {{ site.time | date_to_xmlschema }}
12 | {{ site.url }}
13 |
14 | {{ site.author.name }}
15 | {{ site.author.email }}
16 |
17 |
18 | {% for post in site.posts %}
19 |
20 | {{ post.title | xml_escape }}
21 |
22 | {{ post.date | date_to_xmlschema }}
23 | {{ site.url }}{{ post.id }}
24 | {{ post.content | xml_escape }}
25 |
26 | {% endfor %}
27 |
28 |
29 |
--------------------------------------------------------------------------------
/docs/index.html:
--------------------------------------------------------------------------------
1 | ---
2 | layout: default
3 | title: Home
4 | ---
5 |
6 |
7 | {% for post in paginator.posts %}
8 |
9 |
10 |
11 | {{ post.title }}
12 |
13 |
14 |
15 | {{ post.date | date_to_string }}
16 |
17 | {{ post.content }}
18 |
19 | {% endfor %}
20 |
21 |
22 |
34 |
--------------------------------------------------------------------------------
/docs/poole-for-jekyll.gemspec:
--------------------------------------------------------------------------------
1 | # frozen_string_literal: true
2 |
3 | Gem::Specification.new do |spec|
4 | spec.name = "poole-for-jekyll"
5 | spec.version = "3.0.0"
6 | spec.authors = ["Mark Otto"]
7 | spec.email = ["markdotto@gmail.com"]
8 |
9 | spec.summary = "The Jekyll Butler. A no frills responsive Jekyll blog theme."
10 | spec.homepage = "https://getpoole.com"
11 | spec.license = "MIT"
12 |
13 | spec.files = `git ls-files -z`.split("\x0").select { |f| f.match(%r!^(assets|_layouts|_includes|_sass|LICENSE|README)!i) }
14 |
15 | spec.add_runtime_dependency "jekyll", "~> 4.0"
16 |
17 | spec.add_development_dependency "bundler", "~> 1.16"
18 | spec.add_development_dependency "rake", "~> 12.0"
19 | end
20 |
--------------------------------------------------------------------------------
/docs/styles.scss:
--------------------------------------------------------------------------------
1 | ---
2 | # Use a comment to ensure Jekyll reads the file to be transformed into CSS later
3 | # only main files contain this front matter, not partials.
4 | ---
5 |
6 | //
7 | // ___
8 | // /\_ \
9 | // _____ ___ ___\//\ \ __
10 | // /\ '__`\ / __`\ / __`\\ \ \ /'__`\
11 | // \ \ \_\ \/\ \_\ \/\ \_\ \\_\ \_/\ __/
12 | // \ \ ,__/\ \____/\ \____//\____\ \____\
13 | // \ \ \/ \/___/ \/___/ \/____/\/____/
14 | // \ \_\
15 | // \/_/
16 | //
17 | // Designed, built, and released under MIT license by @mdo. Learn more at
18 | // https://github.com/poole/poole.
19 |
20 | @import "variables";
21 | @import "base";
22 | @import "type";
23 | @import "syntax";
24 | @import "code";
25 | @import "layout";
26 | @import "masthead";
27 | @import "posts";
28 | @import "pagination";
29 | @import "message";
30 | @import "toc";
31 |
32 | // Sass for creating the swatches
33 | .colors {
34 | display: grid;
35 | grid-template-columns: max-content 1fr;
36 |
37 | dt {
38 | width: 3rem;
39 | height: 3rem;
40 | border-radius: var(--border-radius);
41 | box-shadow: inset 0 0 0 1px rgba(255,255,255,.15);
42 | }
43 |
44 | dd {
45 | margin-left: var(--spacer);
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/experiment/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/pytorch:24.02-py3
2 | MAINTAINER Xiuying Wei
3 |
4 |
5 | ARG DEBIAN_FRONTEND=noninteractive
6 |
7 | # package install
8 | RUN apt-get update && apt-get install -y \
9 | curl vim htop\
10 | ca-certificates \
11 | openssh-server \
12 | cmake \
13 | sudo \
14 | git \
15 | bzip2 \
16 | libx11-6 \
17 | zip \
18 | unzip ssh \
19 | tmux \
20 | && rm -rf /var/lib/apt/lists/*
21 |
22 |
23 | # Install Python 3.8 with Miniconda
24 | #RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-py38_4.10.3-Linux-x86_64.sh -O ~/miniconda.sh \
25 | # && /bin/bash ~/miniconda.sh -b -p /opt/conda \
26 | # && rm ~/miniconda.sh \
27 | # && /opt/conda/bin/conda install mkl numpy scipy pandas openmpi ipython jupyter \
28 | # && /opt/conda/bin/conda clean --all -y
29 |
30 |
31 | # ENV PATH="~/.local/bin:/opt/conda/bin:/usr/local/cuda/bin:${PATH}" \
32 | # LD_LIBRARY_PATH="/usr/local/cuda/lib64"
33 | ENV PATH="~/.local/bin:/usr/local/cuda/bin:${PATH}" \
34 | LD_LIBRARY_PATH="/usr/local/cuda/lib64"
35 |
36 | # Make $PATH and $LD_LIBRARY PATH available to all users
37 | RUN echo PATH="${PATH}" >> /etc/environment && \
38 | echo LD_LIBRARY_PATH="${LD_LIBRARY_PATH}" >> /etc/environment
39 |
40 | # transformers==4.34.0
41 | # datasets
42 | # evaluate
43 | # accelerate
44 | # RUN pip uninstall transformer-engine --yes
45 | # The following two rows are for butterfly
46 | RUN pip --no-cache-dir install \
47 | easydict \
48 | h5py \
49 | pyyaml \
50 | tqdm \
51 | pillow \
52 | protobuf \
53 | seaborn \
54 | scipy \
55 | scikit-learn \
56 | wandb \
57 | hydra-core \
58 | transformers==4.34.0 \
59 | datasets \
60 | evaluate \
61 | accelerate \
62 | sentencepiece
63 |
64 | # RUN pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu121
65 | # RUN pip3 install --upgrade flash-attn==2.4.2 --no-build-isolation
66 | # entrypoint
67 | RUN pip install --upgrade protobuf==3.20.0
68 | ENV ENTRYPOINTS_ROOT=/opt/entrypoints
69 | RUN mkdir -p ${ENTRYPOINTS_ROOT}
70 |
71 |
72 | # The entrypoint is run in an interactive shell so that the conda environment is activated before.
73 | # Don't overwrite the entrypoint, it is installing the project
74 | # and testing that you correctly mounted the project code and data and output directories.
75 | # It also performs some other important setup depending on the deployment platform.
76 | COPY --link entrypoint.sh ${ENTRYPOINTS_ROOT}/entrypoint.sh
77 | ENTRYPOINT ["/bin/bash", "-i", "/opt/entrypoints/entrypoint.sh"]
78 | CMD ["/bin/bash"]
79 |
80 |
81 | # userconfig
82 | # define your own config here
83 |
84 |
--------------------------------------------------------------------------------
/experiment/basic.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | s_token=2200000000
3 | m_token=6700000000
4 | l_token=14580000000
5 | xl_token=25500000000
6 | s_lr=0.0006
7 | m_lr=0.0003
8 | l_lr=0.00025
9 | xl_lr=0.0002
10 | s_train_batch=64
11 | s_test_batch=64
12 | m_train_batch=32
13 | m_test_batch=32
14 | l_train_batch=16
15 | l_test_batch=32
16 | xl_train_batch=16
17 | xl_test_batch=16
18 |
19 |
20 | # dense
21 | ./run_gpt.sh "gpt2s" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=linear optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${s_train_batch} data.test.test_batch=${s_test_batch} optimization.log_interval=100"
22 | ./run_gpt.sh "gpt2m" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=linear optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${m_train_batch} data.test.test_batch=${m_test_batch} optimization.log_interval=100"
23 | ./run_gpt.sh "gpt2l" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=linear optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${l_train_batch} data.test.test_batch=${l_test_batch} optimization.log_interval=100"
24 | ./run_gpt.sh "gpt2xl" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2xl method=linear optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${xl_train_batch} data.test.test_batch=${xl_test_batch} optimization.log_interval=100"
25 |
26 | # LowRank
27 | ./run_gpt.sh "gpt2s-lr-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=lowrank method.kwargs.rank=384 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${s_train_batch} data.test.test_batch=${s_test_batch} optimization.log_interval=100"
28 | ./run_gpt.sh "gpt2s-lr-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=lowrank method.kwargs.rank=192 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${s_train_batch} data.test.test_batch=${s_test_batch} optimization.log_interval=100"
29 | ./run_gpt.sh "gpt2m-lr-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=lowrank method.kwargs.rank=512 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${m_train_batch} data.test.test_batch=${m_test_batch} optimization.log_interval=100"
30 | ./run_gpt.sh "gpt2m-lr-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=lowrank method.kwargs.rank=256 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${m_train_batch} data.test.test_batch=${m_test_batch} optimization.log_interval=100"
31 | ./run_gpt.sh "gpt2l-lr-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=lowrank method.kwargs.rank=768 optimization.optimizer.kwargs.lr=${l_lr} optimization.max_tokens=${l_token} data.train.train_batch=${l_train_batch} data.test.test_batch=${l_test_batch} optimization.log_interval=100"
32 | ./run_gpt.sh "gpt2l-lr-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=lowrank method.kwargs.rank=384 optimization.optimizer.kwargs.lr=${l_lr} optimization.max_tokens=${l_token} data.train.train_batch=${l_train_batch} data.test.test_batch=${l_test_batch} optimization.log_interval=100"
33 | ./run_gpt.sh "gpt2xl-lr-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2xl method=lowrank method.kwargs.rank=1024 optimization.optimizer.kwargs.lr=${xl_lr} optimization.max_tokens=${xl_token} data.train.train_batch=${xl_train_batch} data.test.test_batch=${xl_test_batch} optimization.log_interval=100"
34 | ./run_gpt.sh "gpt2xl-lr-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2xl method=lowrank method.kwargs.rank=512 optimization.optimizer.kwargs.lr=${xl_lr} optimization.max_tokens=${xl_token} data.train.train_batch=${xl_train_batch} data.test.test_batch=${xl_test_batch} optimization.log_interval=100"
35 |
36 | # BlockDense
37 | ./run_gpt.sh "gpt2s-bld-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=blockdense method.kwargs.nblocks=2 method.kwargs.rank=512 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${s_train_batch} data.test.test_batch=${s_test_batch} optimization.log_interval=100"
38 | ./run_gpt.sh "gpt2s-bld-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=blockdense method.kwargs.nblocks=2 method.kwargs.rank=256 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${s_train_batch} data.test.test_batch=${s_test_batch} optimization.log_interval=100"
39 | ./run_gpt.sh "gpt2m-bld-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=blockdense method.kwargs.nblocks=4 method.kwargs.rank=768 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${m_train_batch} data.test.test_batch=${m_test_batch} optimization.log_interval=100"
40 | ./run_gpt.sh "gpt2m-bld-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=blockdense method.kwargs.nblocks=4 method.kwargs.rank=384 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${m_train_batch} data.test.test_batch=${m_test_batch} optimization.log_interval=100"
41 | ./run_gpt.sh "gpt2l-bld-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=blockdense method.kwargs.nblocks=2 method.kwargs.rank=1024 optimization.optimizer.kwargs.lr=${l_lr} optimization.max_tokens=${l_token} data.train.train_batch=${l_train_batch} data.test.test_batch=${l_test_batch} optimization.log_interval=100"
42 | ./run_gpt.sh "gpt2l-bld-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=blockdense method.kwargs.nblocks=2 method.kwargs.rank=512 optimization.optimizer.kwargs.lr=${l_lr} optimization.max_tokens=${l_token} data.train.train_batch=${l_train_batch} data.test.test_batch=${l_test_batch} optimization.log_interval=100"
43 | ./run_gpt.sh "gpt2xl-bld-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2xl method=blockdense method.kwargs.nblocks=4 method.kwargs.rank=1536 optimization.optimizer.kwargs.lr=${xl_lr} optimization.max_tokens=${xl_token} data.train.train_batch=${xl_train_batch} data.test.test_batch=${xl_test_batch} optimization.log_interval=100"
44 | ./run_gpt.sh "gpt2xl-bld-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2xl method=blockdense method.kwargs.nblocks=4 method.kwargs.rank=768 optimization.optimizer.kwargs.lr=${xl_lr} optimization.max_tokens=${xl_token} data.train.train_batch=${xl_train_batch} data.test.test_batch=${xl_test_batch} optimization.log_interval=100"
45 |
46 | # BlockShuffle
47 | ./run_gpt.sh "gpt2s-bls-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=blockshuffle method.kwargs.nblocks=2 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${s_train_batch} data.test.test_batch=${s_test_batch} optimization.log_interval=100"
48 | ./run_gpt.sh "gpt2s-bls-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=blockshuffle method.kwargs.nblocks=4 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${s_train_batch} data.test.test_batch=${s_test_batch} optimization.log_interval=100"
49 | ./run_gpt.sh "gpt2m-bls-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=blockshuffle method.kwargs.nblocks=2 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${m_train_batch} data.test.test_batch=${m_test_batch} optimization.log_interval=100"
50 | ./run_gpt.sh "gpt2m-bls-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=blockshuffle method.kwargs.nblocks=4 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${m_train_batch} data.test.test_batch=${m_test_batch} optimization.log_interval=100"
51 | ./run_gpt.sh "gpt2l-bls-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=blockshuffle method.kwargs.nblocks=2 optimization.optimizer.kwargs.lr=${l_lr} optimization.max_tokens=${l_token} data.train.train_batch=${l_train_batch} data.test.test_batch=${l_test_batch} optimization.log_interval=100"
52 | ./run_gpt.sh "gpt2l-bls-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=blockshuffle method.kwargs.nblocks=4 optimization.optimizer.kwargs.lr=${l_lr} optimization.max_tokens=${l_token} data.train.train_batch=${l_train_batch} data.test.test_batch=${l_test_batch} optimization.log_interval=100"
53 | ./run_gpt.sh "gpt2xl-bls-0x63" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2xl method=blockshuffle method.kwargs.nblocks=2 optimization.optimizer.kwargs.lr=${xl_lr} optimization.max_tokens=${xl_token} data.train.train_batch=${xl_train_batch} data.test.test_batch=${xl_test_batch} optimization.log_interval=100"
54 | ./run_gpt.sh "gpt2xl-bls-0x33" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2xl method=blockshuffle method.kwargs.nblocks=4 optimization.optimizer.kwargs.lr=${xl_lr} optimization.max_tokens=${xl_token} data.train.train_batch=${xl_train_batch} data.test.test_batch=${xl_test_batch} optimization.log_interval=100"
55 |
--------------------------------------------------------------------------------
/experiment/sgd.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | s_token=2200000000
3 | m_token=6700000000
4 | l_token=14580000000
5 | xl_token=25500000000
6 | s_lr=0.0006
7 | m_lr=0.0003
8 | l_lr=0.00025
9 | xl_lr=0.0002
10 | train_batch=8
11 | test_batch=8
12 | s_lr_ratio=0.30
13 | m_lr_ratio=0.38
14 | l_lr_ratio=0.41
15 | xl_lr_ratio=0.43
16 | # The parameters for BlockDense with about 32\% parameters are not matched with the LowRank and BlockShuffle exactly. Thus, we provide the max_step_ratio of self-guided training for BlockDense separately to exactly match the training FLOPs.
17 |
18 | s_bld_ratio=0.30
19 | m_bld_ratio=0.40
20 | l_bld_ratio=0.41
21 | xl_bld_ratio=0.45
22 |
23 | # apply self-guided training for the first half of training
24 | ./run_gpt.sh "gpt2s-lr-0x33-sgd-fixedstep" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=lowrank method.kwargs.rank=192 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedstep optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=0.5"
25 | ./run_gpt.sh "gpt2m-lr-0x33-sgd-fixedstep" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=lowrank method.kwargs.rank=256 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedstep optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=0.5"
26 | ./run_gpt.sh "gpt2s-bld-0x33-sgd-fixedstep" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=blockdense method.kwargs.nblocks=2 method.kwargs.rank=256 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedstep optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=0.5"
27 | ./run_gpt.sh "gpt2m-bld-0x33-sgd-fixedstep" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=blockdense method.kwargs.nblocks=4 method.kwargs.rank=384 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedstep optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=0.5"
28 | ./run_gpt.sh "gpt2s-bls-0x33-sgd-fixedstep" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=blockshuffle method.kwargs.nblocks=4 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedstep optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=0.5"
29 | ./run_gpt.sh "gpt2m-bls-0x33-sgd-fixedstep" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=blockshuffle method.kwargs.nblocks=4 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedstep optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=0.5"
30 |
31 | # to match the flops
32 | # LowRank
33 | ./run_gpt.sh "gpt2s-lr-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=lowrank method.kwargs.rank=192 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${s_lr_ratio}"
34 | ./run_gpt.sh "gpt2m-lr-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=lowrank method.kwargs.rank=256 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${m_lr_ratio}"
35 | ./run_gpt.sh "gpt2l-lr-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=lowrank method.kwargs.rank=384 optimization.optimizer.kwargs.lr=${l_lr} optimization.max_tokens=${l_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${l_lr_ratio}"
36 | ./run_gpt.sh "gpt2xl-lr-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2xl method=lowrank method.kwargs.rank=512 optimization.optimizer.kwargs.lr=${xl_lr} optimization.max_tokens=${xl_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${xl_lr_ratio}"
37 |
38 | # BlockDense
39 | ./run_gpt.sh "gpt2s-bld-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=blockdense method.kwargs.nblocks=2 method.kwargs.rank=256 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${s_bld_ratio}"
40 | ./run_gpt.sh "gpt2m-bld-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=blockdense method.kwargs.nblocks=4 method.kwargs.rank=384 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${m_bld_ratio}"
41 | ./run_gpt.sh "gpt2l-bld-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=blockdense method.kwargs.nblocks=2 method.kwargs.rank=512 optimization.optimizer.kwargs.lr=${l_lr} optimization.max_tokens=${l_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${l_bld_ratio}"
42 | ./run_gpt.sh "gpt2xl-bld-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2xl method=blockdense method.kwargs.nblocks=4 method.kwargs.rank=768 optimization.optimizer.kwargs.lr=${xl_lr} optimization.max_tokens=${xl_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${xl_bld_ratio}"
43 |
44 | # BlockShuffle
45 | ./run_gpt.sh "gpt2s-bls-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=blockshuffle method.kwargs.nblocks=4 optimization.optimizer.kwargs.lr=${s_lr} optimization.max_tokens=${s_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${s_lr_ratio}"
46 | ./run_gpt.sh "gpt2m-bls-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=blockshuffle method.kwargs.nblocks=4 optimization.optimizer.kwargs.lr=${m_lr} optimization.max_tokens=${m_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${m_lr_ratio}"
47 | ./run_gpt.sh "gpt2l-bls-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=blockshuffle method.kwargs.nblocks=4 optimization.optimizer.kwargs.lr=${l_lr} optimization.max_tokens=${l_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${l_lr_ratio}"
48 | ./run_gpt.sh "gpt2xl-bls-0x33-sgd" 1 "torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2xl method=blockshuffle method.kwargs.nblocks=4 optimization.optimizer.kwargs.lr=${xl_lr} optimization.max_tokens=${xl_token} data.train.train_batch=${train_batch} data.test.test_batch=${test_batch} optimization.log_interval=100 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=${xl_lr_ratio}"
49 |
50 |
51 |
--------------------------------------------------------------------------------
/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CLAIRE-Labo/StructuredFFN/2442a14c57be177b50bd900f1fba6fcec4c96d93/image.png
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # Building on Efficient Foundations: Effectively Training LLMs with Structured Feedforward Layers
2 |
3 | ## Introduction
4 | This repository contains the offical implementation for our paper
5 |
6 | **Building on Efficient Foundations: Effectively Training LLMs with Structured Feedforward Layers**
7 |
8 | Xiuying Wei, Skander Moalla, Razvan Pascanu, Caglar Gulcehre
9 |
10 | > In this work, we investigate structured matrices for FFN blocks from the train-from-scratch aspect, first identifying their efficiency and optimization challenges and then presenting experimental results. We consider three efficient linear parametrizations: LowRank, BlockShuffle (comprising two block-diagonal matrices), and BlockDense (a combination of dense and block-diagonal matrices). We propose the pre-merge technique to solve their efficiency bottleneck at the online decoding stage. Then, a training strategy called self-guided training is proposed to improve their training dynamics. Experimental results include the steeper scaling curves of these structured matrices compared to the dense ones on FFN, the improvement brought by self-guided training, and the performance of wide and structured networks when combined with GQA for the attention block.
11 |
12 | 
13 | ## File Organization
14 | ```
15 | Structured/src/
16 | ├── benchmark_acc/ [training and evaluation entry for different dataset]
17 | │ └── refinedweb_experiment.py [refinedweb entry]
18 | ├── benchmark_eff [efficiency entry]
19 | │ ├── bench_kernel.py [kernel efficiency]
20 | │ ├── bench_mlp_train.py [mlp efficiency]
21 | │ ├── benchmark_model_infer.py [decoding efficiency]
22 | │ └── benchmark_model_train.py [prefill/ context efficiency]
23 | ├── configs [hydra config]
24 | │ ├── data [No use. refinedweb is preprocessed in advance]
25 | │ ├── method [different efficient linear layer]
26 | │ ├── model [gpt and llama]
27 | │ ├── optimization [optimization including scheduler, optimizer, self-guided training etc.]
28 | │ └── refinedweb_config.yaml
29 | ├── modules
30 | │ ├── __init__.py
31 | │ ├── op [fast op. Commons ones invoke others or paste from megatron]
32 | │ ├── layer [efficient lineaer layers that invoke functions in op dir]
33 | │ ├── mlp [efficient mlps that invoke functions in layer dir]
34 | │ └── model [supports layernorm or rmsnorm, bias or not, tie we or not, rotary or absolute, gelu or swilu]
35 | ├── optimization
36 | │ ├── __init__.py
37 | │ ├── scheduler.py [cosine with warmup]
38 | │ └── trainer.py [basic training function including seed, checkpoint, and info]
39 | └── utils
40 | └── refinedweb_llama.py [preprocess file]
41 | ```
42 |
43 | ## Env
44 | We use a Docker container for the environment and the GPU type of A100 80G for experiments. The Dockerfile is provided in the experiments folder, where the base image is from Nvidia (nvcr.io/nvidia/pytorch:24.02-py3) with the transformer engine, flash attention, and apex pre-installed. The required Python packages include transformers, wandb, datasets, etc., as listed in the Dockerfile.
45 |
46 | ## Data preprocess
47 | ```
48 | python refinedweb_llama.py --tokenizer llama --block_size 1025 --num_proc=32
49 | ```
50 |
51 | Refinedweb is quite large. So we shuffle, extract, and tokenize them into token ids in advance. Their token ids are kept in np. memmap to avoid loading data into CPU memory at one time. The above command will randomly split out about 0.7B validation tokens and 65B training tokens for later use.
52 |
53 | ## Experiments
54 | ### Structured linear parametrization (Table 1 and Table 9)
55 | We provide several examples below. We put the whole commands in basic.sh
56 | ```
57 | # gpt2 and linear
58 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=linear
59 |
60 | # LowRank
61 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=lowrank method.kwargs.rank=384 optimization.max_tokens=2200000000 optimization.optimizer.kwargs.lr=6.0e-4 data.train.train_batch=64 data.test.test_batch=64
62 |
63 | # BlockShuffle
64 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=blockshuffle method.kwargs.nblocks=2 optimization.max_tokens=2200000000 optimization.optimizer.kwargs.lr=6.0e-4 data.train.train_batch=64 data.test.test_batch=64
65 |
66 | # BlockDense
67 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=blockdense method.kwargs.rank=512 method.kwargs.nblocks=2 optimization.max_tokens=2200000000 optimization.optimizer.kwargs.lr=6.0e-4 data.train.train_batch=64 data.test.test_batch=64
68 | ```
69 |
70 | ### Self-Guided Training (Table 3, 4, and 10)
71 | There are two modes:
72 |
73 | * ablation study that applies the method to the first half of training and incurs 25% extra FFN FLOPs
74 |
75 | * experiments with the same training FLOPs to see the straightforward improvement. We use self-guided training for the beginning and repeat this part of tokens at the end to ensure that structured matrices also learn from this data thoroughly. The amount of self-guided training is adjusted to match the training FLOPs.
76 |
77 | We provide examples here, and put all the reproducible commands in experiments/sgd.sh
78 | ```
79 | # Ablation
80 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=lowrank method.kwargs.rank=192 optimization.max_tokens=2200000000 optimization.optimizer.kwargs.lr=6.0e-4 data.train.train_batch=16 data.test.test_batch=16 optimization/training=self_guided_training optimization.training.kwargs.reduce_flop=true
81 |
82 | # to match the flops
83 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=lowrank method.kwargs.rank=192 optimization.max_tokens=2200000000 optimization.optimizer.kwargs.lr=6.0e-4 data.train.train_batch=32 data.test.test_batch=32 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=0.3
84 | ```
85 |
86 | ### Wide and Structured network (Table 2)
87 | Motivated by the scaling curves, we make the wide model structured with LowRank for FFN and GQA for attention block.
88 |
89 | Transformer-m
90 | ```
91 | # GQA
92 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=linear model.kwargs.num_kv_heads=4 model.kwargs.ffn_dim=4864 data.train.train_batch=32 data.test.test_batch=32 optimization.max_tokens=6700000000 optimization.optimizer.kwargs.lr=3.0e-4
93 |
94 | # Ours
95 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=lowrank method.kwargs.rank=512 model.kwargs.hidden_dim=1024 model.kwargs.ffn_dim=4864 model.kwargs.attn_dim=512 model.kwargs.num_q_heads=8 model.kwargs.num_kv_heads=4 data.train.train_batch=32 data.test.test_batch=32 optimization.optimizer.kwargs.lr=3.0e-4 optimization.max_tokens=10580000000
96 |
97 | # Ours (self-guided training)
98 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=lowrank method.kwargs.rank=512 model.kwargs.hidden_dim=1024 model.kwargs.ffn_dim=4864 model.kwargs.attn_dim=512 model.kwargs.num_q_heads=8 model.kwargs.num_kv_heads=4 optimization.optimizer.kwargs.lr=3.0e-4 optimization.max_tokens=6700000000 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=0.40
99 | ```
100 |
101 | Transformer-l
102 | ```
103 | # GQA
104 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=linear model.kwargs.num_kv_heads=2 model.kwargs.ffn_dim=7424 data.train.train_batch=8 data.test.test_batch=8 optimization.max_tokens=14580000000 optimization.optimizer.kwargs.lr=0.00025
105 |
106 | # Ours
107 | # we keep the KV Channels to be 256, aligning with what we used in GQA.
108 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=lowrank method.kwargs.rank=768 model.kwargs.hidden_dim=1536 model.kwargs.ffn_dim=7424 model.kwargs.attn_dim=768 model.kwargs.num_q_heads=12 model.kwargs.num_kv_heads=4 data.train.train_batch=16 data.test.test_batch=16 optimization.optimizer.kwargs.lr=2.5e-4 optimization.max_tokens=23360000000
109 |
110 | # Ours (self-guided training)
111 | torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2l method=lowrank method.kwargs.rank=768 model.kwargs.hidden_dim=1536 model.kwargs.ffn_dim=7424 model.kwargs.attn_dim=768 model.kwargs.num_q_heads=12 model.kwargs.num_kv_heads=4 optimization.optimizer.kwargs.lr=2.5e-4 optimization.max_tokens=14580000000 optimization/training=self_guided_training optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.reduce_flop=true optimization.training.kwargs.max_step_ratio=0.395
112 | ```
113 |
114 | ### Citation
115 | If you find this repo useful for your research, please consider citing the paper:
116 | ```
117 | @article{wei2024building,
118 | title={Building on Efficient Foundations: Effectively Training LLMs with Structured Feedforward Layers},
119 | author={Wei, Xiuying and Moalla, Skander and Pascanu, Razvan and Gulcehre, Caglar},
120 | journal={arXiv preprint arXiv:2406.16450},
121 | year={2024}
122 | }
123 |
124 | @article{wei2024investigating,
125 | title={Investigating Low-Rank Training in Transformer Language Models: Efficiency and Scaling Analysis},
126 | author={Wei, Xiuying and Moalla, Skander and Pascanu, Razvan and Gulcehre, Caglar},
127 | journal={arXiv preprint arXiv:2407.09835},
128 | year={2024}
129 | }
130 |
131 | ```
132 |
--------------------------------------------------------------------------------
/src/benchmark_acc/refinedweb_experiment.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | import numpy as np
5 | import time
6 | from tqdm import tqdm
7 |
8 |
9 | project_root = os.path.dirname(
10 | os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
11 | )
12 |
13 | print(project_root)
14 | sys.path.insert(0, project_root)
15 | import wandb
16 | import hydra
17 | import torch
18 | import torch.distributed as dist
19 | from omegaconf import DictConfig, OmegaConf
20 |
21 | from src.optimization import (
22 | get_optimizer,
23 | get_lr_scheduler,
24 | trainer,
25 | )
26 |
27 | from src.modules import get_model, update_ratio
28 |
29 |
30 | class RefinedWebGPT(trainer.TrainableModel):
31 | # NLP tasks are dominated by step rather than epoch, because we need to consider gradient accumulation
32 | def __init__(self, config):
33 | super().__init__(config)
34 | # get dataset
35 | self.set_seed(self.config.optimization.seed)
36 | # get data files
37 | self.train_file_path = os.path.join(
38 | self.config.data.train.path,
39 | self.config.data.tokenizer.name + "-train-tmp.bin",
40 | )
41 | self.val_file_path = os.path.join(
42 | self.config.data.test.path,
43 | self.config.data.tokenizer.name + "-val-tmp.bin",
44 | )
45 | self.block_size = min(
46 | self.config.data.block_size, self.config.data.tokenizer.model_max_length
47 | )
48 | validate_tokens = 512000 * 1024
49 | self.validate_samples = validate_tokens // self.block_size
50 | assert (
51 | self.validate_samples % (self.ngpus * self.config.data.test.test_batch) == 0
52 | )
53 | assert self.gpu_id != -1, "we only support torchrun in job submission"
54 |
55 | # get metric
56 | self.max_step = int(
57 | self.config.optimization.max_tokens
58 | / self.global_batch_size
59 | / self.block_size
60 | )
61 | self.set_self_guided_training()
62 | self.config.optimization.lr_scheduler.kwargs.T_max = self.max_step
63 | if self.gpu_id in [-1, 0]:
64 | self.metric = {
65 | "train_loss": 0.0,
66 | "train_ppl": 0.0,
67 | "test_loss": 0.0,
68 | "test_ppl": 0.0,
69 | "step": 0,
70 | "lr": 0.0,
71 | "fwd+bwd": 0.0,
72 | }
73 | # get model
74 | self.set_seed(self.config.optimization.seed)
75 | self.model = get_model(self.config, self.device)
76 | self.get_info()
77 |
78 | # get optimizer
79 | self.optimizer = get_optimizer(
80 | self.config.optimization, self.get_optimize_param()
81 | )
82 | if getattr(self.config.optimization, "lr_scheduler", None):
83 | self.lr_scheduler = get_lr_scheduler(
84 | self.config.optimization, self.optimizer
85 | )
86 |
87 | # get wandb
88 | if self.gpu_id in [-1, 0] and self.config.wandb_use:
89 | self.wandblog = trainer.WandbLog(
90 | self.config.wandb, self.metric, x_axis="step"
91 | )
92 |
93 | assert self.load_save_mode == "step"
94 | self.prepare_load_save()
95 | self.resume_kwargs = self.load_checkpoint()
96 | if self.gpu_id != -1:
97 | self.model = torch.nn.parallel.DistributedDataParallel(
98 | self.model,
99 | device_ids=[self.gpu_id],
100 | output_device=self.gpu_id,
101 | find_unused_parameters=self.special_training,
102 | )
103 | if self.gpu_id in [-1, 0]:
104 | print(self.config)
105 |
106 | def get_batch(self, split, offset_row):
107 | if split == "train":
108 | arr = np.memmap(
109 | self.train_file_path,
110 | dtype=np.uint16, # we store in 2 bytes
111 | mode="r",
112 | offset=offset_row * (self.block_size + 1) * 2,
113 | shape=(self.config.data.train.train_batch, (self.block_size + 1)),
114 | )
115 | elif split == "val":
116 | arr = np.memmap(
117 | self.val_file_path,
118 | dtype=np.uint16, # we store in 2 bytes
119 | mode="r",
120 | offset=offset_row * (self.block_size + 1) * 2,
121 | shape=(self.config.data.test.test_batch, (self.block_size + 1)),
122 | )
123 | else:
124 | raise NotImplementedError
125 |
126 | x = torch.from_numpy(arr[:, :-1].astype(np.int64))
127 | y = torch.from_numpy(arr[:, 1:].astype(np.int64))
128 | x, y = x.pin_memory().to("cuda", non_blocking=True), y.pin_memory().to(
129 | "cuda", non_blocking=True
130 | )
131 | return x, y
132 |
133 | def _validate(self):
134 | self.model.eval()
135 | ddp_loss = torch.tensor(0.0).to(self.device)
136 | ddp_samples = torch.tensor(0).to(self.device)
137 | samples_per_gpu = self.validate_samples // self.ngpus
138 | with torch.no_grad():
139 | offset_row = self.gpu_id * samples_per_gpu
140 | for i in range(samples_per_gpu // self.config.data.test.test_batch):
141 | input_ids, labels = self.get_batch(
142 | split="val", offset_row=offset_row + ddp_samples.item()
143 | )
144 | with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
145 | loss = self.model(
146 | input_ids=input_ids,
147 | labels=labels,
148 | )
149 | if i % 100 == 0 and self.gpu_id in [-1, 0]:
150 | print("the loss at batch {} is {}".format(i, loss))
151 | ddp_loss += loss.item() * input_ids.shape[0]
152 | ddp_samples += input_ids.shape[0]
153 | print("The samples on rank {} is {}".format(self.gpu_id, ddp_samples))
154 | dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
155 | dist.all_reduce(ddp_samples, op=dist.ReduceOp.SUM)
156 | var_loss = (ddp_loss / ddp_samples).item()
157 | var_ppl = math.exp(var_loss)
158 | return var_loss, var_ppl
159 |
160 | def _train(self, resume_batch, max_step, offset_row=-1):
161 | if resume_batch >= max_step:
162 | return
163 | train_iterator = tqdm(
164 | range(resume_batch, max_step),
165 | desc="Steps",
166 | disable=self.gpu_id not in [-1, 0],
167 | )
168 | samples_per_gpu = self.global_batch_size // self.ngpus
169 | self.model.train()
170 | self.optimizer.zero_grad()
171 | train_loss = 0.0
172 | train_samples = 0
173 | if offset_row == -1:
174 | offset_row = resume_batch * self.global_batch_size
175 | offset_row += self.gpu_id * samples_per_gpu
176 | for i in train_iterator:
177 | torch.cuda.synchronize()
178 | t0 = time.time()
179 | train_loss = 0.0
180 | train_samples = 0
181 | for micro_step in range(self.gradient_accumulation_steps):
182 | input_ids, labels = self.get_batch(
183 | split="train", offset_row=offset_row + train_samples
184 | )
185 | with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
186 | loss = self.model(
187 | input_ids=input_ids,
188 | labels=labels,
189 | )
190 | train_samples += self.config.data.train.train_batch
191 | train_loss += loss.item() * self.config.data.train.train_batch
192 | loss = loss / self.gradient_accumulation_steps
193 | loss.backward()
194 | # finish the step
195 | if self.special_training:
196 | self.model.apply(lambda module: update_ratio(module=module))
197 | self.set_gradient_clipping()
198 | self.optimizer.step()
199 | self.lr_scheduler.step()
200 | self.optimizer.zero_grad()
201 | torch.cuda.synchronize()
202 | t2 = time.time()
203 | self.step += 1
204 | offset_row += self.global_batch_size
205 | if self.gpu_id in [-1, 0] and (self.step + 1) % self.log_interval == 0:
206 | # test_loss, test_ppl = self._test()
207 | # self.model.train()
208 | self.metric.update(
209 | {
210 | "train_loss": train_loss / train_samples,
211 | "train_ppl": math.exp(train_loss / train_samples),
212 | "step": self.step,
213 | "lr": self.optimizer.param_groups[0]["lr"],
214 | "fwd+bwd": (t2 - t0),
215 | }
216 | )
217 | if self.config.wandb_use:
218 | self.wandblog.record(self.metric)
219 | else:
220 | print(self.metric)
221 |
222 | self.save_checkpoint(**{"resume_batch": i + 1})
223 |
224 | def train(self):
225 | self.set_seed(self.config.optimization.seed)
226 | print("***** Running training *****")
227 | num_examples = self.max_step * self.global_batch_size
228 | print("Num Examples = {}".format(num_examples))
229 | # Note that epoch would always be zero here
230 | print("Num Tokens = {}".format(num_examples * self.block_size))
231 | print("Num Steps = {}".format(self.max_step))
232 | print("Global batch size = {}".format(self.global_batch_size))
233 | print(
234 | "Gradient Accumulation steps = {}".format(self.gradient_accumulation_steps)
235 | )
236 | resume_batch = self.resume_kwargs.get("resume_batch", 0) # next one
237 | print("resume from batch {}".format(resume_batch))
238 | # train guided steps
239 | self._train(resume_batch, self.guided_steps, offset_row=-1)
240 | self.close_self_guided_training()
241 | self._train(
242 | max(self.guided_steps, resume_batch),
243 | self.max_step - self.repeat_steps,
244 | offset_row=-1,
245 | )
246 | self._train(
247 | max(self.max_step - self.repeat_steps, resume_batch),
248 | self.max_step,
249 | offset_row=max(0, resume_batch + self.repeat_steps - self.max_step),
250 | )
251 |
252 |
253 | @hydra.main(
254 | version_base=None,
255 | config_path="../configs",
256 | config_name="refinedweb_config",
257 | )
258 | def main(config):
259 | OmegaConf.register_new_resolver("eval", eval)
260 | config.base_dir = os.path.join(
261 | config.base_dir, config.data.name + "_" + config.model.name
262 | )
263 | config.wandb.dir = config.base_dir
264 | config.wandb.dir = os.path.join(config.base_dir, config.method.name)
265 | gpu_id = int(os.getenv("RANK", -1))
266 | if gpu_id in [-1, 0] and not os.path.exists(config.wandb.dir):
267 | os.makedirs(config.wandb.dir)
268 |
269 | if gpu_id in [-1, 0] and config.wandb_use:
270 | wandb.init(
271 | config=OmegaConf.to_container(config, resolve=True, throw_on_missing=True),
272 | entity=config.wandb.entity,
273 | project=config.wandb.project,
274 | resume=None if config.optimization.load_checkpoint else "allow",
275 | anonymous=config.wandb.anonymous,
276 | mode=config.wandb.mode,
277 | dir=config.wandb.dir,
278 | )
279 | if gpu_id != -1:
280 | dist.barrier()
281 | model = RefinedWebGPT(config)
282 | model.train()
283 |
284 | if gpu_id != -1:
285 | dist.barrier()
286 | print("Finish Training!")
287 | print("Begin to validate!")
288 | var_loss, var_ppl = model._validate()
289 | print("The var loss is {:.4f} and var ppl is {:.4f}".format(var_loss, var_ppl))
290 | if gpu_id in [-1, 0]:
291 | if config.wandb_use:
292 | wandb.finish()
293 | return var_loss, var_ppl
294 |
295 |
296 | if __name__ == "__main__":
297 | gpu_id = int(os.getenv("RANK", -1))
298 | world_size = int(os.getenv("WORLD_SIZE", 1))
299 | if gpu_id != -1:
300 | torch.cuda.set_device(gpu_id)
301 | dist.init_process_group(
302 | backend="nccl", world_size=world_size, rank=gpu_id, init_method="env://"
303 | )
304 |
305 | main()
306 |
--------------------------------------------------------------------------------
/src/benchmark_eff/bench_kernel.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import triton
4 | import torch
5 |
6 | project_root = os.path.dirname(
7 | os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
8 | )
9 |
10 | print(project_root)
11 | sys.path.insert(0, project_root)
12 |
13 | from src.modules.op import (
14 | block_shuffle_bmm,
15 | block_shuffle_einsum,
16 | block_shuffle_custom,
17 | block_dense_bmm,
18 | block_dense_custom,
19 | low_rank_custom,
20 | )
21 |
22 |
23 | @triton.testing.perf_report(
24 | triton.testing.Benchmark(
25 | x_names=["bs", "blocks", "in_blksz", "out_blksz"],
26 | x_vals=[
27 | (16 * 1024, 4, 512, 512),
28 | (16 * 512, 16, 512, 512 * 4),
29 | (32 * 1024, 4, 1024, 1024),
30 | (32 * 1024, 2, 4096, 4096 * 4),
31 | (64 * 1024, 16, 256, 256 * 4),
32 | ],
33 | line_arg="provider",
34 | line_vals=["einsum", "bmm", "custom"],
35 | line_names=["einsum", "bmm", "custom"],
36 | styles=[("blue", "-"), ("green", "-"), ("green", "--")],
37 | ylabel="latency (ms)",
38 | plot_name="blockshuffle-performance",
39 | args={"torch_dtype": torch.float16},
40 | )
41 | )
42 | def benchmark_blockshuffle(bs, blocks, in_blksz, out_blksz, torch_dtype, provider):
43 | input = torch.randn(bs, blocks * in_blksz, device="cuda", dtype=torch_dtype) * 0.02
44 | if in_blksz < out_blksz:
45 | w1 = (
46 | torch.randn(blocks, in_blksz, in_blksz, device="cuda", dtype=torch_dtype)
47 | * 0.02
48 | )
49 | w2 = (
50 | torch.randn(blocks, out_blksz, in_blksz, device="cuda", dtype=torch_dtype)
51 | * 0.02
52 | )
53 | else:
54 | w1 = (
55 | torch.randn(blocks, out_blksz, in_blksz, device="cuda", dtype=torch_dtype)
56 | * 0.02
57 | )
58 | w2 = (
59 | torch.randn(blocks, out_blksz, out_blksz, device="cuda", dtype=torch_dtype)
60 | * 0.02
61 | )
62 | quantiles = [0.5, 0.2, 0.8]
63 | if provider == "einsum":
64 | ms, min_ms, max_ms = triton.testing.do_bench(
65 | lambda: block_shuffle_einsum(input, w1, w2), quantiles=quantiles
66 | )
67 | if provider == "bmm":
68 | ms, min_ms, max_ms = triton.testing.do_bench(
69 | lambda: block_shuffle_bmm(input, w1, w2), quantiles=quantiles
70 | )
71 | if provider == "custom":
72 | ms, min_ms, max_ms = triton.testing.do_bench(
73 | lambda: block_shuffle_custom(input, w1, w2), quantiles=quantiles
74 | )
75 | return ms, max_ms, min_ms
76 |
77 |
78 | @triton.testing.perf_report(
79 | triton.testing.Benchmark(
80 | x_names=["bs", "blocks", "in_blksz", "r_blksz", "out"],
81 | x_vals=[
82 | (16 * 1024, 4, 512, 384, 512),
83 | (16 * 512, 16, 512, 384, 512 * 4),
84 | (32 * 1024, 4, 1024, 512, 1024),
85 | (32 * 1024, 2, 1024, 512, 4096 * 4),
86 | (64 * 1024, 16, 256, 128, 256 * 4),
87 | ],
88 | line_arg="provider",
89 | line_vals=["bmm", "custom"],
90 | line_names=["bmm", "custom"],
91 | styles=[("green", "-"), ("green", "--")],
92 | ylabel="latency (ms)",
93 | plot_name="block-linear-performance",
94 | args={"torch_dtype": torch.float16},
95 | )
96 | )
97 | def benchmark_blockdense(bs, blocks, in_blksz, r_blksz, out, torch_dtype, provider):
98 | input = torch.randn(bs, in_blksz * blocks, device="cuda", dtype=torch_dtype) * 0.02
99 | w1 = (
100 | torch.randn(
101 | blocks,
102 | r_blksz,
103 | in_blksz,
104 | device="cuda",
105 | dtype=torch_dtype,
106 | )
107 | * 0.02
108 | )
109 | w2 = torch.randn(out, r_blksz * blocks, device="cuda", dtype=torch_dtype) * 0.02
110 |
111 | quantiles = [0.5, 0.2, 0.8]
112 | if provider == "bmm":
113 | ms, min_ms, max_ms = triton.testing.do_bench(
114 | lambda: block_dense_bmm(input, w1, w2), quantiles=quantiles
115 | )
116 | if provider == "custom":
117 | ms, min_ms, max_ms = triton.testing.do_bench(
118 | lambda: block_dense_custom(input, w1, w2), quantiles=quantiles
119 | )
120 | return ms, max_ms, min_ms
121 |
122 |
123 | @triton.testing.perf_report(
124 | triton.testing.Benchmark(
125 | x_names=["bs", "seq", "d_in", "d_r", "d_out"],
126 | x_vals=[
127 | (16, 1024, 4 * 512, 384 * 4, 512 * 4),
128 | (16, 512, 16 * 512, 384 * 16, 512 * 16),
129 | (32, 1024, 4 * 1024, 512 * 4, 1024 * 4),
130 | (32, 1024, 2 * 1024, 512 * 2, 4096 * 2),
131 | (64, 1024, 16 * 256, 128 * 16, 256 * 16),
132 | ],
133 | line_arg="provider",
134 | line_vals=["custom"],
135 | line_names=["custom"],
136 | styles=[("green", "-"), ("green", "--")],
137 | ylabel="latency (ms)",
138 | plot_name="block-linear-performance",
139 | args={"torch_dtype": torch.float16},
140 | )
141 | )
142 | def benchmark_lowrank(bs, seq, d_in, d_r, d_out, torch_dtype, provider):
143 | input = torch.randn(bs, seq, d_in, device="cuda", dtype=torch_dtype) * 0.02
144 | w1 = torch.randn(d_r, d_in, device="cuda", dtype=torch_dtype) * 0.02
145 | w2 = torch.randn(d_out, d_r, device="cuda", dtype=torch_dtype) * 0.02
146 |
147 | quantiles = [0.5, 0.2, 0.8]
148 | if provider == "custom":
149 | ms, min_ms, max_ms = triton.testing.do_bench(
150 | lambda: low_rank_custom(input, w1, w2), quantiles=quantiles
151 | )
152 | return ms, max_ms, min_ms
153 |
154 |
155 | benchmark_blockshuffle.run(show_plots=True, print_data=True)
156 | benchmark_blockdense.run(show_plots=True, print_data=True)
157 | benchmark_lowrank.run(show_plots=True, print_data=True)
158 |
--------------------------------------------------------------------------------
/src/benchmark_eff/benchmark_mlp_train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch
5 | import random
6 | import os
7 | import hydra
8 | import triton
9 | import time
10 | import sys
11 | from hide_warnings import hide_warnings
12 |
13 | project_root = os.path.dirname(
14 | os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
15 | )
16 |
17 | print(project_root)
18 | sys.path.insert(0, project_root)
19 |
20 | from src.modules.mlp import (
21 | FusedBlockDenseMLP,
22 | FusedLowRankMLP,
23 | FusedBlockShuffleMLP,
24 | FusedMLP,
25 | )
26 |
27 | # pure bfloat16 efficiency
28 |
29 | name_to_method = {
30 | "lowrank": FusedLowRankMLP,
31 | "blockdense": FusedBlockDenseMLP,
32 | "blockshuffle": FusedBlockShuffleMLP,
33 | "linear": FusedMLP,
34 | }
35 | from omegaconf import DictConfig, OmegaConf
36 |
37 | torch_dtype = torch.bfloat16
38 |
39 |
40 | def set_seed(seed):
41 | random.seed(seed)
42 | os.environ["PYTHONHASHSEED"] = str(seed)
43 | np.random.seed(seed)
44 | torch.manual_seed(seed)
45 | torch.cuda.manual_seed(seed)
46 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
47 | torch.backends.cudnn.benchmark = False
48 | torch.backends.cudnn.deterministic = True
49 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
50 |
51 |
52 | def benchmark_train(net, inp):
53 | def fn(input):
54 | net(input)
55 |
56 | quantiles = [0.5, 0.2, 0.8]
57 |
58 | t, min_ms, max_ms = triton.testing.do_bench(
59 | lambda: fn(inp), quantiles=quantiles, warmup=50, rep=200
60 | )
61 | latency = t
62 | throughput = inp.shape[0] * inp.shape[1] / latency * 10**3
63 | print("Latency (ms): {}, Throughput (token/s): {}".format(latency, throughput))
64 | return latency, throughput
65 |
66 |
67 | @hide_warnings(out=False)
68 | @hydra.main(
69 | version_base=None,
70 | config_path="../configs",
71 | config_name="refinedweb_config",
72 | )
73 | def main(config):
74 | OmegaConf.register_new_resolver("eval", eval)
75 | config_model = config.model
76 | config_method = config.method
77 | f = open("../../exp/logs/fig3.log", "a+")
78 |
79 | if config_method.name == "linear":
80 | model = (
81 | FusedMLP(
82 | config_model.kwargs.hidden_dim,
83 | config_model.kwargs.ffn_dim,
84 | config_model.kwargs.bias,
85 | config_model.kwargs.act,
86 | )
87 | .cuda()
88 | .to(torch_dtype)
89 | )
90 | else:
91 | model = (
92 | name_to_method[config.method.name.lower()](
93 | config_model.kwargs.hidden_dim,
94 | config_model.kwargs.ffn_dim,
95 | config_model.kwargs.bias,
96 | config_model.kwargs.act,
97 | config_method.kwargs,
98 | config_model.kwargs.init,
99 | device="cuda",
100 | )
101 | .cuda()
102 | .to(torch_dtype)
103 | )
104 | model.eval()
105 | with torch.no_grad():
106 | input = (
107 | torch.randn(
108 | config.data.test.test_batch, 1024, config_model.kwargs.hidden_dim
109 | )
110 | .cuda()
111 | .to(torch_dtype)
112 | )
113 | latency, throughput = benchmark_train(model, input)
114 | if config_method.name == "linear":
115 | print(
116 | f"{config_model.kwargs.hidden_dim}, {config_model.kwargs.ffn_dim}", file=f
117 | )
118 | else:
119 | print(
120 | f"{config_model.kwargs.hidden_dim}, {config_model.kwargs.ffn_dim}, {model.get_ckpt_name(config_method.kwargs)}",
121 | file=f,
122 | )
123 | print(
124 | f"latency: {latency}, throughput: {throughput}, bs: {config.data.test.test_batch}, params: {sum(p.numel() for p in model.parameters())}",
125 | file=f,
126 | )
127 |
128 |
129 | if __name__ == "__main__":
130 | set_seed(1005)
131 | main()
132 | print("******END*******")
133 |
--------------------------------------------------------------------------------
/src/benchmark_eff/benchmark_model_infer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch
5 | import random
6 | import os
7 | import hydra
8 | import time
9 | import sys
10 |
11 | project_root = os.path.dirname(
12 | os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
13 | )
14 |
15 | print(project_root)
16 | sys.path.insert(0, project_root)
17 | from hide_warnings import hide_warnings
18 | from src.modules import get_model
19 | from omegaconf import DictConfig, OmegaConf
20 |
21 | torch_dtype = torch.bfloat16
22 | prefill = 0
23 |
24 |
25 | def set_seed(seed):
26 | random.seed(seed)
27 | os.environ["PYTHONHASHSEED"] = str(seed)
28 | np.random.seed(seed)
29 | torch.manual_seed(seed)
30 | torch.cuda.manual_seed(seed)
31 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
32 | torch.backends.cudnn.benchmark = False
33 | torch.backends.cudnn.deterministic = True
34 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
35 |
36 |
37 | @torch.no_grad()
38 | def benchmark_infer(net, generation, inp=None):
39 | net.eval()
40 | net.to(torch_dtype)
41 | seq_len = inp.shape[1]
42 | bs = inp.shape[0]
43 | inference_params = net.prepare_inference_params(
44 | bs,
45 | mx_seq=seq_len + generation,
46 | )
47 | tokens = bs * (seq_len + generation)
48 | repeat = 10
49 | warmup = 10
50 | for i in range(warmup):
51 | if seq_len > 0:
52 | inference_params.sequence_len_offset = 0
53 | net(inp, inference_params=inference_params)
54 | cur = torch.zeros(bs, 1).long().cuda()
55 | for j in range(seq_len, seq_len + generation):
56 | inference_params.sequence_len_offset = j
57 | net(input_ids=cur, inference_params=inference_params, use_cache=True)
58 | torch.cuda.synchronize()
59 | t0 = time.time()
60 | for i in range(repeat):
61 | if seq_len > 0:
62 | inference_params.sequence_len_offset = 0
63 | net(inp, inference_params=inference_params)
64 | cur = torch.zeros(bs, 1).long().cuda()
65 | for j in range(seq_len, seq_len + generation):
66 | inference_params.sequence_len_offset = j
67 | net(input_ids=cur, inference_params=inference_params, use_cache=True)
68 | torch.cuda.synchronize()
69 | t1 = time.time()
70 | latency = (t1 - t0) / repeat * (10**3)
71 | throughput = tokens * repeat / (t1 - t0)
72 | print("Latency (ms): {}, Throughput (token/s): {}".format(latency, throughput))
73 | return latency, throughput
74 |
75 |
76 | @hide_warnings(out=False)
77 | @hydra.main(
78 | version_base=None,
79 | config_path="../configs",
80 | config_name="refinedweb_config",
81 | )
82 | def main(config):
83 | OmegaConf.register_new_resolver("eval", eval)
84 | model = get_model(config)
85 | model.eval()
86 | f = open("../../exp/logs/arch_infer_latency.log", "a+")
87 | print(
88 | "h-f-a-nq-nkv",
89 | config.model.kwargs.hidden_dim,
90 | config.model.kwargs.ffn_dim,
91 | config.model.kwargs.attn_dim,
92 | config.model.kwargs.num_q_heads,
93 | config.model.kwargs.num_kv_heads,
94 | file=f,
95 | )
96 | if config.method.name != "linear":
97 | print(config.method.kwargs, file=f)
98 | input_ids = torch.zeros(config.data.test.test_batch, prefill).long().cuda()
99 | latency, throughput = benchmark_infer(
100 | model, config.model.kwargs.max_position_embeddings, input_ids
101 | )
102 | params = sum(p.numel() for p in model.parameters())
103 | params_woemb = params - 32000 * config.model.kwargs.hidden_dim
104 |
105 | print(
106 | f"bs: {config.data.test.test_batch}, generation: {config.model.kwargs.max_position_embeddings}, latency: {latency}, params: {params}, params_woemb: {params_woemb}, throughput: {throughput}",
107 | file=f,
108 | )
109 |
110 |
111 | if __name__ == "__main__":
112 | set_seed(1005)
113 | main()
114 | print("******END*******")
115 |
--------------------------------------------------------------------------------
/src/benchmark_eff/benchmark_model_train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch
5 | import random
6 | import os
7 | import hydra
8 | import time
9 | import sys
10 | from hide_warnings import hide_warnings
11 |
12 | project_root = os.path.dirname(
13 | os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
14 | )
15 |
16 | print(project_root)
17 | sys.path.insert(0, project_root)
18 |
19 | from src.modules import get_model
20 | from omegaconf import DictConfig, OmegaConf
21 |
22 |
23 | torch_dtype = torch.bfloat16
24 |
25 |
26 | def set_seed(seed):
27 | random.seed(seed)
28 | os.environ["PYTHONHASHSEED"] = str(seed)
29 | np.random.seed(seed)
30 | torch.manual_seed(seed)
31 | torch.cuda.manual_seed(seed)
32 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
33 | torch.backends.cudnn.benchmark = False
34 | torch.backends.cudnn.deterministic = True
35 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
36 |
37 |
38 | def benchmark_train(net, inp):
39 | tokens = inp.shape[0] * inp.shape[1]
40 | repeat = 50
41 | warmup = 10
42 | for i in range(warmup):
43 | net(inp)
44 | torch.cuda.synchronize()
45 | t0 = time.time()
46 | for i in range(repeat):
47 | net(inp)
48 | torch.cuda.synchronize()
49 | t1 = time.time()
50 | latency = (t1 - t0) / repeat * (10**3)
51 | throughput = tokens * repeat / (t1 - t0)
52 | print("Latency (ms): {}, Throughput (token/s): {}".format(latency, throughput))
53 | return latency, throughput
54 |
55 |
56 | @hide_warnings(out=False)
57 | @hydra.main(
58 | version_base=None,
59 | config_path="../configs",
60 | config_name="refinedweb_config",
61 | )
62 | def main(config):
63 | OmegaConf.register_new_resolver("eval", eval)
64 | model = get_model(config)
65 | model.eval()
66 | f = open("../../exp/logs/arch_train_latency.log", "a+")
67 | print(
68 | "h-f-a-nq-nkv",
69 | config.model.kwargs.hidden_dim,
70 | config.model.kwargs.ffn_dim,
71 | config.model.kwargs.attn_dim,
72 | config.model.kwargs.num_q_heads,
73 | config.model.kwargs.num_kv_heads,
74 | file=f,
75 | )
76 | if config.method.name != "linear":
77 | print(config.method.kwargs, file=f)
78 | with torch.no_grad():
79 | input_ids = torch.zeros(config.data.test.test_batch, 1024).long().cuda()
80 | model.to(torch_dtype)
81 | latency, throughput = benchmark_train(model, input_ids)
82 | params = sum(p.numel() for p in model.parameters())
83 | params_woemb = params - 32000 * config.model.kwargs.hidden_dim
84 | print(
85 | f"bs: {config.data.test.test_batch}, latency: {latency}, params: {params}, params_woemb: {params_woemb}, throughput: {throughput}",
86 | file=f,
87 | )
88 |
89 |
90 | if __name__ == "__main__":
91 | set_seed(1005)
92 | main()
93 | print("******END*******")
94 |
--------------------------------------------------------------------------------
/src/benchmark_eff/cac_batch.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch
5 | import random
6 | import os
7 | import hydra
8 | import time
9 | import sys
10 |
11 | project_root = os.path.dirname(
12 | os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
13 | )
14 |
15 | print(project_root)
16 | sys.path.insert(0, project_root)
17 |
18 | from src.modules import get_model
19 | from omegaconf import DictConfig, OmegaConf
20 | from hide_warnings import hide_warnings
21 |
22 | torch_dtype = torch.bfloat16
23 | prefill = 0
24 |
25 |
26 | def set_seed(seed):
27 | random.seed(seed)
28 | os.environ["PYTHONHASHSEED"] = str(seed)
29 | np.random.seed(seed)
30 | torch.manual_seed(seed)
31 | torch.cuda.manual_seed(seed)
32 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
33 | torch.backends.cudnn.benchmark = False
34 | torch.backends.cudnn.deterministic = True
35 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
36 |
37 |
38 | def func(batch_size, generation, model):
39 | try:
40 | torch.cuda.empty_cache()
41 | cur = torch.zeros(batch_size, 1).long().cuda() # test the last token directly
42 |
43 | inference_params = model.prepare_inference_params(
44 | batch_size,
45 | mx_seq=prefill + generation,
46 | )
47 | inference_params.sequence_len_offset = prefill + generation - 1
48 | model(input_ids=cur, inference_params=inference_params, use_cache=True)
49 | except RuntimeError as e:
50 | return None
51 | return batch_size
52 |
53 |
54 | @torch.no_grad()
55 | def find_max_batch_size(model, generation):
56 | start = 256
57 | batch_size = start
58 | max_batch_size = start
59 | step = 256
60 | while True:
61 | if func(batch_size, generation, model):
62 | max_batch_size = batch_size
63 | batch_size += step
64 | else:
65 | break
66 | print(f"bs: {max_batch_size}")
67 | return max_batch_size
68 |
69 |
70 | @hide_warnings(out=False)
71 | @hydra.main(
72 | version_base=None,
73 | config_path="../configs",
74 | config_name="refinedweb_config",
75 | )
76 | def main(config):
77 | OmegaConf.register_new_resolver("eval", eval)
78 | model = get_model(config)
79 | model.eval()
80 | model.to(torch_dtype)
81 | f = open("../../exp/logs/arch_bs.log", "a+")
82 | print(
83 | "h-f-a-nq-nkv",
84 | config.model.kwargs.hidden_dim,
85 | config.model.kwargs.ffn_dim,
86 | config.model.kwargs.attn_dim,
87 | config.model.kwargs.num_q_heads,
88 | config.model.kwargs.num_kv_heads,
89 | file=f,
90 | )
91 | if config.method.name != "linear":
92 | print(config.method.kwargs, file=f)
93 | bs = find_max_batch_size(model, config.model.kwargs.max_position_embeddings)
94 | print(
95 | f"bs: {bs}, generation: {config.model.kwargs.max_position_embeddings}",
96 | file=f,
97 | )
98 |
99 |
100 | if __name__ == "__main__":
101 | set_seed(1005)
102 | main()
103 | print("******END*******")
104 |
--------------------------------------------------------------------------------
/src/configs/data/Aug/mixup.yaml:
--------------------------------------------------------------------------------
1 | mixup:
2 | name: mixup
3 | kwargs:
4 | alpha: 0.2 # 0.2 or 0.4
5 |
--------------------------------------------------------------------------------
/src/configs/data/Aug/randomaugment.yaml:
--------------------------------------------------------------------------------
1 | randomaugment:
2 | name: RandomAugment
3 | kwargs:
4 | n: 2
5 | m: 14
6 |
--------------------------------------------------------------------------------
/src/configs/data/cifar10.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - Aug:
3 | - randomaugment
4 | - mixup
5 |
6 | name: cifar10
7 | train:
8 | path: "/claire-rcp-scratch/shared/xwei/dataset"
9 | train_batch: 512
10 | # sweep:
11 | # values: [32, 64, 128, 256]
12 | test:
13 | path: "/claire-rcp-scratch/shared/xwei/dataset"
14 | test_batch: 512
--------------------------------------------------------------------------------
/src/configs/data/refinedweb.yaml:
--------------------------------------------------------------------------------
1 | name: refinedweb
2 | train:
3 | path: "/claire-rcp-scratch/shared/xwei/dataset/refinedweb"
4 | train_batch: 16
5 | test:
6 | path: "/claire-rcp-scratch/shared/xwei/dataset/refinedweb"
7 | test_batch: 32
8 | overwrite_cache: false
9 | num_workers: 16
10 | block_size: 1024
11 | tokenizer:
12 | name: null
13 | model_max_length: 1024
--------------------------------------------------------------------------------
/src/configs/method/blockdense.yaml:
--------------------------------------------------------------------------------
1 | name: blockdense
2 | kwargs:
3 | first_layer: true
4 | nblocks: 4
5 | rank: 256
6 | training:
7 | enabled: false
8 | init:
9 | post_init: ortho
10 |
--------------------------------------------------------------------------------
/src/configs/method/blockshuffle.yaml:
--------------------------------------------------------------------------------
1 | name: blockshuffle
2 | kwargs:
3 | first_layer: true
4 | nblocks: 2
5 | training:
6 | enabled: false
7 | init:
8 | post_init: ortho
9 |
--------------------------------------------------------------------------------
/src/configs/method/linear.yaml:
--------------------------------------------------------------------------------
1 | name: linear
2 |
--------------------------------------------------------------------------------
/src/configs/method/lowrank.yaml:
--------------------------------------------------------------------------------
1 | name: lowrank
2 | kwargs:
3 | first_layer: true
4 | rank: 256
5 | training:
6 | enabled: false
7 | init:
8 | post_init: svd
9 |
--------------------------------------------------------------------------------
/src/configs/model/gpt2.yaml:
--------------------------------------------------------------------------------
1 | name: gpt2
2 | kwargs:
3 | model_type: llama
4 | bos_token_id: 1
5 | eos_token_id: 2
6 | hidden_dim: 768
7 | attn_dim: 768
8 | ffn_dim: 3072
9 | num_q_heads: 12
10 | num_kv_heads: 12
11 | num_layers: 12
12 | hidden_drop: 0.0
13 | embd_drop: 0.0
14 | max_position_embeddings: 1024
15 | vocab_size: 32000
16 | tie_word_embeddings: true
17 | ln: layernorm
18 | act: gelu
19 | bias: true
20 | scale_attn_by_inverse_layer_idx: false
21 | pos_emb:
22 | name: rope
23 | rotary_interleaved: false
24 | seq_len_interpolation_factor: null
25 | rotary_base: 10000
26 | init:
27 | weight_init: fixed
28 | initializer_range: 0.02
29 |
--------------------------------------------------------------------------------
/src/configs/model/gpt2l.yaml:
--------------------------------------------------------------------------------
1 | name: gpt2 # to indicate Model Function
2 | kwargs:
3 | model_type: llama # to indicate tokenizer
4 | bos_token_id: 1
5 | eos_token_id: 2
6 | hidden_dim: 1536
7 | attn_dim: 1536
8 | ffn_dim: 6144
9 | num_q_heads: 12
10 | num_kv_heads: 12
11 | num_layers: 24
12 | hidden_drop: 0.0
13 | embd_drop: 0.0
14 | max_position_embeddings: 1024
15 | vocab_size: 32000
16 | tie_word_embeddings: true
17 | ln: layernorm
18 | act: gelu
19 | bias: true
20 | scale_attn_by_inverse_layer_idx: false
21 | pos_emb:
22 | name: rope
23 | rotary_interleaved: false
24 | seq_len_interpolation_factor: null
25 | rotary_base: 10000
26 | init:
27 | weight_init: fixed
28 | initializer_range: 0.02
29 |
--------------------------------------------------------------------------------
/src/configs/model/gpt2m.yaml:
--------------------------------------------------------------------------------
1 | name: gpt2
2 | kwargs:
3 | model_type: llama
4 | bos_token_id: 1
5 | eos_token_id: 2
6 | hidden_dim: 1024
7 | attn_dim: 1024
8 | ffn_dim: 4096
9 | num_q_heads: 16
10 | num_kv_heads: 16
11 | num_layers: 24
12 | hidden_drop: 0.0
13 | embd_drop: 0.0
14 | max_position_embeddings: 1024
15 | vocab_size: 32000
16 | tie_word_embeddings: true
17 | ln: layernorm
18 | act: gelu
19 | bias: true
20 | scale_attn_by_inverse_layer_idx: false
21 | pos_emb:
22 | name: rope
23 | rotary_interleaved: false
24 | seq_len_interpolation_factor: null
25 | rotary_base: 10000
26 | init:
27 | weight_init: fixed
28 | initializer_range: 0.02
29 |
--------------------------------------------------------------------------------
/src/configs/model/gpt2xl.yaml:
--------------------------------------------------------------------------------
1 | name: gpt2
2 | kwargs:
3 | model_type: llama
4 | bos_token_id: 1
5 | eos_token_id: 2
6 | hidden_dim: 2048
7 | attn_dim: 2048
8 | ffn_dim: 8192
9 | num_q_heads: 16
10 | num_kv_heads: 16
11 | num_layers: 24
12 | hidden_drop: 0.0
13 | embd_drop: 0.0
14 | max_position_embeddings: 1024
15 | vocab_size: 32000
16 | tie_word_embeddings: true
17 | ln: layernorm
18 | act: gelu
19 | bias: true
20 | scale_attn_by_inverse_layer_idx: false
21 | pos_emb:
22 | name: rope
23 | rotary_interleaved: false
24 | seq_len_interpolation_factor: null
25 | rotary_base: 10000
26 | init:
27 | weight_init: fixed
28 | initializer_range: 0.02
29 |
--------------------------------------------------------------------------------
/src/configs/optimization/basic.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - lr_scheduler: cosineannealinglr
3 | - optimizer: adamw
4 | - training: regular_training
5 |
6 | max_epoch: 1000
7 | device: cuda
8 | seed: 1005
9 | load_checkpoint: false
10 | save_checkpoint: false
11 | save_dir: /home/xwei/transformers/final_version/exp/ckpt/
12 | load_save_mode: epoch
13 | check_gradient_norm: false
14 | check_weight_norm: false
15 | gradient_clipping: false
--------------------------------------------------------------------------------
/src/configs/optimization/lr_scheduler/cosineannealinglr.yaml:
--------------------------------------------------------------------------------
1 | name: CosineAnnealingLR
2 | kwargs:
3 | warmup_iter: 0.02
4 | T_max: 1000
5 | eta_min: 1.0e-7
--------------------------------------------------------------------------------
/src/configs/optimization/lr_scheduler/multisteplr.yaml:
--------------------------------------------------------------------------------
1 | name: MultiStepLR
2 | kwargs:
3 | milestones: [0.3, 0.6, 0.8]
4 | gamma: 0.1
5 | T_max: 1000
6 |
--------------------------------------------------------------------------------
/src/configs/optimization/optimizer/adam.yaml:
--------------------------------------------------------------------------------
1 | name: adam
2 | kwargs:
3 | lr: 1.0e-4
4 |
--------------------------------------------------------------------------------
/src/configs/optimization/optimizer/adamw.yaml:
--------------------------------------------------------------------------------
1 | name: adamw
2 | kwargs:
3 | lr: 1.0e-4
4 | weight_decay: 0.01
--------------------------------------------------------------------------------
/src/configs/optimization/optimizer/sgd.yaml:
--------------------------------------------------------------------------------
1 | name: sgd
2 | kwargs:
3 | lr: 0.1
4 | momentum: 0.9
5 | weight_decay: 5.0e-4
6 |
--------------------------------------------------------------------------------
/src/configs/optimization/training/regular_training.yaml:
--------------------------------------------------------------------------------
1 | name: regular_training
2 |
--------------------------------------------------------------------------------
/src/configs/optimization/training/self_guided_training.yaml:
--------------------------------------------------------------------------------
1 | name: self_guided_training
2 | kwargs:
3 | mode: fixedstep # fixedstep, fixedflops
4 | scheduler: cosine
5 | reduce_flop: false
6 | max_step: null
7 | max_step_ratio: 0.5 # ratio of the total steps
--------------------------------------------------------------------------------
/src/configs/refinedweb_config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - method: &method linear
3 | - optimization: &optimization basic
4 | - data: &data refinedweb
5 | - model: &model gpt2
6 | - _self_
7 |
8 | # rewrite optimization cifar here
9 | optimization:
10 | max_tokens: 2200000000
11 | global_batch_size: 512
12 | gradient_checkpointing: false
13 | gradient_clipping: 1.0
14 | log_interval: 20
15 | load_save_mode: step
16 | load_checkpoint: true
17 | save_checkpoint: true
18 | optimizer:
19 | kwargs:
20 | lr: &lr 6.0e-4
21 | weight_decay: 0.1
22 | betas: [0.9, 0.999]
23 | lr_scheduler:
24 | kwargs:
25 | warmup_iter: 0.1
26 | eta_min: ${eval:0.1 * ${optimization.optimizer.kwargs.lr}}
27 |
28 | data:
29 | tokenizer:
30 | name: ${model.kwargs.model_type}
31 | model_max_length: ${model.kwargs.max_position_embeddings}
32 | block_size: ${model.kwargs.max_position_embeddings}
33 |
34 | base_dir: &base_dir /home/xwei/transformers/final_version/exp/
35 |
36 | wandb:
37 | entity: xiuying-wei
38 | project: gpt2reproduce
39 | mode: online
40 | anonymous: allow
41 | dir: *base_dir
42 |
43 | wandb_use: true
44 | hydra:
45 | run:
46 | dir: /home/xwei/transformers/final_version/exp/hydra/${now:%Y-%m-%d}/${now:%H-%M-%S}
47 |
--------------------------------------------------------------------------------
/src/modules/__init__.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from easydict import EasyDict
3 | from .model import GPT2LMHeadModel
4 | from .mlp import (
5 | FusedBlockDenseMLP,
6 | FusedLowRankMLP,
7 | FusedBlockShuffleMLP,
8 | FusedMLP,
9 | )
10 |
11 | name_to_model = {
12 | "gpt2": GPT2LMHeadModel,
13 | }
14 |
15 | name_to_method = {
16 | "lowrank": FusedLowRankMLP,
17 | "blockdense": FusedBlockDenseMLP,
18 | "blockshuffle": FusedBlockShuffleMLP,
19 | }
20 |
21 |
22 | def replace_mlp(model, config_method, config_model, device="cuda"):
23 | first_layer = (
24 | config_method.kwargs.first_layer
25 | ) # true: keep the original linear layer
26 | for i in range(config_model.kwargs.num_layers):
27 | if first_layer and i == 0:
28 | continue
29 | new_module = name_to_method[config_method.name.lower()](
30 | config_model.kwargs.hidden_dim,
31 | config_model.kwargs.ffn_dim,
32 | config_model.kwargs.bias,
33 | config_model.kwargs.act,
34 | config_method.kwargs,
35 | config_model.kwargs.init,
36 | device=device,
37 | )
38 | del model.model.layers[i].mlp
39 | model.model.layers[i].mlp = new_module
40 |
41 |
42 | def get_model(config, device="cuda"):
43 | config_model = config.model
44 | config_method = config.method
45 | model = name_to_model[config_model.name.lower()](config_model.get("kwargs", {})).to(
46 | device
47 | )
48 |
49 | # replace here
50 | if config_method.name.lower() == "linear":
51 | return model
52 | replace_mlp(model, config_method, config_model, device)
53 | model.to(device)
54 | return model
55 |
56 |
57 | def get_ckpt_name(config):
58 | config_model = config.model
59 | config_method = config.method
60 | long_name = config_model.name + name_to_model[
61 | config_model.name.lower()
62 | ].get_ckpt_name(config_model.get("kwargs", {}))
63 | if config_method.name != "linear":
64 | long_name += (
65 | "-"
66 | + config_method.name
67 | + name_to_method[config_method.name.lower()].get_ckpt_name(
68 | config_method.get("kwargs", {})
69 | )
70 | )
71 | return long_name
72 |
73 |
74 | def update_ratio(module):
75 | if hasattr(module, "_update_ratio"):
76 | module._update_ratio()
77 |
--------------------------------------------------------------------------------
/src/modules/layer/__init__.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .customlinear import CustomLinear
3 | from .lowrank import LowRank
4 | from .blockdense import BlockDense
5 | from .blockshuffle import BlockShuffle
6 |
--------------------------------------------------------------------------------
/src/modules/layer/basiclinear.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 | from .util import LinearTempDecay, CosineTempDecay
5 |
6 |
7 | class BasicLinear(nn.Module):
8 |
9 | def __init__(
10 | self, in_features, out_features, bias, return_bias, config, init_config, device
11 | ):
12 | super().__init__()
13 | # config: method part, and model init
14 | self.device = device
15 | self.config = config
16 | self.init_config = init_config
17 | self.training_config = self.config.training
18 | # model part
19 | self.in_features = in_features
20 | self.out_features = out_features
21 | # otherwise, we need to fuse the bias into the ops
22 | assert return_bias is True
23 | if bias:
24 | self.bias = nn.Parameter(torch.empty(self.out_features, device=device))
25 | else:
26 | self.bias = None
27 |
28 | if self.training_config.enabled:
29 | self.guide_linear = nn.Parameter(
30 | torch.empty(self.out_features, self.in_features, device=device)
31 | )
32 | self.register_buffer("count", torch.tensor(0).cuda(), persistent=True)
33 | self.register_buffer("ratio", torch.tensor(1.0).cuda(), persistent=True)
34 | guide_scheduler = {
35 | "linear": LinearTempDecay,
36 | "cosine": CosineTempDecay,
37 | }
38 | self.guide_scheduler = guide_scheduler[self.training_config.scheduler](
39 | t_max=self.training_config.max_step
40 | )
41 |
42 | @torch.no_grad()
43 | def _update_ratio(
44 | self,
45 | ):
46 | self.count += 1
47 | self.ratio = self.guide_scheduler(self.count)
48 |
49 | def _check_guide_layer(
50 | self,
51 | ):
52 | if not self.training_config.enabled:
53 | return False
54 | if (
55 | self.training_config.reduce_flop
56 | and torch.rand_like(self.ratio) >= self.ratio
57 | ):
58 | return False
59 | return True
60 |
61 | def forward_guide_layer(self, input, out):
62 | if self._check_guide_layer():
63 | guide_out = torch.matmul(input, self.guide_linear.transpose(-1, -2))
64 | out = self.ratio * guide_out + (1.0 - self.ratio) * out
65 | return out, self.bias
66 |
67 | def get_weights(
68 | self,
69 | ):
70 | pass
71 |
72 | @torch.no_grad()
73 | def _init_weights(
74 | self,
75 | ):
76 | if self.bias is not None:
77 | nn.init.zeros_(self.bias)
78 | for para in self.get_weights():
79 | if self.init_config.weight_init == "xavier":
80 | nn.init.normal_(para, mean=0.0, std=(para.shape[-1] ** -0.5))
81 | elif self.init_config.weight_init == "fixed":
82 | nn.init.normal_(para, std=self.init_config.initializer_range)
83 | else:
84 | raise NotImplementedError
85 |
--------------------------------------------------------------------------------
/src/modules/layer/blockdense.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from ..op import block_dense_custom
4 | from .basiclinear import BasicLinear
5 |
6 |
7 | class BlockDense(BasicLinear):
8 |
9 | def __init__(
10 | self,
11 | in_features,
12 | out_features,
13 | bias,
14 | return_bias,
15 | config,
16 | init_config,
17 | device="cuda",
18 | ):
19 | super().__init__(
20 | in_features, out_features, bias, return_bias, config, init_config, device
21 | )
22 | self.rank = config["rank"]
23 | self.nblocks = config["nblocks"]
24 | assert self.in_features % self.nblocks == 0
25 | assert self.rank % self.nblocks == 0
26 | self.blkdiag = nn.Parameter(
27 | torch.empty(
28 | self.nblocks,
29 | self.rank // self.nblocks,
30 | self.in_features // self.nblocks,
31 | device=device,
32 | )
33 | )
34 | self.lr = nn.Parameter(torch.empty(self.out_features, self.rank, device=device))
35 |
36 | self._init_weights()
37 | self.post_init()
38 |
39 | def get_weights(
40 | self,
41 | ):
42 | return [self.blkdiag, self.lr]
43 |
44 | @torch.no_grad()
45 | def post_init(
46 | self,
47 | ):
48 | if self.config.init.post_init == "ortho":
49 | for i in range(self.nblocks):
50 | U, S, Vh = torch.linalg.svd(self.blkdiag.data[i], full_matrices=False)
51 | self.blkdiag.data[i] = torch.mm(U, Vh)
52 | U, S, Vh = torch.linalg.svd(self.lr.data, full_matrices=False)
53 | self.lr.data = torch.mm(U, Vh)
54 | # init guide linear
55 | if hasattr(self, "guide_linear"):
56 | self.guide_linear.data = torch.mm(
57 | self.lr.data, torch.block_diag(*torch.unbind(self.blkdiag.data, dim=0))
58 | )
59 |
60 | def forward(self, input):
61 | out = block_dense_custom(input, self.blkdiag, self.lr)
62 | return self.forward_guide_layer(input, out)
63 |
64 | def extra_repr(self) -> str:
65 | return f"blockdiag1={self.blkdiag.shape}, linear={self.lr.shape}, bias={self.bias is not None}, guide={self.training_config.enabled}"
66 |
--------------------------------------------------------------------------------
/src/modules/layer/blockshuffle.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from .basiclinear import BasicLinear
4 | from ..op import block_shuffle_bmm, block_shuffle_custom
5 |
6 |
7 | class BlockShuffle(BasicLinear):
8 |
9 | def __init__(
10 | self,
11 | in_features,
12 | out_features,
13 | bias,
14 | return_bias,
15 | config,
16 | init_config,
17 | device="cuda",
18 | ):
19 | super().__init__(
20 | in_features, out_features, bias, return_bias, config, init_config, device
21 | )
22 | self.nblocks = config["nblocks"]
23 | assert self.in_features % self.nblocks == 0
24 | assert self.out_features % self.nblocks == 0
25 |
26 | in_blksz = self.in_features // self.nblocks
27 | out_blksz = self.out_features // self.nblocks
28 |
29 | if self.in_features < self.out_features:
30 | self.blkdiag1 = nn.Parameter(
31 | torch.empty(self.nblocks, in_blksz, in_blksz, device=device)
32 | )
33 | self.blkdiag2 = nn.Parameter(
34 | torch.empty(self.nblocks, out_blksz, in_blksz, device=device)
35 | )
36 | else:
37 | self.blkdiag1 = nn.Parameter(
38 | torch.empty(self.nblocks, out_blksz, in_blksz, device=device)
39 | )
40 | self.blkdiag2 = nn.Parameter(
41 | torch.empty(self.nblocks, out_blksz, out_blksz, device=device)
42 | )
43 | self._init_weights()
44 | self.post_init()
45 |
46 | def get_weights(
47 | self,
48 | ):
49 | return [self.blkdiag1, self.blkdiag2]
50 |
51 | @torch.no_grad()
52 | def post_init(
53 | self,
54 | ):
55 | if self.config.init.post_init == "ortho":
56 | for i in range(self.nblocks):
57 | U, S, Vh = torch.linalg.svd(self.blkdiag1.data[i], full_matrices=False)
58 | self.blkdiag1.data[i] = torch.mm(U, Vh)
59 | U, S, Vh = torch.linalg.svd(self.blkdiag2.data[i], full_matrices=False)
60 | self.blkdiag2.data[i] = torch.mm(U, Vh)
61 |
62 | # init guide linear
63 | if hasattr(self, "guide_linear"):
64 | self.guide_linear.data = torch.mm(
65 | torch.block_diag(*torch.unbind(self.blkdiag2.data, dim=0)),
66 | torch.block_diag(*torch.unbind(self.blkdiag1.data, dim=0)),
67 | )
68 |
69 | def forward(self, input):
70 | out = block_shuffle_custom(input, self.blkdiag1, self.blkdiag2)
71 | return self.forward_guide_layer(input, out)
72 |
73 | def extra_repr(self) -> str:
74 | return f"blockdiag1={self.blkdiag1.shape}, blockdiag2={self.blkdiag2.shape}, bias={self.bias is not None}, guide={self.training_config.enabled}"
75 |
--------------------------------------------------------------------------------
/src/modules/layer/customlinear.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 |
5 | # please do not inhere the basic linear here
6 | class CustomLinear(nn.Module):
7 |
8 | def __init__(self, in_features, out_features, bias, return_bias=True):
9 | super().__init__()
10 | self.in_features = in_features
11 | self.out_features = out_features
12 | self.weight = nn.Parameter(
13 | torch.empty(
14 | out_features,
15 | in_features,
16 | )
17 | )
18 | # otherwise, we need to fuse the bias into the ops
19 | assert return_bias is True
20 |
21 | if bias:
22 | self.bias = nn.Parameter(torch.empty(out_features))
23 | else:
24 | self.bias = None
25 |
26 | def forward(self, inp):
27 | output = torch.matmul(inp, self.weight.transpose(-1, -2))
28 | return output, self.bias
29 |
30 | def extra_repr(self) -> str:
31 | return f"linearshape={self.weight.shape}, bias={self.bias is not None}"
32 |
--------------------------------------------------------------------------------
/src/modules/layer/lowrank.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from .basiclinear import BasicLinear
4 | from ..op import low_rank_custom
5 |
6 |
7 | class LowRank(BasicLinear):
8 |
9 | def __init__(
10 | self,
11 | in_features,
12 | out_features,
13 | bias,
14 | return_bias,
15 | config,
16 | init_config,
17 | device="cuda",
18 | ):
19 | super().__init__(
20 | in_features,
21 | out_features,
22 | bias,
23 | return_bias,
24 | config,
25 | init_config,
26 | device=device,
27 | )
28 | self.rank = config["rank"]
29 | self.lr1 = nn.Parameter(torch.empty(self.rank, self.in_features, device=device))
30 | self.lr2 = nn.Parameter(
31 | torch.empty(self.out_features, self.rank, device=device)
32 | )
33 | self._init_weights()
34 | self.post_init()
35 |
36 | def get_weights(
37 | self,
38 | ):
39 | return [self.lr1, self.lr2]
40 |
41 | @torch.no_grad()
42 | def post_init(
43 | self,
44 | ):
45 | if self.config.init.post_init == "svd":
46 | org_linear = nn.Parameter(
47 | torch.empty(self.out_features, self.in_features, device=self.device)
48 | )
49 | if self.init_config.weight_init == "xavier":
50 | nn.init.normal_(
51 | org_linear, mean=0.0, std=(org_linear.shape[-1] ** -0.5)
52 | )
53 | elif self.init_config.weight_init == "fixed":
54 | nn.init.normal_(org_linear, std=self.init_config.initializer_range)
55 | else:
56 | raise NotImplementedError
57 | U, S, Vh = torch.linalg.svd(org_linear, full_matrices=False)
58 | sqrt_S = torch.sqrt(torch.diag_embed(S[: self.rank]))
59 | self.lr1.data = sqrt_S @ Vh[: self.rank, :]
60 | self.lr2.data = U[:, : self.rank] @ sqrt_S
61 |
62 | # init guide linear
63 | if hasattr(self, "guide_linear"):
64 | self.guide_linear.data = torch.mm(self.lr2.data, self.lr1.data)
65 |
66 | def forward(self, input):
67 | out = low_rank_custom(input, self.lr1, self.lr2)
68 | return self.forward_guide_layer(input, out)
69 |
70 | def extra_repr(self) -> str:
71 | return f"lr1={self.lr1.shape}, lr2={self.lr2.shape}, bias={self.bias is not None}, guide={self.training_config.enabled}"
72 |
--------------------------------------------------------------------------------
/src/modules/layer/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 |
4 |
5 | class LinearTempDecay:
6 | def __init__(self, t_max=20000, warm_up=0, start_b=1.0, end_b=0.0):
7 | self.t_max = t_max
8 | self.warmup = warm_up
9 | self.start_b = torch.tensor(start_b).cuda()
10 | self.end_b = torch.tensor(end_b).cuda()
11 | print(
12 | "linear scheduler for self-guided training in steps {} with warmup {}".format(
13 | self.t_max, self.warmup
14 | )
15 | )
16 |
17 | def __call__(self, t):
18 | if t < self.warmup:
19 | return self.start_b
20 | elif t > self.t_max:
21 | return self.end_b
22 | else:
23 | rel_t = (t - self.warmup) / (self.t_max - self.warmup)
24 | return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t))
25 |
26 |
27 | class CosineTempDecay:
28 | def __init__(self, t_max=20000, warm_up=0, start_b=1.0, end_b=0.0):
29 | self.t_max = t_max
30 | self.warmup = warm_up
31 | self.start_b = torch.tensor(start_b).cuda()
32 | self.end_b = torch.tensor(end_b).cuda()
33 | print(
34 | "Cosine scheduler for self-guided training in steps {} with warmup {}".format(
35 | self.t_max, self.warmup
36 | )
37 | )
38 |
39 | def __call__(self, t):
40 | if t < self.warmup:
41 | return self.start_b
42 | elif t > self.t_max:
43 | return self.end_b
44 | else:
45 | rel_t = (t - self.warmup) / (self.t_max - self.warmup)
46 | return self.end_b + 0.5 * (self.start_b - self.end_b) * (
47 | 1 + torch.cos(rel_t * math.pi)
48 | )
49 |
--------------------------------------------------------------------------------
/src/modules/mlp/__init__.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .lowrank_mlp import FusedLowRankMLP
3 | from .blockdense_mlp import FusedBlockDenseMLP
4 | from .blockshuffle_mlp import FusedBlockShuffleMLP
5 | from .mlp import FusedMLP
6 |
--------------------------------------------------------------------------------
/src/modules/mlp/basic_mlp.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from transformer_engine.pytorch.jit import set_jit_fusion_options
3 | from ..op import bias_gelu_impl, bias_swiglu_impl
4 |
5 |
6 | act_func_dict = {
7 | "gelu": bias_gelu_impl,
8 | "swiglu": bias_swiglu_impl,
9 | }
10 |
11 |
12 | class FusedBasicMLP(nn.Module):
13 | def __init__(self, hidden_dim, ffn_dim, bias, act="gelu"):
14 | super().__init__()
15 | self.hidden_dim = hidden_dim
16 | self.ffn_dim = ffn_dim
17 | # fuse bias and gelu
18 | set_jit_fusion_options()
19 | self.fc1 = None
20 | self.fc2 = None
21 | self.act_func = act_func_dict.get(act, None)
22 | if act in ["swiglu"]:
23 | self.ffn_dim *= 2
24 |
25 | def forward(self, input):
26 | fc1_outs = self.fc1(input)
27 | gelu_out = self.act_func(*fc1_outs)
28 | fc2_outs = self.fc2(gelu_out)
29 | return fc2_outs
30 |
--------------------------------------------------------------------------------
/src/modules/mlp/blockdense_mlp.py:
--------------------------------------------------------------------------------
1 | from .basic_mlp import FusedBasicMLP
2 | from ..layer import BlockDense
3 |
4 |
5 | class FusedBlockDenseMLP(FusedBasicMLP):
6 |
7 | def __init__(
8 | self,
9 | hidden_dim,
10 | ffn_dim,
11 | bias,
12 | act,
13 | config,
14 | init_config,
15 | device,
16 | ):
17 | super().__init__(hidden_dim, ffn_dim, bias, act=act)
18 | self.fc1 = BlockDense(
19 | hidden_dim,
20 | self.ffn_dim,
21 | bias=bias,
22 | return_bias=True,
23 | config=config,
24 | init_config=init_config,
25 | device=device,
26 | )
27 | self.fc2 = BlockDense(
28 | ffn_dim,
29 | hidden_dim,
30 | bias=bias,
31 | return_bias=True,
32 | config=config,
33 | init_config=init_config,
34 | device=device,
35 | )
36 |
37 | @staticmethod
38 | def get_ckpt_name(config_method):
39 | long_name = (
40 | "r"
41 | + str(config_method.rank)
42 | + "b"
43 | + str(config_method.nblocks)
44 | + "-"
45 | + str(config_method.init.post_init)
46 | )
47 | return long_name
48 |
--------------------------------------------------------------------------------
/src/modules/mlp/blockshuffle_mlp.py:
--------------------------------------------------------------------------------
1 | from .basic_mlp import FusedBasicMLP
2 | from ..layer import BlockShuffle
3 |
4 |
5 | class FusedBlockShuffleMLP(FusedBasicMLP):
6 |
7 | def __init__(
8 | self,
9 | hidden_dim,
10 | ffn_dim,
11 | bias,
12 | act,
13 | config,
14 | init_config,
15 | device,
16 | ):
17 | super().__init__(hidden_dim, ffn_dim, bias, act=act)
18 | self.fc1 = BlockShuffle(
19 | hidden_dim,
20 | self.ffn_dim,
21 | bias=bias,
22 | return_bias=True,
23 | config=config,
24 | init_config=init_config,
25 | device=device,
26 | )
27 | self.fc2 = BlockShuffle(
28 | ffn_dim,
29 | hidden_dim,
30 | bias=bias,
31 | return_bias=True,
32 | config=config,
33 | init_config=init_config,
34 | device=device,
35 | )
36 |
37 | @staticmethod
38 | def get_ckpt_name(config_method):
39 | long_name = (
40 | "b" + str(config_method.nblocks) + "-" + str(config_method.init.post_init)
41 | )
42 | return long_name
43 |
--------------------------------------------------------------------------------
/src/modules/mlp/lowrank_mlp.py:
--------------------------------------------------------------------------------
1 | from .basic_mlp import FusedBasicMLP
2 | from ..layer import LowRank
3 |
4 |
5 | class FusedLowRankMLP(FusedBasicMLP):
6 |
7 | def __init__(
8 | self,
9 | hidden_dim,
10 | ffn_dim,
11 | bias,
12 | act,
13 | config,
14 | init_config,
15 | device,
16 | ):
17 | super().__init__(hidden_dim, ffn_dim, bias, act=act)
18 | self.fc1 = LowRank(
19 | hidden_dim,
20 | self.ffn_dim,
21 | bias=bias,
22 | return_bias=True,
23 | config=config,
24 | init_config=init_config,
25 | device=device,
26 | )
27 | self.fc2 = LowRank(
28 | ffn_dim,
29 | hidden_dim,
30 | bias=bias,
31 | return_bias=True,
32 | config=config,
33 | init_config=init_config,
34 | device=device,
35 | )
36 |
37 | @staticmethod
38 | def get_ckpt_name(config_method):
39 | long_name = (
40 | "r" + str(config_method.rank) + "-" + str(config_method.init.post_init)
41 | )
42 | return long_name
43 |
--------------------------------------------------------------------------------
/src/modules/mlp/mlp.py:
--------------------------------------------------------------------------------
1 | from .basic_mlp import FusedBasicMLP
2 | from ..layer import CustomLinear
3 |
4 |
5 | class FusedMLP(FusedBasicMLP):
6 | def __init__(self, hidden_dim, ffn_dim, bias, act="gelu"):
7 | super().__init__(hidden_dim, ffn_dim, bias, act)
8 | self.fc1 = CustomLinear(hidden_dim, self.ffn_dim, bias=bias, return_bias=True)
9 | self.fc2 = CustomLinear(ffn_dim, hidden_dim, bias=bias, return_bias=True)
10 |
--------------------------------------------------------------------------------
/src/modules/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .gpt2 import GPT2LMHeadModel
2 |
--------------------------------------------------------------------------------
/src/modules/model/gpt2.py:
--------------------------------------------------------------------------------
1 | """A fast version of gpt2 with flash attention and transformer engine"""
2 |
3 | import torch.nn as nn
4 | import torch
5 | import torch.nn.functional as F
6 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
7 | import transformer_engine.pytorch as te
8 |
9 | from transformer_engine.pytorch.jit import (
10 | set_jit_fusion_options,
11 | )
12 | from flash_attn import flash_attn_func, flash_attn_with_kvcache
13 | from ..layer import CustomLinear
14 | from ..mlp import FusedMLP
15 | from ..op import bias_dropout_add_impl, RotaryEmbedding, apply_rotary_pos_emb
16 |
17 |
18 | layernorm_func = {
19 | "layernorm": te.LayerNorm,
20 | "rmsnorm": te.RMSNorm,
21 | }
22 |
23 |
24 | class InferenceParams:
25 |
26 | def __init__(self, max_batch_size, max_sequence_length):
27 | self.max_sequence_length = max_sequence_length
28 | self.max_batch_size = max_batch_size
29 | self.sequence_len_offset = 0
30 | self.batch_size_offset = 0
31 | self.key_value_memory_dict = {}
32 |
33 |
34 | class TransformerLayer(nn.Module):
35 |
36 | def __init__(self, config, layer_number):
37 | super().__init__()
38 | self.hidden_dim = config.hidden_dim
39 | self.attn_dim = config.attn_dim
40 | self.ffn_dim = config.ffn_dim
41 | self.num_q_heads = config.num_q_heads
42 | assert self.attn_dim % self.num_q_heads == 0
43 | self.head_dim = self.attn_dim // self.num_q_heads
44 | self.num_kv_heads = config.num_kv_heads
45 | self.layer_number = layer_number
46 | self.hidden_dropout = config.hidden_drop
47 |
48 | self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
49 | set_jit_fusion_options()
50 | self.ln1 = layernorm_func[config.ln](self.hidden_dim)
51 | self.qkv_linear = nn.Linear(
52 | self.hidden_dim,
53 | (self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim,
54 | bias=config.bias,
55 | )
56 | self.o_linear = CustomLinear(
57 | self.attn_dim,
58 | self.hidden_dim,
59 | bias=config.bias,
60 | return_bias=True,
61 | )
62 | self.ln2 = layernorm_func[config.ln](self.hidden_dim)
63 | self.mlp = FusedMLP(
64 | self.hidden_dim, self.ffn_dim, bias=config.bias, act=config.act
65 | )
66 |
67 | def _bias_dropout_add(self, hidden_state, bias, residual):
68 | bias_dropout_add_func = bias_dropout_add_impl(self.training)
69 | output = bias_dropout_add_func(
70 | (hidden_state, bias), residual, self.hidden_dropout
71 | )
72 | return output
73 |
74 | def _adjust_key_value_for_inference(
75 | self, inference_params, k_out, v_out, rotary_pos_emb
76 | ):
77 | if inference_params is None:
78 | return k_out, v_out, rotary_pos_emb
79 | bs = k_out.shape[0]
80 | seq_len = k_out.shape[1]
81 |
82 | inference_key_memory, inference_value_memory = (
83 | inference_params.key_value_memory_dict[self.layer_number]
84 | )
85 | batch_start = inference_params.batch_size_offset
86 | batch_end = batch_start + bs
87 | assert batch_end <= inference_key_memory.size(0)
88 | sequence_start = inference_params.sequence_len_offset
89 | sequence_end = sequence_start + seq_len
90 | assert sequence_end <= inference_key_memory.size(1)
91 | inference_key_memory[
92 | batch_start:batch_end, sequence_start:sequence_end, ...
93 | ] = k_out
94 | inference_value_memory[
95 | batch_start:batch_end, sequence_start:sequence_end, ...
96 | ] = v_out
97 | key = inference_key_memory[batch_start:batch_end, :sequence_end, ...]
98 | value = inference_value_memory[batch_start:batch_end, :sequence_end, ...]
99 |
100 | # adjust the key rotary positional embedding
101 | if rotary_pos_emb is not None:
102 | q_pos_emb, k_pos_emb = rotary_pos_emb
103 | q_pos_emb = q_pos_emb[sequence_start:sequence_end, :, :, :]
104 | k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
105 | rotary_pos_emb = (q_pos_emb, k_pos_emb)
106 |
107 | return key, value, rotary_pos_emb
108 |
109 | def forward(
110 | self,
111 | hidden_states,
112 | inference_params: Optional[InferenceParams] = None,
113 | use_cache=False,
114 | rotary_pos_emb: torch.Tensor = None,
115 | ):
116 | hidden_states = hidden_states.contiguous()
117 | bs, seq_len, _ = hidden_states.shape
118 | qkv_out = self.qkv_linear(self.ln1(hidden_states))
119 | q_out = qkv_out[..., : (self.num_q_heads * self.head_dim)]
120 | kv_out = qkv_out[..., (self.num_q_heads * self.head_dim) :]
121 | k_out, v_out = kv_out.chunk(2, -1)
122 | q_out = q_out.reshape(bs, seq_len, self.num_q_heads, self.head_dim)
123 | k_out = k_out.reshape(bs, seq_len, self.num_kv_heads, self.head_dim)
124 | v_out = v_out.reshape(bs, seq_len, self.num_kv_heads, self.head_dim)
125 |
126 | if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
127 | rotary_pos_emb = (rotary_pos_emb,) * 2
128 |
129 | k_out, v_out, rotary_pos_emb = self._adjust_key_value_for_inference(
130 | inference_params, k_out, v_out, rotary_pos_emb
131 | )
132 | if rotary_pos_emb is not None:
133 | q_pos_emb, k_pos_emb = rotary_pos_emb
134 | q_out = apply_rotary_pos_emb(
135 | q_out,
136 | q_pos_emb,
137 | )
138 | k_out = apply_rotary_pos_emb(
139 | k_out,
140 | k_pos_emb,
141 | )
142 | softmax_scale = q_out.shape[-1] ** (-0.5)
143 | if self.scale_attn_by_inverse_layer_idx:
144 | softmax_scale /= float(self.layer_number + 1)
145 | if not use_cache:
146 | attention_out = flash_attn_func(
147 | q_out, k_out, v_out, softmax_scale=softmax_scale, causal=True
148 | ).reshape(bs, seq_len, self.attn_dim)
149 | else:
150 | attention_out = flash_attn_with_kvcache(
151 | q_out,
152 | k_out,
153 | v_out,
154 | softmax_scale=softmax_scale,
155 | cache_seqlens=inference_params.sequence_len_offset,
156 | causal=True,
157 | ).reshape(bs, seq_len, self.attn_dim)
158 |
159 | attention_out, attention_bias = self.o_linear(attention_out)
160 | hidden_states = self._bias_dropout_add(
161 | attention_out, attention_bias, hidden_states
162 | )
163 | ln2_out = self.ln2(hidden_states)
164 | fc2_out, fc2_bias = self.mlp(ln2_out)
165 | hidden_states = self._bias_dropout_add(fc2_out, fc2_bias, hidden_states)
166 | return hidden_states
167 |
168 |
169 | class BasicGPT2(nn.Module):
170 |
171 | def __init__(
172 | self,
173 | ):
174 | super().__init__()
175 |
176 | @torch.no_grad()
177 | def _init_weights(self, module, init_config):
178 | """initialize the weight"""
179 | if init_config.weight_init == "fixed":
180 | initializer_range = init_config.initializer_range
181 | if isinstance(module, (nn.Linear, CustomLinear)):
182 | module.weight.data.normal_(mean=0.0, std=initializer_range)
183 | if module.bias is not None:
184 | module.bias.data.zero_()
185 | elif isinstance(module, nn.Embedding):
186 | module.weight.data.normal_(mean=0.0, std=initializer_range)
187 | elif isinstance(module, (nn.LayerNorm, te.LayerNorm, te.RMSNorm)):
188 | if hasattr(module, "bias"):
189 | module.bias.data.zero_()
190 | module.weight.data.fill_(1.0)
191 | else:
192 | raise NotImplementedError
193 |
194 |
195 | class GPT2Model(BasicGPT2):
196 |
197 | def __init__(self, config):
198 | super().__init__()
199 | self.config = config
200 | self.embed_dim = config.hidden_dim
201 | self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
202 | if config.pos_emb.name == "absolute":
203 | self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
204 | elif config.pos_emb.name == "rope":
205 | self.rotary_pos_emb = RotaryEmbedding(
206 | kv_channels=config.attn_dim // config.num_q_heads,
207 | rotary_interleaved=config.pos_emb.rotary_interleaved,
208 | seq_len_interpolation_factor=config.pos_emb.seq_len_interpolation_factor,
209 | rotary_base=config.pos_emb.rotary_base,
210 | )
211 | else:
212 | raise NotImplementedError
213 | self.drop = nn.Dropout(config.embd_drop)
214 | self.layers = nn.ModuleList(
215 | [TransformerLayer(config, i) for i in range(config.num_layers)]
216 | )
217 | self.ln_f = layernorm_func[config.ln](self.embed_dim)
218 |
219 | def forward(
220 | self,
221 | input_ids: torch.LongTensor = None,
222 | inference_params: Optional[InferenceParams] = None,
223 | use_cache=False,
224 | ):
225 | bs, seq = input_ids.shape
226 | seq_start = (
227 | inference_params.sequence_len_offset
228 | if use_cache and inference_params is not None
229 | else 0
230 | )
231 | seq_end = seq_start + seq
232 |
233 | position_ids = (
234 | torch.arange(seq_start, seq_end, dtype=torch.long, device=input_ids.device)
235 | .unsqueeze(0)
236 | .view(-1, seq)
237 | )
238 | inputs_embeds = self.wte(input_ids)
239 | if self.config.pos_emb.name == "absolute":
240 | position_embeds = self.wpe(position_ids)
241 | hidden_states = inputs_embeds + position_embeds
242 | else:
243 | hidden_states = inputs_embeds
244 | hidden_states = self.drop(hidden_states)
245 | if self.config.pos_emb.name == "rope":
246 | rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
247 | inference_params, hidden_states
248 | )
249 | rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
250 | else:
251 | rotary_pos_emb = None
252 | for layer in self.layers:
253 | hidden_states = layer(
254 | hidden_states, inference_params, use_cache, rotary_pos_emb
255 | )
256 | hidden_states = self.ln_f(hidden_states)
257 | return hidden_states
258 |
259 |
260 | class GPT2LMHeadModel(BasicGPT2):
261 |
262 | def __init__(self, config):
263 | super().__init__()
264 | self.config = config
265 | self.model = GPT2Model(config)
266 | self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
267 | # init weight
268 | self.apply(
269 | lambda module: self._init_weights(module=module, init_config=config.init)
270 | )
271 | # tie weight embedding
272 | if config.tie_word_embeddings:
273 | self.lm_head.weight = self.model.wte.weight
274 |
275 | @staticmethod
276 | def get_ckpt_name(config_model):
277 | return (
278 | "h"
279 | + f"{config_model.hidden_dim}"
280 | + "a"
281 | + f"{config_model.attn_dim}"
282 | + "f"
283 | + f"{config_model.ffn_dim}"
284 | + "nkv"
285 | + f"{config_model.num_kv_heads}"
286 | + f"{config_model.act}"
287 | + f"{config_model.pos_emb.name}"
288 | + f"{config_model.ln}"
289 | )
290 |
291 | def get_flops(self, bs, seq_len):
292 | attn_qo = 2 * bs * seq_len * self.config.attn_dim * self.config.hidden_dim
293 | attn_kv = (
294 | 2
295 | * bs
296 | * seq_len
297 | * (self.config.attn_dim // self.config.num_q_heads)
298 | * self.config.num_kv_heads
299 | * self.config.hidden_dim
300 | )
301 | sdp = 2 * bs * seq_len * seq_len * self.config.attn_dim
302 | return (
303 | 2 * self.config.num_layers * (attn_qo + attn_kv + sdp)
304 | + self.get_flops_mlp(bs, seq_len)
305 | + 2 * bs * seq_len * self.config.vocab_size * self.config.hidden_dim
306 | )
307 |
308 | def get_params(
309 | self,
310 | ):
311 | attn_qo = 2 * self.config.attn_dim * self.config.hidden_dim
312 | attn_kv = (
313 | 2
314 | * (self.config.attn_dim // self.config.num_q_heads)
315 | * self.config.num_kv_heads
316 | * self.config.hidden_dim
317 | )
318 | return (
319 | self.config.num_layers * (attn_qo + attn_kv)
320 | + self.get_params_mlp()
321 | + self.config.vocab_size * self.config.hidden_dim
322 | )
323 |
324 | def get_params_woembedding(
325 | self,
326 | ):
327 | attn_qo = 2 * self.config.attn_dim * self.config.hidden_dim
328 | attn_kv = (
329 | 2
330 | * (self.config.attn_dim // self.config.num_q_heads)
331 | * self.config.num_kv_heads
332 | * self.config.hidden_dim
333 | )
334 | return self.config.num_layers * (attn_qo + attn_kv) + self.get_params_mlp()
335 |
336 | def get_flops_mlp(self, bs, seq):
337 | # as they're all linear layers. The flops just scales with the parameters
338 | mlp = 0
339 | for layer in self.model.layers:
340 | for para in layer.mlp.parameters():
341 | if len(para.shape) != 1:
342 | mlp += para.numel()
343 | return 2 * mlp * bs * seq
344 |
345 | def get_params_mlp(
346 | self,
347 | ):
348 | mlp = 0
349 | for layer in self.model.layers:
350 | for para in layer.mlp.parameters():
351 | if len(para.shape) != 1:
352 | mlp += para.numel()
353 | return mlp
354 |
355 | def forward(
356 | self,
357 | input_ids: torch.LongTensor = None,
358 | labels: Optional[torch.LongTensor] = None,
359 | inference_params: Optional[InferenceParams] = None,
360 | use_cache=False,
361 | ):
362 | out = self.model(input_ids, inference_params, use_cache)
363 | lm_logits = self.lm_head(out)
364 | loss = None
365 | if labels is not None:
366 | loss = F.cross_entropy(
367 | lm_logits.view(-1, lm_logits.size(-1)).contiguous(),
368 | labels.view(-1).contiguous(),
369 | )
370 | if loss is not None:
371 | return loss
372 | else:
373 | return lm_logits
374 |
375 | def prepare_inference_params(
376 | self, batch_size, mx_seq, torch_dtype=torch.bfloat16, device="cuda"
377 | ):
378 | # mx_seq is composed of prefill and generation length
379 | inference_params = InferenceParams(batch_size, mx_seq)
380 | inf_max_seq_len = inference_params.max_sequence_length
381 | inf_max_batch_size = inference_params.max_batch_size
382 | for i in range(self.config.num_layers):
383 | inference_key_memory = torch.empty(
384 | inf_max_batch_size,
385 | inf_max_seq_len,
386 | self.config.num_kv_heads,
387 | (self.config.attn_dim // self.config.num_q_heads),
388 | dtype=torch_dtype,
389 | device=device,
390 | )
391 | inference_value_memory = torch.empty(
392 | inf_max_batch_size,
393 | inf_max_seq_len,
394 | self.config.num_kv_heads,
395 | (self.config.attn_dim // self.config.num_q_heads),
396 | dtype=torch_dtype,
397 | device=device,
398 | )
399 | inference_params.key_value_memory_dict[i] = (
400 | inference_key_memory,
401 | inference_value_memory,
402 | )
403 | return inference_params
404 |
--------------------------------------------------------------------------------
/src/modules/op/__init__.py:
--------------------------------------------------------------------------------
1 | from .low_rank import low_rank_custom
2 | from .block_dense import block_dense_custom, block_dense_bmm
3 | from .block_shuffle import (
4 | block_shuffle_custom,
5 | block_shuffle_bmm,
6 | block_shuffle_einsum,
7 | )
8 |
9 | from .common.fused_gelu import bias_gelu_impl
10 | from .common.fused_swiglu import bias_swiglu_impl
11 | from .common.fused_bias_dropout_add import bias_dropout_add_impl
12 | from .common.rotary_embeddings import RotaryEmbedding, apply_rotary_pos_emb
13 |
--------------------------------------------------------------------------------
/src/modules/op/block_dense.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def block_dense_bmm(input, blkdiag, linear):
6 | batch_shape, h = input.shape[:-1], input.shape[-1]
7 | batch_dim = np.prod(batch_shape)
8 | k, q, p = blkdiag.shape
9 | l, r = linear.shape
10 | assert k * p == h
11 | assert r == k * q
12 | input = input.reshape(batch_dim, k, p).transpose(0, 1)
13 | out1 = torch.bmm(input, blkdiag.transpose(-1, -2))
14 | out1 = out1.transpose(0, 1).reshape(batch_dim, r)
15 | out2 = torch.mm(out1, linear.transpose(-1, -2)).reshape(*batch_shape, l)
16 | return out2
17 |
18 |
19 | class BlockDenseCustom(torch.autograd.Function):
20 | """This is a faster implementation, with careful memory copies for the fastest
21 | bmm performance.
22 | The backward pass is also written manually with careful memory copies.
23 | Arguments:
24 | x: (batch, n)
25 | w1_bfly: (k, q, p), where k = n / p
26 | w2_bfly: (l, s, r), where l = k * q / r = n * q / (p * r)
27 | Outputs:
28 | out: (batch, m), where m = l * s = n * s * q / (p * r)
29 | """
30 |
31 | @staticmethod
32 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.bfloat16)
33 | def forward(ctx, x, w1_bfly, linear):
34 | # due to bugs in torch.bmm with specific out dtype, we need to change the weight dtype here by hand
35 | # note that this only changes the weight dtype in this scope
36 | batch_shape, n = x.shape[:-1], x.shape[-1]
37 | batch_dim = np.prod(batch_shape)
38 | k, q, p = w1_bfly.shape
39 | l, r = linear.shape
40 | assert k * p == n
41 | assert r == k * q
42 | x_reshaped = x.reshape(batch_dim, k, p).transpose(0, 1)
43 | out1 = torch.empty(batch_dim, k, q, device=x.device, dtype=x.dtype).transpose(
44 | 0, 1
45 | )
46 | out1 = torch.bmm(x_reshaped, w1_bfly.transpose(-1, -2), out=out1)
47 | out1 = out1.transpose(0, 1).reshape(batch_dim, r)
48 | out2 = torch.mm(out1, linear.transpose(-1, -2)).reshape(*batch_shape, l)
49 | ctx.save_for_backward(x, w1_bfly, linear, out1)
50 | return out2
51 |
52 | @staticmethod
53 | @torch.cuda.amp.custom_bwd
54 | def backward(ctx, dout):
55 | x, w1_bfly, linear, out1 = ctx.saved_tensors
56 | batch_shape, n = x.shape[:-1], x.shape[-1]
57 | batch_dim = np.prod(batch_shape)
58 | k, q, p = w1_bfly.shape
59 | l, r = linear.shape
60 |
61 | dx, dw1_bfly, dw2_linear = None, None, None
62 | dout_reshaped = dout.reshape(batch_dim, l)
63 | if ctx.needs_input_grad[2]:
64 | dw2_linear = torch.mm(dout_reshaped.transpose(-1, -2), out1)
65 | if ctx.needs_input_grad[1] or ctx.needs_input_grad[0]:
66 | dout1 = (
67 | torch.mm(dout_reshaped, linear).reshape(batch_dim, k, q).transpose(0, 1)
68 | )
69 | if ctx.needs_input_grad[0]:
70 | dx = torch.empty(
71 | batch_dim, k, p, device=x.device, dtype=x.dtype
72 | ).transpose(0, 1)
73 | dx = (
74 | torch.bmm(dout1, w1_bfly, out=dx)
75 | .transpose(0, 1)
76 | .reshape(*batch_shape, n)
77 | )
78 | if ctx.needs_input_grad[1]:
79 | x_reshaped = x.reshape(batch_dim, k, p).transpose(0, 1)
80 | dw1_bfly = torch.bmm(dout1.transpose(-1, -2), x_reshaped)
81 | return dx, dw1_bfly, dw2_linear
82 |
83 |
84 | block_dense_custom = BlockDenseCustom.apply
85 |
--------------------------------------------------------------------------------
/src/modules/op/block_shuffle.py:
--------------------------------------------------------------------------------
1 | # paste from Monarch by TriDao
2 | import torch
3 | import numpy as np
4 | from einops import rearrange
5 |
6 |
7 | def block_shuffle_einsum(input, blkdiag1, blkdiag2):
8 | batch_shape, h = input.shape[:-1], input.shape[-1]
9 | batch_dim = np.prod(batch_shape)
10 | k, q, p = blkdiag1.shape
11 | l, s, r = blkdiag2.shape
12 | assert k * p == h
13 | assert l * r == k * q
14 | input = input.reshape(batch_dim, k, p)
15 | out1 = torch.einsum("kqp,bkp->bkq", blkdiag1, input)
16 | out1 = rearrange(rearrange(out1, "b k q -> b (k q)"), "b (r l) -> b l r", l=l)
17 | return torch.einsum("lsr,blr->bsl", blkdiag2, out1).reshape(*batch_shape, s * l)
18 |
19 |
20 | def block_shuffle_bmm(input, blkdiag1, blkdiag2):
21 | batch_shape, h = input.shape[:-1], input.shape[-1]
22 | batch_dim = np.prod(batch_shape)
23 | k, q, p = blkdiag1.shape
24 | l, s, r = blkdiag2.shape
25 | assert k * p == h
26 | assert l * r == k * q
27 | input = input.reshape(batch_dim, k, p).transpose(0, 1)
28 | out1 = torch.bmm(input, blkdiag1.transpose(-1, -2))
29 | out1 = (
30 | out1.transpose(0, 1)
31 | .reshape(batch_dim, r, l)
32 | .transpose(-1, -2)
33 | .contiguous()
34 | .transpose(0, 1)
35 | )
36 | out2 = torch.bmm(out1, blkdiag2.transpose(-1, -2))
37 | out2 = out2.permute(1, 2, 0).reshape(*batch_shape, s * l)
38 | return out2
39 |
40 |
41 | class BlockShuffleCustom(torch.autograd.Function):
42 | # Paste from monarch repo
43 | """This is a faster implementation, with careful memory copies for the fastest
44 | bmm performance.
45 | The backward pass is also written manually with careful memory copies.
46 | Arguments:
47 | x: (batch, n)
48 | w1_bfly: (k, q, p), where k = n / p
49 | w2_bfly: (l, s, r), where l = k * q / r = n * q / (p * r)
50 | Outputs:
51 | out: (batch, m), where m = l * s = n * s * q / (p * r)
52 | """
53 |
54 | @staticmethod
55 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.bfloat16)
56 | def forward(ctx, x, w1_bfly, w2_bfly):
57 | batch_shape, n = x.shape[:-1], x.shape[-1]
58 | batch_dim = np.prod(batch_shape)
59 | k, q, p = w1_bfly.shape
60 | l, s, r = w2_bfly.shape
61 | assert k * p == n
62 | assert l * r == k * q
63 | x_reshaped = x.reshape(batch_dim, k, p).transpose(0, 1)
64 | out1 = torch.empty(batch_dim, k, q, device=x.device, dtype=x.dtype).transpose(
65 | 0, 1
66 | )
67 | out1 = torch.bmm(x_reshaped, w1_bfly.transpose(-1, -2), out=out1)
68 | out1 = (
69 | out1.transpose(0, 1)
70 | .reshape(batch_dim, r, l)
71 | .transpose(-1, -2)
72 | .contiguous()
73 | .transpose(0, 1)
74 | )
75 | out2 = torch.empty(batch_dim, l, s, device=x.device, dtype=x.dtype).transpose(
76 | 0, 1
77 | )
78 | out2 = torch.bmm(out1, w2_bfly.transpose(-1, -2), out=out2)
79 | out2 = out2.permute(1, 2, 0).reshape(*batch_shape, s * l)
80 | ctx.save_for_backward(x, w1_bfly, w2_bfly, out1)
81 | return out2
82 |
83 | @staticmethod
84 | @torch.cuda.amp.custom_bwd
85 | def backward(ctx, dout):
86 | x, w1_bfly, w2_bfly, out1 = ctx.saved_tensors
87 | batch_shape, n = x.shape[:-1], x.shape[-1]
88 | batch_dim = np.prod(batch_shape)
89 | k, q, p = w1_bfly.shape
90 | l, s, r = w2_bfly.shape
91 | # assert k * p == n
92 | # assert l * r == k * q
93 | dx, dw1_bfly, dw2_bfly = None, None, None
94 | # dout_reshaped = dout.reshape(batch_dim, sqrtn, sqrtn).permute(2, 1, 0).contiguous()
95 | dout_reshaped = dout.reshape(batch_dim, s, l).transpose(-1, -2).contiguous()
96 | dout_reshaped = dout_reshaped.transpose(0, 1)
97 | if ctx.needs_input_grad[2]:
98 | # dw2_bfly = torch.empty(l, s, r, device=w2_bfly.device, dtype=w2_bfly.dtype)
99 | # dw2_bfly = torch.bmm(dout_reshaped.transpose(-1, -2), out1, out=dw2_bfly)
100 | dw2_bfly = torch.bmm(dout_reshaped.transpose(-1, -2), out1)
101 | if ctx.needs_input_grad[1] or ctx.needs_input_grad[0]:
102 | dout1 = torch.empty(
103 | batch_dim, l, r, device=x.device, dtype=x.dtype
104 | ).transpose(0, 1)
105 | dout1 = torch.bmm(dout_reshaped, w2_bfly, out=dout1)
106 | dout1 = (
107 | dout1.transpose(0, 1)
108 | .transpose(-1, -2)
109 | .contiguous()
110 | .reshape(batch_dim, k, q)
111 | .transpose(0, 1)
112 | )
113 | # dout1 = dout1.permute(1, 2, 0).contiguous().transpose(0, 1)
114 | if ctx.needs_input_grad[0]:
115 | dx = torch.empty(batch_dim, k, p, device=x.device, dtype=x.dtype)
116 | dx = (
117 | torch.bmm(dout1, w1_bfly, out=dx.transpose(0, 1))
118 | .transpose(0, 1)
119 | .reshape(*batch_shape, n)
120 | )
121 | if ctx.needs_input_grad[1]:
122 | x_reshaped = x.reshape(batch_dim, k, p).transpose(0, 1)
123 | dw1_bfly = torch.bmm(dout1.transpose(-1, -2), x_reshaped)
124 | return dx, dw1_bfly, dw2_bfly
125 |
126 |
127 | block_shuffle_custom = BlockShuffleCustom.apply
128 |
--------------------------------------------------------------------------------
/src/modules/op/common/fused_bias_dropout_add.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2 | import os
3 | from typing import Optional, Tuple
4 | import torch
5 |
6 |
7 | jit_fuser = torch.jit.script
8 | if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
9 | jit_fuser = torch.compile
10 |
11 |
12 | def _bias_dropout_add_func(x_with_bias, residual, prob, training):
13 | x, bias = x_with_bias # unpack
14 |
15 | # If we want to train mixed precision, then the output of this function
16 | # should be half precision. However, in AMP O1, the input (residual) is
17 | # in fp32, and it will up-cast the result to fp32, causing pipeline parallel
18 | # GPU communication to hang. Therefore, we need to cast residual to the same
19 | # dtype as x.
20 | residual = residual if residual.dtype == x.dtype else residual.to(x.dtype)
21 |
22 | if bias is not None:
23 | x = x + bias
24 | out = torch.nn.functional.dropout(x, p=prob, training=training)
25 | out = residual + out
26 | return out
27 | else:
28 | out = torch.nn.functional.dropout(x, p=prob, training=training)
29 | out = residual + out
30 | return out
31 |
32 |
33 | @jit_fuser
34 | def bias_dropout_add_fused_train(
35 | x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]],
36 | residual: torch.Tensor,
37 | prob: float,
38 | ) -> torch.Tensor:
39 | return _bias_dropout_add_func(x_with_bias, residual, prob, True)
40 |
41 |
42 | @jit_fuser
43 | def bias_dropout_add_fused_inference(
44 | x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]],
45 | residual: torch.Tensor,
46 | prob: float,
47 | ) -> torch.Tensor:
48 | return _bias_dropout_add_func(x_with_bias, residual, prob, False)
49 |
50 |
51 | def bias_dropout_add_impl(training):
52 | if training:
53 | return bias_dropout_add_fused_train
54 | else:
55 | return bias_dropout_add_fused_inference
56 |
--------------------------------------------------------------------------------
/src/modules/op/common/fused_gelu.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2 | # paste from megatron
3 | import os
4 | import torch
5 | from typing import Callable, Optional, Tuple
6 |
7 |
8 | jit_fuser = torch.jit.script
9 | if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
10 | jit_fuser = torch.compile
11 |
12 |
13 | @jit_fuser
14 | def bias_gelu_fused_(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
15 | """Bias-GeLU fused"""
16 | x = inp + bias
17 | return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
18 |
19 |
20 | @jit_fuser
21 | def gelu_fused_(inp: torch.Tensor) -> torch.Tensor:
22 | """
23 | GeLU fused, this is copy of bias_gelu_fused cause jit fusion doesn't allow conditioning.
24 | """
25 | x = inp
26 | return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
27 |
28 |
29 | @jit_fuser
30 | def dgelu_bgrad_fused_(
31 | grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
32 | ) -> Tuple[torch.Tensor, torch.Tensor]:
33 | """Bgrad-Dgelu fused"""
34 | x = inp + bias
35 | tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
36 | # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
37 | ff = 0.5 * x * (
38 | (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
39 | ) + 0.5 * (1 + tanh_out)
40 | dgelu = ff * grad_output
41 | bgrad = dgelu.sum(dim=0)
42 | return dgelu, bgrad
43 |
44 |
45 | @jit_fuser
46 | def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
47 | """
48 | Dgelu fused, this is copy of bgrad_dgelu_fused_ cause jit fusion doesn't allow conditioning.
49 | """
50 | x = inp
51 | tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
52 | # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
53 | ff = 0.5 * x * (
54 | (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
55 | ) + 0.5 * (1 + tanh_out)
56 | dgelu = ff * grad_output
57 | return dgelu
58 |
59 |
60 | class BiasGeLUFunction(torch.autograd.Function):
61 | @staticmethod
62 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
63 | def forward(ctx, input, bias):
64 | ctx.save_for_backward(input, bias)
65 | return bias_gelu_fused_(input, bias)
66 |
67 | @staticmethod
68 | @torch.cuda.amp.custom_bwd
69 | def backward(ctx, grad_output):
70 | input, bias = ctx.saved_tensors
71 | return dgelu_bgrad_fused_(grad_output, input, bias)
72 |
73 |
74 | class GeLUFunction(torch.autograd.Function):
75 | @staticmethod
76 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
77 | def forward(ctx, input):
78 | ctx.save_for_backward(input)
79 | return gelu_fused_(input)
80 |
81 | @staticmethod
82 | @torch.cuda.amp.custom_bwd
83 | def backward(ctx, grad_output):
84 | input = ctx.saved_tensors[0]
85 | return dgelu_fused_(grad_output, input)
86 |
87 |
88 | def bias_gelu_impl(input, bias):
89 | ori_shape = input.shape
90 | assert len(ori_shape) in [2, 3]
91 | input = input.view(-1, ori_shape[-1])
92 | if bias is not None:
93 | output = BiasGeLUFunction.apply(input, bias)
94 | else:
95 | output = GeLUFunction.apply(input)
96 | return (
97 | output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1)
98 | )
99 |
--------------------------------------------------------------------------------
/src/modules/op/common/fused_swiglu.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2 | # paste from megatron
3 | import torch
4 | import torch.nn.functional as F
5 | import os
6 |
7 | jit_fuser = torch.jit.script
8 | if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
9 | jit_fuser = torch.compile
10 |
11 |
12 | @jit_fuser
13 | def swiglu(y):
14 | y_1, y_2 = torch.chunk(y, 2, -1)
15 | return F.silu(y_1) * y_2
16 |
17 |
18 | @jit_fuser
19 | def bias_swiglu(y, bias):
20 | y = y + bias
21 | return swiglu(y)
22 |
23 |
24 | @jit_fuser
25 | def swiglu_back(g, y):
26 | y_1, y_2 = torch.chunk(y, 2, -1)
27 | return torch.cat(
28 | (
29 | g * torch.sigmoid(y_1) * (1 + y_1 * (1 - torch.sigmoid(y_1))) * y_2,
30 | g * F.silu(y_1),
31 | ),
32 | -1,
33 | )
34 |
35 |
36 | @jit_fuser
37 | def bias_swiglu_back(g, y, bias):
38 | y = y + bias
39 | dy = swiglu_back(g, y)
40 | bgrad = dy.sum(dim=0)
41 | return dy, bgrad
42 |
43 |
44 | class BiasSwiGLUFunction(torch.autograd.Function):
45 | @staticmethod
46 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
47 | def forward(ctx, input, bias):
48 | ctx.save_for_backward(input, bias)
49 | return bias_swiglu(input, bias)
50 |
51 | @staticmethod
52 | @torch.cuda.amp.custom_bwd
53 | def backward(ctx, grad_output):
54 | input, bias = ctx.saved_tensors
55 | return bias_swiglu_back(grad_output, input, bias)
56 |
57 |
58 | class SwiGLUFunction(torch.autograd.Function):
59 | @staticmethod
60 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
61 | def forward(ctx, input):
62 | ctx.save_for_backward(input)
63 | return swiglu(input)
64 |
65 | @staticmethod
66 | @torch.cuda.amp.custom_bwd
67 | def backward(ctx, grad_output):
68 | input = ctx.saved_tensors[0]
69 | return swiglu_back(grad_output, input)
70 |
71 |
72 | def bias_swiglu_impl(input, bias):
73 | ori_shape = input.shape
74 | assert len(ori_shape) in [2, 3]
75 | input = input.view(-1, ori_shape[-1])
76 | if bias is not None:
77 | output = BiasSwiGLUFunction.apply(input, bias)
78 | else:
79 | output = SwiGLUFunction.apply(input)
80 |
81 | return (
82 | output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1)
83 | )
84 |
--------------------------------------------------------------------------------
/src/modules/op/common/rotary_embeddings.py:
--------------------------------------------------------------------------------
1 | from apex.transformer.functional import fused_apply_rotary_pos_emb
2 | import torch.nn as nn
3 | import torch
4 | from torch import Tensor
5 |
6 |
7 | class RotaryEmbedding(nn.Module):
8 | """Rotary Embedding for language model.
9 |
10 | Args:
11 | kv_channels (int): Projection weights dimension in multi-head attention. Obtained from transformer config
12 | seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE for longer sequences. The value must be a float larger than 1.0. Defaults to None
13 | rotary_base (int, optional): Base period for rotary position embeddings. Defaults to 10000.
14 | """
15 |
16 | def __init__(
17 | self,
18 | kv_channels: int,
19 | rotary_interleaved: bool = False,
20 | seq_len_interpolation_factor: float = None,
21 | rotary_base: int = 10000,
22 | ) -> None:
23 | super().__init__()
24 |
25 | dim = kv_channels
26 | self.rotary_interleaved = rotary_interleaved
27 |
28 | self.seq_len_interpolation_factor = seq_len_interpolation_factor
29 | self.inv_freq = 1.0 / (
30 | rotary_base
31 | ** (
32 | torch.arange(
33 | 0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device()
34 | )
35 | / dim
36 | )
37 | )
38 |
39 | def forward(self, max_seq_len: int, offset: int = 0) -> Tensor:
40 | """Forward pass of RoPE embedding.
41 |
42 | Args:
43 | max_seq_len (int): Maximum size of sequence
44 | offset (int, optional): _description_. Defaults to 0.
45 |
46 | Returns:
47 | Tensor: Embeddings after applying RoPE.
48 | """
49 | seq = (
50 | torch.arange(
51 | max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype
52 | )
53 | + offset
54 | )
55 |
56 | if self.seq_len_interpolation_factor is not None:
57 | seq *= 1 / self.seq_len_interpolation_factor
58 |
59 | freqs = torch.outer(seq, self.inv_freq)
60 | # first part even vector components, second part odd vector components,
61 | # 2 * dim in dimension size
62 | if not self.rotary_interleaved:
63 | emb = torch.cat((freqs, freqs), dim=-1)
64 | else:
65 | emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view(
66 | freqs.shape[0], -1
67 | )
68 | # emb [seq_length, .., dim]
69 | emb = emb[:, None, None, :]
70 | return emb
71 |
72 | def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
73 | state_dict.pop(f"{prefix}inv_freq", None)
74 | return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
75 |
76 | def get_rotary_seq_len(
77 | self,
78 | inference_params,
79 | transformer_input,
80 | ) -> float:
81 |
82 | if inference_params is not None:
83 | return inference_params.max_sequence_length
84 | return transformer_input.shape[1]
85 |
86 |
87 | def apply_rotary_pos_emb(
88 | t: Tensor,
89 | freqs: Tensor,
90 | ):
91 | # bshd -> sbhd
92 | return fused_apply_rotary_pos_emb(
93 | t.transpose(0, 1), freqs, transpose_output_memory=True
94 | ).transpose(0, 1)
95 |
--------------------------------------------------------------------------------
/src/modules/op/low_rank.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def low_rank_custom(input, linear1, linear2):
6 | batch_shape, h = input.shape[:-1], input.shape[-1]
7 | batch_dim = np.prod(batch_shape)
8 | input = input.reshape(batch_dim, h)
9 | out2 = torch.mm(
10 | torch.mm(input, linear1.transpose(-1, -2)), linear2.transpose(-1, -2)
11 | ).reshape(*batch_shape, -1)
12 | return out2
13 |
--------------------------------------------------------------------------------
/src/optimization/__init__.py:
--------------------------------------------------------------------------------
1 | from torch import optim
2 | from .scheduler import *
3 |
4 |
5 | name_to_scheduler = {
6 | "multisteplr": lambda optimizer, kwargs: _MultiStepLR(optimizer, **kwargs),
7 | "cosineannealinglr": lambda optimizer, kwargs: _CosineAnnealingLR(
8 | optimizer, **kwargs
9 | ),
10 | }
11 |
12 | name_to_optimizer = {
13 | "adam": lambda params, kwargs: optim.Adam(params, **kwargs),
14 | "sgd": lambda params, kwargs: optim.SGD(params, **kwargs),
15 | "adamw": lambda params, kwargs: optim.AdamW(params, **kwargs),
16 | }
17 |
18 |
19 | def get_lr_scheduler(config_optimization, optimizer):
20 | name = config_optimization.lr_scheduler.name.lower()
21 | return name_to_scheduler[name](optimizer, config_optimization.lr_scheduler.kwargs)
22 |
23 |
24 | def get_optimizer(config_optimization, params):
25 | name = config_optimization.optimizer.name.lower()
26 | return name_to_optimizer[name](params, config_optimization.optimizer.kwargs)
27 |
--------------------------------------------------------------------------------
/src/optimization/scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from torch.optim.lr_scheduler import (
3 | MultiStepLR,
4 | CosineAnnealingLR,
5 | )
6 |
7 |
8 | __all__ = [
9 | "_MultiStepLR",
10 | "_CosineAnnealingLR",
11 | ]
12 |
13 |
14 | class _MultiStepLR(MultiStepLR):
15 |
16 | def __init__(self, optimizer, **kwargs):
17 | kwargs["milestones"] = [
18 | int(e * kwargs.pop("T_max")) for e in kwargs["milestones"]
19 | ]
20 | super(_MultiStepLR, self).__init__(optimizer, **kwargs)
21 |
22 |
23 | class _CosineAnnealingLR(CosineAnnealingLR):
24 | def __init__(self, optimizer, **kwargs):
25 | self.warmup_iter = 0
26 | if "warmup_iter" in kwargs:
27 | self.warmup_iter = int(kwargs.pop("warmup_iter") * kwargs["T_max"])
28 | super(_CosineAnnealingLR, self).__init__(optimizer, **kwargs)
29 |
30 | def get_lr(self):
31 | if self.last_epoch < self.warmup_iter:
32 | return [
33 | (self.last_epoch + 1) / self.warmup_iter * base_lr
34 | for base_lr in self.base_lrs
35 | ]
36 |
37 | return [
38 | self.eta_min
39 | + (base_lr - self.eta_min)
40 | * (
41 | 1
42 | + math.cos(
43 | math.pi
44 | * (self.last_epoch - self.warmup_iter)
45 | / (self.T_max - self.warmup_iter)
46 | )
47 | )
48 | / 2
49 | for base_lr in self.base_lrs
50 | ]
51 |
--------------------------------------------------------------------------------
/src/optimization/trainer.py:
--------------------------------------------------------------------------------
1 | import random
2 | import os
3 | import glob
4 | import wandb
5 | import shutil
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from easydict import EasyDict
10 | from omegaconf import open_dict
11 | import torch.distributed as dist
12 | from src.modules import get_ckpt_name
13 |
14 |
15 | class WandbLog:
16 | def __init__(self, config, metric, x_axis="epoch"):
17 | self.config = config
18 | for k, v in metric.items():
19 | if k == x_axis:
20 | wandb.define_metric(x_axis)
21 | else:
22 | wandb.define_metric(k, step_metric=x_axis)
23 |
24 | def record(self, item):
25 | wandb.log(item)
26 |
27 |
28 | class TrainableModel:
29 |
30 | def __init__(self, config):
31 | self.config = config
32 | self.epoch = -1
33 | self.step = -1
34 | self.max_epoch = self.config.optimization.max_epoch
35 | self.max_step = None # define in specific Trainer
36 |
37 | # gpu setting
38 | self.gpu_id = int(os.getenv("RANK", -1))
39 | self.device = (
40 | torch.device("cuda", self.gpu_id)
41 | if self.gpu_id != -1
42 | else torch.device("cuda")
43 | )
44 | self.ngpus = dist.get_world_size() if self.gpu_id != -1 else 1
45 | print("The device is {} out of {}".format(self.device, self.ngpus))
46 |
47 | self.global_batch_size = getattr(
48 | self.config.optimization,
49 | "global_batch_size",
50 | self.config.data.train.train_batch,
51 | )
52 | assert (
53 | self.global_batch_size % (self.ngpus * self.config.data.train.train_batch)
54 | == 0
55 | )
56 | self.gradient_accumulation_steps = self.global_batch_size // (
57 | self.ngpus * self.config.data.train.train_batch
58 | )
59 |
60 | self.log_interval = getattr(self.config.optimization, "log_interval", False)
61 | self.check_gradient_norm = getattr(
62 | self.config.optimization, "check_gradient_norm", False
63 | )
64 | self.check_weight_norm = getattr(
65 | self.config.optimization, "check_weight_norm", False
66 | )
67 | self.gradient_clipping = getattr(
68 | self.config.optimization, "gradient_clipping", False
69 | )
70 | self.special_training = (
71 | self.config.optimization.training.name == "self_guided_training"
72 | )
73 | # save
74 | self.is_save_checkpoint = getattr(
75 | self.config.optimization, "save_checkpoint", False
76 | )
77 | self.is_load_checkpoint = getattr(
78 | self.config.optimization, "load_checkpoint", False
79 | )
80 | self.load_save_mode = getattr(
81 | self.config.optimization, "load_save_mode", "epoch"
82 | )
83 |
84 | def prepare_load_save(
85 | self,
86 | ):
87 | if self.is_save_checkpoint or self.is_load_checkpoint:
88 | long_name = get_ckpt_name(self.config) + "-" + str(self.special_training)
89 | if self.special_training:
90 | long_name += (
91 | "-"
92 | + self.config.optimization.training.kwargs.mode
93 | + "-"
94 | + str(self.config.optimization.training.kwargs.reduce_flop)
95 | )
96 | self.save_dir = os.path.join(self.config.optimization.save_dir, long_name)
97 | self.save_dir = os.path.join(
98 | self.save_dir,
99 | str(self.config.optimization.optimizer.kwargs.lr).replace(".", "x"),
100 | )
101 | if self.load_save_mode == "epoch":
102 | self.save_interval = self.max_epoch // 10
103 | elif self.load_save_mode == "step":
104 | self.save_interval = self.max_step // 10
105 | else:
106 | raise NotImplementedError
107 | print(
108 | "plan to save or load checkpoint in {} for each {} in the mode {}".format(
109 | self.save_dir, self.save_interval, self.load_save_mode
110 | )
111 | )
112 | if not self.is_load_checkpoint:
113 | shutil.rmtree(self.save_dir)
114 | if not os.path.exists(self.save_dir):
115 | os.makedirs(self.save_dir)
116 |
117 | def set_gradient_clipping(
118 | self,
119 | ):
120 | if self.gradient_clipping is not False:
121 | torch.nn.utils.clip_grad_norm_(
122 | self.model.parameters(), self.gradient_clipping
123 | )
124 |
125 | def get_info(
126 | self,
127 | ):
128 | nparam = self.get_nparam()
129 | nflops = self.model.get_flops(
130 | self.global_batch_size,
131 | self.block_size,
132 | ) # we consider all the matrix multiplication including the final logits in the model
133 | total_flops = nflops * self.max_step
134 | if self.special_training:
135 | guide_params = sum(
136 | [
137 | p.guide_linear.numel()
138 | for p in self.model.modules()
139 | if hasattr(p, "guide_linear")
140 | ]
141 | )
142 | # print("the number of guide parameters are {:.2f}".format(guide_params))
143 | guide_flops = (
144 | 2 * guide_params * self.global_batch_size * self.block_size
145 | ) # addition and multiplication
146 | total_flops -= guide_flops * (self.max_step - self.guided_steps)
147 | if self.config.optimization.training.kwargs.reduce_flop:
148 | total_flops -= 0.5 * guide_flops * self.guided_steps
149 | print("The total parameter is {:.2f} M".format(nparam / 10**6))
150 | print(
151 | "FLOPs information: flops per forward step {:.2f}T, total flops {:.2f}T".format(
152 | nflops / 10**12,
153 | total_flops * 3 / 10**12, # backward and forward
154 | )
155 | )
156 | nparam_mlp = self.model.get_params_mlp()
157 | nflops_mlp = self.model.get_flops_mlp(
158 | self.global_batch_size,
159 | self.block_size,
160 | )
161 | print(
162 | "MLP information: params {:.2f}M, flops per step {:.2f}T".format(
163 | nparam_mlp / 10**6,
164 | nflops_mlp / 10**12,
165 | )
166 | )
167 |
168 | print(self.model)
169 |
170 | def get_nparam(
171 | self,
172 | ):
173 | self.nparam = sum(param.numel() for param in self.model.parameters())
174 | return self.nparam
175 |
176 | def set_seed(self, seed):
177 | random.seed(seed)
178 | os.environ["PYTHONHASHSEED"] = str(seed)
179 | np.random.seed(seed)
180 | torch.manual_seed(seed)
181 | torch.cuda.manual_seed(seed)
182 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
183 | torch.backends.cudnn.benchmark = False
184 | torch.backends.cudnn.deterministic = True
185 |
186 | def set_self_guided_training(
187 | self,
188 | ):
189 | self.repeat_steps = 0
190 | self.guided_steps = 0
191 | if self.special_training:
192 | self.guided_steps = int(
193 | self.max_step * self.config.optimization.training.kwargs.max_step_ratio
194 | )
195 | with open_dict(self.config.method.kwargs) as f:
196 | f.training.enabled = True
197 | f.training.scheduler = (
198 | self.config.optimization.training.kwargs.scheduler
199 | )
200 | f.training.max_step = self.guided_steps
201 | f.training.reduce_flop = (
202 | self.config.optimization.training.kwargs.reduce_flop
203 | )
204 | if self.config.optimization.training.kwargs.mode == "fixedflop":
205 | self.repeat_steps = self.guided_steps
206 | self.max_step += self.repeat_steps
207 | elif self.config.method.name != "linear":
208 | with open_dict(self.config.method.kwargs) as f:
209 | f.training.enabled = False
210 |
211 | def close_self_guided_training(
212 | self,
213 | ):
214 | from src.modules.layer.basiclinear import BasicLinear
215 |
216 | self.special_training = False
217 | for name, module in self.model.named_modules():
218 | if isinstance(module, BasicLinear):
219 | module.training_config.enabled = False
220 |
221 | def get_optimize_param(
222 | self,
223 | ):
224 | params = [{"params": self.model.parameters()}]
225 | return params
226 |
227 | def save_checkpoint(self, **resume_kwargs):
228 | # save checkpoint by epoch
229 | if not self.is_save_checkpoint or self.gpu_id not in [-1, 0]:
230 | return
231 | if self.load_save_mode == "epoch":
232 | cur = self.epoch
233 | cur_max = self.max_epoch
234 | elif self.load_save_mode == "step":
235 | cur = self.step
236 | cur_max = self.max_step
237 | if (cur + 1) % self.save_interval == 0 or cur + 1 == cur_max:
238 | ckpt_path = os.path.join(
239 | self.save_dir,
240 | f"{cur}.pth",
241 | )
242 | ckpt = {
243 | "model": (
244 | self.model.module.state_dict()
245 | if self.gpu_id == 0
246 | else self.model.state_dict()
247 | ),
248 | self.load_save_mode: cur,
249 | "config": self.config,
250 | "nparam": self.nparam,
251 | "optimizer": self.optimizer.state_dict(),
252 | "lr_scheduler": (
253 | self.lr_scheduler.state_dict()
254 | if getattr(self, "lr_scheduler", None)
255 | else None
256 | ),
257 | "resume_kwargs": resume_kwargs,
258 | }
259 | torch.save(ckpt, ckpt_path)
260 |
261 | def load_checkpoint(self):
262 | if not self.is_load_checkpoint:
263 | return {}
264 |
265 | def find_latest_checkpoint():
266 | checkpoint_files = glob.glob(
267 | os.path.join(
268 | self.save_dir,
269 | f"*.pth",
270 | )
271 | )
272 | if not checkpoint_files:
273 | return None
274 |
275 | latest_checkpoint_file = max(checkpoint_files, key=os.path.getctime)
276 | return latest_checkpoint_file
277 |
278 | latest_checkpoint = find_latest_checkpoint()
279 | if latest_checkpoint is not None:
280 | print("load checkpoint from {}".format(latest_checkpoint))
281 | ckpt = torch.load(latest_checkpoint, map_location=self.device)
282 | self.model.load_state_dict(ckpt["model"])
283 | self.optimizer.load_state_dict(ckpt["optimizer"])
284 | if getattr(self, "lr_scheduler", None):
285 | self.lr_scheduler.load_state_dict(ckpt["lr_scheduler"])
286 | if self.load_save_mode == "epoch":
287 | self.epoch = ckpt["epoch"]
288 | elif self.load_save_mode == "step":
289 | self.step = ckpt["step"]
290 | return ckpt["resume_kwargs"]
291 | return {}
292 |
--------------------------------------------------------------------------------
/src/utils/refinedweb_llama.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset, concatenate_datasets
2 | from transformers import AutoTokenizer
3 | from argparse import ArgumentParser
4 | import tiktoken
5 | import os
6 | from itertools import chain
7 | import numpy as np
8 | from tqdm import tqdm
9 | from transformers import LlamaTokenizer
10 |
11 | long_path = "/claire-rcp-scratch/shared/xwei/dataset/tiiuae___falcon-refinedweb/default-4033b99bd924aaad/0.0.0/0111277fb19b16f696664cde7f0cb90f833dec72db2cc73cfdf87e697f78fe02"
12 | cache_dir = "/claire-rcp-scratch/shared/xwei/dataset"
13 |
14 |
15 | def tokenize(tokenizer, num_proc, dataset):
16 | if tokenizer == "gpt2":
17 | enc = tiktoken.get_encoding("gpt2")
18 |
19 | def tokenize_process(example):
20 | ids = enc.encode_ordinary(
21 | example["text"]
22 | ) # encode_ordinary ignores any special tokens
23 | ids.append(
24 | enc.eot_token
25 | ) # add the end of text token, e.g. 50256 for gpt2 bpe
26 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though...
27 | out = {"ids": ids}
28 | return out
29 |
30 | elif tokenizer == "llama":
31 | enc = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
32 | eos_tokens = enc(
33 | "", truncation=False, padding=False, add_special_tokens=False
34 | )["input_ids"]
35 |
36 | def tokenize_process(example):
37 | ids = enc(
38 | example["text"],
39 | truncation=False,
40 | padding=False,
41 | add_special_tokens=False,
42 | )["input_ids"]
43 | ids = ids + eos_tokens
44 | out = {"ids": ids}
45 | return out
46 |
47 | else:
48 | raise NotImplementedError
49 |
50 | tokenized = dataset.map(
51 | tokenize_process,
52 | remove_columns=["text", "url", "timestamp", "dump", "segment", "image_urls"],
53 | desc="tokenizing the splits",
54 | num_proc=num_proc,
55 | )
56 | print(tokenized)
57 | return tokenized
58 |
59 |
60 | def group_context(block_size, num_proc, dataset):
61 |
62 | def group_process(examples):
63 | # Concatenate all texts.
64 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
65 | total_length = len(concatenated_examples[list(examples.keys())[0]])
66 | # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict.
67 | # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
68 | total_length = (total_length // block_size) * block_size
69 | # Split by chunks of max_len.
70 | result = {
71 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
72 | for k, t in concatenated_examples.items()
73 | }
74 | return result
75 |
76 | lm_datasets = dataset.map(
77 | group_process,
78 | batched=True,
79 | num_proc=num_proc,
80 | desc=f"Grouping texts in chunks of {block_size}",
81 | )
82 | print(lm_datasets)
83 | return lm_datasets
84 |
85 |
86 | def save_to_npmemmap(split, dset, tokenizer, block_size):
87 | arr_len = dset.num_rows
88 | print(split, arr_len)
89 | filename = os.path.join(
90 | os.path.join(cache_dir, "refinedweb"), f"{tokenizer}-{split}-tmp.bin"
91 | )
92 | dtype = np.uint16 # (can do since enc.max_token_value == 32000 is < 2**16)
93 | arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len, block_size))
94 | total_batches = 1024
95 |
96 | idx = 0
97 | for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"):
98 | # Batch together samples for faster write
99 | batch = dset.shard(
100 | num_shards=total_batches, index=batch_idx, contiguous=True
101 | ).with_format("numpy")
102 | # Write into mmap
103 | arr_batch = np.stack(batch["ids"])
104 | arr[idx : idx + arr_batch.shape[0], :] = arr_batch
105 | idx += arr_batch.shape[0]
106 | arr.flush()
107 |
108 |
109 | def parse_args():
110 | parser = ArgumentParser(
111 | description="Convert dataset into MDS format, optionally concatenating and tokenizing"
112 | )
113 | parser.add_argument("--tokenizer", type=str, required=True)
114 | parser.add_argument(
115 | "--block_size",
116 | type=int,
117 | help="Convert text to tokens and concatenate up to this many tokens",
118 | )
119 |
120 | parser.add_argument("--num_proc", type=int, required=True, default=None)
121 | return parser.parse_args()
122 |
123 |
124 | def main(args):
125 | print(args.num_proc)
126 | new_dataset = []
127 | for i in range(6):
128 | for j in range(10):
129 | if i == 5 and j > 3:
130 | continue
131 | refinedweb_chunk = load_dataset(
132 | path=long_path,
133 | split="train",
134 | data_files=f"falcon-refinedweb-train-0{i}{j}*-of-05379.arrow",
135 | num_proc=args.num_proc,
136 | ).shuffle(seed=i * 10 + j)
137 | print(refinedweb_chunk)
138 | total_rows = refinedweb_chunk.num_rows
139 | selected_rows = int(0.1 * total_rows)
140 | cur_chunk = refinedweb_chunk.select(range(selected_rows)).rename_column(
141 | "content", "text"
142 | )
143 | del refinedweb_chunk
144 | print("begin to tokenize!")
145 | # tokenization
146 | cur_chunk = tokenize(args.tokenizer, args.num_proc, cur_chunk)
147 | cur_chunk = group_context(args.block_size, args.num_proc, cur_chunk)
148 | new_dataset.append(cur_chunk)
149 |
150 | new_dataset = concatenate_datasets(new_dataset)
151 | new_dataset = new_dataset.train_test_split(test_size=0.01, seed=1005, shuffle=True)
152 |
153 | save_to_npmemmap("train", new_dataset["train"], args.tokenizer, args.block_size)
154 | save_to_npmemmap("val", new_dataset["test"], args.tokenizer, args.block_size)
155 |
156 |
157 | if __name__ == "__main__":
158 | main(parse_args())
159 |
--------------------------------------------------------------------------------