├── .gitignore
├── LICENSE
├── README.md
├── cmd
├── convert_checkpoint.py
└── gemma_demo
│ ├── generator.go
│ ├── main.go
│ └── ui.go
├── download
├── README.md
├── huggingface
│ └── huggingface.go
└── kaggle
│ ├── metadata.go
│ └── weights.go
├── go.mod
├── go.sum
├── samplers
└── samplers.go
├── sentencepiece
└── sentencepiece.go
├── transformers
├── attention.go
├── attentiontype_enumer.go
├── cache.go
├── config.go
├── gemmatype_enumer.go
├── layers.go
├── querypreattentionnormalisationtype_enumer.go
└── transformers.go
└── trees
├── trees.go
└── trees_test.go
/.gitignore:
--------------------------------------------------------------------------------
1 | # If you prefer the allow list template instead of the deny list, see community template:
2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
3 | #
4 | # Binaries for programs and plugins
5 | *.exe
6 | *.exe~
7 | *.dll
8 | *.so
9 | *.dylib
10 |
11 | # Test binary, built with `go test -c`
12 | *.test
13 |
14 | # Output of the go coverage tool, specifically when used with LiteIDE
15 | *.out
16 |
17 | # Dependency directories (remove the comment below to include it)
18 | # vendor/
19 |
20 | # Go workspace file
21 | go.work
22 | go.work.sum
23 |
24 | # env file
25 | .env
26 |
27 | # GoLand IDE
28 | .idea
29 |
30 | # Python script, local venv.
31 | venv
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | # GoMLX Gemma
5 |
6 | GoMLX (for Go) port of Google Deepmind's Gemma GenAI/LLM model.
7 |
8 | ## 📖 About GoMLX Gemma
9 |
10 | An implementation of [Google DeepMind](deepmind.google)'s [Gemma model](https://github.com/google-deepmind/gemma?tab=readme-ov-file)
11 | using [GoMLX, a Machine Learning framework for Go](https://github.com/gomlx/gomlx).
12 |
13 | It is very "_fresh from the oven_", so use it at your own risk.
14 | At the same time, I'm happy to help if you need any specific features, it's a good time for feature requests.
15 |
16 | ## ✅ **What is done** already:
17 |
18 | * **Sampling** / **Generating**: it provides the `samplers.Sampler` object to easily generate text.
19 | See example below, or `cmd/gemma_demo/generator.go` for an example.
20 | * HuggingFace Weights Version:
21 | * Download weights from HuggingFace, using provided AuthToken -- a read-only token will suffice.
22 | * Kaggle Version
23 | * Requires manually downloading weights from Kaggle.
24 | * Use provided `cmd/convert_checkpoint.py` script to convert Jax weights -- requires Python installation.
25 | * A command-line demo `cmd/gemma_demo`, with a simple [Charm](https://charm.sh/) interface.
26 |
27 | ## ❌ **Not done** yet:
28 |
29 | * **Fine-tuning**: the model is there, and it just needs some wiring together. But there is no sample code yet.
30 |
31 | ## ⌨️ Sample Code
32 |
33 | This is an example of how a `Sampler` object is created (for the simpler HuggingFace version) -- it requires the
34 | HuggingFace token (read-only) used to download to be set in HF_TOKEN -- go to HuggingFace webpage to generate one for you.
35 |
36 | ```go
37 | package main
38 |
39 | import (
40 | ...
41 |
42 | hfd "github.com/gomlx/gemma/download/huggingface"
43 | "github.com/gomlx/gemma/samplers"
44 | "github.com/gomlx/gomlx/backends"
45 | "github.com/gomlx/gomlx/ml/context"
46 |
47 | _ "github.com/gomlx/gomlx/backends/xla"
48 | )
49 |
50 | var (
51 | flagModelID = flag.String("model", "google/gemma-2-2b-it", "HuggingFace Gemma model id")
52 | flagDataDir = flag.String("data", "~/work/gemma", "Directory to cache downloaded dataset files.")
53 | )
54 |
55 | func main() {
56 | flag.Parse()
57 | prompts := []string{
58 | "What is 1+1 ?",
59 | "What are the planets of the solar system?",
60 | "```\n// BubbleSort is a Go function that sorts the Bubble Sort algorithm\nfunc BubbleSort[S ~[]E, E cmp.Ordered](x S) {\n",
61 | }
62 | ctx := context.New()
63 | vocab, err := hfd.Download(ctx, *flagModelID, os.Getenv("HF_TOKEN"), path.Join(*flagDataDir, "huggingface"))
64 | if err != nil {
65 | log.Fatalf("%+v", err)
66 | }
67 | sampler, err := samplers.New(backends.New(), ctx, vocab, 1024)
68 | if err != nil {
69 | log.Fatalf("%+v", err)
70 | }
71 |
72 | start := time.Now()
73 | output, err := sampler.Sample([]string{
74 | "What is 1+1?",
75 | "What are the planets of the solar system?",
76 | // "// BubbleSort is a Go function that sorts the Bubble Sort algorithm\nfunc BubbleSort[S ~[]E, E cmp.Ordered](x S)",
77 | })
78 | if err != nil {
79 | log.Fatalf("%+v", err)
80 | }
81 | fmt.Printf("\tElapsed time: %s\n", time.Since(start))
82 | fmt.Printf("Generated text:\n%s\n", strings.Join(output, "\n\n"))
83 | }
84 | ```
85 |
86 | ## 🔗 Resources
87 |
88 | 1. [**github.com/google-deepmind/gemma**](https://github.com/google-deepmind/gemma):
89 | [Gemma](https://ai.google.dev/gemma) is a family of open-weights Large Language Model (LLM) by [Google DeepMind](https://deepmind.google/),
90 | based on Gemini research and technology.
91 | 1.
[github.com/eliben/go-sentencepiece](https://github.com/eliben/go-sentencepiece):
92 | This is a pure Go implementation of encoding and decoding text with the [SentencePiece tokenizer](https://github.com/google/sentencepiece).
93 |
94 |
95 | ## 📝 TODO
96 |
97 | * Remove special symbols from sampling, like "".
98 | * Fine-tuning demo.
99 | * Benchmarking: how does it compare to Jax implementation ? Jax JIT-compile the main sampling loop during generation,
100 | which could be done with GoMLX, but it would require implementing some new features. Not sure it is needed yet.
101 | * At least in an old NVidia RTX 2080Ti, it works with GoMLX, and Jax reference implementation fails to sample,
102 | because it tries to JIT-compile the full sampling loop.
--------------------------------------------------------------------------------
/cmd/convert_checkpoint.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | Jax/Orbax Gemma Checkpoint to Raw Bytes Converter
6 |
7 | This script converts a Jax/Orbax Gemma model checkpoint into a directory structure
8 | containing raw byte representations of the model's parameter arrays, as well as their shape.
9 |
10 | Usage:
11 | python convert_checkpoint.py [--target_dir ]
12 |
13 | Arguments:
14 | source_dir: Path to the directory containing the Jax/Orbax checkpoint.
15 | target_dir: (Optional) Path to the directory where the raw bytes will be saved.
16 | Defaults to '/raw/' if not provided.
17 |
18 | It requires the following libraries installed, probably in a virtual environment (venv) or equivalent::
19 |
20 | ```
21 | pip install jax "git+https://github.com/google-deepmind/gemma.git"
22 | ```
23 | """
24 |
25 | import argparse
26 | import os
27 | import jax
28 | from gemma import params as params_lib
29 |
30 | def read_parameter(path):
31 | """
32 | Read model checkpoint parameters from path to directory.
33 |
34 | :param path: Path to directory holding the checkpoint to read the parameters from.
35 | :return: PyTree of jaxlib.xla_extension.ArrayImpl
36 | """
37 | path = os.path.expanduser(path)
38 | return params_lib.load_and_format_params(path)
39 |
40 |
41 | def write_params(params, base_dir):
42 | """Write parameters to structured to directory: each file correspond to one array written as raw bytes."""
43 | base_dir = os.path.expanduser(base_dir)
44 | for path, array in flatten_params(params):
45 | base_file_path = os.path.join(base_dir, *path)
46 |
47 | # Create necessary directories
48 | os.makedirs(os.path.dirname(base_file_path), exist_ok=True)
49 |
50 | # Save array.
51 | with open(base_file_path+".raw", 'wb') as f:
52 | f.write(array.tobytes())
53 |
54 | # Save shape.
55 | with open(base_file_path+".shape", 'w') as f:
56 | f.write(serialize_shape(array))
57 |
58 |
59 | def path_to_str_tuple(path):
60 | """Converts a PyTree path (tuple of jax.tree_util.DictKey) to a tuple of strings."""
61 | return [e.key for e in path]
62 |
63 |
64 | def flatten_params(params):
65 | """Convert PyTree of arrays to a list of pairs of (path, array), where path is itself a tuple of strings."""
66 | list = []
67 | def append_to_list(path, value):
68 | list.append((path_to_str_tuple(path), value))
69 |
70 | jax.tree_util.tree_map_with_path(append_to_list, params)
71 | return list
72 |
73 |
74 | def serialize_shape(array):
75 | """Return an encoding of the given array's shape (including dtype)."""
76 | return ",".join([f"{str(array.dtype)}"]+[str(i) for i in array.shape])
77 |
78 |
79 | def main():
80 | parser = argparse.ArgumentParser(description="Convert Gemma Jax/Orbax checkpoint to raw bytes.")
81 | parser.add_argument("source_dir", help="Path to the source directory containing the checkpoint.")
82 | parser.add_argument("--target_dir", help="Path to the target directory where the raw bytes will be saved. Defaults to source_dir + 'raw/' if not provided.")
83 |
84 | args = parser.parse_args()
85 |
86 | source_dir = os.path.abspath(os.path.expanduser(args.source_dir))
87 | target_dir = os.path.abspath(os.path.expanduser(args.target_dir)) if args.target_dir else os.path.join(source_dir, "raw")
88 |
89 | print("Conversion from Jax/Orbax Gemma checkpoint to raw arrays/shapes:")
90 | print(f"\tSource directory: {source_dir}")
91 | print(f"\tTarget directory: {target_dir}")
92 |
93 | # We don't want to use GPU memory for this.
94 | jax.config.update('jax_platform_name', 'cpu')
95 |
96 | params = read_parameter(source_dir)
97 | write_params(params, target_dir)
98 |
99 |
100 | if __name__ == "__main__":
101 | main()
--------------------------------------------------------------------------------
/cmd/gemma_demo/generator.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "flag"
5 | hfd "github.com/gomlx/gemma/download/huggingface"
6 | "github.com/gomlx/gemma/samplers"
7 | "github.com/gomlx/gomlx/backends"
8 | _ "github.com/gomlx/gomlx/backends/xla"
9 | "github.com/gomlx/gomlx/ml/context"
10 | "github.com/janpfeifer/must"
11 | "os"
12 | "path"
13 | )
14 |
15 | var (
16 | flagDataDir = flag.String("data", "~/work/gemma", "Directory to cache downloaded and generated dataset files.")
17 | flagModelID = flag.String("model", "google/gemma-2-2b-it", "HuggingFace Gemma model id")
18 | flagMaxGeneratedTokens = flag.Int("max_tokens", 1024, "Maximum number of tokens to generate.")
19 | )
20 |
21 | func BuildSampler() *samplers.Sampler {
22 | ctx := context.New()
23 | vocab := must.M1(hfd.Download(ctx, *flagModelID, os.Getenv("HF_TOKEN"), path.Join(*flagDataDir, "huggingface")))
24 | return must.M1(samplers.New(backends.New(), ctx, vocab, *flagMaxGeneratedTokens))
25 | }
26 |
--------------------------------------------------------------------------------
/cmd/gemma_demo/main.go:
--------------------------------------------------------------------------------
1 | // gemma_demo uses Gemma for GoMLX to generate text given a prompt.
2 | //
3 | // It also uses github.com/charmbracelet libraries to make for a pretty command-line UI.
4 | package main
5 |
6 | import (
7 | "flag"
8 | "fmt"
9 | tea "github.com/charmbracelet/bubbletea"
10 | "github.com/gomlx/exceptions"
11 | "os"
12 | )
13 |
14 | func main() {
15 | flag.Parse()
16 |
17 | var p *tea.Program
18 | err := exceptions.TryCatch[error](func() { p = tea.NewProgram(newUIModel()) })
19 | if err != nil {
20 | fmt.Fprintf(os.Stderr, "Alas, there's been an error: %+v", err)
21 | os.Exit(1)
22 | }
23 | _, err = p.Run()
24 | if err != nil {
25 | fmt.Fprintf(os.Stderr, "Alas, there's been an error: %+v", err)
26 | os.Exit(1)
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/cmd/gemma_demo/ui.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "github.com/charmbracelet/bubbles/textarea"
6 | "github.com/charmbracelet/bubbles/viewport"
7 | tea "github.com/charmbracelet/bubbletea"
8 | "github.com/charmbracelet/lipgloss"
9 | "github.com/gomlx/gemma/samplers"
10 | )
11 |
12 | type uiModel struct {
13 | textarea textarea.Model
14 | viewport viewport.Model
15 | submitted bool
16 | sampler *samplers.Sampler
17 | err error
18 | }
19 |
20 | func newUIModel() *uiModel {
21 | ta := textarea.New()
22 | ta.Placeholder = "Gemma Prompt:"
23 | ta.Focus()
24 |
25 | vp := viewport.New(0, 0)
26 | vp.Style = lipgloss.NewStyle().Margin(1, 2).
27 | Border(lipgloss.NormalBorder()).BorderForeground(lipgloss.Color("99"))
28 |
29 | return &uiModel{
30 | textarea: ta,
31 | viewport: vp,
32 | sampler: BuildSampler(),
33 | }
34 | }
35 |
36 | func (m *uiModel) Init() tea.Cmd {
37 | return textarea.Blink
38 | }
39 |
40 | func (m *uiModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
41 | var (
42 | taCmd tea.Cmd
43 | vpCmd tea.Cmd
44 | cmds []tea.Cmd
45 | resize bool
46 | )
47 |
48 | switch msg := msg.(type) {
49 | case tea.KeyMsg:
50 | switch {
51 | case msg.Type == tea.KeyCtrlC || msg.Type == tea.KeyEsc:
52 | return m, tea.Quit
53 | case msg.Type == tea.KeyCtrlL:
54 | m.textarea.Reset()
55 |
56 | case msg.Type == tea.KeyCtrlD && !m.submitted: // Ctrl+Enter to submit
57 | m.submitted = true
58 | generatedContent, err := m.Generate()
59 | if err != nil {
60 | m.err = err
61 | return m, tea.Quit
62 | }
63 | m.viewport.SetContent(generatedContent)
64 | m.textarea.Blur()
65 | m.textarea.SetValue(generatedContent)
66 |
67 | case m.submitted && msg.Type == tea.KeyEnter: // Enter while submitted to edit
68 | m.submitted = false
69 | m.textarea.Focus()
70 | }
71 |
72 | case tea.WindowSizeMsg:
73 | resize = true
74 | m.viewport.Width = msg.Width
75 | m.viewport.Height = msg.Height - 3 // Account for textarea and margins
76 | m.textarea.SetWidth(msg.Width - 4) // Account for textarea margins
77 | m.textarea.SetHeight(msg.Height - 8)
78 | }
79 |
80 | m.textarea, taCmd = m.textarea.Update(msg)
81 | m.viewport, vpCmd = m.viewport.Update(msg)
82 |
83 | if resize {
84 | cmds = append(cmds, vpCmd)
85 | }
86 |
87 | return m, tea.Batch(append(cmds, taCmd)...)
88 | }
89 |
90 | func (m *uiModel) Generate() (string, error) {
91 | outputs, err := m.sampler.Sample([]string{m.textarea.Value()})
92 | if err != nil {
93 | return "", err
94 | }
95 | return outputs[0], nil
96 | }
97 |
98 | func (m *uiModel) View() string {
99 | if m.submitted {
100 | return fmt.Sprintf("\n%s\n\nPress Enter to edit...", m.viewport.View())
101 | }
102 |
103 | return fmt.Sprintf(
104 | "\n%s\n\n"+
105 | "\t\u2022 Ctrl+C or ESC to quit;\n"+
106 | "\t• Ctrl+D to submit;\n"+
107 | "\t• Ctrl+L to clear the prompt.\n",
108 | m.textarea.View(),
109 | )
110 | }
111 |
--------------------------------------------------------------------------------
/download/README.md:
--------------------------------------------------------------------------------
1 | # Download
2 |
3 | The Gemma models can be downloaded from Google/Kaggle site or from HuggingFace.
4 |
5 | The recommendation is to use the huggingface version, it has a couple of advantages:
6 |
7 | This has some advantages from downloading it from Google (Kaggle):
8 |
9 | - With a HuggingFace token, the process is automatic (no need to manually navigate to the site).
10 | - No need for conversion of the model, the library reads directly from the HuggingFace ".safetensors" format into
11 | GoMLX context.
12 | - No Python dependency.
13 |
--------------------------------------------------------------------------------
/download/huggingface/huggingface.go:
--------------------------------------------------------------------------------
1 | // Package huggingface handles downloading Gemma model weights from HuggingFace.
2 | //
3 | // This has some advantages from downloading it from Google (Kaggle):
4 | //
5 | // - With a HuggingFace token, the process is automatic.
6 | // - No need for conversion of the model, the library reads directly from the HuggingFace ".safetensors" format into
7 | // GoMLX context.
8 | // - No Python dependency.
9 | //
10 | // Example:
11 | package huggingface
12 |
13 | import (
14 | "fmt"
15 | "github.com/gomlx/gemma/sentencepiece"
16 | "github.com/gomlx/gomlx/ml/context"
17 | "github.com/gomlx/gomlx/ml/data"
18 | gomlxhf "github.com/gomlx/gomlx/ml/data/huggingface"
19 | "github.com/gomlx/gomlx/types/xslices"
20 | "path"
21 | "strconv"
22 | "strings"
23 | )
24 |
25 | // Download will download (if needed) the Gemma model identified by hfID (it's a HuggingFace model id, e.g.: "google/gemma-2-2b-it"),
26 | // and save under the cacheDir (for future reuse).
27 | //
28 | // The hfAuthToken is a HuggingFace token -- read-only access -- that needs to be created once in HuggingFace site.
29 | //
30 | // It loads the weights into the given context and creates a sentencepiece tokenizer (vocab) that is returned.
31 | //
32 | // An error is returned if something fails.
33 | func Download(ctx *context.Context, hfID, hfAuthToken, cacheDir string) (vocab *sentencepiece.Tokenizer, err error) {
34 | cacheDir = data.ReplaceTildeInDir(cacheDir)
35 | var hfm *gomlxhf.Model
36 | hfm, err = gomlxhf.New(hfID, hfAuthToken, cacheDir)
37 | if err != nil {
38 | return
39 | }
40 | err = hfm.Download()
41 | if err != nil {
42 | return
43 | }
44 |
45 | vocab, err = sentencepiece.NewFromPath(path.Join(hfm.BaseDir, "tokenizer.model"))
46 | if err != nil {
47 | return
48 | }
49 |
50 | for entry, err2 := range hfm.EnumerateTensors() {
51 | if err2 != nil {
52 | err = err2
53 | return
54 | }
55 | scopeAndName := convertHuggingFaceNameToScopeAndName(entry.Name)
56 | if len(scopeAndName) == 0 {
57 | fmt.Printf("Skipping: %s -> %s\n", entry.Name, entry.Tensor.Shape())
58 | } else {
59 | ctxTmp := ctx.In("model")
60 | name, scope := xslices.Pop(scopeAndName)
61 | for _, p := range scope {
62 | ctxTmp = ctxTmp.In(p)
63 | }
64 | ctxTmp.VariableWithValue(name, entry.Tensor)
65 | }
66 | }
67 | return
68 | }
69 |
70 | func convertHuggingFaceNameToScopeAndName(name string) []string {
71 | if name == "model.embed_tokens.weight" {
72 | return []string{"embedder", "input_embedding"}
73 | } else if name == "model.norm.weight" {
74 | return []string{"final_norm", "scale"}
75 | }
76 |
77 | // Parse the layer number for the name prefixed as "model.layers.X.<...>"
78 | if strings.HasPrefix(name, "model.layers.") {
79 | parts := strings.Split(name, ".")
80 | if len(parts) < 5 || xslices.Last(parts) != "weight" {
81 | return nil
82 | }
83 | layerNumberStr := parts[2]
84 | layerNumber, err := strconv.Atoi(layerNumberStr)
85 | if err != nil {
86 | return nil
87 | }
88 | layerScope := fmt.Sprintf("layer_%d", layerNumber)
89 | switch parts[3] {
90 | case "input_layernorm":
91 | return append([]string{layerScope, "pre_attention_norm", "scale"})
92 | case "post_attention_layernorm":
93 | return append([]string{layerScope, "post_attention_norm", "scale"})
94 | case "post_feedforward_layernorm":
95 | return append([]string{layerScope, "post_ffw_norm", "scale"})
96 | case "pre_feedforward_layernorm":
97 | return append([]string{layerScope, "pre_ffw_norm", "scale"})
98 | case "mlp":
99 | // For the MLP (the GatedFeedForwardNetwork), the weights in HuggingFace are transposed/split differently,
100 | // so they take new variable names not matching those in Kaggle version.
101 | switch parts[4] {
102 | case "down_proj":
103 | return append([]string{layerScope, "mlp", "hf", "down_proj"})
104 | case "gate_proj":
105 | return append([]string{layerScope, "mlp", "hf", "gating_proj"})
106 | case "up_proj":
107 | return append([]string{layerScope, "mlp", "hf", "up_proj"})
108 | default:
109 | return nil
110 | }
111 | case "self_attn":
112 | return append([]string{layerScope, "attn", "hf", parts[4]})
113 | default:
114 | return nil
115 | }
116 | }
117 | return nil
118 | }
119 |
--------------------------------------------------------------------------------
/download/kaggle/metadata.go:
--------------------------------------------------------------------------------
1 | package kaggle
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "github.com/gomlx/gemma/trees"
7 | "github.com/gomlx/gomlx/ml/data"
8 | "github.com/pkg/errors"
9 | "os"
10 | "path"
11 | "regexp"
12 | "strings"
13 | )
14 |
15 | const (
16 | MetadataFileName = "_METADATA"
17 | KeyUseZarr3 = "use_zarr3"
18 | KeyTreeMetadata = "tree_metadata"
19 | KeyKeyMetadata = "key_metadata"
20 | KeyValueMetadata = "value_metadata"
21 | )
22 |
23 | // ReadMetadata returns the metadata loaded from the given directory in the form of a tree.
24 | func ReadMetadata(checkpointDir string) (tree *trees.Tree[*Metadata], err error) {
25 | checkpointDir = data.ReplaceTildeInDir(checkpointDir)
26 | metadataPath := path.Join(checkpointDir, MetadataFileName)
27 | var f *os.File
28 | f, err = os.Open(metadataPath)
29 | if err != nil {
30 | err = errors.Wrapf(err, "failed to read aggregate checkpoint file from %q", metadataPath)
31 | return
32 | }
33 | defer func() { _ = f.Close() }()
34 |
35 | dec := json.NewDecoder(f)
36 | var jsonTree any
37 | err = dec.Decode(&jsonTree)
38 | if err != nil {
39 | return
40 | }
41 | tree, err = fromJsonTreeMetaData(jsonTree)
42 | return
43 | }
44 |
45 | // Metadata of one checkpoint entry (usually weights or embedding tables of model)
46 | type Metadata struct {
47 | KeyPath []string
48 |
49 | KeyToKeyType map[string]MetadataKeyType
50 | ValueType string
51 | SkipDeserialize bool
52 | }
53 |
54 | // String implements fmt.Stringer.
55 | func (m *Metadata) String() string {
56 | deserialize := ""
57 | if m.SkipDeserialize {
58 | deserialize = " [*]"
59 | }
60 | return fmt.Sprintf("%s%s", m.ValueType, deserialize)
61 | }
62 |
63 | type MetadataKeyType int
64 |
65 | const (
66 | KeyTypeSequence MetadataKeyType = 1
67 | KeyTypeDict = 2
68 | )
69 |
70 | func fromJsonTreeMetaData(jsonTree any) (tree *trees.Tree[*Metadata], err error) {
71 | mapAny, ok := jsonTree.(map[string]any)
72 | if !ok {
73 | err = errors.Errorf("expected json to be a map of strings, got %T instead", jsonTree)
74 | return
75 | }
76 | _ = mapAny
77 | tree = trees.New[*Metadata]()
78 | for key, value := range mapAny {
79 | switch key {
80 | case KeyUseZarr3:
81 | // Check for Zarr3: not supported.
82 | zarr3, ok := value.(bool)
83 | if !ok {
84 | err = errors.Errorf("metadata json value for key %q is not a bool, got %T instead", key, value)
85 | return
86 | }
87 | if zarr3 {
88 | err = errors.Errorf("%q set to true, but Zarr3 is not supported by this library", key)
89 | return
90 | }
91 | case KeyTreeMetadata:
92 | entries, ok := value.(map[string]any)
93 | if !ok {
94 | err = errors.Errorf("metadata json value for key %q is not a map[string]any, got %T instead", key, value)
95 | return
96 | }
97 | for keyPath, jsonEntryAny := range entries {
98 | jsonEntry, ok := jsonEntryAny.(map[string]any)
99 | if !ok {
100 | err = errors.Errorf("metadata json value for key %q/%q is not a map[string]any, got %T instead", key, keyPath, jsonEntryAny)
101 | return
102 | }
103 | err = parseJsonMetadataEntry(tree, keyPath, jsonEntry)
104 | if err != nil {
105 | err = errors.WithMessagef(err, "metadata json value for key %q/%q", key, keyPath)
106 | return
107 | }
108 | }
109 |
110 | default:
111 | err = errors.Errorf("metadata json key %q unknown, don't know how to proceed", key)
112 | return
113 | }
114 |
115 | }
116 | return
117 | }
118 |
119 | func parseJsonMetadataEntry(tree *trees.Tree[*Metadata], keyPath string, jsonEntry map[string]any) error {
120 | entry := &Metadata{}
121 | if err := parseKeyPath(entry, keyPath); err != nil {
122 | return err
123 | }
124 |
125 | keyMetadataJsonAny, found := jsonEntry[KeyKeyMetadata]
126 | if !found {
127 | return errors.Errorf("missing KeyMetadata (key %q)", KeyKeyMetadata)
128 | }
129 | keyMetadataJson, ok := keyMetadataJsonAny.([]any)
130 | if !ok {
131 | return errors.Errorf("invalid KeyMetadata (key %q) type %T, expected []any", KeyKeyMetadata, keyMetadataJsonAny)
132 | }
133 | parseKeyMetadata(entry, keyMetadataJson)
134 |
135 | valueMetadataJsonAny, found := jsonEntry[KeyValueMetadata]
136 | if !found {
137 | return errors.Errorf("missing ValueMetadata (key %q)", KeyValueMetadata)
138 | }
139 | valueMetadataJson, ok := valueMetadataJsonAny.(map[string]any)
140 | if !ok {
141 | return errors.Errorf("invalid ValueMetadata (key %q) type %T, expected map[string]any", KeyValueMetadata, valueMetadataJsonAny)
142 | }
143 | parseValueMetadata(entry, valueMetadataJson)
144 | tree.Set(entry.KeyPath, entry)
145 | return nil
146 | }
147 |
148 | var reParseKeyPath = regexp.MustCompile(`'(.*?)'\s*[,)]`)
149 |
150 | func parseKeyPath(metadata *Metadata, keyPathStr string) error {
151 | matches := reParseKeyPath.FindAllStringSubmatch(keyPathStr, -1)
152 | if len(matches) == 0 {
153 | return errors.Errorf("can't parse keypath from %q", keyPathStr)
154 | }
155 | for _, match := range matches {
156 | metadata.KeyPath = append(metadata.KeyPath, match[1])
157 | }
158 | return nil
159 | }
160 |
161 | func parseKeyMetadata(metadata *Metadata, keyMetadataJson []any) {
162 | if metadata.KeyToKeyType == nil {
163 | metadata.KeyToKeyType = make(map[string]MetadataKeyType)
164 | }
165 | for _, entryAny := range keyMetadataJson {
166 | entry := entryAny.(map[string]any)
167 | metadata.KeyToKeyType[entry["key"].(string)] = MetadataKeyType(entry["key_type"].(float64))
168 | }
169 | }
170 |
171 | func parseValueMetadata(metadata *Metadata, valueMetadataJson map[string]any) {
172 | metadata.ValueType = valueMetadataJson["value_type"].(string)
173 | metadata.SkipDeserialize = valueMetadataJson["skip_deserialize"].(bool)
174 | }
175 |
176 | // ParamNames convert metadata to the paramNames (?? not sure where in Gemma this is sued)
177 | func ParamNames(metadata *trees.Tree[*Metadata]) *trees.Tree[string] {
178 | return trees.Map(metadata, func(treePath trees.Path, metadata *Metadata) string {
179 | return strings.Join(treePath, ".")
180 | })
181 | }
182 |
--------------------------------------------------------------------------------
/download/kaggle/weights.go:
--------------------------------------------------------------------------------
1 | // Package kaggle loads Gemma weights into tensors along with the matching metadata, after they
2 | // have been downloaded from kaggle and converted using the included cmd/convert_checkpoint.py.
3 | package kaggle
4 |
5 | import (
6 | "github.com/dustin/go-humanize"
7 | "github.com/gomlx/gemma/trees"
8 | "github.com/gomlx/gomlx/ml/context"
9 | "github.com/gomlx/gomlx/ml/data"
10 | "github.com/gomlx/gomlx/types/shapes"
11 | "github.com/gomlx/gomlx/types/tensors"
12 | "github.com/gomlx/gomlx/types/xslices"
13 | "github.com/gomlx/gopjrt/dtypes"
14 | "github.com/janpfeifer/must"
15 | "github.com/pkg/errors"
16 | "github.com/vmihailenco/msgpack"
17 | "io"
18 | "io/fs"
19 | "os"
20 | "path"
21 | "path/filepath"
22 | "strconv"
23 | "strings"
24 | )
25 |
26 | const (
27 | AggregateFileName = "checkpoint"
28 |
29 | // OCDBTManifestFileName indicates usage of "Orbax Consistent Distributed Backend Tree" (OCDBT).
30 | OCDBTManifestFileName = "manifest.ocdbt"
31 | )
32 |
33 | // ReadConvertedWeights from checkpointDir (under the "raw/" subdirectory).
34 | // It will read the weights and shape converted by the `convert_checkpoint.py` script
35 | // (see github.com/gomlx/gemma repository, under cmd/convert_checkpoint.py) and set them
36 | // in the given context, under its current scope.
37 | func ReadConvertedWeights(ctx *context.Context, checkpointDir string) error {
38 | weights, err := ReadConvertedWeightsToTree(checkpointDir)
39 | if err != nil {
40 | return err
41 | }
42 | UploadWeightsToContext(ctx.In("model"), weights)
43 | return nil
44 | }
45 |
46 | // ReadConvertedWeightsToTree from checkpointDir (under the "raw/" subdirectory).
47 | // It will read the weights and shape converted by the `convert_checkpoint.py` script
48 | // (see github.com/gomlx/gemma repository, under cmd/convert_checkpoint.py)
49 | //
50 | // It returns a tree of tensors, with the path matching those of the original Jax checkpoint.
51 | func ReadConvertedWeightsToTree(checkpointDir string) (tree *trees.Tree[*tensors.Tensor], err error) {
52 | rawDir := path.Join(checkpointDir, "raw")
53 | if !data.FileExists(rawDir) {
54 | err = errors.Errorf(
55 | "ReadConvertedWeights(%q), the given directory doesn't have a subdirectory 'raw/' with the converted files",
56 | checkpointDir)
57 | return
58 | }
59 | tree = trees.New[*tensors.Tensor]()
60 | err = fs.WalkDir(os.DirFS(rawDir), ".", func(filePath string, entry fs.DirEntry, err error) error {
61 | if err != nil {
62 | return errors.Wrapf(err, "failed to traverse %q", rawDir)
63 | }
64 | if entry.IsDir() {
65 | return nil
66 | }
67 | ext := filepath.Ext(filePath)
68 | if ext != ".raw" {
69 | return nil
70 | }
71 |
72 | // Here we have teh pair of files ".shape" and ".raw":
73 | base := strings.TrimSuffix(filePath, ext)
74 | basePath := path.Join(rawDir, base)
75 | shapeFilePath := basePath + ".shape"
76 | if !data.FileExists(shapeFilePath) {
77 | return nil
78 | }
79 | shapeBytes, err := os.ReadFile(shapeFilePath)
80 | if err != nil {
81 | return errors.Wrapf(err, "failed to read shape from %q", shapeFilePath)
82 | }
83 | shapeParts := strings.Split(string(shapeBytes), ",")
84 | dtype, err := dtypes.DTypeString(shapeParts[0])
85 | if err != nil {
86 | return errors.Wrapf(err, "unknown dtype read from %q", shapeFilePath)
87 | }
88 | shapeDims := xslices.Map(shapeParts[1:], func(s string) (v int) {
89 | if err != nil {
90 | return 0
91 | }
92 | v, err = strconv.Atoi(s)
93 | return
94 | })
95 | if err != nil {
96 | return errors.Wrapf(err, "failed to convert %q to a dimension, read from %q", shapeBytes, basePath+".shape")
97 | }
98 | shape := shapes.Make(dtype, shapeDims...)
99 |
100 | rawFilePath := basePath + ".raw"
101 | info, err := entry.Info()
102 | if err != nil {
103 | return errors.Wrapf(err, "failed to get info from %q", rawFilePath)
104 | }
105 | if info.Size() != int64(shape.Memory()) {
106 | return errors.Errorf("file %q has %d bytes, but shape %s (read from %q) requires %d bytes, something went wrong in the conversion",
107 | rawFilePath, info.Size(), shape, shapeFilePath, shape.Memory())
108 | }
109 |
110 | treePath := strings.Split(base, "/")
111 | //fmt.Printf("%q -> %s\n", treePath, shape)
112 |
113 | tensor := tensors.FromShape(shape)
114 | f, err := os.Open(rawFilePath)
115 | if err != nil {
116 | return errors.Wrapf(err, "failed to open raw data from %q", rawFilePath)
117 | }
118 | var n int
119 | tensor.MutableBytes(func(data []byte) {
120 | n, err = io.ReadFull(f, data)
121 | })
122 | _ = f.Close()
123 | if err != nil {
124 | return errors.Wrapf(err, "failed to read raw data from %q", rawFilePath)
125 | }
126 | if n != int(shape.Memory()) {
127 | return errors.Errorf("read %d bytes from %q, expected %d", n, rawFilePath, shape.Memory())
128 | }
129 | err = tree.Set(treePath, tensor)
130 | if err != nil {
131 | return errors.WithMessagef(err, "failed to set variable with %s", humanize.Bytes(uint64(n)))
132 | }
133 | return nil
134 | })
135 | if err != nil {
136 | return nil, err
137 | }
138 | return
139 | }
140 |
141 | // PyReadAggregate of Python checkpoint. Not used by Gemma v2.
142 | func PyReadAggregate(checkpointDir string) (results any, err error) {
143 | checkpointDir = data.ReplaceTildeInDir(checkpointDir)
144 | aggregatePath := path.Join(checkpointDir, AggregateFileName)
145 | var f *os.File
146 | f, err = os.Open(aggregatePath)
147 | if err != nil {
148 | err = errors.Wrapf(err, "failed to read aggregate checkpoint file from %q", aggregatePath)
149 | return
150 | }
151 |
152 | dec := msgpack.NewDecoder(f)
153 | results, err = dec.DecodeMap()
154 | defer func() { _ = f.Close() }()
155 | return
156 | }
157 |
158 | func isOCDBT(checkpointDir string) bool {
159 | checkpointDir = data.ReplaceTildeInDir(checkpointDir)
160 | ocdbtPath := path.Join(checkpointDir, OCDBTManifestFileName)
161 | return data.FileExists(ocdbtPath)
162 | }
163 |
164 | type PyParamInfo struct {
165 | Name, Path string
166 | SkipDeserialize bool
167 | }
168 |
169 | func PyReadParamInfo(checkpointDir string) *trees.Tree[*PyParamInfo] {
170 | checkpointDir = data.ReplaceTildeInDir(checkpointDir)
171 | metadata := must.M1(ReadMetadata(checkpointDir))
172 | return trees.Map(metadata, func(p trees.Path, meta *Metadata) *PyParamInfo {
173 | name := strings.Join(p, ".")
174 | pi := &PyParamInfo{
175 | Name: name,
176 | Path: path.Join(checkpointDir, name),
177 | SkipDeserialize: meta.SkipDeserialize,
178 | }
179 | return pi
180 | })
181 | }
182 |
183 | // UploadWeightsToContext creates variables corresponding to the weights.
184 | // It returns the ctx given, with the variables set.
185 | //
186 | // It's tightly coupled with the model building functions in this package.
187 | // Meaning the modeling must match the naming here.
188 | func UploadWeightsToContext(ctx *context.Context, weights *trees.Tree[*tensors.Tensor]) {
189 | weights = weights.Map["transformer"]
190 | for treePath, tensor := range weights.Leaves() {
191 | scopedCtx := ctx
192 | scopeParts := treePath[:len(treePath)-1]
193 | for _, p := range scopeParts {
194 | scopedCtx = scopedCtx.In(p)
195 | }
196 | varName := treePath[len(treePath)-1]
197 | _ = scopedCtx.VariableWithValue(varName, tensor)
198 | }
199 | }
200 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/gomlx/gemma
2 |
3 | go 1.23.0
4 |
5 | require (
6 | github.com/charmbracelet/bubbles v0.20.0
7 | github.com/charmbracelet/bubbletea v1.1.2
8 | github.com/charmbracelet/lipgloss v1.0.0
9 | github.com/dustin/go-humanize v1.0.1
10 | github.com/eliben/go-sentencepiece v0.6.0
11 | github.com/gomlx/exceptions v0.0.3
12 | github.com/gomlx/gomlx v0.15.0
13 | github.com/gomlx/gopjrt v0.4.4
14 | github.com/janpfeifer/must v0.2.0
15 | github.com/pkg/errors v0.9.1
16 | github.com/stretchr/testify v1.9.0
17 | github.com/vmihailenco/msgpack v4.0.4+incompatible
18 | golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
19 | k8s.io/klog/v2 v2.130.1
20 | )
21 |
22 | require (
23 | github.com/atotto/clipboard v0.1.4 // indirect
24 | github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
25 | github.com/charmbracelet/x/ansi v0.4.2 // indirect
26 | github.com/charmbracelet/x/term v0.2.0 // indirect
27 | github.com/davecgh/go-spew v1.1.1 // indirect
28 | github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
29 | github.com/go-logr/logr v1.4.2 // indirect
30 | github.com/golang/protobuf v1.5.4 // indirect
31 | github.com/google/uuid v1.6.0 // indirect
32 | github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
33 | github.com/mattn/go-isatty v0.0.20 // indirect
34 | github.com/mattn/go-localereader v0.0.1 // indirect
35 | github.com/mattn/go-runewidth v0.0.16 // indirect
36 | github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect
37 | github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
38 | github.com/muesli/cancelreader v0.2.2 // indirect
39 | github.com/muesli/termenv v0.15.2 // indirect
40 | github.com/pmezard/go-difflib v1.0.0 // indirect
41 | github.com/rivo/uniseg v0.4.7 // indirect
42 | github.com/schollz/progressbar/v3 v3.17.0 // indirect
43 | github.com/x448/float16 v0.8.4 // indirect
44 | golang.org/x/sync v0.8.0 // indirect
45 | golang.org/x/sys v0.26.0 // indirect
46 | golang.org/x/term v0.25.0 // indirect
47 | golang.org/x/text v0.19.0 // indirect
48 | google.golang.org/appengine v1.6.8 // indirect
49 | google.golang.org/protobuf v1.35.1 // indirect
50 | gopkg.in/yaml.v3 v3.0.1 // indirect
51 | )
52 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
2 | github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
3 | github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
4 | github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
5 | github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
6 | github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
7 | github.com/charmbracelet/bubbles v0.20.0 h1:jSZu6qD8cRQ6k9OMfR1WlM+ruM8fkPWkHvQWD9LIutE=
8 | github.com/charmbracelet/bubbles v0.20.0/go.mod h1:39slydyswPy+uVOHZ5x/GjwVAFkCsV8IIVy+4MhzwwU=
9 | github.com/charmbracelet/bubbletea v1.1.2 h1:naQXF2laRxyLyil/i7fxdpiz1/k06IKquhm4vBfHsIc=
10 | github.com/charmbracelet/bubbletea v1.1.2/go.mod h1:9HIU/hBV24qKjlehyj8z1r/tR9TYTQEag+cWZnuXo8E=
11 | github.com/charmbracelet/lipgloss v1.0.0 h1:O7VkGDvqEdGi93X+DeqsQ7PKHDgtQfF8j8/O2qFMQNg=
12 | github.com/charmbracelet/lipgloss v1.0.0/go.mod h1:U5fy9Z+C38obMs+T+tJqst9VGzlOYGj4ri9reL3qUlo=
13 | github.com/charmbracelet/x/ansi v0.4.2 h1:0JM6Aj/g/KC154/gOP4vfxun0ff6itogDYk41kof+qk=
14 | github.com/charmbracelet/x/ansi v0.4.2/go.mod h1:dk73KoMTT5AX5BsX0KrqhsTqAnhZZoCBjs7dGWp4Ktw=
15 | github.com/charmbracelet/x/term v0.2.0 h1:cNB9Ot9q8I711MyZ7myUR5HFWL/lc3OpU8jZ4hwm0x0=
16 | github.com/charmbracelet/x/term v0.2.0/go.mod h1:GVxgxAbjUrmpvIINHIQnJJKpMlHiZ4cktEQCN6GWyF0=
17 | github.com/chengxilo/virtualterm v1.0.4 h1:Z6IpERbRVlfB8WkOmtbHiDbBANU7cimRIof7mk9/PwM=
18 | github.com/chengxilo/virtualterm v1.0.4/go.mod h1:DyxxBZz/x1iqJjFxTFcr6/x+jSpqN0iwWCOK1q10rlY=
19 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
20 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
21 | github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
22 | github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
23 | github.com/eliben/go-sentencepiece v0.6.0 h1:wbnefMCxYyVYmeTVtiMJet+mS9CVwq5klveLpfQLsnk=
24 | github.com/eliben/go-sentencepiece v0.6.0/go.mod h1:nNYk4aMzgBoI6QFp4LUG8Eu1uO9fHD9L5ZEre93o9+c=
25 | github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
26 | github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
27 | github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
28 | github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
29 | github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
30 | github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
31 | github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
32 | github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
33 | github.com/gomlx/exceptions v0.0.3 h1:HKnTgEjj4jlmhr8zVFkTP9qmV1ey7ypYYosQ8GzXWuM=
34 | github.com/gomlx/exceptions v0.0.3/go.mod h1:uHL0TQwJ0xaV2/snJOJV6hSE4yRmhhfymuYgNredGxU=
35 | github.com/gomlx/gomlx v0.15.0 h1:+gRsUzT6O/kcx+wa8pZqGi1JeQKyoCN17IQ9d7d8+gc=
36 | github.com/gomlx/gomlx v0.15.0/go.mod h1:nHkPBFi4z8CyZ5pOmViS8cUbImEB19XH3CVUwlbWjfk=
37 | github.com/gomlx/gopjrt v0.4.4 h1:Pc3jcBTsWmkf58pEuGnTfsf3eFhnX7UHhm/u/6VKczU=
38 | github.com/gomlx/gopjrt v0.4.4/go.mod h1:KkKFWTOGpJ1gludkzITb5/VOxpr5sy7qsTHvehcjTQA=
39 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
40 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
41 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
42 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
43 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
44 | github.com/janpfeifer/must v0.2.0 h1:yWy1CE5gtk1i2ICBvqAcMMXrCMqil9CJPkc7x81fRdQ=
45 | github.com/janpfeifer/must v0.2.0/go.mod h1:S6c5Yg/YSMR43cJw4zhIq7HFMci90a7kPY9XA4c8UIs=
46 | github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
47 | github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
48 | github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
49 | github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
50 | github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
51 | github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
52 | github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
53 | github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
54 | github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ=
55 | github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw=
56 | github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
57 | github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
58 | github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
59 | github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
60 | github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo=
61 | github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8=
62 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
63 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
64 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
65 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
66 | github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
67 | github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
68 | github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
69 | github.com/schollz/progressbar/v3 v3.17.0 h1:Fv+vG6O6jnJwdjCelvfyYO7sF2jaUGQVmdH4CxcZdsQ=
70 | github.com/schollz/progressbar/v3 v3.17.0/go.mod h1:5H4fLgifX+KeQCsEJnZTOepgZLe1jFF1lpPXb68IJTA=
71 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
72 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
73 | github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaUXK79GlxNBwueZn0xI=
74 | github.com/vmihailenco/msgpack v4.0.4+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk=
75 | github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
76 | github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
77 | github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
78 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
79 | golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
80 | golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY=
81 | golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
82 | golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
83 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
84 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
85 | golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
86 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
87 | golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
88 | golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
89 | golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
90 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
91 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
92 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
93 | golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
94 | golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
95 | golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
96 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
97 | golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
98 | golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
99 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
100 | golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
101 | golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24=
102 | golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
103 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
104 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
105 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
106 | golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
107 | golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
108 | golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
109 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
110 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
111 | golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
112 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
113 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
114 | google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM=
115 | google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds=
116 | google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
117 | google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
118 | google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA=
119 | google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
120 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
121 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
122 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
123 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
124 | k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk=
125 | k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE=
126 |
--------------------------------------------------------------------------------
/samplers/samplers.go:
--------------------------------------------------------------------------------
1 | // Package samplers uses a transformer model to generate senteces based on prompts.
2 | package samplers
3 |
4 | import (
5 | "github.com/dustin/go-humanize"
6 | "github.com/gomlx/exceptions"
7 | "github.com/gomlx/gemma/transformers"
8 | "github.com/gomlx/gemma/trees"
9 | "github.com/gomlx/gomlx/backends"
10 | . "github.com/gomlx/gomlx/graph"
11 | "github.com/gomlx/gomlx/ml/context"
12 | "github.com/gomlx/gomlx/types/shapes"
13 | "github.com/gomlx/gomlx/types/tensors"
14 | "github.com/gomlx/gomlx/types/xslices"
15 | "github.com/gomlx/gopjrt/dtypes"
16 | klog "k8s.io/klog/v2"
17 | "slices"
18 | "time"
19 | )
20 |
21 | type Vocabulary interface {
22 | EncodeAsIDs(text string) []int
23 | DecodeIDs([]int) string
24 |
25 | // The methods below define the special ids for the model.
26 |
27 | BeginningOfSentenceID() int
28 | EndOfSentenceID() int
29 | UnknownID() int
30 | PadID() int
31 | }
32 |
33 | // Sampler has a transformer (LLM) model and a vocabulary (sentencepiece) configured and generates
34 | // sentences based on prompts.
35 | type Sampler struct {
36 | Backend backends.Backend
37 | Vocab Vocabulary
38 |
39 | // MaxGeneratedTokens default for Sampler.Sample.
40 | MaxGeneratedTokens int
41 |
42 | // Context with the model weights, used to execute the model.
43 | Context *context.Context
44 |
45 | // SampleStep graph computation.
46 | SampleStep *context.Exec
47 |
48 | // Config of the Gemma model, created from the weights.
49 | Config *transformers.Config
50 |
51 | // CacheTreeStructure holds the structure of the tree used for caching: the tree structure (paths) is stable
52 | // across different calls to Sample.
53 | CacheTreeStructure *trees.Tree[struct{}]
54 | }
55 |
56 | // New creates a new sampler with the registered vocabulary and model.
57 | func New(backend backends.Backend, ctx *context.Context, vocab Vocabulary, maxGeneratedTokens int) (*Sampler, error) {
58 | s := &Sampler{
59 | Backend: backend,
60 | Vocab: vocab,
61 | MaxGeneratedTokens: maxGeneratedTokens,
62 | Context: ctx,
63 | }
64 | var err error
65 | s.Config, err = transformers.NewConfigFromContext(s.Context.In("model"))
66 | if err != nil {
67 | return nil, err
68 | }
69 | s.Context = s.Context.Reuse()
70 | s.SampleStep = context.NewExec(backend, s.Context, s.sampleStepGraphFn())
71 | return s, nil
72 | }
73 |
74 | // Sample the continuation from the given prompts.
75 | func (s *Sampler) Sample(prompts []string) ([]string, error) {
76 | return s.SampleMaxTokens(prompts, s.MaxGeneratedTokens)
77 | }
78 |
79 | // SampleMaxTokens is like Sample, but instead of using the default MaxGenerateTokens, uses the given maxTokens instead.
80 | func (s *Sampler) SampleMaxTokens(prompts []string, maxTokens int) ([]string, error) {
81 | promptIds := xslices.Map(prompts, s.Vocab.EncodeAsIDs)
82 | state, err := s.initialState(promptIds, maxTokens)
83 | if err != nil {
84 | return nil, err
85 | }
86 | err = exceptions.TryCatch[error](func() {
87 | state = s.sampleLoop(state)
88 | })
89 | if err != nil {
90 | return nil, err
91 | }
92 | return s.decode(state), nil
93 | }
94 |
95 | // sampleLoop, executes a sampleStep until all examples in the batch are finished.
96 | func (s *Sampler) sampleLoop(state samplingState) samplingState {
97 | // Prepare inputs as slice:
98 | // * If you change this order, change the parsing order in sampleStepGraphFn below.
99 | inputs := []any{
100 | state.InputBuffer,
101 | state.StepNum,
102 | state.Done,
103 | }
104 | // * Append cache values.
105 | cacheValues := trees.ValuesAsList(state.Cache.Data)
106 | inputs = append(inputs, xslices.Map(cacheValues, func(t *tensors.Tensor) any { return t })...)
107 | numMutableInputs := len(inputs)
108 |
109 | // Append constant inputs.
110 | start := time.Now()
111 | inputs = append(inputs,
112 | state.Positions,
113 | )
114 | var outputs []*tensors.Tensor
115 | var execTime, inputsPrepTime time.Duration
116 | var count int
117 | for {
118 | inputPrepStart := time.Now()
119 | // We donate all the inputs, since they are all going to be updated (saves some GPU memory).
120 | for ii := range numMutableInputs {
121 | inputs[ii] = DonateTensorBuffer(inputs[ii].(*tensors.Tensor), s.Backend)
122 | }
123 | inputsPrepTime += time.Since(inputPrepStart)
124 |
125 | // Execute a step.
126 | execStart := time.Now()
127 | outputs = s.SampleStep.Call(inputs...)
128 | execTime += time.Since(execStart)
129 | count++
130 |
131 | // Update states (the output has the same order as the input).
132 | for ii := range numMutableInputs {
133 | inputs[ii] = outputs[ii]
134 | }
135 | extraOutputs := outputs[numMutableInputs:] // Separate the transient outputs.
136 | done := tensors.ToScalar[bool](extraOutputs[0])
137 |
138 | // End-of-sampling:
139 | if done {
140 | break
141 | }
142 | }
143 | if klog.V(1).Enabled() {
144 | elapsed := time.Since(start)
145 | klog.Infof("Sample execution time (%d steps): %s", count, elapsed)
146 | klog.Infof("> Graph execution time: %s", execTime)
147 | klog.Infof("> Inputs preparation time: %s", inputsPrepTime)
148 | }
149 | state.InputBuffer = outputs[0]
150 | state.Positions = outputs[1]
151 | state.StepNum = outputs[2]
152 | state.Done = outputs[3]
153 | updatedCache := trees.FromValuesAndTree(outputs[4:4+s.CacheTreeStructure.NumLeaves()], s.CacheTreeStructure)
154 | state.Cache.Data = updatedCache
155 | return state
156 | }
157 |
158 | // buildSampleStepGraphFn returns the computation graph building function for this sampler.
159 | // The returned function can be used by context.NewExec.
160 | func (s *Sampler) sampleStepGraphFn() func(*context.Context, []*Node) []*Node {
161 | return func(ctx *context.Context, state []*Node) []*Node {
162 | g := state[0].Graph() // Reference to the underlying graph, it could be from any of the inputs.
163 | _ = ctx
164 |
165 | // Extract state parts:
166 | stateFieldsIdx := 0
167 | nextState := func() *Node {
168 | field := state[stateFieldsIdx]
169 | stateFieldsIdx++
170 | return field
171 | }
172 | // This order has to match the order fed in sampleLoop.
173 | // - Mutable fields, to be updated.
174 | inputBuffer := nextState()
175 | stepNum := nextState()
176 | done := nextState()
177 | numCacheValues := s.CacheTreeStructure.NumLeaves()
178 | cache := trees.FromValuesAndTree(state[stateFieldsIdx:stateFieldsIdx+numCacheValues], s.CacheTreeStructure)
179 | stateFieldsIdx += numCacheValues
180 |
181 | // - Constant fields.
182 | positions := nextState()
183 |
184 | // Take the current step token for all examples of the batch.
185 | batchSize := inputBuffer.Shape().Dimensions[0]
186 | zeroIdx := ScalarZero(g, dtypes.Int32)
187 | currentTokens := DynamicSlice(inputBuffer, []*Node{zeroIdx, stepNum}, []int{batchSize, 1})
188 | currentTokens.AssertDims(batchSize, 1)
189 | currentPositions := DynamicSlice(positions, []*Node{zeroIdx, stepNum}, []int{batchSize, 1})
190 | currentPositions.AssertDims(batchSize, 1)
191 |
192 | // Attention to all positions < current step number (stepNum).
193 | // Notice that the cache rotates, so once stepNum > Config.MaxCacheLength, the mask will be
194 | // true everywhere.
195 | cacheAttentionMask := Iota(g, shapes.Make(dtypes.Int32, batchSize, 1, s.Config.MaxCacheLength), -1)
196 | cacheAttentionMask = LessOrEqual(cacheAttentionMask, stepNum)
197 |
198 | logits := transformers.GemmaWithCache(ctx.In("model"), s.Config,
199 | currentTokens, currentPositions, cache, cacheAttentionMask)
200 | logits.AssertDims(batchSize, 1, s.Config.VocabularySize)
201 |
202 | nextTokenNum := OnePlus(stepNum)
203 | nextPredictedTokens := ArgMax(logits, -1)
204 | nextPredictedTokens.AssertDims(batchSize, 1)
205 | nextTokenStartIdx := []*Node{zeroIdx, nextTokenNum}
206 | nextTokens := DynamicSlice(inputBuffer, nextTokenStartIdx, []int{batchSize, 1})
207 | nextTokens.AssertDims(batchSize, 1)
208 | nextTokens = Where(
209 | Or(
210 | Equal(nextTokens, Const(g, int32(s.Vocab.PadID()))),
211 | ExpandAxes(done, -1),
212 | ),
213 | nextPredictedTokens,
214 | nextTokens,
215 | )
216 | inputBuffer = DynamicUpdateSlice(inputBuffer, nextTokens, nextTokenStartIdx)
217 | eosToken := Scalar(g, dtypes.Int32, s.Vocab.EndOfSentenceID())
218 | nextTokenIsEOS := Squeeze(Equal(nextTokens, eosToken), -1)
219 | done = Or(done, nextTokenIsEOS)
220 |
221 | // Prepare next step: are we done ?
222 | stepNum = nextTokenNum
223 | maxSteps := inputBuffer.Shape().Dimensions[1] - 2
224 | allDone := Or(
225 | LogicalAll(done),
226 | GreaterOrEqual(stepNum, Const(g, int32(maxSteps))),
227 | )
228 |
229 | // Outputs: updated mutable values first including cache):
230 | outputs := []*Node{inputBuffer, stepNum, done}
231 | outputs = append(outputs, trees.ValuesAsList(cache)...)
232 | // - Other results:
233 | outputs = append(outputs, allDone)
234 | return outputs
235 | }
236 | }
237 |
238 | // samplingState holds the state of the sampling loop plus some constants for the loop.
239 | type samplingState struct {
240 | // BatchSize, MaxTokens, TotalLength are constants for one sampling.
241 | BatchSize, MaxTokens, TotalLength int
242 |
243 | // InputBuffer holds the ids with prepended (beginning-of-sentence) and padding () and extra space for
244 | // an (end-of-sentence).
245 | InputBuffer *tensors.Tensor
246 |
247 | // NumInputTokens is the number of tokens on the original input per example: shaped int32[batch_size].
248 | NumInputTokens *tensors.Tensor
249 |
250 | // Positions for each token, see transformers.BuildPositionsFromMask
251 | Positions *tensors.Tensor
252 |
253 | // StepNum is a scalar counter of the steps sampled (decoded) so far.
254 | StepNum *tensors.Tensor
255 |
256 | // Done is a vector of the inputs who are done with the generation: shaped bool[batch_size].
257 | Done *tensors.Tensor
258 |
259 | // Cache used during the sampling.
260 | Cache *transformers.Cache
261 | }
262 |
263 | // initialState creates a tensor shaped int32[batchSize, totalLength+2] padded with the Vocab.PadId filled (left to right)
264 | // with the given promptIds.
265 | //
266 | // It also returns the mask, that is set to true where it is not padding.
267 | //
268 | // It also adds a "bos" (beginning of sentence) token to each prompt.
269 | func (s *Sampler) initialState(promptIds [][]int, maxTokens int) (state samplingState, err error) {
270 | state.MaxTokens = maxTokens
271 | state.BatchSize = len(promptIds)
272 | batchSize := state.BatchSize
273 |
274 | lengths := xslices.Map(promptIds, func(seq []int) int32 { return int32(len(seq)) + 1 }) // +1 for (beginning-of-sentence) token.
275 | state.NumInputTokens = tensors.FromValue(lengths) // Shape [batchSize]
276 | maxInputLength := int(slices.Max(lengths))
277 | state.TotalLength = maxInputLength + maxTokens + 1 // +1 for .
278 | totalLength := state.TotalLength
279 |
280 | state.StepNum = tensors.FromScalar(int32(0))
281 | state.InputBuffer = tensors.FromScalarAndDimensions(int32(s.Vocab.PadID()), batchSize, totalLength)
282 | bos := int32(s.Vocab.BeginningOfSentenceID())
283 |
284 | // Copy over "ragged" promptIds to dense InputBuffer (filled with ), prepending ,
285 | // and set InputMask to true where InputBuffer != .
286 | tensors.MutableFlatData(state.InputBuffer, func(flatIDs []int32) {
287 | for exampleIdx := range batchSize {
288 | exampleIds := flatIDs[exampleIdx*totalLength : (exampleIdx+1)*totalLength]
289 | exampleIds[0] = bos
290 | for ii, value := range promptIds[exampleIdx] {
291 | exampleIds[1+ii] = int32(value)
292 | }
293 | }
294 | })
295 |
296 | // Notice that the convoluted code in https://github.com/google-deepmind/gemma/blob/main/gemma/sampler.py
297 | // (see Sampler.init_sample_state()) in the end simply does the same as Iota() -- except if the input
298 | // has pad symbols (not the padding added by filling the ragged input) inside it -- which is not doable when
299 | // converting from string.
300 | //
301 | // Probably there is a bug in the original code ... it's not documented what they intended to do with it, but
302 | // we simply take the iota here.
303 | state.Positions = ExecOnce(s.Backend, func(g *Graph) *Node {
304 | return Iota(g, shapes.Make(dtypes.Int32, batchSize, totalLength), -1)
305 | })
306 |
307 | state.Done = tensors.FromShape(shapes.Make(dtypes.Bool, batchSize))
308 |
309 | // Setup cache, and if not yet setup, configure cache structure.
310 | var start time.Time
311 | if klog.V(1).Enabled() {
312 | start = time.Now()
313 | }
314 | state.Cache, err = transformers.NewCache(s.Config, batchSize)
315 | if err != nil {
316 | return
317 | }
318 | if s.CacheTreeStructure == nil {
319 | s.CacheTreeStructure = trees.Map(state.Cache.Data, func(_ trees.Path, _ *tensors.Tensor) (empty struct{}) { return })
320 | }
321 |
322 | if klog.V(1).Enabled() {
323 | elapsed := time.Since(start)
324 | var cacheMem uintptr
325 | for _, t := range state.Cache.Data.Leaves() {
326 | cacheMem += t.Memory()
327 | }
328 | klog.Infof("cache: elapsed %s, memory used %s\n", elapsed, humanize.Bytes(uint64(cacheMem)))
329 | }
330 | return
331 | }
332 |
333 | // decode converts the state's InputBuffer with the sampled tokens to actual text.
334 | func (s *Sampler) decode(state samplingState) []string {
335 | text := make([]string, state.BatchSize)
336 | totalLength := state.TotalLength
337 | tensors.ConstFlatData(state.InputBuffer, func(flatIds []int32) {
338 | for exampleIdx := range state.BatchSize {
339 | exampleIds := flatIds[exampleIdx*totalLength : (exampleIdx+1)*totalLength]
340 | ids := xslices.Map(exampleIds, func(id int32) int { return int(id) })
341 | nonPad := 0
342 | for _, id := range ids {
343 | if id == s.Vocab.EndOfSentenceID() || id == s.Vocab.PadID() {
344 | break
345 | }
346 | nonPad++
347 | }
348 | text[exampleIdx] = s.Vocab.DecodeIDs(ids[:nonPad]) // Notice , and are converted to empty strings.
349 | //fmt.Printf("tokens: %#v\n", ids[:nonPad])
350 | }
351 | })
352 | return text
353 | }
354 |
355 | // UpdateCacheAttentionMaskGraph given an inputMask (on the whole batch of example token ids), a currentStep and
356 | // attentionLen (static).
357 | //
358 | // It's based on _compute_attention_mask in https://github.com/google-deepmind/gemma/blob/main/gemma/sampler.py#L32:
359 | // the inputs and outputs here are very cryptic ... my best guess (also with some help from Gemini) what the original
360 | // authors meant to generate is a mask that is False to where they can attend, except if it is in the "future"
361 | // ("future" means positions > currentStep).
362 | //
363 | // - currentStep: scalar with current step.
364 | // - attentionLen: length of the attention mask (it's ok to be larger than inputMask, it will pad the output accordingly)
365 | // - inputMask: mask of valid tokens in the input, shaped [batchSize, inputLen]
366 | func UpdateCacheAttentionMaskGraph(currentStep *Node, attentionLen int, inputMask *Node) *Node {
367 | return nil
368 | }
369 |
--------------------------------------------------------------------------------
/sentencepiece/sentencepiece.go:
--------------------------------------------------------------------------------
1 | // Package sentencepiece fills some missing functionality from github.com/eliben/go-sentencepiece
2 | //
3 | // Hopefully it's temporary.
4 | package sentencepiece
5 |
6 | import (
7 | esentencepiece "github.com/eliben/go-sentencepiece"
8 | "github.com/gomlx/gomlx/types/xslices"
9 | "github.com/pkg/errors"
10 | )
11 |
12 | // Tokenizer is able to encode/decode tokens from/to text.
13 | type Tokenizer struct {
14 | *esentencepiece.Processor
15 | Info *esentencepiece.ModelInfo
16 | }
17 |
18 | func NewFromPath(vocabPath string) (*Tokenizer, error) {
19 | proc, err := esentencepiece.NewProcessorFromPath(vocabPath)
20 | if err != nil {
21 | return nil, errors.Wrapf(err, "can't create sentencepiece")
22 | }
23 | return &Tokenizer{
24 | Processor: proc,
25 | Info: proc.ModelInfo(),
26 | }, nil
27 | }
28 |
29 | type Token = esentencepiece.Token
30 |
31 | // EncodeAsIDs returns the text encoded into a sequence of ids.
32 | // It implements sampler.Vocabulary.
33 | func (p *Tokenizer) EncodeAsIDs(text string) []int {
34 | tokens := p.Processor.Encode(text)
35 | return xslices.Map(tokens, func(t Token) int { return t.ID })
36 | }
37 |
38 | // DecodeIDs returns the text from a sequence of ids.
39 | // It implements sampler.Vocabulary.
40 | func (p *Tokenizer) DecodeIDs(ids []int) string {
41 | return p.Processor.Decode(ids)
42 | }
43 |
44 | // BeginningOfSentenceID implements sampler.Vocabulary.
45 | func (p *Tokenizer) BeginningOfSentenceID() int {
46 | return p.Info.BeginningOfSentenceID
47 | }
48 |
49 | // EndOfSentenceID implements sampler.Vocabulary.
50 | func (p *Tokenizer) EndOfSentenceID() int {
51 | return p.Info.EndOfSentenceID
52 | }
53 |
54 | // UnknownID implements sampler.Vocabulary.
55 | func (p *Tokenizer) UnknownID() int {
56 | return p.Info.UnknownID
57 | }
58 |
59 | // PadID implements sampler.Vocabulary.
60 | func (p *Tokenizer) PadID() int {
61 | return p.Info.PadID
62 | }
63 |
--------------------------------------------------------------------------------
/transformers/attention.go:
--------------------------------------------------------------------------------
1 | package transformers
2 |
3 | import (
4 | "github.com/gomlx/exceptions"
5 | "github.com/gomlx/gemma/trees"
6 | . "github.com/gomlx/gomlx/graph"
7 | "github.com/gomlx/gomlx/ml/context"
8 | "github.com/gomlx/gomlx/types/shapes"
9 | "github.com/gomlx/gomlx/types/tensors"
10 | "github.com/gomlx/gopjrt/dtypes"
11 | "github.com/pkg/errors"
12 | )
13 |
14 | // createAttentionCache creates the attention cache for the attention layer under treePath.
15 | func createAttentionCache(data *trees.Tree[*tensors.Tensor], treePath trees.Path, dtype dtypes.DType,
16 | batchSize, maxCacheLength, numHeads, headDim int) error {
17 | // Value cache:
18 | err := data.Set(append(treePath, "v"),
19 | tensors.FromShape(shapes.Make(dtype, batchSize, maxCacheLength, numHeads, headDim)))
20 | if err != nil {
21 | return errors.WithMessage(err, "in createAttentionCache()")
22 | }
23 |
24 | // Keys cache:
25 | err = data.Set(append(treePath, "k"),
26 | tensors.FromShape(shapes.Make(dtype, batchSize, maxCacheLength, numHeads, headDim)))
27 | if err != nil {
28 | return errors.WithMessage(err, "in createAttentionCache()")
29 | }
30 |
31 | // Index where to insert new values, in a rotating cache.
32 | err = data.Set(append(treePath, "end_index"), tensors.FromScalar(int32(0)))
33 | if err != nil {
34 | return errors.WithMessage(err, "in createAttentionCache()")
35 | }
36 | return nil
37 | }
38 |
39 | // Must panics if the error is not nil.
40 | func Must(err error) {
41 | if err != nil {
42 | panic(err)
43 | }
44 | }
45 |
46 | // Must1 panics in case of error, otherwise returns the one return value.
47 | func Must1[T any](v T, err error) T {
48 | if err != nil {
49 | panic(err)
50 | }
51 | return v
52 | }
53 |
54 | // Attention builds an attention layer, optionally using cache to store a limited amount of context.
55 | //
56 | // - attentionIdx indexes attention configuration (in config) parameters, like config.AttentionTypes.
57 | // - x is the operand shaped [batchSize, sequenceLength, embedDim]. If using cache, typically the sequenceLength will be 1.
58 | // - positions are the positions of the sequence in x, shaped int32[batchSize, sequenceLength].
59 | // - cache: if set, x is only used for the current token (so sequenceLength will be 1), and the x's key and value projections
60 | // are set in the cache. After that, cache is used instead of x for the attention.
61 | // - attentionMask: shaped bool[batchSize, sequenceLength, sequenceLength] (if cache is nil) or bool[batchSize, sequenceLength==1, config.MaxCacheLength] if
62 | // cache is being used.
63 | func Attention(ctx *context.Context, config *Config, attentionIdx int, x, positions *Node, cache *trees.Tree[*Node], attentionMask *Node) *Node {
64 | g := x.Graph()
65 | dtype := x.DType()
66 |
67 | // Calculates projections used in the attention.
68 | var queryProjection, keyProjection, valueProjection *Node
69 |
70 | // Glossary of keys for einsum:
71 | // B = batchSize
72 | // T = sequenceLength
73 | // D = config.EmbedDim
74 | // N = config.NumHeads
75 | // H = config.HeadDim
76 | // K = config.NumKVHeads
77 | if config.HuggingFaceVersion {
78 | // HuggingFace version has separate variables per projection.
79 | keyProjectionWeights := ctx.In("hf").
80 | VariableWithShape("k_proj", shapes.Make(dtype, config.NumKVHeads*config.HeadDim, config.EmbedDim)).
81 | ValueGraph(g)
82 | keyProjectionWeights = Reshape(keyProjectionWeights, config.NumKVHeads, config.HeadDim, config.EmbedDim)
83 | keyProjection = Einsum("BSD,KHD->BSKH", x, keyProjectionWeights)
84 |
85 | valueProjectionWeights := ctx.In("hf").
86 | VariableWithShape("v_proj", shapes.Make(dtype, config.NumKVHeads*config.HeadDim, config.EmbedDim)).
87 | ValueGraph(g)
88 | valueProjectionWeights = Reshape(valueProjectionWeights, config.NumKVHeads, config.HeadDim, config.EmbedDim)
89 | valueProjection = Einsum("BSD,KHD->BSKH", x, valueProjectionWeights)
90 |
91 | queryProjectionWeights := ctx.In("hf").
92 | VariableWithShape("q_proj", shapes.Make(dtype, config.NumHeads*config.HeadDim, config.EmbedDim)).
93 | ValueGraph(g)
94 | queryProjectionWeights = Reshape(queryProjectionWeights, config.NumHeads, config.HeadDim, config.EmbedDim)
95 | queryProjection = Einsum("BSD,NHD->BSNH", x, queryProjectionWeights)
96 |
97 | } else if config.UseQKV {
98 | // S = 3, one extra dimensions for query, key, value projections
99 | qkvProjections := KernelEinsum(ctx.In("qkv_einsum"), "BTD,SNDH->SBTNH", x,
100 | shapes.Make(dtype /* k, q, v = 3 */, 3, config.NumHeads, config.EmbedDim, config.HeadDim))
101 | queryProjection = Squeeze(Slice(qkvProjections, AxisElem(0)), 0)
102 | keyProjection = Squeeze(Slice(qkvProjections, AxisElem(1)), 0)
103 | valueProjection = Squeeze(Slice(qkvProjections, AxisElem(2)), 0)
104 | } else {
105 | queryProjection = KernelEinsum(ctx.In("q_einsum"), "BTD,NDH->BTNH", x,
106 | shapes.Make(dtype, config.NumHeads, config.EmbedDim, config.HeadDim))
107 | // C = 2, one dimension for key, the other for value.
108 | kvProjections := KernelEinsum(ctx.In("kv_einsum"), "BSD,CKDH->CBSKH", x,
109 | shapes.Make(dtype, 2, config.NumKVHeads, config.EmbedDim, config.HeadDim))
110 | keyProjection = Squeeze(Slice(kvProjections, AxisElem(0)), 0)
111 | valueProjection = Squeeze(Slice(kvProjections, AxisElem(1)), 0)
112 | }
113 |
114 | queryProjection = ApplyRotaryPositionEncoding(queryProjection, positions, RoPEDefaultMaxWaveLength)
115 | queryScaled := MulScalar(queryProjection, config.QueryPreAttentionScalar())
116 | keyProjection = ApplyRotaryPositionEncoding(keyProjection, positions, RoPEDefaultMaxWaveLength)
117 |
118 | // If cache is set, update it with the projections of the slice of the sequence given, and then take the
119 | // projections of the whole cache.
120 | if cache != nil {
121 | // Insert calculated projections in cache: cached projections are shaped [batchSize, maxCacheLength, numHeads, headDim]
122 | endIndex, err := cache.Get("end_index")
123 | if err != nil {
124 | panic(err)
125 | }
126 | zeroIdx := ScalarZero(g, dtypes.Int32)
127 | cacheSequencePosition := Mod(endIndex, Scalar(g, endIndex.DType(), config.MaxCacheLength))
128 | updateSliceIndices := []*Node{zeroIdx, cacheSequencePosition, zeroIdx, zeroIdx}
129 |
130 | valueProjection = DynamicUpdateSlice(Must1(cache.Get("v")), valueProjection, updateSliceIndices)
131 | keyProjection = DynamicUpdateSlice(Must1(cache.Get("k")), keyProjection, updateSliceIndices)
132 | Must(cache.Set(trees.Path{"v"}, valueProjection))
133 | Must(cache.Set(trees.Path{"k"}, keyProjection))
134 | // Bump end_index the length of tokens provided at this step: typically, this will be only 1. If > 1
135 | // this will probably not work if the cache wraps around.
136 | Must(cache.Set(trees.Path{"end_index"}, AddScalar(endIndex, positions.Shape().Dim(-1))))
137 | }
138 |
139 | batchSize := queryScaled.Shape().Dim(0) // B
140 | seqLength := queryScaled.Shape().Dim(1) // T
141 | numQueryHeads := queryScaled.Shape().Dim(2) // N
142 | headDim := queryScaled.Shape().Dim(3) // H
143 | numKVHeads := config.NumKVHeads // K
144 | attentionTargetLength := keyProjection.Shape().Dim(1) // S = config.MaxCacheLength if cache != nil, or seqLength.
145 |
146 | var logits *Node
147 | if config.UseGroupQueryAttention {
148 | // There are fewer key (and value) projections than query projections,
149 | // reshape matrices accordingly and adjust Einsum.
150 | queryPerKVHeads := numQueryHeads / numKVHeads // G
151 | queryScaled = Reshape(queryScaled, batchSize, seqLength, numKVHeads, queryPerKVHeads, headDim)
152 | logits = Einsum("BTKGH,BSKH->BTKGS", queryScaled, keyProjection)
153 | logits = Reshape(logits, batchSize, seqLength, numQueryHeads, attentionTargetLength)
154 | } else {
155 | // Same number of query/key projections.
156 | // N = numQueryHeads == numKVHeads.
157 | logits = Einsum("BTNH,BSNH->BTNS", queryScaled, keyProjection)
158 | }
159 | logits.AssertDims(batchSize, seqLength, numQueryHeads, config.MaxCacheLength)
160 | logits = SoftCap(logits, config.AttentionLogitsSoftCap) // No-op if config.AttentionLogitsSoftCap is 0.
161 |
162 | if config.AttentionTypes[attentionIdx] == AttentionTypeLocalSliding {
163 | // Create a sliding mask: a mask that has a band (2*config.SlidingWindowSize) around the diagonal.
164 | // Issue: this will not work when using cache, and the cache loops around its config.MaxCacheLength, since
165 | // the sliding mask "band" won't wrap around.
166 | if config.SlidingWindowSize <= 0 {
167 | exceptions.Panicf("Config.SlidingWindowSize must be set for AttentionTypeLocalSliding")
168 | }
169 | allOnes := OnesLike(attentionMask)
170 | slidingMask := And(
171 | TakeUpperTriangular(allOnes, 1-config.SlidingWindowSize),
172 | TakeLowerTriangular(allOnes, config.SlidingWindowSize-1),
173 | )
174 | attentionMask = And(attentionMask, slidingMask)
175 | }
176 |
177 | // Calculate attention weights.
178 | const logitsMask = -2.3819763e38
179 | paddedLogits := Where(
180 | BroadcastToShape(ExpandAxes(attentionMask, -2), logits.Shape()),
181 | logits,
182 | Scalar(g, logits.DType(), logitsMask),
183 | )
184 | attentionWeights := Softmax(paddedLogits, -1)
185 |
186 | // Weighted sum of the values:
187 | var encoded *Node
188 | if config.UseGroupQueryAttention {
189 | // Reshape matrices to enable Einsums over groups of queries.
190 | queryPerKVHeads := numQueryHeads / numKVHeads // G
191 | attentionWeights = Reshape(attentionWeights, batchSize, seqLength, numKVHeads, queryPerKVHeads, attentionTargetLength)
192 | encoded = Einsum("BTKGS,BSKH->BTKGH", attentionWeights, valueProjection)
193 | encoded = Reshape(encoded, batchSize, seqLength, numQueryHeads, headDim)
194 | } else {
195 | // Plain attention: same number of query, keys and values projections.
196 | encoded = Einsum("BTNS,BSNH->BTNH", attentionWeights, valueProjection)
197 | encoded.AssertDims(batchSize, seqLength, numQueryHeads, headDim)
198 | }
199 |
200 | // Finally, a linear transformation on the result, merging all the heads.
201 | var output *Node
202 | if config.HuggingFaceVersion {
203 | outputProjectionWeights := ctx.In("hf").
204 | VariableWithShape("o_proj", shapes.Make(dtype, config.EmbedDim, config.NumHeads*config.HeadDim)).
205 | ValueGraph(g)
206 | outputProjectionWeights = Reshape(outputProjectionWeights, config.EmbedDim, numQueryHeads, config.HeadDim)
207 | output = Einsum("BTNH,DNH->BTD", encoded, outputProjectionWeights)
208 |
209 | } else {
210 | output = KernelEinsum(ctx.In("attn_vec_einsum"), "BTNH,NHD->BTD",
211 | encoded,
212 | shapes.Make(encoded.DType(), numQueryHeads, config.HeadDim, config.EmbedDim))
213 | }
214 | return output
215 | }
216 |
--------------------------------------------------------------------------------
/transformers/attentiontype_enumer.go:
--------------------------------------------------------------------------------
1 | // Code generated by "enumer -type=AttentionType -trimprefix=AttentionType -transform=snake -values -text -json -yaml config.go"; DO NOT EDIT.
2 |
3 | package transformers
4 |
5 | import (
6 | "encoding/json"
7 | "fmt"
8 | "strings"
9 | )
10 |
11 | const _AttentionTypeName = "unknowngloballocal_sliding"
12 |
13 | var _AttentionTypeIndex = [...]uint8{0, 7, 13, 26}
14 |
15 | const _AttentionTypeLowerName = "unknowngloballocal_sliding"
16 |
17 | func (i AttentionType) String() string {
18 | if i < 0 || i >= AttentionType(len(_AttentionTypeIndex)-1) {
19 | return fmt.Sprintf("AttentionType(%d)", i)
20 | }
21 | return _AttentionTypeName[_AttentionTypeIndex[i]:_AttentionTypeIndex[i+1]]
22 | }
23 |
24 | func (AttentionType) Values() []string {
25 | return AttentionTypeStrings()
26 | }
27 |
28 | // An "invalid array index" compiler error signifies that the constant values have changed.
29 | // Re-run the stringer command to generate them again.
30 | func _AttentionTypeNoOp() {
31 | var x [1]struct{}
32 | _ = x[AttentionTypeUnknown-(0)]
33 | _ = x[AttentionTypeGlobal-(1)]
34 | _ = x[AttentionTypeLocalSliding-(2)]
35 | }
36 |
37 | var _AttentionTypeValues = []AttentionType{AttentionTypeUnknown, AttentionTypeGlobal, AttentionTypeLocalSliding}
38 |
39 | var _AttentionTypeNameToValueMap = map[string]AttentionType{
40 | _AttentionTypeName[0:7]: AttentionTypeUnknown,
41 | _AttentionTypeLowerName[0:7]: AttentionTypeUnknown,
42 | _AttentionTypeName[7:13]: AttentionTypeGlobal,
43 | _AttentionTypeLowerName[7:13]: AttentionTypeGlobal,
44 | _AttentionTypeName[13:26]: AttentionTypeLocalSliding,
45 | _AttentionTypeLowerName[13:26]: AttentionTypeLocalSliding,
46 | }
47 |
48 | var _AttentionTypeNames = []string{
49 | _AttentionTypeName[0:7],
50 | _AttentionTypeName[7:13],
51 | _AttentionTypeName[13:26],
52 | }
53 |
54 | // AttentionTypeString retrieves an enum value from the enum constants string name.
55 | // Throws an error if the param is not part of the enum.
56 | func AttentionTypeString(s string) (AttentionType, error) {
57 | if val, ok := _AttentionTypeNameToValueMap[s]; ok {
58 | return val, nil
59 | }
60 |
61 | if val, ok := _AttentionTypeNameToValueMap[strings.ToLower(s)]; ok {
62 | return val, nil
63 | }
64 | return 0, fmt.Errorf("%s does not belong to AttentionType values", s)
65 | }
66 |
67 | // AttentionTypeValues returns all values of the enum
68 | func AttentionTypeValues() []AttentionType {
69 | return _AttentionTypeValues
70 | }
71 |
72 | // AttentionTypeStrings returns a slice of all String values of the enum
73 | func AttentionTypeStrings() []string {
74 | strs := make([]string, len(_AttentionTypeNames))
75 | copy(strs, _AttentionTypeNames)
76 | return strs
77 | }
78 |
79 | // IsAAttentionType returns "true" if the value is listed in the enum definition. "false" otherwise
80 | func (i AttentionType) IsAAttentionType() bool {
81 | for _, v := range _AttentionTypeValues {
82 | if i == v {
83 | return true
84 | }
85 | }
86 | return false
87 | }
88 |
89 | // MarshalJSON implements the json.Marshaler interface for AttentionType
90 | func (i AttentionType) MarshalJSON() ([]byte, error) {
91 | return json.Marshal(i.String())
92 | }
93 |
94 | // UnmarshalJSON implements the json.Unmarshaler interface for AttentionType
95 | func (i *AttentionType) UnmarshalJSON(data []byte) error {
96 | var s string
97 | if err := json.Unmarshal(data, &s); err != nil {
98 | return fmt.Errorf("AttentionType should be a string, got %s", data)
99 | }
100 |
101 | var err error
102 | *i, err = AttentionTypeString(s)
103 | return err
104 | }
105 |
106 | // MarshalText implements the encoding.TextMarshaler interface for AttentionType
107 | func (i AttentionType) MarshalText() ([]byte, error) {
108 | return []byte(i.String()), nil
109 | }
110 |
111 | // UnmarshalText implements the encoding.TextUnmarshaler interface for AttentionType
112 | func (i *AttentionType) UnmarshalText(text []byte) error {
113 | var err error
114 | *i, err = AttentionTypeString(string(text))
115 | return err
116 | }
117 |
118 | // MarshalYAML implements a YAML Marshaler for AttentionType
119 | func (i AttentionType) MarshalYAML() (interface{}, error) {
120 | return i.String(), nil
121 | }
122 |
123 | // UnmarshalYAML implements a YAML Unmarshaler for AttentionType
124 | func (i *AttentionType) UnmarshalYAML(unmarshal func(interface{}) error) error {
125 | var s string
126 | if err := unmarshal(&s); err != nil {
127 | return err
128 | }
129 |
130 | var err error
131 | *i, err = AttentionTypeString(s)
132 | return err
133 | }
134 |
--------------------------------------------------------------------------------
/transformers/cache.go:
--------------------------------------------------------------------------------
1 | package transformers
2 |
3 | import (
4 | "fmt"
5 | "github.com/gomlx/gemma/trees"
6 | "github.com/gomlx/gomlx/types/tensors"
7 | )
8 |
9 | // Cache is a state cache of a (batch of) sequence being encoded/decoded.
10 | //
11 | // It has a fixed size (so typical cached values with be prefixed with the dimensions [BatchSize, Cache.Size],
12 | // and current position (where each decode step is stored) is rotating: CurrentStep = (CurrentStep+1)%Cache.Length).
13 | //
14 | // It's stored as a trees.Tree[*tensor.Tensor].
15 | //
16 | // For the Gemma2 model, the first level of the tree being the layer names,
17 | // and the second level hold the "keys" and "values" embedding caches for each transformer layer.
18 | type Cache struct {
19 | // Config of the model.
20 | Config *Config
21 |
22 | // BatchSize for this cache.
23 | BatchSize int
24 |
25 | // Length (in number of steps) of the cache. The cache itself is rotating on this size.
26 | // It comes from config.MaxCacheLength.
27 | Length int
28 |
29 | // Data holds the cached data, organized as a trees.Tree[*tensors.Tensor].
30 | Data *trees.Tree[*tensors.Tensor]
31 | }
32 |
33 | func NewCache(config *Config, batchSize int) (*Cache, error) {
34 | c := &Cache{
35 | Config: config,
36 | BatchSize: batchSize,
37 | Length: config.MaxCacheLength,
38 | Data: trees.New[*tensors.Tensor](),
39 | }
40 |
41 | for layerIdx := range config.NumLayers {
42 | treePath := []string{fmt.Sprintf("layer_%d", layerIdx)}
43 | err := createAttentionCache(c.Data, treePath, config.DType, batchSize, config.MaxCacheLength,
44 | config.NumKVHeads, config.HeadDim)
45 | if err != nil {
46 | return nil, err
47 | }
48 | }
49 | return c, nil
50 | }
51 |
--------------------------------------------------------------------------------
/transformers/config.go:
--------------------------------------------------------------------------------
1 | package transformers
2 |
3 | import (
4 | "github.com/gomlx/exceptions"
5 | "github.com/gomlx/gomlx/ml/context"
6 | "github.com/gomlx/gopjrt/dtypes"
7 | "github.com/pkg/errors"
8 | "math"
9 | )
10 |
11 | type GemmaType int
12 |
13 | const (
14 | UnknownGemmaType GemmaType = iota
15 | Gemma_2B
16 | Gemma_7B
17 | Gemma2_2B
18 | Gemma2_9B
19 | Gemma2_27B
20 | )
21 |
22 | //go:generate enumer -type=GemmaType -transform=snake -values -text -json -yaml config.go
23 |
24 | var numLayersToGemmaClass = map[int]GemmaType{
25 | 18: Gemma_2B,
26 | 28: Gemma_7B,
27 | 26: Gemma2_2B,
28 | 42: Gemma2_9B,
29 | 46: Gemma2_27B,
30 | }
31 |
32 | type AttentionType int
33 |
34 | //go:generate enumer -type=AttentionType -trimprefix=AttentionType -transform=snake -values -text -json -yaml config.go
35 |
36 | const (
37 | AttentionTypeUnknown AttentionType = iota
38 | AttentionTypeGlobal
39 | AttentionTypeLocalSliding
40 | )
41 |
42 | // QueryPreAttentionNormalisationType defines how to normalize query before attention.
43 | type QueryPreAttentionNormalisationType int
44 |
45 | //go:generate enumer -type=QueryPreAttentionNormalisationType -trimprefix=QueryNormType -transform=snake -values -text -json -yaml config.go
46 |
47 | const (
48 | // QueryNormTypeByOneOverSqrtHeadDim indicates whether to scale the query by 1/sqrt(head_dim)
49 | QueryNormTypeByOneOverSqrtHeadDim QueryPreAttentionNormalisationType = iota
50 |
51 | // QueryNormTypeByEmbedDimDivNumHeads indicates whether to scale the query by `embed_dim // num_heads`
52 | QueryNormTypeByEmbedDimDivNumHeads
53 |
54 | // QueryNormTypeByOneOverSqrtEmbedDimDivNumHeads indicates whether to scale the query by `1/sqrt(embed_dim // num_heads)`
55 | QueryNormTypeByOneOverSqrtEmbedDimDivNumHeads
56 | )
57 |
58 | // Config Gemma transformer model.
59 | type Config struct {
60 | Type GemmaType
61 | DType dtypes.DType
62 | VocabularySize int
63 | NumLayers, NumEmbed int
64 |
65 | // HuggingFaceVersion has different shapes for some of the variables.
66 | HuggingFaceVersion bool
67 |
68 | // EmbedDim is also called "features" in the original code. It is the representation size (last dimension) of the output of the attention layers.
69 | EmbedDim int
70 | NumHeads, HeadDim int
71 | HiddenDim int
72 | NumKVHeads int
73 | FinalLogitSoftCap float64
74 | UseQKV, UseGroupQueryAttention bool
75 | UsePostAttentionNorm, UsePostFFWNorm bool
76 |
77 | AttentionTypes []AttentionType
78 | MaxCacheLength int
79 | QueryPreAttentionNorm QueryPreAttentionNormalisationType
80 |
81 | // AttentionLogitsSoftCap limits the attention logits (logits = AttentionLogitsSoftCap * tanh(logits/AttentionLogitsSoftCap)).
82 | // Enabled if > 0.
83 | AttentionLogitsSoftCap float64
84 | SlidingWindowSize int
85 | TransposeGatingEinsum bool
86 | }
87 |
88 | // NewConfigFromContext creates a transformers config model, based on the structure of the variables in the given context -- the scope
89 | // has to be set directly to the model variables.
90 | func NewConfigFromContext(ctx *context.Context) (*Config, error) {
91 | c := &Config{
92 | MaxCacheLength: 1024,
93 | QueryPreAttentionNorm: QueryNormTypeByOneOverSqrtHeadDim,
94 | }
95 |
96 | embedTable := ctx.In("embedder").GetVariable("input_embedding")
97 | if embedTable == nil {
98 | return nil, errors.New("context given doesn't have an embedding table defined in \"embedder/input_embedding\"")
99 | }
100 |
101 | c.DType = embedTable.Shape().DType
102 | c.VocabularySize = embedTable.Shape().Dim(0)
103 | c.HuggingFaceVersion = c.VocabularySize == 256000 // Kaggle version is 256128.
104 |
105 | // Find number of layers.
106 | for {
107 | v := ctx.Inf("layer_%d", c.NumLayers).In("pre_attention_norm").GetVariable("scale")
108 | if v == nil {
109 | break
110 | }
111 | c.NumLayers++
112 | }
113 | if t, found := numLayersToGemmaClass[c.NumLayers]; found {
114 | c.Type = t
115 | }
116 |
117 | switch c.Type {
118 | case Gemma2_2B:
119 | c.setGemma2_2B()
120 | default:
121 | return nil, errors.Errorf("unknown or not implemented for Gemma model type %q", c.Type)
122 | }
123 |
124 | c.UseQKV = c.NumKVHeads == c.NumHeads
125 | c.UseGroupQueryAttention = (c.NumKVHeads != c.NumHeads) && c.NumKVHeads > 1
126 | return c, nil
127 | }
128 |
129 | func (c *Config) setGemma2_2B() {
130 | c.NumLayers = 26
131 | c.NumEmbed = 256128
132 | c.EmbedDim = 2304
133 | c.HiddenDim = 9216
134 | c.NumHeads = 8
135 | c.HeadDim = 256
136 | c.NumKVHeads = 4
137 | c.FinalLogitSoftCap = 30.0
138 | c.AttentionTypes = nil
139 | for _ = range c.NumLayers / 2 {
140 | c.AttentionTypes = append(c.AttentionTypes, []AttentionType{AttentionTypeLocalSliding, AttentionTypeGlobal}...)
141 | }
142 | c.UsePostAttentionNorm = true
143 | c.UsePostFFWNorm = true
144 | c.QueryPreAttentionNorm = QueryNormTypeByOneOverSqrtHeadDim
145 | c.AttentionLogitsSoftCap = 50.0
146 | c.SlidingWindowSize = 4096
147 | }
148 |
149 | // QueryPreAttentionScalar is a multiplier to the query projections.
150 | func (c *Config) QueryPreAttentionScalar() float64 {
151 | switch c.QueryPreAttentionNorm {
152 | case QueryNormTypeByEmbedDimDivNumHeads:
153 | return float64(c.EmbedDim / c.NumHeads)
154 | case QueryNormTypeByOneOverSqrtEmbedDimDivNumHeads:
155 | return 1.0 / math.Sqrt(float64(c.EmbedDim/c.NumHeads))
156 | case QueryNormTypeByOneOverSqrtHeadDim:
157 | return 1.0 / math.Sqrt(float64(c.HeadDim))
158 | default:
159 | exceptions.Panicf("invalid value of QueryPreAttentionNorm = %d, expected one of the valid enum values", c.QueryPreAttentionNorm)
160 | panic(nil) // Quiet lint.
161 | }
162 | }
163 |
--------------------------------------------------------------------------------
/transformers/gemmatype_enumer.go:
--------------------------------------------------------------------------------
1 | // Code generated by "enumer -type=GemmaType -transform=snake -values -text -json -yaml config.go"; DO NOT EDIT.
2 |
3 | package transformers
4 |
5 | import (
6 | "encoding/json"
7 | "fmt"
8 | "strings"
9 | )
10 |
11 | const _GemmaTypeName = "unknown_gemma_typegemma_2bgemma_7bgemma2_2bgemma2_9bgemma2_27b"
12 |
13 | var _GemmaTypeIndex = [...]uint8{0, 18, 26, 34, 43, 52, 62}
14 |
15 | const _GemmaTypeLowerName = "unknown_gemma_typegemma_2bgemma_7bgemma2_2bgemma2_9bgemma2_27b"
16 |
17 | func (i GemmaType) String() string {
18 | if i < 0 || i >= GemmaType(len(_GemmaTypeIndex)-1) {
19 | return fmt.Sprintf("GemmaType(%d)", i)
20 | }
21 | return _GemmaTypeName[_GemmaTypeIndex[i]:_GemmaTypeIndex[i+1]]
22 | }
23 |
24 | func (GemmaType) Values() []string {
25 | return GemmaTypeStrings()
26 | }
27 |
28 | // An "invalid array index" compiler error signifies that the constant values have changed.
29 | // Re-run the stringer command to generate them again.
30 | func _GemmaTypeNoOp() {
31 | var x [1]struct{}
32 | _ = x[UnknownGemmaType-(0)]
33 | _ = x[Gemma_2B-(1)]
34 | _ = x[Gemma_7B-(2)]
35 | _ = x[Gemma2_2B-(3)]
36 | _ = x[Gemma2_9B-(4)]
37 | _ = x[Gemma2_27B-(5)]
38 | }
39 |
40 | var _GemmaTypeValues = []GemmaType{UnknownGemmaType, Gemma_2B, Gemma_7B, Gemma2_2B, Gemma2_9B, Gemma2_27B}
41 |
42 | var _GemmaTypeNameToValueMap = map[string]GemmaType{
43 | _GemmaTypeName[0:18]: UnknownGemmaType,
44 | _GemmaTypeLowerName[0:18]: UnknownGemmaType,
45 | _GemmaTypeName[18:26]: Gemma_2B,
46 | _GemmaTypeLowerName[18:26]: Gemma_2B,
47 | _GemmaTypeName[26:34]: Gemma_7B,
48 | _GemmaTypeLowerName[26:34]: Gemma_7B,
49 | _GemmaTypeName[34:43]: Gemma2_2B,
50 | _GemmaTypeLowerName[34:43]: Gemma2_2B,
51 | _GemmaTypeName[43:52]: Gemma2_9B,
52 | _GemmaTypeLowerName[43:52]: Gemma2_9B,
53 | _GemmaTypeName[52:62]: Gemma2_27B,
54 | _GemmaTypeLowerName[52:62]: Gemma2_27B,
55 | }
56 |
57 | var _GemmaTypeNames = []string{
58 | _GemmaTypeName[0:18],
59 | _GemmaTypeName[18:26],
60 | _GemmaTypeName[26:34],
61 | _GemmaTypeName[34:43],
62 | _GemmaTypeName[43:52],
63 | _GemmaTypeName[52:62],
64 | }
65 |
66 | // GemmaTypeString retrieves an enum value from the enum constants string name.
67 | // Throws an error if the param is not part of the enum.
68 | func GemmaTypeString(s string) (GemmaType, error) {
69 | if val, ok := _GemmaTypeNameToValueMap[s]; ok {
70 | return val, nil
71 | }
72 |
73 | if val, ok := _GemmaTypeNameToValueMap[strings.ToLower(s)]; ok {
74 | return val, nil
75 | }
76 | return 0, fmt.Errorf("%s does not belong to GemmaType values", s)
77 | }
78 |
79 | // GemmaTypeValues returns all values of the enum
80 | func GemmaTypeValues() []GemmaType {
81 | return _GemmaTypeValues
82 | }
83 |
84 | // GemmaTypeStrings returns a slice of all String values of the enum
85 | func GemmaTypeStrings() []string {
86 | strs := make([]string, len(_GemmaTypeNames))
87 | copy(strs, _GemmaTypeNames)
88 | return strs
89 | }
90 |
91 | // IsAGemmaType returns "true" if the value is listed in the enum definition. "false" otherwise
92 | func (i GemmaType) IsAGemmaType() bool {
93 | for _, v := range _GemmaTypeValues {
94 | if i == v {
95 | return true
96 | }
97 | }
98 | return false
99 | }
100 |
101 | // MarshalJSON implements the json.Marshaler interface for GemmaType
102 | func (i GemmaType) MarshalJSON() ([]byte, error) {
103 | return json.Marshal(i.String())
104 | }
105 |
106 | // UnmarshalJSON implements the json.Unmarshaler interface for GemmaType
107 | func (i *GemmaType) UnmarshalJSON(data []byte) error {
108 | var s string
109 | if err := json.Unmarshal(data, &s); err != nil {
110 | return fmt.Errorf("GemmaType should be a string, got %s", data)
111 | }
112 |
113 | var err error
114 | *i, err = GemmaTypeString(s)
115 | return err
116 | }
117 |
118 | // MarshalText implements the encoding.TextMarshaler interface for GemmaType
119 | func (i GemmaType) MarshalText() ([]byte, error) {
120 | return []byte(i.String()), nil
121 | }
122 |
123 | // UnmarshalText implements the encoding.TextUnmarshaler interface for GemmaType
124 | func (i *GemmaType) UnmarshalText(text []byte) error {
125 | var err error
126 | *i, err = GemmaTypeString(string(text))
127 | return err
128 | }
129 |
130 | // MarshalYAML implements a YAML Marshaler for GemmaType
131 | func (i GemmaType) MarshalYAML() (interface{}, error) {
132 | return i.String(), nil
133 | }
134 |
135 | // UnmarshalYAML implements a YAML Unmarshaler for GemmaType
136 | func (i *GemmaType) UnmarshalYAML(unmarshal func(interface{}) error) error {
137 | var s string
138 | if err := unmarshal(&s); err != nil {
139 | return err
140 | }
141 |
142 | var err error
143 | *i, err = GemmaTypeString(s)
144 | return err
145 | }
146 |
--------------------------------------------------------------------------------
/transformers/layers.go:
--------------------------------------------------------------------------------
1 | package transformers
2 |
3 | import (
4 | "github.com/gomlx/exceptions"
5 | . "github.com/gomlx/gomlx/graph"
6 | "github.com/gomlx/gomlx/ml/context"
7 | "github.com/gomlx/gomlx/ml/context/initializers"
8 | "github.com/gomlx/gomlx/ml/layers/activations"
9 | "github.com/gomlx/gomlx/types/shapes"
10 | )
11 |
12 | // SoftCap using Tanh, so values won't go beyond +/- cap. If cap <= 0, it is a no-op.
13 | //
14 | // SoftCap(x) = Tanh(x/cap) * cap
15 | func SoftCap(x *Node, cap float64) *Node {
16 | if cap <= 0 {
17 | return x
18 | }
19 | return MulScalar(Tanh(DivScalar(x, cap)), cap)
20 | }
21 |
22 | // KernelEinsum multiplies the input by a kernel of the given shape, using the given graph.EinSum equation.
23 | func KernelEinsum(ctx *context.Context, equation string, x *Node, kernelShape shapes.Shape) *Node {
24 | g := x.Graph()
25 | kernelVar := ctx.VariableWithShape("w", kernelShape)
26 | kernel := kernelVar.ValueGraph(g)
27 | return Einsum(equation, x, kernel)
28 | }
29 |
30 | // RMSNorm normalizes by its root-mean-square x = x / √(mean(sqrt(x), axis=-1) + epsilon) and applies a learned scale.
31 | func RMSNorm(ctx *context.Context, x *Node) *Node {
32 | g := x.Graph()
33 | variance := ReduceAndKeep(Square(x), ReduceMean, -1)
34 | const epsilon = 1e-6
35 | normalizedX := Mul(x, Rsqrt(AddScalar(variance, epsilon)))
36 |
37 | // Now apply a learned scale.
38 | scaleVar := ctx.WithInitializer(initializers.Zero).
39 | VariableWithShape("scale", shapes.Make(x.DType(), x.Shape().Dim(-1)))
40 | scale := scaleVar.ValueGraph(g)
41 | scale = ExpandLeftToRank(scale, normalizedX.Rank()) // Expand rank of scale to match normalizedX.
42 | // Scale centered on 1.0 (so 0.0 has no effect).
43 | scale = OnePlus(scale)
44 | normalizedX = Mul(scale, normalizedX)
45 | return normalizedX
46 | }
47 |
48 | // RoPEDefaultMaxWaveLength is a default value to use for rotary positional encoding.
49 | // See ApplyRotaryPositionEncoding.
50 | const RoPEDefaultMaxWaveLength = 10_000
51 |
52 | // ApplyRotaryPositionEncoding (aka. RoPE) applies the positional encoding to the operand, given the positions
53 | // (integer numbers of the position).
54 | //
55 | // - operand: the last axis ("features" or "embedding" axis) must be divisible by 2. The shape usually is [batchSize, sequenceSize, numHeads, headDim].
56 | // - positions: its shape must be a prefix to operand. Typically, it's shaped [batchSize, sequenceSize].
57 | // - maxWaveLength: it uses wave lengths in a power scale, up to maxWaveLength -- see RoPEDefaultMaxWaveLength for a reasonable value.
58 | //
59 | // Reference: https://arxiv.org/abs/2104.09864
60 | func ApplyRotaryPositionEncoding(operand, positions *Node, maxWaveLength int) *Node {
61 | g := operand.Graph()
62 | dtype := operand.DType()
63 | featuresDim := operand.Shape().Dim(-1)
64 | if featuresDim <= 0 || featuresDim%2 != 0 {
65 | exceptions.Panicf("ApplyRotaryPositionEncoding(operand=%s, position=%s) requires operand's last "+
66 | "dimension to be >= 0 and divisible by 2", operand.Shape(), positions.Shape())
67 | }
68 |
69 | transientDType := dtype
70 | fraction := Iota(g, shapes.Make(transientDType, featuresDim/2), 0)
71 | fraction = MulScalar(fraction, 2.0/float64(featuresDim))
72 | timeScale := Pow(Scalar(g, transientDType, float64(maxWaveLength)), fraction)
73 | timeScale = ExpandLeftToRank(timeScale, positions.Rank()+1)
74 |
75 | // Angles shape will add a rank to positions: we will take each position at a different wave length (or timeScale).
76 | angles := ConvertDType(ExpandAxes(positions, -1), transientDType)
77 | angles = Div(angles, timeScale)
78 |
79 | // Insert an axis just before the last until it matches the operand's shape.
80 | for angles.Rank() < operand.Rank() {
81 | angles = ExpandDims(angles, -2)
82 | }
83 | sines := Sin(angles)
84 | cosines := Cos(angles)
85 |
86 | // Split first/second half of operands features (the last dimension), and apply rotation at the various wave lengths.
87 | firstHalf := Slice(operand, AxisRange().Spacer(), AxisRange(0, featuresDim/2))
88 | secondHalf := Slice(operand, AxisRange().Spacer(), AxisRangeToEnd(featuresDim/2))
89 | firstHalfUpdate := Sub(
90 | Mul(firstHalf, cosines),
91 | Mul(secondHalf, sines),
92 | )
93 | secondHalfUpdate := Add(
94 | Mul(secondHalf, cosines),
95 | Mul(firstHalf, sines),
96 | )
97 | return ConvertDType(Concatenate([]*Node{firstHalfUpdate, secondHalfUpdate}, -1), dtype)
98 | }
99 |
100 | // GatedFeedForward layer for Gemma:
101 | // - hiddenDim: one intermediary layer.
102 | // - transposeGatingEinsum: for some versions of Gemma, the gating (hidden) weights have the axes transposed.
103 | // - It uses Gelu as activation function for the gating signal (multiplied by the up-projected values).
104 | func GatedFeedForward(ctx *context.Context, x *Node, hiddenDim int, transposeGatingEinsum bool) *Node {
105 | g := x.Graph()
106 | featuresDim := x.Shape().Dim(-1)
107 |
108 | var gatingWeights *Node
109 | if transposeGatingEinsum {
110 | // Some versions of Gemma use an alternate parameter ordering that transposes hiddenDim and outputDim.
111 | gatingVar := ctx.WithInitializer(initializers.Zero).
112 | VariableWithShape("gating_einsum", shapes.Make(x.DType(), 2, hiddenDim, featuresDim))
113 | gatingWeights = gatingVar.ValueGraph(g)
114 | gatingWeights = Transpose(gatingWeights, 1, 2)
115 | } else {
116 | // Standard shape of the gating weights.
117 | gatingVar := ctx.WithInitializer(initializers.Zero).
118 | VariableWithShape("gating_einsum", shapes.Make(x.DType(), 2, featuresDim, hiddenDim))
119 | gatingWeights = gatingVar.ValueGraph(g)
120 | }
121 | gatingWeights0 := Squeeze(Slice(gatingWeights, AxisElem(0)), 0)
122 | gatingWeights1 := Squeeze(Slice(gatingWeights, AxisElem(1)), 0)
123 |
124 | gateValue := DotGeneral(x, []int{-1}, nil, gatingWeights0, []int{0}, nil)
125 | gateValue = activations.Gelu(gateValue)
126 |
127 | upProjection := DotGeneral(x, []int{-1}, nil, gatingWeights1, []int{0}, nil)
128 | upProjection = Mul(gateValue, upProjection) // Gate upProjection.
129 |
130 | downProjectionVar := ctx.WithInitializer(initializers.Zero).
131 | VariableWithShape("linear", shapes.Make(x.DType(), hiddenDim, featuresDim))
132 | downProjectionWeights := downProjectionVar.ValueGraph(g)
133 | output := DotGeneral(upProjection, []int{-1}, nil, downProjectionWeights, []int{0}, nil)
134 | return output
135 | }
136 |
137 | // HuggingFaceGatedFeedForward layer for Gemma, the HuggingFace version, with transposed weights:
138 | // - hiddenDim: one intermediary layer.
139 | // - transposeGatingEinsum: for some versions of Gemma, the gating (hidden) weights have the axes transposed.
140 | // - It uses Gelu as activation function for the gating signal (multiplied by the up-projected values).
141 | func HuggingFaceGatedFeedForward(ctx *context.Context, x *Node, hiddenDim int, transposeGatingEinsum bool) *Node {
142 | ctx = ctx.In("hf") // extra-scope for HuggingFace version.
143 | g := x.Graph()
144 | featuresDim := x.Shape().Dim(-1)
145 |
146 | gatingProjVar := ctx.WithInitializer(initializers.Zero).
147 | VariableWithShape("gating_proj", shapes.Make(x.DType(), hiddenDim, featuresDim))
148 | gatingWeights := gatingProjVar.ValueGraph(g)
149 | upProjectionVar := ctx.WithInitializer(initializers.Zero).
150 | VariableWithShape("up_proj", shapes.Make(x.DType(), hiddenDim, featuresDim))
151 | upProjectionWeights := upProjectionVar.ValueGraph(g)
152 |
153 | gateValue := DotGeneral(x, []int{-1}, nil, gatingWeights, []int{1}, nil)
154 | gateValue = activations.Gelu(gateValue)
155 |
156 | upProjection := DotGeneral(x, []int{-1}, nil, upProjectionWeights, []int{1}, nil)
157 | upProjection = Mul(gateValue, upProjection) // Gate upProjection.
158 |
159 | downProjectionVar := ctx.WithInitializer(initializers.Zero).
160 | VariableWithShape("down_proj", shapes.Make(x.DType(), featuresDim, hiddenDim))
161 | downProjectionWeights := downProjectionVar.ValueGraph(g)
162 | output := DotGeneral(upProjection, []int{-1}, nil, downProjectionWeights, []int{1}, nil)
163 | return output
164 | }
165 |
--------------------------------------------------------------------------------
/transformers/querypreattentionnormalisationtype_enumer.go:
--------------------------------------------------------------------------------
1 | // Code generated by "enumer -type=QueryPreAttentionNormalisationType -trimprefix=QueryNormType -transform=snake -values -text -json -yaml config.go"; DO NOT EDIT.
2 |
3 | package transformers
4 |
5 | import (
6 | "encoding/json"
7 | "fmt"
8 | "strings"
9 | )
10 |
11 | const _QueryPreAttentionNormalisationTypeName = "by_one_over_sqrt_head_dimby_embed_dim_div_num_headsby_one_over_sqrt_embed_dim_div_num_heads"
12 |
13 | var _QueryPreAttentionNormalisationTypeIndex = [...]uint8{0, 25, 51, 91}
14 |
15 | const _QueryPreAttentionNormalisationTypeLowerName = "by_one_over_sqrt_head_dimby_embed_dim_div_num_headsby_one_over_sqrt_embed_dim_div_num_heads"
16 |
17 | func (i QueryPreAttentionNormalisationType) String() string {
18 | if i < 0 || i >= QueryPreAttentionNormalisationType(len(_QueryPreAttentionNormalisationTypeIndex)-1) {
19 | return fmt.Sprintf("QueryPreAttentionNormalisationType(%d)", i)
20 | }
21 | return _QueryPreAttentionNormalisationTypeName[_QueryPreAttentionNormalisationTypeIndex[i]:_QueryPreAttentionNormalisationTypeIndex[i+1]]
22 | }
23 |
24 | func (QueryPreAttentionNormalisationType) Values() []string {
25 | return QueryPreAttentionNormalisationTypeStrings()
26 | }
27 |
28 | // An "invalid array index" compiler error signifies that the constant values have changed.
29 | // Re-run the stringer command to generate them again.
30 | func _QueryPreAttentionNormalisationTypeNoOp() {
31 | var x [1]struct{}
32 | _ = x[QueryNormTypeByOneOverSqrtHeadDim-(0)]
33 | _ = x[QueryNormTypeByEmbedDimDivNumHeads-(1)]
34 | _ = x[QueryNormTypeByOneOverSqrtEmbedDimDivNumHeads-(2)]
35 | }
36 |
37 | var _QueryPreAttentionNormalisationTypeValues = []QueryPreAttentionNormalisationType{QueryNormTypeByOneOverSqrtHeadDim, QueryNormTypeByEmbedDimDivNumHeads, QueryNormTypeByOneOverSqrtEmbedDimDivNumHeads}
38 |
39 | var _QueryPreAttentionNormalisationTypeNameToValueMap = map[string]QueryPreAttentionNormalisationType{
40 | _QueryPreAttentionNormalisationTypeName[0:25]: QueryNormTypeByOneOverSqrtHeadDim,
41 | _QueryPreAttentionNormalisationTypeLowerName[0:25]: QueryNormTypeByOneOverSqrtHeadDim,
42 | _QueryPreAttentionNormalisationTypeName[25:51]: QueryNormTypeByEmbedDimDivNumHeads,
43 | _QueryPreAttentionNormalisationTypeLowerName[25:51]: QueryNormTypeByEmbedDimDivNumHeads,
44 | _QueryPreAttentionNormalisationTypeName[51:91]: QueryNormTypeByOneOverSqrtEmbedDimDivNumHeads,
45 | _QueryPreAttentionNormalisationTypeLowerName[51:91]: QueryNormTypeByOneOverSqrtEmbedDimDivNumHeads,
46 | }
47 |
48 | var _QueryPreAttentionNormalisationTypeNames = []string{
49 | _QueryPreAttentionNormalisationTypeName[0:25],
50 | _QueryPreAttentionNormalisationTypeName[25:51],
51 | _QueryPreAttentionNormalisationTypeName[51:91],
52 | }
53 |
54 | // QueryPreAttentionNormalisationTypeString retrieves an enum value from the enum constants string name.
55 | // Throws an error if the param is not part of the enum.
56 | func QueryPreAttentionNormalisationTypeString(s string) (QueryPreAttentionNormalisationType, error) {
57 | if val, ok := _QueryPreAttentionNormalisationTypeNameToValueMap[s]; ok {
58 | return val, nil
59 | }
60 |
61 | if val, ok := _QueryPreAttentionNormalisationTypeNameToValueMap[strings.ToLower(s)]; ok {
62 | return val, nil
63 | }
64 | return 0, fmt.Errorf("%s does not belong to QueryPreAttentionNormalisationType values", s)
65 | }
66 |
67 | // QueryPreAttentionNormalisationTypeValues returns all values of the enum
68 | func QueryPreAttentionNormalisationTypeValues() []QueryPreAttentionNormalisationType {
69 | return _QueryPreAttentionNormalisationTypeValues
70 | }
71 |
72 | // QueryPreAttentionNormalisationTypeStrings returns a slice of all String values of the enum
73 | func QueryPreAttentionNormalisationTypeStrings() []string {
74 | strs := make([]string, len(_QueryPreAttentionNormalisationTypeNames))
75 | copy(strs, _QueryPreAttentionNormalisationTypeNames)
76 | return strs
77 | }
78 |
79 | // IsAQueryPreAttentionNormalisationType returns "true" if the value is listed in the enum definition. "false" otherwise
80 | func (i QueryPreAttentionNormalisationType) IsAQueryPreAttentionNormalisationType() bool {
81 | for _, v := range _QueryPreAttentionNormalisationTypeValues {
82 | if i == v {
83 | return true
84 | }
85 | }
86 | return false
87 | }
88 |
89 | // MarshalJSON implements the json.Marshaler interface for QueryPreAttentionNormalisationType
90 | func (i QueryPreAttentionNormalisationType) MarshalJSON() ([]byte, error) {
91 | return json.Marshal(i.String())
92 | }
93 |
94 | // UnmarshalJSON implements the json.Unmarshaler interface for QueryPreAttentionNormalisationType
95 | func (i *QueryPreAttentionNormalisationType) UnmarshalJSON(data []byte) error {
96 | var s string
97 | if err := json.Unmarshal(data, &s); err != nil {
98 | return fmt.Errorf("QueryPreAttentionNormalisationType should be a string, got %s", data)
99 | }
100 |
101 | var err error
102 | *i, err = QueryPreAttentionNormalisationTypeString(s)
103 | return err
104 | }
105 |
106 | // MarshalText implements the encoding.TextMarshaler interface for QueryPreAttentionNormalisationType
107 | func (i QueryPreAttentionNormalisationType) MarshalText() ([]byte, error) {
108 | return []byte(i.String()), nil
109 | }
110 |
111 | // UnmarshalText implements the encoding.TextUnmarshaler interface for QueryPreAttentionNormalisationType
112 | func (i *QueryPreAttentionNormalisationType) UnmarshalText(text []byte) error {
113 | var err error
114 | *i, err = QueryPreAttentionNormalisationTypeString(string(text))
115 | return err
116 | }
117 |
118 | // MarshalYAML implements a YAML Marshaler for QueryPreAttentionNormalisationType
119 | func (i QueryPreAttentionNormalisationType) MarshalYAML() (interface{}, error) {
120 | return i.String(), nil
121 | }
122 |
123 | // UnmarshalYAML implements a YAML Unmarshaler for QueryPreAttentionNormalisationType
124 | func (i *QueryPreAttentionNormalisationType) UnmarshalYAML(unmarshal func(interface{}) error) error {
125 | var s string
126 | if err := unmarshal(&s); err != nil {
127 | return err
128 | }
129 |
130 | var err error
131 | *i, err = QueryPreAttentionNormalisationTypeString(s)
132 | return err
133 | }
134 |
--------------------------------------------------------------------------------
/transformers/transformers.go:
--------------------------------------------------------------------------------
1 | // Package transformers implements the various Gema models.
2 | // It is based on https://github.com/google-deepmind/gemma/blob/main/gemma/transformer.py
3 | package transformers
4 |
5 | import (
6 | "fmt"
7 | "github.com/gomlx/gemma/trees"
8 | . "github.com/gomlx/gomlx/graph"
9 | "github.com/gomlx/gomlx/ml/context"
10 | "github.com/gomlx/gomlx/types/shapes"
11 | "github.com/gomlx/gopjrt/dtypes"
12 | )
13 |
14 | // GemmaWithCache creates a forward path on a Gemma model for one decoding step,
15 | // using the weights in Config to initialize the variables.
16 | //
17 | // It takes as input the current token to decode currentTokens (shape [batchSize, 1] in a sequence along
18 | // with currentPosition (shape [batchSize, 1]) and the current cache of the key/values for each transformer
19 | // layer (see Cache), whose elements are generally shaped [batchSize, MaxCacheLength,...].
20 | //
21 | // It updates the Cache with the new step in-place, and returns the logits (shape [batchSize, ])
22 | // of the prediction of the next token.
23 | func GemmaWithCache(ctx *context.Context, config *Config,
24 | currentTokens, currentPositions *Node, cache *trees.Tree[*Node], cacheAttentionMask *Node) *Node {
25 | batchSize := currentTokens.Shape().Dim(0)
26 | seqLength := currentTokens.Shape().Dim(1)
27 |
28 | // Embed.
29 | x := EmbedTokens(ctx.In("embedder"), config, currentTokens)
30 |
31 | // Run through numLayers blocks.
32 | for blockIdx := range config.NumLayers {
33 | blockName := fmt.Sprintf("layer_%d", blockIdx)
34 | blockCtx := ctx.In(blockName)
35 | blockCache := cache.Map[blockName]
36 | x = Block(blockCtx, config, blockIdx, x, currentPositions, blockCache, cacheAttentionMask)
37 | //x.SetLogged(fmt.Sprintf("GemmaWithCache::x(%s)", blockName))
38 | x = Identity(x)
39 | }
40 |
41 | x = RMSNorm(ctx.In("final_norm"), x)
42 | logits := DecodeTokens(ctx.Reuse().In("embedder"), config, x)
43 | logits = SoftCap(logits, config.FinalLogitSoftCap)
44 | logits.AssertDims(batchSize, seqLength, config.VocabularySize)
45 | return logits
46 | }
47 |
48 | // EmbedTokens using weights in Config.
49 | // Input: currentTokens: [batchSize, sequenceLength]
50 | // Output: embeddings: [batchSize, sequenceLength, config.EmbedDim]
51 | func EmbedTokens(ctx *context.Context, config *Config, currentTokens *Node) *Node {
52 | g := currentTokens.Graph()
53 | embedTableVar := ctx.VariableWithShape("input_embedding", shapes.Make(dtypes.BFloat16, config.VocabularySize, config.EmbedDim))
54 | embeddings := Gather(embedTableVar.ValueGraph(g), ExpandAxes(currentTokens, -1))
55 | embeddings = Mul(embeddings, Sqrt(Scalar(g, embeddings.DType(), config.EmbedDim)))
56 | return embeddings
57 | }
58 |
59 | // DecodeTokens use the same table as EmbedTokens to convert embedding back to the tokens -- or to token logits.
60 | // Input: current embeddings: [batchSize, sequenceLength, embedDim]
61 | // Output: logits for each token: [batchSize, sequenceLength, vocabularySize]
62 | func DecodeTokens(ctx *context.Context, config *Config, x *Node) *Node {
63 | g := x.Graph()
64 | embedTableVar := ctx.VariableWithShape("input_embedding", shapes.Make(dtypes.BFloat16, config.VocabularySize, config.EmbedDim))
65 | embedTable := embedTableVar.ValueGraph(g)
66 | return DotGeneral(x, []int{-1}, nil, embedTable, []int{-1}, nil)
67 | }
68 |
69 | // Block implements one transformer block for the Gemma model. x is shaped [batchSize, sequenceLength], and if
70 | // using cache (cache != nil), x will only contain the current token, shaped [batchSize, 1].
71 | //
72 | // The attentionIdx indexes attention configuration (in config) parameters, like config.AttentionTypes.
73 | //
74 | // If cache is given, attentionMask is relative to the cache. Otherwise, attentionMask is relative to the operand x.
75 | func Block(ctx *context.Context, config *Config, attentionIdx int, x, positions *Node, cache *trees.Tree[*Node], attentionMask *Node) *Node {
76 | normalizedX := RMSNorm(ctx.In("pre_attention_norm"), x)
77 |
78 | // Attention
79 | attentionOut := Attention(ctx.In("attn"), config, attentionIdx, normalizedX, positions, cache, attentionMask)
80 | if config.UsePostAttentionNorm {
81 | attentionOut = RMSNorm(ctx.In("post_attention_norm"), attentionOut)
82 | }
83 |
84 | // Residual (or skip) connection.
85 | attentionOut = Add(attentionOut, x)
86 |
87 | // GatedFeedForward ("ffw") layer: 2 layers, with a gate.
88 | output := RMSNorm(ctx.In("pre_ffw_norm"), attentionOut)
89 | if config.HuggingFaceVersion {
90 | output = HuggingFaceGatedFeedForward(ctx.In("mlp"), output, config.HiddenDim, config.TransposeGatingEinsum)
91 | } else {
92 | output = GatedFeedForward(ctx.In("mlp"), output, config.HiddenDim, config.TransposeGatingEinsum)
93 | }
94 | if config.UsePostFFWNorm {
95 | output = RMSNorm(ctx.In("post_ffw_norm"), output)
96 | }
97 |
98 | // Residual to attentionOut.
99 | output = Add(output, attentionOut)
100 | return output
101 | }
102 |
--------------------------------------------------------------------------------
/trees/trees.go:
--------------------------------------------------------------------------------
1 | package trees
2 |
3 | import (
4 | "fmt"
5 | "github.com/gomlx/exceptions"
6 | "github.com/gomlx/gomlx/types/xslices"
7 | "github.com/pkg/errors"
8 | "golang.org/x/exp/slices"
9 | "iter"
10 | "strings"
11 | )
12 |
13 | // Tree represent both a root of a tree, and a tree node.
14 | //
15 | // It can either be a Value or a Map of its children -- but not both.
16 | type Tree[T any] struct {
17 | // Value is set for leaf nodes only.
18 | Value T
19 |
20 | // Map is set for non-leaf nodes (and nil in leaf nodes).
21 | Map map[string]*Tree[T]
22 | }
23 |
24 | func (n *Tree[T]) IsLeaf() bool { return n.Map == nil }
25 |
26 | // Path is usually used as the path from the root node.
27 | type Path []string
28 |
29 | // New creates a new empty tree.
30 | func New[T any]() *Tree[T] {
31 | return &Tree[T]{
32 | Map: make(map[string]*Tree[T]),
33 | }
34 | }
35 |
36 | // NewLeaf creates a new leaf node with the given value.
37 | func NewLeaf[T any](value T) *Tree[T] {
38 | return &Tree[T]{Value: value}
39 | }
40 |
41 | // DefaultTreePath is used whenever an empty treePath is given.
42 | var DefaultTreePath = []string{"#root"}
43 |
44 | // Get value in treePath.
45 | // It returns an error if such a leaf node doesn't exist.
46 | //
47 | // Empty values in treePath are not used.
48 | func (tree *Tree[T]) Get(treePath ...string) (value T, err error) {
49 | // Remove empty ("") path components -- clone the slice, not to modify caller's slice.
50 | if slices.Index(treePath, "") > 0 {
51 | treePath = slices.DeleteFunc(slices.Clone(treePath),
52 | func(s string) bool {
53 | return s == ""
54 | })
55 | }
56 | remainingPath := treePath
57 | pathCount := 0
58 | for len(remainingPath) > 0 {
59 | if tree == nil {
60 | err = errors.Errorf("trees.Tree[%T].Get(%q) can't get to sub-path %q, tree ends in a nil node, can't go forward",
61 | value, treePath, treePath[:pathCount+1])
62 | }
63 | if tree.IsLeaf() {
64 | err = errors.Errorf("trees.Tree[%T].Get(%q) the sub-path %q ends on a leaf-node, can't go forward",
65 | value, treePath, treePath[:pathCount+1])
66 | return
67 | }
68 | tree = tree.Map[remainingPath[0]]
69 | remainingPath = remainingPath[1:]
70 | pathCount++
71 | }
72 | if tree == nil {
73 | err = errors.Errorf("trees.Tree[%T].Get(%q) can't get to sub-path %q, tree ends in a nil node, can't go forward",
74 | value, treePath, treePath[:pathCount+1])
75 | }
76 | if !tree.IsLeaf() {
77 | err = errors.Errorf("trees.Tree[%T].Get(%q) is not a leaf-node!?", value, treePath)
78 | return
79 | }
80 | value = tree.Value
81 | return
82 | }
83 |
84 | // Set value in treePath, populating intermediary nodes where needed.
85 | //
86 | // Empty values in treePath are not used.
87 | //
88 | // It returns an error if one is trying to set the value to an existing non-leaf node: nodes can either
89 | // be a leaf or a Map (non-leaf), but not both.
90 | func (tree *Tree[T]) Set(treePath Path, value T) error {
91 | // Remove empty ("") path components -- clone the slice, not to modify caller's slice.
92 | if slices.Index(treePath, "") > 0 {
93 | treePath = slices.DeleteFunc(slices.Clone(treePath),
94 | func(s string) bool {
95 | return s == ""
96 | })
97 | }
98 | remainingPath := treePath
99 | pathCount := 0
100 | node := tree
101 | for len(remainingPath) > 0 {
102 | pathElement := remainingPath[0]
103 | remainingPath = remainingPath[1:]
104 | if node.IsLeaf() {
105 | var t T
106 | return errors.Errorf("trees.Tree[%T].Set(%q) trying to create a path using an existing leaf node (%q) as a non-leaf node",
107 | t, treePath, treePath[:pathCount])
108 | }
109 | newNode := node.Map[pathElement]
110 | if newNode == nil {
111 | if len(remainingPath) == 0 {
112 | newNode = NewLeaf[T](value)
113 | } else {
114 | newNode = New[T]()
115 | }
116 | node.Map[pathElement] = newNode
117 | }
118 | node = newNode
119 | pathCount++
120 | }
121 | if !node.IsLeaf() {
122 | var t T
123 | return errors.Errorf("trees.Tree[%T].Set(%q) trying to set the value to a non-leaf node -- each node can either be a leaf node, or be a structural map of the tree",
124 | t, treePath)
125 | }
126 | node.Value = value
127 | return nil
128 | }
129 |
130 | // String implements fmt.String
131 | func (tree *Tree[T]) String() string {
132 | var parts []string
133 | parts = nodeToString(parts, "/", tree, 0)
134 | return strings.Join(parts, "\n") + "\n"
135 | }
136 |
137 | func nodeToString[T any](parts []string, name string, subTree *Tree[T], indent int) []string {
138 | indentSpaces := strings.Repeat(" ", indent)
139 | indent++
140 | if len(subTree.Map) == 0 {
141 | // Leaf node.
142 | var valueAny any
143 | valueAny = subTree.Value
144 | if valueStr, ok := valueAny.(fmt.Stringer); ok {
145 | // T is a stringer:
146 | return append(parts, fmt.Sprintf("%s%q: %s", indentSpaces, name, valueStr))
147 | }
148 | // If not a stringer, use %v.
149 | return append(parts, fmt.Sprintf("%s%q: %v", indentSpaces, name, subTree.Value))
150 | }
151 | parts = append(parts, fmt.Sprintf("%s%q: {", indentSpaces, name))
152 |
153 | for _, key := range xslices.SortedKeys(subTree.Map) {
154 | parts = nodeToString(parts, key, subTree.Map[key], indent)
155 | }
156 | parts = append(parts, fmt.Sprintf("%s}", indentSpaces))
157 | return parts
158 | }
159 |
160 | // Map converts a Tree[T1] to a Tree[T2] by calling mapFn at every element.
161 | func Map[T1, T2 any](tree1 *Tree[T1], mapFn func(Path, T1) T2) *Tree[T2] {
162 | if tree1.IsLeaf() {
163 | // Input tree1 is just a leaf node:
164 | return NewLeaf[T2](mapFn(nil, tree1.Value))
165 | }
166 |
167 | tree2 := New[T2]()
168 | for p, t1 := range tree1.Leaves() {
169 | err := tree2.Set(p, mapFn(p, t1))
170 | if err != nil {
171 | // Should never happen, since there can be no errors duplicating the structure of an existing valid tree.
172 | panic(err)
173 | }
174 | }
175 | return tree2
176 | }
177 |
178 | // Leaves returns an iterator that goes over all the leaf nodes of the Tree.
179 | // The key is a Path, and value is T.
180 | func (tree *Tree[T]) Leaves() iter.Seq2[Path, T] {
181 | return func(yield func(Path, T) bool) {
182 | recursiveLeaves(nil, tree, false, yield)
183 | }
184 | }
185 |
186 | // NumLeaves traverses the trees and returns the number of leaf nodes.
187 | func (tree *Tree[T]) NumLeaves() int {
188 | var count int
189 | for _, _ = range tree.Leaves() {
190 | count++
191 | }
192 | return count
193 | }
194 |
195 | // OrderedLeaves returns an iterator that goes over all the leaf nodes of the Tree in alphabetical order of the
196 | // tree nodes (depth-first).
197 | //
198 | // The key is a Path, and value is T.
199 | func (tree *Tree[T]) OrderedLeaves() iter.Seq2[Path, T] {
200 | return func(yield func(Path, T) bool) {
201 | recursiveLeaves(nil, tree, true, yield)
202 | }
203 | }
204 |
205 | func recursiveLeaves[T any](treePath Path, node *Tree[T], ordered bool, yield func(Path, T) bool) bool {
206 | if node.IsLeaf() {
207 | return yield(slices.Clone(treePath), node.Value)
208 | }
209 | if ordered {
210 | // Extract keys and sort first.
211 | for _, key := range xslices.SortedKeys(node.Map) {
212 | subNode := node.Map[key]
213 | ok := recursiveLeaves[T](append(treePath, key), subNode, ordered, yield)
214 | if !ok {
215 | return false
216 | }
217 | }
218 | } else {
219 | // Usual range over map, non-deterministic.
220 | for key, subNode := range node.Map {
221 | ok := recursiveLeaves(append(treePath, key), subNode, ordered, yield)
222 | if !ok {
223 | return false
224 | }
225 | }
226 | }
227 | return true
228 | }
229 |
230 | // ValuesAsList extracts the leaf values of Tree into a list.
231 | //
232 | // It's generated in alphabetical order -- see OrderedLeaves to see or generate the order.
233 | func ValuesAsList[T any](tree *Tree[T]) []T {
234 | results := make([]T, 0, tree.NumLeaves())
235 | for _, values := range tree.OrderedLeaves() {
236 | results = append(results, values)
237 | }
238 | return results
239 | }
240 |
241 | // FromValuesAndTree creates a Tree[T1] with the given values, but borrowing the structure from the given tree (but
242 | // ignoring the tree's values).
243 | func FromValuesAndTree[T1, T2 any](values []T1, tree *Tree[T2]) *Tree[T1] {
244 | numLeaves := tree.NumLeaves()
245 | if len(values) != numLeaves {
246 | exceptions.Panicf("%d values given, but the tree to be built has %d leaves.", len(values), numLeaves)
247 | }
248 | newTree := New[T1]()
249 | var idx int
250 | for treePath, _ := range tree.OrderedLeaves() {
251 | err := newTree.Set(treePath, values[idx])
252 | if err != nil {
253 | // Should never happen, since there can be no errors duplicating the structure of an existing valid tree.
254 | panic(err)
255 | }
256 | idx++
257 | }
258 | return newTree
259 | }
260 |
--------------------------------------------------------------------------------
/trees/trees_test.go:
--------------------------------------------------------------------------------
1 | package trees
2 |
3 | import (
4 | "fmt"
5 | "github.com/stretchr/testify/require"
6 | "testing"
7 | )
8 |
9 | type expectedTreeValueType[T any] struct {
10 | p Path
11 | v T
12 | }
13 |
14 | func verifyTreeValues[T any](t *testing.T, tree *Tree[T], wantValues []expectedTreeValueType[T]) {
15 | for _, want := range wantValues {
16 | got, err := tree.Get(want.p...)
17 | require.NoError(t, err)
18 | require.Equal(t, want.v, got)
19 | }
20 | count := 0
21 | for p, v := range tree.OrderedLeaves() {
22 | if count >= len(wantValues) {
23 | t.Fatalf("tree ranged over more leaves than the %d expected", len(wantValues))
24 | }
25 | require.Equalf(t, wantValues[count].p, p, "Unexpected path %q -- maybe out-of-order?", p)
26 | require.Equalf(t, wantValues[count].v, v, "Unexpected value for path %q", p)
27 | count++
28 | }
29 | if count != len(wantValues) {
30 | t.Fatalf("tree only ranged over %d leaf-values, but we expected %d values", count, len(wantValues))
31 | }
32 | }
33 |
34 | func createTestTree(t *testing.T) *Tree[int] {
35 | tree := New[int]()
36 | require.NoError(t, tree.Set([]string{"a"}, 1))
37 | require.NoError(t, tree.Set([]string{"b", "y"}, 3))
38 | require.NoError(t, tree.Set([]string{"b", "x"}, 2))
39 | return tree
40 | }
41 |
42 | func TestNewAndSet(t *testing.T) {
43 | tree := createTestTree(t)
44 | fmt.Printf("Tree:\n%v\n", tree)
45 |
46 | require.Equal(t, 1, tree.Map["a"].Value)
47 | require.Equal(t, 2, tree.Map["b"].Map["x"].Value)
48 | require.Equal(t, 3, tree.Map["b"].Map["y"].Value)
49 |
50 | err := tree.Set([]string{"b"}, 4)
51 | fmt.Printf("\texpected error trying to set non-leaf node: %v\n", err)
52 | require.ErrorContains(t, err, "trying to set the value to a non-leaf node")
53 |
54 | err = tree.Set([]string{"b", "x", "0"}, 5)
55 | fmt.Printf("\texpected error trying to use leaf node as structure: %v\n", err)
56 | require.ErrorContains(t, err, "trying to create a path using an existing leaf node")
57 |
58 | tree2 := NewLeaf(float32(7))
59 | fmt.Printf("Tree:\n%v\n", tree2)
60 | require.NoError(t, tree2.Set(nil, float32(11)))
61 | require.Equal(t, float32(11), tree2.Value)
62 | }
63 |
64 | func TestOrderedLeaves(t *testing.T) {
65 | tree := createTestTree(t)
66 | fmt.Printf("Tree:\n%v\n", tree)
67 | // Test OrderedLeaves traversal and that the contents of the tree match.
68 | verifyTreeValues(t, tree, []expectedTreeValueType[int]{
69 | {Path{"a"}, 1},
70 | {Path{"b", "x"}, 2},
71 | {Path{"b", "y"}, 3},
72 | })
73 | }
74 |
75 | func TestMap(t *testing.T) {
76 | tree := createTestTree(t)
77 | fmt.Printf("Tree:\n%v\n", tree)
78 | treeFloat := Map(tree, func(_ Path, v int) float32 { return float32(v) })
79 | verifyTreeValues(t, treeFloat, []expectedTreeValueType[float32]{
80 | {Path{"a"}, 1},
81 | {Path{"b", "x"}, 2},
82 | {Path{"b", "y"}, 3},
83 | })
84 |
85 | tree2 := NewLeaf(float32(7))
86 | fmt.Printf("Tree:\n%v\n", tree2)
87 | tree2Int := Map(tree2, func(_ Path, v float32) int { return int(v) })
88 | verifyTreeValues(t, tree2Int, []expectedTreeValueType[int]{
89 | {nil, 7},
90 | })
91 | }
92 |
93 | func TestValuesAsList(t *testing.T) {
94 | tree := createTestTree(t)
95 | fmt.Printf("Tree:\n%v\n", tree)
96 | require.Equal(t, []int{1, 2, 3}, ValuesAsList(tree))
97 |
98 | tree2 := NewLeaf(float32(7))
99 | fmt.Printf("Tree:\n%v\n", tree2)
100 | require.Equal(t, []float32{7}, ValuesAsList(tree2))
101 | }
102 |
103 | func TestFromValuesAndTree(t *testing.T) {
104 | tree := createTestTree(t)
105 | newValues := []float64{1.01, 2.02, 3.03}
106 | newTree := FromValuesAndTree(newValues, tree)
107 | fmt.Printf("New Tree:\n%v\n", newTree)
108 | verifyTreeValues(t, newTree, []expectedTreeValueType[float64]{
109 | {Path{"a"}, 1.01},
110 | {Path{"b", "x"}, 2.02},
111 | {Path{"b", "y"}, 3.03},
112 | })
113 | }
114 |
--------------------------------------------------------------------------------