├── .gitignore
├── LICENSE
├── README.md
├── attention
├── attention.go
└── multihead.go
├── examples
├── api_examples.go
├── data
│ └── imdb_1000.json
├── sentiment_trainer.go
├── serverless_search.go
├── transformer_example.go
└── vector_search.go
├── go.mod
└── transformer
└── transformer.go
/.gitignore:
--------------------------------------------------------------------------------
1 | # If you prefer the allow list template instead of the deny list, see community template:
2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
3 | #
4 | # Binaries for programs and plugins
5 | *.exe
6 | *.exe~
7 | *.dll
8 | *.so
9 | *.dylib
10 |
11 | # Test binary, built with `go test -c`
12 | *.test
13 |
14 | # Output of the go coverage tool, specifically when used with LiteIDE
15 | *.out
16 |
17 | # Dependency directories (remove the comment below to include it)
18 | # vendor/
19 |
20 | # Go workspace file
21 | go.work
22 | go.work.sum
23 |
24 | # env file
25 | .env
26 | .DS_Store
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Takara AI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # go-attention
2 |
3 |
4 |
5 | From the Frontier Research Team at takara.ai we present the first pure Go implementation of attention mechanisms and transformer layers, designed for high performance and ease of use.
6 |
7 | ## Quick Start
8 |
9 | Run our comprehensive examples:
10 |
11 | ```bash
12 | # Get the module
13 | go get github.com/takara-ai/go-attention
14 |
15 | # Run the examples
16 | go run api_examples.go
17 | ```
18 |
19 | ## API Documentation
20 |
21 | ### Core Types
22 |
23 | ```go
24 | type Vector []float64 // Represents a 1D vector of float64 values
25 | type Matrix []Vector // Represents a 2D matrix of float64 values
26 | ```
27 |
28 | ### 1. Basic Dot-Product Attention
29 |
30 | The simplest form of attention mechanism. Useful for basic sequence processing tasks.
31 |
32 | ```go
33 | import "github.com/takara-ai/go-attention/attention"
34 |
35 | // Create query-key-value setup
36 | query := attention.Vector{1.0, 0.0, 1.0, 0.0} // Pattern to search for
37 | keys := attention.Matrix{
38 | {1.0, 0.0, 1.0, 0.0}, // Similar to query
39 | {0.0, 1.0, 0.0, 1.0}, // Different from query
40 | {0.5, 0.5, 0.5, 0.5}, // Neutral pattern
41 | }
42 | values := attention.Matrix{
43 | {1.0, 2.0}, // Value for similar key
44 | {3.0, 4.0}, // Value for different key
45 | {5.0, 6.0}, // Value for neutral key
46 | }
47 |
48 | // Compute attention
49 | output, weights, err := attention.DotProductAttention(query, keys, values)
50 | if err != nil {
51 | log.Fatal(err)
52 | }
53 |
54 | // Output will be a weighted combination of values based on query-key similarity
55 | // Weights will show how much attention each key received
56 | ```
57 |
58 | ### 2. Multi-Head Attention
59 |
60 | More sophisticated attention mechanism that can capture different types of relationships in parallel.
61 |
62 | ```go
63 | import "github.com/takara-ai/go-attention/attention"
64 |
65 | // Configure multi-head attention
66 | config := attention.MultiHeadConfig{
67 | NumHeads: 4, // Number of parallel attention heads
68 | DModel: 64, // Size of input/output embeddings
69 | DKey: 16, // Size per head (DModel/NumHeads)
70 | DValue: 16, // Size per head (DModel/NumHeads)
71 | DropoutRate: 0.1, // For regularization
72 | }
73 |
74 | // Create the attention module
75 | mha, err := attention.NewMultiHeadAttention(config)
76 | if err != nil {
77 | log.Fatal(err)
78 | }
79 |
80 | // Process sequences (batched input)
81 | batchSize, seqLen := 2, 3 // Process 2 sequences, each with 3 tokens
82 |
83 | // Create input matrices [batchSize × seqLen × DModel]
84 | queries := make(attention.Matrix, batchSize*seqLen)
85 | keys := make(attention.Matrix, batchSize*seqLen)
86 | values := make(attention.Matrix, batchSize*seqLen)
87 |
88 | // Initialize your matrices with actual data...
89 |
90 | // Process through multi-head attention
91 | output, err := mha.Forward(queries, keys, values)
92 | if err != nil {
93 | log.Fatal(err)
94 | }
95 | ```
96 |
97 | ### 3. Full Transformer Layer
98 |
99 | Complete transformer layer with self-attention and feed-forward network.
100 |
101 | ```go
102 | import (
103 | "github.com/takara-ai/go-attention/transformer"
104 | "github.com/takara-ai/go-attention/attention"
105 | )
106 |
107 | // Configure transformer layer
108 | config := transformer.TransformerConfig{
109 | DModel: 64, // Size of token embeddings
110 | NumHeads: 4, // Number of attention heads
111 | DHidden: 256, // Size of feed-forward hidden layer
112 | DropoutRate: 0.1, // For regularization
113 | }
114 |
115 | // Create transformer layer
116 | layer, err := transformer.NewTransformerLayer(config)
117 | if err != nil {
118 | log.Fatal(err)
119 | }
120 |
121 | // Create input sequence [seq_len × d_model]
122 | seqLen := 3
123 | input := make(attention.Matrix, seqLen)
124 | for i := range input {
125 | input[i] = make(attention.Vector, config.DModel)
126 | // Fill with your embedding data...
127 | }
128 |
129 | // Process through transformer
130 | output, err := layer.Forward(input)
131 | if err != nil {
132 | log.Fatal(err)
133 | }
134 | ```
135 |
136 | ## Example Output
137 |
138 | When running the examples, you'll see:
139 |
140 | 1. **Dot-Product Attention**:
141 |
142 | ```
143 | Query: [1 0 1 0]
144 | Attention Weights: [0.506 0.186 0.307] // Shows focus on similar patterns
145 | Output: [2.601 3.601] // Weighted combination of values
146 | ```
147 |
148 | 2. **Multi-Head Attention**:
149 |
150 | ```
151 | Input dimensions: [2 batches × 3 tokens × 64 features]
152 | Output shape: [6×64]
153 | ```
154 |
155 | 3. **Transformer Layer**:
156 | ```
157 | Input shape: [3×64]
158 | Output shape: [3×64]
159 | ```
160 |
161 | ## Common Use Cases
162 |
163 | 1. **Text Processing**:
164 |
165 | - Sequence-to-sequence translation
166 | - Document summarization
167 | - Sentiment analysis
168 |
169 | 2. **Time Series**:
170 |
171 | - Financial forecasting
172 | - Sensor data analysis
173 | - Anomaly detection
174 |
175 | 3. **Structured Data**:
176 | - Graph node embedding
177 | - Feature interaction modeling
178 | - Recommendation systems
179 |
180 | ## Performance Considerations
181 |
182 | - Matrix operations are optimized for CPU
183 | - Memory allocations are minimized
184 | - Batch processing for better throughput
185 | - No external dependencies
186 |
187 | For more detailed examples, see the `examples` directory in the repository.
188 |
189 | ## Why go-attention?
190 |
191 | This module was created to provide a clean, efficient, and dependency-free implementation of attention mechanisms in Go. It's particularly useful for:
192 |
193 | - **Edge Computing**: Zero external dependencies makes it perfect for edge devices where dependency management is crucial
194 | - **Real-time Processing**: Pure Go implementation ensures predictable performance for real-time applications
195 | - **Cloud-native Applications**: Efficient batched operations support high-throughput scaling in cloud environments
196 | - **Embedded Systems**: Predictable resource usage and minimal memory allocations
197 | - **Production Systems**: Comprehensive error handling and type safety for robust production deployments
198 |
199 | ## Features
200 |
201 | - Efficient dot-product attention mechanism (upgraded with Scalable-Softmax (SSMax, s=1) for improved long-context performance)
202 | - Multi-head attention support
203 | - Full transformer layer implementation with:
204 | - Layer normalization
205 | - Position-wise feed-forward networks
206 | - Residual connections
207 | - Batched operations for improved performance
208 |
209 | ## Roadmap
210 |
211 | Future improvements may include:
212 |
213 | - Positional encoding implementations
214 | - Dropout support
215 | - CUDA acceleration support
216 | - Additional transformer variants
217 | - Pre-trained models
218 | - Training utilities
219 |
220 | ## Contributing
221 |
222 | Contributions are welcome! Please feel free to submit a Pull Request.
223 |
224 | ## License
225 |
226 | MIT License - see LICENSE file for details
227 |
228 | ---
229 |
230 | For research inquiries and press, please reach out to research@takara.ai
231 |
232 | > 人類を変革する
233 |
--------------------------------------------------------------------------------
/attention/attention.go:
--------------------------------------------------------------------------------
1 | package attention
2 |
3 | import (
4 | "fmt"
5 | "math"
6 | )
7 |
8 | // Vector represents a slice of float64 values
9 | type Vector []float64
10 |
11 | // Matrix represents a 2D slice of float64 values
12 | type Matrix []Vector
13 |
14 | // i is the representation of the query
15 | // j is the representation of the key
16 | // k is the representation of the value
17 |
18 | // DotProduct computes the dot product of two vectors
19 | func DotProduct(v1, v2 Vector) (float64, error) {
20 | if len(v1) != len(v2) {
21 | return 0, fmt.Errorf("vector dimensions mismatch: %d != %d", len(v1), len(v2))
22 | }
23 |
24 | sum := 0.0
25 | for i := range v1 {
26 | sum += v1[i] * v2[i]
27 | }
28 | return sum, nil
29 | }
30 |
31 | // Softmax applies the softmax function to a vector
32 | func Softmax(x Vector) Vector {
33 | if len(x) == 0 {
34 | return Vector{}
35 | }
36 |
37 | maxVal := x[0]
38 | // Find max value for numerical stability
39 | for _, v := range x[1:] {
40 | if v > maxVal {
41 | maxVal = v
42 | }
43 | }
44 |
45 | exps := make(Vector, len(x))
46 | sumExp := 0.0
47 | for i, v := range x {
48 | // Subtracting maxVal before exponentiation
49 | exps[i] = math.Exp(v - maxVal)
50 | sumExp += exps[i]
51 | }
52 |
53 | // Handle sumExp == 0 case to prevent NaN (division by zero)
54 | // This can happen if all inputs are extremely small negative numbers.
55 | if sumExp == 0 {
56 | // Return a zero vector. It's already initialized with zeros implicitly by make
57 | return exps
58 | }
59 |
60 | // Normalize: Use direct division instead of multiplying by inverse
61 | for i := range exps {
62 | exps[i] /= sumExp
63 | }
64 |
65 | return exps
66 | }
67 |
68 | // ScaleVector multiplies a vector by a scalar
69 | func ScaleVector(v Vector, scale float64) Vector {
70 | result := make(Vector, len(v))
71 | for i, val := range v {
72 | result[i] = val * scale
73 | }
74 | return result
75 | }
76 |
77 | // AddVectors adds two vectors element-wise
78 | func AddVectors(v1, v2 Vector) (Vector, error) {
79 | if len(v1) != len(v2) {
80 | return nil, fmt.Errorf("vector dimensions mismatch: %d != %d", len(v1), len(v2))
81 | }
82 |
83 | result := make(Vector, len(v1))
84 | for i := range v1 {
85 | result[i] = v1[i] + v2[i]
86 | }
87 | return result, nil
88 | }
89 |
90 | // DotProductAttention computes scaled dot-product attention
91 | // query: [d_k], keys: [n, d_k], values: [n, d_v]
92 | // Returns: attended vector [d_v] and attention weights [n]
93 | func DotProductAttention(query Vector, keys, values Matrix) (Vector, Vector, error) {
94 | n := len(keys)
95 | if n == 0 {
96 | // If keys are empty, check if values exist to determine output dimension d_v
97 | if len(values) > 0 && len(values[0]) > 0 {
98 | d_v := len(values[0])
99 | // Return empty weights and a zero vector of the correct value dimension
100 | return make(Vector, d_v), Vector{}, nil
101 | }
102 | // If both keys and values are empty (or values have zero dimension), return error or nil vectors
103 | return nil, nil, fmt.Errorf("empty keys and values provided")
104 | }
105 |
106 | // Basic dimension validation before proceeding
107 | if len(values) != n {
108 | return nil, nil, fmt.Errorf("number of keys (%d) must match number of values (%d)", n, len(values))
109 | }
110 | if len(values[0]) == 0 {
111 | return nil, nil, fmt.Errorf("value dimension (d_v) cannot be zero")
112 | }
113 | d_v := len(values[0]) // Dimension of value vectors
114 |
115 | // Determine key dimension (d_k) safely
116 | d_k := len(query)
117 | if n > 0 && len(keys[0]) != d_k {
118 | // Check the first key's dimension against the query dimension
119 | return nil, nil, fmt.Errorf("query dimension (%d) must match key dimension (%d)", d_k, len(keys[0]))
120 | }
121 |
122 |
123 | // Compute attention scores
124 | scores := make(Vector, n)
125 |
126 | // Pre-calculate scaling factor only if d_k > 0 to avoid division by zero or sqrt of zero
127 | scale := 1.0
128 | if d_k > 0 {
129 | scale = 1.0 / math.Sqrt(float64(d_k))
130 | } // If d_k is 0, scale remains 1.0, dot product will likely be 0 unless vectors are empty.
131 |
132 | for i, key := range keys {
133 | // Ensure consistent key dimensions within the loop
134 | if len(key) != d_k {
135 | return nil, nil, fmt.Errorf("key dimension mismatch at index %d: expected %d, got %d", i, d_k, len(key))
136 | }
137 | score, err := DotProduct(query, key) // DotProduct already checks len(query) == len(key)
138 | if err != nil {
139 | // This error should theoretically not happen if the outer checks pass, but handle defensively.
140 | return nil, nil, fmt.Errorf("error computing dot product for key %d: %w", i, err)
141 | }
142 | // Scale by sqrt(d_k) for better gradient flow
143 | scores[i] = score * scale // Use pre-calculated scale
144 | }
145 |
146 | // --- SSMax Modification (s=1) ---
147 | // Calculate log n (natural logarithm of sequence length)
148 | // If n=0 or n=1, log_n is handled to avoid issues. log(1)=0 is mathematically correct per formula.
149 | log_n := 0.0
150 | if n > 1 {
151 | log_n = math.Log(float64(n))
152 | } else if n == 1 {
153 | // For n=1, log(1)=0. Softmax of a single element scaled by 0 is still undefined
154 | // in the standard implementation, but mathematically should result in a weight of 1.
155 | // The existing Softmax handles single elements correctly, so scaling by 0 is fine.
156 | // Alternatively, we could set log_n = 1.0 to effectively bypass SSMax scaling for n=1.
157 | // Let's stick to the formula log_n = 0 for n=1.
158 | log_n = 0.0
159 | }
160 |
161 | // Multiply scores by log n
162 | // Avoid modifying if log_n is effectively 1 (e.g., if n == math.E, though unlikely)
163 | // Or if log_n is 0 (when n=1), scaling is identity or zeroing.
164 | if n > 1 { // Only scale if n > 1, otherwise log_n is 0
165 | for i := range scores {
166 | scores[i] *= log_n
167 | }
168 | }
169 | // --- End SSMax Modification ---
170 |
171 | // Apply softmax to get attention weights (now using SSMax-modified scores)
172 | weights := Softmax(scores)
173 |
174 | // Compute weighted sum of values
175 | attended := make(Vector, d_v) // Use d_v determined earlier
176 | for i, weight := range weights {
177 | // Ensure consistent value dimensions within the loop
178 | if len(values[i]) != d_v {
179 | return nil, nil, fmt.Errorf("value dimension mismatch at index %d: expected %d, got %d", i, d_v, len(values[i]))
180 | }
181 |
182 | // Optimization: Skip summation if weight is zero
183 | if weight == 0 {
184 | continue
185 | }
186 |
187 | // Fuse the scaling by weight into the accumulation loop
188 | valueVec := values[i] // Local ref might help optimizer, maybe minor
189 | for j := 0; j < d_v; j++ { // Iterate up to d_v
190 | attended[j] += weight * valueVec[j]
191 | }
192 | }
193 |
194 | return attended, weights, nil
195 | }
196 |
197 | // TODO: kv caching
198 | // TODO: tokenization support
199 | // TODO: rotary positional embedding support RoPE
200 | // TODO: SwarmFormer Local-Global Hiearchical Attention Support
--------------------------------------------------------------------------------
/attention/multihead.go:
--------------------------------------------------------------------------------
1 | package attention
2 |
3 | import (
4 | "fmt"
5 | "math"
6 | "math/rand"
7 | )
8 |
9 | // MultiHeadConfig holds configuration for multi-head attention
10 | type MultiHeadConfig struct {
11 | NumHeads int // Number of attention heads
12 | DModel int // Model dimension
13 | DKey int // Key dimension per head
14 | DValue int // Value dimension per head
15 | DropoutRate float64 // Dropout rate (not implemented in this version for simplicity)
16 | }
17 |
18 | // MultiHeadAttention implements multi-head attention mechanism
19 | type MultiHeadAttention struct {
20 | config MultiHeadConfig
21 |
22 | // Linear projections for each head
23 | queryProj []Matrix // [num_heads][d_model][d_k]
24 | keyProj []Matrix // [num_heads][d_model][d_k]
25 | valueProj []Matrix // [num_heads][d_model][d_v]
26 |
27 | // Output projection
28 | outputProj Matrix // [d_model][num_heads * d_v]
29 | }
30 |
31 | // NewMultiHeadAttention creates a new multi-head attention module
32 | func NewMultiHeadAttention(config MultiHeadConfig) (*MultiHeadAttention, error) {
33 | if config.NumHeads <= 0 {
34 | return nil, fmt.Errorf("number of heads must be positive, got %d", config.NumHeads)
35 | }
36 | if config.DModel <= 0 {
37 | return nil, fmt.Errorf("model dimension must be positive, got %d", config.DModel)
38 | }
39 | if config.DModel%config.NumHeads != 0 {
40 | return nil, fmt.Errorf("model dimension (%d) must be divisible by number of heads (%d)", config.DModel, config.NumHeads)
41 | }
42 |
43 | mha := &MultiHeadAttention{
44 | config: config,
45 | queryProj: make([]Matrix, config.NumHeads),
46 | keyProj: make([]Matrix, config.NumHeads),
47 | valueProj: make([]Matrix, config.NumHeads),
48 | outputProj: make(Matrix, config.DModel),
49 | }
50 |
51 | // Initialize projections with random weights
52 | for h := 0; h < config.NumHeads; h++ {
53 | mha.queryProj[h] = randomMatrix(config.DModel, config.DKey)
54 | mha.keyProj[h] = randomMatrix(config.DModel, config.DKey)
55 | mha.valueProj[h] = randomMatrix(config.DModel, config.DValue)
56 | }
57 |
58 | // Initialize output projection
59 | for i := range mha.outputProj {
60 | mha.outputProj[i] = make(Vector, config.NumHeads*config.DValue)
61 | for j := range mha.outputProj[i] {
62 | mha.outputProj[i][j] = (rand.Float64() - 0.5) / math.Sqrt(float64(config.DValue))
63 | }
64 | }
65 |
66 | return mha, nil
67 | }
68 |
69 | // Forward computes multi-head attention
70 | // query, key, value: [batch_size, seq_len, d_model]
71 | func (mha *MultiHeadAttention) Forward(query, key, value Matrix) (Matrix, error) {
72 | batchSize := len(query)
73 | if batchSize != len(key) || batchSize != len(value) {
74 | return nil, fmt.Errorf("batch size mismatch: query(%d), key(%d), value(%d)", batchSize, len(key), len(value))
75 | }
76 |
77 | // Compute attention for each head
78 | headOutputs := make([]Matrix, mha.config.NumHeads)
79 | for h := 0; h < mha.config.NumHeads; h++ {
80 | // Project inputs
81 | projQuery, err := projectBatch(query, mha.queryProj[h])
82 | if err != nil {
83 | return nil, fmt.Errorf("projecting query for head %d: %w", h, err)
84 | }
85 |
86 | projKey, err := projectBatch(key, mha.keyProj[h])
87 | if err != nil {
88 | return nil, fmt.Errorf("projecting key for head %d: %w", h, err)
89 | }
90 |
91 | projValue, err := projectBatch(value, mha.valueProj[h])
92 | if err != nil {
93 | return nil, fmt.Errorf("projecting value for head %d: %w", h, err)
94 | }
95 |
96 | // Initialize head output
97 | headOutputs[h] = make(Matrix, batchSize)
98 |
99 | // Compute attention for each item in the batch
100 | for b := 0; b < batchSize; b++ {
101 | attended, _, err := DotProductAttention(projQuery[b], projKey, projValue)
102 | if err != nil {
103 | return nil, fmt.Errorf("computing attention for batch %d, head %d: %w", b, h, err)
104 | }
105 | headOutputs[h][b] = attended
106 | }
107 | }
108 |
109 | // Concatenate and project heads
110 | output := make(Matrix, batchSize)
111 | for b := 0; b < batchSize; b++ {
112 | // Concatenate all head outputs
113 | concat := make(Vector, 0, mha.config.NumHeads*mha.config.DValue)
114 | for h := 0; h < mha.config.NumHeads; h++ {
115 | concat = append(concat, headOutputs[h][b]...)
116 | }
117 |
118 | // Project concatenated output
119 | output[b] = make(Vector, mha.config.DModel)
120 | for i := range output[b] {
121 | for j, v := range concat {
122 | output[b][i] += v * mha.outputProj[i][j]
123 | }
124 | }
125 | }
126 |
127 | return output, nil
128 | }
129 |
130 | // Helper functions
131 |
132 | func randomMatrix(rows, cols int) Matrix {
133 | mat := make(Matrix, rows)
134 | scale := math.Sqrt(2.0 / float64(rows+cols)) // Xavier initialization
135 | for i := range mat {
136 | mat[i] = make(Vector, cols)
137 | for j := range mat[i] {
138 | mat[i][j] = (rand.Float64() - 0.5) * scale
139 | }
140 | }
141 | return mat
142 | }
143 |
144 | func projectBatch(input Matrix, weights Matrix) (Matrix, error) {
145 | output := make(Matrix, len(input))
146 | for i, vec := range input {
147 | projected, err := projectVector(vec, weights)
148 | if err != nil {
149 | return nil, err
150 | }
151 | output[i] = projected
152 | }
153 | return output, nil
154 | }
155 |
156 | func projectVector(input Vector, weights Matrix) (Vector, error) {
157 | if len(weights) == 0 || len(weights[0]) == 0 {
158 | return nil, fmt.Errorf("empty weight matrix")
159 | }
160 | if len(input) != len(weights) {
161 | return nil, fmt.Errorf("input dimension (%d) does not match weights (%d)", len(input), len(weights))
162 | }
163 |
164 | output := make(Vector, len(weights[0]))
165 | for i := range weights[0] {
166 | for j, w := range weights {
167 | output[i] += input[j] * w[i]
168 | }
169 | }
170 | return output, nil
171 | }
--------------------------------------------------------------------------------
/examples/api_examples.go:
--------------------------------------------------------------------------------
1 | // This script demonstrates all key APIs of the go-attention module
2 | package main
3 |
4 | import (
5 | "fmt"
6 | "log"
7 | "github.com/takara-ai/go-attention/attention"
8 | "github.com/takara-ai/go-attention/transformer"
9 | )
10 |
11 | func main() {
12 | // 1. Basic Dot-Product Attention
13 | fmt.Println("\n=== 1. Basic Dot-Product Attention ===")
14 | testDotProductAttention()
15 |
16 | // 2. Multi-Head Attention
17 | fmt.Println("\n=== 2. Multi-Head Attention ===")
18 | testMultiHeadAttention()
19 |
20 | // 3. Full Transformer Layer
21 | fmt.Println("\n=== 3. Full Transformer Layer ===")
22 | testTransformerLayer()
23 | }
24 |
25 | func testDotProductAttention() {
26 | // Create a simple query-key-value setup
27 | query := attention.Vector{1.0, 0.0, 1.0, 0.0} // Looking for patterns similar to [1,0,1,0]
28 | keys := attention.Matrix{
29 | {1.0, 0.0, 1.0, 0.0}, // Similar to query
30 | {0.0, 1.0, 0.0, 1.0}, // Different from query
31 | {0.5, 0.5, 0.5, 0.5}, // Neutral pattern
32 | }
33 | values := attention.Matrix{
34 | {1.0, 2.0}, // Value for similar key
35 | {3.0, 4.0}, // Value for different key
36 | {5.0, 6.0}, // Value for neutral key
37 | }
38 |
39 | // Compute attention
40 | output, weights, err := attention.DotProductAttention(query, keys, values)
41 | if err != nil {
42 | log.Fatal(err)
43 | }
44 |
45 | fmt.Println("Query:", query)
46 | fmt.Println("Attention Weights:", weights)
47 | fmt.Println("Output:", output)
48 | }
49 |
50 | func testMultiHeadAttention() {
51 | // Configure multi-head attention
52 | config := attention.MultiHeadConfig{
53 | NumHeads: 4,
54 | DModel: 64,
55 | DKey: 16, // DModel / NumHeads
56 | DValue: 16, // DModel / NumHeads
57 | DropoutRate: 0.1,
58 | }
59 |
60 | // Create multi-head attention module
61 | mha, err := attention.NewMultiHeadAttention(config)
62 | if err != nil {
63 | log.Fatal(err)
64 | }
65 |
66 | // Create sample input with dimensions:
67 | // - batch_size: number of sequences to process in parallel
68 | // - seq_len: number of tokens in each sequence
69 | // - d_model: dimension of each token's embedding
70 | batchSize, seqLen := 2, 3
71 |
72 | // Create input matrices with shape [batchSize × seqLen × DModel]
73 | queries := make(attention.Matrix, batchSize*seqLen)
74 | keys := make(attention.Matrix, batchSize*seqLen)
75 | values := make(attention.Matrix, batchSize*seqLen)
76 |
77 | // Initialize with a deterministic pattern
78 | for i := 0; i < batchSize*seqLen; i++ {
79 | queries[i] = make(attention.Vector, config.DModel)
80 | keys[i] = make(attention.Vector, config.DModel)
81 | values[i] = make(attention.Vector, config.DModel)
82 | for j := 0; j < config.DModel; j++ {
83 | val := float64(i*j) / float64(config.DModel)
84 | queries[i][j] = val
85 | keys[i][j] = val
86 | values[i][j] = val
87 | }
88 | }
89 |
90 | // Process through multi-head attention
91 | output, err := mha.Forward(queries, keys, values)
92 | if err != nil {
93 | log.Fatal(err)
94 | }
95 |
96 | fmt.Printf("Input dimensions: [%d batches × %d tokens × %d features]\n",
97 | batchSize, seqLen, config.DModel)
98 | fmt.Printf("Output shape: [%d×%d]\n", len(output), len(output[0]))
99 | fmt.Println("First few output values:", output[0][:4])
100 | }
101 |
102 | func testTransformerLayer() {
103 | // Configure transformer layer
104 | config := transformer.TransformerConfig{
105 | DModel: 64,
106 | NumHeads: 4,
107 | DHidden: 256,
108 | DropoutRate: 0.1,
109 | }
110 |
111 | // Create transformer layer
112 | layer, err := transformer.NewTransformerLayer(config)
113 | if err != nil {
114 | log.Fatal(err)
115 | }
116 |
117 | // Create sample sequence (seq_len=3, d_model=64)
118 | input := createRandomMatrix(3, config.DModel)
119 |
120 | // Process through transformer
121 | output, err := layer.Forward(input)
122 | if err != nil {
123 | log.Fatal(err)
124 | }
125 |
126 | fmt.Printf("Input shape: [%d×%d]\n", len(input), len(input[0]))
127 | fmt.Printf("Output shape: [%d×%d]\n", len(output), len(output[0]))
128 | fmt.Println("First token before:", input[0][:4])
129 | fmt.Println("First token after:", output[0][:4])
130 | }
131 |
132 | // Helper function to create random matrices
133 | func createRandomMatrix(rows, cols int) attention.Matrix {
134 | matrix := make(attention.Matrix, rows)
135 | for i := range matrix {
136 | matrix[i] = make(attention.Vector, cols)
137 | for j := range matrix[i] {
138 | matrix[i][j] = float64(i+j) / float64(cols) // Deterministic pattern for testing
139 | }
140 | }
141 | return matrix
142 | }
--------------------------------------------------------------------------------
/examples/sentiment_trainer.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "log"
7 | "math"
8 | "math/rand"
9 | "os"
10 | "strings"
11 | "time"
12 |
13 | "github.com/takara-ai/go-attention/attention"
14 | )
15 |
16 | // IMDBReview represents a single review from the IMDB dataset
17 | type IMDBReview struct {
18 | Text string `json:"text"`
19 | Label float64 `json:"label"`
20 | }
21 |
22 | // TrainingExample represents a single training instance
23 | type TrainingExample struct {
24 | Text string `json:"text"`
25 | Sentiment float64 `json:"sentiment"`
26 | }
27 |
28 | // Model represents our sentiment analysis model
29 | type Model struct {
30 | WordEmbedding []attention.Vector `json:"word_embedding"` // Simple word embedding lookup
31 | QueryWeights attention.Matrix `json:"query_weights"` // Query projection weights
32 | KeyWeights attention.Matrix `json:"key_weights"` // Key projection weights
33 | ValueWeights attention.Matrix `json:"value_weights"` // Value projection weights
34 | OutputWeights attention.Vector `json:"output_weights"` // Output layer weights
35 | VocabSize int `json:"vocab_size"`
36 | EmbedDim int `json:"embed_dim"`
37 | AttnDim int `json:"attn_dim"`
38 | }
39 |
40 | // NewModel creates a new sentiment analysis model
41 | func NewModel(vocabSize, embedDim, attnDim int) *Model {
42 | // Initialize word embeddings with Xavier initialization
43 | scale := math.Sqrt(2.0 / float64(embedDim))
44 | wordEmbedding := make([]attention.Vector, vocabSize)
45 | for i := range wordEmbedding {
46 | wordEmbedding[i] = make(attention.Vector, embedDim)
47 | for j := range wordEmbedding[i] {
48 | wordEmbedding[i][j] = rand.NormFloat64() * scale
49 | }
50 | }
51 |
52 | // Initialize attention weights
53 | queryWeights := make(attention.Matrix, embedDim)
54 | keyWeights := make(attention.Matrix, embedDim)
55 | valueWeights := make(attention.Matrix, embedDim)
56 | for i := 0; i < embedDim; i++ {
57 | queryWeights[i] = make(attention.Vector, attnDim)
58 | keyWeights[i] = make(attention.Vector, attnDim)
59 | valueWeights[i] = make(attention.Vector, attnDim)
60 | for j := 0; j < attnDim; j++ {
61 | queryWeights[i][j] = rand.NormFloat64() * scale
62 | keyWeights[i][j] = rand.NormFloat64() * scale
63 | valueWeights[i][j] = rand.NormFloat64() * scale
64 | }
65 | }
66 |
67 | // Initialize output weights
68 | outputWeights := make(attention.Vector, embedDim)
69 | for i := range outputWeights {
70 | outputWeights[i] = rand.NormFloat64() * scale
71 | }
72 |
73 | return &Model{
74 | WordEmbedding: wordEmbedding,
75 | QueryWeights: queryWeights,
76 | KeyWeights: keyWeights,
77 | ValueWeights: valueWeights,
78 | OutputWeights: outputWeights,
79 | VocabSize: vocabSize,
80 | EmbedDim: embedDim,
81 | AttnDim: attnDim,
82 | }
83 | }
84 |
85 | // sigmoid computes the sigmoid function
86 | func sigmoid(x float64) float64 {
87 | return 1.0 / (1.0 + math.Exp(-x))
88 | }
89 |
90 | // dotProduct computes the dot product of two vectors
91 | func dotProduct(a, b attention.Vector) float64 {
92 | sum := 0.0
93 | for i := range a {
94 | sum += a[i] * b[i]
95 | }
96 | return sum
97 | }
98 |
99 | // hashWord creates a simple hash for word to vocabulary mapping
100 | func hashWord(word string) int {
101 | hash := 0
102 | for _, c := range word {
103 | hash = (hash*31 + int(c)) % 10000
104 | }
105 | return hash
106 | }
107 |
108 | // Batch represents a batch of training examples
109 | type Batch struct {
110 | Embeddings []attention.Matrix
111 | Targets []float64
112 | }
113 |
114 | // createBatch creates a batch from examples
115 | func (m *Model) createBatch(examples []TrainingExample, batchSize int) []Batch {
116 | var batches []Batch
117 | for i := 0; i < len(examples); i += batchSize {
118 | end := i + batchSize
119 | if end > len(examples) {
120 | end = len(examples)
121 | }
122 |
123 | batch := Batch{
124 | Embeddings: make([]attention.Matrix, end-i),
125 | Targets: make([]float64, end-i),
126 | }
127 |
128 | for j, example := range examples[i:end] {
129 | words := strings.Fields(strings.ToLower(example.Text))
130 | if len(words) > 100 {
131 | words = words[:100]
132 | }
133 |
134 | embeddings := make(attention.Matrix, len(words))
135 | for k, word := range words {
136 | wordIdx := hashWord(word) % m.VocabSize
137 | embeddings[k] = m.WordEmbedding[wordIdx]
138 | }
139 | batch.Embeddings[j] = embeddings
140 | batch.Targets[j] = example.Sentiment
141 | }
142 | batches = append(batches, batch)
143 | }
144 | return batches
145 | }
146 |
147 | // clipGradients clips gradients to prevent explosion
148 | func clipGradients(grads attention.Vector, maxNorm float64) {
149 | var norm float64
150 | for _, g := range grads {
151 | norm += g * g
152 | }
153 | norm = math.Sqrt(norm)
154 | if norm > maxNorm {
155 | scale := maxNorm / norm
156 | for i := range grads {
157 | grads[i] *= scale
158 | }
159 | }
160 | }
161 |
162 | // Train trains the model on a dataset
163 | func (m *Model) Train(examples []TrainingExample, epochs int, learningRate float64) error {
164 | rand.Seed(time.Now().UnixNano())
165 |
166 | // Training hyperparameters
167 | batchSize := 32
168 | maxGradNorm := 5.0
169 | l2Reg := 0.01
170 | lrDecay := 0.95
171 |
172 | // Split into training and validation sets (80-20 split)
173 | rand.Shuffle(len(examples), func(i, j int) {
174 | examples[i], examples[j] = examples[j], examples[i]
175 | })
176 | splitIdx := int(float64(len(examples)) * 0.8)
177 | trainExamples := examples[:splitIdx]
178 | valExamples := examples[splitIdx:]
179 |
180 | bestValAcc := 0.0
181 | noImprovementCount := 0
182 | currentLR := learningRate
183 |
184 | for epoch := 0; epoch < epochs; epoch++ {
185 | // Create batches
186 | batches := m.createBatch(trainExamples, batchSize)
187 |
188 | // Training phase
189 | totalLoss := 0.0
190 | correct := 0
191 | totalExamples := 0
192 |
193 | for batchIdx, batch := range batches {
194 | batchLoss := 0.0
195 | batchCorrect := 0
196 |
197 | // Accumulate gradients for the batch
198 | queryGrads := make(attention.Matrix, m.EmbedDim)
199 | keyGrads := make(attention.Matrix, m.EmbedDim)
200 | valueGrads := make(attention.Matrix, m.EmbedDim)
201 | outputGrads := make(attention.Vector, m.EmbedDim)
202 |
203 | // Initialize gradient matrices
204 | for i := 0; i < m.EmbedDim; i++ {
205 | queryGrads[i] = make(attention.Vector, m.AttnDim)
206 | keyGrads[i] = make(attention.Vector, m.AttnDim)
207 | valueGrads[i] = make(attention.Vector, m.AttnDim)
208 | }
209 |
210 | for i := range batch.Embeddings {
211 | // Forward pass
212 | queries := m.project(batch.Embeddings[i], m.QueryWeights)
213 | keys := m.project(batch.Embeddings[i], m.KeyWeights)
214 | values := m.project(batch.Embeddings[i], m.ValueWeights)
215 |
216 | // Compute attention scores
217 | scores := make([]float64, len(queries))
218 | maxScore := -math.MaxFloat64
219 | for j := range queries {
220 | scores[j] = 0
221 | for k := 0; k < m.AttnDim; k++ {
222 | scores[j] += queries[j][k] * keys[j][k]
223 | }
224 | scores[j] /= math.Sqrt(float64(m.AttnDim))
225 | if scores[j] > maxScore {
226 | maxScore = scores[j]
227 | }
228 | }
229 |
230 | // Softmax
231 | sumExp := 0.0
232 | for j := range scores {
233 | scores[j] = math.Exp(scores[j] - maxScore)
234 | sumExp += scores[j]
235 | }
236 | for j := range scores {
237 | scores[j] /= sumExp
238 | }
239 |
240 | // Weighted sum
241 | context := make(attention.Vector, m.EmbedDim)
242 | for j := range context {
243 | for k := range scores {
244 | for d := 0; d < m.AttnDim; d++ {
245 | context[j] += scores[k] * values[k][d] * m.ValueWeights[j][d]
246 | }
247 | }
248 | }
249 |
250 | // Final prediction
251 | logit := dotProduct(context, m.OutputWeights)
252 | prediction := sigmoid(logit)
253 |
254 | // Compute loss with L2 regularization
255 | target := batch.Targets[i]
256 | loss := -(target*math.Log(prediction+1e-10) + (1-target)*math.Log(1-prediction+1e-10))
257 |
258 | // Add L2 regularization
259 | l2Loss := 0.0
260 | for _, w := range m.OutputWeights {
261 | l2Loss += w * w
262 | }
263 | loss += l2Reg * l2Loss * 0.5
264 | batchLoss += loss
265 |
266 | // Backward pass
267 | error := prediction - target
268 |
269 | // Output weight gradients
270 | for j := range outputGrads {
271 | outputGrads[j] += error * context[j] + l2Reg * m.OutputWeights[j]
272 | }
273 |
274 | // Attention gradients
275 | contextGrad := make(attention.Vector, m.EmbedDim)
276 | for j := range contextGrad {
277 | contextGrad[j] = error * m.OutputWeights[j]
278 | }
279 |
280 | // Update attention weight gradients
281 | for j := range batch.Embeddings[i] {
282 | if j >= m.EmbedDim {
283 | continue
284 | }
285 | for k := 0; k < m.AttnDim; k++ {
286 | queryGrad := 0.0
287 | keyGrad := 0.0
288 | valueGrad := 0.0
289 |
290 | for d := range scores {
291 | queryGrad += contextGrad[j] * scores[d] * keys[j][k] / math.Sqrt(float64(m.AttnDim))
292 | keyGrad += contextGrad[j] * scores[d] * queries[j][k] / math.Sqrt(float64(m.AttnDim))
293 | valueGrad += contextGrad[j] * scores[d]
294 | }
295 |
296 | queryGrads[j][k] += queryGrad
297 | keyGrads[j][k] += keyGrad
298 | valueGrads[j][k] += valueGrad
299 | }
300 | }
301 |
302 | // Calculate accuracy
303 | predictedClass := 0.0
304 | if prediction > 0.5 {
305 | predictedClass = 1.0
306 | }
307 | if predictedClass == target {
308 | batchCorrect++
309 | }
310 | }
311 |
312 | // Clip and apply gradients
313 | clipGradients(outputGrads, maxGradNorm)
314 | for i := range m.OutputWeights {
315 | m.OutputWeights[i] -= currentLR * outputGrads[i]
316 | }
317 |
318 | for i := 0; i < m.EmbedDim; i++ {
319 | clipGradients(queryGrads[i], maxGradNorm)
320 | clipGradients(keyGrads[i], maxGradNorm)
321 | clipGradients(valueGrads[i], maxGradNorm)
322 | for j := 0; j < m.AttnDim; j++ {
323 | m.QueryWeights[i][j] -= currentLR * queryGrads[i][j]
324 | m.KeyWeights[i][j] -= currentLR * keyGrads[i][j]
325 | m.ValueWeights[i][j] -= currentLR * valueGrads[i][j]
326 | }
327 | }
328 |
329 | totalLoss += batchLoss
330 | correct += batchCorrect
331 | totalExamples += len(batch.Embeddings)
332 |
333 | // Print batch progress
334 | if (batchIdx+1) % 10 == 0 {
335 | fmt.Printf("Epoch %d, Batch %d/%d, Loss: %.4f, Accuracy: %.2f%%\n",
336 | epoch+1, batchIdx+1, len(batches),
337 | batchLoss/float64(len(batch.Embeddings)),
338 | float64(batchCorrect)*100.0/float64(len(batch.Embeddings)))
339 | }
340 | }
341 |
342 | trainLoss := totalLoss / float64(totalExamples)
343 | trainAcc := float64(correct) * 100.0 / float64(totalExamples)
344 |
345 | // Validation phase
346 | valLoss, valAcc := m.evaluate(valExamples)
347 |
348 | fmt.Printf("Epoch %d complete:\n", epoch+1)
349 | fmt.Printf(" Training - Loss: %.4f, Accuracy: %.2f%%\n", trainLoss, trainAcc)
350 | fmt.Printf(" Validation - Loss: %.4f, Accuracy: %.2f%%\n", valLoss, valAcc)
351 |
352 | // Learning rate decay and early stopping
353 | if valAcc > bestValAcc {
354 | bestValAcc = valAcc
355 | noImprovementCount = 0
356 | } else {
357 | noImprovementCount++
358 | if noImprovementCount >= 2 {
359 | currentLR *= lrDecay
360 | fmt.Printf(" Reducing learning rate to %.6f\n", currentLR)
361 | noImprovementCount = 0
362 | }
363 | }
364 |
365 | if currentLR < learningRate * 0.01 {
366 | fmt.Println("Learning rate too small, stopping training")
367 | break
368 | }
369 | }
370 |
371 | return nil
372 | }
373 |
374 | // project applies weight matrix to input
375 | func (m *Model) project(input attention.Matrix, weights attention.Matrix) attention.Matrix {
376 | output := make(attention.Matrix, len(input))
377 | for i := range output {
378 | output[i] = make(attention.Vector, m.AttnDim)
379 | for j := 0; j < m.AttnDim; j++ {
380 | for k := 0; k < m.EmbedDim && k < len(input[i]); k++ {
381 | output[i][j] += input[i][k] * weights[k][j]
382 | }
383 | }
384 | }
385 | return output
386 | }
387 |
388 | // evaluate computes loss and accuracy on a dataset
389 | func (m *Model) evaluate(examples []TrainingExample) (float64, float64) {
390 | totalLoss := 0.0
391 | correct := 0
392 |
393 | for _, example := range examples {
394 | prediction := m.Predict(example.Text)
395 |
396 | loss := -(example.Sentiment*math.Log(prediction+1e-10) +
397 | (1-example.Sentiment)*math.Log(1-prediction+1e-10))
398 | totalLoss += loss
399 |
400 | predictedClass := 0.0
401 | if prediction > 0.5 {
402 | predictedClass = 1.0
403 | }
404 | if predictedClass == example.Sentiment {
405 | correct++
406 | }
407 | }
408 |
409 | return totalLoss / float64(len(examples)),
410 | float64(correct) * 100.0 / float64(len(examples))
411 | }
412 |
413 | // Predict makes a prediction on new text
414 | func (m *Model) Predict(text string) float64 {
415 | // Tokenize
416 | words := strings.Fields(strings.ToLower(text))
417 | if len(words) > 100 {
418 | words = words[:100]
419 | }
420 | if len(words) == 0 {
421 | return 0.5 // Default prediction for empty text
422 | }
423 |
424 | // Get embeddings
425 | embeddings := make(attention.Matrix, len(words))
426 | for i, word := range words {
427 | wordIdx := hashWord(word) % m.VocabSize
428 | embeddings[i] = make(attention.Vector, m.EmbedDim)
429 | copy(embeddings[i], m.WordEmbedding[wordIdx])
430 | }
431 |
432 | // Apply attention
433 | queries := m.project(embeddings, m.QueryWeights)
434 | keys := m.project(embeddings, m.KeyWeights)
435 | values := m.project(embeddings, m.ValueWeights)
436 |
437 | // Compute attention scores
438 | scores := make([]float64, len(queries))
439 | maxScore := -math.MaxFloat64
440 | for j := range queries {
441 | scores[j] = 0
442 | for k := 0; k < m.AttnDim; k++ {
443 | scores[j] += queries[j][k] * keys[j][k]
444 | }
445 | scores[j] /= math.Sqrt(float64(m.AttnDim))
446 | if scores[j] > maxScore {
447 | maxScore = scores[j]
448 | }
449 | }
450 |
451 | // Softmax
452 | sumExp := 0.0
453 | for j := range scores {
454 | scores[j] = math.Exp(scores[j] - maxScore)
455 | sumExp += scores[j]
456 | }
457 | if sumExp > 0 {
458 | for j := range scores {
459 | scores[j] /= sumExp
460 | }
461 | } else {
462 | // If all scores are very negative, use uniform attention
463 | for j := range scores {
464 | scores[j] = 1.0 / float64(len(scores))
465 | }
466 | }
467 |
468 | // Weighted sum
469 | context := make(attention.Vector, m.EmbedDim)
470 | for j := range context {
471 | for k := range scores {
472 | for d := 0; d < m.AttnDim; d++ {
473 | context[j] += scores[k] * values[k][d] * m.ValueWeights[j][d]
474 | }
475 | }
476 | }
477 |
478 | // Final prediction
479 | logit := dotProduct(context, m.OutputWeights)
480 | return sigmoid(logit)
481 | }
482 |
483 | // SaveModel saves the model to a file
484 | func (m *Model) SaveModel(filename string) error {
485 | file, err := os.Create(filename)
486 | if err != nil {
487 | return fmt.Errorf("failed to create file: %v", err)
488 | }
489 | defer file.Close()
490 |
491 | encoder := json.NewEncoder(file)
492 | if err := encoder.Encode(m); err != nil {
493 | return fmt.Errorf("failed to encode model: %v", err)
494 | }
495 | return nil
496 | }
497 |
498 | // LoadModel loads a model from a file
499 | func LoadModel(filename string) (*Model, error) {
500 | file, err := os.Open(filename)
501 | if err != nil {
502 | return nil, fmt.Errorf("failed to open file: %v", err)
503 | }
504 | defer file.Close()
505 |
506 | var model Model
507 | decoder := json.NewDecoder(file)
508 | if err := decoder.Decode(&model); err != nil {
509 | return nil, fmt.Errorf("failed to decode model: %v", err)
510 | }
511 |
512 | return &model, nil
513 | }
514 |
515 | // loadIMDBData loads the IMDB review dataset
516 | func loadIMDBData(filename string) ([]TrainingExample, error) {
517 | data, err := os.ReadFile(filename)
518 | if err != nil {
519 | return nil, fmt.Errorf("failed to read file: %v", err)
520 | }
521 |
522 | // Split into lines and parse each line as a separate JSON object
523 | lines := strings.Split(string(data), "\n")
524 | examples := make([]TrainingExample, 0, len(lines))
525 |
526 | for _, line := range lines {
527 | if len(strings.TrimSpace(line)) == 0 {
528 | continue
529 | }
530 |
531 | var review IMDBReview
532 | if err := json.Unmarshal([]byte(line), &review); err != nil {
533 | return nil, fmt.Errorf("failed to parse JSON line: %v", err)
534 | }
535 |
536 | // Clean the text by removing HTML tags
537 | text := strings.ReplaceAll(review.Text, "
", " ")
538 | text = strings.ReplaceAll(text, "
", " ")
539 | text = strings.ReplaceAll(text, "\\/", "/")
540 |
541 | examples = append(examples, TrainingExample{
542 | Text: text,
543 | Sentiment: review.Label,
544 | })
545 | }
546 |
547 | return examples, nil
548 | }
549 |
550 | func main() {
551 | // Load IMDB dataset
552 | examples, err := loadIMDBData("data/imdb_1000.json")
553 | if err != nil {
554 | log.Fatalf("Failed to load IMDB data: %v", err)
555 | }
556 | fmt.Printf("Loaded %d examples from IMDB dataset\n", len(examples))
557 |
558 | // Create and train model
559 | model := NewModel(10000, 64, 32) // vocab size: 10000, embed dim: 64, attention dim: 32
560 |
561 | fmt.Println("Training model...")
562 | if err := model.Train(examples, 10, 0.01); err != nil { // More epochs, higher learning rate
563 | log.Fatalf("Training failed: %v", err)
564 | }
565 |
566 | // Save the model
567 | if err := model.SaveModel("sentiment_model.json"); err != nil {
568 | log.Fatalf("Failed to save model: %v", err)
569 | }
570 | fmt.Println("Model saved to sentiment_model.json")
571 |
572 | // Test the model
573 | testTexts := []string{
574 | "This movie was absolutely fantastic! The acting was superb.",
575 | "Terrible waste of time. The worst movie I've ever seen.",
576 | "An okay film, nothing special but decent entertainment.",
577 | }
578 |
579 | fmt.Println("\nTesting model predictions:")
580 | for _, text := range testTexts {
581 | prediction := model.Predict(text)
582 | fmt.Printf("Text: %s\nSentiment: %.2f\n\n", text, prediction)
583 | }
584 | }
--------------------------------------------------------------------------------
/examples/serverless_search.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "encoding/json"
5 | "log"
6 | "net/http"
7 | "sort"
8 | "strings"
9 | "sync"
10 | "time"
11 | "github.com/takara-ai/go-attention/attention"
12 | )
13 |
14 | // SearchRequest represents the incoming search query
15 | type SearchRequest struct {
16 | Query string `json:"query"`
17 | Docs []string `json:"documents"`
18 | }
19 |
20 | // SearchResult represents a single matched document
21 | type SearchResult struct {
22 | Text string `json:"text"`
23 | Score float64 `json:"score"`
24 | Rank int `json:"rank"`
25 | }
26 |
27 | // SearchResponse represents the search results
28 | type SearchResponse struct {
29 | Results []SearchResult `json:"results"`
30 | Timing float64 `json:"timing_ms"`
31 | }
32 |
33 | // Singleton pattern for embedder to avoid reinitialization
34 | var (
35 | embedder *SemanticEmbedder
36 | embedderOnce sync.Once
37 | )
38 |
39 | // SemanticEmbedder handles document embedding with attention
40 | type SemanticEmbedder struct {
41 | dimension int
42 | }
43 |
44 | // Simple semantic embedding simulation
45 | func (e *SemanticEmbedder) embedWord(word string) attention.Vector {
46 | word = strings.ToLower(word)
47 | embedding := make(attention.Vector, e.dimension)
48 |
49 | // Animal-related features
50 | if strings.Contains(word, "cat") || strings.Contains(word, "kitten") {
51 | embedding[0] = 1.0 // feline
52 | embedding[1] = 0.8 // pet
53 | embedding[2] = 0.6 // animal
54 | }
55 | if strings.Contains(word, "dog") || strings.Contains(word, "canine") {
56 | embedding[0] = 0.8 // animal
57 | embedding[1] = 0.8 // pet
58 | embedding[3] = 1.0 // canine
59 | }
60 |
61 | // Activity-related features
62 | if strings.Contains(word, "play") || strings.Contains(word, "chase") {
63 | embedding[4] = 1.0 // activity
64 | embedding[5] = 0.7 // movement
65 | }
66 | if strings.Contains(word, "sat") || strings.Contains(word, "love") {
67 | embedding[6] = 0.6 // state
68 | }
69 |
70 | // Tech-related features
71 | if strings.Contains(word, "ai") || strings.Contains(word, "artificial") ||
72 | strings.Contains(word, "intelligence") {
73 | embedding[8] = 1.0 // AI
74 | embedding[9] = 0.8 // technology
75 | embedding[10] = 0.7 // computing
76 | }
77 | if strings.Contains(word, "machine") || strings.Contains(word, "learning") {
78 | embedding[8] = 0.8 // AI
79 | embedding[9] = 0.9 // technology
80 | embedding[11] = 1.0 // learning
81 | }
82 | if strings.Contains(word, "neural") || strings.Contains(word, "network") {
83 | embedding[8] = 0.7 // AI
84 | embedding[9] = 0.8 // technology
85 | embedding[12] = 1.0 // networks
86 | }
87 |
88 | return embedding
89 | }
90 |
91 | // Embed a sentence into a fixed-size vector using attention
92 | func (e *SemanticEmbedder) embedSentence(sentence string) attention.Vector {
93 | words := strings.Fields(sentence)
94 | if len(words) == 0 {
95 | return make(attention.Vector, e.dimension)
96 | }
97 |
98 | // Create embeddings for each word
99 | wordEmbeddings := make(attention.Matrix, len(words))
100 | for i, word := range words {
101 | wordEmbeddings[i] = e.embedWord(word)
102 | }
103 |
104 | // Use attention to combine word embeddings
105 | output, _, err := attention.DotProductAttention(wordEmbeddings[0], wordEmbeddings, wordEmbeddings)
106 | if err != nil {
107 | log.Printf("Error in embedSentence: %v", err)
108 | return make(attention.Vector, e.dimension)
109 | }
110 |
111 | return output
112 | }
113 |
114 | func getEmbedder() *SemanticEmbedder {
115 | embedderOnce.Do(func() {
116 | embedder = &SemanticEmbedder{
117 | dimension: 16,
118 | }
119 | })
120 | return embedder
121 | }
122 |
123 | // HandleSearch is the serverless entry point
124 | func HandleSearch(w http.ResponseWriter, r *http.Request) {
125 | startTime := time.Now()
126 |
127 | if r.Method != http.MethodPost {
128 | http.Error(w, "Only POST method is allowed", http.StatusMethodNotAllowed)
129 | return
130 | }
131 |
132 | // Parse request
133 | var req SearchRequest
134 | if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
135 | http.Error(w, "Invalid request format", http.StatusBadRequest)
136 | return
137 | }
138 |
139 | // Get singleton embedder (zero cold start penalty after first request)
140 | emb := getEmbedder()
141 |
142 | // Process query and documents in parallel
143 | queryEmbed := emb.embedSentence(req.Query)
144 |
145 | // Use goroutines for parallel document processing
146 | docCount := len(req.Docs)
147 | docEmbeddings := make(attention.Matrix, docCount)
148 | var wg sync.WaitGroup
149 | wg.Add(docCount)
150 |
151 | for i := range req.Docs {
152 | go func(idx int) {
153 | defer wg.Done()
154 | docEmbeddings[idx] = emb.embedSentence(req.Docs[idx])
155 | }(i)
156 | }
157 | wg.Wait()
158 |
159 | // Compute attention scores
160 | _, weights, err := attention.DotProductAttention(queryEmbed, docEmbeddings, docEmbeddings)
161 | if err != nil {
162 | http.Error(w, "Processing error", http.StatusInternalServerError)
163 | return
164 | }
165 |
166 | // Prepare results
167 | results := make([]SearchResult, len(weights))
168 | for i := range weights {
169 | results[i] = SearchResult{
170 | Text: req.Docs[i],
171 | Score: weights[i] * 100, // Convert to percentage
172 | Rank: i + 1,
173 | }
174 | }
175 |
176 | // Sort results (quick sort is more efficient than bubble sort)
177 | sort.Slice(results, func(i, j int) bool {
178 | return results[i].Score > results[j].Score
179 | })
180 |
181 | // Return top results
182 | response := SearchResponse{
183 | Results: results[:min(3, len(results))],
184 | Timing: float64(time.Since(startTime).Microseconds()) / 1000.0, // Convert to milliseconds
185 | }
186 |
187 | w.Header().Set("Content-Type", "application/json")
188 | json.NewEncoder(w).Encode(response)
189 | }
190 |
191 | func min(a, b int) int {
192 | if a < b {
193 | return a
194 | }
195 | return b
196 | }
197 |
198 | // For local testing
199 | func main() {
200 | log.Printf("Starting semantic search server on :8080...")
201 | http.HandleFunc("/search", HandleSearch)
202 | log.Fatal(http.ListenAndServe(":8080", nil))
203 | }
--------------------------------------------------------------------------------
/examples/transformer_example.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "log"
6 | "math/rand"
7 | "time"
8 |
9 | "github.com/takara-ai/go-attention/attention"
10 | "github.com/takara-ai/go-attention/transformer"
11 | )
12 |
13 | func main() {
14 | // Seed random number generator
15 | rand.Seed(time.Now().UnixNano())
16 |
17 | // Create a transformer layer configuration
18 | config := transformer.TransformerConfig{
19 | DModel: 64, // Model dimension
20 | NumHeads: 4, // Number of attention heads
21 | DHidden: 256, // Hidden dimension in feed-forward network
22 | DropoutRate: 0.1, // Dropout rate (not used in this implementation)
23 | }
24 |
25 | // Create a transformer layer
26 | layer, err := transformer.NewTransformerLayer(config)
27 | if err != nil {
28 | log.Fatalf("Failed to create transformer layer: %v", err)
29 | }
30 |
31 | // Create sample input sequence
32 | // In this example, we'll create a sequence of 3 tokens, each with dimension DModel
33 | input := make(attention.Matrix, 3)
34 | for i := range input {
35 | input[i] = make(attention.Vector, config.DModel)
36 | // Initialize with random values
37 | for j := range input[i] {
38 | input[i][j] = rand.Float64()*2 - 1
39 | }
40 | }
41 |
42 | // Process the sequence through the transformer layer
43 | output, err := layer.Forward(input)
44 | if err != nil {
45 | log.Fatalf("Failed to process sequence: %v", err)
46 | }
47 |
48 | // Print results
49 | fmt.Println("Input sequence:")
50 | for i, vec := range input {
51 | fmt.Printf("Token %d: First 4 values: %v\n", i, vec[:4])
52 | }
53 |
54 | fmt.Println("\nTransformed sequence:")
55 | for i, vec := range output {
56 | fmt.Printf("Token %d: First 4 values: %v\n", i, vec[:4])
57 | }
58 |
59 | // Demonstrate multi-head attention separately
60 | fmt.Println("\nDemonstrating Multi-Head Attention:")
61 |
62 | // Create input vectors with correct dimensions
63 | batchSize := 2
64 |
65 | queries := make(attention.Matrix, batchSize)
66 | keys := make(attention.Matrix, batchSize)
67 | values := make(attention.Matrix, batchSize)
68 |
69 | // Initialize with some random values
70 | for b := 0; b < batchSize; b++ {
71 | // Create query sequence
72 | queries[b] = make(attention.Vector, config.DModel)
73 | for j := range queries[b] {
74 | queries[b][j] = rand.Float64()*2 - 1
75 | }
76 |
77 | // Create key sequence
78 | keys[b] = make(attention.Vector, config.DModel)
79 | for j := range keys[b] {
80 | keys[b][j] = rand.Float64()*2 - 1
81 | }
82 |
83 | // Create value sequence
84 | values[b] = make(attention.Vector, config.DModel)
85 | for j := range values[b] {
86 | values[b][j] = rand.Float64()*2 - 1
87 | }
88 | }
89 |
90 | // Create multi-head attention
91 | mha, err := attention.NewMultiHeadAttention(attention.MultiHeadConfig{
92 | NumHeads: 4,
93 | DModel: config.DModel,
94 | DKey: config.DModel / 4,
95 | DValue: config.DModel / 4,
96 | DropoutRate: 0.1,
97 | })
98 | if err != nil {
99 | log.Fatalf("Failed to create multi-head attention: %v", err)
100 | }
101 |
102 | // Process through multi-head attention
103 | attended, err := mha.Forward(queries, keys, values)
104 | if err != nil {
105 | log.Fatalf("Failed to compute multi-head attention: %v", err)
106 | }
107 |
108 | fmt.Println("\nMulti-Head Attention outputs (first 4 values for each batch):")
109 | for b := range attended {
110 | fmt.Printf("Batch %d: %v\n", b, attended[b][:4])
111 | }
112 | }
--------------------------------------------------------------------------------
/examples/vector_search.go:
--------------------------------------------------------------------------------
1 | // main.go
2 | package main
3 |
4 | import (
5 | "fmt"
6 | "log"
7 | "os"
8 | "sort"
9 | "strconv"
10 | "strings"
11 |
12 | "github.com/takara-ai/go-attention/attention"
13 | )
14 |
15 | // Document holds a title, some content, and an embedding vector.
16 | type Document struct {
17 | Title string
18 | Content string
19 | Embedding attention.Vector
20 | }
21 |
22 | // parseVector converts a comma-separated string (e.g. "1.0,0.0,1.0,0.0")
23 | // into an attention.Vector.
24 | func parseVector(s string) (attention.Vector, error) {
25 | parts := strings.Split(s, ",")
26 | vec := make(attention.Vector, len(parts))
27 | for i, p := range parts {
28 | f, err := strconv.ParseFloat(strings.TrimSpace(p), 64)
29 | if err != nil {
30 | return nil, err
31 | }
32 | vec[i] = f
33 | }
34 | return vec, nil
35 | }
36 |
37 | func main() {
38 | // Define a small "database" of documents with 4-dimensional embeddings.
39 | documents := []Document{
40 | {
41 | Title: "Cats",
42 | Content: "Cats are small, carnivorous mammals that are often valued by humans for companionship.",
43 | Embedding: attention.Vector{1.0, 0.0, 1.0, 0.0},
44 | },
45 | {
46 | Title: "Dogs",
47 | Content: "Dogs are domesticated mammals, known for their loyalty and companionship with humans.",
48 | Embedding: attention.Vector{0.0, 1.0, 0.0, 1.0},
49 | },
50 | {
51 | Title: "Neutral",
52 | Content: "This document does not lean toward any particular subject.",
53 | Embedding: attention.Vector{0.5, 0.5, 0.5, 0.5},
54 | },
55 | {
56 | Title: "Birds",
57 | Content: "Birds are warm-blooded vertebrates characterized by feathers and beaks.",
58 | Embedding: attention.Vector{1.0, 1.0, 0.0, 0.0},
59 | },
60 | }
61 |
62 | // Default query vector (for example, looking for cat‑like features).
63 | defaultQuery := attention.Vector{1.0, 0.0, 1.0, 0.0}
64 |
65 | // If the user provides a query vector as the first command-line argument,
66 | // parse it. Otherwise, use the default.
67 | var query attention.Vector
68 | var err error
69 | if len(os.Args) > 1 {
70 | query, err = parseVector(os.Args[1])
71 | if err != nil {
72 | log.Fatalf("Error parsing query vector: %v", err)
73 | }
74 | } else {
75 | query = defaultQuery
76 | }
77 |
78 | // Ensure that the query vector dimension matches our document embeddings.
79 | if len(query) != len(documents[0].Embedding) {
80 | log.Fatalf("Query vector dimension (%d) does not match document embedding dimension (%d)",
81 | len(query), len(documents[0].Embedding))
82 | }
83 |
84 | // Build the keys and values matrices from the document embeddings.
85 | // In this simple example, we use the embeddings for both keys and values.
86 | keys := make(attention.Matrix, len(documents))
87 | values := make(attention.Matrix, len(documents))
88 | for i, doc := range documents {
89 | keys[i] = doc.Embedding
90 | values[i] = doc.Embedding
91 | }
92 |
93 | // Compute dot-product attention.
94 | // The returned 'weights' slice contains the attention weight for each document.
95 | _, weights, err := attention.DotProductAttention(query, keys, values)
96 | if err != nil {
97 | log.Fatalf("Error computing attention: %v", err)
98 | }
99 |
100 | // Create a slice that holds each document's index and its corresponding attention score.
101 | type docScore struct {
102 | index int
103 | score float64
104 | }
105 | scores := make([]docScore, len(documents))
106 | for i, w := range weights {
107 | scores[i] = docScore{index: i, score: w}
108 | }
109 |
110 | // Sort documents by descending attention weight (i.e. relevance).
111 | sort.Slice(scores, func(i, j int) bool {
112 | return scores[i].score > scores[j].score
113 | })
114 |
115 | // Print the query and the documents ranked by their relevance.
116 | fmt.Println("Query Vector:", query)
117 | fmt.Println("\nDocument Relevance Scores:")
118 | for _, ds := range scores {
119 | doc := documents[ds.index]
120 | fmt.Printf("Title: %s\nScore: %.3f\nContent: %s\n\n", doc.Title, ds.score, doc.Content)
121 | }
122 | }
123 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/takara-ai/go-attention
2 |
3 | go 1.23.4
4 |
--------------------------------------------------------------------------------
/transformer/transformer.go:
--------------------------------------------------------------------------------
1 | package transformer
2 |
3 | import (
4 | "fmt"
5 | "math"
6 | "math/rand"
7 |
8 | "github.com/takara-ai/go-attention/attention"
9 | )
10 |
11 | // LayerNorm implements layer normalization
12 | type LayerNorm struct {
13 | Dim int
14 | Eps float64
15 | Gamma attention.Vector // Scale parameter
16 | Beta attention.Vector // Shift parameter
17 | }
18 |
19 | // NewLayerNorm creates a new layer normalization module
20 | func NewLayerNorm(dim int, eps float64) *LayerNorm {
21 | gamma := make(attention.Vector, dim)
22 | beta := make(attention.Vector, dim)
23 |
24 | // Initialize parameters
25 | for i := range gamma {
26 | gamma[i] = 1.0 // Initialize scale to 1
27 | beta[i] = 0.0 // Initialize shift to 0
28 | }
29 |
30 | return &LayerNorm{
31 | Dim: dim,
32 | Eps: eps,
33 | Gamma: gamma,
34 | Beta: beta,
35 | }
36 | }
37 |
38 | // Forward applies layer normalization
39 | func (ln *LayerNorm) Forward(input attention.Matrix) (attention.Matrix, error) {
40 | output := make(attention.Matrix, len(input))
41 |
42 | for i, vec := range input {
43 | if len(vec) != ln.Dim {
44 | return nil, fmt.Errorf("input dimension mismatch: expected %d, got %d", ln.Dim, len(vec))
45 | }
46 |
47 | // Compute mean
48 | mean := 0.0
49 | for _, v := range vec {
50 | mean += v
51 | }
52 | mean /= float64(ln.Dim)
53 |
54 | // Compute variance
55 | variance := 0.0
56 | for _, v := range vec {
57 | diff := v - mean
58 | variance += diff * diff
59 | }
60 | variance /= float64(ln.Dim)
61 |
62 | // Normalize
63 | normalized := make(attention.Vector, ln.Dim)
64 | stdDev := math.Sqrt(variance + ln.Eps)
65 | for j, v := range vec {
66 | normalized[j] = ln.Gamma[j]*((v-mean)/stdDev) + ln.Beta[j]
67 | }
68 |
69 | output[i] = normalized
70 | }
71 |
72 | return output, nil
73 | }
74 |
75 | // FeedForward implements a position-wise feed-forward network
76 | type FeedForward struct {
77 | DModel int
78 | DHidden int
79 | W1, W2 attention.Matrix
80 | B1, B2 attention.Vector
81 | }
82 |
83 | // NewFeedForward creates a new feed-forward network
84 | func NewFeedForward(dModel, dHidden int) *FeedForward {
85 | ff := &FeedForward{
86 | DModel: dModel,
87 | DHidden: dHidden,
88 | W1: make(attention.Matrix, dModel),
89 | W2: make(attention.Matrix, dHidden),
90 | B1: make(attention.Vector, dHidden),
91 | B2: make(attention.Vector, dModel),
92 | }
93 |
94 | // Initialize weights with Xavier initialization
95 | scale1 := math.Sqrt(2.0 / float64(dModel+dHidden))
96 | scale2 := math.Sqrt(2.0 / float64(dHidden+dModel))
97 |
98 | for i := range ff.W1 {
99 | ff.W1[i] = make(attention.Vector, dHidden)
100 | for j := range ff.W1[i] {
101 | ff.W1[i][j] = (rand.Float64() - 0.5) * scale1
102 | }
103 | }
104 |
105 | for i := range ff.W2 {
106 | ff.W2[i] = make(attention.Vector, dModel)
107 | for j := range ff.W2[i] {
108 | ff.W2[i][j] = (rand.Float64() - 0.5) * scale2
109 | }
110 | }
111 |
112 | return ff
113 | }
114 |
115 | // Forward applies the feed-forward network
116 | func (ff *FeedForward) Forward(input attention.Matrix) (attention.Matrix, error) {
117 | // First layer
118 | hidden := make(attention.Matrix, len(input))
119 | for i, vec := range input {
120 | projected, err := projectVector(vec, ff.W1)
121 | if err != nil {
122 | return nil, fmt.Errorf("projecting first layer: %w", err)
123 | }
124 |
125 | // Add bias and apply ReLU
126 | hidden[i] = make(attention.Vector, ff.DHidden)
127 | for j := range projected {
128 | val := projected[j] + ff.B1[j]
129 | if val > 0 {
130 | hidden[i][j] = val
131 | }
132 | }
133 | }
134 |
135 | // Second layer
136 | output := make(attention.Matrix, len(input))
137 | for i, vec := range hidden {
138 | projected, err := projectVector(vec, ff.W2)
139 | if err != nil {
140 | return nil, fmt.Errorf("projecting second layer: %w", err)
141 | }
142 |
143 | // Add bias
144 | output[i] = make(attention.Vector, ff.DModel)
145 | for j := range projected {
146 | output[i][j] = projected[j] + ff.B2[j]
147 | }
148 | }
149 |
150 | return output, nil
151 | }
152 |
153 | // Helper function to project a vector through a weight matrix
154 | func projectVector(input attention.Vector, weights attention.Matrix) (attention.Vector, error) {
155 | if len(weights) == 0 || len(weights[0]) == 0 {
156 | return nil, fmt.Errorf("empty weight matrix")
157 | }
158 | if len(input) != len(weights) {
159 | return nil, fmt.Errorf("input dimension (%d) does not match weights (%d)", len(input), len(weights))
160 | }
161 |
162 | output := make(attention.Vector, len(weights[0]))
163 | for i := range weights[0] {
164 | for j, w := range weights {
165 | output[i] += input[j] * w[i]
166 | }
167 | }
168 | return output, nil
169 | }
170 |
171 | // TransformerConfig holds configuration for a transformer layer
172 | type TransformerConfig struct {
173 | DModel int // Model dimension
174 | NumHeads int // Number of attention heads
175 | DHidden int // Hidden dimension in feed-forward network
176 | DropoutRate float64 // Dropout rate
177 | }
178 |
179 | // TransformerLayer implements a single transformer layer
180 | type TransformerLayer struct {
181 | Config TransformerConfig
182 | SelfAttn *attention.MultiHeadAttention
183 | FeedForward *FeedForward
184 | Norm1 *LayerNorm
185 | Norm2 *LayerNorm
186 | }
187 |
188 | // NewTransformerLayer creates a new transformer layer
189 | func NewTransformerLayer(config TransformerConfig) (*TransformerLayer, error) {
190 | // Create multi-head attention
191 | attnConfig := attention.MultiHeadConfig{
192 | NumHeads: config.NumHeads,
193 | DModel: config.DModel,
194 | DKey: config.DModel / config.NumHeads,
195 | DValue: config.DModel / config.NumHeads,
196 | DropoutRate: config.DropoutRate,
197 | }
198 |
199 | selfAttn, err := attention.NewMultiHeadAttention(attnConfig)
200 | if err != nil {
201 | return nil, fmt.Errorf("creating self-attention: %w", err)
202 | }
203 |
204 | return &TransformerLayer{
205 | Config: config,
206 | SelfAttn: selfAttn,
207 | FeedForward: NewFeedForward(config.DModel, config.DHidden),
208 | Norm1: NewLayerNorm(config.DModel, 1e-5),
209 | Norm2: NewLayerNorm(config.DModel, 1e-5),
210 | }, nil
211 | }
212 |
213 | // Forward applies the transformer layer
214 | func (t *TransformerLayer) Forward(input attention.Matrix) (attention.Matrix, error) {
215 | // Self-attention sub-layer
216 | normalized1, err := t.Norm1.Forward(input)
217 | if err != nil {
218 | return nil, fmt.Errorf("normalizing input: %w", err)
219 | }
220 |
221 | attended, err := t.SelfAttn.Forward(normalized1, normalized1, normalized1)
222 | if err != nil {
223 | return nil, fmt.Errorf("computing self-attention: %w", err)
224 | }
225 |
226 | // Add & Norm
227 | residual1 := make(attention.Matrix, len(input))
228 | for i := range input {
229 | residual1[i], err = attention.AddVectors(input[i], attended[i])
230 | if err != nil {
231 | return nil, fmt.Errorf("adding residual connection: %w", err)
232 | }
233 | }
234 |
235 | // Feed-forward sub-layer
236 | normalized2, err := t.Norm2.Forward(residual1)
237 | if err != nil {
238 | return nil, fmt.Errorf("normalizing first sub-layer output: %w", err)
239 | }
240 |
241 | ffOutput, err := t.FeedForward.Forward(normalized2)
242 | if err != nil {
243 | return nil, fmt.Errorf("computing feed-forward: %w", err)
244 | }
245 |
246 | // Add & Norm
247 | output := make(attention.Matrix, len(input))
248 | for i := range input {
249 | output[i], err = attention.AddVectors(residual1[i], ffOutput[i])
250 | if err != nil {
251 | return nil, fmt.Errorf("adding final residual connection: %w", err)
252 | }
253 | }
254 |
255 | return output, nil
256 | }
--------------------------------------------------------------------------------