├── .gitignore ├── LICENSE ├── README.md ├── cmd ├── convert_checkpoint.py └── gemma_demo │ ├── generator.go │ ├── main.go │ └── ui.go ├── download ├── README.md ├── huggingface │ └── huggingface.go └── kaggle │ ├── metadata.go │ └── weights.go ├── go.mod ├── go.sum ├── samplers └── samplers.go ├── sentencepiece └── sentencepiece.go ├── transformers ├── attention.go ├── attentiontype_enumer.go ├── cache.go ├── config.go ├── gemmatype_enumer.go ├── layers.go ├── querypreattentionnormalisationtype_enumer.go └── transformers.go └── trees ├── trees.go └── trees_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | # If you prefer the allow list template instead of the deny list, see community template: 2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 3 | # 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | 11 | # Test binary, built with `go test -c` 12 | *.test 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Dependency directories (remove the comment below to include it) 18 | # vendor/ 19 | 20 | # Go workspace file 21 | go.work 22 | go.work.sum 23 | 24 | # env file 25 | .env 26 | 27 | # GoLand IDE 28 | .idea 29 | 30 | # Python script, local venv. 31 | venv -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | GoMLX Gopher 3 | 4 | # GoMLX Gemma 5 | 6 | GoMLX (for Go) port of Google Deepmind's Gemma GenAI/LLM model. 7 | 8 | ## 📖 About GoMLX Gemma 9 | 10 | An implementation of [Google DeepMind](deepmind.google)'s [Gemma model](https://github.com/google-deepmind/gemma?tab=readme-ov-file) 11 | using [GoMLX, a Machine Learning framework for Go](https://github.com/gomlx/gomlx). 12 | 13 | It is very "_fresh from the oven_", so use it at your own risk. 14 | At the same time, I'm happy to help if you need any specific features, it's a good time for feature requests. 15 | 16 | ## ✅ **What is done** already: 17 | 18 | * **Sampling** / **Generating**: it provides the `samplers.Sampler` object to easily generate text. 19 | See example below, or `cmd/gemma_demo/generator.go` for an example. 20 | * HuggingFace Weights Version: 21 | * Download weights from HuggingFace, using provided AuthToken -- a read-only token will suffice. 22 | * Kaggle Version 23 | * Requires manually downloading weights from Kaggle. 24 | * Use provided `cmd/convert_checkpoint.py` script to convert Jax weights -- requires Python installation. 25 | * A command-line demo `cmd/gemma_demo`, with a simple [Charm](https://charm.sh/) interface. 26 | 27 | ## ❌ **Not done** yet: 28 | 29 | * **Fine-tuning**: the model is there, and it just needs some wiring together. But there is no sample code yet. 30 | 31 | ## ⌨️ Sample Code 32 | 33 | This is an example of how a `Sampler` object is created (for the simpler HuggingFace version) -- it requires the 34 | HuggingFace token (read-only) used to download to be set in HF_TOKEN -- go to HuggingFace webpage to generate one for you. 35 | 36 | ```go 37 | package main 38 | 39 | import ( 40 | ... 41 | 42 | hfd "github.com/gomlx/gemma/download/huggingface" 43 | "github.com/gomlx/gemma/samplers" 44 | "github.com/gomlx/gomlx/backends" 45 | "github.com/gomlx/gomlx/ml/context" 46 | 47 | _ "github.com/gomlx/gomlx/backends/xla" 48 | ) 49 | 50 | var ( 51 | flagModelID = flag.String("model", "google/gemma-2-2b-it", "HuggingFace Gemma model id") 52 | flagDataDir = flag.String("data", "~/work/gemma", "Directory to cache downloaded dataset files.") 53 | ) 54 | 55 | func main() { 56 | flag.Parse() 57 | prompts := []string{ 58 | "What is 1+1 ?", 59 | "What are the planets of the solar system?", 60 | "```\n// BubbleSort is a Go function that sorts the Bubble Sort algorithm\nfunc BubbleSort[S ~[]E, E cmp.Ordered](x S) {\n", 61 | } 62 | ctx := context.New() 63 | vocab, err := hfd.Download(ctx, *flagModelID, os.Getenv("HF_TOKEN"), path.Join(*flagDataDir, "huggingface")) 64 | if err != nil { 65 | log.Fatalf("%+v", err) 66 | } 67 | sampler, err := samplers.New(backends.New(), ctx, vocab, 1024) 68 | if err != nil { 69 | log.Fatalf("%+v", err) 70 | } 71 | 72 | start := time.Now() 73 | output, err := sampler.Sample([]string{ 74 | "What is 1+1?", 75 | "What are the planets of the solar system?", 76 | // "// BubbleSort is a Go function that sorts the Bubble Sort algorithm\nfunc BubbleSort[S ~[]E, E cmp.Ordered](x S)", 77 | }) 78 | if err != nil { 79 | log.Fatalf("%+v", err) 80 | } 81 | fmt.Printf("\tElapsed time: %s\n", time.Since(start)) 82 | fmt.Printf("Generated text:\n%s\n", strings.Join(output, "\n\n")) 83 | } 84 | ``` 85 | 86 | ## 🔗 Resources 87 | 88 | 1. [**github.com/google-deepmind/gemma**](https://github.com/google-deepmind/gemma): 89 | [Gemma](https://ai.google.dev/gemma) is a family of open-weights Large Language Model (LLM) by [Google DeepMind](https://deepmind.google/), 90 | based on Gemini research and technology. 91 | 1. [github.com/eliben/go-sentencepiece](https://github.com/eliben/go-sentencepiece): 92 | This is a pure Go implementation of encoding and decoding text with the [SentencePiece tokenizer](https://github.com/google/sentencepiece). 93 | 94 | 95 | ## 📝 TODO 96 | 97 | * Remove special symbols from sampling, like "". 98 | * Fine-tuning demo. 99 | * Benchmarking: how does it compare to Jax implementation ? Jax JIT-compile the main sampling loop during generation, 100 | which could be done with GoMLX, but it would require implementing some new features. Not sure it is needed yet. 101 | * At least in an old NVidia RTX 2080Ti, it works with GoMLX, and Jax reference implementation fails to sample, 102 | because it tries to JIT-compile the full sampling loop. -------------------------------------------------------------------------------- /cmd/convert_checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Jax/Orbax Gemma Checkpoint to Raw Bytes Converter 6 | 7 | This script converts a Jax/Orbax Gemma model checkpoint into a directory structure 8 | containing raw byte representations of the model's parameter arrays, as well as their shape. 9 | 10 | Usage: 11 | python convert_checkpoint.py [--target_dir ] 12 | 13 | Arguments: 14 | source_dir: Path to the directory containing the Jax/Orbax checkpoint. 15 | target_dir: (Optional) Path to the directory where the raw bytes will be saved. 16 | Defaults to '/raw/' if not provided. 17 | 18 | It requires the following libraries installed, probably in a virtual environment (venv) or equivalent:: 19 | 20 | ``` 21 | pip install jax "git+https://github.com/google-deepmind/gemma.git" 22 | ``` 23 | """ 24 | 25 | import argparse 26 | import os 27 | import jax 28 | from gemma import params as params_lib 29 | 30 | def read_parameter(path): 31 | """ 32 | Read model checkpoint parameters from path to directory. 33 | 34 | :param path: Path to directory holding the checkpoint to read the parameters from. 35 | :return: PyTree of jaxlib.xla_extension.ArrayImpl 36 | """ 37 | path = os.path.expanduser(path) 38 | return params_lib.load_and_format_params(path) 39 | 40 | 41 | def write_params(params, base_dir): 42 | """Write parameters to structured to directory: each file correspond to one array written as raw bytes.""" 43 | base_dir = os.path.expanduser(base_dir) 44 | for path, array in flatten_params(params): 45 | base_file_path = os.path.join(base_dir, *path) 46 | 47 | # Create necessary directories 48 | os.makedirs(os.path.dirname(base_file_path), exist_ok=True) 49 | 50 | # Save array. 51 | with open(base_file_path+".raw", 'wb') as f: 52 | f.write(array.tobytes()) 53 | 54 | # Save shape. 55 | with open(base_file_path+".shape", 'w') as f: 56 | f.write(serialize_shape(array)) 57 | 58 | 59 | def path_to_str_tuple(path): 60 | """Converts a PyTree path (tuple of jax.tree_util.DictKey) to a tuple of strings.""" 61 | return [e.key for e in path] 62 | 63 | 64 | def flatten_params(params): 65 | """Convert PyTree of arrays to a list of pairs of (path, array), where path is itself a tuple of strings.""" 66 | list = [] 67 | def append_to_list(path, value): 68 | list.append((path_to_str_tuple(path), value)) 69 | 70 | jax.tree_util.tree_map_with_path(append_to_list, params) 71 | return list 72 | 73 | 74 | def serialize_shape(array): 75 | """Return an encoding of the given array's shape (including dtype).""" 76 | return ",".join([f"{str(array.dtype)}"]+[str(i) for i in array.shape]) 77 | 78 | 79 | def main(): 80 | parser = argparse.ArgumentParser(description="Convert Gemma Jax/Orbax checkpoint to raw bytes.") 81 | parser.add_argument("source_dir", help="Path to the source directory containing the checkpoint.") 82 | parser.add_argument("--target_dir", help="Path to the target directory where the raw bytes will be saved. Defaults to source_dir + 'raw/' if not provided.") 83 | 84 | args = parser.parse_args() 85 | 86 | source_dir = os.path.abspath(os.path.expanduser(args.source_dir)) 87 | target_dir = os.path.abspath(os.path.expanduser(args.target_dir)) if args.target_dir else os.path.join(source_dir, "raw") 88 | 89 | print("Conversion from Jax/Orbax Gemma checkpoint to raw arrays/shapes:") 90 | print(f"\tSource directory: {source_dir}") 91 | print(f"\tTarget directory: {target_dir}") 92 | 93 | # We don't want to use GPU memory for this. 94 | jax.config.update('jax_platform_name', 'cpu') 95 | 96 | params = read_parameter(source_dir) 97 | write_params(params, target_dir) 98 | 99 | 100 | if __name__ == "__main__": 101 | main() -------------------------------------------------------------------------------- /cmd/gemma_demo/generator.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | hfd "github.com/gomlx/gemma/download/huggingface" 6 | "github.com/gomlx/gemma/samplers" 7 | "github.com/gomlx/gomlx/backends" 8 | _ "github.com/gomlx/gomlx/backends/xla" 9 | "github.com/gomlx/gomlx/ml/context" 10 | "github.com/janpfeifer/must" 11 | "os" 12 | "path" 13 | ) 14 | 15 | var ( 16 | flagDataDir = flag.String("data", "~/work/gemma", "Directory to cache downloaded and generated dataset files.") 17 | flagModelID = flag.String("model", "google/gemma-2-2b-it", "HuggingFace Gemma model id") 18 | flagMaxGeneratedTokens = flag.Int("max_tokens", 1024, "Maximum number of tokens to generate.") 19 | ) 20 | 21 | func BuildSampler() *samplers.Sampler { 22 | ctx := context.New() 23 | vocab := must.M1(hfd.Download(ctx, *flagModelID, os.Getenv("HF_TOKEN"), path.Join(*flagDataDir, "huggingface"))) 24 | return must.M1(samplers.New(backends.New(), ctx, vocab, *flagMaxGeneratedTokens)) 25 | } 26 | -------------------------------------------------------------------------------- /cmd/gemma_demo/main.go: -------------------------------------------------------------------------------- 1 | // gemma_demo uses Gemma for GoMLX to generate text given a prompt. 2 | // 3 | // It also uses github.com/charmbracelet libraries to make for a pretty command-line UI. 4 | package main 5 | 6 | import ( 7 | "flag" 8 | "fmt" 9 | tea "github.com/charmbracelet/bubbletea" 10 | "github.com/gomlx/exceptions" 11 | "os" 12 | ) 13 | 14 | func main() { 15 | flag.Parse() 16 | 17 | var p *tea.Program 18 | err := exceptions.TryCatch[error](func() { p = tea.NewProgram(newUIModel()) }) 19 | if err != nil { 20 | fmt.Fprintf(os.Stderr, "Alas, there's been an error: %+v", err) 21 | os.Exit(1) 22 | } 23 | _, err = p.Run() 24 | if err != nil { 25 | fmt.Fprintf(os.Stderr, "Alas, there's been an error: %+v", err) 26 | os.Exit(1) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /cmd/gemma_demo/ui.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "github.com/charmbracelet/bubbles/textarea" 6 | "github.com/charmbracelet/bubbles/viewport" 7 | tea "github.com/charmbracelet/bubbletea" 8 | "github.com/charmbracelet/lipgloss" 9 | "github.com/gomlx/gemma/samplers" 10 | ) 11 | 12 | type uiModel struct { 13 | textarea textarea.Model 14 | viewport viewport.Model 15 | submitted bool 16 | sampler *samplers.Sampler 17 | err error 18 | } 19 | 20 | func newUIModel() *uiModel { 21 | ta := textarea.New() 22 | ta.Placeholder = "Gemma Prompt:" 23 | ta.Focus() 24 | 25 | vp := viewport.New(0, 0) 26 | vp.Style = lipgloss.NewStyle().Margin(1, 2). 27 | Border(lipgloss.NormalBorder()).BorderForeground(lipgloss.Color("99")) 28 | 29 | return &uiModel{ 30 | textarea: ta, 31 | viewport: vp, 32 | sampler: BuildSampler(), 33 | } 34 | } 35 | 36 | func (m *uiModel) Init() tea.Cmd { 37 | return textarea.Blink 38 | } 39 | 40 | func (m *uiModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { 41 | var ( 42 | taCmd tea.Cmd 43 | vpCmd tea.Cmd 44 | cmds []tea.Cmd 45 | resize bool 46 | ) 47 | 48 | switch msg := msg.(type) { 49 | case tea.KeyMsg: 50 | switch { 51 | case msg.Type == tea.KeyCtrlC || msg.Type == tea.KeyEsc: 52 | return m, tea.Quit 53 | case msg.Type == tea.KeyCtrlL: 54 | m.textarea.Reset() 55 | 56 | case msg.Type == tea.KeyCtrlD && !m.submitted: // Ctrl+Enter to submit 57 | m.submitted = true 58 | generatedContent, err := m.Generate() 59 | if err != nil { 60 | m.err = err 61 | return m, tea.Quit 62 | } 63 | m.viewport.SetContent(generatedContent) 64 | m.textarea.Blur() 65 | m.textarea.SetValue(generatedContent) 66 | 67 | case m.submitted && msg.Type == tea.KeyEnter: // Enter while submitted to edit 68 | m.submitted = false 69 | m.textarea.Focus() 70 | } 71 | 72 | case tea.WindowSizeMsg: 73 | resize = true 74 | m.viewport.Width = msg.Width 75 | m.viewport.Height = msg.Height - 3 // Account for textarea and margins 76 | m.textarea.SetWidth(msg.Width - 4) // Account for textarea margins 77 | m.textarea.SetHeight(msg.Height - 8) 78 | } 79 | 80 | m.textarea, taCmd = m.textarea.Update(msg) 81 | m.viewport, vpCmd = m.viewport.Update(msg) 82 | 83 | if resize { 84 | cmds = append(cmds, vpCmd) 85 | } 86 | 87 | return m, tea.Batch(append(cmds, taCmd)...) 88 | } 89 | 90 | func (m *uiModel) Generate() (string, error) { 91 | outputs, err := m.sampler.Sample([]string{m.textarea.Value()}) 92 | if err != nil { 93 | return "", err 94 | } 95 | return outputs[0], nil 96 | } 97 | 98 | func (m *uiModel) View() string { 99 | if m.submitted { 100 | return fmt.Sprintf("\n%s\n\nPress Enter to edit...", m.viewport.View()) 101 | } 102 | 103 | return fmt.Sprintf( 104 | "\n%s\n\n"+ 105 | "\t\u2022 Ctrl+C or ESC to quit;\n"+ 106 | "\t• Ctrl+D to submit;\n"+ 107 | "\t• Ctrl+L to clear the prompt.\n", 108 | m.textarea.View(), 109 | ) 110 | } 111 | -------------------------------------------------------------------------------- /download/README.md: -------------------------------------------------------------------------------- 1 | # Download 2 | 3 | The Gemma models can be downloaded from Google/Kaggle site or from HuggingFace. 4 | 5 | The recommendation is to use the huggingface version, it has a couple of advantages: 6 | 7 | This has some advantages from downloading it from Google (Kaggle): 8 | 9 | - With a HuggingFace token, the process is automatic (no need to manually navigate to the site). 10 | - No need for conversion of the model, the library reads directly from the HuggingFace ".safetensors" format into 11 | GoMLX context. 12 | - No Python dependency. 13 | -------------------------------------------------------------------------------- /download/huggingface/huggingface.go: -------------------------------------------------------------------------------- 1 | // Package huggingface handles downloading Gemma model weights from HuggingFace. 2 | // 3 | // This has some advantages from downloading it from Google (Kaggle): 4 | // 5 | // - With a HuggingFace token, the process is automatic. 6 | // - No need for conversion of the model, the library reads directly from the HuggingFace ".safetensors" format into 7 | // GoMLX context. 8 | // - No Python dependency. 9 | // 10 | // Example: 11 | package huggingface 12 | 13 | import ( 14 | "fmt" 15 | "github.com/gomlx/gemma/sentencepiece" 16 | "github.com/gomlx/gomlx/ml/context" 17 | "github.com/gomlx/gomlx/ml/data" 18 | gomlxhf "github.com/gomlx/gomlx/ml/data/huggingface" 19 | "github.com/gomlx/gomlx/types/xslices" 20 | "path" 21 | "strconv" 22 | "strings" 23 | ) 24 | 25 | // Download will download (if needed) the Gemma model identified by hfID (it's a HuggingFace model id, e.g.: "google/gemma-2-2b-it"), 26 | // and save under the cacheDir (for future reuse). 27 | // 28 | // The hfAuthToken is a HuggingFace token -- read-only access -- that needs to be created once in HuggingFace site. 29 | // 30 | // It loads the weights into the given context and creates a sentencepiece tokenizer (vocab) that is returned. 31 | // 32 | // An error is returned if something fails. 33 | func Download(ctx *context.Context, hfID, hfAuthToken, cacheDir string) (vocab *sentencepiece.Tokenizer, err error) { 34 | cacheDir = data.ReplaceTildeInDir(cacheDir) 35 | var hfm *gomlxhf.Model 36 | hfm, err = gomlxhf.New(hfID, hfAuthToken, cacheDir) 37 | if err != nil { 38 | return 39 | } 40 | err = hfm.Download() 41 | if err != nil { 42 | return 43 | } 44 | 45 | vocab, err = sentencepiece.NewFromPath(path.Join(hfm.BaseDir, "tokenizer.model")) 46 | if err != nil { 47 | return 48 | } 49 | 50 | for entry, err2 := range hfm.EnumerateTensors() { 51 | if err2 != nil { 52 | err = err2 53 | return 54 | } 55 | scopeAndName := convertHuggingFaceNameToScopeAndName(entry.Name) 56 | if len(scopeAndName) == 0 { 57 | fmt.Printf("Skipping: %s -> %s\n", entry.Name, entry.Tensor.Shape()) 58 | } else { 59 | ctxTmp := ctx.In("model") 60 | name, scope := xslices.Pop(scopeAndName) 61 | for _, p := range scope { 62 | ctxTmp = ctxTmp.In(p) 63 | } 64 | ctxTmp.VariableWithValue(name, entry.Tensor) 65 | } 66 | } 67 | return 68 | } 69 | 70 | func convertHuggingFaceNameToScopeAndName(name string) []string { 71 | if name == "model.embed_tokens.weight" { 72 | return []string{"embedder", "input_embedding"} 73 | } else if name == "model.norm.weight" { 74 | return []string{"final_norm", "scale"} 75 | } 76 | 77 | // Parse the layer number for the name prefixed as "model.layers.X.<...>" 78 | if strings.HasPrefix(name, "model.layers.") { 79 | parts := strings.Split(name, ".") 80 | if len(parts) < 5 || xslices.Last(parts) != "weight" { 81 | return nil 82 | } 83 | layerNumberStr := parts[2] 84 | layerNumber, err := strconv.Atoi(layerNumberStr) 85 | if err != nil { 86 | return nil 87 | } 88 | layerScope := fmt.Sprintf("layer_%d", layerNumber) 89 | switch parts[3] { 90 | case "input_layernorm": 91 | return append([]string{layerScope, "pre_attention_norm", "scale"}) 92 | case "post_attention_layernorm": 93 | return append([]string{layerScope, "post_attention_norm", "scale"}) 94 | case "post_feedforward_layernorm": 95 | return append([]string{layerScope, "post_ffw_norm", "scale"}) 96 | case "pre_feedforward_layernorm": 97 | return append([]string{layerScope, "pre_ffw_norm", "scale"}) 98 | case "mlp": 99 | // For the MLP (the GatedFeedForwardNetwork), the weights in HuggingFace are transposed/split differently, 100 | // so they take new variable names not matching those in Kaggle version. 101 | switch parts[4] { 102 | case "down_proj": 103 | return append([]string{layerScope, "mlp", "hf", "down_proj"}) 104 | case "gate_proj": 105 | return append([]string{layerScope, "mlp", "hf", "gating_proj"}) 106 | case "up_proj": 107 | return append([]string{layerScope, "mlp", "hf", "up_proj"}) 108 | default: 109 | return nil 110 | } 111 | case "self_attn": 112 | return append([]string{layerScope, "attn", "hf", parts[4]}) 113 | default: 114 | return nil 115 | } 116 | } 117 | return nil 118 | } 119 | -------------------------------------------------------------------------------- /download/kaggle/metadata.go: -------------------------------------------------------------------------------- 1 | package kaggle 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "github.com/gomlx/gemma/trees" 7 | "github.com/gomlx/gomlx/ml/data" 8 | "github.com/pkg/errors" 9 | "os" 10 | "path" 11 | "regexp" 12 | "strings" 13 | ) 14 | 15 | const ( 16 | MetadataFileName = "_METADATA" 17 | KeyUseZarr3 = "use_zarr3" 18 | KeyTreeMetadata = "tree_metadata" 19 | KeyKeyMetadata = "key_metadata" 20 | KeyValueMetadata = "value_metadata" 21 | ) 22 | 23 | // ReadMetadata returns the metadata loaded from the given directory in the form of a tree. 24 | func ReadMetadata(checkpointDir string) (tree *trees.Tree[*Metadata], err error) { 25 | checkpointDir = data.ReplaceTildeInDir(checkpointDir) 26 | metadataPath := path.Join(checkpointDir, MetadataFileName) 27 | var f *os.File 28 | f, err = os.Open(metadataPath) 29 | if err != nil { 30 | err = errors.Wrapf(err, "failed to read aggregate checkpoint file from %q", metadataPath) 31 | return 32 | } 33 | defer func() { _ = f.Close() }() 34 | 35 | dec := json.NewDecoder(f) 36 | var jsonTree any 37 | err = dec.Decode(&jsonTree) 38 | if err != nil { 39 | return 40 | } 41 | tree, err = fromJsonTreeMetaData(jsonTree) 42 | return 43 | } 44 | 45 | // Metadata of one checkpoint entry (usually weights or embedding tables of model) 46 | type Metadata struct { 47 | KeyPath []string 48 | 49 | KeyToKeyType map[string]MetadataKeyType 50 | ValueType string 51 | SkipDeserialize bool 52 | } 53 | 54 | // String implements fmt.Stringer. 55 | func (m *Metadata) String() string { 56 | deserialize := "" 57 | if m.SkipDeserialize { 58 | deserialize = " [*]" 59 | } 60 | return fmt.Sprintf("%s%s", m.ValueType, deserialize) 61 | } 62 | 63 | type MetadataKeyType int 64 | 65 | const ( 66 | KeyTypeSequence MetadataKeyType = 1 67 | KeyTypeDict = 2 68 | ) 69 | 70 | func fromJsonTreeMetaData(jsonTree any) (tree *trees.Tree[*Metadata], err error) { 71 | mapAny, ok := jsonTree.(map[string]any) 72 | if !ok { 73 | err = errors.Errorf("expected json to be a map of strings, got %T instead", jsonTree) 74 | return 75 | } 76 | _ = mapAny 77 | tree = trees.New[*Metadata]() 78 | for key, value := range mapAny { 79 | switch key { 80 | case KeyUseZarr3: 81 | // Check for Zarr3: not supported. 82 | zarr3, ok := value.(bool) 83 | if !ok { 84 | err = errors.Errorf("metadata json value for key %q is not a bool, got %T instead", key, value) 85 | return 86 | } 87 | if zarr3 { 88 | err = errors.Errorf("%q set to true, but Zarr3 is not supported by this library", key) 89 | return 90 | } 91 | case KeyTreeMetadata: 92 | entries, ok := value.(map[string]any) 93 | if !ok { 94 | err = errors.Errorf("metadata json value for key %q is not a map[string]any, got %T instead", key, value) 95 | return 96 | } 97 | for keyPath, jsonEntryAny := range entries { 98 | jsonEntry, ok := jsonEntryAny.(map[string]any) 99 | if !ok { 100 | err = errors.Errorf("metadata json value for key %q/%q is not a map[string]any, got %T instead", key, keyPath, jsonEntryAny) 101 | return 102 | } 103 | err = parseJsonMetadataEntry(tree, keyPath, jsonEntry) 104 | if err != nil { 105 | err = errors.WithMessagef(err, "metadata json value for key %q/%q", key, keyPath) 106 | return 107 | } 108 | } 109 | 110 | default: 111 | err = errors.Errorf("metadata json key %q unknown, don't know how to proceed", key) 112 | return 113 | } 114 | 115 | } 116 | return 117 | } 118 | 119 | func parseJsonMetadataEntry(tree *trees.Tree[*Metadata], keyPath string, jsonEntry map[string]any) error { 120 | entry := &Metadata{} 121 | if err := parseKeyPath(entry, keyPath); err != nil { 122 | return err 123 | } 124 | 125 | keyMetadataJsonAny, found := jsonEntry[KeyKeyMetadata] 126 | if !found { 127 | return errors.Errorf("missing KeyMetadata (key %q)", KeyKeyMetadata) 128 | } 129 | keyMetadataJson, ok := keyMetadataJsonAny.([]any) 130 | if !ok { 131 | return errors.Errorf("invalid KeyMetadata (key %q) type %T, expected []any", KeyKeyMetadata, keyMetadataJsonAny) 132 | } 133 | parseKeyMetadata(entry, keyMetadataJson) 134 | 135 | valueMetadataJsonAny, found := jsonEntry[KeyValueMetadata] 136 | if !found { 137 | return errors.Errorf("missing ValueMetadata (key %q)", KeyValueMetadata) 138 | } 139 | valueMetadataJson, ok := valueMetadataJsonAny.(map[string]any) 140 | if !ok { 141 | return errors.Errorf("invalid ValueMetadata (key %q) type %T, expected map[string]any", KeyValueMetadata, valueMetadataJsonAny) 142 | } 143 | parseValueMetadata(entry, valueMetadataJson) 144 | tree.Set(entry.KeyPath, entry) 145 | return nil 146 | } 147 | 148 | var reParseKeyPath = regexp.MustCompile(`'(.*?)'\s*[,)]`) 149 | 150 | func parseKeyPath(metadata *Metadata, keyPathStr string) error { 151 | matches := reParseKeyPath.FindAllStringSubmatch(keyPathStr, -1) 152 | if len(matches) == 0 { 153 | return errors.Errorf("can't parse keypath from %q", keyPathStr) 154 | } 155 | for _, match := range matches { 156 | metadata.KeyPath = append(metadata.KeyPath, match[1]) 157 | } 158 | return nil 159 | } 160 | 161 | func parseKeyMetadata(metadata *Metadata, keyMetadataJson []any) { 162 | if metadata.KeyToKeyType == nil { 163 | metadata.KeyToKeyType = make(map[string]MetadataKeyType) 164 | } 165 | for _, entryAny := range keyMetadataJson { 166 | entry := entryAny.(map[string]any) 167 | metadata.KeyToKeyType[entry["key"].(string)] = MetadataKeyType(entry["key_type"].(float64)) 168 | } 169 | } 170 | 171 | func parseValueMetadata(metadata *Metadata, valueMetadataJson map[string]any) { 172 | metadata.ValueType = valueMetadataJson["value_type"].(string) 173 | metadata.SkipDeserialize = valueMetadataJson["skip_deserialize"].(bool) 174 | } 175 | 176 | // ParamNames convert metadata to the paramNames (?? not sure where in Gemma this is sued) 177 | func ParamNames(metadata *trees.Tree[*Metadata]) *trees.Tree[string] { 178 | return trees.Map(metadata, func(treePath trees.Path, metadata *Metadata) string { 179 | return strings.Join(treePath, ".") 180 | }) 181 | } 182 | -------------------------------------------------------------------------------- /download/kaggle/weights.go: -------------------------------------------------------------------------------- 1 | // Package kaggle loads Gemma weights into tensors along with the matching metadata, after they 2 | // have been downloaded from kaggle and converted using the included cmd/convert_checkpoint.py. 3 | package kaggle 4 | 5 | import ( 6 | "github.com/dustin/go-humanize" 7 | "github.com/gomlx/gemma/trees" 8 | "github.com/gomlx/gomlx/ml/context" 9 | "github.com/gomlx/gomlx/ml/data" 10 | "github.com/gomlx/gomlx/types/shapes" 11 | "github.com/gomlx/gomlx/types/tensors" 12 | "github.com/gomlx/gomlx/types/xslices" 13 | "github.com/gomlx/gopjrt/dtypes" 14 | "github.com/janpfeifer/must" 15 | "github.com/pkg/errors" 16 | "github.com/vmihailenco/msgpack" 17 | "io" 18 | "io/fs" 19 | "os" 20 | "path" 21 | "path/filepath" 22 | "strconv" 23 | "strings" 24 | ) 25 | 26 | const ( 27 | AggregateFileName = "checkpoint" 28 | 29 | // OCDBTManifestFileName indicates usage of "Orbax Consistent Distributed Backend Tree" (OCDBT). 30 | OCDBTManifestFileName = "manifest.ocdbt" 31 | ) 32 | 33 | // ReadConvertedWeights from checkpointDir (under the "raw/" subdirectory). 34 | // It will read the weights and shape converted by the `convert_checkpoint.py` script 35 | // (see github.com/gomlx/gemma repository, under cmd/convert_checkpoint.py) and set them 36 | // in the given context, under its current scope. 37 | func ReadConvertedWeights(ctx *context.Context, checkpointDir string) error { 38 | weights, err := ReadConvertedWeightsToTree(checkpointDir) 39 | if err != nil { 40 | return err 41 | } 42 | UploadWeightsToContext(ctx.In("model"), weights) 43 | return nil 44 | } 45 | 46 | // ReadConvertedWeightsToTree from checkpointDir (under the "raw/" subdirectory). 47 | // It will read the weights and shape converted by the `convert_checkpoint.py` script 48 | // (see github.com/gomlx/gemma repository, under cmd/convert_checkpoint.py) 49 | // 50 | // It returns a tree of tensors, with the path matching those of the original Jax checkpoint. 51 | func ReadConvertedWeightsToTree(checkpointDir string) (tree *trees.Tree[*tensors.Tensor], err error) { 52 | rawDir := path.Join(checkpointDir, "raw") 53 | if !data.FileExists(rawDir) { 54 | err = errors.Errorf( 55 | "ReadConvertedWeights(%q), the given directory doesn't have a subdirectory 'raw/' with the converted files", 56 | checkpointDir) 57 | return 58 | } 59 | tree = trees.New[*tensors.Tensor]() 60 | err = fs.WalkDir(os.DirFS(rawDir), ".", func(filePath string, entry fs.DirEntry, err error) error { 61 | if err != nil { 62 | return errors.Wrapf(err, "failed to traverse %q", rawDir) 63 | } 64 | if entry.IsDir() { 65 | return nil 66 | } 67 | ext := filepath.Ext(filePath) 68 | if ext != ".raw" { 69 | return nil 70 | } 71 | 72 | // Here we have teh pair of files ".shape" and ".raw": 73 | base := strings.TrimSuffix(filePath, ext) 74 | basePath := path.Join(rawDir, base) 75 | shapeFilePath := basePath + ".shape" 76 | if !data.FileExists(shapeFilePath) { 77 | return nil 78 | } 79 | shapeBytes, err := os.ReadFile(shapeFilePath) 80 | if err != nil { 81 | return errors.Wrapf(err, "failed to read shape from %q", shapeFilePath) 82 | } 83 | shapeParts := strings.Split(string(shapeBytes), ",") 84 | dtype, err := dtypes.DTypeString(shapeParts[0]) 85 | if err != nil { 86 | return errors.Wrapf(err, "unknown dtype read from %q", shapeFilePath) 87 | } 88 | shapeDims := xslices.Map(shapeParts[1:], func(s string) (v int) { 89 | if err != nil { 90 | return 0 91 | } 92 | v, err = strconv.Atoi(s) 93 | return 94 | }) 95 | if err != nil { 96 | return errors.Wrapf(err, "failed to convert %q to a dimension, read from %q", shapeBytes, basePath+".shape") 97 | } 98 | shape := shapes.Make(dtype, shapeDims...) 99 | 100 | rawFilePath := basePath + ".raw" 101 | info, err := entry.Info() 102 | if err != nil { 103 | return errors.Wrapf(err, "failed to get info from %q", rawFilePath) 104 | } 105 | if info.Size() != int64(shape.Memory()) { 106 | return errors.Errorf("file %q has %d bytes, but shape %s (read from %q) requires %d bytes, something went wrong in the conversion", 107 | rawFilePath, info.Size(), shape, shapeFilePath, shape.Memory()) 108 | } 109 | 110 | treePath := strings.Split(base, "/") 111 | //fmt.Printf("%q -> %s\n", treePath, shape) 112 | 113 | tensor := tensors.FromShape(shape) 114 | f, err := os.Open(rawFilePath) 115 | if err != nil { 116 | return errors.Wrapf(err, "failed to open raw data from %q", rawFilePath) 117 | } 118 | var n int 119 | tensor.MutableBytes(func(data []byte) { 120 | n, err = io.ReadFull(f, data) 121 | }) 122 | _ = f.Close() 123 | if err != nil { 124 | return errors.Wrapf(err, "failed to read raw data from %q", rawFilePath) 125 | } 126 | if n != int(shape.Memory()) { 127 | return errors.Errorf("read %d bytes from %q, expected %d", n, rawFilePath, shape.Memory()) 128 | } 129 | err = tree.Set(treePath, tensor) 130 | if err != nil { 131 | return errors.WithMessagef(err, "failed to set variable with %s", humanize.Bytes(uint64(n))) 132 | } 133 | return nil 134 | }) 135 | if err != nil { 136 | return nil, err 137 | } 138 | return 139 | } 140 | 141 | // PyReadAggregate of Python checkpoint. Not used by Gemma v2. 142 | func PyReadAggregate(checkpointDir string) (results any, err error) { 143 | checkpointDir = data.ReplaceTildeInDir(checkpointDir) 144 | aggregatePath := path.Join(checkpointDir, AggregateFileName) 145 | var f *os.File 146 | f, err = os.Open(aggregatePath) 147 | if err != nil { 148 | err = errors.Wrapf(err, "failed to read aggregate checkpoint file from %q", aggregatePath) 149 | return 150 | } 151 | 152 | dec := msgpack.NewDecoder(f) 153 | results, err = dec.DecodeMap() 154 | defer func() { _ = f.Close() }() 155 | return 156 | } 157 | 158 | func isOCDBT(checkpointDir string) bool { 159 | checkpointDir = data.ReplaceTildeInDir(checkpointDir) 160 | ocdbtPath := path.Join(checkpointDir, OCDBTManifestFileName) 161 | return data.FileExists(ocdbtPath) 162 | } 163 | 164 | type PyParamInfo struct { 165 | Name, Path string 166 | SkipDeserialize bool 167 | } 168 | 169 | func PyReadParamInfo(checkpointDir string) *trees.Tree[*PyParamInfo] { 170 | checkpointDir = data.ReplaceTildeInDir(checkpointDir) 171 | metadata := must.M1(ReadMetadata(checkpointDir)) 172 | return trees.Map(metadata, func(p trees.Path, meta *Metadata) *PyParamInfo { 173 | name := strings.Join(p, ".") 174 | pi := &PyParamInfo{ 175 | Name: name, 176 | Path: path.Join(checkpointDir, name), 177 | SkipDeserialize: meta.SkipDeserialize, 178 | } 179 | return pi 180 | }) 181 | } 182 | 183 | // UploadWeightsToContext creates variables corresponding to the weights. 184 | // It returns the ctx given, with the variables set. 185 | // 186 | // It's tightly coupled with the model building functions in this package. 187 | // Meaning the modeling must match the naming here. 188 | func UploadWeightsToContext(ctx *context.Context, weights *trees.Tree[*tensors.Tensor]) { 189 | weights = weights.Map["transformer"] 190 | for treePath, tensor := range weights.Leaves() { 191 | scopedCtx := ctx 192 | scopeParts := treePath[:len(treePath)-1] 193 | for _, p := range scopeParts { 194 | scopedCtx = scopedCtx.In(p) 195 | } 196 | varName := treePath[len(treePath)-1] 197 | _ = scopedCtx.VariableWithValue(varName, tensor) 198 | } 199 | } 200 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/gomlx/gemma 2 | 3 | go 1.23.0 4 | 5 | require ( 6 | github.com/charmbracelet/bubbles v0.20.0 7 | github.com/charmbracelet/bubbletea v1.1.2 8 | github.com/charmbracelet/lipgloss v1.0.0 9 | github.com/dustin/go-humanize v1.0.1 10 | github.com/eliben/go-sentencepiece v0.6.0 11 | github.com/gomlx/exceptions v0.0.3 12 | github.com/gomlx/gomlx v0.15.0 13 | github.com/gomlx/gopjrt v0.4.4 14 | github.com/janpfeifer/must v0.2.0 15 | github.com/pkg/errors v0.9.1 16 | github.com/stretchr/testify v1.9.0 17 | github.com/vmihailenco/msgpack v4.0.4+incompatible 18 | golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c 19 | k8s.io/klog/v2 v2.130.1 20 | ) 21 | 22 | require ( 23 | github.com/atotto/clipboard v0.1.4 // indirect 24 | github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect 25 | github.com/charmbracelet/x/ansi v0.4.2 // indirect 26 | github.com/charmbracelet/x/term v0.2.0 // indirect 27 | github.com/davecgh/go-spew v1.1.1 // indirect 28 | github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect 29 | github.com/go-logr/logr v1.4.2 // indirect 30 | github.com/golang/protobuf v1.5.4 // indirect 31 | github.com/google/uuid v1.6.0 // indirect 32 | github.com/lucasb-eyer/go-colorful v1.2.0 // indirect 33 | github.com/mattn/go-isatty v0.0.20 // indirect 34 | github.com/mattn/go-localereader v0.0.1 // indirect 35 | github.com/mattn/go-runewidth v0.0.16 // indirect 36 | github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect 37 | github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect 38 | github.com/muesli/cancelreader v0.2.2 // indirect 39 | github.com/muesli/termenv v0.15.2 // indirect 40 | github.com/pmezard/go-difflib v1.0.0 // indirect 41 | github.com/rivo/uniseg v0.4.7 // indirect 42 | github.com/schollz/progressbar/v3 v3.17.0 // indirect 43 | github.com/x448/float16 v0.8.4 // indirect 44 | golang.org/x/sync v0.8.0 // indirect 45 | golang.org/x/sys v0.26.0 // indirect 46 | golang.org/x/term v0.25.0 // indirect 47 | golang.org/x/text v0.19.0 // indirect 48 | google.golang.org/appengine v1.6.8 // indirect 49 | google.golang.org/protobuf v1.35.1 // indirect 50 | gopkg.in/yaml.v3 v3.0.1 // indirect 51 | ) 52 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ= 2 | github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= 3 | github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= 4 | github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= 5 | github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= 6 | github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= 7 | github.com/charmbracelet/bubbles v0.20.0 h1:jSZu6qD8cRQ6k9OMfR1WlM+ruM8fkPWkHvQWD9LIutE= 8 | github.com/charmbracelet/bubbles v0.20.0/go.mod h1:39slydyswPy+uVOHZ5x/GjwVAFkCsV8IIVy+4MhzwwU= 9 | github.com/charmbracelet/bubbletea v1.1.2 h1:naQXF2laRxyLyil/i7fxdpiz1/k06IKquhm4vBfHsIc= 10 | github.com/charmbracelet/bubbletea v1.1.2/go.mod h1:9HIU/hBV24qKjlehyj8z1r/tR9TYTQEag+cWZnuXo8E= 11 | github.com/charmbracelet/lipgloss v1.0.0 h1:O7VkGDvqEdGi93X+DeqsQ7PKHDgtQfF8j8/O2qFMQNg= 12 | github.com/charmbracelet/lipgloss v1.0.0/go.mod h1:U5fy9Z+C38obMs+T+tJqst9VGzlOYGj4ri9reL3qUlo= 13 | github.com/charmbracelet/x/ansi v0.4.2 h1:0JM6Aj/g/KC154/gOP4vfxun0ff6itogDYk41kof+qk= 14 | github.com/charmbracelet/x/ansi v0.4.2/go.mod h1:dk73KoMTT5AX5BsX0KrqhsTqAnhZZoCBjs7dGWp4Ktw= 15 | github.com/charmbracelet/x/term v0.2.0 h1:cNB9Ot9q8I711MyZ7myUR5HFWL/lc3OpU8jZ4hwm0x0= 16 | github.com/charmbracelet/x/term v0.2.0/go.mod h1:GVxgxAbjUrmpvIINHIQnJJKpMlHiZ4cktEQCN6GWyF0= 17 | github.com/chengxilo/virtualterm v1.0.4 h1:Z6IpERbRVlfB8WkOmtbHiDbBANU7cimRIof7mk9/PwM= 18 | github.com/chengxilo/virtualterm v1.0.4/go.mod h1:DyxxBZz/x1iqJjFxTFcr6/x+jSpqN0iwWCOK1q10rlY= 19 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 20 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 21 | github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= 22 | github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= 23 | github.com/eliben/go-sentencepiece v0.6.0 h1:wbnefMCxYyVYmeTVtiMJet+mS9CVwq5klveLpfQLsnk= 24 | github.com/eliben/go-sentencepiece v0.6.0/go.mod h1:nNYk4aMzgBoI6QFp4LUG8Eu1uO9fHD9L5ZEre93o9+c= 25 | github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= 26 | github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= 27 | github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= 28 | github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= 29 | github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= 30 | github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= 31 | github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= 32 | github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= 33 | github.com/gomlx/exceptions v0.0.3 h1:HKnTgEjj4jlmhr8zVFkTP9qmV1ey7ypYYosQ8GzXWuM= 34 | github.com/gomlx/exceptions v0.0.3/go.mod h1:uHL0TQwJ0xaV2/snJOJV6hSE4yRmhhfymuYgNredGxU= 35 | github.com/gomlx/gomlx v0.15.0 h1:+gRsUzT6O/kcx+wa8pZqGi1JeQKyoCN17IQ9d7d8+gc= 36 | github.com/gomlx/gomlx v0.15.0/go.mod h1:nHkPBFi4z8CyZ5pOmViS8cUbImEB19XH3CVUwlbWjfk= 37 | github.com/gomlx/gopjrt v0.4.4 h1:Pc3jcBTsWmkf58pEuGnTfsf3eFhnX7UHhm/u/6VKczU= 38 | github.com/gomlx/gopjrt v0.4.4/go.mod h1:KkKFWTOGpJ1gludkzITb5/VOxpr5sy7qsTHvehcjTQA= 39 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 40 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 41 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 42 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 43 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 44 | github.com/janpfeifer/must v0.2.0 h1:yWy1CE5gtk1i2ICBvqAcMMXrCMqil9CJPkc7x81fRdQ= 45 | github.com/janpfeifer/must v0.2.0/go.mod h1:S6c5Yg/YSMR43cJw4zhIq7HFMci90a7kPY9XA4c8UIs= 46 | github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= 47 | github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= 48 | github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= 49 | github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= 50 | github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= 51 | github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= 52 | github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= 53 | github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= 54 | github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= 55 | github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= 56 | github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= 57 | github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= 58 | github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= 59 | github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= 60 | github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo= 61 | github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8= 62 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 63 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 64 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 65 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 66 | github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= 67 | github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= 68 | github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= 69 | github.com/schollz/progressbar/v3 v3.17.0 h1:Fv+vG6O6jnJwdjCelvfyYO7sF2jaUGQVmdH4CxcZdsQ= 70 | github.com/schollz/progressbar/v3 v3.17.0/go.mod h1:5H4fLgifX+KeQCsEJnZTOepgZLe1jFF1lpPXb68IJTA= 71 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 72 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 73 | github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaUXK79GlxNBwueZn0xI= 74 | github.com/vmihailenco/msgpack v4.0.4+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk= 75 | github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= 76 | github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= 77 | github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= 78 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 79 | golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= 80 | golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY= 81 | golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8= 82 | golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= 83 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 84 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= 85 | golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= 86 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 87 | golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 88 | golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= 89 | golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 90 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 91 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 92 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 93 | golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 94 | golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 95 | golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 96 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 97 | golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= 98 | golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 99 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 100 | golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= 101 | golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= 102 | golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= 103 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 104 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 105 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= 106 | golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= 107 | golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= 108 | golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= 109 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 110 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 111 | golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= 112 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 113 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 114 | google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM= 115 | google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds= 116 | google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= 117 | google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= 118 | google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= 119 | google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= 120 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 121 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 122 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 123 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 124 | k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk= 125 | k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE= 126 | -------------------------------------------------------------------------------- /samplers/samplers.go: -------------------------------------------------------------------------------- 1 | // Package samplers uses a transformer model to generate senteces based on prompts. 2 | package samplers 3 | 4 | import ( 5 | "github.com/dustin/go-humanize" 6 | "github.com/gomlx/exceptions" 7 | "github.com/gomlx/gemma/transformers" 8 | "github.com/gomlx/gemma/trees" 9 | "github.com/gomlx/gomlx/backends" 10 | . "github.com/gomlx/gomlx/graph" 11 | "github.com/gomlx/gomlx/ml/context" 12 | "github.com/gomlx/gomlx/types/shapes" 13 | "github.com/gomlx/gomlx/types/tensors" 14 | "github.com/gomlx/gomlx/types/xslices" 15 | "github.com/gomlx/gopjrt/dtypes" 16 | klog "k8s.io/klog/v2" 17 | "slices" 18 | "time" 19 | ) 20 | 21 | type Vocabulary interface { 22 | EncodeAsIDs(text string) []int 23 | DecodeIDs([]int) string 24 | 25 | // The methods below define the special ids for the model. 26 | 27 | BeginningOfSentenceID() int 28 | EndOfSentenceID() int 29 | UnknownID() int 30 | PadID() int 31 | } 32 | 33 | // Sampler has a transformer (LLM) model and a vocabulary (sentencepiece) configured and generates 34 | // sentences based on prompts. 35 | type Sampler struct { 36 | Backend backends.Backend 37 | Vocab Vocabulary 38 | 39 | // MaxGeneratedTokens default for Sampler.Sample. 40 | MaxGeneratedTokens int 41 | 42 | // Context with the model weights, used to execute the model. 43 | Context *context.Context 44 | 45 | // SampleStep graph computation. 46 | SampleStep *context.Exec 47 | 48 | // Config of the Gemma model, created from the weights. 49 | Config *transformers.Config 50 | 51 | // CacheTreeStructure holds the structure of the tree used for caching: the tree structure (paths) is stable 52 | // across different calls to Sample. 53 | CacheTreeStructure *trees.Tree[struct{}] 54 | } 55 | 56 | // New creates a new sampler with the registered vocabulary and model. 57 | func New(backend backends.Backend, ctx *context.Context, vocab Vocabulary, maxGeneratedTokens int) (*Sampler, error) { 58 | s := &Sampler{ 59 | Backend: backend, 60 | Vocab: vocab, 61 | MaxGeneratedTokens: maxGeneratedTokens, 62 | Context: ctx, 63 | } 64 | var err error 65 | s.Config, err = transformers.NewConfigFromContext(s.Context.In("model")) 66 | if err != nil { 67 | return nil, err 68 | } 69 | s.Context = s.Context.Reuse() 70 | s.SampleStep = context.NewExec(backend, s.Context, s.sampleStepGraphFn()) 71 | return s, nil 72 | } 73 | 74 | // Sample the continuation from the given prompts. 75 | func (s *Sampler) Sample(prompts []string) ([]string, error) { 76 | return s.SampleMaxTokens(prompts, s.MaxGeneratedTokens) 77 | } 78 | 79 | // SampleMaxTokens is like Sample, but instead of using the default MaxGenerateTokens, uses the given maxTokens instead. 80 | func (s *Sampler) SampleMaxTokens(prompts []string, maxTokens int) ([]string, error) { 81 | promptIds := xslices.Map(prompts, s.Vocab.EncodeAsIDs) 82 | state, err := s.initialState(promptIds, maxTokens) 83 | if err != nil { 84 | return nil, err 85 | } 86 | err = exceptions.TryCatch[error](func() { 87 | state = s.sampleLoop(state) 88 | }) 89 | if err != nil { 90 | return nil, err 91 | } 92 | return s.decode(state), nil 93 | } 94 | 95 | // sampleLoop, executes a sampleStep until all examples in the batch are finished. 96 | func (s *Sampler) sampleLoop(state samplingState) samplingState { 97 | // Prepare inputs as slice: 98 | // * If you change this order, change the parsing order in sampleStepGraphFn below. 99 | inputs := []any{ 100 | state.InputBuffer, 101 | state.StepNum, 102 | state.Done, 103 | } 104 | // * Append cache values. 105 | cacheValues := trees.ValuesAsList(state.Cache.Data) 106 | inputs = append(inputs, xslices.Map(cacheValues, func(t *tensors.Tensor) any { return t })...) 107 | numMutableInputs := len(inputs) 108 | 109 | // Append constant inputs. 110 | start := time.Now() 111 | inputs = append(inputs, 112 | state.Positions, 113 | ) 114 | var outputs []*tensors.Tensor 115 | var execTime, inputsPrepTime time.Duration 116 | var count int 117 | for { 118 | inputPrepStart := time.Now() 119 | // We donate all the inputs, since they are all going to be updated (saves some GPU memory). 120 | for ii := range numMutableInputs { 121 | inputs[ii] = DonateTensorBuffer(inputs[ii].(*tensors.Tensor), s.Backend) 122 | } 123 | inputsPrepTime += time.Since(inputPrepStart) 124 | 125 | // Execute a step. 126 | execStart := time.Now() 127 | outputs = s.SampleStep.Call(inputs...) 128 | execTime += time.Since(execStart) 129 | count++ 130 | 131 | // Update states (the output has the same order as the input). 132 | for ii := range numMutableInputs { 133 | inputs[ii] = outputs[ii] 134 | } 135 | extraOutputs := outputs[numMutableInputs:] // Separate the transient outputs. 136 | done := tensors.ToScalar[bool](extraOutputs[0]) 137 | 138 | // End-of-sampling: 139 | if done { 140 | break 141 | } 142 | } 143 | if klog.V(1).Enabled() { 144 | elapsed := time.Since(start) 145 | klog.Infof("Sample execution time (%d steps): %s", count, elapsed) 146 | klog.Infof("> Graph execution time: %s", execTime) 147 | klog.Infof("> Inputs preparation time: %s", inputsPrepTime) 148 | } 149 | state.InputBuffer = outputs[0] 150 | state.Positions = outputs[1] 151 | state.StepNum = outputs[2] 152 | state.Done = outputs[3] 153 | updatedCache := trees.FromValuesAndTree(outputs[4:4+s.CacheTreeStructure.NumLeaves()], s.CacheTreeStructure) 154 | state.Cache.Data = updatedCache 155 | return state 156 | } 157 | 158 | // buildSampleStepGraphFn returns the computation graph building function for this sampler. 159 | // The returned function can be used by context.NewExec. 160 | func (s *Sampler) sampleStepGraphFn() func(*context.Context, []*Node) []*Node { 161 | return func(ctx *context.Context, state []*Node) []*Node { 162 | g := state[0].Graph() // Reference to the underlying graph, it could be from any of the inputs. 163 | _ = ctx 164 | 165 | // Extract state parts: 166 | stateFieldsIdx := 0 167 | nextState := func() *Node { 168 | field := state[stateFieldsIdx] 169 | stateFieldsIdx++ 170 | return field 171 | } 172 | // This order has to match the order fed in sampleLoop. 173 | // - Mutable fields, to be updated. 174 | inputBuffer := nextState() 175 | stepNum := nextState() 176 | done := nextState() 177 | numCacheValues := s.CacheTreeStructure.NumLeaves() 178 | cache := trees.FromValuesAndTree(state[stateFieldsIdx:stateFieldsIdx+numCacheValues], s.CacheTreeStructure) 179 | stateFieldsIdx += numCacheValues 180 | 181 | // - Constant fields. 182 | positions := nextState() 183 | 184 | // Take the current step token for all examples of the batch. 185 | batchSize := inputBuffer.Shape().Dimensions[0] 186 | zeroIdx := ScalarZero(g, dtypes.Int32) 187 | currentTokens := DynamicSlice(inputBuffer, []*Node{zeroIdx, stepNum}, []int{batchSize, 1}) 188 | currentTokens.AssertDims(batchSize, 1) 189 | currentPositions := DynamicSlice(positions, []*Node{zeroIdx, stepNum}, []int{batchSize, 1}) 190 | currentPositions.AssertDims(batchSize, 1) 191 | 192 | // Attention to all positions < current step number (stepNum). 193 | // Notice that the cache rotates, so once stepNum > Config.MaxCacheLength, the mask will be 194 | // true everywhere. 195 | cacheAttentionMask := Iota(g, shapes.Make(dtypes.Int32, batchSize, 1, s.Config.MaxCacheLength), -1) 196 | cacheAttentionMask = LessOrEqual(cacheAttentionMask, stepNum) 197 | 198 | logits := transformers.GemmaWithCache(ctx.In("model"), s.Config, 199 | currentTokens, currentPositions, cache, cacheAttentionMask) 200 | logits.AssertDims(batchSize, 1, s.Config.VocabularySize) 201 | 202 | nextTokenNum := OnePlus(stepNum) 203 | nextPredictedTokens := ArgMax(logits, -1) 204 | nextPredictedTokens.AssertDims(batchSize, 1) 205 | nextTokenStartIdx := []*Node{zeroIdx, nextTokenNum} 206 | nextTokens := DynamicSlice(inputBuffer, nextTokenStartIdx, []int{batchSize, 1}) 207 | nextTokens.AssertDims(batchSize, 1) 208 | nextTokens = Where( 209 | Or( 210 | Equal(nextTokens, Const(g, int32(s.Vocab.PadID()))), 211 | ExpandAxes(done, -1), 212 | ), 213 | nextPredictedTokens, 214 | nextTokens, 215 | ) 216 | inputBuffer = DynamicUpdateSlice(inputBuffer, nextTokens, nextTokenStartIdx) 217 | eosToken := Scalar(g, dtypes.Int32, s.Vocab.EndOfSentenceID()) 218 | nextTokenIsEOS := Squeeze(Equal(nextTokens, eosToken), -1) 219 | done = Or(done, nextTokenIsEOS) 220 | 221 | // Prepare next step: are we done ? 222 | stepNum = nextTokenNum 223 | maxSteps := inputBuffer.Shape().Dimensions[1] - 2 224 | allDone := Or( 225 | LogicalAll(done), 226 | GreaterOrEqual(stepNum, Const(g, int32(maxSteps))), 227 | ) 228 | 229 | // Outputs: updated mutable values first including cache): 230 | outputs := []*Node{inputBuffer, stepNum, done} 231 | outputs = append(outputs, trees.ValuesAsList(cache)...) 232 | // - Other results: 233 | outputs = append(outputs, allDone) 234 | return outputs 235 | } 236 | } 237 | 238 | // samplingState holds the state of the sampling loop plus some constants for the loop. 239 | type samplingState struct { 240 | // BatchSize, MaxTokens, TotalLength are constants for one sampling. 241 | BatchSize, MaxTokens, TotalLength int 242 | 243 | // InputBuffer holds the ids with prepended (beginning-of-sentence) and padding () and extra space for 244 | // an (end-of-sentence). 245 | InputBuffer *tensors.Tensor 246 | 247 | // NumInputTokens is the number of tokens on the original input per example: shaped int32[batch_size]. 248 | NumInputTokens *tensors.Tensor 249 | 250 | // Positions for each token, see transformers.BuildPositionsFromMask 251 | Positions *tensors.Tensor 252 | 253 | // StepNum is a scalar counter of the steps sampled (decoded) so far. 254 | StepNum *tensors.Tensor 255 | 256 | // Done is a vector of the inputs who are done with the generation: shaped bool[batch_size]. 257 | Done *tensors.Tensor 258 | 259 | // Cache used during the sampling. 260 | Cache *transformers.Cache 261 | } 262 | 263 | // initialState creates a tensor shaped int32[batchSize, totalLength+2] padded with the Vocab.PadId filled (left to right) 264 | // with the given promptIds. 265 | // 266 | // It also returns the mask, that is set to true where it is not padding. 267 | // 268 | // It also adds a "bos" (beginning of sentence) token to each prompt. 269 | func (s *Sampler) initialState(promptIds [][]int, maxTokens int) (state samplingState, err error) { 270 | state.MaxTokens = maxTokens 271 | state.BatchSize = len(promptIds) 272 | batchSize := state.BatchSize 273 | 274 | lengths := xslices.Map(promptIds, func(seq []int) int32 { return int32(len(seq)) + 1 }) // +1 for (beginning-of-sentence) token. 275 | state.NumInputTokens = tensors.FromValue(lengths) // Shape [batchSize] 276 | maxInputLength := int(slices.Max(lengths)) 277 | state.TotalLength = maxInputLength + maxTokens + 1 // +1 for . 278 | totalLength := state.TotalLength 279 | 280 | state.StepNum = tensors.FromScalar(int32(0)) 281 | state.InputBuffer = tensors.FromScalarAndDimensions(int32(s.Vocab.PadID()), batchSize, totalLength) 282 | bos := int32(s.Vocab.BeginningOfSentenceID()) 283 | 284 | // Copy over "ragged" promptIds to dense InputBuffer (filled with ), prepending , 285 | // and set InputMask to true where InputBuffer != . 286 | tensors.MutableFlatData(state.InputBuffer, func(flatIDs []int32) { 287 | for exampleIdx := range batchSize { 288 | exampleIds := flatIDs[exampleIdx*totalLength : (exampleIdx+1)*totalLength] 289 | exampleIds[0] = bos 290 | for ii, value := range promptIds[exampleIdx] { 291 | exampleIds[1+ii] = int32(value) 292 | } 293 | } 294 | }) 295 | 296 | // Notice that the convoluted code in https://github.com/google-deepmind/gemma/blob/main/gemma/sampler.py 297 | // (see Sampler.init_sample_state()) in the end simply does the same as Iota() -- except if the input 298 | // has pad symbols (not the padding added by filling the ragged input) inside it -- which is not doable when 299 | // converting from string. 300 | // 301 | // Probably there is a bug in the original code ... it's not documented what they intended to do with it, but 302 | // we simply take the iota here. 303 | state.Positions = ExecOnce(s.Backend, func(g *Graph) *Node { 304 | return Iota(g, shapes.Make(dtypes.Int32, batchSize, totalLength), -1) 305 | }) 306 | 307 | state.Done = tensors.FromShape(shapes.Make(dtypes.Bool, batchSize)) 308 | 309 | // Setup cache, and if not yet setup, configure cache structure. 310 | var start time.Time 311 | if klog.V(1).Enabled() { 312 | start = time.Now() 313 | } 314 | state.Cache, err = transformers.NewCache(s.Config, batchSize) 315 | if err != nil { 316 | return 317 | } 318 | if s.CacheTreeStructure == nil { 319 | s.CacheTreeStructure = trees.Map(state.Cache.Data, func(_ trees.Path, _ *tensors.Tensor) (empty struct{}) { return }) 320 | } 321 | 322 | if klog.V(1).Enabled() { 323 | elapsed := time.Since(start) 324 | var cacheMem uintptr 325 | for _, t := range state.Cache.Data.Leaves() { 326 | cacheMem += t.Memory() 327 | } 328 | klog.Infof("cache: elapsed %s, memory used %s\n", elapsed, humanize.Bytes(uint64(cacheMem))) 329 | } 330 | return 331 | } 332 | 333 | // decode converts the state's InputBuffer with the sampled tokens to actual text. 334 | func (s *Sampler) decode(state samplingState) []string { 335 | text := make([]string, state.BatchSize) 336 | totalLength := state.TotalLength 337 | tensors.ConstFlatData(state.InputBuffer, func(flatIds []int32) { 338 | for exampleIdx := range state.BatchSize { 339 | exampleIds := flatIds[exampleIdx*totalLength : (exampleIdx+1)*totalLength] 340 | ids := xslices.Map(exampleIds, func(id int32) int { return int(id) }) 341 | nonPad := 0 342 | for _, id := range ids { 343 | if id == s.Vocab.EndOfSentenceID() || id == s.Vocab.PadID() { 344 | break 345 | } 346 | nonPad++ 347 | } 348 | text[exampleIdx] = s.Vocab.DecodeIDs(ids[:nonPad]) // Notice , and are converted to empty strings. 349 | //fmt.Printf("tokens: %#v\n", ids[:nonPad]) 350 | } 351 | }) 352 | return text 353 | } 354 | 355 | // UpdateCacheAttentionMaskGraph given an inputMask (on the whole batch of example token ids), a currentStep and 356 | // attentionLen (static). 357 | // 358 | // It's based on _compute_attention_mask in https://github.com/google-deepmind/gemma/blob/main/gemma/sampler.py#L32: 359 | // the inputs and outputs here are very cryptic ... my best guess (also with some help from Gemini) what the original 360 | // authors meant to generate is a mask that is False to where they can attend, except if it is in the "future" 361 | // ("future" means positions > currentStep). 362 | // 363 | // - currentStep: scalar with current step. 364 | // - attentionLen: length of the attention mask (it's ok to be larger than inputMask, it will pad the output accordingly) 365 | // - inputMask: mask of valid tokens in the input, shaped [batchSize, inputLen] 366 | func UpdateCacheAttentionMaskGraph(currentStep *Node, attentionLen int, inputMask *Node) *Node { 367 | return nil 368 | } 369 | -------------------------------------------------------------------------------- /sentencepiece/sentencepiece.go: -------------------------------------------------------------------------------- 1 | // Package sentencepiece fills some missing functionality from github.com/eliben/go-sentencepiece 2 | // 3 | // Hopefully it's temporary. 4 | package sentencepiece 5 | 6 | import ( 7 | esentencepiece "github.com/eliben/go-sentencepiece" 8 | "github.com/gomlx/gomlx/types/xslices" 9 | "github.com/pkg/errors" 10 | ) 11 | 12 | // Tokenizer is able to encode/decode tokens from/to text. 13 | type Tokenizer struct { 14 | *esentencepiece.Processor 15 | Info *esentencepiece.ModelInfo 16 | } 17 | 18 | func NewFromPath(vocabPath string) (*Tokenizer, error) { 19 | proc, err := esentencepiece.NewProcessorFromPath(vocabPath) 20 | if err != nil { 21 | return nil, errors.Wrapf(err, "can't create sentencepiece") 22 | } 23 | return &Tokenizer{ 24 | Processor: proc, 25 | Info: proc.ModelInfo(), 26 | }, nil 27 | } 28 | 29 | type Token = esentencepiece.Token 30 | 31 | // EncodeAsIDs returns the text encoded into a sequence of ids. 32 | // It implements sampler.Vocabulary. 33 | func (p *Tokenizer) EncodeAsIDs(text string) []int { 34 | tokens := p.Processor.Encode(text) 35 | return xslices.Map(tokens, func(t Token) int { return t.ID }) 36 | } 37 | 38 | // DecodeIDs returns the text from a sequence of ids. 39 | // It implements sampler.Vocabulary. 40 | func (p *Tokenizer) DecodeIDs(ids []int) string { 41 | return p.Processor.Decode(ids) 42 | } 43 | 44 | // BeginningOfSentenceID implements sampler.Vocabulary. 45 | func (p *Tokenizer) BeginningOfSentenceID() int { 46 | return p.Info.BeginningOfSentenceID 47 | } 48 | 49 | // EndOfSentenceID implements sampler.Vocabulary. 50 | func (p *Tokenizer) EndOfSentenceID() int { 51 | return p.Info.EndOfSentenceID 52 | } 53 | 54 | // UnknownID implements sampler.Vocabulary. 55 | func (p *Tokenizer) UnknownID() int { 56 | return p.Info.UnknownID 57 | } 58 | 59 | // PadID implements sampler.Vocabulary. 60 | func (p *Tokenizer) PadID() int { 61 | return p.Info.PadID 62 | } 63 | -------------------------------------------------------------------------------- /transformers/attention.go: -------------------------------------------------------------------------------- 1 | package transformers 2 | 3 | import ( 4 | "github.com/gomlx/exceptions" 5 | "github.com/gomlx/gemma/trees" 6 | . "github.com/gomlx/gomlx/graph" 7 | "github.com/gomlx/gomlx/ml/context" 8 | "github.com/gomlx/gomlx/types/shapes" 9 | "github.com/gomlx/gomlx/types/tensors" 10 | "github.com/gomlx/gopjrt/dtypes" 11 | "github.com/pkg/errors" 12 | ) 13 | 14 | // createAttentionCache creates the attention cache for the attention layer under treePath. 15 | func createAttentionCache(data *trees.Tree[*tensors.Tensor], treePath trees.Path, dtype dtypes.DType, 16 | batchSize, maxCacheLength, numHeads, headDim int) error { 17 | // Value cache: 18 | err := data.Set(append(treePath, "v"), 19 | tensors.FromShape(shapes.Make(dtype, batchSize, maxCacheLength, numHeads, headDim))) 20 | if err != nil { 21 | return errors.WithMessage(err, "in createAttentionCache()") 22 | } 23 | 24 | // Keys cache: 25 | err = data.Set(append(treePath, "k"), 26 | tensors.FromShape(shapes.Make(dtype, batchSize, maxCacheLength, numHeads, headDim))) 27 | if err != nil { 28 | return errors.WithMessage(err, "in createAttentionCache()") 29 | } 30 | 31 | // Index where to insert new values, in a rotating cache. 32 | err = data.Set(append(treePath, "end_index"), tensors.FromScalar(int32(0))) 33 | if err != nil { 34 | return errors.WithMessage(err, "in createAttentionCache()") 35 | } 36 | return nil 37 | } 38 | 39 | // Must panics if the error is not nil. 40 | func Must(err error) { 41 | if err != nil { 42 | panic(err) 43 | } 44 | } 45 | 46 | // Must1 panics in case of error, otherwise returns the one return value. 47 | func Must1[T any](v T, err error) T { 48 | if err != nil { 49 | panic(err) 50 | } 51 | return v 52 | } 53 | 54 | // Attention builds an attention layer, optionally using cache to store a limited amount of context. 55 | // 56 | // - attentionIdx indexes attention configuration (in config) parameters, like config.AttentionTypes. 57 | // - x is the operand shaped [batchSize, sequenceLength, embedDim]. If using cache, typically the sequenceLength will be 1. 58 | // - positions are the positions of the sequence in x, shaped int32[batchSize, sequenceLength]. 59 | // - cache: if set, x is only used for the current token (so sequenceLength will be 1), and the x's key and value projections 60 | // are set in the cache. After that, cache is used instead of x for the attention. 61 | // - attentionMask: shaped bool[batchSize, sequenceLength, sequenceLength] (if cache is nil) or bool[batchSize, sequenceLength==1, config.MaxCacheLength] if 62 | // cache is being used. 63 | func Attention(ctx *context.Context, config *Config, attentionIdx int, x, positions *Node, cache *trees.Tree[*Node], attentionMask *Node) *Node { 64 | g := x.Graph() 65 | dtype := x.DType() 66 | 67 | // Calculates projections used in the attention. 68 | var queryProjection, keyProjection, valueProjection *Node 69 | 70 | // Glossary of keys for einsum: 71 | // B = batchSize 72 | // T = sequenceLength 73 | // D = config.EmbedDim 74 | // N = config.NumHeads 75 | // H = config.HeadDim 76 | // K = config.NumKVHeads 77 | if config.HuggingFaceVersion { 78 | // HuggingFace version has separate variables per projection. 79 | keyProjectionWeights := ctx.In("hf"). 80 | VariableWithShape("k_proj", shapes.Make(dtype, config.NumKVHeads*config.HeadDim, config.EmbedDim)). 81 | ValueGraph(g) 82 | keyProjectionWeights = Reshape(keyProjectionWeights, config.NumKVHeads, config.HeadDim, config.EmbedDim) 83 | keyProjection = Einsum("BSD,KHD->BSKH", x, keyProjectionWeights) 84 | 85 | valueProjectionWeights := ctx.In("hf"). 86 | VariableWithShape("v_proj", shapes.Make(dtype, config.NumKVHeads*config.HeadDim, config.EmbedDim)). 87 | ValueGraph(g) 88 | valueProjectionWeights = Reshape(valueProjectionWeights, config.NumKVHeads, config.HeadDim, config.EmbedDim) 89 | valueProjection = Einsum("BSD,KHD->BSKH", x, valueProjectionWeights) 90 | 91 | queryProjectionWeights := ctx.In("hf"). 92 | VariableWithShape("q_proj", shapes.Make(dtype, config.NumHeads*config.HeadDim, config.EmbedDim)). 93 | ValueGraph(g) 94 | queryProjectionWeights = Reshape(queryProjectionWeights, config.NumHeads, config.HeadDim, config.EmbedDim) 95 | queryProjection = Einsum("BSD,NHD->BSNH", x, queryProjectionWeights) 96 | 97 | } else if config.UseQKV { 98 | // S = 3, one extra dimensions for query, key, value projections 99 | qkvProjections := KernelEinsum(ctx.In("qkv_einsum"), "BTD,SNDH->SBTNH", x, 100 | shapes.Make(dtype /* k, q, v = 3 */, 3, config.NumHeads, config.EmbedDim, config.HeadDim)) 101 | queryProjection = Squeeze(Slice(qkvProjections, AxisElem(0)), 0) 102 | keyProjection = Squeeze(Slice(qkvProjections, AxisElem(1)), 0) 103 | valueProjection = Squeeze(Slice(qkvProjections, AxisElem(2)), 0) 104 | } else { 105 | queryProjection = KernelEinsum(ctx.In("q_einsum"), "BTD,NDH->BTNH", x, 106 | shapes.Make(dtype, config.NumHeads, config.EmbedDim, config.HeadDim)) 107 | // C = 2, one dimension for key, the other for value. 108 | kvProjections := KernelEinsum(ctx.In("kv_einsum"), "BSD,CKDH->CBSKH", x, 109 | shapes.Make(dtype, 2, config.NumKVHeads, config.EmbedDim, config.HeadDim)) 110 | keyProjection = Squeeze(Slice(kvProjections, AxisElem(0)), 0) 111 | valueProjection = Squeeze(Slice(kvProjections, AxisElem(1)), 0) 112 | } 113 | 114 | queryProjection = ApplyRotaryPositionEncoding(queryProjection, positions, RoPEDefaultMaxWaveLength) 115 | queryScaled := MulScalar(queryProjection, config.QueryPreAttentionScalar()) 116 | keyProjection = ApplyRotaryPositionEncoding(keyProjection, positions, RoPEDefaultMaxWaveLength) 117 | 118 | // If cache is set, update it with the projections of the slice of the sequence given, and then take the 119 | // projections of the whole cache. 120 | if cache != nil { 121 | // Insert calculated projections in cache: cached projections are shaped [batchSize, maxCacheLength, numHeads, headDim] 122 | endIndex, err := cache.Get("end_index") 123 | if err != nil { 124 | panic(err) 125 | } 126 | zeroIdx := ScalarZero(g, dtypes.Int32) 127 | cacheSequencePosition := Mod(endIndex, Scalar(g, endIndex.DType(), config.MaxCacheLength)) 128 | updateSliceIndices := []*Node{zeroIdx, cacheSequencePosition, zeroIdx, zeroIdx} 129 | 130 | valueProjection = DynamicUpdateSlice(Must1(cache.Get("v")), valueProjection, updateSliceIndices) 131 | keyProjection = DynamicUpdateSlice(Must1(cache.Get("k")), keyProjection, updateSliceIndices) 132 | Must(cache.Set(trees.Path{"v"}, valueProjection)) 133 | Must(cache.Set(trees.Path{"k"}, keyProjection)) 134 | // Bump end_index the length of tokens provided at this step: typically, this will be only 1. If > 1 135 | // this will probably not work if the cache wraps around. 136 | Must(cache.Set(trees.Path{"end_index"}, AddScalar(endIndex, positions.Shape().Dim(-1)))) 137 | } 138 | 139 | batchSize := queryScaled.Shape().Dim(0) // B 140 | seqLength := queryScaled.Shape().Dim(1) // T 141 | numQueryHeads := queryScaled.Shape().Dim(2) // N 142 | headDim := queryScaled.Shape().Dim(3) // H 143 | numKVHeads := config.NumKVHeads // K 144 | attentionTargetLength := keyProjection.Shape().Dim(1) // S = config.MaxCacheLength if cache != nil, or seqLength. 145 | 146 | var logits *Node 147 | if config.UseGroupQueryAttention { 148 | // There are fewer key (and value) projections than query projections, 149 | // reshape matrices accordingly and adjust Einsum. 150 | queryPerKVHeads := numQueryHeads / numKVHeads // G 151 | queryScaled = Reshape(queryScaled, batchSize, seqLength, numKVHeads, queryPerKVHeads, headDim) 152 | logits = Einsum("BTKGH,BSKH->BTKGS", queryScaled, keyProjection) 153 | logits = Reshape(logits, batchSize, seqLength, numQueryHeads, attentionTargetLength) 154 | } else { 155 | // Same number of query/key projections. 156 | // N = numQueryHeads == numKVHeads. 157 | logits = Einsum("BTNH,BSNH->BTNS", queryScaled, keyProjection) 158 | } 159 | logits.AssertDims(batchSize, seqLength, numQueryHeads, config.MaxCacheLength) 160 | logits = SoftCap(logits, config.AttentionLogitsSoftCap) // No-op if config.AttentionLogitsSoftCap is 0. 161 | 162 | if config.AttentionTypes[attentionIdx] == AttentionTypeLocalSliding { 163 | // Create a sliding mask: a mask that has a band (2*config.SlidingWindowSize) around the diagonal. 164 | // Issue: this will not work when using cache, and the cache loops around its config.MaxCacheLength, since 165 | // the sliding mask "band" won't wrap around. 166 | if config.SlidingWindowSize <= 0 { 167 | exceptions.Panicf("Config.SlidingWindowSize must be set for AttentionTypeLocalSliding") 168 | } 169 | allOnes := OnesLike(attentionMask) 170 | slidingMask := And( 171 | TakeUpperTriangular(allOnes, 1-config.SlidingWindowSize), 172 | TakeLowerTriangular(allOnes, config.SlidingWindowSize-1), 173 | ) 174 | attentionMask = And(attentionMask, slidingMask) 175 | } 176 | 177 | // Calculate attention weights. 178 | const logitsMask = -2.3819763e38 179 | paddedLogits := Where( 180 | BroadcastToShape(ExpandAxes(attentionMask, -2), logits.Shape()), 181 | logits, 182 | Scalar(g, logits.DType(), logitsMask), 183 | ) 184 | attentionWeights := Softmax(paddedLogits, -1) 185 | 186 | // Weighted sum of the values: 187 | var encoded *Node 188 | if config.UseGroupQueryAttention { 189 | // Reshape matrices to enable Einsums over groups of queries. 190 | queryPerKVHeads := numQueryHeads / numKVHeads // G 191 | attentionWeights = Reshape(attentionWeights, batchSize, seqLength, numKVHeads, queryPerKVHeads, attentionTargetLength) 192 | encoded = Einsum("BTKGS,BSKH->BTKGH", attentionWeights, valueProjection) 193 | encoded = Reshape(encoded, batchSize, seqLength, numQueryHeads, headDim) 194 | } else { 195 | // Plain attention: same number of query, keys and values projections. 196 | encoded = Einsum("BTNS,BSNH->BTNH", attentionWeights, valueProjection) 197 | encoded.AssertDims(batchSize, seqLength, numQueryHeads, headDim) 198 | } 199 | 200 | // Finally, a linear transformation on the result, merging all the heads. 201 | var output *Node 202 | if config.HuggingFaceVersion { 203 | outputProjectionWeights := ctx.In("hf"). 204 | VariableWithShape("o_proj", shapes.Make(dtype, config.EmbedDim, config.NumHeads*config.HeadDim)). 205 | ValueGraph(g) 206 | outputProjectionWeights = Reshape(outputProjectionWeights, config.EmbedDim, numQueryHeads, config.HeadDim) 207 | output = Einsum("BTNH,DNH->BTD", encoded, outputProjectionWeights) 208 | 209 | } else { 210 | output = KernelEinsum(ctx.In("attn_vec_einsum"), "BTNH,NHD->BTD", 211 | encoded, 212 | shapes.Make(encoded.DType(), numQueryHeads, config.HeadDim, config.EmbedDim)) 213 | } 214 | return output 215 | } 216 | -------------------------------------------------------------------------------- /transformers/attentiontype_enumer.go: -------------------------------------------------------------------------------- 1 | // Code generated by "enumer -type=AttentionType -trimprefix=AttentionType -transform=snake -values -text -json -yaml config.go"; DO NOT EDIT. 2 | 3 | package transformers 4 | 5 | import ( 6 | "encoding/json" 7 | "fmt" 8 | "strings" 9 | ) 10 | 11 | const _AttentionTypeName = "unknowngloballocal_sliding" 12 | 13 | var _AttentionTypeIndex = [...]uint8{0, 7, 13, 26} 14 | 15 | const _AttentionTypeLowerName = "unknowngloballocal_sliding" 16 | 17 | func (i AttentionType) String() string { 18 | if i < 0 || i >= AttentionType(len(_AttentionTypeIndex)-1) { 19 | return fmt.Sprintf("AttentionType(%d)", i) 20 | } 21 | return _AttentionTypeName[_AttentionTypeIndex[i]:_AttentionTypeIndex[i+1]] 22 | } 23 | 24 | func (AttentionType) Values() []string { 25 | return AttentionTypeStrings() 26 | } 27 | 28 | // An "invalid array index" compiler error signifies that the constant values have changed. 29 | // Re-run the stringer command to generate them again. 30 | func _AttentionTypeNoOp() { 31 | var x [1]struct{} 32 | _ = x[AttentionTypeUnknown-(0)] 33 | _ = x[AttentionTypeGlobal-(1)] 34 | _ = x[AttentionTypeLocalSliding-(2)] 35 | } 36 | 37 | var _AttentionTypeValues = []AttentionType{AttentionTypeUnknown, AttentionTypeGlobal, AttentionTypeLocalSliding} 38 | 39 | var _AttentionTypeNameToValueMap = map[string]AttentionType{ 40 | _AttentionTypeName[0:7]: AttentionTypeUnknown, 41 | _AttentionTypeLowerName[0:7]: AttentionTypeUnknown, 42 | _AttentionTypeName[7:13]: AttentionTypeGlobal, 43 | _AttentionTypeLowerName[7:13]: AttentionTypeGlobal, 44 | _AttentionTypeName[13:26]: AttentionTypeLocalSliding, 45 | _AttentionTypeLowerName[13:26]: AttentionTypeLocalSliding, 46 | } 47 | 48 | var _AttentionTypeNames = []string{ 49 | _AttentionTypeName[0:7], 50 | _AttentionTypeName[7:13], 51 | _AttentionTypeName[13:26], 52 | } 53 | 54 | // AttentionTypeString retrieves an enum value from the enum constants string name. 55 | // Throws an error if the param is not part of the enum. 56 | func AttentionTypeString(s string) (AttentionType, error) { 57 | if val, ok := _AttentionTypeNameToValueMap[s]; ok { 58 | return val, nil 59 | } 60 | 61 | if val, ok := _AttentionTypeNameToValueMap[strings.ToLower(s)]; ok { 62 | return val, nil 63 | } 64 | return 0, fmt.Errorf("%s does not belong to AttentionType values", s) 65 | } 66 | 67 | // AttentionTypeValues returns all values of the enum 68 | func AttentionTypeValues() []AttentionType { 69 | return _AttentionTypeValues 70 | } 71 | 72 | // AttentionTypeStrings returns a slice of all String values of the enum 73 | func AttentionTypeStrings() []string { 74 | strs := make([]string, len(_AttentionTypeNames)) 75 | copy(strs, _AttentionTypeNames) 76 | return strs 77 | } 78 | 79 | // IsAAttentionType returns "true" if the value is listed in the enum definition. "false" otherwise 80 | func (i AttentionType) IsAAttentionType() bool { 81 | for _, v := range _AttentionTypeValues { 82 | if i == v { 83 | return true 84 | } 85 | } 86 | return false 87 | } 88 | 89 | // MarshalJSON implements the json.Marshaler interface for AttentionType 90 | func (i AttentionType) MarshalJSON() ([]byte, error) { 91 | return json.Marshal(i.String()) 92 | } 93 | 94 | // UnmarshalJSON implements the json.Unmarshaler interface for AttentionType 95 | func (i *AttentionType) UnmarshalJSON(data []byte) error { 96 | var s string 97 | if err := json.Unmarshal(data, &s); err != nil { 98 | return fmt.Errorf("AttentionType should be a string, got %s", data) 99 | } 100 | 101 | var err error 102 | *i, err = AttentionTypeString(s) 103 | return err 104 | } 105 | 106 | // MarshalText implements the encoding.TextMarshaler interface for AttentionType 107 | func (i AttentionType) MarshalText() ([]byte, error) { 108 | return []byte(i.String()), nil 109 | } 110 | 111 | // UnmarshalText implements the encoding.TextUnmarshaler interface for AttentionType 112 | func (i *AttentionType) UnmarshalText(text []byte) error { 113 | var err error 114 | *i, err = AttentionTypeString(string(text)) 115 | return err 116 | } 117 | 118 | // MarshalYAML implements a YAML Marshaler for AttentionType 119 | func (i AttentionType) MarshalYAML() (interface{}, error) { 120 | return i.String(), nil 121 | } 122 | 123 | // UnmarshalYAML implements a YAML Unmarshaler for AttentionType 124 | func (i *AttentionType) UnmarshalYAML(unmarshal func(interface{}) error) error { 125 | var s string 126 | if err := unmarshal(&s); err != nil { 127 | return err 128 | } 129 | 130 | var err error 131 | *i, err = AttentionTypeString(s) 132 | return err 133 | } 134 | -------------------------------------------------------------------------------- /transformers/cache.go: -------------------------------------------------------------------------------- 1 | package transformers 2 | 3 | import ( 4 | "fmt" 5 | "github.com/gomlx/gemma/trees" 6 | "github.com/gomlx/gomlx/types/tensors" 7 | ) 8 | 9 | // Cache is a state cache of a (batch of) sequence being encoded/decoded. 10 | // 11 | // It has a fixed size (so typical cached values with be prefixed with the dimensions [BatchSize, Cache.Size], 12 | // and current position (where each decode step is stored) is rotating: CurrentStep = (CurrentStep+1)%Cache.Length). 13 | // 14 | // It's stored as a trees.Tree[*tensor.Tensor]. 15 | // 16 | // For the Gemma2 model, the first level of the tree being the layer names, 17 | // and the second level hold the "keys" and "values" embedding caches for each transformer layer. 18 | type Cache struct { 19 | // Config of the model. 20 | Config *Config 21 | 22 | // BatchSize for this cache. 23 | BatchSize int 24 | 25 | // Length (in number of steps) of the cache. The cache itself is rotating on this size. 26 | // It comes from config.MaxCacheLength. 27 | Length int 28 | 29 | // Data holds the cached data, organized as a trees.Tree[*tensors.Tensor]. 30 | Data *trees.Tree[*tensors.Tensor] 31 | } 32 | 33 | func NewCache(config *Config, batchSize int) (*Cache, error) { 34 | c := &Cache{ 35 | Config: config, 36 | BatchSize: batchSize, 37 | Length: config.MaxCacheLength, 38 | Data: trees.New[*tensors.Tensor](), 39 | } 40 | 41 | for layerIdx := range config.NumLayers { 42 | treePath := []string{fmt.Sprintf("layer_%d", layerIdx)} 43 | err := createAttentionCache(c.Data, treePath, config.DType, batchSize, config.MaxCacheLength, 44 | config.NumKVHeads, config.HeadDim) 45 | if err != nil { 46 | return nil, err 47 | } 48 | } 49 | return c, nil 50 | } 51 | -------------------------------------------------------------------------------- /transformers/config.go: -------------------------------------------------------------------------------- 1 | package transformers 2 | 3 | import ( 4 | "github.com/gomlx/exceptions" 5 | "github.com/gomlx/gomlx/ml/context" 6 | "github.com/gomlx/gopjrt/dtypes" 7 | "github.com/pkg/errors" 8 | "math" 9 | ) 10 | 11 | type GemmaType int 12 | 13 | const ( 14 | UnknownGemmaType GemmaType = iota 15 | Gemma_2B 16 | Gemma_7B 17 | Gemma2_2B 18 | Gemma2_9B 19 | Gemma2_27B 20 | ) 21 | 22 | //go:generate enumer -type=GemmaType -transform=snake -values -text -json -yaml config.go 23 | 24 | var numLayersToGemmaClass = map[int]GemmaType{ 25 | 18: Gemma_2B, 26 | 28: Gemma_7B, 27 | 26: Gemma2_2B, 28 | 42: Gemma2_9B, 29 | 46: Gemma2_27B, 30 | } 31 | 32 | type AttentionType int 33 | 34 | //go:generate enumer -type=AttentionType -trimprefix=AttentionType -transform=snake -values -text -json -yaml config.go 35 | 36 | const ( 37 | AttentionTypeUnknown AttentionType = iota 38 | AttentionTypeGlobal 39 | AttentionTypeLocalSliding 40 | ) 41 | 42 | // QueryPreAttentionNormalisationType defines how to normalize query before attention. 43 | type QueryPreAttentionNormalisationType int 44 | 45 | //go:generate enumer -type=QueryPreAttentionNormalisationType -trimprefix=QueryNormType -transform=snake -values -text -json -yaml config.go 46 | 47 | const ( 48 | // QueryNormTypeByOneOverSqrtHeadDim indicates whether to scale the query by 1/sqrt(head_dim) 49 | QueryNormTypeByOneOverSqrtHeadDim QueryPreAttentionNormalisationType = iota 50 | 51 | // QueryNormTypeByEmbedDimDivNumHeads indicates whether to scale the query by `embed_dim // num_heads` 52 | QueryNormTypeByEmbedDimDivNumHeads 53 | 54 | // QueryNormTypeByOneOverSqrtEmbedDimDivNumHeads indicates whether to scale the query by `1/sqrt(embed_dim // num_heads)` 55 | QueryNormTypeByOneOverSqrtEmbedDimDivNumHeads 56 | ) 57 | 58 | // Config Gemma transformer model. 59 | type Config struct { 60 | Type GemmaType 61 | DType dtypes.DType 62 | VocabularySize int 63 | NumLayers, NumEmbed int 64 | 65 | // HuggingFaceVersion has different shapes for some of the variables. 66 | HuggingFaceVersion bool 67 | 68 | // EmbedDim is also called "features" in the original code. It is the representation size (last dimension) of the output of the attention layers. 69 | EmbedDim int 70 | NumHeads, HeadDim int 71 | HiddenDim int 72 | NumKVHeads int 73 | FinalLogitSoftCap float64 74 | UseQKV, UseGroupQueryAttention bool 75 | UsePostAttentionNorm, UsePostFFWNorm bool 76 | 77 | AttentionTypes []AttentionType 78 | MaxCacheLength int 79 | QueryPreAttentionNorm QueryPreAttentionNormalisationType 80 | 81 | // AttentionLogitsSoftCap limits the attention logits (logits = AttentionLogitsSoftCap * tanh(logits/AttentionLogitsSoftCap)). 82 | // Enabled if > 0. 83 | AttentionLogitsSoftCap float64 84 | SlidingWindowSize int 85 | TransposeGatingEinsum bool 86 | } 87 | 88 | // NewConfigFromContext creates a transformers config model, based on the structure of the variables in the given context -- the scope 89 | // has to be set directly to the model variables. 90 | func NewConfigFromContext(ctx *context.Context) (*Config, error) { 91 | c := &Config{ 92 | MaxCacheLength: 1024, 93 | QueryPreAttentionNorm: QueryNormTypeByOneOverSqrtHeadDim, 94 | } 95 | 96 | embedTable := ctx.In("embedder").GetVariable("input_embedding") 97 | if embedTable == nil { 98 | return nil, errors.New("context given doesn't have an embedding table defined in \"embedder/input_embedding\"") 99 | } 100 | 101 | c.DType = embedTable.Shape().DType 102 | c.VocabularySize = embedTable.Shape().Dim(0) 103 | c.HuggingFaceVersion = c.VocabularySize == 256000 // Kaggle version is 256128. 104 | 105 | // Find number of layers. 106 | for { 107 | v := ctx.Inf("layer_%d", c.NumLayers).In("pre_attention_norm").GetVariable("scale") 108 | if v == nil { 109 | break 110 | } 111 | c.NumLayers++ 112 | } 113 | if t, found := numLayersToGemmaClass[c.NumLayers]; found { 114 | c.Type = t 115 | } 116 | 117 | switch c.Type { 118 | case Gemma2_2B: 119 | c.setGemma2_2B() 120 | default: 121 | return nil, errors.Errorf("unknown or not implemented for Gemma model type %q", c.Type) 122 | } 123 | 124 | c.UseQKV = c.NumKVHeads == c.NumHeads 125 | c.UseGroupQueryAttention = (c.NumKVHeads != c.NumHeads) && c.NumKVHeads > 1 126 | return c, nil 127 | } 128 | 129 | func (c *Config) setGemma2_2B() { 130 | c.NumLayers = 26 131 | c.NumEmbed = 256128 132 | c.EmbedDim = 2304 133 | c.HiddenDim = 9216 134 | c.NumHeads = 8 135 | c.HeadDim = 256 136 | c.NumKVHeads = 4 137 | c.FinalLogitSoftCap = 30.0 138 | c.AttentionTypes = nil 139 | for _ = range c.NumLayers / 2 { 140 | c.AttentionTypes = append(c.AttentionTypes, []AttentionType{AttentionTypeLocalSliding, AttentionTypeGlobal}...) 141 | } 142 | c.UsePostAttentionNorm = true 143 | c.UsePostFFWNorm = true 144 | c.QueryPreAttentionNorm = QueryNormTypeByOneOverSqrtHeadDim 145 | c.AttentionLogitsSoftCap = 50.0 146 | c.SlidingWindowSize = 4096 147 | } 148 | 149 | // QueryPreAttentionScalar is a multiplier to the query projections. 150 | func (c *Config) QueryPreAttentionScalar() float64 { 151 | switch c.QueryPreAttentionNorm { 152 | case QueryNormTypeByEmbedDimDivNumHeads: 153 | return float64(c.EmbedDim / c.NumHeads) 154 | case QueryNormTypeByOneOverSqrtEmbedDimDivNumHeads: 155 | return 1.0 / math.Sqrt(float64(c.EmbedDim/c.NumHeads)) 156 | case QueryNormTypeByOneOverSqrtHeadDim: 157 | return 1.0 / math.Sqrt(float64(c.HeadDim)) 158 | default: 159 | exceptions.Panicf("invalid value of QueryPreAttentionNorm = %d, expected one of the valid enum values", c.QueryPreAttentionNorm) 160 | panic(nil) // Quiet lint. 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /transformers/gemmatype_enumer.go: -------------------------------------------------------------------------------- 1 | // Code generated by "enumer -type=GemmaType -transform=snake -values -text -json -yaml config.go"; DO NOT EDIT. 2 | 3 | package transformers 4 | 5 | import ( 6 | "encoding/json" 7 | "fmt" 8 | "strings" 9 | ) 10 | 11 | const _GemmaTypeName = "unknown_gemma_typegemma_2bgemma_7bgemma2_2bgemma2_9bgemma2_27b" 12 | 13 | var _GemmaTypeIndex = [...]uint8{0, 18, 26, 34, 43, 52, 62} 14 | 15 | const _GemmaTypeLowerName = "unknown_gemma_typegemma_2bgemma_7bgemma2_2bgemma2_9bgemma2_27b" 16 | 17 | func (i GemmaType) String() string { 18 | if i < 0 || i >= GemmaType(len(_GemmaTypeIndex)-1) { 19 | return fmt.Sprintf("GemmaType(%d)", i) 20 | } 21 | return _GemmaTypeName[_GemmaTypeIndex[i]:_GemmaTypeIndex[i+1]] 22 | } 23 | 24 | func (GemmaType) Values() []string { 25 | return GemmaTypeStrings() 26 | } 27 | 28 | // An "invalid array index" compiler error signifies that the constant values have changed. 29 | // Re-run the stringer command to generate them again. 30 | func _GemmaTypeNoOp() { 31 | var x [1]struct{} 32 | _ = x[UnknownGemmaType-(0)] 33 | _ = x[Gemma_2B-(1)] 34 | _ = x[Gemma_7B-(2)] 35 | _ = x[Gemma2_2B-(3)] 36 | _ = x[Gemma2_9B-(4)] 37 | _ = x[Gemma2_27B-(5)] 38 | } 39 | 40 | var _GemmaTypeValues = []GemmaType{UnknownGemmaType, Gemma_2B, Gemma_7B, Gemma2_2B, Gemma2_9B, Gemma2_27B} 41 | 42 | var _GemmaTypeNameToValueMap = map[string]GemmaType{ 43 | _GemmaTypeName[0:18]: UnknownGemmaType, 44 | _GemmaTypeLowerName[0:18]: UnknownGemmaType, 45 | _GemmaTypeName[18:26]: Gemma_2B, 46 | _GemmaTypeLowerName[18:26]: Gemma_2B, 47 | _GemmaTypeName[26:34]: Gemma_7B, 48 | _GemmaTypeLowerName[26:34]: Gemma_7B, 49 | _GemmaTypeName[34:43]: Gemma2_2B, 50 | _GemmaTypeLowerName[34:43]: Gemma2_2B, 51 | _GemmaTypeName[43:52]: Gemma2_9B, 52 | _GemmaTypeLowerName[43:52]: Gemma2_9B, 53 | _GemmaTypeName[52:62]: Gemma2_27B, 54 | _GemmaTypeLowerName[52:62]: Gemma2_27B, 55 | } 56 | 57 | var _GemmaTypeNames = []string{ 58 | _GemmaTypeName[0:18], 59 | _GemmaTypeName[18:26], 60 | _GemmaTypeName[26:34], 61 | _GemmaTypeName[34:43], 62 | _GemmaTypeName[43:52], 63 | _GemmaTypeName[52:62], 64 | } 65 | 66 | // GemmaTypeString retrieves an enum value from the enum constants string name. 67 | // Throws an error if the param is not part of the enum. 68 | func GemmaTypeString(s string) (GemmaType, error) { 69 | if val, ok := _GemmaTypeNameToValueMap[s]; ok { 70 | return val, nil 71 | } 72 | 73 | if val, ok := _GemmaTypeNameToValueMap[strings.ToLower(s)]; ok { 74 | return val, nil 75 | } 76 | return 0, fmt.Errorf("%s does not belong to GemmaType values", s) 77 | } 78 | 79 | // GemmaTypeValues returns all values of the enum 80 | func GemmaTypeValues() []GemmaType { 81 | return _GemmaTypeValues 82 | } 83 | 84 | // GemmaTypeStrings returns a slice of all String values of the enum 85 | func GemmaTypeStrings() []string { 86 | strs := make([]string, len(_GemmaTypeNames)) 87 | copy(strs, _GemmaTypeNames) 88 | return strs 89 | } 90 | 91 | // IsAGemmaType returns "true" if the value is listed in the enum definition. "false" otherwise 92 | func (i GemmaType) IsAGemmaType() bool { 93 | for _, v := range _GemmaTypeValues { 94 | if i == v { 95 | return true 96 | } 97 | } 98 | return false 99 | } 100 | 101 | // MarshalJSON implements the json.Marshaler interface for GemmaType 102 | func (i GemmaType) MarshalJSON() ([]byte, error) { 103 | return json.Marshal(i.String()) 104 | } 105 | 106 | // UnmarshalJSON implements the json.Unmarshaler interface for GemmaType 107 | func (i *GemmaType) UnmarshalJSON(data []byte) error { 108 | var s string 109 | if err := json.Unmarshal(data, &s); err != nil { 110 | return fmt.Errorf("GemmaType should be a string, got %s", data) 111 | } 112 | 113 | var err error 114 | *i, err = GemmaTypeString(s) 115 | return err 116 | } 117 | 118 | // MarshalText implements the encoding.TextMarshaler interface for GemmaType 119 | func (i GemmaType) MarshalText() ([]byte, error) { 120 | return []byte(i.String()), nil 121 | } 122 | 123 | // UnmarshalText implements the encoding.TextUnmarshaler interface for GemmaType 124 | func (i *GemmaType) UnmarshalText(text []byte) error { 125 | var err error 126 | *i, err = GemmaTypeString(string(text)) 127 | return err 128 | } 129 | 130 | // MarshalYAML implements a YAML Marshaler for GemmaType 131 | func (i GemmaType) MarshalYAML() (interface{}, error) { 132 | return i.String(), nil 133 | } 134 | 135 | // UnmarshalYAML implements a YAML Unmarshaler for GemmaType 136 | func (i *GemmaType) UnmarshalYAML(unmarshal func(interface{}) error) error { 137 | var s string 138 | if err := unmarshal(&s); err != nil { 139 | return err 140 | } 141 | 142 | var err error 143 | *i, err = GemmaTypeString(s) 144 | return err 145 | } 146 | -------------------------------------------------------------------------------- /transformers/layers.go: -------------------------------------------------------------------------------- 1 | package transformers 2 | 3 | import ( 4 | "github.com/gomlx/exceptions" 5 | . "github.com/gomlx/gomlx/graph" 6 | "github.com/gomlx/gomlx/ml/context" 7 | "github.com/gomlx/gomlx/ml/context/initializers" 8 | "github.com/gomlx/gomlx/ml/layers/activations" 9 | "github.com/gomlx/gomlx/types/shapes" 10 | ) 11 | 12 | // SoftCap using Tanh, so values won't go beyond +/- cap. If cap <= 0, it is a no-op. 13 | // 14 | // SoftCap(x) = Tanh(x/cap) * cap 15 | func SoftCap(x *Node, cap float64) *Node { 16 | if cap <= 0 { 17 | return x 18 | } 19 | return MulScalar(Tanh(DivScalar(x, cap)), cap) 20 | } 21 | 22 | // KernelEinsum multiplies the input by a kernel of the given shape, using the given graph.EinSum equation. 23 | func KernelEinsum(ctx *context.Context, equation string, x *Node, kernelShape shapes.Shape) *Node { 24 | g := x.Graph() 25 | kernelVar := ctx.VariableWithShape("w", kernelShape) 26 | kernel := kernelVar.ValueGraph(g) 27 | return Einsum(equation, x, kernel) 28 | } 29 | 30 | // RMSNorm normalizes by its root-mean-square x = x / √(mean(sqrt(x), axis=-1) + epsilon) and applies a learned scale. 31 | func RMSNorm(ctx *context.Context, x *Node) *Node { 32 | g := x.Graph() 33 | variance := ReduceAndKeep(Square(x), ReduceMean, -1) 34 | const epsilon = 1e-6 35 | normalizedX := Mul(x, Rsqrt(AddScalar(variance, epsilon))) 36 | 37 | // Now apply a learned scale. 38 | scaleVar := ctx.WithInitializer(initializers.Zero). 39 | VariableWithShape("scale", shapes.Make(x.DType(), x.Shape().Dim(-1))) 40 | scale := scaleVar.ValueGraph(g) 41 | scale = ExpandLeftToRank(scale, normalizedX.Rank()) // Expand rank of scale to match normalizedX. 42 | // Scale centered on 1.0 (so 0.0 has no effect). 43 | scale = OnePlus(scale) 44 | normalizedX = Mul(scale, normalizedX) 45 | return normalizedX 46 | } 47 | 48 | // RoPEDefaultMaxWaveLength is a default value to use for rotary positional encoding. 49 | // See ApplyRotaryPositionEncoding. 50 | const RoPEDefaultMaxWaveLength = 10_000 51 | 52 | // ApplyRotaryPositionEncoding (aka. RoPE) applies the positional encoding to the operand, given the positions 53 | // (integer numbers of the position). 54 | // 55 | // - operand: the last axis ("features" or "embedding" axis) must be divisible by 2. The shape usually is [batchSize, sequenceSize, numHeads, headDim]. 56 | // - positions: its shape must be a prefix to operand. Typically, it's shaped [batchSize, sequenceSize]. 57 | // - maxWaveLength: it uses wave lengths in a power scale, up to maxWaveLength -- see RoPEDefaultMaxWaveLength for a reasonable value. 58 | // 59 | // Reference: https://arxiv.org/abs/2104.09864 60 | func ApplyRotaryPositionEncoding(operand, positions *Node, maxWaveLength int) *Node { 61 | g := operand.Graph() 62 | dtype := operand.DType() 63 | featuresDim := operand.Shape().Dim(-1) 64 | if featuresDim <= 0 || featuresDim%2 != 0 { 65 | exceptions.Panicf("ApplyRotaryPositionEncoding(operand=%s, position=%s) requires operand's last "+ 66 | "dimension to be >= 0 and divisible by 2", operand.Shape(), positions.Shape()) 67 | } 68 | 69 | transientDType := dtype 70 | fraction := Iota(g, shapes.Make(transientDType, featuresDim/2), 0) 71 | fraction = MulScalar(fraction, 2.0/float64(featuresDim)) 72 | timeScale := Pow(Scalar(g, transientDType, float64(maxWaveLength)), fraction) 73 | timeScale = ExpandLeftToRank(timeScale, positions.Rank()+1) 74 | 75 | // Angles shape will add a rank to positions: we will take each position at a different wave length (or timeScale). 76 | angles := ConvertDType(ExpandAxes(positions, -1), transientDType) 77 | angles = Div(angles, timeScale) 78 | 79 | // Insert an axis just before the last until it matches the operand's shape. 80 | for angles.Rank() < operand.Rank() { 81 | angles = ExpandDims(angles, -2) 82 | } 83 | sines := Sin(angles) 84 | cosines := Cos(angles) 85 | 86 | // Split first/second half of operands features (the last dimension), and apply rotation at the various wave lengths. 87 | firstHalf := Slice(operand, AxisRange().Spacer(), AxisRange(0, featuresDim/2)) 88 | secondHalf := Slice(operand, AxisRange().Spacer(), AxisRangeToEnd(featuresDim/2)) 89 | firstHalfUpdate := Sub( 90 | Mul(firstHalf, cosines), 91 | Mul(secondHalf, sines), 92 | ) 93 | secondHalfUpdate := Add( 94 | Mul(secondHalf, cosines), 95 | Mul(firstHalf, sines), 96 | ) 97 | return ConvertDType(Concatenate([]*Node{firstHalfUpdate, secondHalfUpdate}, -1), dtype) 98 | } 99 | 100 | // GatedFeedForward layer for Gemma: 101 | // - hiddenDim: one intermediary layer. 102 | // - transposeGatingEinsum: for some versions of Gemma, the gating (hidden) weights have the axes transposed. 103 | // - It uses Gelu as activation function for the gating signal (multiplied by the up-projected values). 104 | func GatedFeedForward(ctx *context.Context, x *Node, hiddenDim int, transposeGatingEinsum bool) *Node { 105 | g := x.Graph() 106 | featuresDim := x.Shape().Dim(-1) 107 | 108 | var gatingWeights *Node 109 | if transposeGatingEinsum { 110 | // Some versions of Gemma use an alternate parameter ordering that transposes hiddenDim and outputDim. 111 | gatingVar := ctx.WithInitializer(initializers.Zero). 112 | VariableWithShape("gating_einsum", shapes.Make(x.DType(), 2, hiddenDim, featuresDim)) 113 | gatingWeights = gatingVar.ValueGraph(g) 114 | gatingWeights = Transpose(gatingWeights, 1, 2) 115 | } else { 116 | // Standard shape of the gating weights. 117 | gatingVar := ctx.WithInitializer(initializers.Zero). 118 | VariableWithShape("gating_einsum", shapes.Make(x.DType(), 2, featuresDim, hiddenDim)) 119 | gatingWeights = gatingVar.ValueGraph(g) 120 | } 121 | gatingWeights0 := Squeeze(Slice(gatingWeights, AxisElem(0)), 0) 122 | gatingWeights1 := Squeeze(Slice(gatingWeights, AxisElem(1)), 0) 123 | 124 | gateValue := DotGeneral(x, []int{-1}, nil, gatingWeights0, []int{0}, nil) 125 | gateValue = activations.Gelu(gateValue) 126 | 127 | upProjection := DotGeneral(x, []int{-1}, nil, gatingWeights1, []int{0}, nil) 128 | upProjection = Mul(gateValue, upProjection) // Gate upProjection. 129 | 130 | downProjectionVar := ctx.WithInitializer(initializers.Zero). 131 | VariableWithShape("linear", shapes.Make(x.DType(), hiddenDim, featuresDim)) 132 | downProjectionWeights := downProjectionVar.ValueGraph(g) 133 | output := DotGeneral(upProjection, []int{-1}, nil, downProjectionWeights, []int{0}, nil) 134 | return output 135 | } 136 | 137 | // HuggingFaceGatedFeedForward layer for Gemma, the HuggingFace version, with transposed weights: 138 | // - hiddenDim: one intermediary layer. 139 | // - transposeGatingEinsum: for some versions of Gemma, the gating (hidden) weights have the axes transposed. 140 | // - It uses Gelu as activation function for the gating signal (multiplied by the up-projected values). 141 | func HuggingFaceGatedFeedForward(ctx *context.Context, x *Node, hiddenDim int, transposeGatingEinsum bool) *Node { 142 | ctx = ctx.In("hf") // extra-scope for HuggingFace version. 143 | g := x.Graph() 144 | featuresDim := x.Shape().Dim(-1) 145 | 146 | gatingProjVar := ctx.WithInitializer(initializers.Zero). 147 | VariableWithShape("gating_proj", shapes.Make(x.DType(), hiddenDim, featuresDim)) 148 | gatingWeights := gatingProjVar.ValueGraph(g) 149 | upProjectionVar := ctx.WithInitializer(initializers.Zero). 150 | VariableWithShape("up_proj", shapes.Make(x.DType(), hiddenDim, featuresDim)) 151 | upProjectionWeights := upProjectionVar.ValueGraph(g) 152 | 153 | gateValue := DotGeneral(x, []int{-1}, nil, gatingWeights, []int{1}, nil) 154 | gateValue = activations.Gelu(gateValue) 155 | 156 | upProjection := DotGeneral(x, []int{-1}, nil, upProjectionWeights, []int{1}, nil) 157 | upProjection = Mul(gateValue, upProjection) // Gate upProjection. 158 | 159 | downProjectionVar := ctx.WithInitializer(initializers.Zero). 160 | VariableWithShape("down_proj", shapes.Make(x.DType(), featuresDim, hiddenDim)) 161 | downProjectionWeights := downProjectionVar.ValueGraph(g) 162 | output := DotGeneral(upProjection, []int{-1}, nil, downProjectionWeights, []int{1}, nil) 163 | return output 164 | } 165 | -------------------------------------------------------------------------------- /transformers/querypreattentionnormalisationtype_enumer.go: -------------------------------------------------------------------------------- 1 | // Code generated by "enumer -type=QueryPreAttentionNormalisationType -trimprefix=QueryNormType -transform=snake -values -text -json -yaml config.go"; DO NOT EDIT. 2 | 3 | package transformers 4 | 5 | import ( 6 | "encoding/json" 7 | "fmt" 8 | "strings" 9 | ) 10 | 11 | const _QueryPreAttentionNormalisationTypeName = "by_one_over_sqrt_head_dimby_embed_dim_div_num_headsby_one_over_sqrt_embed_dim_div_num_heads" 12 | 13 | var _QueryPreAttentionNormalisationTypeIndex = [...]uint8{0, 25, 51, 91} 14 | 15 | const _QueryPreAttentionNormalisationTypeLowerName = "by_one_over_sqrt_head_dimby_embed_dim_div_num_headsby_one_over_sqrt_embed_dim_div_num_heads" 16 | 17 | func (i QueryPreAttentionNormalisationType) String() string { 18 | if i < 0 || i >= QueryPreAttentionNormalisationType(len(_QueryPreAttentionNormalisationTypeIndex)-1) { 19 | return fmt.Sprintf("QueryPreAttentionNormalisationType(%d)", i) 20 | } 21 | return _QueryPreAttentionNormalisationTypeName[_QueryPreAttentionNormalisationTypeIndex[i]:_QueryPreAttentionNormalisationTypeIndex[i+1]] 22 | } 23 | 24 | func (QueryPreAttentionNormalisationType) Values() []string { 25 | return QueryPreAttentionNormalisationTypeStrings() 26 | } 27 | 28 | // An "invalid array index" compiler error signifies that the constant values have changed. 29 | // Re-run the stringer command to generate them again. 30 | func _QueryPreAttentionNormalisationTypeNoOp() { 31 | var x [1]struct{} 32 | _ = x[QueryNormTypeByOneOverSqrtHeadDim-(0)] 33 | _ = x[QueryNormTypeByEmbedDimDivNumHeads-(1)] 34 | _ = x[QueryNormTypeByOneOverSqrtEmbedDimDivNumHeads-(2)] 35 | } 36 | 37 | var _QueryPreAttentionNormalisationTypeValues = []QueryPreAttentionNormalisationType{QueryNormTypeByOneOverSqrtHeadDim, QueryNormTypeByEmbedDimDivNumHeads, QueryNormTypeByOneOverSqrtEmbedDimDivNumHeads} 38 | 39 | var _QueryPreAttentionNormalisationTypeNameToValueMap = map[string]QueryPreAttentionNormalisationType{ 40 | _QueryPreAttentionNormalisationTypeName[0:25]: QueryNormTypeByOneOverSqrtHeadDim, 41 | _QueryPreAttentionNormalisationTypeLowerName[0:25]: QueryNormTypeByOneOverSqrtHeadDim, 42 | _QueryPreAttentionNormalisationTypeName[25:51]: QueryNormTypeByEmbedDimDivNumHeads, 43 | _QueryPreAttentionNormalisationTypeLowerName[25:51]: QueryNormTypeByEmbedDimDivNumHeads, 44 | _QueryPreAttentionNormalisationTypeName[51:91]: QueryNormTypeByOneOverSqrtEmbedDimDivNumHeads, 45 | _QueryPreAttentionNormalisationTypeLowerName[51:91]: QueryNormTypeByOneOverSqrtEmbedDimDivNumHeads, 46 | } 47 | 48 | var _QueryPreAttentionNormalisationTypeNames = []string{ 49 | _QueryPreAttentionNormalisationTypeName[0:25], 50 | _QueryPreAttentionNormalisationTypeName[25:51], 51 | _QueryPreAttentionNormalisationTypeName[51:91], 52 | } 53 | 54 | // QueryPreAttentionNormalisationTypeString retrieves an enum value from the enum constants string name. 55 | // Throws an error if the param is not part of the enum. 56 | func QueryPreAttentionNormalisationTypeString(s string) (QueryPreAttentionNormalisationType, error) { 57 | if val, ok := _QueryPreAttentionNormalisationTypeNameToValueMap[s]; ok { 58 | return val, nil 59 | } 60 | 61 | if val, ok := _QueryPreAttentionNormalisationTypeNameToValueMap[strings.ToLower(s)]; ok { 62 | return val, nil 63 | } 64 | return 0, fmt.Errorf("%s does not belong to QueryPreAttentionNormalisationType values", s) 65 | } 66 | 67 | // QueryPreAttentionNormalisationTypeValues returns all values of the enum 68 | func QueryPreAttentionNormalisationTypeValues() []QueryPreAttentionNormalisationType { 69 | return _QueryPreAttentionNormalisationTypeValues 70 | } 71 | 72 | // QueryPreAttentionNormalisationTypeStrings returns a slice of all String values of the enum 73 | func QueryPreAttentionNormalisationTypeStrings() []string { 74 | strs := make([]string, len(_QueryPreAttentionNormalisationTypeNames)) 75 | copy(strs, _QueryPreAttentionNormalisationTypeNames) 76 | return strs 77 | } 78 | 79 | // IsAQueryPreAttentionNormalisationType returns "true" if the value is listed in the enum definition. "false" otherwise 80 | func (i QueryPreAttentionNormalisationType) IsAQueryPreAttentionNormalisationType() bool { 81 | for _, v := range _QueryPreAttentionNormalisationTypeValues { 82 | if i == v { 83 | return true 84 | } 85 | } 86 | return false 87 | } 88 | 89 | // MarshalJSON implements the json.Marshaler interface for QueryPreAttentionNormalisationType 90 | func (i QueryPreAttentionNormalisationType) MarshalJSON() ([]byte, error) { 91 | return json.Marshal(i.String()) 92 | } 93 | 94 | // UnmarshalJSON implements the json.Unmarshaler interface for QueryPreAttentionNormalisationType 95 | func (i *QueryPreAttentionNormalisationType) UnmarshalJSON(data []byte) error { 96 | var s string 97 | if err := json.Unmarshal(data, &s); err != nil { 98 | return fmt.Errorf("QueryPreAttentionNormalisationType should be a string, got %s", data) 99 | } 100 | 101 | var err error 102 | *i, err = QueryPreAttentionNormalisationTypeString(s) 103 | return err 104 | } 105 | 106 | // MarshalText implements the encoding.TextMarshaler interface for QueryPreAttentionNormalisationType 107 | func (i QueryPreAttentionNormalisationType) MarshalText() ([]byte, error) { 108 | return []byte(i.String()), nil 109 | } 110 | 111 | // UnmarshalText implements the encoding.TextUnmarshaler interface for QueryPreAttentionNormalisationType 112 | func (i *QueryPreAttentionNormalisationType) UnmarshalText(text []byte) error { 113 | var err error 114 | *i, err = QueryPreAttentionNormalisationTypeString(string(text)) 115 | return err 116 | } 117 | 118 | // MarshalYAML implements a YAML Marshaler for QueryPreAttentionNormalisationType 119 | func (i QueryPreAttentionNormalisationType) MarshalYAML() (interface{}, error) { 120 | return i.String(), nil 121 | } 122 | 123 | // UnmarshalYAML implements a YAML Unmarshaler for QueryPreAttentionNormalisationType 124 | func (i *QueryPreAttentionNormalisationType) UnmarshalYAML(unmarshal func(interface{}) error) error { 125 | var s string 126 | if err := unmarshal(&s); err != nil { 127 | return err 128 | } 129 | 130 | var err error 131 | *i, err = QueryPreAttentionNormalisationTypeString(s) 132 | return err 133 | } 134 | -------------------------------------------------------------------------------- /transformers/transformers.go: -------------------------------------------------------------------------------- 1 | // Package transformers implements the various Gema models. 2 | // It is based on https://github.com/google-deepmind/gemma/blob/main/gemma/transformer.py 3 | package transformers 4 | 5 | import ( 6 | "fmt" 7 | "github.com/gomlx/gemma/trees" 8 | . "github.com/gomlx/gomlx/graph" 9 | "github.com/gomlx/gomlx/ml/context" 10 | "github.com/gomlx/gomlx/types/shapes" 11 | "github.com/gomlx/gopjrt/dtypes" 12 | ) 13 | 14 | // GemmaWithCache creates a forward path on a Gemma model for one decoding step, 15 | // using the weights in Config to initialize the variables. 16 | // 17 | // It takes as input the current token to decode currentTokens (shape [batchSize, 1] in a sequence along 18 | // with currentPosition (shape [batchSize, 1]) and the current cache of the key/values for each transformer 19 | // layer (see Cache), whose elements are generally shaped [batchSize, MaxCacheLength,...]. 20 | // 21 | // It updates the Cache with the new step in-place, and returns the logits (shape [batchSize, ]) 22 | // of the prediction of the next token. 23 | func GemmaWithCache(ctx *context.Context, config *Config, 24 | currentTokens, currentPositions *Node, cache *trees.Tree[*Node], cacheAttentionMask *Node) *Node { 25 | batchSize := currentTokens.Shape().Dim(0) 26 | seqLength := currentTokens.Shape().Dim(1) 27 | 28 | // Embed. 29 | x := EmbedTokens(ctx.In("embedder"), config, currentTokens) 30 | 31 | // Run through numLayers blocks. 32 | for blockIdx := range config.NumLayers { 33 | blockName := fmt.Sprintf("layer_%d", blockIdx) 34 | blockCtx := ctx.In(blockName) 35 | blockCache := cache.Map[blockName] 36 | x = Block(blockCtx, config, blockIdx, x, currentPositions, blockCache, cacheAttentionMask) 37 | //x.SetLogged(fmt.Sprintf("GemmaWithCache::x(%s)", blockName)) 38 | x = Identity(x) 39 | } 40 | 41 | x = RMSNorm(ctx.In("final_norm"), x) 42 | logits := DecodeTokens(ctx.Reuse().In("embedder"), config, x) 43 | logits = SoftCap(logits, config.FinalLogitSoftCap) 44 | logits.AssertDims(batchSize, seqLength, config.VocabularySize) 45 | return logits 46 | } 47 | 48 | // EmbedTokens using weights in Config. 49 | // Input: currentTokens: [batchSize, sequenceLength] 50 | // Output: embeddings: [batchSize, sequenceLength, config.EmbedDim] 51 | func EmbedTokens(ctx *context.Context, config *Config, currentTokens *Node) *Node { 52 | g := currentTokens.Graph() 53 | embedTableVar := ctx.VariableWithShape("input_embedding", shapes.Make(dtypes.BFloat16, config.VocabularySize, config.EmbedDim)) 54 | embeddings := Gather(embedTableVar.ValueGraph(g), ExpandAxes(currentTokens, -1)) 55 | embeddings = Mul(embeddings, Sqrt(Scalar(g, embeddings.DType(), config.EmbedDim))) 56 | return embeddings 57 | } 58 | 59 | // DecodeTokens use the same table as EmbedTokens to convert embedding back to the tokens -- or to token logits. 60 | // Input: current embeddings: [batchSize, sequenceLength, embedDim] 61 | // Output: logits for each token: [batchSize, sequenceLength, vocabularySize] 62 | func DecodeTokens(ctx *context.Context, config *Config, x *Node) *Node { 63 | g := x.Graph() 64 | embedTableVar := ctx.VariableWithShape("input_embedding", shapes.Make(dtypes.BFloat16, config.VocabularySize, config.EmbedDim)) 65 | embedTable := embedTableVar.ValueGraph(g) 66 | return DotGeneral(x, []int{-1}, nil, embedTable, []int{-1}, nil) 67 | } 68 | 69 | // Block implements one transformer block for the Gemma model. x is shaped [batchSize, sequenceLength], and if 70 | // using cache (cache != nil), x will only contain the current token, shaped [batchSize, 1]. 71 | // 72 | // The attentionIdx indexes attention configuration (in config) parameters, like config.AttentionTypes. 73 | // 74 | // If cache is given, attentionMask is relative to the cache. Otherwise, attentionMask is relative to the operand x. 75 | func Block(ctx *context.Context, config *Config, attentionIdx int, x, positions *Node, cache *trees.Tree[*Node], attentionMask *Node) *Node { 76 | normalizedX := RMSNorm(ctx.In("pre_attention_norm"), x) 77 | 78 | // Attention 79 | attentionOut := Attention(ctx.In("attn"), config, attentionIdx, normalizedX, positions, cache, attentionMask) 80 | if config.UsePostAttentionNorm { 81 | attentionOut = RMSNorm(ctx.In("post_attention_norm"), attentionOut) 82 | } 83 | 84 | // Residual (or skip) connection. 85 | attentionOut = Add(attentionOut, x) 86 | 87 | // GatedFeedForward ("ffw") layer: 2 layers, with a gate. 88 | output := RMSNorm(ctx.In("pre_ffw_norm"), attentionOut) 89 | if config.HuggingFaceVersion { 90 | output = HuggingFaceGatedFeedForward(ctx.In("mlp"), output, config.HiddenDim, config.TransposeGatingEinsum) 91 | } else { 92 | output = GatedFeedForward(ctx.In("mlp"), output, config.HiddenDim, config.TransposeGatingEinsum) 93 | } 94 | if config.UsePostFFWNorm { 95 | output = RMSNorm(ctx.In("post_ffw_norm"), output) 96 | } 97 | 98 | // Residual to attentionOut. 99 | output = Add(output, attentionOut) 100 | return output 101 | } 102 | -------------------------------------------------------------------------------- /trees/trees.go: -------------------------------------------------------------------------------- 1 | package trees 2 | 3 | import ( 4 | "fmt" 5 | "github.com/gomlx/exceptions" 6 | "github.com/gomlx/gomlx/types/xslices" 7 | "github.com/pkg/errors" 8 | "golang.org/x/exp/slices" 9 | "iter" 10 | "strings" 11 | ) 12 | 13 | // Tree represent both a root of a tree, and a tree node. 14 | // 15 | // It can either be a Value or a Map of its children -- but not both. 16 | type Tree[T any] struct { 17 | // Value is set for leaf nodes only. 18 | Value T 19 | 20 | // Map is set for non-leaf nodes (and nil in leaf nodes). 21 | Map map[string]*Tree[T] 22 | } 23 | 24 | func (n *Tree[T]) IsLeaf() bool { return n.Map == nil } 25 | 26 | // Path is usually used as the path from the root node. 27 | type Path []string 28 | 29 | // New creates a new empty tree. 30 | func New[T any]() *Tree[T] { 31 | return &Tree[T]{ 32 | Map: make(map[string]*Tree[T]), 33 | } 34 | } 35 | 36 | // NewLeaf creates a new leaf node with the given value. 37 | func NewLeaf[T any](value T) *Tree[T] { 38 | return &Tree[T]{Value: value} 39 | } 40 | 41 | // DefaultTreePath is used whenever an empty treePath is given. 42 | var DefaultTreePath = []string{"#root"} 43 | 44 | // Get value in treePath. 45 | // It returns an error if such a leaf node doesn't exist. 46 | // 47 | // Empty values in treePath are not used. 48 | func (tree *Tree[T]) Get(treePath ...string) (value T, err error) { 49 | // Remove empty ("") path components -- clone the slice, not to modify caller's slice. 50 | if slices.Index(treePath, "") > 0 { 51 | treePath = slices.DeleteFunc(slices.Clone(treePath), 52 | func(s string) bool { 53 | return s == "" 54 | }) 55 | } 56 | remainingPath := treePath 57 | pathCount := 0 58 | for len(remainingPath) > 0 { 59 | if tree == nil { 60 | err = errors.Errorf("trees.Tree[%T].Get(%q) can't get to sub-path %q, tree ends in a nil node, can't go forward", 61 | value, treePath, treePath[:pathCount+1]) 62 | } 63 | if tree.IsLeaf() { 64 | err = errors.Errorf("trees.Tree[%T].Get(%q) the sub-path %q ends on a leaf-node, can't go forward", 65 | value, treePath, treePath[:pathCount+1]) 66 | return 67 | } 68 | tree = tree.Map[remainingPath[0]] 69 | remainingPath = remainingPath[1:] 70 | pathCount++ 71 | } 72 | if tree == nil { 73 | err = errors.Errorf("trees.Tree[%T].Get(%q) can't get to sub-path %q, tree ends in a nil node, can't go forward", 74 | value, treePath, treePath[:pathCount+1]) 75 | } 76 | if !tree.IsLeaf() { 77 | err = errors.Errorf("trees.Tree[%T].Get(%q) is not a leaf-node!?", value, treePath) 78 | return 79 | } 80 | value = tree.Value 81 | return 82 | } 83 | 84 | // Set value in treePath, populating intermediary nodes where needed. 85 | // 86 | // Empty values in treePath are not used. 87 | // 88 | // It returns an error if one is trying to set the value to an existing non-leaf node: nodes can either 89 | // be a leaf or a Map (non-leaf), but not both. 90 | func (tree *Tree[T]) Set(treePath Path, value T) error { 91 | // Remove empty ("") path components -- clone the slice, not to modify caller's slice. 92 | if slices.Index(treePath, "") > 0 { 93 | treePath = slices.DeleteFunc(slices.Clone(treePath), 94 | func(s string) bool { 95 | return s == "" 96 | }) 97 | } 98 | remainingPath := treePath 99 | pathCount := 0 100 | node := tree 101 | for len(remainingPath) > 0 { 102 | pathElement := remainingPath[0] 103 | remainingPath = remainingPath[1:] 104 | if node.IsLeaf() { 105 | var t T 106 | return errors.Errorf("trees.Tree[%T].Set(%q) trying to create a path using an existing leaf node (%q) as a non-leaf node", 107 | t, treePath, treePath[:pathCount]) 108 | } 109 | newNode := node.Map[pathElement] 110 | if newNode == nil { 111 | if len(remainingPath) == 0 { 112 | newNode = NewLeaf[T](value) 113 | } else { 114 | newNode = New[T]() 115 | } 116 | node.Map[pathElement] = newNode 117 | } 118 | node = newNode 119 | pathCount++ 120 | } 121 | if !node.IsLeaf() { 122 | var t T 123 | return errors.Errorf("trees.Tree[%T].Set(%q) trying to set the value to a non-leaf node -- each node can either be a leaf node, or be a structural map of the tree", 124 | t, treePath) 125 | } 126 | node.Value = value 127 | return nil 128 | } 129 | 130 | // String implements fmt.String 131 | func (tree *Tree[T]) String() string { 132 | var parts []string 133 | parts = nodeToString(parts, "/", tree, 0) 134 | return strings.Join(parts, "\n") + "\n" 135 | } 136 | 137 | func nodeToString[T any](parts []string, name string, subTree *Tree[T], indent int) []string { 138 | indentSpaces := strings.Repeat(" ", indent) 139 | indent++ 140 | if len(subTree.Map) == 0 { 141 | // Leaf node. 142 | var valueAny any 143 | valueAny = subTree.Value 144 | if valueStr, ok := valueAny.(fmt.Stringer); ok { 145 | // T is a stringer: 146 | return append(parts, fmt.Sprintf("%s%q: %s", indentSpaces, name, valueStr)) 147 | } 148 | // If not a stringer, use %v. 149 | return append(parts, fmt.Sprintf("%s%q: %v", indentSpaces, name, subTree.Value)) 150 | } 151 | parts = append(parts, fmt.Sprintf("%s%q: {", indentSpaces, name)) 152 | 153 | for _, key := range xslices.SortedKeys(subTree.Map) { 154 | parts = nodeToString(parts, key, subTree.Map[key], indent) 155 | } 156 | parts = append(parts, fmt.Sprintf("%s}", indentSpaces)) 157 | return parts 158 | } 159 | 160 | // Map converts a Tree[T1] to a Tree[T2] by calling mapFn at every element. 161 | func Map[T1, T2 any](tree1 *Tree[T1], mapFn func(Path, T1) T2) *Tree[T2] { 162 | if tree1.IsLeaf() { 163 | // Input tree1 is just a leaf node: 164 | return NewLeaf[T2](mapFn(nil, tree1.Value)) 165 | } 166 | 167 | tree2 := New[T2]() 168 | for p, t1 := range tree1.Leaves() { 169 | err := tree2.Set(p, mapFn(p, t1)) 170 | if err != nil { 171 | // Should never happen, since there can be no errors duplicating the structure of an existing valid tree. 172 | panic(err) 173 | } 174 | } 175 | return tree2 176 | } 177 | 178 | // Leaves returns an iterator that goes over all the leaf nodes of the Tree. 179 | // The key is a Path, and value is T. 180 | func (tree *Tree[T]) Leaves() iter.Seq2[Path, T] { 181 | return func(yield func(Path, T) bool) { 182 | recursiveLeaves(nil, tree, false, yield) 183 | } 184 | } 185 | 186 | // NumLeaves traverses the trees and returns the number of leaf nodes. 187 | func (tree *Tree[T]) NumLeaves() int { 188 | var count int 189 | for _, _ = range tree.Leaves() { 190 | count++ 191 | } 192 | return count 193 | } 194 | 195 | // OrderedLeaves returns an iterator that goes over all the leaf nodes of the Tree in alphabetical order of the 196 | // tree nodes (depth-first). 197 | // 198 | // The key is a Path, and value is T. 199 | func (tree *Tree[T]) OrderedLeaves() iter.Seq2[Path, T] { 200 | return func(yield func(Path, T) bool) { 201 | recursiveLeaves(nil, tree, true, yield) 202 | } 203 | } 204 | 205 | func recursiveLeaves[T any](treePath Path, node *Tree[T], ordered bool, yield func(Path, T) bool) bool { 206 | if node.IsLeaf() { 207 | return yield(slices.Clone(treePath), node.Value) 208 | } 209 | if ordered { 210 | // Extract keys and sort first. 211 | for _, key := range xslices.SortedKeys(node.Map) { 212 | subNode := node.Map[key] 213 | ok := recursiveLeaves[T](append(treePath, key), subNode, ordered, yield) 214 | if !ok { 215 | return false 216 | } 217 | } 218 | } else { 219 | // Usual range over map, non-deterministic. 220 | for key, subNode := range node.Map { 221 | ok := recursiveLeaves(append(treePath, key), subNode, ordered, yield) 222 | if !ok { 223 | return false 224 | } 225 | } 226 | } 227 | return true 228 | } 229 | 230 | // ValuesAsList extracts the leaf values of Tree into a list. 231 | // 232 | // It's generated in alphabetical order -- see OrderedLeaves to see or generate the order. 233 | func ValuesAsList[T any](tree *Tree[T]) []T { 234 | results := make([]T, 0, tree.NumLeaves()) 235 | for _, values := range tree.OrderedLeaves() { 236 | results = append(results, values) 237 | } 238 | return results 239 | } 240 | 241 | // FromValuesAndTree creates a Tree[T1] with the given values, but borrowing the structure from the given tree (but 242 | // ignoring the tree's values). 243 | func FromValuesAndTree[T1, T2 any](values []T1, tree *Tree[T2]) *Tree[T1] { 244 | numLeaves := tree.NumLeaves() 245 | if len(values) != numLeaves { 246 | exceptions.Panicf("%d values given, but the tree to be built has %d leaves.", len(values), numLeaves) 247 | } 248 | newTree := New[T1]() 249 | var idx int 250 | for treePath, _ := range tree.OrderedLeaves() { 251 | err := newTree.Set(treePath, values[idx]) 252 | if err != nil { 253 | // Should never happen, since there can be no errors duplicating the structure of an existing valid tree. 254 | panic(err) 255 | } 256 | idx++ 257 | } 258 | return newTree 259 | } 260 | -------------------------------------------------------------------------------- /trees/trees_test.go: -------------------------------------------------------------------------------- 1 | package trees 2 | 3 | import ( 4 | "fmt" 5 | "github.com/stretchr/testify/require" 6 | "testing" 7 | ) 8 | 9 | type expectedTreeValueType[T any] struct { 10 | p Path 11 | v T 12 | } 13 | 14 | func verifyTreeValues[T any](t *testing.T, tree *Tree[T], wantValues []expectedTreeValueType[T]) { 15 | for _, want := range wantValues { 16 | got, err := tree.Get(want.p...) 17 | require.NoError(t, err) 18 | require.Equal(t, want.v, got) 19 | } 20 | count := 0 21 | for p, v := range tree.OrderedLeaves() { 22 | if count >= len(wantValues) { 23 | t.Fatalf("tree ranged over more leaves than the %d expected", len(wantValues)) 24 | } 25 | require.Equalf(t, wantValues[count].p, p, "Unexpected path %q -- maybe out-of-order?", p) 26 | require.Equalf(t, wantValues[count].v, v, "Unexpected value for path %q", p) 27 | count++ 28 | } 29 | if count != len(wantValues) { 30 | t.Fatalf("tree only ranged over %d leaf-values, but we expected %d values", count, len(wantValues)) 31 | } 32 | } 33 | 34 | func createTestTree(t *testing.T) *Tree[int] { 35 | tree := New[int]() 36 | require.NoError(t, tree.Set([]string{"a"}, 1)) 37 | require.NoError(t, tree.Set([]string{"b", "y"}, 3)) 38 | require.NoError(t, tree.Set([]string{"b", "x"}, 2)) 39 | return tree 40 | } 41 | 42 | func TestNewAndSet(t *testing.T) { 43 | tree := createTestTree(t) 44 | fmt.Printf("Tree:\n%v\n", tree) 45 | 46 | require.Equal(t, 1, tree.Map["a"].Value) 47 | require.Equal(t, 2, tree.Map["b"].Map["x"].Value) 48 | require.Equal(t, 3, tree.Map["b"].Map["y"].Value) 49 | 50 | err := tree.Set([]string{"b"}, 4) 51 | fmt.Printf("\texpected error trying to set non-leaf node: %v\n", err) 52 | require.ErrorContains(t, err, "trying to set the value to a non-leaf node") 53 | 54 | err = tree.Set([]string{"b", "x", "0"}, 5) 55 | fmt.Printf("\texpected error trying to use leaf node as structure: %v\n", err) 56 | require.ErrorContains(t, err, "trying to create a path using an existing leaf node") 57 | 58 | tree2 := NewLeaf(float32(7)) 59 | fmt.Printf("Tree:\n%v\n", tree2) 60 | require.NoError(t, tree2.Set(nil, float32(11))) 61 | require.Equal(t, float32(11), tree2.Value) 62 | } 63 | 64 | func TestOrderedLeaves(t *testing.T) { 65 | tree := createTestTree(t) 66 | fmt.Printf("Tree:\n%v\n", tree) 67 | // Test OrderedLeaves traversal and that the contents of the tree match. 68 | verifyTreeValues(t, tree, []expectedTreeValueType[int]{ 69 | {Path{"a"}, 1}, 70 | {Path{"b", "x"}, 2}, 71 | {Path{"b", "y"}, 3}, 72 | }) 73 | } 74 | 75 | func TestMap(t *testing.T) { 76 | tree := createTestTree(t) 77 | fmt.Printf("Tree:\n%v\n", tree) 78 | treeFloat := Map(tree, func(_ Path, v int) float32 { return float32(v) }) 79 | verifyTreeValues(t, treeFloat, []expectedTreeValueType[float32]{ 80 | {Path{"a"}, 1}, 81 | {Path{"b", "x"}, 2}, 82 | {Path{"b", "y"}, 3}, 83 | }) 84 | 85 | tree2 := NewLeaf(float32(7)) 86 | fmt.Printf("Tree:\n%v\n", tree2) 87 | tree2Int := Map(tree2, func(_ Path, v float32) int { return int(v) }) 88 | verifyTreeValues(t, tree2Int, []expectedTreeValueType[int]{ 89 | {nil, 7}, 90 | }) 91 | } 92 | 93 | func TestValuesAsList(t *testing.T) { 94 | tree := createTestTree(t) 95 | fmt.Printf("Tree:\n%v\n", tree) 96 | require.Equal(t, []int{1, 2, 3}, ValuesAsList(tree)) 97 | 98 | tree2 := NewLeaf(float32(7)) 99 | fmt.Printf("Tree:\n%v\n", tree2) 100 | require.Equal(t, []float32{7}, ValuesAsList(tree2)) 101 | } 102 | 103 | func TestFromValuesAndTree(t *testing.T) { 104 | tree := createTestTree(t) 105 | newValues := []float64{1.01, 2.02, 3.03} 106 | newTree := FromValuesAndTree(newValues, tree) 107 | fmt.Printf("New Tree:\n%v\n", newTree) 108 | verifyTreeValues(t, newTree, []expectedTreeValueType[float64]{ 109 | {Path{"a"}, 1.01}, 110 | {Path{"b", "x"}, 2.02}, 111 | {Path{"b", "y"}, 3.03}, 112 | }) 113 | } 114 | --------------------------------------------------------------------------------