├── .gitignore
├── LICENSE
├── README.md
├── block.go
├── data
├── data.go
├── data_test.go
├── fairy_tales.txt
├── fairy_vocab
├── jules_verne.txt
└── vocab
├── go.mod
├── go.sum
├── head.go
├── layer.go
├── main.go
├── main_test.go
└── pkg
├── adamw.go
├── cat.go
├── cat_test.go
├── functions.go
├── matmul.go
├── matmul_test.go
├── mean.go
├── mean_test.go
├── params.go
├── softmax_test.go
├── variance.go
└── variance_test.go
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode/**
2 | .idea/**
3 | **/.DS_Store
4 | /model-*
5 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Artem Zakirullin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # gpt-go
4 | Simple GPT implementation in pure Go. Trained on favourite Jules Verne books.
5 |
6 | What kind of response you can expect from the model:
7 | ```
8 | Mysterious Island.
9 | Well.
10 | My days must follow
11 | ```
12 |
13 | Or this:
14 | ```
15 | Captain Nemo, in two hundred thousand feet weary in
16 | the existence of the world.
17 | ```
18 |
19 | ## How to run
20 | ```shell
21 | $ go run .
22 | ```
23 |
24 | It takes about 40 minutes to train on MacBook Air M3. You can train on your own dataset by pointing the `data.dataset` variable to your text corpus.
25 |
26 | To run in chat-only mode once the training is done:
27 | ```shell
28 | $ go run . -chat
29 | ```
30 |
31 | ## How to understand
32 | You can use this repository as a companion to the [Neural Networks: Zero to Hero](https://karpathy.ai/zero-to-hero.html) course. Use `git checkout ` to see how the model has evolved over time: `naive`, `bigram`, `multihead`, `block`, `residual`, `full`.
33 |
34 | In [main_test.go](https://github.com/zakirullin/gpt-go/blob/main/main_test.go) you will find explanations starting from basic neuron example:
35 | ```go
36 | // Our neuron has 2 inputs and 1 output (number of columns in weight matrix).
37 | // Its goal is to predict next number in the sequence.
38 | input := V{1, 2} // {x1, x2}
39 | weight := M{
40 | {2}, // how much x1 contributes to the output
41 | {3}, // how much x2 contributes to the output
42 | }
43 | ```
44 |
45 | All the way to self-attention mechanism:
46 | ```go
47 | // To calculate the sum of all previous tokens, we can multiply by this triangular matrix:
48 | tril := M{
49 | {1, 0, 0, 0}, // first token attends only at itself ("cat"), it can't look into the future
50 | {1, 1, 0, 0}, // second token attends at itself and the previous token ( "cat" + ", ")
51 | {1, 1, 1, 0}, // third token attends at itself and the two previous tokens ("cat" + ", " + "dog")
52 | {1, 1, 1, 1}, // fourth token attends at itself and all the previous tokens ("cat" + ", " + "dog" + " and")
53 | }.Var()
54 | // So, at this point each embedding is enriched with the information from all the previous tokens.
55 | // That's the crux of self-attention.
56 | enrichedEmbeds := MatMul(tril, inputEmbeds)
57 | ```
58 |
59 | ## Design choices
60 | No batches.
61 | I've given up the complexity of the batch dimension for the sake of better understanding. It's far easier to build intuition with 2D matrices, rather than with 3D tensors. Besides, batches aren't inherent to the transformer architecture. As an alternative, gradient accumulation was tried. The effect was negligible, so it was removed as well.
62 |
63 | Removed `gonum`.
64 | The `gonum.matmul` gave us ~30% performance boost, but it brought additional dependency. We're not striving for maximum efficiency here, rather for radical simplicity. Current matmul implementation is quite effective, and it's only 40 lines of plain readable code.
65 |
66 | ## Papers
67 | You don't need to read them to understand the code :)
68 |
69 | [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
70 | [Deep Residual Learning](https://arxiv.org/abs/1512.03385)
71 | [DeepMind WaveNet](https://arxiv.org/abs/1609.03499)
72 | [Batch Normalization](https://arxiv.org/abs/1502.03167)
73 | [Deep NN + huge data = breakthrough performance](https://papers.nips.cc/paper_files/paper/2012/hash/c399862d3b9d6b76c8436e924a68c45b-Abstract.html)
74 | [OpenAI GPT-3 paper](https://arxiv.org/abs/2005.14165)
75 | [Analyzing the Structure of Attention](https://arxiv.org/abs/1906.04284)
76 |
77 | ## Credits
78 | Many thanks to [Andrej Karpathy](https://github.com/karpathy) for his brilliant [Neural Networks: Zero to Hero](https://karpathy.ai/zero-to-hero.html) course.
79 |
80 | Thanks to [@itsubaki](https://github.com/itsubaki) for his elegant [autograd](https://github.com/itsubaki/autograd) package.
81 |
--------------------------------------------------------------------------------
/block.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "github.com/itsubaki/autograd/function"
5 | "github.com/itsubaki/autograd/layer"
6 | "github.com/itsubaki/autograd/variable"
7 |
8 | "github.com/zakirullin/gpt-go/pkg"
9 | )
10 |
11 | var (
12 | Zeros = variable.Zero
13 | Ones = pkg.Ones
14 | ReLU = function.ReLU
15 | Dropout = function.DropoutSimple
16 | MatMul = pkg.MatMul
17 | Add = variable.Add
18 | Sub = variable.Sub
19 | MulC = variable.MulC
20 | Transpose = variable.Transpose
21 | Softmax = function.Softmax
22 | SoftmaxCrossEntropy = function.SoftmaxCrossEntropy
23 | RandEmbeds = pkg.Normal
24 | Rows = pkg.Rows
25 | Val = pkg.Val
26 | Flat = pkg.Flat
27 | )
28 |
29 | type Block struct {
30 | embedSize int
31 | headCount int
32 | saHead *MultiHeadAttention
33 | mlp *Linear // multi-layer perceptron
34 | mlpProj *Linear // projects the output of the MLP back to the original embedding size
35 | norm1 *LayerNorm
36 | norm2 *LayerNorm
37 | }
38 |
39 | func NewBlock(embedSize, numHeads int) *Block {
40 | return &Block{
41 | embedSize: embedSize,
42 | headCount: numHeads,
43 | saHead: NewMultiHeadAttention(embedSize, numHeads),
44 | mlp: NewLinear(embedSize, embedSize*4),
45 | mlpProj: NewLinear(embedSize*4, embedSize),
46 | norm1: NewLayerNorm(embedSize),
47 | norm2: NewLayerNorm(embedSize),
48 | }
49 | }
50 |
51 | func (b *Block) Forward(input *variable.Variable) *variable.Variable {
52 | // Self-attention with residual connections. Input is our highway, we allow the gradient to flow back unimpeded.
53 | input = b.norm1.Forward(input) // Normalize input (mean=0, var=1), i.e. normalize every token's embed
54 | saOut := b.saHead.Forward(input) // Encode relationships between positions, (blockSize, embedSize)
55 | input = Add(input, saOut) // Add residual attention output back to main path
56 |
57 | // Feed-forward network with residual connection
58 | input = b.norm2.Forward(input) // Normalize input
59 | mlpExpanded := b.mlp.Forward(input) // Expand to higher dimension
60 | mlpActivated := ReLU(mlpExpanded) // Apply activation function
61 | mlpOutput := b.mlpProj.Forward(mlpActivated) // Project back to original dimension
62 | mlpOutput = Dropout(dropout)(mlpOutput) // Dropping out some neurons to prevent overfitting
63 | input = Add(input, mlpOutput) // Add feed-forward residual output to main path
64 |
65 | return input
66 | }
67 |
68 | func (b *Block) Params() []layer.Parameter {
69 | var params []layer.Parameter
70 | for _, param := range b.saHead.Params() {
71 | params = append(params, param)
72 | }
73 | params = append(params, b.mlp.Weight, b.mlp.Bias)
74 | params = append(params, b.mlpProj.Weight, b.mlpProj.Bias)
75 | params = append(params, b.norm1.Scale, b.norm1.Shift)
76 | params = append(params, b.norm2.Scale, b.norm2.Shift)
77 |
78 | return params
79 | }
80 |
--------------------------------------------------------------------------------
/data/data.go:
--------------------------------------------------------------------------------
1 | package data
2 |
3 | import (
4 | _ "embed"
5 | "fmt"
6 | "math/rand/v2"
7 | "regexp"
8 | "slices"
9 | "strings"
10 |
11 | "github.com/itsubaki/autograd/variable"
12 | )
13 |
14 | var (
15 | //go:embed jules_verne.txt
16 | dataset string
17 | //go:embed vocab
18 | vocab string
19 |
20 | tokenToID map[string]int
21 | idToToken map[int]string
22 | mergeRules map[int64]int
23 | rulesOrder []int64
24 |
25 | Dataset = func() string { return dataset }
26 | Vocab = func() string { return vocab }
27 | RandInt = rand.IntN
28 | )
29 |
30 | func Tokenize(numMerges int) ([]float64, int) {
31 | tokenToID = make(map[string]int)
32 | idToToken = make(map[int]string)
33 | mergeRules = make(map[int64]int)
34 | rulesOrder = nil
35 |
36 | normDataset := normNewLines(Dataset())
37 | addCharsToVocab(normDataset)
38 | createMergeRules(Vocab(), numMerges)
39 |
40 | return Encode(normDataset), VocabSize()
41 | }
42 |
43 | func Encode(s string) []float64 {
44 | var tokens []float64
45 | for _, ch := range s {
46 | tok, ok := tokenToID[string(ch)]
47 | if !ok {
48 | panic(fmt.Sprintf("char '%s' is missing from vocabulary", string(ch)))
49 | }
50 | tokens = append(tokens, float64(tok))
51 | }
52 |
53 | for _, rule := range rulesOrder {
54 | var newTokens []float64
55 | tok1, tok2 := unzip(rule)
56 | // Try to apply rule on every pair of tokens
57 | for i := 0; i < len(tokens); {
58 | hasNextToken := i+1 < len(tokens)
59 | shouldMerge := hasNextToken && int(tokens[i]) == tok1 && int(tokens[i+1]) == tok2
60 | if shouldMerge {
61 | newTokens = append(newTokens, float64(mergeRules[rule]))
62 | i += 2 // eat two tokens
63 | } else {
64 | newTokens = append(newTokens, tokens[i])
65 | i++ // eat one token
66 | }
67 | }
68 | tokens = newTokens
69 | }
70 |
71 | return tokens
72 | }
73 |
74 | func Decode(indices ...float64) string {
75 | var result strings.Builder
76 |
77 | for _, idx := range indices {
78 | id := int(idx)
79 | if token, ok := idToToken[id]; ok {
80 | result.WriteString(token)
81 | } else {
82 | panic(fmt.Sprintf("uknown token id=%d", id))
83 | }
84 | }
85 |
86 | return result.String()
87 | }
88 |
89 | func VocabSize() int {
90 | return len(tokenToID)
91 | }
92 |
93 | // Sample returns a random sample of data of the given block size.
94 | func Sample(data []float64, blockSize int) (*variable.Variable, *variable.Variable) {
95 | dataLen := len(data) - (blockSize + 1)
96 | if dataLen < 0 {
97 | panic("not enough data for the given block size")
98 | }
99 |
100 | offset := RandInt(dataLen)
101 |
102 | x := make([]float64, blockSize)
103 | y := make([]float64, blockSize)
104 |
105 | for i := 0; i < blockSize; i++ {
106 | x[i] = data[i+offset]
107 | y[i] = data[i+offset+1]
108 | }
109 |
110 | return variable.New(x...), variable.New(y...)
111 | }
112 |
113 | func Chars() string {
114 | var tokens []string
115 | for token := range tokenToID {
116 | if len(token) == 1 {
117 | tokens = append(tokens, token)
118 | }
119 | }
120 | slices.Sort(tokens)
121 |
122 | return strings.Join(tokens, "")
123 | }
124 |
125 | func addCharsToVocab(text string) {
126 | var chars []string
127 | for _, ch := range text {
128 | chars = append(chars, string(ch))
129 | }
130 | addTokensToVocab(chars...)
131 | }
132 |
133 | func createMergeRules(rules string, numMerges int) {
134 | rules = strings.TrimSpace(rules)
135 | if len(rules) == 0 {
136 | return
137 | }
138 |
139 | merges := strings.Split(rules, "\n")
140 | merges = merges[:min(numMerges, len(merges))]
141 |
142 | // Mint new tokens, save merge rules.
143 | for _, m := range merges {
144 | re := regexp.MustCompile(`\[(.*?)\]\[(.*?)\] -> \[(.*?)\]`)
145 | matches := re.FindStringSubmatch(m)
146 | if len(matches) != 4 {
147 | panic(fmt.Sprintf("invalid vocab format: %s", m))
148 | }
149 |
150 | // Process Unicode escape sequences in all tokens.
151 | left := strings.ReplaceAll(matches[1], "\\n", "\n")
152 | right := strings.ReplaceAll(matches[2], "\\n", "\n")
153 | mergedToken := strings.ReplaceAll(matches[3], "\\n", "\n")
154 |
155 | addTokensToVocab(mergedToken)
156 |
157 | for _, token := range []string{left, right, mergedToken} {
158 | if _, ok := tokenToID[token]; !ok {
159 | panic(fmt.Sprintf("rule '%s' is malformed, token '%s' is missing from vocabulary", m, token))
160 | }
161 | }
162 | addRule(tokenToID[left], tokenToID[right], tokenToID[mergedToken])
163 | }
164 | }
165 |
166 | func addTokensToVocab(tokens ...string) {
167 | for _, token := range tokens {
168 | if _, ok := tokenToID[token]; ok {
169 | continue
170 | }
171 |
172 | tokenID := len(tokenToID)
173 | tokenToID[token] = tokenID
174 | idToToken[tokenID] = token
175 | }
176 | }
177 |
178 | func addRule(tok1, tok2, mergedTok int) {
179 | key := zip(tok1, tok2)
180 | mergeRules[key] = mergedTok
181 | rulesOrder = append(rulesOrder, key)
182 | }
183 |
184 | func normNewLines(text string) string {
185 | text = strings.Replace(text, "\r\n", "\n", -1) // replace Windows line endings
186 | return strings.Replace(text, "\r", "\n", -1) // replace remaining Mac line endings
187 | }
188 |
189 | // Zips two tokens into a single int64.
190 | func zip(tok1, tok2 int) int64 {
191 | return int64(tok1)<<32 | int64(tok2&0xFFFFFFFF)
192 | }
193 |
194 | // Unzips a single int64 into two tokens.
195 | func unzip(tok int64) (int, int) {
196 | tok1 := int(tok >> 32)
197 | tok2 := int(tok & 0xFFFFFFFF)
198 | return tok1, tok2
199 | }
200 |
--------------------------------------------------------------------------------
/data/data_test.go:
--------------------------------------------------------------------------------
1 | package data
2 |
3 | import (
4 | "math"
5 | "math/rand"
6 | "testing"
7 |
8 | "github.com/itsubaki/autograd/variable"
9 | )
10 |
11 | func TestTokenize(t *testing.T) {
12 | Dataset = func() string {
13 | return "abc"
14 | }
15 | Vocab = func() string {
16 | return "[a][b] -> [z]"
17 | }
18 | encoded, vocabSize := Tokenize(1)
19 |
20 | areSlicesEqual(t, []float64{3, 2}, encoded)
21 | areEqual(t, 4, vocabSize)
22 | }
23 |
24 | func TestEncode(t *testing.T) {
25 | Dataset = func() string {
26 | return "abcd"
27 | }
28 | Vocab = func() string {
29 | return "[a][a] -> [Z]\n[a][b] -> [Y]\n[Z][Y] -> [X]"
30 | }
31 |
32 | Tokenize(3)
33 |
34 | // aaabdaaabac
35 | // ZabdZabac
36 | // ZYdZYac
37 | // XdXac
38 | encoded := Encode("aaabdaaabac")
39 | areSlicesEqual(t, []float64{6, 3, 6, 0, 2}, encoded)
40 | }
41 |
42 | func TestDecode(t *testing.T) {
43 | Dataset = func() string {
44 | return "abcd"
45 | }
46 | Vocab = func() string {
47 | return "[a][a] -> [aa]\n[a][b] -> [ab]\n[aa][ab] -> [aaab]"
48 | }
49 |
50 | Tokenize(3)
51 |
52 | decoded := Decode([]float64{6, 3, 6, 0, 2}...)
53 |
54 | areEqual(t, "aaabdaaabac", decoded)
55 | }
56 |
57 | func TestEncodeDecodeDifferentNewLines(t *testing.T) {
58 | Dataset = func() string {
59 | return "a\nb\r\nc"
60 | }
61 | Vocab = func() string {
62 | return ""
63 | }
64 |
65 | encoded, _ := Tokenize(3)
66 | areSlicesEqual(t, []float64{0, 1, 2, 1, 3}, encoded)
67 |
68 | decoded := Decode([]float64{0, 1, 2, 1, 3}...)
69 | areEqual(t, "a\nb\nc", decoded)
70 | }
71 |
72 | func TestEncodeDecodeTokenizedNewLines(t *testing.T) {
73 | Dataset = func() string {
74 | return "a\nb\n\nc"
75 | }
76 | Vocab = func() string {
77 | return "[\\n][\\n] -> [\\n\\n]"
78 | }
79 |
80 | encoded, _ := Tokenize(1)
81 | areSlicesEqual(t, []float64{0, 1, 2, 4, 3}, encoded)
82 |
83 | decoded := Decode([]float64{0, 1, 2, 4, 3}...)
84 | areEqual(t, "a\nb\n\nc", decoded)
85 | }
86 |
87 | func TestZipUnzip(t *testing.T) {
88 | zipped := zip(1, 2)
89 | expected := int64(4294967298)
90 | areEqual(t, zipped, expected)
91 | x, y := unzip(expected)
92 | areEqual(t, 1, x)
93 | areEqual(t, 2, y)
94 | }
95 |
96 | func TestAddTokensFromText(t *testing.T) {
97 | idToToken = make(map[int]string)
98 | tokenToID = make(map[string]int)
99 |
100 | testText := "hello world"
101 | addCharsToVocab(testText)
102 |
103 | contains := true
104 | for _, token := range []string{"h", "e", "l", "o", " ", "w", "r", "d"} {
105 | if _, exists := tokenToID[token]; !exists {
106 | contains = false
107 | break
108 | }
109 | }
110 |
111 | areEqual(t, 8, len(tokenToID))
112 | areEqual(t, true, contains)
113 | }
114 |
115 | func TestSample(t *testing.T) {
116 | // Setup data
117 | testData := V{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
118 | blockSize := 3
119 |
120 | RandInt = func(_ int) int { return 0 }
121 | defer func() {
122 | RandInt = rand.Intn
123 | }()
124 |
125 | x, y := Sample(testData, blockSize)
126 | areMatricesEqual(t, M{
127 | {0, 1, 2},
128 | }, x)
129 | areMatricesEqual(t, M{
130 | {1, 2, 3},
131 | }, y)
132 | }
133 |
134 | func TestNormNewLinesEmptyString(t *testing.T) {
135 | input := ""
136 | expected := ""
137 | result := normNewLines(input)
138 | areEqual(t, expected, result)
139 | }
140 |
141 | func TestNormNewLinesSingleLine(t *testing.T) {
142 | input := "hello world"
143 | expected := "hello world"
144 | result := normNewLines(input)
145 | areEqual(t, expected, result)
146 | }
147 |
148 | func TestNormNewLinesLinuxLineEndings(t *testing.T) {
149 | input := "line1\nline2\nline3"
150 | expected := "line1\nline2\nline3"
151 | result := normNewLines(input)
152 | areEqual(t, expected, result)
153 | }
154 |
155 | func TestNormNewLinesMixedContent(t *testing.T) {
156 | input := "title\n\nsome content\nmore content\n\nfooter"
157 | expected := "title\n\nsome content\nmore content\n\nfooter"
158 | result := normNewLines(input)
159 | areEqual(t, expected, result)
160 | }
161 |
162 | func TestNormNewLinesWindowsLineEndings(t *testing.T) {
163 | input := "line1\r\nline2\r\nline3"
164 | expected := "line1\nline2\nline3"
165 | result := normNewLines(input)
166 | areEqual(t, expected, result)
167 | }
168 |
169 | func TestNormNewLinesOldMacLineEndings(t *testing.T) {
170 | input := "line1\rline2\rline3"
171 | expected := "line1\nline2\nline3"
172 | result := normNewLines(input)
173 | areEqual(t, expected, result)
174 | }
175 |
176 | func TestNormNewLinesMixedLineEndings(t *testing.T) {
177 | input := "line1\nline2\r\nline3\rline4"
178 | expected := "line1\nline2\nline3\nline4"
179 | result := normNewLines(input)
180 | areEqual(t, expected, result)
181 | }
182 |
183 | func areEqual[V comparable](t *testing.T, want, got V) {
184 | t.Helper()
185 | if want != got {
186 | t.Errorf("want: %v, got: %v", want, got)
187 | }
188 | }
189 |
190 | func areSlicesEqual[T comparable](t *testing.T, want, got []T) {
191 | t.Helper()
192 |
193 | if len(want) != len(got) {
194 | t.Errorf("length mismatch: want %d elements, got %d elements", len(want), len(got))
195 | t.Errorf("want: %v, got: %v", want, got)
196 | return
197 | }
198 |
199 | for i := range want {
200 | if want[i] != got[i] {
201 | t.Errorf("want: %v, got: %v", want, got)
202 | return
203 | }
204 | }
205 | }
206 |
207 | func areMatricesEqual(t *testing.T, want M, got *variable.Variable) {
208 | t.Helper()
209 | if len(want) != len(got.Data) {
210 | t.Errorf("matrix length mismatch: want length=%d, got length=%d", len(want), len(got.Data))
211 | return
212 | }
213 |
214 | for i := range want {
215 | if len(want[i]) != len(got.Data[i]) {
216 | t.Errorf("matrix row length mismatch at row %d: want length=%d, got length=%d", i, len(want[i]), len(got.Data[i]))
217 | return
218 | }
219 | }
220 |
221 | for i := range want {
222 | for j := range want[i] {
223 | if math.Abs(want[i][j]-got.Data[i][j]) > 1e-9 {
224 | t.Errorf("matrix mismatch at row %d, column %d: want %v, got %v", i, j, want[i][j], got.Data[i][j])
225 | }
226 | }
227 | }
228 | }
229 |
230 | type V []float64
231 |
232 | func (v V) Var() *variable.Variable {
233 | return variable.NewOf(v)
234 | }
235 |
236 | type M [][]float64
237 |
238 | func (m M) Var() *variable.Variable {
239 | return variable.NewOf(m...)
240 | }
241 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/zakirullin/gpt-go
2 |
3 | go 1.23.6
4 |
5 | require github.com/itsubaki/autograd v0.0.0-20250418093449-6f1c6692aa69
6 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/itsubaki/autograd v0.0.0-20250418093449-6f1c6692aa69 h1:7N9RwynBDvSMdrBCuoqKJrVw12cgQUIUSjbaQb1BrnQ=
2 | github.com/itsubaki/autograd v0.0.0-20250418093449-6f1c6692aa69/go.mod h1:KhCUsowk1ZUwjV10xjlMtCYucGJLwtQaBFGS43ogV8Y=
3 |
--------------------------------------------------------------------------------
/head.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "math"
5 |
6 | "github.com/itsubaki/autograd/layer"
7 | "github.com/itsubaki/autograd/variable"
8 |
9 | "github.com/zakirullin/gpt-go/pkg"
10 | )
11 |
12 | var (
13 | Tril = pkg.Tril
14 | MaskedInfFill = pkg.MaskedInfFill
15 | )
16 |
17 | type MultiHeadAttention struct {
18 | numHeads int
19 | embedSize int
20 | headSize int
21 | Heads []*Head
22 | proj *Linear
23 | }
24 |
25 | func NewMultiHeadAttention(embedSize, numHeads int) *MultiHeadAttention {
26 | heads := make([]*Head, numHeads)
27 | headSize := embedSize / numHeads
28 | for i := range heads {
29 | heads[i] = NewHead(embedSize, headSize)
30 | }
31 |
32 | return &MultiHeadAttention{
33 | Heads: heads,
34 | numHeads: numHeads,
35 | embedSize: embedSize,
36 | headSize: headSize,
37 | proj: NewLinear(embedSize, embedSize),
38 | }
39 | }
40 |
41 | func (mh *MultiHeadAttention) Forward(input *variable.Variable) *variable.Variable {
42 | var features []*variable.Variable
43 | for _, head := range mh.Heads {
44 | features = append(features, head.Forward(input))
45 | }
46 |
47 | out := pkg.Cat(features...)
48 | out = mh.proj.Forward(out) // Project back to (embedSize, embedSize)
49 | out = Dropout(dropout)(out) // Dropping out some neurons to prevent overfitting
50 |
51 | return out
52 | }
53 |
54 | func (mh *MultiHeadAttention) Params() []layer.Parameter {
55 | var params []layer.Parameter
56 | for _, head := range mh.Heads {
57 | params = append(params, head.Query.Weight, head.Key.Weight, head.Value.Weight)
58 | }
59 | params = append(params, mh.proj.Weight, mh.proj.Bias)
60 |
61 | return params
62 | }
63 |
64 | type Head struct {
65 | embedSize int
66 | headSize int
67 | Key *Linear
68 | Query *Linear
69 | Value *Linear
70 | }
71 |
72 | // Number of embeds
73 | func NewHead(embedSize, headSize int) *Head {
74 | key := NewLinear(embedSize, headSize, NoBias())
75 | query := NewLinear(embedSize, headSize, NoBias())
76 | value := NewLinear(embedSize, headSize, NoBias())
77 |
78 | return &Head{embedSize, headSize, key, query, value}
79 | }
80 |
81 | // Self-attention mechanism, see main_test.go for explanation.
82 | func (h *Head) Forward(input *variable.Variable) *variable.Variable {
83 | query := h.Query.Forward(input)
84 | key := h.Key.Forward(input)
85 | attentions := MatMul(query, Transpose(key))
86 |
87 | T := len(input.Data) // number of tokens
88 | tril := Tril(Ones(T, T))
89 | attentions = MaskedInfFill(attentions, tril)
90 | attentions = Softmax(attentions)
91 | attentions = Dropout(dropout)(attentions)
92 |
93 | v := h.Value.Forward(input)
94 | weightedSum := MatMul(attentions, v)
95 | normalizedSum := MulC(math.Pow(float64(h.embedSize), -0.5), weightedSum)
96 |
97 | return normalizedSum
98 | }
99 |
--------------------------------------------------------------------------------
/layer.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "github.com/itsubaki/autograd/layer"
5 | "github.com/itsubaki/autograd/variable"
6 |
7 | "github.com/zakirullin/gpt-go/pkg"
8 | )
9 |
10 | var (
11 | Mean = pkg.Mean
12 | Variance = pkg.Variance
13 | Div = pkg.Div
14 | Mul = variable.Mul
15 | Pow = variable.Pow
16 | RandWeights = pkg.Normal
17 | )
18 |
19 | type Linear struct {
20 | In, Out int
21 | Weight *variable.Variable
22 | Biased bool
23 | Bias *variable.Variable
24 | }
25 |
26 | func NewLinear(in, out int, opts ...LinearOption) *Linear {
27 | l := &Linear{
28 | In: in,
29 | Out: out,
30 | Weight: RandWeights(in, out),
31 | Biased: true,
32 | Bias: variable.Zero(1, out),
33 | }
34 |
35 | for _, opt := range opts {
36 | opt(l)
37 | }
38 |
39 | return l
40 | }
41 |
42 | // Forward computes the output based on the input (forward pass)
43 | func (l *Linear) Forward(input *variable.Variable) *variable.Variable {
44 | logits := MatMul(input, l.Weight)
45 |
46 | if l.Biased {
47 | logits = Add(logits, l.Bias)
48 | }
49 |
50 | return logits
51 | }
52 |
53 | func (l *Linear) Params() []layer.Parameter {
54 | params := []layer.Parameter{
55 | l.Weight,
56 | }
57 |
58 | if l.Biased {
59 | params = append(params, l.Bias)
60 | }
61 |
62 | return params
63 | }
64 |
65 | type LinearOption func(*Linear)
66 |
67 | func NoBias() LinearOption {
68 | return func(l *Linear) {
69 | l.Biased = false
70 | // Set bias tensors to nil or zero-sized tensors
71 | l.Bias = nil
72 | }
73 | }
74 |
75 | type LayerNorm struct {
76 | Scale *variable.Variable
77 | Shift *variable.Variable
78 | eps float64
79 | }
80 |
81 | func NewLayerNorm(dim int) *LayerNorm {
82 | return &LayerNorm{
83 | eps: 1e-05,
84 | Scale: Ones(1, dim),
85 | Shift: Zeros(1, dim),
86 | }
87 | }
88 |
89 | // It is implemented using existing primitives, so back propagation will work
90 | func (ln *LayerNorm) Forward(x *variable.Variable) *variable.Variable {
91 | xmean := Mean(x)
92 | xvar := Variance(x)
93 | eps := variable.New(ln.eps)
94 | xhat := Div(Sub(x, xmean), Pow(0.5)(Add(xvar, eps)))
95 | out := Add(Mul(ln.Scale, xhat), ln.Shift)
96 |
97 | return out
98 | }
99 |
100 | func (ln *LayerNorm) Params() []layer.Parameter {
101 | return []layer.Parameter{
102 | ln.Scale,
103 | ln.Shift,
104 | }
105 | }
106 |
--------------------------------------------------------------------------------
/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "bufio"
5 | "flag"
6 | "fmt"
7 | "os"
8 | "strings"
9 |
10 | "github.com/zakirullin/gpt-go/data"
11 | "github.com/zakirullin/gpt-go/pkg"
12 | )
13 |
14 | // Hyperparameters
15 | const (
16 | blockSize = 32
17 | embedSize = 88
18 | heads = 4
19 | layers = 4
20 | learningRate = 0.0001
21 | steps = 80000 // number of training steps, increase for better results
22 | evalSteps = 1000 // evaluate loss once per every evalSteps
23 | dropout = 0.0 // disable some % of our neurons to prevent overfitting, model is likely to generalize
24 | pretrainedTokens = 6000 // number of pretrained tokens to add on top of auto-detected characters
25 | maxTokens = 50 // tokens limit for generation
26 | )
27 |
28 | func main() {
29 | // Skip training if "-chat" flag is provided.
30 | steps := steps
31 | chat := flag.Bool("chat", false, "Skip training and jump straight to chat")
32 | flag.Parse()
33 | if *chat {
34 | steps = -1
35 | }
36 |
37 | // Loading dataset and building vocabulary.
38 | fmt.Println("Tokenizing dataset...")
39 | dataset, vocabSize := data.Tokenize(pretrainedTokens)
40 | fmt.Printf("First characters:\n%s\n", strings.TrimSpace(data.Decode(dataset[:45]...)))
41 | fmt.Printf("Vocabulary: %s\n", data.Chars())
42 | fmt.Printf("Tokens in dataset: %.3fM\n", pkg.Millions(len(dataset)))
43 |
44 | // Basic transformer components.
45 | tokEmbeds := RandEmbeds(vocabSize, embedSize)
46 | posEmbeds := RandEmbeds(blockSize, embedSize)
47 | var blocks []*Block
48 | for range layers {
49 | blocks = append(blocks, NewBlock(embedSize, heads))
50 | }
51 | norm := NewLayerNorm(embedSize)
52 | lmHead := NewLinear(embedSize, vocabSize)
53 |
54 | // Collecting all the parameters.
55 | params := pkg.NewParams()
56 | params.Add(tokEmbeds, posEmbeds)
57 | for _, block := range blocks {
58 | params.Add(block.Params()...)
59 | }
60 | params.Add(norm.Params()...)
61 | params.Add(lmHead.Params()...)
62 | params.TryLoadPretrained()
63 | fmt.Printf("Model size: %.3fM\n", pkg.Millions(params.Count()))
64 |
65 | // Training loop.
66 | losses := 0.0
67 | optimizer := pkg.NewAdamW(learningRate)
68 | fmt.Printf("bs=%d, es=%d, lr=%.4f, vs=%d, steps=%d\n", blockSize, embedSize, learningRate, vocabSize, steps)
69 | for i := 0; i < steps; i++ {
70 | // Targets contain the ground truth next token for each input token.
71 | input, targets := data.Sample(dataset, blockSize)
72 |
73 | // Forward pass, calculate predictions for every input token.
74 | embeds := Rows(tokEmbeds, Flat(input)...) // get embed for every input token
75 | embeds = Add(embeds, posEmbeds) // add positional embedding
76 | for _, block := range blocks { // self-attention and feed-forward
77 | embeds = block.Forward(embeds)
78 | }
79 | embeds = norm.Forward(embeds)
80 | logits := lmHead.Forward(embeds) // get scores for the next token for every context-enriched embed
81 |
82 | // Loss calculation, "how much our predicted targets differ from the ground truth targets?"
83 | // We average the loss over evalSteps iterations to smooth out fluctuations.
84 | loss := SoftmaxCrossEntropy(logits, targets)
85 | losses += Val(loss)
86 | fmt.Printf("\r%s", strings.Repeat("·", (i%evalSteps)*26/evalSteps)) // progress bar
87 | if i%evalSteps == 0 {
88 | avgLoss := losses / float64(min(i+1, evalSteps))
89 | fmt.Printf("\rstep: %5d, loss: %.4f\n", i, avgLoss)
90 | losses = 0
91 | }
92 |
93 | // Backward pass, calculate the gradients (how much each parameter contributes to the loss)
94 | // for all the parameters (weights, biases, embeds). Loss is the tail of a computation graph.
95 | loss.Backward()
96 | // Nudge the parameters in the direction of the gradients, so to minimize the loss.
97 | optimizer.Update(params)
98 | params.ZeroGrad()
99 | }
100 | params.Save()
101 | pkg.DisableDropout()
102 | // Training is done.
103 |
104 | // Predicts the next token based on the context of tokens.
105 | nextTok := func(context []float64) float64 {
106 | context = context[max(0, len(context)-blockSize):]
107 |
108 | // Feed context tokens to the model.
109 | embeds := Rows(tokEmbeds, context...)
110 | embeds = Add(embeds, posEmbeds)
111 | for _, block := range blocks {
112 | embeds = block.Forward(embeds)
113 | }
114 | embeds = norm.Forward(embeds)
115 | logits := lmHead.Forward(embeds) // get a list of final logits for the next token
116 |
117 | // We only care about the probabilities of the next token for the last token.
118 | logitsForNextToken := Rows(logits, -1)
119 | probs := Softmax(logitsForNextToken)
120 | tok := pkg.SampleTemp(probs, 0.8)
121 |
122 | return tok
123 | }
124 |
125 | // Sample from the model.
126 | prompt := " mysterious island"
127 | for {
128 | fmt.Printf("\n%s", prompt)
129 | context := data.Encode(prompt)
130 | for i := 0; i < maxTokens; i++ {
131 | nextToken := nextTok(context)
132 | fmt.Print(data.Decode(nextToken))
133 | context = append(context, nextToken)
134 | }
135 |
136 | fmt.Print("\n$ ")
137 | scanner := bufio.NewScanner(os.Stdin)
138 | scanner.Scan()
139 | prompt = scanner.Text()
140 | if prompt == "exit" {
141 | fmt.Println("Bye!")
142 | break
143 | }
144 | }
145 | }
146 |
--------------------------------------------------------------------------------
/main_test.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "math"
5 | "testing"
6 |
7 | "github.com/itsubaki/autograd/variable"
8 |
9 | "github.com/zakirullin/gpt-go/data"
10 | )
11 |
12 | func TestNeuron(t *testing.T) {
13 | // Our neuron has 2 inputs and 1 output (number of columns in weight matrix).
14 | // Its goal is to predict next number in the sequence.
15 | input := V{1, 2} // {x1, x2}
16 | weight := M{
17 | {2}, // how much x1 contributes to the output
18 | {3}, // how much x2 contributes to the output
19 | }
20 |
21 | // We calculate the output by multiplying the input vector with the weight matrix.
22 | output := MatMul(input.Var(), weight.Var())
23 | // output[0] = 1*2 + 2*3 = 8
24 | areEqual(t, 8, output)
25 |
26 | // That's a bad prediction (not equal to 3), so we have to adjust the weight.
27 | weight = M{
28 | {1}, // nudge the first weight down by 1
29 | {1}, // nudge the second weight down by 2
30 | }
31 |
32 | output = MatMul(input.Var(), weight.Var())
33 | // output = 1*-1 + 2*2 = 3
34 | // Now our neuron's prediction matches the target, so the weight are correct.
35 | // In reality, though, we don't tune the weight manually.
36 | areEqual(t, 3, output)
37 | }
38 |
39 | func TestLinear(t *testing.T) {
40 | // Linear layer is a collection of neurons.
41 | layer := NewLinear(2, 1)
42 | layer.Weight = M{
43 | {1},
44 | {1},
45 | }.Var()
46 | output := layer.Forward(V{1, 2}.Var())
47 | areEqual(t, 3, output)
48 | }
49 |
50 | func TestLinearWithTwoInputs(t *testing.T) {
51 | layer := NewLinear(2, 1)
52 | layer.Weight = M{
53 | {1},
54 | {1},
55 | }.Var()
56 |
57 | // For each input row the weighted output is calculated independently.
58 | input := M{
59 | {1, 2},
60 | {2, 3},
61 | }.Var()
62 | output := layer.Forward(input)
63 | areMatricesEqual(t, M{
64 | {3},
65 | {5},
66 | }, output)
67 | }
68 |
69 | func TestLoss(t *testing.T) {
70 | input := V{1, 2}.Var()
71 | weight := M{
72 | {2}, // w1
73 | {3}, // w2
74 | }.Var()
75 | output := MatMul(input, weight)
76 | target := V{3}.Var()
77 |
78 | // We need a single number to characterize how bad our prediction is.
79 | // For that we need a loss function.
80 | // Let's pick the simplest one:
81 | // loss = prediction - target
82 | loss := Sub(output, target)
83 | // 8 - 3 = 5
84 | areEqual(t, 5, loss)
85 | // So the loss is 5, which is a bad loss (we should strive to 0).
86 |
87 | // Both input (x1, x2) and weights (w1, w2) values contribute to the loss.
88 | // So, to minimize the loss we can tune the input or the weights.
89 | // Since we are training a model, we want to tune the weights. Inputs come from fixed dataset.
90 | // So we need to calculate how much w1 and w2 contribute to the loss.
91 | // "By what amount the loss would change if we change x1 and x2 by some tiny value".
92 | // That's derivative. The speed of change of the loss with respect to the weights.
93 |
94 | // dLoss/dW1 = (x1 * w1 + x2 * w2) = x1 = 1
95 | // It means that to make our loss bigger, we need to move our x1 into positive direction (+), 1.
96 | // We want the opposite - to make it smaller, so we need to - 1 from our weight.
97 | // dLoss/dW2 = (x1 * w1 + x2 * w2) = x2 = 2
98 |
99 | // We calculate how much each weight contributes to the predicted value.
100 | // So that we know in which direction to nudge the weight to minimize the loss.
101 | loss.Backward()
102 | areMatricesEqual(t, M{{1}, {2}}, weight.Grad) // derivatives also called gradients
103 | }
104 |
105 | func TestGradientDescent(t *testing.T) {
106 | input := V{1, 2}.Var() // {x1, x2}
107 | weight := M{
108 | {2}, // w1
109 | {3}, // w2
110 | }
111 |
112 | // Now we know in which direction we should nudge the weights to minimize the loss.
113 | // Gradient for w1 is 1, it means that w1 contributes to loss proportionally.
114 | // Gradient for w2 is 2, it means that w2 contributes to loss twice as strongly.
115 | // I.e. if we nudge w2 by some tiny value 0.1, the loss will change by 0.2.
116 | // If we want to minimize the loss, we nudge the weights in the opposite direction.
117 | learningRate := 0.5
118 | weightGrad := V{1, 2}
119 | weight[0][0] -= learningRate * weightGrad[0] // w1 -= w1 * learningRate * w1Grad
120 | weight[1][0] -= learningRate * weightGrad[1] // w2 -= w2 * learningRate * w2Grad
121 |
122 | output := MatMul(input, weight.Var())
123 | // Previously the neuron predicted 8, now it predicts 5.5.
124 | areEqual(t, 5.5, output)
125 | // Previously the loss was 5, now it is 2.5, so the model has learned something.
126 | loss := Sub(output, V{3}.Var())
127 | areEqual(t, 2.5, loss)
128 |
129 | // Repeat the process.
130 | weight[0][0] -= learningRate * weightGrad[0]
131 | weight[1][0] -= learningRate * weightGrad[1]
132 | output = MatMul(input, weight.Var()) // MatMul({1, 2}, weights) = 3
133 |
134 | // The neuron predicts 3 now, which is exactly what follows after 1 and 2!
135 | areEqual(t, 3, output)
136 | loss = Sub(output, V{3}.Var())
137 | // The loss should be 0 now, because the neuron predicts the target value.
138 | areEqual(t, 0, loss)
139 |
140 | // Our simple model is now trained.
141 | // If the input is {1, 2}, the output is 3.
142 | // Our learning weights are:
143 | areMatricesEqual(t, M{
144 | {1},
145 | {1},
146 | }, weight.Var()) // w1 = 1, w2 = 2
147 | }
148 |
149 | func TestSelfAttention(t *testing.T) {
150 | // Suppose we have the following tokens (~words) in our vocabulary:
151 | // 0 - "cat"
152 | // 1 - ", "
153 | // 2 - "dog"
154 | // 3 - " and"
155 |
156 | // Embeddings are a way to represent words (tokens) as vectors. It encodes the meaning of
157 | // the word in a high-dimensional space. Similar words are close to each other.
158 | embeds := M{
159 | {4, 1, 6}, // embedding for "cat"
160 | {1, 8, 1}, // embedding for ", "
161 | {4, 1, 7}, // embedding for "dog"
162 | {1, 9, 3}, // embedding for " and"
163 | }.Var()
164 | // As we can see, embeddings for "cat" and "dog" are quite similar.
165 |
166 | // The input to transformer models are a sequence of token embeddings.
167 | // So if we feed "cat, dog and" string to our transformer, we must encode it.
168 | // First we tokenize it, i.e. split the sentence into the tokens:
169 | input := V{0, 1, 2, 3}.Var()
170 |
171 | // Then convert it to the list of embeddings:
172 | //{
173 | // {4, 1, 6}, // embedding for "cat"
174 | // {1, 8, 3}, // embedding for ", "
175 | // {4, 1, 7}, // embedding for "dog"
176 | // {1, 9, 3}, // embedding for " and"
177 | //}
178 | inputEmbeds := Rows(embeds, Flat(input)...)
179 | areMatricesEqual(t, M{
180 | {4, 1, 6}, // embedding for "cat"
181 | {1, 8, 1}, // embedding for ", "
182 | {4, 1, 7}, // embedding for "dog"
183 | {1, 9, 3}, // embedding for " and"
184 | }, inputEmbeds)
185 |
186 | // How do we predict next token from a given sequence of tokens?
187 | // We can naively predict the next token by looking only at the last token (bigram model does that).
188 | // However, by looking at the token "and" alone, we lose the context of the previous tokens (what does "and" refer to?).
189 |
190 | // So, we have to somehow combine the information from the current token and all the previous tokens.
191 | // "cat" -> "cat", no previous tokens to look at
192 | // ", " -> "cat" + ", "
193 | // "blue" -> "cat" + ", " + "dog"
194 | // " and" -> "cat" + ", " + "dog" + " and"
195 | // Since we're operating with numerical representations of the words (embeddings), we can just add them together.
196 | // I.e. for token " and" we'll do that:
197 | // {4, 1, 6} + {1, 8, 1} + {4, 1, 7} + {1, 9, 3} = {10, 19, 17}
198 | // Now our resulting vector " and" combines more information from the previous tokens. Now we can predict the next
199 | // token more accurately, because we have more context.
200 |
201 | // To calculate the sum of all previous tokens, we can multiply by this triangular matrix:
202 | tril := M{
203 | {1, 0, 0, 0}, // first token attends only at itself ("cat"), it can't look into the future
204 | {1, 1, 0, 0}, // second token attends at itself and the previous token ( "cat" + ", ")
205 | {1, 1, 1, 0}, // third token attends at itself and the two previous tokens ("cat" + ", " + "dog")
206 | {1, 1, 1, 1}, // fourth token attends at itself and all the previous tokens ("cat" + ", " + "dog" + " and")
207 | }.Var()
208 |
209 | // So, at this point each embedding is enriched with the information from all the previous tokens.
210 | // That's the crux of self-attention.
211 | enrichedEmbeds := MatMul(tril, inputEmbeds)
212 | areMatricesEqual(t, M{
213 | {4, 1, 6},
214 | {5, 9, 7},
215 | {9, 10, 14},
216 | {10, 19, 17},
217 | }, enrichedEmbeds)
218 |
219 | }
220 |
221 | func TestWeightedSelfAttention(t *testing.T) {
222 | // In reality, though, we don't pay equal attention to all the previous tokens.
223 | // Some previous tokens are more interested to us, some are less.
224 |
225 | // Let's look at token "and" and its {1, 2, 3} embedding.
226 | // We treat those "1", "2" and "3" components in same way (some features we don't know about).
227 | // But let's split them into 3 categories:
228 | // Query - "what I am looking for"
229 | // Key - "what I can communicate"
230 | // Value - "what I give you"
231 |
232 | // Since we don't know how to split them, we can use a linear layer to learn that for us.
233 | // We introduce 3 linear layers: query, key and value.
234 | query := NewLinear(3, 3) // converts each embedding into a query vector "what I am looking for"
235 | query.Weight = Zeros(3, 3) // manually set values for a good example
236 | query.Weight.Data[1][0] = 10
237 |
238 | key := NewLinear(3, 3) // converts each embedding into a key vector "what I can communicate"
239 | key.Weight = Zeros(3, 3)
240 | key.Weight.Data[0][0] = 1 // first neuron is paying attention to the first component of the embedding
241 |
242 | value := NewLinear(3, 3) // converts each embedding into a value vector "what I give you"
243 | value.Weight = Ones(3, 3)
244 |
245 | embeds := M{
246 | {4, 1, 6}, // embedding for "cat"
247 | {1, 8, 1}, // embedding for ", "
248 | {4, 1, 7}, // embedding for "dog"
249 | {1, 9, 3}, // embedding for " and"
250 | }.Var()
251 |
252 | // Let's now extract the key and query vectors for each embed.
253 |
254 | k := key.Forward(embeds)
255 | // For our case, let's imagine that first component of our key is responsible for
256 | // I am "enumerable token". The bigger the value, "the more enumerable" the token is.
257 | areMatricesEqual(t, M{
258 | {4, 0, 0}, // "cat" is quite enumerable
259 | {1, 0, 0}, // ", " is not quite enumerable
260 | {4, 0, 0}, // "dog" is quite enumerable
261 | {1, 0, 0}, // " and" is not quite enumerable
262 | }, k)
263 |
264 | q := query.Forward(embeds)
265 | areMatricesEqual(t, M{
266 | {10, 0, 0}, // token "cat" is not looking for something enumerable
267 | {80, 0, 0}, // token ", " is looking for something enumerable a lot
268 | {10, 0, 0}, // token "dog" is not looking for something enumerable
269 | {90, 0, 0}, // token " and" is looking for something enumerable a lot
270 | }, q)
271 |
272 | // If we multiply q * k vectors, for each token we would answer to
273 | // a question "what tokens are interesting for me?".
274 | // Big values would indicate high interest.
275 | attentionScores := MatMul(q, Transpose(k))
276 | areMatricesEqual(t, M{
277 | {40, 10, 40, 10},
278 | {320, 80, 320, 80},
279 | {40, 10, 40, 10},
280 | {360, 90, 360, 90}, // token " and" is interested in tokens "cat" (score=360) and "dog" (score=360)
281 | }, attentionScores)
282 |
283 | tril := M{
284 | {1, 0, 0, 0}, // first token attends only at itself ("cat"), it can't look into the future
285 | {1, 1, 0, 0}, // second token attends at itself and the previous token ( "cat" + ", ")
286 | {1, 1, 1, 0}, // third token attends at itself and the two previous tokens ("cat" + ", " + "dog")
287 | {1, 1, 1, 1}, // fourth token attends at itself and all the previous tokens ("cat" + ", " + "dog" + " and")
288 | }.Var()
289 | // Previously we attended to all the previous token with the help of this tril matrix.
290 | // Now we only attend to those tokens in which we are interested, which is basically:
291 | attentionScores = MaskedInfFill(attentionScores, tril)
292 | no := math.Inf(-1)
293 | areMatricesEqual(t, M{
294 | {80, no, no, no},
295 | {640, 160, no, no},
296 | {80, 20, 80, no},
297 | {720, 180, 720, 180}, // token " and" is interested in "cat" and "dog", not so much in the others
298 | }, attentionScores)
299 |
300 | attentionScores = Softmax(attentionScores) // fancy trick to turn {1, 1, no, no} to {0.5, 0.5, 0, 0}
301 | areMatricesEqual(t, M{
302 | {1, 0, 0, 0},
303 | {1, 0, 0, 0},
304 | {0.5, 0, 0.5, 0},
305 | {0.5, 0, 0.5, 0}, // token " and" is interested in "cat" and "dog", not interested into others at all
306 | }, attentionScores)
307 | }
308 |
309 | func TestTransformer(t *testing.T) {
310 | RandEmbeds = func(rows, cols int) *variable.Variable {
311 | return Zeros(rows, cols)
312 | }
313 | RandWeights = func(rows, cols int) *variable.Variable {
314 | return Ones(rows, cols)
315 | }
316 | data.RandInt = func(_ int) int {
317 | return 0
318 | }
319 |
320 | vocabSize := 10
321 | embedSize := 2
322 | blockSize := 2
323 |
324 | // Basic transformer components
325 | tokEmbeds := RandEmbeds(vocabSize, embedSize)
326 | posEmbeds := RandEmbeds(blockSize, embedSize)
327 | block := NewBlock(embedSize, 1)
328 | norm := NewLayerNorm(embedSize)
329 | lmHead := NewLinear(embedSize, vocabSize)
330 |
331 | // Input contains blockSize consecutive tokens.
332 | // Targets contain the expected next token for each input token.
333 | // Example: for input={0,1}, targets={1,2}, meaning
334 | // that next token (target) after 0 is 1, next after 1 is 2.
335 | input, targets := data.Sample([]float64{0, 1, 2}, blockSize)
336 |
337 | // {
338 | // {vector for tok0},
339 | // {vector for tok1},
340 | // ... other embeds
341 | // }
342 | embeds := Rows(tokEmbeds, Flat(input)...) // get embed for every input token
343 | embeds = Add(embeds, posEmbeds) // add positional embedding
344 | embeds = block.Forward(embeds)
345 | embeds = norm.Forward(embeds)
346 | // {
347 | // {score for tok0, ..., score for tokN}, // for input tok0
348 | // {score for tok0, ..., score for tokN}, // for input tok1
349 | // ... other logits
350 | // }
351 | logits := lmHead.Forward(embeds) // converts contextual embeddings to next-token predictions
352 |
353 | // Loss calculation, how much our predicted targets differ from the actual targets?
354 | loss := SoftmaxCrossEntropy(logits, targets)
355 |
356 | areEqual(t, 2.302585092994046, loss)
357 | }
358 |
359 | func areEqual(t *testing.T, want float64, got *variable.Variable) {
360 | t.Helper()
361 | if len(got.Data) != 1 {
362 | t.Errorf("expected a single value, got %d values", len(got.Data))
363 | return
364 | }
365 |
366 | if math.Abs(want-Val(got)) > 1e-9 {
367 | t.Errorf("value mismatch: want %v, got %v", want, Val(got))
368 | }
369 | }
370 |
371 | func areMatricesEqual(t *testing.T, want M, got *variable.Variable) {
372 | t.Helper()
373 | if len(want) != len(got.Data) {
374 | t.Errorf("matrix length mismatch: want length=%d, got length=%d", len(want), len(got.Data))
375 | return
376 | }
377 |
378 | for i := range want {
379 | if len(want[i]) != len(got.Data[i]) {
380 | t.Errorf("matrix row length mismatch at row %d: want length=%d, got length=%d", i, len(want[i]), len(got.Data[i]))
381 | return
382 | }
383 | }
384 |
385 | for i := range want {
386 | for j := range want[i] {
387 | if math.Abs(want[i][j]-got.Data[i][j]) > 1e-9 {
388 | t.Errorf("matrix mismatch at row %d, column %d: want %v, got %v", i, j, want[i][j], got.Data[i][j])
389 | }
390 | }
391 | }
392 | }
393 |
394 | type V []float64
395 |
396 | func (v V) Var() *variable.Variable {
397 | return variable.NewOf(v)
398 | }
399 |
400 | type M [][]float64
401 |
402 | func (m M) Var() *variable.Variable {
403 | return variable.NewOf(m...)
404 | }
405 |
--------------------------------------------------------------------------------
/pkg/adamw.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "math"
5 |
6 | "github.com/itsubaki/autograd/matrix"
7 | "github.com/itsubaki/autograd/optimizer"
8 | "github.com/itsubaki/autograd/variable"
9 | )
10 |
11 | type AdamW struct {
12 | Alpha float64 // Learning rate
13 | Beta1 float64 // Exponential decay rate for first moment
14 | Beta2 float64 // Exponential decay rate for second moment
15 | WeightDecay float64 // Weight decay coefficient
16 | Hook []optimizer.Hook
17 | iter int
18 | ms, vs map[*variable.Variable]matrix.Matrix
19 | }
20 |
21 | func NewAdamW(learningRate float64) AdamW {
22 | return AdamW{Alpha: learningRate, Beta1: 0.9, Beta2: 0.999, WeightDecay: 0.01}
23 | }
24 |
25 | func (o *AdamW) Update(model optimizer.Model) {
26 | params := optimizer.Params(model, o.Hook)
27 |
28 | if len(o.ms) == 0 {
29 | o.ms = make(map[*variable.Variable]matrix.Matrix)
30 | o.vs = make(map[*variable.Variable]matrix.Matrix)
31 | }
32 |
33 | o.iter++
34 | fix1 := 1.0 - math.Pow(o.Beta1, float64(o.iter))
35 | fix2 := 1.0 - math.Pow(o.Beta2, float64(o.iter))
36 | lr := o.Alpha * math.Sqrt(fix2) / fix1
37 |
38 | for _, p := range params {
39 | if _, ok := o.ms[p]; !ok {
40 | o.ms[p] = matrix.ZeroLike(p.Data)
41 | o.vs[p] = matrix.ZeroLike(p.Data)
42 | }
43 |
44 | // Update biased first moment estimate
45 | o.ms[p] = matrix.F2(o.ms[p], p.Grad.Data, func(m, grad float64) float64 {
46 | return m + ((1 - o.Beta1) * (grad - m))
47 | })
48 |
49 | // Update biased second raw moment estimate
50 | o.vs[p] = matrix.F2(o.vs[p], p.Grad.Data, func(v, grad float64) float64 {
51 | return v + ((1 - o.Beta2) * (grad*grad - v))
52 | })
53 |
54 | // The key difference for AdamW: apply weight decay directly to the weights
55 | // instead of incorporating it into the gradient
56 |
57 | // First compute the standard Adam update
58 | adamUpdate := matrix.F2(o.ms[p], o.vs[p], func(m, v float64) float64 {
59 | return lr * m / (math.Sqrt(v) + 1e-8)
60 | })
61 |
62 | // Then apply weight decay separately
63 | weightDecayUpdate := matrix.MulC(lr*o.WeightDecay, p.Data)
64 |
65 | // Update parameters: param = param - adamUpdate - weightDecayUpdate
66 | p.Data = matrix.Sub(p.Data, adamUpdate)
67 | p.Data = matrix.Sub(p.Data, weightDecayUpdate)
68 | }
69 | }
70 |
--------------------------------------------------------------------------------
/pkg/cat.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "github.com/itsubaki/autograd/variable"
5 | )
6 |
7 | // Cat concatenates matrices horizontally
8 | func Cat(x ...*variable.Variable) *variable.Variable {
9 | return (&variable.Function{Forwarder: &CatT{NumInputs: len(x)}}).First(x...)
10 | }
11 |
12 | type CatT struct {
13 | NumInputs int
14 | ColSize int
15 | }
16 |
17 | // Concatenate along the columns dimension (dim=1)
18 | func (f *CatT) Forward(x ...*variable.Variable) []*variable.Variable {
19 | rows := len(x[0].Data)
20 | f.ColSize = len(x[0].Data[0])
21 | totalCols := f.ColSize * len(x)
22 |
23 | result := make([][]float64, rows)
24 | for i := range result {
25 | result[i] = make([]float64, totalCols)
26 | colOffset := 0
27 | for _, v := range x {
28 | copy(result[i][colOffset:], v.Data[i])
29 | colOffset += f.ColSize
30 | }
31 | }
32 |
33 | return []*variable.Variable{
34 | variable.NewOf(result...),
35 | }
36 | }
37 |
38 | func (f *CatT) Backward(gy ...*variable.Variable) []*variable.Variable {
39 | grads := make([]*variable.Variable, f.NumInputs)
40 |
41 | // Split along columns
42 | for i := 0; i < f.NumInputs; i++ {
43 | colOffset := i * f.ColSize
44 | colData := make([][]float64, len(gy[0].Data))
45 |
46 | for j := range colData {
47 | colData[j] = make([]float64, f.ColSize)
48 | copy(colData[j], gy[0].Data[j][colOffset:colOffset+f.ColSize])
49 | }
50 |
51 | grads[i] = variable.NewOf(colData...)
52 | }
53 |
54 | return grads
55 | }
56 |
--------------------------------------------------------------------------------
/pkg/cat_test.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "fmt"
5 | )
6 |
7 | func ExampleCat_basic() {
8 | a := M{
9 | {1, 2},
10 | {3, 4},
11 | }.Var()
12 |
13 | b := M{
14 | {5, 6},
15 | {7, 8},
16 | }.Var()
17 |
18 | result := Cat(a, b)
19 | fmt.Println(result.Data)
20 |
21 | // Output: [[1 2 5 6] [3 4 7 8]]
22 | }
23 |
24 | func ExampleCat_singleInput() {
25 | a := M{
26 | {1, 2},
27 | {3, 4},
28 | }.Var()
29 |
30 | result := Cat(a)
31 | fmt.Println(result.Data)
32 |
33 | // Output: [[1 2] [3 4]]
34 | }
35 |
36 | func ExampleCat_threeMatrices() {
37 | a := M{
38 | {1, 2},
39 | {3, 4},
40 | }.Var()
41 |
42 | b := M{
43 | {5, 6},
44 | {7, 8},
45 | }.Var()
46 |
47 | c := M{
48 | {9, 10},
49 | {11, 12},
50 | }.Var()
51 |
52 | result := Cat(a, b, c)
53 | fmt.Println(result.Data)
54 |
55 | // Output: [[1 2 5 6 9 10] [3 4 7 8 11 12]]
56 | }
57 |
58 | func ExampleCat_gradient() {
59 | a := M{
60 | {1, 2},
61 | {3, 4},
62 | }.Var()
63 |
64 | b := M{
65 | {5, 6},
66 | {7, 8},
67 | }.Var()
68 |
69 | result := Cat(a, b)
70 | result.Grad = M{
71 | {0.1, 0.2, 0.3, 0.4},
72 | {0.5, 0.6, 0.7, 0.8},
73 | }.Var()
74 |
75 | result.Backward()
76 |
77 | fmt.Println(a.Grad.Data)
78 | fmt.Println(b.Grad.Data)
79 |
80 | // Output:
81 | // [[0.1 0.2] [0.5 0.6]]
82 | // [[0.3 0.4] [0.7 0.8]]
83 | }
84 |
85 | func ExampleCat_withThreeDifferentMatrices() {
86 | a := M{
87 | {1, 2},
88 | {3, 4},
89 | }.Var()
90 |
91 | b := M{
92 | {5, 6},
93 | {7, 8},
94 | }.Var()
95 |
96 | c := M{
97 | {9, 10},
98 | {11, 12},
99 | }.Var()
100 |
101 | result := Cat(a, b, c)
102 | result.Grad = M{
103 | {0.1, 0.2, 0.3, 0.4, 0.5, 0.6},
104 | {0.7, 0.8, 0.9, 1.0, 1.1, 1.2},
105 | }.Var()
106 |
107 | result.Backward()
108 |
109 | fmt.Println(a.Grad.Data)
110 | fmt.Println(b.Grad.Data)
111 | fmt.Println(c.Grad.Data)
112 |
113 | // Output:
114 | // [[0.1 0.2] [0.7 0.8]]
115 | // [[0.3 0.4] [0.9 1]]
116 | // [[0.5 0.6] [1.1 1.2]]
117 | }
118 |
119 | func ExampleCat_matrixMultiplicationWith() {
120 | a := M{
121 | {1, 2},
122 | {3, 4},
123 | }.Var()
124 |
125 | b := M{
126 | {5, 6},
127 | {7, 8},
128 | }.Var()
129 |
130 | c := Cat(a, b)
131 |
132 | d := M{
133 | {0.1},
134 | {0.2},
135 | {0.3},
136 | {0.4},
137 | }.Var()
138 |
139 | result := MatMul(c, d)
140 | fmt.Println(result.Data)
141 |
142 | result.Backward()
143 |
144 | fmt.Println(a.Grad.Data)
145 | fmt.Println(b.Grad.Data)
146 |
147 | // Output:
148 | // [[4.4] [6.4]]
149 | // [[0.1 0.2] [0.1 0.2]]
150 | // [[0.3 0.4] [0.3 0.4]]
151 | }
152 |
--------------------------------------------------------------------------------
/pkg/functions.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "math"
5 | "math/rand/v2"
6 |
7 | "github.com/itsubaki/autograd/matrix"
8 | "github.com/itsubaki/autograd/variable"
9 | )
10 |
11 | var (
12 | Add = variable.Add
13 | Div = variable.Div
14 | Zeros = variable.Zero
15 | )
16 |
17 | // Sample returns a random index based on the given probabilities.
18 | func Sample(probs *variable.Variable) float64 {
19 | r := rand.Float64()
20 |
21 | // Find the first index where cumulative probability exceeds r.
22 | cumulativeProb := 0.0
23 | for i, p := range probs.Data[0] {
24 | cumulativeProb += p
25 | if r < cumulativeProb {
26 | return float64(i)
27 | }
28 | }
29 |
30 | return float64(len(probs.Data)) - 1
31 | }
32 |
33 | // SampleTemp returns a random index based on the given probabilities and temperature.
34 | // The higher the temperature, the more random the sampling.
35 | // Usually, temperature is between 0.5 and 0.8.
36 | func SampleTemp(probs *variable.Variable, temperature float64) float64 {
37 | adjustedProbs := make([]float64, len(probs.Data[0]))
38 | copy(adjustedProbs, probs.Data[0])
39 | if temperature != 1.0 {
40 | // Lower temperature: higher probs amplified, lower reduced, more deterministic.
41 | // Higher temperature: probabilities become more uniform, more random.
42 | sum := 0.0
43 | for i, p := range adjustedProbs {
44 | // Apply temperature by raising to power of 1/temperature.
45 | adjustedProbs[i] = math.Pow(p, 1.0/temperature)
46 | sum += adjustedProbs[i]
47 | }
48 |
49 | for i := range adjustedProbs {
50 | adjustedProbs[i] /= sum
51 | }
52 | }
53 |
54 | return Sample(variable.NewOf(adjustedProbs))
55 | }
56 |
57 | // Returns rows at specified indexes. Negative indexes return rows from the end.
58 | func Rows(x *variable.Variable, indexes ...float64) *variable.Variable {
59 | size := len(x.Data)
60 |
61 | var intIndexes []int
62 | for _, index := range indexes {
63 | intIndex := int(index)
64 | if intIndex < 0 {
65 | intIndex = size + intIndex
66 | }
67 |
68 | intIndexes = append(intIndexes, intIndex)
69 | }
70 |
71 | return (&variable.Function{Forwarder: &variable.GetItemT{Slices: intIndexes}}).First(x)
72 | }
73 |
74 | // Returns a matrix of random values from a normal distribution.
75 | func Normal(rows, cols int) *variable.Variable {
76 | rnd := func(_ float64) float64 {
77 | // Standard deviation = 0.02 is widely used in transformer models like GPT-2.
78 | // It prevents too large values in the beginning of training.
79 | std := 0.02
80 | return rand.NormFloat64() * std
81 | }
82 |
83 | m := matrix.Zero(rows, cols)
84 | m = matrix.F(m, rnd)
85 |
86 | return variable.NewOf(m...)
87 | }
88 |
89 | func Tril(m *variable.Variable) *variable.Variable {
90 | result := variable.ZeroLike(m)
91 | for i := 0; i < len(m.Data); i++ {
92 | for j := 0; j < len(m.Data[i]); j++ {
93 | if j <= i {
94 | result.Data[i][j] = m.Data[i][j]
95 | }
96 | }
97 | }
98 |
99 | return result
100 | }
101 |
102 | // The result would be added to computation graph and tied to m.
103 | func MaskedInfFill(m, mask *variable.Variable) *variable.Variable {
104 | negInfMaskedData := matrix.F2(m.Data, mask.Data, func(a, b float64) float64 {
105 | if b == 0 {
106 | return math.Inf(-1)
107 | }
108 |
109 | return a
110 | })
111 | mMasked := Add(variable.Mul(m, mask), variable.NewOf(negInfMaskedData...))
112 |
113 | return mMasked
114 | }
115 |
116 | func DivC(c float64, x *variable.Variable) *variable.Variable {
117 | return variable.MulC(1.0/c, x)
118 | }
119 |
120 | // Returns a matrix of ones.
121 | func Ones(m, n int) *variable.Variable {
122 | out := make([][]float64, m)
123 | for i := range m {
124 | out[i] = make([]float64, n)
125 | for j := range n {
126 | out[i][j] = 1.0
127 | }
128 | }
129 |
130 | return variable.NewOf(out...)
131 | }
132 |
133 | // Returns the first element of the variable.
134 | func Val(x *variable.Variable) float64 {
135 | return x.Data[0][0]
136 | }
137 |
138 | func Flat(x *variable.Variable) []float64 {
139 | return matrix.Flatten(x.Data)
140 | }
141 |
142 | func Millions(num int) float64 {
143 | return float64(num) / 1e6
144 | }
145 |
146 | func DisableDropout() {
147 | variable.Config.Train = false // disables dropout
148 | }
149 |
--------------------------------------------------------------------------------
/pkg/matmul.go:
--------------------------------------------------------------------------------
1 | // MatMul performs parallelized matrix multiplication that minimizes
2 | // CPU cache misses. It divides the computation into sequential chunks
3 | // processed by multiple goroutines in parallel.
4 | package pkg
5 |
6 | import (
7 | "runtime"
8 | "sync"
9 |
10 | "github.com/itsubaki/autograd/matrix"
11 | "github.com/itsubaki/autograd/variable"
12 | )
13 |
14 | const (
15 | blockSize = 32 // Group calculations by blocks for better CPU cache utilization
16 | )
17 |
18 | func MatMul(x ...*variable.Variable) *variable.Variable {
19 | return (&variable.Function{Forwarder: &MatMulT{}}).First(x...)
20 | }
21 |
22 | type MatMulT struct {
23 | x, w *variable.Variable
24 | }
25 |
26 | func (f *MatMulT) Forward(x ...*variable.Variable) []*variable.Variable {
27 | f.x, f.w = x[0], x[1]
28 |
29 | y := matmul(x[0].Data, x[1].Data)
30 | return []*variable.Variable{y}
31 | }
32 |
33 | func (f *MatMulT) Backward(gy ...*variable.Variable) []*variable.Variable {
34 | return []*variable.Variable{
35 | MatMul(gy[0], variable.Transpose(f.w)), // gy * w.T
36 | MatMul(variable.Transpose(f.x), gy[0]), // x.T * gy
37 | }
38 | }
39 |
40 | func matmul(m, n matrix.Matrix) *variable.Variable {
41 | mRows, mCols := len(m), len(m[0])
42 | _, nCols := len(n), len(n[0])
43 |
44 | result := Zeros(mRows, nCols)
45 | var wg sync.WaitGroup
46 | numCPU := runtime.NumCPU()
47 | // Create more chunks than CPUs for better load balancing
48 | // Adjust the multiplier to find the optimal balance
49 | chunkSize := max(1, mRows/(numCPU*4))
50 | for startRow := 0; startRow < mRows; startRow += chunkSize {
51 | wg.Add(1)
52 |
53 | go func(firstRow, lastRow int) {
54 | defer wg.Done()
55 |
56 | // Process this chunk of rows with blocking for better cache utilization
57 | for ii := firstRow; ii < lastRow; ii += blockSize {
58 | for kk := 0; kk < mCols; kk += blockSize {
59 | for jj := 0; jj < nCols; jj += blockSize {
60 | // Calculate bounds for current block
61 | iEnd := min(ii+blockSize, lastRow)
62 | kEnd := min(kk+blockSize, mCols)
63 | jEnd := min(jj+blockSize, nCols)
64 |
65 | // Process the current block with cache-friendly access
66 | for i := ii; i < iEnd; i++ {
67 | for k := kk; k < kEnd; k++ {
68 | aik := m[i][k]
69 | for j := jj; j < jEnd; j++ {
70 | result.Data[i][j] += aik * n[k][j]
71 | }
72 | }
73 | }
74 | }
75 | }
76 | }
77 | }(startRow, min(startRow+chunkSize, mRows))
78 | }
79 | wg.Wait()
80 |
81 | return result
82 | }
83 |
--------------------------------------------------------------------------------
/pkg/matmul_test.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/itsubaki/autograd/variable"
7 | )
8 |
9 | func ExampleMatMul_basic2x2() {
10 | a := M{
11 | {1, 2},
12 | {3, 4},
13 | }.Var()
14 |
15 | b := M{
16 | {5, 6},
17 | {7, 8},
18 | }.Var()
19 |
20 | result := MatMul(a, b)
21 | fmt.Println(result)
22 |
23 | // Output:
24 | // variable([[19 22] [43 50]])
25 | }
26 |
27 | func ExampleMatMul_nonSquare() {
28 | a := M{
29 | {1, 2, 3},
30 | {4, 5, 6},
31 | }.Var()
32 |
33 | b := M{
34 | {7, 8},
35 | {9, 10},
36 | {11, 12},
37 | }.Var()
38 |
39 | result := MatMul(a, b)
40 | fmt.Println(result)
41 |
42 | // Output:
43 | // variable([[58 64] [139 154]])
44 | }
45 |
46 | func ExampleMatMul_columnVector() {
47 | a := M{
48 | {1, 2},
49 | {3, 4},
50 | {5, 6},
51 | }.Var()
52 |
53 | b := M{
54 | {7},
55 | {8},
56 | }.Var()
57 |
58 | result := MatMul(a, b)
59 | fmt.Println(result)
60 |
61 | // Output:
62 | // variable([[23] [53] [83]])
63 | }
64 |
65 | func ExampleMatMul_rowVector() {
66 | a := V{1, 2}.Var()
67 |
68 | b := M{
69 | {3, 4},
70 | {5, 6},
71 | }.Var()
72 |
73 | result := MatMul(a, b)
74 | fmt.Println(result)
75 |
76 | // Output:
77 | // variable([13 16])
78 | }
79 |
80 | func ExampleMatMul_chain() {
81 | a := V{1, 2}.Var()
82 |
83 | b := M{
84 | {3, 4},
85 | {5, 6},
86 | }.Var()
87 |
88 | c := M{
89 | {7},
90 | {8},
91 | }.Var()
92 |
93 | result := MatMul(MatMul(a, b), c)
94 | fmt.Println(result)
95 |
96 | // Output:
97 | // variable([219])
98 | }
99 |
100 | func ExampleMatMul_zeroMatrix() {
101 | a := M{
102 | {0, 0},
103 | {0, 0},
104 | }.Var()
105 |
106 | b := M{
107 | {1, 2},
108 | {3, 4},
109 | }.Var()
110 |
111 | result := MatMul(a, b)
112 | fmt.Println(result)
113 |
114 | // Output:
115 | //variable([[0 0] [0 0]])
116 | }
117 |
118 | func ExampleMatMul_gradient() {
119 | a := M{
120 | {1, 2},
121 | {3, 4},
122 | }.Var()
123 |
124 | b := M{
125 | {5, 6},
126 | {7, 8},
127 | }.Var()
128 |
129 | result := MatMul(a, b)
130 |
131 | result.Grad = M{
132 | {1, 1},
133 | {1, 1},
134 | }.Var()
135 |
136 | result.Backward()
137 |
138 | fmt.Println(a.Grad)
139 | fmt.Println(b.Grad)
140 |
141 | // Output:
142 | //variable([[11 15] [11 15]])
143 | //variable([[4 4] [6 6]])
144 | }
145 |
146 | // Shortcut for building readable matrices:
147 | //
148 | // M{
149 | // {1, 2},
150 | // {3, 4},
151 | // }.Var()
152 | type M [][]float64
153 |
154 | func (m M) Var() *variable.Variable {
155 | return variable.NewOf(m...)
156 | }
157 |
158 | type V []float64
159 |
160 | func (v V) Var() *variable.Variable {
161 | return variable.NewOf(v)
162 | }
163 |
--------------------------------------------------------------------------------
/pkg/mean.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "github.com/itsubaki/autograd/variable"
5 | )
6 |
7 | func Mean(x ...*variable.Variable) *variable.Variable {
8 | return (&variable.Function{Forwarder: &MeanT{}}).First(x...)
9 | }
10 |
11 | type MeanT struct {
12 | x *variable.Variable
13 | n int
14 | }
15 |
16 | // Mean alongside rows
17 | func (m *MeanT) Forward(x ...*variable.Variable) []*variable.Variable {
18 | m.x = x[0]
19 | m.n = len(x[0].Data[0])
20 |
21 | means := variable.Zero(len(x[0].Data), 1)
22 | for i := range x[0].Data {
23 | means.Data[i][0] = mean(x[0].Data[i])
24 | }
25 |
26 | return []*variable.Variable{
27 | means,
28 | }
29 | }
30 |
31 | // Derivative of mean(x1, x2) by xn = 1/n
32 | func (m *MeanT) Backward(gy ...*variable.Variable) []*variable.Variable {
33 | g := variable.ZeroLike(m.x)
34 | for i := range g.Data {
35 | for j := range g.Data[i] {
36 | g.Data[i][j] = gy[0].Data[i][0] / float64(m.n)
37 | }
38 | }
39 |
40 | return []*variable.Variable{
41 | g,
42 | }
43 | }
44 |
45 | func mean(values []float64) float64 {
46 | sum := 0.0
47 | for _, v := range values {
48 | sum += v
49 | }
50 |
51 | return sum / float64(len(values))
52 | }
53 |
--------------------------------------------------------------------------------
/pkg/mean_test.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/itsubaki/autograd/variable"
7 | )
8 |
9 | func ExampleMean_basic() {
10 | a := M{
11 | {1, 2, 3},
12 | {4, 5, 6},
13 | }.Var()
14 |
15 | result := Mean(a)
16 | fmt.Println(result.Data)
17 |
18 | // Output: [[2] [5]]
19 | }
20 |
21 | func ExampleMean_withZero() {
22 | a := M{
23 | {1, 0, 1},
24 | }.Var()
25 |
26 | result := Mean(a)
27 | fmt.Printf("%.6f\n", result.Data[0][0])
28 |
29 | // Output: 0.666667
30 | }
31 |
32 | func ExampleMean_withZeros() {
33 | a := M{
34 | {0, 0, 0},
35 | {0, 0, 0},
36 | }.Var()
37 |
38 | result := Mean(a)
39 | fmt.Println(result.Data)
40 |
41 | // Output: [[0] [0]]
42 | }
43 |
44 | func ExampleMean_withNegatives() {
45 | a := M{
46 | {-1, 2, -3},
47 | {4, -5, 6},
48 | }.Var()
49 |
50 | result := Mean(a)
51 | fmt.Printf("%.6f %.6f\n", result.Data[0][0], result.Data[1][0])
52 |
53 | // Output: -0.666667 1.666667
54 | }
55 |
56 | func ExampleMean_gradient() {
57 | a := M{
58 | {1, 2, 3},
59 | {4, 5, 6},
60 | }.Var()
61 |
62 | result := Mean(a)
63 | result.Grad = M{
64 | {0.1},
65 | {0.2},
66 | }.Var()
67 |
68 | result.Backward()
69 | fmt.Println(a.Grad.Data)
70 |
71 | // Output: [[0.03333333333333333 0.03333333333333333 0.03333333333333333] [0.06666666666666667 0.06666666666666667 0.06666666666666667]]
72 | }
73 |
74 | func ExampleMean_withScalarGradient() {
75 | a := M{
76 | {1, 2},
77 | {3, 4},
78 | }.Var()
79 |
80 | result := Mean(a)
81 | result.Grad = M{
82 | {1.0},
83 | {1.0},
84 | }.Var()
85 |
86 | result.Backward()
87 | fmt.Println(a.Grad.Data)
88 |
89 | // Output: [[0.5 0.5] [0.5 0.5]]
90 | }
91 |
92 | func ExampleMean_inComputationGraph() {
93 | a := M{
94 | {1, 3},
95 | {2, 4},
96 | }.Var()
97 |
98 | meanA := Mean(a)
99 | result := variable.Mul(meanA, M{
100 | {2},
101 | {2},
102 | }.Var())
103 |
104 | fmt.Println(result.Data)
105 | result.Backward()
106 | fmt.Println(a.Grad.Data)
107 |
108 | // Output:
109 | // [[4] [6]]
110 | // [[1 1] [1 1]]
111 | }
112 |
--------------------------------------------------------------------------------
/pkg/params.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "encoding/binary"
5 | "fmt"
6 | "hash/crc32"
7 | "os"
8 |
9 | "github.com/itsubaki/autograd/layer"
10 | )
11 |
12 | type Params struct {
13 | params layer.Parameters
14 | }
15 |
16 | func NewParams() *Params {
17 | return &Params{params: layer.Parameters{}}
18 | }
19 |
20 | func (p *Params) Add(params ...layer.Parameter) {
21 | for _, param := range params {
22 | p.params.Add(fmt.Sprintf("%d", len(p.params)), param)
23 | }
24 | }
25 |
26 | func (p *Params) Params() layer.Parameters {
27 | return p.params
28 | }
29 |
30 | func (p *Params) Count() int {
31 | numParams := 0
32 | for _, param := range p.params {
33 | numParams += len(param.Data) * len(param.Data[0])
34 | }
35 |
36 | return numParams
37 | }
38 |
39 | func (p *Params) ZeroGrad() {
40 | p.params.Cleargrads()
41 | }
42 |
43 | func (p *Params) Save() {
44 | file, err := os.Create(p.filename())
45 | if err != nil {
46 | panic(err)
47 | }
48 | defer file.Close()
49 |
50 | // Save map of params in ordered fashion.
51 | hash := crc32.NewIEEE()
52 | for i := 0; i < len(p.params); i++ {
53 | key := fmt.Sprintf("%d", i)
54 | for _, row := range p.params[key].Data {
55 | if err := binary.Write(file, binary.LittleEndian, row); err != nil {
56 | panic(err)
57 | }
58 | }
59 | shape := fmt.Sprintf("%d:%d×%d", i, len(p.params[key].Data), len(p.params[key].Data[0]))
60 | hash.Write([]byte(shape))
61 | }
62 |
63 | // Write checksum at the end of the file.
64 | checksum := hash.Sum32()
65 | if err := binary.Write(file, binary.LittleEndian, checksum); err != nil {
66 | panic(err)
67 | }
68 | }
69 |
70 | func (p *Params) TryLoadPretrained() {
71 | file, err := os.Open(p.filename())
72 | if err != nil {
73 | return
74 | }
75 | defer file.Close()
76 |
77 | // Load map of params in ordered fashion.
78 | hash := crc32.NewIEEE()
79 | for i := 0; i < len(p.params); i++ {
80 | key := fmt.Sprintf("%d", i)
81 | for j := range p.params[key].Data {
82 | if err := binary.Read(file, binary.LittleEndian, &p.params[key].Data[j]); err != nil {
83 | panic(fmt.Sprintf("model shapes mismatch, remove '%s' file", p.filename()))
84 | }
85 | }
86 | shape := fmt.Sprintf("%d:%d×%d", i, len(p.params[key].Data), len(p.params[key].Data[0]))
87 | hash.Write([]byte(shape))
88 | }
89 |
90 | var savedChecksum uint32
91 | if err := binary.Read(file, binary.LittleEndian, &savedChecksum); err != nil {
92 | panic(fmt.Errorf("failed to read shapes checksum: %v", err))
93 | }
94 | if savedChecksum != hash.Sum32() {
95 | panic(fmt.Sprintf("model shapes mismatch, remove '%s' file", p.filename()))
96 | }
97 |
98 | fmt.Printf("Loaded pretrained params: %s\n", p.filename())
99 | }
100 |
101 | func (p *Params) filename() string {
102 | return fmt.Sprintf("model-%.3fM", Millions(p.Count()))
103 | }
104 |
--------------------------------------------------------------------------------
/pkg/softmax_test.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "fmt"
5 | "math"
6 | "strings"
7 |
8 | "github.com/itsubaki/autograd/function"
9 | "github.com/itsubaki/autograd/variable"
10 | )
11 |
12 | func ExampleSoftmax_basic() {
13 | a := M{
14 | {1, 2, 3},
15 | {4, 5, 6},
16 | }.Var()
17 |
18 | result := function.Softmax(a)
19 | fmt.Println(result)
20 |
21 | // Output:
22 | // variable([[0.09003057317038046 0.24472847105479764 0.6652409557748218] [0.09003057317038046 0.24472847105479764 0.6652409557748218]])
23 | }
24 |
25 | func ExampleSoftmax_largeValues() {
26 | a := V{100, 100.1, 100.2}.Var()
27 |
28 | result := function.Softmax(a)
29 | fmt.Println(result)
30 |
31 | // Output:
32 | // variable([0.30060960535572756 0.33222499353334567 0.3671654011109268])
33 | }
34 |
35 | func ExampleSoftmax_withMasking() {
36 | a := variable.NewOf(
37 | []float64{1, math.Inf(-1), 3},
38 | []float64{math.Inf(-1), 2, 3},
39 | []float64{1, 2, math.Inf(-1)},
40 | []float64{1, math.Inf(-1), math.Inf(-1)},
41 | )
42 | result := function.Softmax(a)
43 |
44 | for _, row := range result.Data {
45 | values := make([]string, len(row))
46 | for i, val := range row {
47 | values[i] = fmt.Sprintf("%.6f", val)
48 | }
49 | fmt.Println(strings.Join(values, " "))
50 | }
51 |
52 | // Output:
53 | // 0.119203 0.000000 0.880797
54 | // 0.000000 0.268941 0.731059
55 | // 0.268941 0.731059 0.000000
56 | // 1.000000 0.000000 0.000000
57 | }
58 |
59 | func ExampleSoftmax_allMasked() {
60 | a := variable.NewOf(
61 | []float64{0, math.Inf(-1), math.Inf(-1)},
62 | []float64{1, 2, 3},
63 | )
64 | result := function.Softmax(a)
65 | fmt.Println(result)
66 |
67 | // Output:
68 | // variable([[1 0 0] [0.09003057317038046 0.24472847105479764 0.6652409557748218]])
69 | }
70 |
71 | func ExampleSoftmax_gradient() {
72 | a := variable.NewOf([]float64{1, 2, 3})
73 | result := function.Softmax(a)
74 | result.Grad = variable.NewOf([]float64{1, 1, 1})
75 | result.Backward()
76 | fmt.Println(a.Grad)
77 |
78 | b := variable.NewOf([]float64{1, 2, 3})
79 | resultB := function.Softmax(b)
80 | resultB.Grad = variable.NewOf([]float64{1, 0, 0})
81 | resultB.Backward()
82 | fmt.Println(b.Grad)
83 |
84 | // Output:
85 | // variable([1.3877787807814457e-17 2.7755575615628914e-17 1.1102230246251565e-16])
86 | // variable([0.08192506906499324 -0.022033044520174298 -0.05989202454481893])
87 | }
88 |
--------------------------------------------------------------------------------
/pkg/variance.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "github.com/itsubaki/autograd/variable"
5 | )
6 |
7 | func Variance(x ...*variable.Variable) *variable.Variable {
8 | // Calculate mean per row
9 | means := Mean(x[0])
10 |
11 | diffs := variable.Sub(x[0], means)
12 | squaredDiffs := variable.Pow(2)(diffs)
13 | variance := Mean(squaredDiffs)
14 |
15 | return variance
16 | }
17 |
--------------------------------------------------------------------------------
/pkg/variance_test.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/itsubaki/autograd/variable"
7 | )
8 |
9 | func ExampleVariance_basic() {
10 | a := M{
11 | {1, 2, 3},
12 | {4, 5, 6},
13 | }.Var()
14 |
15 | result := Variance(a)
16 |
17 | // Print with higher precision to show exact values
18 | fmt.Printf("%.10f\n", result.Data[0][0])
19 | fmt.Printf("%.10f\n", result.Data[1][0])
20 |
21 | // Output:
22 | // 0.6666666667
23 | // 0.6666666667
24 | }
25 |
26 | func ExampleVariance_constants() {
27 | a := M{
28 | {5, 5, 5},
29 | {-3, -3, -3},
30 | }.Var()
31 |
32 | result := Variance(a)
33 |
34 | fmt.Println(result.Data[0][0])
35 | fmt.Println(result.Data[1][0])
36 |
37 | // Output:
38 | // 0
39 | // 0
40 | }
41 |
42 | func ExampleVariance_withNegatives() {
43 | a := M{
44 | {-1, 0, 1},
45 | {-10, 0, 10},
46 | }.Var()
47 |
48 | result := Variance(a)
49 |
50 | fmt.Printf("%.10f\n", result.Data[0][0])
51 | fmt.Printf("%.10f\n", result.Data[1][0])
52 |
53 | // Output:
54 | // 0.6666666667
55 | // 66.6666666667
56 | }
57 |
58 | func ExampleVariance_gradient() {
59 | // Values [1, 3, 5] have a mean of 3
60 | a := M{
61 | {1, 3, 5},
62 | }.Var()
63 |
64 | result := Variance(a)
65 |
66 | // Print variance result
67 | fmt.Printf("Variance: %.10f\n", result.Data[0][0])
68 |
69 | // Set gradient to 1.0 and backpropagate
70 | result.Grad = M{
71 | {1.0},
72 | }.Var()
73 |
74 | result.Backward()
75 |
76 | // Print gradients with high precision
77 | fmt.Printf("Gradients: %.10f %.10f %.10f\n",
78 | a.Grad.Data[0][0], a.Grad.Data[0][1], a.Grad.Data[0][2])
79 |
80 | // Output:
81 | // Variance: 2.6666666667
82 | // Gradients: -1.3333333333 0.0000000000 1.3333333333
83 | }
84 |
85 | func ExampleVariance_inComputationGraph() {
86 | // Create input with a single row
87 | a := M{
88 | {2, 4, 6},
89 | }.Var()
90 |
91 | // Calculate variance
92 | v := Variance(a)
93 |
94 | // Multiply by scalar
95 | k := M{
96 | {0.5},
97 | }.Var()
98 |
99 | result := variable.Mul(v, k)
100 |
101 | // Print the result
102 | fmt.Printf("Result: %.10f\n", result.Data[0][0])
103 |
104 | // Backpropagate
105 | result.Backward()
106 |
107 | // Print gradients
108 | fmt.Printf("Gradients: %.10f %.10f %.10f\n",
109 | a.Grad.Data[0][0], a.Grad.Data[0][1], a.Grad.Data[0][2])
110 |
111 | // Output:
112 | // Result: 1.3333333333
113 | // Gradients: -0.6666666667 0.0000000000 0.6666666667
114 | }
115 |
--------------------------------------------------------------------------------