├── .gitignore ├── LICENSE ├── README.md ├── block.go ├── data ├── data.go ├── data_test.go ├── fairy_tales.txt ├── fairy_vocab ├── jules_verne.txt └── vocab ├── go.mod ├── go.sum ├── head.go ├── layer.go ├── main.go ├── main_test.go └── pkg ├── adamw.go ├── cat.go ├── cat_test.go ├── functions.go ├── matmul.go ├── matmul_test.go ├── mean.go ├── mean_test.go ├── params.go ├── softmax_test.go ├── variance.go └── variance_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/** 2 | .idea/** 3 | **/.DS_Store 4 | /model-* 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Artem Zakirullin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | gptgo 2 | 3 | # gpt-go 4 | Simple GPT implementation in pure Go. Trained on favourite Jules Verne books. 5 | 6 | What kind of response you can expect from the model: 7 | ``` 8 | Mysterious Island. 9 | Well. 10 | My days must follow 11 | ``` 12 | 13 | Or this: 14 | ``` 15 | Captain Nemo, in two hundred thousand feet weary in 16 | the existence of the world. 17 | ``` 18 | 19 | ## How to run 20 | ```shell 21 | $ go run . 22 | ``` 23 | 24 | It takes about 40 minutes to train on MacBook Air M3. You can train on your own dataset by pointing the `data.dataset` variable to your text corpus. 25 | 26 | To run in chat-only mode once the training is done: 27 | ```shell 28 | $ go run . -chat 29 | ``` 30 | 31 | ## How to understand 32 | You can use this repository as a companion to the [Neural Networks: Zero to Hero](https://karpathy.ai/zero-to-hero.html) course. Use `git checkout ` to see how the model has evolved over time: `naive`, `bigram`, `multihead`, `block`, `residual`, `full`. 33 | 34 | In [main_test.go](https://github.com/zakirullin/gpt-go/blob/main/main_test.go) you will find explanations starting from basic neuron example: 35 | ```go 36 | // Our neuron has 2 inputs and 1 output (number of columns in weight matrix). 37 | // Its goal is to predict next number in the sequence. 38 | input := V{1, 2} // {x1, x2} 39 | weight := M{ 40 | {2}, // how much x1 contributes to the output 41 | {3}, // how much x2 contributes to the output 42 | } 43 | ``` 44 | 45 | All the way to self-attention mechanism: 46 | ```go 47 | // To calculate the sum of all previous tokens, we can multiply by this triangular matrix: 48 | tril := M{ 49 | {1, 0, 0, 0}, // first token attends only at itself ("cat"), it can't look into the future 50 | {1, 1, 0, 0}, // second token attends at itself and the previous token ( "cat" + ", ") 51 | {1, 1, 1, 0}, // third token attends at itself and the two previous tokens ("cat" + ", " + "dog") 52 | {1, 1, 1, 1}, // fourth token attends at itself and all the previous tokens ("cat" + ", " + "dog" + " and") 53 | }.Var() 54 | // So, at this point each embedding is enriched with the information from all the previous tokens. 55 | // That's the crux of self-attention. 56 | enrichedEmbeds := MatMul(tril, inputEmbeds) 57 | ``` 58 | 59 | ## Design choices 60 | No batches. 61 | I've given up the complexity of the batch dimension for the sake of better understanding. It's far easier to build intuition with 2D matrices, rather than with 3D tensors. Besides, batches aren't inherent to the transformer architecture. As an alternative, gradient accumulation was tried. The effect was negligible, so it was removed as well. 62 | 63 | Removed `gonum`. 64 | The `gonum.matmul` gave us ~30% performance boost, but it brought additional dependency. We're not striving for maximum efficiency here, rather for radical simplicity. Current matmul implementation is quite effective, and it's only 40 lines of plain readable code. 65 | 66 | ## Papers 67 | You don't need to read them to understand the code :) 68 | 69 | [Attention Is All You Need](https://arxiv.org/abs/1706.03762) 70 | [Deep Residual Learning](https://arxiv.org/abs/1512.03385) 71 | [DeepMind WaveNet](https://arxiv.org/abs/1609.03499) 72 | [Batch Normalization](https://arxiv.org/abs/1502.03167) 73 | [Deep NN + huge data = breakthrough performance](https://papers.nips.cc/paper_files/paper/2012/hash/c399862d3b9d6b76c8436e924a68c45b-Abstract.html) 74 | [OpenAI GPT-3 paper](https://arxiv.org/abs/2005.14165) 75 | [Analyzing the Structure of Attention](https://arxiv.org/abs/1906.04284) 76 | 77 | ## Credits 78 | Many thanks to [Andrej Karpathy](https://github.com/karpathy) for his brilliant [Neural Networks: Zero to Hero](https://karpathy.ai/zero-to-hero.html) course. 79 | 80 | Thanks to [@itsubaki](https://github.com/itsubaki) for his elegant [autograd](https://github.com/itsubaki/autograd) package. 81 | -------------------------------------------------------------------------------- /block.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/itsubaki/autograd/function" 5 | "github.com/itsubaki/autograd/layer" 6 | "github.com/itsubaki/autograd/variable" 7 | 8 | "github.com/zakirullin/gpt-go/pkg" 9 | ) 10 | 11 | var ( 12 | Zeros = variable.Zero 13 | Ones = pkg.Ones 14 | ReLU = function.ReLU 15 | Dropout = function.DropoutSimple 16 | MatMul = pkg.MatMul 17 | Add = variable.Add 18 | Sub = variable.Sub 19 | MulC = variable.MulC 20 | Transpose = variable.Transpose 21 | Softmax = function.Softmax 22 | SoftmaxCrossEntropy = function.SoftmaxCrossEntropy 23 | RandEmbeds = pkg.Normal 24 | Rows = pkg.Rows 25 | Val = pkg.Val 26 | Flat = pkg.Flat 27 | ) 28 | 29 | type Block struct { 30 | embedSize int 31 | headCount int 32 | saHead *MultiHeadAttention 33 | mlp *Linear // multi-layer perceptron 34 | mlpProj *Linear // projects the output of the MLP back to the original embedding size 35 | norm1 *LayerNorm 36 | norm2 *LayerNorm 37 | } 38 | 39 | func NewBlock(embedSize, numHeads int) *Block { 40 | return &Block{ 41 | embedSize: embedSize, 42 | headCount: numHeads, 43 | saHead: NewMultiHeadAttention(embedSize, numHeads), 44 | mlp: NewLinear(embedSize, embedSize*4), 45 | mlpProj: NewLinear(embedSize*4, embedSize), 46 | norm1: NewLayerNorm(embedSize), 47 | norm2: NewLayerNorm(embedSize), 48 | } 49 | } 50 | 51 | func (b *Block) Forward(input *variable.Variable) *variable.Variable { 52 | // Self-attention with residual connections. Input is our highway, we allow the gradient to flow back unimpeded. 53 | input = b.norm1.Forward(input) // Normalize input (mean=0, var=1), i.e. normalize every token's embed 54 | saOut := b.saHead.Forward(input) // Encode relationships between positions, (blockSize, embedSize) 55 | input = Add(input, saOut) // Add residual attention output back to main path 56 | 57 | // Feed-forward network with residual connection 58 | input = b.norm2.Forward(input) // Normalize input 59 | mlpExpanded := b.mlp.Forward(input) // Expand to higher dimension 60 | mlpActivated := ReLU(mlpExpanded) // Apply activation function 61 | mlpOutput := b.mlpProj.Forward(mlpActivated) // Project back to original dimension 62 | mlpOutput = Dropout(dropout)(mlpOutput) // Dropping out some neurons to prevent overfitting 63 | input = Add(input, mlpOutput) // Add feed-forward residual output to main path 64 | 65 | return input 66 | } 67 | 68 | func (b *Block) Params() []layer.Parameter { 69 | var params []layer.Parameter 70 | for _, param := range b.saHead.Params() { 71 | params = append(params, param) 72 | } 73 | params = append(params, b.mlp.Weight, b.mlp.Bias) 74 | params = append(params, b.mlpProj.Weight, b.mlpProj.Bias) 75 | params = append(params, b.norm1.Scale, b.norm1.Shift) 76 | params = append(params, b.norm2.Scale, b.norm2.Shift) 77 | 78 | return params 79 | } 80 | -------------------------------------------------------------------------------- /data/data.go: -------------------------------------------------------------------------------- 1 | package data 2 | 3 | import ( 4 | _ "embed" 5 | "fmt" 6 | "math/rand/v2" 7 | "regexp" 8 | "slices" 9 | "strings" 10 | 11 | "github.com/itsubaki/autograd/variable" 12 | ) 13 | 14 | var ( 15 | //go:embed jules_verne.txt 16 | dataset string 17 | //go:embed vocab 18 | vocab string 19 | 20 | tokenToID map[string]int 21 | idToToken map[int]string 22 | mergeRules map[int64]int 23 | rulesOrder []int64 24 | 25 | Dataset = func() string { return dataset } 26 | Vocab = func() string { return vocab } 27 | RandInt = rand.IntN 28 | ) 29 | 30 | func Tokenize(numMerges int) ([]float64, int) { 31 | tokenToID = make(map[string]int) 32 | idToToken = make(map[int]string) 33 | mergeRules = make(map[int64]int) 34 | rulesOrder = nil 35 | 36 | normDataset := normNewLines(Dataset()) 37 | addCharsToVocab(normDataset) 38 | createMergeRules(Vocab(), numMerges) 39 | 40 | return Encode(normDataset), VocabSize() 41 | } 42 | 43 | func Encode(s string) []float64 { 44 | var tokens []float64 45 | for _, ch := range s { 46 | tok, ok := tokenToID[string(ch)] 47 | if !ok { 48 | panic(fmt.Sprintf("char '%s' is missing from vocabulary", string(ch))) 49 | } 50 | tokens = append(tokens, float64(tok)) 51 | } 52 | 53 | for _, rule := range rulesOrder { 54 | var newTokens []float64 55 | tok1, tok2 := unzip(rule) 56 | // Try to apply rule on every pair of tokens 57 | for i := 0; i < len(tokens); { 58 | hasNextToken := i+1 < len(tokens) 59 | shouldMerge := hasNextToken && int(tokens[i]) == tok1 && int(tokens[i+1]) == tok2 60 | if shouldMerge { 61 | newTokens = append(newTokens, float64(mergeRules[rule])) 62 | i += 2 // eat two tokens 63 | } else { 64 | newTokens = append(newTokens, tokens[i]) 65 | i++ // eat one token 66 | } 67 | } 68 | tokens = newTokens 69 | } 70 | 71 | return tokens 72 | } 73 | 74 | func Decode(indices ...float64) string { 75 | var result strings.Builder 76 | 77 | for _, idx := range indices { 78 | id := int(idx) 79 | if token, ok := idToToken[id]; ok { 80 | result.WriteString(token) 81 | } else { 82 | panic(fmt.Sprintf("uknown token id=%d", id)) 83 | } 84 | } 85 | 86 | return result.String() 87 | } 88 | 89 | func VocabSize() int { 90 | return len(tokenToID) 91 | } 92 | 93 | // Sample returns a random sample of data of the given block size. 94 | func Sample(data []float64, blockSize int) (*variable.Variable, *variable.Variable) { 95 | dataLen := len(data) - (blockSize + 1) 96 | if dataLen < 0 { 97 | panic("not enough data for the given block size") 98 | } 99 | 100 | offset := RandInt(dataLen) 101 | 102 | x := make([]float64, blockSize) 103 | y := make([]float64, blockSize) 104 | 105 | for i := 0; i < blockSize; i++ { 106 | x[i] = data[i+offset] 107 | y[i] = data[i+offset+1] 108 | } 109 | 110 | return variable.New(x...), variable.New(y...) 111 | } 112 | 113 | func Chars() string { 114 | var tokens []string 115 | for token := range tokenToID { 116 | if len(token) == 1 { 117 | tokens = append(tokens, token) 118 | } 119 | } 120 | slices.Sort(tokens) 121 | 122 | return strings.Join(tokens, "") 123 | } 124 | 125 | func addCharsToVocab(text string) { 126 | var chars []string 127 | for _, ch := range text { 128 | chars = append(chars, string(ch)) 129 | } 130 | addTokensToVocab(chars...) 131 | } 132 | 133 | func createMergeRules(rules string, numMerges int) { 134 | rules = strings.TrimSpace(rules) 135 | if len(rules) == 0 { 136 | return 137 | } 138 | 139 | merges := strings.Split(rules, "\n") 140 | merges = merges[:min(numMerges, len(merges))] 141 | 142 | // Mint new tokens, save merge rules. 143 | for _, m := range merges { 144 | re := regexp.MustCompile(`\[(.*?)\]\[(.*?)\] -> \[(.*?)\]`) 145 | matches := re.FindStringSubmatch(m) 146 | if len(matches) != 4 { 147 | panic(fmt.Sprintf("invalid vocab format: %s", m)) 148 | } 149 | 150 | // Process Unicode escape sequences in all tokens. 151 | left := strings.ReplaceAll(matches[1], "\\n", "\n") 152 | right := strings.ReplaceAll(matches[2], "\\n", "\n") 153 | mergedToken := strings.ReplaceAll(matches[3], "\\n", "\n") 154 | 155 | addTokensToVocab(mergedToken) 156 | 157 | for _, token := range []string{left, right, mergedToken} { 158 | if _, ok := tokenToID[token]; !ok { 159 | panic(fmt.Sprintf("rule '%s' is malformed, token '%s' is missing from vocabulary", m, token)) 160 | } 161 | } 162 | addRule(tokenToID[left], tokenToID[right], tokenToID[mergedToken]) 163 | } 164 | } 165 | 166 | func addTokensToVocab(tokens ...string) { 167 | for _, token := range tokens { 168 | if _, ok := tokenToID[token]; ok { 169 | continue 170 | } 171 | 172 | tokenID := len(tokenToID) 173 | tokenToID[token] = tokenID 174 | idToToken[tokenID] = token 175 | } 176 | } 177 | 178 | func addRule(tok1, tok2, mergedTok int) { 179 | key := zip(tok1, tok2) 180 | mergeRules[key] = mergedTok 181 | rulesOrder = append(rulesOrder, key) 182 | } 183 | 184 | func normNewLines(text string) string { 185 | text = strings.Replace(text, "\r\n", "\n", -1) // replace Windows line endings 186 | return strings.Replace(text, "\r", "\n", -1) // replace remaining Mac line endings 187 | } 188 | 189 | // Zips two tokens into a single int64. 190 | func zip(tok1, tok2 int) int64 { 191 | return int64(tok1)<<32 | int64(tok2&0xFFFFFFFF) 192 | } 193 | 194 | // Unzips a single int64 into two tokens. 195 | func unzip(tok int64) (int, int) { 196 | tok1 := int(tok >> 32) 197 | tok2 := int(tok & 0xFFFFFFFF) 198 | return tok1, tok2 199 | } 200 | -------------------------------------------------------------------------------- /data/data_test.go: -------------------------------------------------------------------------------- 1 | package data 2 | 3 | import ( 4 | "math" 5 | "math/rand" 6 | "testing" 7 | 8 | "github.com/itsubaki/autograd/variable" 9 | ) 10 | 11 | func TestTokenize(t *testing.T) { 12 | Dataset = func() string { 13 | return "abc" 14 | } 15 | Vocab = func() string { 16 | return "[a][b] -> [z]" 17 | } 18 | encoded, vocabSize := Tokenize(1) 19 | 20 | areSlicesEqual(t, []float64{3, 2}, encoded) 21 | areEqual(t, 4, vocabSize) 22 | } 23 | 24 | func TestEncode(t *testing.T) { 25 | Dataset = func() string { 26 | return "abcd" 27 | } 28 | Vocab = func() string { 29 | return "[a][a] -> [Z]\n[a][b] -> [Y]\n[Z][Y] -> [X]" 30 | } 31 | 32 | Tokenize(3) 33 | 34 | // aaabdaaabac 35 | // ZabdZabac 36 | // ZYdZYac 37 | // XdXac 38 | encoded := Encode("aaabdaaabac") 39 | areSlicesEqual(t, []float64{6, 3, 6, 0, 2}, encoded) 40 | } 41 | 42 | func TestDecode(t *testing.T) { 43 | Dataset = func() string { 44 | return "abcd" 45 | } 46 | Vocab = func() string { 47 | return "[a][a] -> [aa]\n[a][b] -> [ab]\n[aa][ab] -> [aaab]" 48 | } 49 | 50 | Tokenize(3) 51 | 52 | decoded := Decode([]float64{6, 3, 6, 0, 2}...) 53 | 54 | areEqual(t, "aaabdaaabac", decoded) 55 | } 56 | 57 | func TestEncodeDecodeDifferentNewLines(t *testing.T) { 58 | Dataset = func() string { 59 | return "a\nb\r\nc" 60 | } 61 | Vocab = func() string { 62 | return "" 63 | } 64 | 65 | encoded, _ := Tokenize(3) 66 | areSlicesEqual(t, []float64{0, 1, 2, 1, 3}, encoded) 67 | 68 | decoded := Decode([]float64{0, 1, 2, 1, 3}...) 69 | areEqual(t, "a\nb\nc", decoded) 70 | } 71 | 72 | func TestEncodeDecodeTokenizedNewLines(t *testing.T) { 73 | Dataset = func() string { 74 | return "a\nb\n\nc" 75 | } 76 | Vocab = func() string { 77 | return "[\\n][\\n] -> [\\n\\n]" 78 | } 79 | 80 | encoded, _ := Tokenize(1) 81 | areSlicesEqual(t, []float64{0, 1, 2, 4, 3}, encoded) 82 | 83 | decoded := Decode([]float64{0, 1, 2, 4, 3}...) 84 | areEqual(t, "a\nb\n\nc", decoded) 85 | } 86 | 87 | func TestZipUnzip(t *testing.T) { 88 | zipped := zip(1, 2) 89 | expected := int64(4294967298) 90 | areEqual(t, zipped, expected) 91 | x, y := unzip(expected) 92 | areEqual(t, 1, x) 93 | areEqual(t, 2, y) 94 | } 95 | 96 | func TestAddTokensFromText(t *testing.T) { 97 | idToToken = make(map[int]string) 98 | tokenToID = make(map[string]int) 99 | 100 | testText := "hello world" 101 | addCharsToVocab(testText) 102 | 103 | contains := true 104 | for _, token := range []string{"h", "e", "l", "o", " ", "w", "r", "d"} { 105 | if _, exists := tokenToID[token]; !exists { 106 | contains = false 107 | break 108 | } 109 | } 110 | 111 | areEqual(t, 8, len(tokenToID)) 112 | areEqual(t, true, contains) 113 | } 114 | 115 | func TestSample(t *testing.T) { 116 | // Setup data 117 | testData := V{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} 118 | blockSize := 3 119 | 120 | RandInt = func(_ int) int { return 0 } 121 | defer func() { 122 | RandInt = rand.Intn 123 | }() 124 | 125 | x, y := Sample(testData, blockSize) 126 | areMatricesEqual(t, M{ 127 | {0, 1, 2}, 128 | }, x) 129 | areMatricesEqual(t, M{ 130 | {1, 2, 3}, 131 | }, y) 132 | } 133 | 134 | func TestNormNewLinesEmptyString(t *testing.T) { 135 | input := "" 136 | expected := "" 137 | result := normNewLines(input) 138 | areEqual(t, expected, result) 139 | } 140 | 141 | func TestNormNewLinesSingleLine(t *testing.T) { 142 | input := "hello world" 143 | expected := "hello world" 144 | result := normNewLines(input) 145 | areEqual(t, expected, result) 146 | } 147 | 148 | func TestNormNewLinesLinuxLineEndings(t *testing.T) { 149 | input := "line1\nline2\nline3" 150 | expected := "line1\nline2\nline3" 151 | result := normNewLines(input) 152 | areEqual(t, expected, result) 153 | } 154 | 155 | func TestNormNewLinesMixedContent(t *testing.T) { 156 | input := "title\n\nsome content\nmore content\n\nfooter" 157 | expected := "title\n\nsome content\nmore content\n\nfooter" 158 | result := normNewLines(input) 159 | areEqual(t, expected, result) 160 | } 161 | 162 | func TestNormNewLinesWindowsLineEndings(t *testing.T) { 163 | input := "line1\r\nline2\r\nline3" 164 | expected := "line1\nline2\nline3" 165 | result := normNewLines(input) 166 | areEqual(t, expected, result) 167 | } 168 | 169 | func TestNormNewLinesOldMacLineEndings(t *testing.T) { 170 | input := "line1\rline2\rline3" 171 | expected := "line1\nline2\nline3" 172 | result := normNewLines(input) 173 | areEqual(t, expected, result) 174 | } 175 | 176 | func TestNormNewLinesMixedLineEndings(t *testing.T) { 177 | input := "line1\nline2\r\nline3\rline4" 178 | expected := "line1\nline2\nline3\nline4" 179 | result := normNewLines(input) 180 | areEqual(t, expected, result) 181 | } 182 | 183 | func areEqual[V comparable](t *testing.T, want, got V) { 184 | t.Helper() 185 | if want != got { 186 | t.Errorf("want: %v, got: %v", want, got) 187 | } 188 | } 189 | 190 | func areSlicesEqual[T comparable](t *testing.T, want, got []T) { 191 | t.Helper() 192 | 193 | if len(want) != len(got) { 194 | t.Errorf("length mismatch: want %d elements, got %d elements", len(want), len(got)) 195 | t.Errorf("want: %v, got: %v", want, got) 196 | return 197 | } 198 | 199 | for i := range want { 200 | if want[i] != got[i] { 201 | t.Errorf("want: %v, got: %v", want, got) 202 | return 203 | } 204 | } 205 | } 206 | 207 | func areMatricesEqual(t *testing.T, want M, got *variable.Variable) { 208 | t.Helper() 209 | if len(want) != len(got.Data) { 210 | t.Errorf("matrix length mismatch: want length=%d, got length=%d", len(want), len(got.Data)) 211 | return 212 | } 213 | 214 | for i := range want { 215 | if len(want[i]) != len(got.Data[i]) { 216 | t.Errorf("matrix row length mismatch at row %d: want length=%d, got length=%d", i, len(want[i]), len(got.Data[i])) 217 | return 218 | } 219 | } 220 | 221 | for i := range want { 222 | for j := range want[i] { 223 | if math.Abs(want[i][j]-got.Data[i][j]) > 1e-9 { 224 | t.Errorf("matrix mismatch at row %d, column %d: want %v, got %v", i, j, want[i][j], got.Data[i][j]) 225 | } 226 | } 227 | } 228 | } 229 | 230 | type V []float64 231 | 232 | func (v V) Var() *variable.Variable { 233 | return variable.NewOf(v) 234 | } 235 | 236 | type M [][]float64 237 | 238 | func (m M) Var() *variable.Variable { 239 | return variable.NewOf(m...) 240 | } 241 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/zakirullin/gpt-go 2 | 3 | go 1.23.6 4 | 5 | require github.com/itsubaki/autograd v0.0.0-20250418093449-6f1c6692aa69 6 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/itsubaki/autograd v0.0.0-20250418093449-6f1c6692aa69 h1:7N9RwynBDvSMdrBCuoqKJrVw12cgQUIUSjbaQb1BrnQ= 2 | github.com/itsubaki/autograd v0.0.0-20250418093449-6f1c6692aa69/go.mod h1:KhCUsowk1ZUwjV10xjlMtCYucGJLwtQaBFGS43ogV8Y= 3 | -------------------------------------------------------------------------------- /head.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/itsubaki/autograd/layer" 7 | "github.com/itsubaki/autograd/variable" 8 | 9 | "github.com/zakirullin/gpt-go/pkg" 10 | ) 11 | 12 | var ( 13 | Tril = pkg.Tril 14 | MaskedInfFill = pkg.MaskedInfFill 15 | ) 16 | 17 | type MultiHeadAttention struct { 18 | numHeads int 19 | embedSize int 20 | headSize int 21 | Heads []*Head 22 | proj *Linear 23 | } 24 | 25 | func NewMultiHeadAttention(embedSize, numHeads int) *MultiHeadAttention { 26 | heads := make([]*Head, numHeads) 27 | headSize := embedSize / numHeads 28 | for i := range heads { 29 | heads[i] = NewHead(embedSize, headSize) 30 | } 31 | 32 | return &MultiHeadAttention{ 33 | Heads: heads, 34 | numHeads: numHeads, 35 | embedSize: embedSize, 36 | headSize: headSize, 37 | proj: NewLinear(embedSize, embedSize), 38 | } 39 | } 40 | 41 | func (mh *MultiHeadAttention) Forward(input *variable.Variable) *variable.Variable { 42 | var features []*variable.Variable 43 | for _, head := range mh.Heads { 44 | features = append(features, head.Forward(input)) 45 | } 46 | 47 | out := pkg.Cat(features...) 48 | out = mh.proj.Forward(out) // Project back to (embedSize, embedSize) 49 | out = Dropout(dropout)(out) // Dropping out some neurons to prevent overfitting 50 | 51 | return out 52 | } 53 | 54 | func (mh *MultiHeadAttention) Params() []layer.Parameter { 55 | var params []layer.Parameter 56 | for _, head := range mh.Heads { 57 | params = append(params, head.Query.Weight, head.Key.Weight, head.Value.Weight) 58 | } 59 | params = append(params, mh.proj.Weight, mh.proj.Bias) 60 | 61 | return params 62 | } 63 | 64 | type Head struct { 65 | embedSize int 66 | headSize int 67 | Key *Linear 68 | Query *Linear 69 | Value *Linear 70 | } 71 | 72 | // Number of embeds 73 | func NewHead(embedSize, headSize int) *Head { 74 | key := NewLinear(embedSize, headSize, NoBias()) 75 | query := NewLinear(embedSize, headSize, NoBias()) 76 | value := NewLinear(embedSize, headSize, NoBias()) 77 | 78 | return &Head{embedSize, headSize, key, query, value} 79 | } 80 | 81 | // Self-attention mechanism, see main_test.go for explanation. 82 | func (h *Head) Forward(input *variable.Variable) *variable.Variable { 83 | query := h.Query.Forward(input) 84 | key := h.Key.Forward(input) 85 | attentions := MatMul(query, Transpose(key)) 86 | 87 | T := len(input.Data) // number of tokens 88 | tril := Tril(Ones(T, T)) 89 | attentions = MaskedInfFill(attentions, tril) 90 | attentions = Softmax(attentions) 91 | attentions = Dropout(dropout)(attentions) 92 | 93 | v := h.Value.Forward(input) 94 | weightedSum := MatMul(attentions, v) 95 | normalizedSum := MulC(math.Pow(float64(h.embedSize), -0.5), weightedSum) 96 | 97 | return normalizedSum 98 | } 99 | -------------------------------------------------------------------------------- /layer.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/itsubaki/autograd/layer" 5 | "github.com/itsubaki/autograd/variable" 6 | 7 | "github.com/zakirullin/gpt-go/pkg" 8 | ) 9 | 10 | var ( 11 | Mean = pkg.Mean 12 | Variance = pkg.Variance 13 | Div = pkg.Div 14 | Mul = variable.Mul 15 | Pow = variable.Pow 16 | RandWeights = pkg.Normal 17 | ) 18 | 19 | type Linear struct { 20 | In, Out int 21 | Weight *variable.Variable 22 | Biased bool 23 | Bias *variable.Variable 24 | } 25 | 26 | func NewLinear(in, out int, opts ...LinearOption) *Linear { 27 | l := &Linear{ 28 | In: in, 29 | Out: out, 30 | Weight: RandWeights(in, out), 31 | Biased: true, 32 | Bias: variable.Zero(1, out), 33 | } 34 | 35 | for _, opt := range opts { 36 | opt(l) 37 | } 38 | 39 | return l 40 | } 41 | 42 | // Forward computes the output based on the input (forward pass) 43 | func (l *Linear) Forward(input *variable.Variable) *variable.Variable { 44 | logits := MatMul(input, l.Weight) 45 | 46 | if l.Biased { 47 | logits = Add(logits, l.Bias) 48 | } 49 | 50 | return logits 51 | } 52 | 53 | func (l *Linear) Params() []layer.Parameter { 54 | params := []layer.Parameter{ 55 | l.Weight, 56 | } 57 | 58 | if l.Biased { 59 | params = append(params, l.Bias) 60 | } 61 | 62 | return params 63 | } 64 | 65 | type LinearOption func(*Linear) 66 | 67 | func NoBias() LinearOption { 68 | return func(l *Linear) { 69 | l.Biased = false 70 | // Set bias tensors to nil or zero-sized tensors 71 | l.Bias = nil 72 | } 73 | } 74 | 75 | type LayerNorm struct { 76 | Scale *variable.Variable 77 | Shift *variable.Variable 78 | eps float64 79 | } 80 | 81 | func NewLayerNorm(dim int) *LayerNorm { 82 | return &LayerNorm{ 83 | eps: 1e-05, 84 | Scale: Ones(1, dim), 85 | Shift: Zeros(1, dim), 86 | } 87 | } 88 | 89 | // It is implemented using existing primitives, so back propagation will work 90 | func (ln *LayerNorm) Forward(x *variable.Variable) *variable.Variable { 91 | xmean := Mean(x) 92 | xvar := Variance(x) 93 | eps := variable.New(ln.eps) 94 | xhat := Div(Sub(x, xmean), Pow(0.5)(Add(xvar, eps))) 95 | out := Add(Mul(ln.Scale, xhat), ln.Shift) 96 | 97 | return out 98 | } 99 | 100 | func (ln *LayerNorm) Params() []layer.Parameter { 101 | return []layer.Parameter{ 102 | ln.Scale, 103 | ln.Shift, 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "flag" 6 | "fmt" 7 | "os" 8 | "strings" 9 | 10 | "github.com/zakirullin/gpt-go/data" 11 | "github.com/zakirullin/gpt-go/pkg" 12 | ) 13 | 14 | // Hyperparameters 15 | const ( 16 | blockSize = 32 17 | embedSize = 88 18 | heads = 4 19 | layers = 4 20 | learningRate = 0.0001 21 | steps = 80000 // number of training steps, increase for better results 22 | evalSteps = 1000 // evaluate loss once per every evalSteps 23 | dropout = 0.0 // disable some % of our neurons to prevent overfitting, model is likely to generalize 24 | pretrainedTokens = 6000 // number of pretrained tokens to add on top of auto-detected characters 25 | maxTokens = 50 // tokens limit for generation 26 | ) 27 | 28 | func main() { 29 | // Skip training if "-chat" flag is provided. 30 | steps := steps 31 | chat := flag.Bool("chat", false, "Skip training and jump straight to chat") 32 | flag.Parse() 33 | if *chat { 34 | steps = -1 35 | } 36 | 37 | // Loading dataset and building vocabulary. 38 | fmt.Println("Tokenizing dataset...") 39 | dataset, vocabSize := data.Tokenize(pretrainedTokens) 40 | fmt.Printf("First characters:\n%s\n", strings.TrimSpace(data.Decode(dataset[:45]...))) 41 | fmt.Printf("Vocabulary: %s\n", data.Chars()) 42 | fmt.Printf("Tokens in dataset: %.3fM\n", pkg.Millions(len(dataset))) 43 | 44 | // Basic transformer components. 45 | tokEmbeds := RandEmbeds(vocabSize, embedSize) 46 | posEmbeds := RandEmbeds(blockSize, embedSize) 47 | var blocks []*Block 48 | for range layers { 49 | blocks = append(blocks, NewBlock(embedSize, heads)) 50 | } 51 | norm := NewLayerNorm(embedSize) 52 | lmHead := NewLinear(embedSize, vocabSize) 53 | 54 | // Collecting all the parameters. 55 | params := pkg.NewParams() 56 | params.Add(tokEmbeds, posEmbeds) 57 | for _, block := range blocks { 58 | params.Add(block.Params()...) 59 | } 60 | params.Add(norm.Params()...) 61 | params.Add(lmHead.Params()...) 62 | params.TryLoadPretrained() 63 | fmt.Printf("Model size: %.3fM\n", pkg.Millions(params.Count())) 64 | 65 | // Training loop. 66 | losses := 0.0 67 | optimizer := pkg.NewAdamW(learningRate) 68 | fmt.Printf("bs=%d, es=%d, lr=%.4f, vs=%d, steps=%d\n", blockSize, embedSize, learningRate, vocabSize, steps) 69 | for i := 0; i < steps; i++ { 70 | // Targets contain the ground truth next token for each input token. 71 | input, targets := data.Sample(dataset, blockSize) 72 | 73 | // Forward pass, calculate predictions for every input token. 74 | embeds := Rows(tokEmbeds, Flat(input)...) // get embed for every input token 75 | embeds = Add(embeds, posEmbeds) // add positional embedding 76 | for _, block := range blocks { // self-attention and feed-forward 77 | embeds = block.Forward(embeds) 78 | } 79 | embeds = norm.Forward(embeds) 80 | logits := lmHead.Forward(embeds) // get scores for the next token for every context-enriched embed 81 | 82 | // Loss calculation, "how much our predicted targets differ from the ground truth targets?" 83 | // We average the loss over evalSteps iterations to smooth out fluctuations. 84 | loss := SoftmaxCrossEntropy(logits, targets) 85 | losses += Val(loss) 86 | fmt.Printf("\r%s", strings.Repeat("·", (i%evalSteps)*26/evalSteps)) // progress bar 87 | if i%evalSteps == 0 { 88 | avgLoss := losses / float64(min(i+1, evalSteps)) 89 | fmt.Printf("\rstep: %5d, loss: %.4f\n", i, avgLoss) 90 | losses = 0 91 | } 92 | 93 | // Backward pass, calculate the gradients (how much each parameter contributes to the loss) 94 | // for all the parameters (weights, biases, embeds). Loss is the tail of a computation graph. 95 | loss.Backward() 96 | // Nudge the parameters in the direction of the gradients, so to minimize the loss. 97 | optimizer.Update(params) 98 | params.ZeroGrad() 99 | } 100 | params.Save() 101 | pkg.DisableDropout() 102 | // Training is done. 103 | 104 | // Predicts the next token based on the context of tokens. 105 | nextTok := func(context []float64) float64 { 106 | context = context[max(0, len(context)-blockSize):] 107 | 108 | // Feed context tokens to the model. 109 | embeds := Rows(tokEmbeds, context...) 110 | embeds = Add(embeds, posEmbeds) 111 | for _, block := range blocks { 112 | embeds = block.Forward(embeds) 113 | } 114 | embeds = norm.Forward(embeds) 115 | logits := lmHead.Forward(embeds) // get a list of final logits for the next token 116 | 117 | // We only care about the probabilities of the next token for the last token. 118 | logitsForNextToken := Rows(logits, -1) 119 | probs := Softmax(logitsForNextToken) 120 | tok := pkg.SampleTemp(probs, 0.8) 121 | 122 | return tok 123 | } 124 | 125 | // Sample from the model. 126 | prompt := " mysterious island" 127 | for { 128 | fmt.Printf("\n%s", prompt) 129 | context := data.Encode(prompt) 130 | for i := 0; i < maxTokens; i++ { 131 | nextToken := nextTok(context) 132 | fmt.Print(data.Decode(nextToken)) 133 | context = append(context, nextToken) 134 | } 135 | 136 | fmt.Print("\n$ ") 137 | scanner := bufio.NewScanner(os.Stdin) 138 | scanner.Scan() 139 | prompt = scanner.Text() 140 | if prompt == "exit" { 141 | fmt.Println("Bye!") 142 | break 143 | } 144 | } 145 | } 146 | -------------------------------------------------------------------------------- /main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | 7 | "github.com/itsubaki/autograd/variable" 8 | 9 | "github.com/zakirullin/gpt-go/data" 10 | ) 11 | 12 | func TestNeuron(t *testing.T) { 13 | // Our neuron has 2 inputs and 1 output (number of columns in weight matrix). 14 | // Its goal is to predict next number in the sequence. 15 | input := V{1, 2} // {x1, x2} 16 | weight := M{ 17 | {2}, // how much x1 contributes to the output 18 | {3}, // how much x2 contributes to the output 19 | } 20 | 21 | // We calculate the output by multiplying the input vector with the weight matrix. 22 | output := MatMul(input.Var(), weight.Var()) 23 | // output[0] = 1*2 + 2*3 = 8 24 | areEqual(t, 8, output) 25 | 26 | // That's a bad prediction (not equal to 3), so we have to adjust the weight. 27 | weight = M{ 28 | {1}, // nudge the first weight down by 1 29 | {1}, // nudge the second weight down by 2 30 | } 31 | 32 | output = MatMul(input.Var(), weight.Var()) 33 | // output = 1*-1 + 2*2 = 3 34 | // Now our neuron's prediction matches the target, so the weight are correct. 35 | // In reality, though, we don't tune the weight manually. 36 | areEqual(t, 3, output) 37 | } 38 | 39 | func TestLinear(t *testing.T) { 40 | // Linear layer is a collection of neurons. 41 | layer := NewLinear(2, 1) 42 | layer.Weight = M{ 43 | {1}, 44 | {1}, 45 | }.Var() 46 | output := layer.Forward(V{1, 2}.Var()) 47 | areEqual(t, 3, output) 48 | } 49 | 50 | func TestLinearWithTwoInputs(t *testing.T) { 51 | layer := NewLinear(2, 1) 52 | layer.Weight = M{ 53 | {1}, 54 | {1}, 55 | }.Var() 56 | 57 | // For each input row the weighted output is calculated independently. 58 | input := M{ 59 | {1, 2}, 60 | {2, 3}, 61 | }.Var() 62 | output := layer.Forward(input) 63 | areMatricesEqual(t, M{ 64 | {3}, 65 | {5}, 66 | }, output) 67 | } 68 | 69 | func TestLoss(t *testing.T) { 70 | input := V{1, 2}.Var() 71 | weight := M{ 72 | {2}, // w1 73 | {3}, // w2 74 | }.Var() 75 | output := MatMul(input, weight) 76 | target := V{3}.Var() 77 | 78 | // We need a single number to characterize how bad our prediction is. 79 | // For that we need a loss function. 80 | // Let's pick the simplest one: 81 | // loss = prediction - target 82 | loss := Sub(output, target) 83 | // 8 - 3 = 5 84 | areEqual(t, 5, loss) 85 | // So the loss is 5, which is a bad loss (we should strive to 0). 86 | 87 | // Both input (x1, x2) and weights (w1, w2) values contribute to the loss. 88 | // So, to minimize the loss we can tune the input or the weights. 89 | // Since we are training a model, we want to tune the weights. Inputs come from fixed dataset. 90 | // So we need to calculate how much w1 and w2 contribute to the loss. 91 | // "By what amount the loss would change if we change x1 and x2 by some tiny value". 92 | // That's derivative. The speed of change of the loss with respect to the weights. 93 | 94 | // dLoss/dW1 = (x1 * w1 + x2 * w2) = x1 = 1 95 | // It means that to make our loss bigger, we need to move our x1 into positive direction (+), 1. 96 | // We want the opposite - to make it smaller, so we need to - 1 from our weight. 97 | // dLoss/dW2 = (x1 * w1 + x2 * w2) = x2 = 2 98 | 99 | // We calculate how much each weight contributes to the predicted value. 100 | // So that we know in which direction to nudge the weight to minimize the loss. 101 | loss.Backward() 102 | areMatricesEqual(t, M{{1}, {2}}, weight.Grad) // derivatives also called gradients 103 | } 104 | 105 | func TestGradientDescent(t *testing.T) { 106 | input := V{1, 2}.Var() // {x1, x2} 107 | weight := M{ 108 | {2}, // w1 109 | {3}, // w2 110 | } 111 | 112 | // Now we know in which direction we should nudge the weights to minimize the loss. 113 | // Gradient for w1 is 1, it means that w1 contributes to loss proportionally. 114 | // Gradient for w2 is 2, it means that w2 contributes to loss twice as strongly. 115 | // I.e. if we nudge w2 by some tiny value 0.1, the loss will change by 0.2. 116 | // If we want to minimize the loss, we nudge the weights in the opposite direction. 117 | learningRate := 0.5 118 | weightGrad := V{1, 2} 119 | weight[0][0] -= learningRate * weightGrad[0] // w1 -= w1 * learningRate * w1Grad 120 | weight[1][0] -= learningRate * weightGrad[1] // w2 -= w2 * learningRate * w2Grad 121 | 122 | output := MatMul(input, weight.Var()) 123 | // Previously the neuron predicted 8, now it predicts 5.5. 124 | areEqual(t, 5.5, output) 125 | // Previously the loss was 5, now it is 2.5, so the model has learned something. 126 | loss := Sub(output, V{3}.Var()) 127 | areEqual(t, 2.5, loss) 128 | 129 | // Repeat the process. 130 | weight[0][0] -= learningRate * weightGrad[0] 131 | weight[1][0] -= learningRate * weightGrad[1] 132 | output = MatMul(input, weight.Var()) // MatMul({1, 2}, weights) = 3 133 | 134 | // The neuron predicts 3 now, which is exactly what follows after 1 and 2! 135 | areEqual(t, 3, output) 136 | loss = Sub(output, V{3}.Var()) 137 | // The loss should be 0 now, because the neuron predicts the target value. 138 | areEqual(t, 0, loss) 139 | 140 | // Our simple model is now trained. 141 | // If the input is {1, 2}, the output is 3. 142 | // Our learning weights are: 143 | areMatricesEqual(t, M{ 144 | {1}, 145 | {1}, 146 | }, weight.Var()) // w1 = 1, w2 = 2 147 | } 148 | 149 | func TestSelfAttention(t *testing.T) { 150 | // Suppose we have the following tokens (~words) in our vocabulary: 151 | // 0 - "cat" 152 | // 1 - ", " 153 | // 2 - "dog" 154 | // 3 - " and" 155 | 156 | // Embeddings are a way to represent words (tokens) as vectors. It encodes the meaning of 157 | // the word in a high-dimensional space. Similar words are close to each other. 158 | embeds := M{ 159 | {4, 1, 6}, // embedding for "cat" 160 | {1, 8, 1}, // embedding for ", " 161 | {4, 1, 7}, // embedding for "dog" 162 | {1, 9, 3}, // embedding for " and" 163 | }.Var() 164 | // As we can see, embeddings for "cat" and "dog" are quite similar. 165 | 166 | // The input to transformer models are a sequence of token embeddings. 167 | // So if we feed "cat, dog and" string to our transformer, we must encode it. 168 | // First we tokenize it, i.e. split the sentence into the tokens: 169 | input := V{0, 1, 2, 3}.Var() 170 | 171 | // Then convert it to the list of embeddings: 172 | //{ 173 | // {4, 1, 6}, // embedding for "cat" 174 | // {1, 8, 3}, // embedding for ", " 175 | // {4, 1, 7}, // embedding for "dog" 176 | // {1, 9, 3}, // embedding for " and" 177 | //} 178 | inputEmbeds := Rows(embeds, Flat(input)...) 179 | areMatricesEqual(t, M{ 180 | {4, 1, 6}, // embedding for "cat" 181 | {1, 8, 1}, // embedding for ", " 182 | {4, 1, 7}, // embedding for "dog" 183 | {1, 9, 3}, // embedding for " and" 184 | }, inputEmbeds) 185 | 186 | // How do we predict next token from a given sequence of tokens? 187 | // We can naively predict the next token by looking only at the last token (bigram model does that). 188 | // However, by looking at the token "and" alone, we lose the context of the previous tokens (what does "and" refer to?). 189 | 190 | // So, we have to somehow combine the information from the current token and all the previous tokens. 191 | // "cat" -> "cat", no previous tokens to look at 192 | // ", " -> "cat" + ", " 193 | // "blue" -> "cat" + ", " + "dog" 194 | // " and" -> "cat" + ", " + "dog" + " and" 195 | // Since we're operating with numerical representations of the words (embeddings), we can just add them together. 196 | // I.e. for token " and" we'll do that: 197 | // {4, 1, 6} + {1, 8, 1} + {4, 1, 7} + {1, 9, 3} = {10, 19, 17} 198 | // Now our resulting vector " and" combines more information from the previous tokens. Now we can predict the next 199 | // token more accurately, because we have more context. 200 | 201 | // To calculate the sum of all previous tokens, we can multiply by this triangular matrix: 202 | tril := M{ 203 | {1, 0, 0, 0}, // first token attends only at itself ("cat"), it can't look into the future 204 | {1, 1, 0, 0}, // second token attends at itself and the previous token ( "cat" + ", ") 205 | {1, 1, 1, 0}, // third token attends at itself and the two previous tokens ("cat" + ", " + "dog") 206 | {1, 1, 1, 1}, // fourth token attends at itself and all the previous tokens ("cat" + ", " + "dog" + " and") 207 | }.Var() 208 | 209 | // So, at this point each embedding is enriched with the information from all the previous tokens. 210 | // That's the crux of self-attention. 211 | enrichedEmbeds := MatMul(tril, inputEmbeds) 212 | areMatricesEqual(t, M{ 213 | {4, 1, 6}, 214 | {5, 9, 7}, 215 | {9, 10, 14}, 216 | {10, 19, 17}, 217 | }, enrichedEmbeds) 218 | 219 | } 220 | 221 | func TestWeightedSelfAttention(t *testing.T) { 222 | // In reality, though, we don't pay equal attention to all the previous tokens. 223 | // Some previous tokens are more interested to us, some are less. 224 | 225 | // Let's look at token "and" and its {1, 2, 3} embedding. 226 | // We treat those "1", "2" and "3" components in same way (some features we don't know about). 227 | // But let's split them into 3 categories: 228 | // Query - "what I am looking for" 229 | // Key - "what I can communicate" 230 | // Value - "what I give you" 231 | 232 | // Since we don't know how to split them, we can use a linear layer to learn that for us. 233 | // We introduce 3 linear layers: query, key and value. 234 | query := NewLinear(3, 3) // converts each embedding into a query vector "what I am looking for" 235 | query.Weight = Zeros(3, 3) // manually set values for a good example 236 | query.Weight.Data[1][0] = 10 237 | 238 | key := NewLinear(3, 3) // converts each embedding into a key vector "what I can communicate" 239 | key.Weight = Zeros(3, 3) 240 | key.Weight.Data[0][0] = 1 // first neuron is paying attention to the first component of the embedding 241 | 242 | value := NewLinear(3, 3) // converts each embedding into a value vector "what I give you" 243 | value.Weight = Ones(3, 3) 244 | 245 | embeds := M{ 246 | {4, 1, 6}, // embedding for "cat" 247 | {1, 8, 1}, // embedding for ", " 248 | {4, 1, 7}, // embedding for "dog" 249 | {1, 9, 3}, // embedding for " and" 250 | }.Var() 251 | 252 | // Let's now extract the key and query vectors for each embed. 253 | 254 | k := key.Forward(embeds) 255 | // For our case, let's imagine that first component of our key is responsible for 256 | // I am "enumerable token". The bigger the value, "the more enumerable" the token is. 257 | areMatricesEqual(t, M{ 258 | {4, 0, 0}, // "cat" is quite enumerable 259 | {1, 0, 0}, // ", " is not quite enumerable 260 | {4, 0, 0}, // "dog" is quite enumerable 261 | {1, 0, 0}, // " and" is not quite enumerable 262 | }, k) 263 | 264 | q := query.Forward(embeds) 265 | areMatricesEqual(t, M{ 266 | {10, 0, 0}, // token "cat" is not looking for something enumerable 267 | {80, 0, 0}, // token ", " is looking for something enumerable a lot 268 | {10, 0, 0}, // token "dog" is not looking for something enumerable 269 | {90, 0, 0}, // token " and" is looking for something enumerable a lot 270 | }, q) 271 | 272 | // If we multiply q * k vectors, for each token we would answer to 273 | // a question "what tokens are interesting for me?". 274 | // Big values would indicate high interest. 275 | attentionScores := MatMul(q, Transpose(k)) 276 | areMatricesEqual(t, M{ 277 | {40, 10, 40, 10}, 278 | {320, 80, 320, 80}, 279 | {40, 10, 40, 10}, 280 | {360, 90, 360, 90}, // token " and" is interested in tokens "cat" (score=360) and "dog" (score=360) 281 | }, attentionScores) 282 | 283 | tril := M{ 284 | {1, 0, 0, 0}, // first token attends only at itself ("cat"), it can't look into the future 285 | {1, 1, 0, 0}, // second token attends at itself and the previous token ( "cat" + ", ") 286 | {1, 1, 1, 0}, // third token attends at itself and the two previous tokens ("cat" + ", " + "dog") 287 | {1, 1, 1, 1}, // fourth token attends at itself and all the previous tokens ("cat" + ", " + "dog" + " and") 288 | }.Var() 289 | // Previously we attended to all the previous token with the help of this tril matrix. 290 | // Now we only attend to those tokens in which we are interested, which is basically: 291 | attentionScores = MaskedInfFill(attentionScores, tril) 292 | no := math.Inf(-1) 293 | areMatricesEqual(t, M{ 294 | {80, no, no, no}, 295 | {640, 160, no, no}, 296 | {80, 20, 80, no}, 297 | {720, 180, 720, 180}, // token " and" is interested in "cat" and "dog", not so much in the others 298 | }, attentionScores) 299 | 300 | attentionScores = Softmax(attentionScores) // fancy trick to turn {1, 1, no, no} to {0.5, 0.5, 0, 0} 301 | areMatricesEqual(t, M{ 302 | {1, 0, 0, 0}, 303 | {1, 0, 0, 0}, 304 | {0.5, 0, 0.5, 0}, 305 | {0.5, 0, 0.5, 0}, // token " and" is interested in "cat" and "dog", not interested into others at all 306 | }, attentionScores) 307 | } 308 | 309 | func TestTransformer(t *testing.T) { 310 | RandEmbeds = func(rows, cols int) *variable.Variable { 311 | return Zeros(rows, cols) 312 | } 313 | RandWeights = func(rows, cols int) *variable.Variable { 314 | return Ones(rows, cols) 315 | } 316 | data.RandInt = func(_ int) int { 317 | return 0 318 | } 319 | 320 | vocabSize := 10 321 | embedSize := 2 322 | blockSize := 2 323 | 324 | // Basic transformer components 325 | tokEmbeds := RandEmbeds(vocabSize, embedSize) 326 | posEmbeds := RandEmbeds(blockSize, embedSize) 327 | block := NewBlock(embedSize, 1) 328 | norm := NewLayerNorm(embedSize) 329 | lmHead := NewLinear(embedSize, vocabSize) 330 | 331 | // Input contains blockSize consecutive tokens. 332 | // Targets contain the expected next token for each input token. 333 | // Example: for input={0,1}, targets={1,2}, meaning 334 | // that next token (target) after 0 is 1, next after 1 is 2. 335 | input, targets := data.Sample([]float64{0, 1, 2}, blockSize) 336 | 337 | // { 338 | // {vector for tok0}, 339 | // {vector for tok1}, 340 | // ... other embeds 341 | // } 342 | embeds := Rows(tokEmbeds, Flat(input)...) // get embed for every input token 343 | embeds = Add(embeds, posEmbeds) // add positional embedding 344 | embeds = block.Forward(embeds) 345 | embeds = norm.Forward(embeds) 346 | // { 347 | // {score for tok0, ..., score for tokN}, // for input tok0 348 | // {score for tok0, ..., score for tokN}, // for input tok1 349 | // ... other logits 350 | // } 351 | logits := lmHead.Forward(embeds) // converts contextual embeddings to next-token predictions 352 | 353 | // Loss calculation, how much our predicted targets differ from the actual targets? 354 | loss := SoftmaxCrossEntropy(logits, targets) 355 | 356 | areEqual(t, 2.302585092994046, loss) 357 | } 358 | 359 | func areEqual(t *testing.T, want float64, got *variable.Variable) { 360 | t.Helper() 361 | if len(got.Data) != 1 { 362 | t.Errorf("expected a single value, got %d values", len(got.Data)) 363 | return 364 | } 365 | 366 | if math.Abs(want-Val(got)) > 1e-9 { 367 | t.Errorf("value mismatch: want %v, got %v", want, Val(got)) 368 | } 369 | } 370 | 371 | func areMatricesEqual(t *testing.T, want M, got *variable.Variable) { 372 | t.Helper() 373 | if len(want) != len(got.Data) { 374 | t.Errorf("matrix length mismatch: want length=%d, got length=%d", len(want), len(got.Data)) 375 | return 376 | } 377 | 378 | for i := range want { 379 | if len(want[i]) != len(got.Data[i]) { 380 | t.Errorf("matrix row length mismatch at row %d: want length=%d, got length=%d", i, len(want[i]), len(got.Data[i])) 381 | return 382 | } 383 | } 384 | 385 | for i := range want { 386 | for j := range want[i] { 387 | if math.Abs(want[i][j]-got.Data[i][j]) > 1e-9 { 388 | t.Errorf("matrix mismatch at row %d, column %d: want %v, got %v", i, j, want[i][j], got.Data[i][j]) 389 | } 390 | } 391 | } 392 | } 393 | 394 | type V []float64 395 | 396 | func (v V) Var() *variable.Variable { 397 | return variable.NewOf(v) 398 | } 399 | 400 | type M [][]float64 401 | 402 | func (m M) Var() *variable.Variable { 403 | return variable.NewOf(m...) 404 | } 405 | -------------------------------------------------------------------------------- /pkg/adamw.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/itsubaki/autograd/matrix" 7 | "github.com/itsubaki/autograd/optimizer" 8 | "github.com/itsubaki/autograd/variable" 9 | ) 10 | 11 | type AdamW struct { 12 | Alpha float64 // Learning rate 13 | Beta1 float64 // Exponential decay rate for first moment 14 | Beta2 float64 // Exponential decay rate for second moment 15 | WeightDecay float64 // Weight decay coefficient 16 | Hook []optimizer.Hook 17 | iter int 18 | ms, vs map[*variable.Variable]matrix.Matrix 19 | } 20 | 21 | func NewAdamW(learningRate float64) AdamW { 22 | return AdamW{Alpha: learningRate, Beta1: 0.9, Beta2: 0.999, WeightDecay: 0.01} 23 | } 24 | 25 | func (o *AdamW) Update(model optimizer.Model) { 26 | params := optimizer.Params(model, o.Hook) 27 | 28 | if len(o.ms) == 0 { 29 | o.ms = make(map[*variable.Variable]matrix.Matrix) 30 | o.vs = make(map[*variable.Variable]matrix.Matrix) 31 | } 32 | 33 | o.iter++ 34 | fix1 := 1.0 - math.Pow(o.Beta1, float64(o.iter)) 35 | fix2 := 1.0 - math.Pow(o.Beta2, float64(o.iter)) 36 | lr := o.Alpha * math.Sqrt(fix2) / fix1 37 | 38 | for _, p := range params { 39 | if _, ok := o.ms[p]; !ok { 40 | o.ms[p] = matrix.ZeroLike(p.Data) 41 | o.vs[p] = matrix.ZeroLike(p.Data) 42 | } 43 | 44 | // Update biased first moment estimate 45 | o.ms[p] = matrix.F2(o.ms[p], p.Grad.Data, func(m, grad float64) float64 { 46 | return m + ((1 - o.Beta1) * (grad - m)) 47 | }) 48 | 49 | // Update biased second raw moment estimate 50 | o.vs[p] = matrix.F2(o.vs[p], p.Grad.Data, func(v, grad float64) float64 { 51 | return v + ((1 - o.Beta2) * (grad*grad - v)) 52 | }) 53 | 54 | // The key difference for AdamW: apply weight decay directly to the weights 55 | // instead of incorporating it into the gradient 56 | 57 | // First compute the standard Adam update 58 | adamUpdate := matrix.F2(o.ms[p], o.vs[p], func(m, v float64) float64 { 59 | return lr * m / (math.Sqrt(v) + 1e-8) 60 | }) 61 | 62 | // Then apply weight decay separately 63 | weightDecayUpdate := matrix.MulC(lr*o.WeightDecay, p.Data) 64 | 65 | // Update parameters: param = param - adamUpdate - weightDecayUpdate 66 | p.Data = matrix.Sub(p.Data, adamUpdate) 67 | p.Data = matrix.Sub(p.Data, weightDecayUpdate) 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /pkg/cat.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "github.com/itsubaki/autograd/variable" 5 | ) 6 | 7 | // Cat concatenates matrices horizontally 8 | func Cat(x ...*variable.Variable) *variable.Variable { 9 | return (&variable.Function{Forwarder: &CatT{NumInputs: len(x)}}).First(x...) 10 | } 11 | 12 | type CatT struct { 13 | NumInputs int 14 | ColSize int 15 | } 16 | 17 | // Concatenate along the columns dimension (dim=1) 18 | func (f *CatT) Forward(x ...*variable.Variable) []*variable.Variable { 19 | rows := len(x[0].Data) 20 | f.ColSize = len(x[0].Data[0]) 21 | totalCols := f.ColSize * len(x) 22 | 23 | result := make([][]float64, rows) 24 | for i := range result { 25 | result[i] = make([]float64, totalCols) 26 | colOffset := 0 27 | for _, v := range x { 28 | copy(result[i][colOffset:], v.Data[i]) 29 | colOffset += f.ColSize 30 | } 31 | } 32 | 33 | return []*variable.Variable{ 34 | variable.NewOf(result...), 35 | } 36 | } 37 | 38 | func (f *CatT) Backward(gy ...*variable.Variable) []*variable.Variable { 39 | grads := make([]*variable.Variable, f.NumInputs) 40 | 41 | // Split along columns 42 | for i := 0; i < f.NumInputs; i++ { 43 | colOffset := i * f.ColSize 44 | colData := make([][]float64, len(gy[0].Data)) 45 | 46 | for j := range colData { 47 | colData[j] = make([]float64, f.ColSize) 48 | copy(colData[j], gy[0].Data[j][colOffset:colOffset+f.ColSize]) 49 | } 50 | 51 | grads[i] = variable.NewOf(colData...) 52 | } 53 | 54 | return grads 55 | } 56 | -------------------------------------------------------------------------------- /pkg/cat_test.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | func ExampleCat_basic() { 8 | a := M{ 9 | {1, 2}, 10 | {3, 4}, 11 | }.Var() 12 | 13 | b := M{ 14 | {5, 6}, 15 | {7, 8}, 16 | }.Var() 17 | 18 | result := Cat(a, b) 19 | fmt.Println(result.Data) 20 | 21 | // Output: [[1 2 5 6] [3 4 7 8]] 22 | } 23 | 24 | func ExampleCat_singleInput() { 25 | a := M{ 26 | {1, 2}, 27 | {3, 4}, 28 | }.Var() 29 | 30 | result := Cat(a) 31 | fmt.Println(result.Data) 32 | 33 | // Output: [[1 2] [3 4]] 34 | } 35 | 36 | func ExampleCat_threeMatrices() { 37 | a := M{ 38 | {1, 2}, 39 | {3, 4}, 40 | }.Var() 41 | 42 | b := M{ 43 | {5, 6}, 44 | {7, 8}, 45 | }.Var() 46 | 47 | c := M{ 48 | {9, 10}, 49 | {11, 12}, 50 | }.Var() 51 | 52 | result := Cat(a, b, c) 53 | fmt.Println(result.Data) 54 | 55 | // Output: [[1 2 5 6 9 10] [3 4 7 8 11 12]] 56 | } 57 | 58 | func ExampleCat_gradient() { 59 | a := M{ 60 | {1, 2}, 61 | {3, 4}, 62 | }.Var() 63 | 64 | b := M{ 65 | {5, 6}, 66 | {7, 8}, 67 | }.Var() 68 | 69 | result := Cat(a, b) 70 | result.Grad = M{ 71 | {0.1, 0.2, 0.3, 0.4}, 72 | {0.5, 0.6, 0.7, 0.8}, 73 | }.Var() 74 | 75 | result.Backward() 76 | 77 | fmt.Println(a.Grad.Data) 78 | fmt.Println(b.Grad.Data) 79 | 80 | // Output: 81 | // [[0.1 0.2] [0.5 0.6]] 82 | // [[0.3 0.4] [0.7 0.8]] 83 | } 84 | 85 | func ExampleCat_withThreeDifferentMatrices() { 86 | a := M{ 87 | {1, 2}, 88 | {3, 4}, 89 | }.Var() 90 | 91 | b := M{ 92 | {5, 6}, 93 | {7, 8}, 94 | }.Var() 95 | 96 | c := M{ 97 | {9, 10}, 98 | {11, 12}, 99 | }.Var() 100 | 101 | result := Cat(a, b, c) 102 | result.Grad = M{ 103 | {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}, 104 | {0.7, 0.8, 0.9, 1.0, 1.1, 1.2}, 105 | }.Var() 106 | 107 | result.Backward() 108 | 109 | fmt.Println(a.Grad.Data) 110 | fmt.Println(b.Grad.Data) 111 | fmt.Println(c.Grad.Data) 112 | 113 | // Output: 114 | // [[0.1 0.2] [0.7 0.8]] 115 | // [[0.3 0.4] [0.9 1]] 116 | // [[0.5 0.6] [1.1 1.2]] 117 | } 118 | 119 | func ExampleCat_matrixMultiplicationWith() { 120 | a := M{ 121 | {1, 2}, 122 | {3, 4}, 123 | }.Var() 124 | 125 | b := M{ 126 | {5, 6}, 127 | {7, 8}, 128 | }.Var() 129 | 130 | c := Cat(a, b) 131 | 132 | d := M{ 133 | {0.1}, 134 | {0.2}, 135 | {0.3}, 136 | {0.4}, 137 | }.Var() 138 | 139 | result := MatMul(c, d) 140 | fmt.Println(result.Data) 141 | 142 | result.Backward() 143 | 144 | fmt.Println(a.Grad.Data) 145 | fmt.Println(b.Grad.Data) 146 | 147 | // Output: 148 | // [[4.4] [6.4]] 149 | // [[0.1 0.2] [0.1 0.2]] 150 | // [[0.3 0.4] [0.3 0.4]] 151 | } 152 | -------------------------------------------------------------------------------- /pkg/functions.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "math" 5 | "math/rand/v2" 6 | 7 | "github.com/itsubaki/autograd/matrix" 8 | "github.com/itsubaki/autograd/variable" 9 | ) 10 | 11 | var ( 12 | Add = variable.Add 13 | Div = variable.Div 14 | Zeros = variable.Zero 15 | ) 16 | 17 | // Sample returns a random index based on the given probabilities. 18 | func Sample(probs *variable.Variable) float64 { 19 | r := rand.Float64() 20 | 21 | // Find the first index where cumulative probability exceeds r. 22 | cumulativeProb := 0.0 23 | for i, p := range probs.Data[0] { 24 | cumulativeProb += p 25 | if r < cumulativeProb { 26 | return float64(i) 27 | } 28 | } 29 | 30 | return float64(len(probs.Data)) - 1 31 | } 32 | 33 | // SampleTemp returns a random index based on the given probabilities and temperature. 34 | // The higher the temperature, the more random the sampling. 35 | // Usually, temperature is between 0.5 and 0.8. 36 | func SampleTemp(probs *variable.Variable, temperature float64) float64 { 37 | adjustedProbs := make([]float64, len(probs.Data[0])) 38 | copy(adjustedProbs, probs.Data[0]) 39 | if temperature != 1.0 { 40 | // Lower temperature: higher probs amplified, lower reduced, more deterministic. 41 | // Higher temperature: probabilities become more uniform, more random. 42 | sum := 0.0 43 | for i, p := range adjustedProbs { 44 | // Apply temperature by raising to power of 1/temperature. 45 | adjustedProbs[i] = math.Pow(p, 1.0/temperature) 46 | sum += adjustedProbs[i] 47 | } 48 | 49 | for i := range adjustedProbs { 50 | adjustedProbs[i] /= sum 51 | } 52 | } 53 | 54 | return Sample(variable.NewOf(adjustedProbs)) 55 | } 56 | 57 | // Returns rows at specified indexes. Negative indexes return rows from the end. 58 | func Rows(x *variable.Variable, indexes ...float64) *variable.Variable { 59 | size := len(x.Data) 60 | 61 | var intIndexes []int 62 | for _, index := range indexes { 63 | intIndex := int(index) 64 | if intIndex < 0 { 65 | intIndex = size + intIndex 66 | } 67 | 68 | intIndexes = append(intIndexes, intIndex) 69 | } 70 | 71 | return (&variable.Function{Forwarder: &variable.GetItemT{Slices: intIndexes}}).First(x) 72 | } 73 | 74 | // Returns a matrix of random values from a normal distribution. 75 | func Normal(rows, cols int) *variable.Variable { 76 | rnd := func(_ float64) float64 { 77 | // Standard deviation = 0.02 is widely used in transformer models like GPT-2. 78 | // It prevents too large values in the beginning of training. 79 | std := 0.02 80 | return rand.NormFloat64() * std 81 | } 82 | 83 | m := matrix.Zero(rows, cols) 84 | m = matrix.F(m, rnd) 85 | 86 | return variable.NewOf(m...) 87 | } 88 | 89 | func Tril(m *variable.Variable) *variable.Variable { 90 | result := variable.ZeroLike(m) 91 | for i := 0; i < len(m.Data); i++ { 92 | for j := 0; j < len(m.Data[i]); j++ { 93 | if j <= i { 94 | result.Data[i][j] = m.Data[i][j] 95 | } 96 | } 97 | } 98 | 99 | return result 100 | } 101 | 102 | // The result would be added to computation graph and tied to m. 103 | func MaskedInfFill(m, mask *variable.Variable) *variable.Variable { 104 | negInfMaskedData := matrix.F2(m.Data, mask.Data, func(a, b float64) float64 { 105 | if b == 0 { 106 | return math.Inf(-1) 107 | } 108 | 109 | return a 110 | }) 111 | mMasked := Add(variable.Mul(m, mask), variable.NewOf(negInfMaskedData...)) 112 | 113 | return mMasked 114 | } 115 | 116 | func DivC(c float64, x *variable.Variable) *variable.Variable { 117 | return variable.MulC(1.0/c, x) 118 | } 119 | 120 | // Returns a matrix of ones. 121 | func Ones(m, n int) *variable.Variable { 122 | out := make([][]float64, m) 123 | for i := range m { 124 | out[i] = make([]float64, n) 125 | for j := range n { 126 | out[i][j] = 1.0 127 | } 128 | } 129 | 130 | return variable.NewOf(out...) 131 | } 132 | 133 | // Returns the first element of the variable. 134 | func Val(x *variable.Variable) float64 { 135 | return x.Data[0][0] 136 | } 137 | 138 | func Flat(x *variable.Variable) []float64 { 139 | return matrix.Flatten(x.Data) 140 | } 141 | 142 | func Millions(num int) float64 { 143 | return float64(num) / 1e6 144 | } 145 | 146 | func DisableDropout() { 147 | variable.Config.Train = false // disables dropout 148 | } 149 | -------------------------------------------------------------------------------- /pkg/matmul.go: -------------------------------------------------------------------------------- 1 | // MatMul performs parallelized matrix multiplication that minimizes 2 | // CPU cache misses. It divides the computation into sequential chunks 3 | // processed by multiple goroutines in parallel. 4 | package pkg 5 | 6 | import ( 7 | "runtime" 8 | "sync" 9 | 10 | "github.com/itsubaki/autograd/matrix" 11 | "github.com/itsubaki/autograd/variable" 12 | ) 13 | 14 | const ( 15 | blockSize = 32 // Group calculations by blocks for better CPU cache utilization 16 | ) 17 | 18 | func MatMul(x ...*variable.Variable) *variable.Variable { 19 | return (&variable.Function{Forwarder: &MatMulT{}}).First(x...) 20 | } 21 | 22 | type MatMulT struct { 23 | x, w *variable.Variable 24 | } 25 | 26 | func (f *MatMulT) Forward(x ...*variable.Variable) []*variable.Variable { 27 | f.x, f.w = x[0], x[1] 28 | 29 | y := matmul(x[0].Data, x[1].Data) 30 | return []*variable.Variable{y} 31 | } 32 | 33 | func (f *MatMulT) Backward(gy ...*variable.Variable) []*variable.Variable { 34 | return []*variable.Variable{ 35 | MatMul(gy[0], variable.Transpose(f.w)), // gy * w.T 36 | MatMul(variable.Transpose(f.x), gy[0]), // x.T * gy 37 | } 38 | } 39 | 40 | func matmul(m, n matrix.Matrix) *variable.Variable { 41 | mRows, mCols := len(m), len(m[0]) 42 | _, nCols := len(n), len(n[0]) 43 | 44 | result := Zeros(mRows, nCols) 45 | var wg sync.WaitGroup 46 | numCPU := runtime.NumCPU() 47 | // Create more chunks than CPUs for better load balancing 48 | // Adjust the multiplier to find the optimal balance 49 | chunkSize := max(1, mRows/(numCPU*4)) 50 | for startRow := 0; startRow < mRows; startRow += chunkSize { 51 | wg.Add(1) 52 | 53 | go func(firstRow, lastRow int) { 54 | defer wg.Done() 55 | 56 | // Process this chunk of rows with blocking for better cache utilization 57 | for ii := firstRow; ii < lastRow; ii += blockSize { 58 | for kk := 0; kk < mCols; kk += blockSize { 59 | for jj := 0; jj < nCols; jj += blockSize { 60 | // Calculate bounds for current block 61 | iEnd := min(ii+blockSize, lastRow) 62 | kEnd := min(kk+blockSize, mCols) 63 | jEnd := min(jj+blockSize, nCols) 64 | 65 | // Process the current block with cache-friendly access 66 | for i := ii; i < iEnd; i++ { 67 | for k := kk; k < kEnd; k++ { 68 | aik := m[i][k] 69 | for j := jj; j < jEnd; j++ { 70 | result.Data[i][j] += aik * n[k][j] 71 | } 72 | } 73 | } 74 | } 75 | } 76 | } 77 | }(startRow, min(startRow+chunkSize, mRows)) 78 | } 79 | wg.Wait() 80 | 81 | return result 82 | } 83 | -------------------------------------------------------------------------------- /pkg/matmul_test.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/itsubaki/autograd/variable" 7 | ) 8 | 9 | func ExampleMatMul_basic2x2() { 10 | a := M{ 11 | {1, 2}, 12 | {3, 4}, 13 | }.Var() 14 | 15 | b := M{ 16 | {5, 6}, 17 | {7, 8}, 18 | }.Var() 19 | 20 | result := MatMul(a, b) 21 | fmt.Println(result) 22 | 23 | // Output: 24 | // variable([[19 22] [43 50]]) 25 | } 26 | 27 | func ExampleMatMul_nonSquare() { 28 | a := M{ 29 | {1, 2, 3}, 30 | {4, 5, 6}, 31 | }.Var() 32 | 33 | b := M{ 34 | {7, 8}, 35 | {9, 10}, 36 | {11, 12}, 37 | }.Var() 38 | 39 | result := MatMul(a, b) 40 | fmt.Println(result) 41 | 42 | // Output: 43 | // variable([[58 64] [139 154]]) 44 | } 45 | 46 | func ExampleMatMul_columnVector() { 47 | a := M{ 48 | {1, 2}, 49 | {3, 4}, 50 | {5, 6}, 51 | }.Var() 52 | 53 | b := M{ 54 | {7}, 55 | {8}, 56 | }.Var() 57 | 58 | result := MatMul(a, b) 59 | fmt.Println(result) 60 | 61 | // Output: 62 | // variable([[23] [53] [83]]) 63 | } 64 | 65 | func ExampleMatMul_rowVector() { 66 | a := V{1, 2}.Var() 67 | 68 | b := M{ 69 | {3, 4}, 70 | {5, 6}, 71 | }.Var() 72 | 73 | result := MatMul(a, b) 74 | fmt.Println(result) 75 | 76 | // Output: 77 | // variable([13 16]) 78 | } 79 | 80 | func ExampleMatMul_chain() { 81 | a := V{1, 2}.Var() 82 | 83 | b := M{ 84 | {3, 4}, 85 | {5, 6}, 86 | }.Var() 87 | 88 | c := M{ 89 | {7}, 90 | {8}, 91 | }.Var() 92 | 93 | result := MatMul(MatMul(a, b), c) 94 | fmt.Println(result) 95 | 96 | // Output: 97 | // variable([219]) 98 | } 99 | 100 | func ExampleMatMul_zeroMatrix() { 101 | a := M{ 102 | {0, 0}, 103 | {0, 0}, 104 | }.Var() 105 | 106 | b := M{ 107 | {1, 2}, 108 | {3, 4}, 109 | }.Var() 110 | 111 | result := MatMul(a, b) 112 | fmt.Println(result) 113 | 114 | // Output: 115 | //variable([[0 0] [0 0]]) 116 | } 117 | 118 | func ExampleMatMul_gradient() { 119 | a := M{ 120 | {1, 2}, 121 | {3, 4}, 122 | }.Var() 123 | 124 | b := M{ 125 | {5, 6}, 126 | {7, 8}, 127 | }.Var() 128 | 129 | result := MatMul(a, b) 130 | 131 | result.Grad = M{ 132 | {1, 1}, 133 | {1, 1}, 134 | }.Var() 135 | 136 | result.Backward() 137 | 138 | fmt.Println(a.Grad) 139 | fmt.Println(b.Grad) 140 | 141 | // Output: 142 | //variable([[11 15] [11 15]]) 143 | //variable([[4 4] [6 6]]) 144 | } 145 | 146 | // Shortcut for building readable matrices: 147 | // 148 | // M{ 149 | // {1, 2}, 150 | // {3, 4}, 151 | // }.Var() 152 | type M [][]float64 153 | 154 | func (m M) Var() *variable.Variable { 155 | return variable.NewOf(m...) 156 | } 157 | 158 | type V []float64 159 | 160 | func (v V) Var() *variable.Variable { 161 | return variable.NewOf(v) 162 | } 163 | -------------------------------------------------------------------------------- /pkg/mean.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "github.com/itsubaki/autograd/variable" 5 | ) 6 | 7 | func Mean(x ...*variable.Variable) *variable.Variable { 8 | return (&variable.Function{Forwarder: &MeanT{}}).First(x...) 9 | } 10 | 11 | type MeanT struct { 12 | x *variable.Variable 13 | n int 14 | } 15 | 16 | // Mean alongside rows 17 | func (m *MeanT) Forward(x ...*variable.Variable) []*variable.Variable { 18 | m.x = x[0] 19 | m.n = len(x[0].Data[0]) 20 | 21 | means := variable.Zero(len(x[0].Data), 1) 22 | for i := range x[0].Data { 23 | means.Data[i][0] = mean(x[0].Data[i]) 24 | } 25 | 26 | return []*variable.Variable{ 27 | means, 28 | } 29 | } 30 | 31 | // Derivative of mean(x1, x2) by xn = 1/n 32 | func (m *MeanT) Backward(gy ...*variable.Variable) []*variable.Variable { 33 | g := variable.ZeroLike(m.x) 34 | for i := range g.Data { 35 | for j := range g.Data[i] { 36 | g.Data[i][j] = gy[0].Data[i][0] / float64(m.n) 37 | } 38 | } 39 | 40 | return []*variable.Variable{ 41 | g, 42 | } 43 | } 44 | 45 | func mean(values []float64) float64 { 46 | sum := 0.0 47 | for _, v := range values { 48 | sum += v 49 | } 50 | 51 | return sum / float64(len(values)) 52 | } 53 | -------------------------------------------------------------------------------- /pkg/mean_test.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/itsubaki/autograd/variable" 7 | ) 8 | 9 | func ExampleMean_basic() { 10 | a := M{ 11 | {1, 2, 3}, 12 | {4, 5, 6}, 13 | }.Var() 14 | 15 | result := Mean(a) 16 | fmt.Println(result.Data) 17 | 18 | // Output: [[2] [5]] 19 | } 20 | 21 | func ExampleMean_withZero() { 22 | a := M{ 23 | {1, 0, 1}, 24 | }.Var() 25 | 26 | result := Mean(a) 27 | fmt.Printf("%.6f\n", result.Data[0][0]) 28 | 29 | // Output: 0.666667 30 | } 31 | 32 | func ExampleMean_withZeros() { 33 | a := M{ 34 | {0, 0, 0}, 35 | {0, 0, 0}, 36 | }.Var() 37 | 38 | result := Mean(a) 39 | fmt.Println(result.Data) 40 | 41 | // Output: [[0] [0]] 42 | } 43 | 44 | func ExampleMean_withNegatives() { 45 | a := M{ 46 | {-1, 2, -3}, 47 | {4, -5, 6}, 48 | }.Var() 49 | 50 | result := Mean(a) 51 | fmt.Printf("%.6f %.6f\n", result.Data[0][0], result.Data[1][0]) 52 | 53 | // Output: -0.666667 1.666667 54 | } 55 | 56 | func ExampleMean_gradient() { 57 | a := M{ 58 | {1, 2, 3}, 59 | {4, 5, 6}, 60 | }.Var() 61 | 62 | result := Mean(a) 63 | result.Grad = M{ 64 | {0.1}, 65 | {0.2}, 66 | }.Var() 67 | 68 | result.Backward() 69 | fmt.Println(a.Grad.Data) 70 | 71 | // Output: [[0.03333333333333333 0.03333333333333333 0.03333333333333333] [0.06666666666666667 0.06666666666666667 0.06666666666666667]] 72 | } 73 | 74 | func ExampleMean_withScalarGradient() { 75 | a := M{ 76 | {1, 2}, 77 | {3, 4}, 78 | }.Var() 79 | 80 | result := Mean(a) 81 | result.Grad = M{ 82 | {1.0}, 83 | {1.0}, 84 | }.Var() 85 | 86 | result.Backward() 87 | fmt.Println(a.Grad.Data) 88 | 89 | // Output: [[0.5 0.5] [0.5 0.5]] 90 | } 91 | 92 | func ExampleMean_inComputationGraph() { 93 | a := M{ 94 | {1, 3}, 95 | {2, 4}, 96 | }.Var() 97 | 98 | meanA := Mean(a) 99 | result := variable.Mul(meanA, M{ 100 | {2}, 101 | {2}, 102 | }.Var()) 103 | 104 | fmt.Println(result.Data) 105 | result.Backward() 106 | fmt.Println(a.Grad.Data) 107 | 108 | // Output: 109 | // [[4] [6]] 110 | // [[1 1] [1 1]] 111 | } 112 | -------------------------------------------------------------------------------- /pkg/params.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | "hash/crc32" 7 | "os" 8 | 9 | "github.com/itsubaki/autograd/layer" 10 | ) 11 | 12 | type Params struct { 13 | params layer.Parameters 14 | } 15 | 16 | func NewParams() *Params { 17 | return &Params{params: layer.Parameters{}} 18 | } 19 | 20 | func (p *Params) Add(params ...layer.Parameter) { 21 | for _, param := range params { 22 | p.params.Add(fmt.Sprintf("%d", len(p.params)), param) 23 | } 24 | } 25 | 26 | func (p *Params) Params() layer.Parameters { 27 | return p.params 28 | } 29 | 30 | func (p *Params) Count() int { 31 | numParams := 0 32 | for _, param := range p.params { 33 | numParams += len(param.Data) * len(param.Data[0]) 34 | } 35 | 36 | return numParams 37 | } 38 | 39 | func (p *Params) ZeroGrad() { 40 | p.params.Cleargrads() 41 | } 42 | 43 | func (p *Params) Save() { 44 | file, err := os.Create(p.filename()) 45 | if err != nil { 46 | panic(err) 47 | } 48 | defer file.Close() 49 | 50 | // Save map of params in ordered fashion. 51 | hash := crc32.NewIEEE() 52 | for i := 0; i < len(p.params); i++ { 53 | key := fmt.Sprintf("%d", i) 54 | for _, row := range p.params[key].Data { 55 | if err := binary.Write(file, binary.LittleEndian, row); err != nil { 56 | panic(err) 57 | } 58 | } 59 | shape := fmt.Sprintf("%d:%d×%d", i, len(p.params[key].Data), len(p.params[key].Data[0])) 60 | hash.Write([]byte(shape)) 61 | } 62 | 63 | // Write checksum at the end of the file. 64 | checksum := hash.Sum32() 65 | if err := binary.Write(file, binary.LittleEndian, checksum); err != nil { 66 | panic(err) 67 | } 68 | } 69 | 70 | func (p *Params) TryLoadPretrained() { 71 | file, err := os.Open(p.filename()) 72 | if err != nil { 73 | return 74 | } 75 | defer file.Close() 76 | 77 | // Load map of params in ordered fashion. 78 | hash := crc32.NewIEEE() 79 | for i := 0; i < len(p.params); i++ { 80 | key := fmt.Sprintf("%d", i) 81 | for j := range p.params[key].Data { 82 | if err := binary.Read(file, binary.LittleEndian, &p.params[key].Data[j]); err != nil { 83 | panic(fmt.Sprintf("model shapes mismatch, remove '%s' file", p.filename())) 84 | } 85 | } 86 | shape := fmt.Sprintf("%d:%d×%d", i, len(p.params[key].Data), len(p.params[key].Data[0])) 87 | hash.Write([]byte(shape)) 88 | } 89 | 90 | var savedChecksum uint32 91 | if err := binary.Read(file, binary.LittleEndian, &savedChecksum); err != nil { 92 | panic(fmt.Errorf("failed to read shapes checksum: %v", err)) 93 | } 94 | if savedChecksum != hash.Sum32() { 95 | panic(fmt.Sprintf("model shapes mismatch, remove '%s' file", p.filename())) 96 | } 97 | 98 | fmt.Printf("Loaded pretrained params: %s\n", p.filename()) 99 | } 100 | 101 | func (p *Params) filename() string { 102 | return fmt.Sprintf("model-%.3fM", Millions(p.Count())) 103 | } 104 | -------------------------------------------------------------------------------- /pkg/softmax_test.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "strings" 7 | 8 | "github.com/itsubaki/autograd/function" 9 | "github.com/itsubaki/autograd/variable" 10 | ) 11 | 12 | func ExampleSoftmax_basic() { 13 | a := M{ 14 | {1, 2, 3}, 15 | {4, 5, 6}, 16 | }.Var() 17 | 18 | result := function.Softmax(a) 19 | fmt.Println(result) 20 | 21 | // Output: 22 | // variable([[0.09003057317038046 0.24472847105479764 0.6652409557748218] [0.09003057317038046 0.24472847105479764 0.6652409557748218]]) 23 | } 24 | 25 | func ExampleSoftmax_largeValues() { 26 | a := V{100, 100.1, 100.2}.Var() 27 | 28 | result := function.Softmax(a) 29 | fmt.Println(result) 30 | 31 | // Output: 32 | // variable([0.30060960535572756 0.33222499353334567 0.3671654011109268]) 33 | } 34 | 35 | func ExampleSoftmax_withMasking() { 36 | a := variable.NewOf( 37 | []float64{1, math.Inf(-1), 3}, 38 | []float64{math.Inf(-1), 2, 3}, 39 | []float64{1, 2, math.Inf(-1)}, 40 | []float64{1, math.Inf(-1), math.Inf(-1)}, 41 | ) 42 | result := function.Softmax(a) 43 | 44 | for _, row := range result.Data { 45 | values := make([]string, len(row)) 46 | for i, val := range row { 47 | values[i] = fmt.Sprintf("%.6f", val) 48 | } 49 | fmt.Println(strings.Join(values, " ")) 50 | } 51 | 52 | // Output: 53 | // 0.119203 0.000000 0.880797 54 | // 0.000000 0.268941 0.731059 55 | // 0.268941 0.731059 0.000000 56 | // 1.000000 0.000000 0.000000 57 | } 58 | 59 | func ExampleSoftmax_allMasked() { 60 | a := variable.NewOf( 61 | []float64{0, math.Inf(-1), math.Inf(-1)}, 62 | []float64{1, 2, 3}, 63 | ) 64 | result := function.Softmax(a) 65 | fmt.Println(result) 66 | 67 | // Output: 68 | // variable([[1 0 0] [0.09003057317038046 0.24472847105479764 0.6652409557748218]]) 69 | } 70 | 71 | func ExampleSoftmax_gradient() { 72 | a := variable.NewOf([]float64{1, 2, 3}) 73 | result := function.Softmax(a) 74 | result.Grad = variable.NewOf([]float64{1, 1, 1}) 75 | result.Backward() 76 | fmt.Println(a.Grad) 77 | 78 | b := variable.NewOf([]float64{1, 2, 3}) 79 | resultB := function.Softmax(b) 80 | resultB.Grad = variable.NewOf([]float64{1, 0, 0}) 81 | resultB.Backward() 82 | fmt.Println(b.Grad) 83 | 84 | // Output: 85 | // variable([1.3877787807814457e-17 2.7755575615628914e-17 1.1102230246251565e-16]) 86 | // variable([0.08192506906499324 -0.022033044520174298 -0.05989202454481893]) 87 | } 88 | -------------------------------------------------------------------------------- /pkg/variance.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "github.com/itsubaki/autograd/variable" 5 | ) 6 | 7 | func Variance(x ...*variable.Variable) *variable.Variable { 8 | // Calculate mean per row 9 | means := Mean(x[0]) 10 | 11 | diffs := variable.Sub(x[0], means) 12 | squaredDiffs := variable.Pow(2)(diffs) 13 | variance := Mean(squaredDiffs) 14 | 15 | return variance 16 | } 17 | -------------------------------------------------------------------------------- /pkg/variance_test.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/itsubaki/autograd/variable" 7 | ) 8 | 9 | func ExampleVariance_basic() { 10 | a := M{ 11 | {1, 2, 3}, 12 | {4, 5, 6}, 13 | }.Var() 14 | 15 | result := Variance(a) 16 | 17 | // Print with higher precision to show exact values 18 | fmt.Printf("%.10f\n", result.Data[0][0]) 19 | fmt.Printf("%.10f\n", result.Data[1][0]) 20 | 21 | // Output: 22 | // 0.6666666667 23 | // 0.6666666667 24 | } 25 | 26 | func ExampleVariance_constants() { 27 | a := M{ 28 | {5, 5, 5}, 29 | {-3, -3, -3}, 30 | }.Var() 31 | 32 | result := Variance(a) 33 | 34 | fmt.Println(result.Data[0][0]) 35 | fmt.Println(result.Data[1][0]) 36 | 37 | // Output: 38 | // 0 39 | // 0 40 | } 41 | 42 | func ExampleVariance_withNegatives() { 43 | a := M{ 44 | {-1, 0, 1}, 45 | {-10, 0, 10}, 46 | }.Var() 47 | 48 | result := Variance(a) 49 | 50 | fmt.Printf("%.10f\n", result.Data[0][0]) 51 | fmt.Printf("%.10f\n", result.Data[1][0]) 52 | 53 | // Output: 54 | // 0.6666666667 55 | // 66.6666666667 56 | } 57 | 58 | func ExampleVariance_gradient() { 59 | // Values [1, 3, 5] have a mean of 3 60 | a := M{ 61 | {1, 3, 5}, 62 | }.Var() 63 | 64 | result := Variance(a) 65 | 66 | // Print variance result 67 | fmt.Printf("Variance: %.10f\n", result.Data[0][0]) 68 | 69 | // Set gradient to 1.0 and backpropagate 70 | result.Grad = M{ 71 | {1.0}, 72 | }.Var() 73 | 74 | result.Backward() 75 | 76 | // Print gradients with high precision 77 | fmt.Printf("Gradients: %.10f %.10f %.10f\n", 78 | a.Grad.Data[0][0], a.Grad.Data[0][1], a.Grad.Data[0][2]) 79 | 80 | // Output: 81 | // Variance: 2.6666666667 82 | // Gradients: -1.3333333333 0.0000000000 1.3333333333 83 | } 84 | 85 | func ExampleVariance_inComputationGraph() { 86 | // Create input with a single row 87 | a := M{ 88 | {2, 4, 6}, 89 | }.Var() 90 | 91 | // Calculate variance 92 | v := Variance(a) 93 | 94 | // Multiply by scalar 95 | k := M{ 96 | {0.5}, 97 | }.Var() 98 | 99 | result := variable.Mul(v, k) 100 | 101 | // Print the result 102 | fmt.Printf("Result: %.10f\n", result.Data[0][0]) 103 | 104 | // Backpropagate 105 | result.Backward() 106 | 107 | // Print gradients 108 | fmt.Printf("Gradients: %.10f %.10f %.10f\n", 109 | a.Grad.Data[0][0], a.Grad.Data[0][1], a.Grad.Data[0][2]) 110 | 111 | // Output: 112 | // Result: 1.3333333333 113 | // Gradients: -0.6666666667 0.0000000000 0.6666666667 114 | } 115 | --------------------------------------------------------------------------------