├── .gitignore ├── LICENSE ├── README.md ├── attention ├── attention.go └── multihead.go ├── examples ├── api_examples.go ├── data │ └── imdb_1000.json ├── sentiment_trainer.go ├── serverless_search.go ├── transformer_example.go └── vector_search.go ├── go.mod └── transformer └── transformer.go /.gitignore: -------------------------------------------------------------------------------- 1 | # If you prefer the allow list template instead of the deny list, see community template: 2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 3 | # 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | 11 | # Test binary, built with `go test -c` 12 | *.test 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Dependency directories (remove the comment below to include it) 18 | # vendor/ 19 | 20 | # Go workspace file 21 | go.work 22 | go.work.sum 23 | 24 | # env file 25 | .env 26 | .DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Takara AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # go-attention 2 | 3 | Takara.ai Logo 4 | 5 | From the Frontier Research Team at takara.ai we present the first pure Go implementation of attention mechanisms and transformer layers, designed for high performance and ease of use. 6 | 7 | ## Quick Start 8 | 9 | Run our comprehensive examples: 10 | 11 | ```bash 12 | # Get the module 13 | go get github.com/takara-ai/go-attention 14 | 15 | # Run the examples 16 | go run api_examples.go 17 | ``` 18 | 19 | ## API Documentation 20 | 21 | ### Core Types 22 | 23 | ```go 24 | type Vector []float64 // Represents a 1D vector of float64 values 25 | type Matrix []Vector // Represents a 2D matrix of float64 values 26 | ``` 27 | 28 | ### 1. Basic Dot-Product Attention 29 | 30 | The simplest form of attention mechanism. Useful for basic sequence processing tasks. 31 | 32 | ```go 33 | import "github.com/takara-ai/go-attention/attention" 34 | 35 | // Create query-key-value setup 36 | query := attention.Vector{1.0, 0.0, 1.0, 0.0} // Pattern to search for 37 | keys := attention.Matrix{ 38 | {1.0, 0.0, 1.0, 0.0}, // Similar to query 39 | {0.0, 1.0, 0.0, 1.0}, // Different from query 40 | {0.5, 0.5, 0.5, 0.5}, // Neutral pattern 41 | } 42 | values := attention.Matrix{ 43 | {1.0, 2.0}, // Value for similar key 44 | {3.0, 4.0}, // Value for different key 45 | {5.0, 6.0}, // Value for neutral key 46 | } 47 | 48 | // Compute attention 49 | output, weights, err := attention.DotProductAttention(query, keys, values) 50 | if err != nil { 51 | log.Fatal(err) 52 | } 53 | 54 | // Output will be a weighted combination of values based on query-key similarity 55 | // Weights will show how much attention each key received 56 | ``` 57 | 58 | ### 2. Multi-Head Attention 59 | 60 | More sophisticated attention mechanism that can capture different types of relationships in parallel. 61 | 62 | ```go 63 | import "github.com/takara-ai/go-attention/attention" 64 | 65 | // Configure multi-head attention 66 | config := attention.MultiHeadConfig{ 67 | NumHeads: 4, // Number of parallel attention heads 68 | DModel: 64, // Size of input/output embeddings 69 | DKey: 16, // Size per head (DModel/NumHeads) 70 | DValue: 16, // Size per head (DModel/NumHeads) 71 | DropoutRate: 0.1, // For regularization 72 | } 73 | 74 | // Create the attention module 75 | mha, err := attention.NewMultiHeadAttention(config) 76 | if err != nil { 77 | log.Fatal(err) 78 | } 79 | 80 | // Process sequences (batched input) 81 | batchSize, seqLen := 2, 3 // Process 2 sequences, each with 3 tokens 82 | 83 | // Create input matrices [batchSize × seqLen × DModel] 84 | queries := make(attention.Matrix, batchSize*seqLen) 85 | keys := make(attention.Matrix, batchSize*seqLen) 86 | values := make(attention.Matrix, batchSize*seqLen) 87 | 88 | // Initialize your matrices with actual data... 89 | 90 | // Process through multi-head attention 91 | output, err := mha.Forward(queries, keys, values) 92 | if err != nil { 93 | log.Fatal(err) 94 | } 95 | ``` 96 | 97 | ### 3. Full Transformer Layer 98 | 99 | Complete transformer layer with self-attention and feed-forward network. 100 | 101 | ```go 102 | import ( 103 | "github.com/takara-ai/go-attention/transformer" 104 | "github.com/takara-ai/go-attention/attention" 105 | ) 106 | 107 | // Configure transformer layer 108 | config := transformer.TransformerConfig{ 109 | DModel: 64, // Size of token embeddings 110 | NumHeads: 4, // Number of attention heads 111 | DHidden: 256, // Size of feed-forward hidden layer 112 | DropoutRate: 0.1, // For regularization 113 | } 114 | 115 | // Create transformer layer 116 | layer, err := transformer.NewTransformerLayer(config) 117 | if err != nil { 118 | log.Fatal(err) 119 | } 120 | 121 | // Create input sequence [seq_len × d_model] 122 | seqLen := 3 123 | input := make(attention.Matrix, seqLen) 124 | for i := range input { 125 | input[i] = make(attention.Vector, config.DModel) 126 | // Fill with your embedding data... 127 | } 128 | 129 | // Process through transformer 130 | output, err := layer.Forward(input) 131 | if err != nil { 132 | log.Fatal(err) 133 | } 134 | ``` 135 | 136 | ## Example Output 137 | 138 | When running the examples, you'll see: 139 | 140 | 1. **Dot-Product Attention**: 141 | 142 | ``` 143 | Query: [1 0 1 0] 144 | Attention Weights: [0.506 0.186 0.307] // Shows focus on similar patterns 145 | Output: [2.601 3.601] // Weighted combination of values 146 | ``` 147 | 148 | 2. **Multi-Head Attention**: 149 | 150 | ``` 151 | Input dimensions: [2 batches × 3 tokens × 64 features] 152 | Output shape: [6×64] 153 | ``` 154 | 155 | 3. **Transformer Layer**: 156 | ``` 157 | Input shape: [3×64] 158 | Output shape: [3×64] 159 | ``` 160 | 161 | ## Common Use Cases 162 | 163 | 1. **Text Processing**: 164 | 165 | - Sequence-to-sequence translation 166 | - Document summarization 167 | - Sentiment analysis 168 | 169 | 2. **Time Series**: 170 | 171 | - Financial forecasting 172 | - Sensor data analysis 173 | - Anomaly detection 174 | 175 | 3. **Structured Data**: 176 | - Graph node embedding 177 | - Feature interaction modeling 178 | - Recommendation systems 179 | 180 | ## Performance Considerations 181 | 182 | - Matrix operations are optimized for CPU 183 | - Memory allocations are minimized 184 | - Batch processing for better throughput 185 | - No external dependencies 186 | 187 | For more detailed examples, see the `examples` directory in the repository. 188 | 189 | ## Why go-attention? 190 | 191 | This module was created to provide a clean, efficient, and dependency-free implementation of attention mechanisms in Go. It's particularly useful for: 192 | 193 | - **Edge Computing**: Zero external dependencies makes it perfect for edge devices where dependency management is crucial 194 | - **Real-time Processing**: Pure Go implementation ensures predictable performance for real-time applications 195 | - **Cloud-native Applications**: Efficient batched operations support high-throughput scaling in cloud environments 196 | - **Embedded Systems**: Predictable resource usage and minimal memory allocations 197 | - **Production Systems**: Comprehensive error handling and type safety for robust production deployments 198 | 199 | ## Features 200 | 201 | - Efficient dot-product attention mechanism (upgraded with Scalable-Softmax (SSMax, s=1) for improved long-context performance) 202 | - Multi-head attention support 203 | - Full transformer layer implementation with: 204 | - Layer normalization 205 | - Position-wise feed-forward networks 206 | - Residual connections 207 | - Batched operations for improved performance 208 | 209 | ## Roadmap 210 | 211 | Future improvements may include: 212 | 213 | - Positional encoding implementations 214 | - Dropout support 215 | - CUDA acceleration support 216 | - Additional transformer variants 217 | - Pre-trained models 218 | - Training utilities 219 | 220 | ## Contributing 221 | 222 | Contributions are welcome! Please feel free to submit a Pull Request. 223 | 224 | ## License 225 | 226 | MIT License - see LICENSE file for details 227 | 228 | --- 229 | 230 | For research inquiries and press, please reach out to research@takara.ai 231 | 232 | > 人類を変革する 233 | -------------------------------------------------------------------------------- /attention/attention.go: -------------------------------------------------------------------------------- 1 | package attention 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | ) 7 | 8 | // Vector represents a slice of float64 values 9 | type Vector []float64 10 | 11 | // Matrix represents a 2D slice of float64 values 12 | type Matrix []Vector 13 | 14 | // i is the representation of the query 15 | // j is the representation of the key 16 | // k is the representation of the value 17 | 18 | // DotProduct computes the dot product of two vectors 19 | func DotProduct(v1, v2 Vector) (float64, error) { 20 | if len(v1) != len(v2) { 21 | return 0, fmt.Errorf("vector dimensions mismatch: %d != %d", len(v1), len(v2)) 22 | } 23 | 24 | sum := 0.0 25 | for i := range v1 { 26 | sum += v1[i] * v2[i] 27 | } 28 | return sum, nil 29 | } 30 | 31 | // Softmax applies the softmax function to a vector 32 | func Softmax(x Vector) Vector { 33 | if len(x) == 0 { 34 | return Vector{} 35 | } 36 | 37 | maxVal := x[0] 38 | // Find max value for numerical stability 39 | for _, v := range x[1:] { 40 | if v > maxVal { 41 | maxVal = v 42 | } 43 | } 44 | 45 | exps := make(Vector, len(x)) 46 | sumExp := 0.0 47 | for i, v := range x { 48 | // Subtracting maxVal before exponentiation 49 | exps[i] = math.Exp(v - maxVal) 50 | sumExp += exps[i] 51 | } 52 | 53 | // Handle sumExp == 0 case to prevent NaN (division by zero) 54 | // This can happen if all inputs are extremely small negative numbers. 55 | if sumExp == 0 { 56 | // Return a zero vector. It's already initialized with zeros implicitly by make 57 | return exps 58 | } 59 | 60 | // Normalize: Use direct division instead of multiplying by inverse 61 | for i := range exps { 62 | exps[i] /= sumExp 63 | } 64 | 65 | return exps 66 | } 67 | 68 | // ScaleVector multiplies a vector by a scalar 69 | func ScaleVector(v Vector, scale float64) Vector { 70 | result := make(Vector, len(v)) 71 | for i, val := range v { 72 | result[i] = val * scale 73 | } 74 | return result 75 | } 76 | 77 | // AddVectors adds two vectors element-wise 78 | func AddVectors(v1, v2 Vector) (Vector, error) { 79 | if len(v1) != len(v2) { 80 | return nil, fmt.Errorf("vector dimensions mismatch: %d != %d", len(v1), len(v2)) 81 | } 82 | 83 | result := make(Vector, len(v1)) 84 | for i := range v1 { 85 | result[i] = v1[i] + v2[i] 86 | } 87 | return result, nil 88 | } 89 | 90 | // DotProductAttention computes scaled dot-product attention 91 | // query: [d_k], keys: [n, d_k], values: [n, d_v] 92 | // Returns: attended vector [d_v] and attention weights [n] 93 | func DotProductAttention(query Vector, keys, values Matrix) (Vector, Vector, error) { 94 | n := len(keys) 95 | if n == 0 { 96 | // If keys are empty, check if values exist to determine output dimension d_v 97 | if len(values) > 0 && len(values[0]) > 0 { 98 | d_v := len(values[0]) 99 | // Return empty weights and a zero vector of the correct value dimension 100 | return make(Vector, d_v), Vector{}, nil 101 | } 102 | // If both keys and values are empty (or values have zero dimension), return error or nil vectors 103 | return nil, nil, fmt.Errorf("empty keys and values provided") 104 | } 105 | 106 | // Basic dimension validation before proceeding 107 | if len(values) != n { 108 | return nil, nil, fmt.Errorf("number of keys (%d) must match number of values (%d)", n, len(values)) 109 | } 110 | if len(values[0]) == 0 { 111 | return nil, nil, fmt.Errorf("value dimension (d_v) cannot be zero") 112 | } 113 | d_v := len(values[0]) // Dimension of value vectors 114 | 115 | // Determine key dimension (d_k) safely 116 | d_k := len(query) 117 | if n > 0 && len(keys[0]) != d_k { 118 | // Check the first key's dimension against the query dimension 119 | return nil, nil, fmt.Errorf("query dimension (%d) must match key dimension (%d)", d_k, len(keys[0])) 120 | } 121 | 122 | 123 | // Compute attention scores 124 | scores := make(Vector, n) 125 | 126 | // Pre-calculate scaling factor only if d_k > 0 to avoid division by zero or sqrt of zero 127 | scale := 1.0 128 | if d_k > 0 { 129 | scale = 1.0 / math.Sqrt(float64(d_k)) 130 | } // If d_k is 0, scale remains 1.0, dot product will likely be 0 unless vectors are empty. 131 | 132 | for i, key := range keys { 133 | // Ensure consistent key dimensions within the loop 134 | if len(key) != d_k { 135 | return nil, nil, fmt.Errorf("key dimension mismatch at index %d: expected %d, got %d", i, d_k, len(key)) 136 | } 137 | score, err := DotProduct(query, key) // DotProduct already checks len(query) == len(key) 138 | if err != nil { 139 | // This error should theoretically not happen if the outer checks pass, but handle defensively. 140 | return nil, nil, fmt.Errorf("error computing dot product for key %d: %w", i, err) 141 | } 142 | // Scale by sqrt(d_k) for better gradient flow 143 | scores[i] = score * scale // Use pre-calculated scale 144 | } 145 | 146 | // --- SSMax Modification (s=1) --- 147 | // Calculate log n (natural logarithm of sequence length) 148 | // If n=0 or n=1, log_n is handled to avoid issues. log(1)=0 is mathematically correct per formula. 149 | log_n := 0.0 150 | if n > 1 { 151 | log_n = math.Log(float64(n)) 152 | } else if n == 1 { 153 | // For n=1, log(1)=0. Softmax of a single element scaled by 0 is still undefined 154 | // in the standard implementation, but mathematically should result in a weight of 1. 155 | // The existing Softmax handles single elements correctly, so scaling by 0 is fine. 156 | // Alternatively, we could set log_n = 1.0 to effectively bypass SSMax scaling for n=1. 157 | // Let's stick to the formula log_n = 0 for n=1. 158 | log_n = 0.0 159 | } 160 | 161 | // Multiply scores by log n 162 | // Avoid modifying if log_n is effectively 1 (e.g., if n == math.E, though unlikely) 163 | // Or if log_n is 0 (when n=1), scaling is identity or zeroing. 164 | if n > 1 { // Only scale if n > 1, otherwise log_n is 0 165 | for i := range scores { 166 | scores[i] *= log_n 167 | } 168 | } 169 | // --- End SSMax Modification --- 170 | 171 | // Apply softmax to get attention weights (now using SSMax-modified scores) 172 | weights := Softmax(scores) 173 | 174 | // Compute weighted sum of values 175 | attended := make(Vector, d_v) // Use d_v determined earlier 176 | for i, weight := range weights { 177 | // Ensure consistent value dimensions within the loop 178 | if len(values[i]) != d_v { 179 | return nil, nil, fmt.Errorf("value dimension mismatch at index %d: expected %d, got %d", i, d_v, len(values[i])) 180 | } 181 | 182 | // Optimization: Skip summation if weight is zero 183 | if weight == 0 { 184 | continue 185 | } 186 | 187 | // Fuse the scaling by weight into the accumulation loop 188 | valueVec := values[i] // Local ref might help optimizer, maybe minor 189 | for j := 0; j < d_v; j++ { // Iterate up to d_v 190 | attended[j] += weight * valueVec[j] 191 | } 192 | } 193 | 194 | return attended, weights, nil 195 | } 196 | 197 | // TODO: kv caching 198 | // TODO: tokenization support 199 | // TODO: rotary positional embedding support RoPE 200 | // TODO: SwarmFormer Local-Global Hiearchical Attention Support -------------------------------------------------------------------------------- /attention/multihead.go: -------------------------------------------------------------------------------- 1 | package attention 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "math/rand" 7 | ) 8 | 9 | // MultiHeadConfig holds configuration for multi-head attention 10 | type MultiHeadConfig struct { 11 | NumHeads int // Number of attention heads 12 | DModel int // Model dimension 13 | DKey int // Key dimension per head 14 | DValue int // Value dimension per head 15 | DropoutRate float64 // Dropout rate (not implemented in this version for simplicity) 16 | } 17 | 18 | // MultiHeadAttention implements multi-head attention mechanism 19 | type MultiHeadAttention struct { 20 | config MultiHeadConfig 21 | 22 | // Linear projections for each head 23 | queryProj []Matrix // [num_heads][d_model][d_k] 24 | keyProj []Matrix // [num_heads][d_model][d_k] 25 | valueProj []Matrix // [num_heads][d_model][d_v] 26 | 27 | // Output projection 28 | outputProj Matrix // [d_model][num_heads * d_v] 29 | } 30 | 31 | // NewMultiHeadAttention creates a new multi-head attention module 32 | func NewMultiHeadAttention(config MultiHeadConfig) (*MultiHeadAttention, error) { 33 | if config.NumHeads <= 0 { 34 | return nil, fmt.Errorf("number of heads must be positive, got %d", config.NumHeads) 35 | } 36 | if config.DModel <= 0 { 37 | return nil, fmt.Errorf("model dimension must be positive, got %d", config.DModel) 38 | } 39 | if config.DModel%config.NumHeads != 0 { 40 | return nil, fmt.Errorf("model dimension (%d) must be divisible by number of heads (%d)", config.DModel, config.NumHeads) 41 | } 42 | 43 | mha := &MultiHeadAttention{ 44 | config: config, 45 | queryProj: make([]Matrix, config.NumHeads), 46 | keyProj: make([]Matrix, config.NumHeads), 47 | valueProj: make([]Matrix, config.NumHeads), 48 | outputProj: make(Matrix, config.DModel), 49 | } 50 | 51 | // Initialize projections with random weights 52 | for h := 0; h < config.NumHeads; h++ { 53 | mha.queryProj[h] = randomMatrix(config.DModel, config.DKey) 54 | mha.keyProj[h] = randomMatrix(config.DModel, config.DKey) 55 | mha.valueProj[h] = randomMatrix(config.DModel, config.DValue) 56 | } 57 | 58 | // Initialize output projection 59 | for i := range mha.outputProj { 60 | mha.outputProj[i] = make(Vector, config.NumHeads*config.DValue) 61 | for j := range mha.outputProj[i] { 62 | mha.outputProj[i][j] = (rand.Float64() - 0.5) / math.Sqrt(float64(config.DValue)) 63 | } 64 | } 65 | 66 | return mha, nil 67 | } 68 | 69 | // Forward computes multi-head attention 70 | // query, key, value: [batch_size, seq_len, d_model] 71 | func (mha *MultiHeadAttention) Forward(query, key, value Matrix) (Matrix, error) { 72 | batchSize := len(query) 73 | if batchSize != len(key) || batchSize != len(value) { 74 | return nil, fmt.Errorf("batch size mismatch: query(%d), key(%d), value(%d)", batchSize, len(key), len(value)) 75 | } 76 | 77 | // Compute attention for each head 78 | headOutputs := make([]Matrix, mha.config.NumHeads) 79 | for h := 0; h < mha.config.NumHeads; h++ { 80 | // Project inputs 81 | projQuery, err := projectBatch(query, mha.queryProj[h]) 82 | if err != nil { 83 | return nil, fmt.Errorf("projecting query for head %d: %w", h, err) 84 | } 85 | 86 | projKey, err := projectBatch(key, mha.keyProj[h]) 87 | if err != nil { 88 | return nil, fmt.Errorf("projecting key for head %d: %w", h, err) 89 | } 90 | 91 | projValue, err := projectBatch(value, mha.valueProj[h]) 92 | if err != nil { 93 | return nil, fmt.Errorf("projecting value for head %d: %w", h, err) 94 | } 95 | 96 | // Initialize head output 97 | headOutputs[h] = make(Matrix, batchSize) 98 | 99 | // Compute attention for each item in the batch 100 | for b := 0; b < batchSize; b++ { 101 | attended, _, err := DotProductAttention(projQuery[b], projKey, projValue) 102 | if err != nil { 103 | return nil, fmt.Errorf("computing attention for batch %d, head %d: %w", b, h, err) 104 | } 105 | headOutputs[h][b] = attended 106 | } 107 | } 108 | 109 | // Concatenate and project heads 110 | output := make(Matrix, batchSize) 111 | for b := 0; b < batchSize; b++ { 112 | // Concatenate all head outputs 113 | concat := make(Vector, 0, mha.config.NumHeads*mha.config.DValue) 114 | for h := 0; h < mha.config.NumHeads; h++ { 115 | concat = append(concat, headOutputs[h][b]...) 116 | } 117 | 118 | // Project concatenated output 119 | output[b] = make(Vector, mha.config.DModel) 120 | for i := range output[b] { 121 | for j, v := range concat { 122 | output[b][i] += v * mha.outputProj[i][j] 123 | } 124 | } 125 | } 126 | 127 | return output, nil 128 | } 129 | 130 | // Helper functions 131 | 132 | func randomMatrix(rows, cols int) Matrix { 133 | mat := make(Matrix, rows) 134 | scale := math.Sqrt(2.0 / float64(rows+cols)) // Xavier initialization 135 | for i := range mat { 136 | mat[i] = make(Vector, cols) 137 | for j := range mat[i] { 138 | mat[i][j] = (rand.Float64() - 0.5) * scale 139 | } 140 | } 141 | return mat 142 | } 143 | 144 | func projectBatch(input Matrix, weights Matrix) (Matrix, error) { 145 | output := make(Matrix, len(input)) 146 | for i, vec := range input { 147 | projected, err := projectVector(vec, weights) 148 | if err != nil { 149 | return nil, err 150 | } 151 | output[i] = projected 152 | } 153 | return output, nil 154 | } 155 | 156 | func projectVector(input Vector, weights Matrix) (Vector, error) { 157 | if len(weights) == 0 || len(weights[0]) == 0 { 158 | return nil, fmt.Errorf("empty weight matrix") 159 | } 160 | if len(input) != len(weights) { 161 | return nil, fmt.Errorf("input dimension (%d) does not match weights (%d)", len(input), len(weights)) 162 | } 163 | 164 | output := make(Vector, len(weights[0])) 165 | for i := range weights[0] { 166 | for j, w := range weights { 167 | output[i] += input[j] * w[i] 168 | } 169 | } 170 | return output, nil 171 | } -------------------------------------------------------------------------------- /examples/api_examples.go: -------------------------------------------------------------------------------- 1 | // This script demonstrates all key APIs of the go-attention module 2 | package main 3 | 4 | import ( 5 | "fmt" 6 | "log" 7 | "github.com/takara-ai/go-attention/attention" 8 | "github.com/takara-ai/go-attention/transformer" 9 | ) 10 | 11 | func main() { 12 | // 1. Basic Dot-Product Attention 13 | fmt.Println("\n=== 1. Basic Dot-Product Attention ===") 14 | testDotProductAttention() 15 | 16 | // 2. Multi-Head Attention 17 | fmt.Println("\n=== 2. Multi-Head Attention ===") 18 | testMultiHeadAttention() 19 | 20 | // 3. Full Transformer Layer 21 | fmt.Println("\n=== 3. Full Transformer Layer ===") 22 | testTransformerLayer() 23 | } 24 | 25 | func testDotProductAttention() { 26 | // Create a simple query-key-value setup 27 | query := attention.Vector{1.0, 0.0, 1.0, 0.0} // Looking for patterns similar to [1,0,1,0] 28 | keys := attention.Matrix{ 29 | {1.0, 0.0, 1.0, 0.0}, // Similar to query 30 | {0.0, 1.0, 0.0, 1.0}, // Different from query 31 | {0.5, 0.5, 0.5, 0.5}, // Neutral pattern 32 | } 33 | values := attention.Matrix{ 34 | {1.0, 2.0}, // Value for similar key 35 | {3.0, 4.0}, // Value for different key 36 | {5.0, 6.0}, // Value for neutral key 37 | } 38 | 39 | // Compute attention 40 | output, weights, err := attention.DotProductAttention(query, keys, values) 41 | if err != nil { 42 | log.Fatal(err) 43 | } 44 | 45 | fmt.Println("Query:", query) 46 | fmt.Println("Attention Weights:", weights) 47 | fmt.Println("Output:", output) 48 | } 49 | 50 | func testMultiHeadAttention() { 51 | // Configure multi-head attention 52 | config := attention.MultiHeadConfig{ 53 | NumHeads: 4, 54 | DModel: 64, 55 | DKey: 16, // DModel / NumHeads 56 | DValue: 16, // DModel / NumHeads 57 | DropoutRate: 0.1, 58 | } 59 | 60 | // Create multi-head attention module 61 | mha, err := attention.NewMultiHeadAttention(config) 62 | if err != nil { 63 | log.Fatal(err) 64 | } 65 | 66 | // Create sample input with dimensions: 67 | // - batch_size: number of sequences to process in parallel 68 | // - seq_len: number of tokens in each sequence 69 | // - d_model: dimension of each token's embedding 70 | batchSize, seqLen := 2, 3 71 | 72 | // Create input matrices with shape [batchSize × seqLen × DModel] 73 | queries := make(attention.Matrix, batchSize*seqLen) 74 | keys := make(attention.Matrix, batchSize*seqLen) 75 | values := make(attention.Matrix, batchSize*seqLen) 76 | 77 | // Initialize with a deterministic pattern 78 | for i := 0; i < batchSize*seqLen; i++ { 79 | queries[i] = make(attention.Vector, config.DModel) 80 | keys[i] = make(attention.Vector, config.DModel) 81 | values[i] = make(attention.Vector, config.DModel) 82 | for j := 0; j < config.DModel; j++ { 83 | val := float64(i*j) / float64(config.DModel) 84 | queries[i][j] = val 85 | keys[i][j] = val 86 | values[i][j] = val 87 | } 88 | } 89 | 90 | // Process through multi-head attention 91 | output, err := mha.Forward(queries, keys, values) 92 | if err != nil { 93 | log.Fatal(err) 94 | } 95 | 96 | fmt.Printf("Input dimensions: [%d batches × %d tokens × %d features]\n", 97 | batchSize, seqLen, config.DModel) 98 | fmt.Printf("Output shape: [%d×%d]\n", len(output), len(output[0])) 99 | fmt.Println("First few output values:", output[0][:4]) 100 | } 101 | 102 | func testTransformerLayer() { 103 | // Configure transformer layer 104 | config := transformer.TransformerConfig{ 105 | DModel: 64, 106 | NumHeads: 4, 107 | DHidden: 256, 108 | DropoutRate: 0.1, 109 | } 110 | 111 | // Create transformer layer 112 | layer, err := transformer.NewTransformerLayer(config) 113 | if err != nil { 114 | log.Fatal(err) 115 | } 116 | 117 | // Create sample sequence (seq_len=3, d_model=64) 118 | input := createRandomMatrix(3, config.DModel) 119 | 120 | // Process through transformer 121 | output, err := layer.Forward(input) 122 | if err != nil { 123 | log.Fatal(err) 124 | } 125 | 126 | fmt.Printf("Input shape: [%d×%d]\n", len(input), len(input[0])) 127 | fmt.Printf("Output shape: [%d×%d]\n", len(output), len(output[0])) 128 | fmt.Println("First token before:", input[0][:4]) 129 | fmt.Println("First token after:", output[0][:4]) 130 | } 131 | 132 | // Helper function to create random matrices 133 | func createRandomMatrix(rows, cols int) attention.Matrix { 134 | matrix := make(attention.Matrix, rows) 135 | for i := range matrix { 136 | matrix[i] = make(attention.Vector, cols) 137 | for j := range matrix[i] { 138 | matrix[i][j] = float64(i+j) / float64(cols) // Deterministic pattern for testing 139 | } 140 | } 141 | return matrix 142 | } -------------------------------------------------------------------------------- /examples/sentiment_trainer.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "log" 7 | "math" 8 | "math/rand" 9 | "os" 10 | "strings" 11 | "time" 12 | 13 | "github.com/takara-ai/go-attention/attention" 14 | ) 15 | 16 | // IMDBReview represents a single review from the IMDB dataset 17 | type IMDBReview struct { 18 | Text string `json:"text"` 19 | Label float64 `json:"label"` 20 | } 21 | 22 | // TrainingExample represents a single training instance 23 | type TrainingExample struct { 24 | Text string `json:"text"` 25 | Sentiment float64 `json:"sentiment"` 26 | } 27 | 28 | // Model represents our sentiment analysis model 29 | type Model struct { 30 | WordEmbedding []attention.Vector `json:"word_embedding"` // Simple word embedding lookup 31 | QueryWeights attention.Matrix `json:"query_weights"` // Query projection weights 32 | KeyWeights attention.Matrix `json:"key_weights"` // Key projection weights 33 | ValueWeights attention.Matrix `json:"value_weights"` // Value projection weights 34 | OutputWeights attention.Vector `json:"output_weights"` // Output layer weights 35 | VocabSize int `json:"vocab_size"` 36 | EmbedDim int `json:"embed_dim"` 37 | AttnDim int `json:"attn_dim"` 38 | } 39 | 40 | // NewModel creates a new sentiment analysis model 41 | func NewModel(vocabSize, embedDim, attnDim int) *Model { 42 | // Initialize word embeddings with Xavier initialization 43 | scale := math.Sqrt(2.0 / float64(embedDim)) 44 | wordEmbedding := make([]attention.Vector, vocabSize) 45 | for i := range wordEmbedding { 46 | wordEmbedding[i] = make(attention.Vector, embedDim) 47 | for j := range wordEmbedding[i] { 48 | wordEmbedding[i][j] = rand.NormFloat64() * scale 49 | } 50 | } 51 | 52 | // Initialize attention weights 53 | queryWeights := make(attention.Matrix, embedDim) 54 | keyWeights := make(attention.Matrix, embedDim) 55 | valueWeights := make(attention.Matrix, embedDim) 56 | for i := 0; i < embedDim; i++ { 57 | queryWeights[i] = make(attention.Vector, attnDim) 58 | keyWeights[i] = make(attention.Vector, attnDim) 59 | valueWeights[i] = make(attention.Vector, attnDim) 60 | for j := 0; j < attnDim; j++ { 61 | queryWeights[i][j] = rand.NormFloat64() * scale 62 | keyWeights[i][j] = rand.NormFloat64() * scale 63 | valueWeights[i][j] = rand.NormFloat64() * scale 64 | } 65 | } 66 | 67 | // Initialize output weights 68 | outputWeights := make(attention.Vector, embedDim) 69 | for i := range outputWeights { 70 | outputWeights[i] = rand.NormFloat64() * scale 71 | } 72 | 73 | return &Model{ 74 | WordEmbedding: wordEmbedding, 75 | QueryWeights: queryWeights, 76 | KeyWeights: keyWeights, 77 | ValueWeights: valueWeights, 78 | OutputWeights: outputWeights, 79 | VocabSize: vocabSize, 80 | EmbedDim: embedDim, 81 | AttnDim: attnDim, 82 | } 83 | } 84 | 85 | // sigmoid computes the sigmoid function 86 | func sigmoid(x float64) float64 { 87 | return 1.0 / (1.0 + math.Exp(-x)) 88 | } 89 | 90 | // dotProduct computes the dot product of two vectors 91 | func dotProduct(a, b attention.Vector) float64 { 92 | sum := 0.0 93 | for i := range a { 94 | sum += a[i] * b[i] 95 | } 96 | return sum 97 | } 98 | 99 | // hashWord creates a simple hash for word to vocabulary mapping 100 | func hashWord(word string) int { 101 | hash := 0 102 | for _, c := range word { 103 | hash = (hash*31 + int(c)) % 10000 104 | } 105 | return hash 106 | } 107 | 108 | // Batch represents a batch of training examples 109 | type Batch struct { 110 | Embeddings []attention.Matrix 111 | Targets []float64 112 | } 113 | 114 | // createBatch creates a batch from examples 115 | func (m *Model) createBatch(examples []TrainingExample, batchSize int) []Batch { 116 | var batches []Batch 117 | for i := 0; i < len(examples); i += batchSize { 118 | end := i + batchSize 119 | if end > len(examples) { 120 | end = len(examples) 121 | } 122 | 123 | batch := Batch{ 124 | Embeddings: make([]attention.Matrix, end-i), 125 | Targets: make([]float64, end-i), 126 | } 127 | 128 | for j, example := range examples[i:end] { 129 | words := strings.Fields(strings.ToLower(example.Text)) 130 | if len(words) > 100 { 131 | words = words[:100] 132 | } 133 | 134 | embeddings := make(attention.Matrix, len(words)) 135 | for k, word := range words { 136 | wordIdx := hashWord(word) % m.VocabSize 137 | embeddings[k] = m.WordEmbedding[wordIdx] 138 | } 139 | batch.Embeddings[j] = embeddings 140 | batch.Targets[j] = example.Sentiment 141 | } 142 | batches = append(batches, batch) 143 | } 144 | return batches 145 | } 146 | 147 | // clipGradients clips gradients to prevent explosion 148 | func clipGradients(grads attention.Vector, maxNorm float64) { 149 | var norm float64 150 | for _, g := range grads { 151 | norm += g * g 152 | } 153 | norm = math.Sqrt(norm) 154 | if norm > maxNorm { 155 | scale := maxNorm / norm 156 | for i := range grads { 157 | grads[i] *= scale 158 | } 159 | } 160 | } 161 | 162 | // Train trains the model on a dataset 163 | func (m *Model) Train(examples []TrainingExample, epochs int, learningRate float64) error { 164 | rand.Seed(time.Now().UnixNano()) 165 | 166 | // Training hyperparameters 167 | batchSize := 32 168 | maxGradNorm := 5.0 169 | l2Reg := 0.01 170 | lrDecay := 0.95 171 | 172 | // Split into training and validation sets (80-20 split) 173 | rand.Shuffle(len(examples), func(i, j int) { 174 | examples[i], examples[j] = examples[j], examples[i] 175 | }) 176 | splitIdx := int(float64(len(examples)) * 0.8) 177 | trainExamples := examples[:splitIdx] 178 | valExamples := examples[splitIdx:] 179 | 180 | bestValAcc := 0.0 181 | noImprovementCount := 0 182 | currentLR := learningRate 183 | 184 | for epoch := 0; epoch < epochs; epoch++ { 185 | // Create batches 186 | batches := m.createBatch(trainExamples, batchSize) 187 | 188 | // Training phase 189 | totalLoss := 0.0 190 | correct := 0 191 | totalExamples := 0 192 | 193 | for batchIdx, batch := range batches { 194 | batchLoss := 0.0 195 | batchCorrect := 0 196 | 197 | // Accumulate gradients for the batch 198 | queryGrads := make(attention.Matrix, m.EmbedDim) 199 | keyGrads := make(attention.Matrix, m.EmbedDim) 200 | valueGrads := make(attention.Matrix, m.EmbedDim) 201 | outputGrads := make(attention.Vector, m.EmbedDim) 202 | 203 | // Initialize gradient matrices 204 | for i := 0; i < m.EmbedDim; i++ { 205 | queryGrads[i] = make(attention.Vector, m.AttnDim) 206 | keyGrads[i] = make(attention.Vector, m.AttnDim) 207 | valueGrads[i] = make(attention.Vector, m.AttnDim) 208 | } 209 | 210 | for i := range batch.Embeddings { 211 | // Forward pass 212 | queries := m.project(batch.Embeddings[i], m.QueryWeights) 213 | keys := m.project(batch.Embeddings[i], m.KeyWeights) 214 | values := m.project(batch.Embeddings[i], m.ValueWeights) 215 | 216 | // Compute attention scores 217 | scores := make([]float64, len(queries)) 218 | maxScore := -math.MaxFloat64 219 | for j := range queries { 220 | scores[j] = 0 221 | for k := 0; k < m.AttnDim; k++ { 222 | scores[j] += queries[j][k] * keys[j][k] 223 | } 224 | scores[j] /= math.Sqrt(float64(m.AttnDim)) 225 | if scores[j] > maxScore { 226 | maxScore = scores[j] 227 | } 228 | } 229 | 230 | // Softmax 231 | sumExp := 0.0 232 | for j := range scores { 233 | scores[j] = math.Exp(scores[j] - maxScore) 234 | sumExp += scores[j] 235 | } 236 | for j := range scores { 237 | scores[j] /= sumExp 238 | } 239 | 240 | // Weighted sum 241 | context := make(attention.Vector, m.EmbedDim) 242 | for j := range context { 243 | for k := range scores { 244 | for d := 0; d < m.AttnDim; d++ { 245 | context[j] += scores[k] * values[k][d] * m.ValueWeights[j][d] 246 | } 247 | } 248 | } 249 | 250 | // Final prediction 251 | logit := dotProduct(context, m.OutputWeights) 252 | prediction := sigmoid(logit) 253 | 254 | // Compute loss with L2 regularization 255 | target := batch.Targets[i] 256 | loss := -(target*math.Log(prediction+1e-10) + (1-target)*math.Log(1-prediction+1e-10)) 257 | 258 | // Add L2 regularization 259 | l2Loss := 0.0 260 | for _, w := range m.OutputWeights { 261 | l2Loss += w * w 262 | } 263 | loss += l2Reg * l2Loss * 0.5 264 | batchLoss += loss 265 | 266 | // Backward pass 267 | error := prediction - target 268 | 269 | // Output weight gradients 270 | for j := range outputGrads { 271 | outputGrads[j] += error * context[j] + l2Reg * m.OutputWeights[j] 272 | } 273 | 274 | // Attention gradients 275 | contextGrad := make(attention.Vector, m.EmbedDim) 276 | for j := range contextGrad { 277 | contextGrad[j] = error * m.OutputWeights[j] 278 | } 279 | 280 | // Update attention weight gradients 281 | for j := range batch.Embeddings[i] { 282 | if j >= m.EmbedDim { 283 | continue 284 | } 285 | for k := 0; k < m.AttnDim; k++ { 286 | queryGrad := 0.0 287 | keyGrad := 0.0 288 | valueGrad := 0.0 289 | 290 | for d := range scores { 291 | queryGrad += contextGrad[j] * scores[d] * keys[j][k] / math.Sqrt(float64(m.AttnDim)) 292 | keyGrad += contextGrad[j] * scores[d] * queries[j][k] / math.Sqrt(float64(m.AttnDim)) 293 | valueGrad += contextGrad[j] * scores[d] 294 | } 295 | 296 | queryGrads[j][k] += queryGrad 297 | keyGrads[j][k] += keyGrad 298 | valueGrads[j][k] += valueGrad 299 | } 300 | } 301 | 302 | // Calculate accuracy 303 | predictedClass := 0.0 304 | if prediction > 0.5 { 305 | predictedClass = 1.0 306 | } 307 | if predictedClass == target { 308 | batchCorrect++ 309 | } 310 | } 311 | 312 | // Clip and apply gradients 313 | clipGradients(outputGrads, maxGradNorm) 314 | for i := range m.OutputWeights { 315 | m.OutputWeights[i] -= currentLR * outputGrads[i] 316 | } 317 | 318 | for i := 0; i < m.EmbedDim; i++ { 319 | clipGradients(queryGrads[i], maxGradNorm) 320 | clipGradients(keyGrads[i], maxGradNorm) 321 | clipGradients(valueGrads[i], maxGradNorm) 322 | for j := 0; j < m.AttnDim; j++ { 323 | m.QueryWeights[i][j] -= currentLR * queryGrads[i][j] 324 | m.KeyWeights[i][j] -= currentLR * keyGrads[i][j] 325 | m.ValueWeights[i][j] -= currentLR * valueGrads[i][j] 326 | } 327 | } 328 | 329 | totalLoss += batchLoss 330 | correct += batchCorrect 331 | totalExamples += len(batch.Embeddings) 332 | 333 | // Print batch progress 334 | if (batchIdx+1) % 10 == 0 { 335 | fmt.Printf("Epoch %d, Batch %d/%d, Loss: %.4f, Accuracy: %.2f%%\n", 336 | epoch+1, batchIdx+1, len(batches), 337 | batchLoss/float64(len(batch.Embeddings)), 338 | float64(batchCorrect)*100.0/float64(len(batch.Embeddings))) 339 | } 340 | } 341 | 342 | trainLoss := totalLoss / float64(totalExamples) 343 | trainAcc := float64(correct) * 100.0 / float64(totalExamples) 344 | 345 | // Validation phase 346 | valLoss, valAcc := m.evaluate(valExamples) 347 | 348 | fmt.Printf("Epoch %d complete:\n", epoch+1) 349 | fmt.Printf(" Training - Loss: %.4f, Accuracy: %.2f%%\n", trainLoss, trainAcc) 350 | fmt.Printf(" Validation - Loss: %.4f, Accuracy: %.2f%%\n", valLoss, valAcc) 351 | 352 | // Learning rate decay and early stopping 353 | if valAcc > bestValAcc { 354 | bestValAcc = valAcc 355 | noImprovementCount = 0 356 | } else { 357 | noImprovementCount++ 358 | if noImprovementCount >= 2 { 359 | currentLR *= lrDecay 360 | fmt.Printf(" Reducing learning rate to %.6f\n", currentLR) 361 | noImprovementCount = 0 362 | } 363 | } 364 | 365 | if currentLR < learningRate * 0.01 { 366 | fmt.Println("Learning rate too small, stopping training") 367 | break 368 | } 369 | } 370 | 371 | return nil 372 | } 373 | 374 | // project applies weight matrix to input 375 | func (m *Model) project(input attention.Matrix, weights attention.Matrix) attention.Matrix { 376 | output := make(attention.Matrix, len(input)) 377 | for i := range output { 378 | output[i] = make(attention.Vector, m.AttnDim) 379 | for j := 0; j < m.AttnDim; j++ { 380 | for k := 0; k < m.EmbedDim && k < len(input[i]); k++ { 381 | output[i][j] += input[i][k] * weights[k][j] 382 | } 383 | } 384 | } 385 | return output 386 | } 387 | 388 | // evaluate computes loss and accuracy on a dataset 389 | func (m *Model) evaluate(examples []TrainingExample) (float64, float64) { 390 | totalLoss := 0.0 391 | correct := 0 392 | 393 | for _, example := range examples { 394 | prediction := m.Predict(example.Text) 395 | 396 | loss := -(example.Sentiment*math.Log(prediction+1e-10) + 397 | (1-example.Sentiment)*math.Log(1-prediction+1e-10)) 398 | totalLoss += loss 399 | 400 | predictedClass := 0.0 401 | if prediction > 0.5 { 402 | predictedClass = 1.0 403 | } 404 | if predictedClass == example.Sentiment { 405 | correct++ 406 | } 407 | } 408 | 409 | return totalLoss / float64(len(examples)), 410 | float64(correct) * 100.0 / float64(len(examples)) 411 | } 412 | 413 | // Predict makes a prediction on new text 414 | func (m *Model) Predict(text string) float64 { 415 | // Tokenize 416 | words := strings.Fields(strings.ToLower(text)) 417 | if len(words) > 100 { 418 | words = words[:100] 419 | } 420 | if len(words) == 0 { 421 | return 0.5 // Default prediction for empty text 422 | } 423 | 424 | // Get embeddings 425 | embeddings := make(attention.Matrix, len(words)) 426 | for i, word := range words { 427 | wordIdx := hashWord(word) % m.VocabSize 428 | embeddings[i] = make(attention.Vector, m.EmbedDim) 429 | copy(embeddings[i], m.WordEmbedding[wordIdx]) 430 | } 431 | 432 | // Apply attention 433 | queries := m.project(embeddings, m.QueryWeights) 434 | keys := m.project(embeddings, m.KeyWeights) 435 | values := m.project(embeddings, m.ValueWeights) 436 | 437 | // Compute attention scores 438 | scores := make([]float64, len(queries)) 439 | maxScore := -math.MaxFloat64 440 | for j := range queries { 441 | scores[j] = 0 442 | for k := 0; k < m.AttnDim; k++ { 443 | scores[j] += queries[j][k] * keys[j][k] 444 | } 445 | scores[j] /= math.Sqrt(float64(m.AttnDim)) 446 | if scores[j] > maxScore { 447 | maxScore = scores[j] 448 | } 449 | } 450 | 451 | // Softmax 452 | sumExp := 0.0 453 | for j := range scores { 454 | scores[j] = math.Exp(scores[j] - maxScore) 455 | sumExp += scores[j] 456 | } 457 | if sumExp > 0 { 458 | for j := range scores { 459 | scores[j] /= sumExp 460 | } 461 | } else { 462 | // If all scores are very negative, use uniform attention 463 | for j := range scores { 464 | scores[j] = 1.0 / float64(len(scores)) 465 | } 466 | } 467 | 468 | // Weighted sum 469 | context := make(attention.Vector, m.EmbedDim) 470 | for j := range context { 471 | for k := range scores { 472 | for d := 0; d < m.AttnDim; d++ { 473 | context[j] += scores[k] * values[k][d] * m.ValueWeights[j][d] 474 | } 475 | } 476 | } 477 | 478 | // Final prediction 479 | logit := dotProduct(context, m.OutputWeights) 480 | return sigmoid(logit) 481 | } 482 | 483 | // SaveModel saves the model to a file 484 | func (m *Model) SaveModel(filename string) error { 485 | file, err := os.Create(filename) 486 | if err != nil { 487 | return fmt.Errorf("failed to create file: %v", err) 488 | } 489 | defer file.Close() 490 | 491 | encoder := json.NewEncoder(file) 492 | if err := encoder.Encode(m); err != nil { 493 | return fmt.Errorf("failed to encode model: %v", err) 494 | } 495 | return nil 496 | } 497 | 498 | // LoadModel loads a model from a file 499 | func LoadModel(filename string) (*Model, error) { 500 | file, err := os.Open(filename) 501 | if err != nil { 502 | return nil, fmt.Errorf("failed to open file: %v", err) 503 | } 504 | defer file.Close() 505 | 506 | var model Model 507 | decoder := json.NewDecoder(file) 508 | if err := decoder.Decode(&model); err != nil { 509 | return nil, fmt.Errorf("failed to decode model: %v", err) 510 | } 511 | 512 | return &model, nil 513 | } 514 | 515 | // loadIMDBData loads the IMDB review dataset 516 | func loadIMDBData(filename string) ([]TrainingExample, error) { 517 | data, err := os.ReadFile(filename) 518 | if err != nil { 519 | return nil, fmt.Errorf("failed to read file: %v", err) 520 | } 521 | 522 | // Split into lines and parse each line as a separate JSON object 523 | lines := strings.Split(string(data), "\n") 524 | examples := make([]TrainingExample, 0, len(lines)) 525 | 526 | for _, line := range lines { 527 | if len(strings.TrimSpace(line)) == 0 { 528 | continue 529 | } 530 | 531 | var review IMDBReview 532 | if err := json.Unmarshal([]byte(line), &review); err != nil { 533 | return nil, fmt.Errorf("failed to parse JSON line: %v", err) 534 | } 535 | 536 | // Clean the text by removing HTML tags 537 | text := strings.ReplaceAll(review.Text, "
", " ") 538 | text = strings.ReplaceAll(text, "
", " ") 539 | text = strings.ReplaceAll(text, "\\/", "/") 540 | 541 | examples = append(examples, TrainingExample{ 542 | Text: text, 543 | Sentiment: review.Label, 544 | }) 545 | } 546 | 547 | return examples, nil 548 | } 549 | 550 | func main() { 551 | // Load IMDB dataset 552 | examples, err := loadIMDBData("data/imdb_1000.json") 553 | if err != nil { 554 | log.Fatalf("Failed to load IMDB data: %v", err) 555 | } 556 | fmt.Printf("Loaded %d examples from IMDB dataset\n", len(examples)) 557 | 558 | // Create and train model 559 | model := NewModel(10000, 64, 32) // vocab size: 10000, embed dim: 64, attention dim: 32 560 | 561 | fmt.Println("Training model...") 562 | if err := model.Train(examples, 10, 0.01); err != nil { // More epochs, higher learning rate 563 | log.Fatalf("Training failed: %v", err) 564 | } 565 | 566 | // Save the model 567 | if err := model.SaveModel("sentiment_model.json"); err != nil { 568 | log.Fatalf("Failed to save model: %v", err) 569 | } 570 | fmt.Println("Model saved to sentiment_model.json") 571 | 572 | // Test the model 573 | testTexts := []string{ 574 | "This movie was absolutely fantastic! The acting was superb.", 575 | "Terrible waste of time. The worst movie I've ever seen.", 576 | "An okay film, nothing special but decent entertainment.", 577 | } 578 | 579 | fmt.Println("\nTesting model predictions:") 580 | for _, text := range testTexts { 581 | prediction := model.Predict(text) 582 | fmt.Printf("Text: %s\nSentiment: %.2f\n\n", text, prediction) 583 | } 584 | } -------------------------------------------------------------------------------- /examples/serverless_search.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "log" 6 | "net/http" 7 | "sort" 8 | "strings" 9 | "sync" 10 | "time" 11 | "github.com/takara-ai/go-attention/attention" 12 | ) 13 | 14 | // SearchRequest represents the incoming search query 15 | type SearchRequest struct { 16 | Query string `json:"query"` 17 | Docs []string `json:"documents"` 18 | } 19 | 20 | // SearchResult represents a single matched document 21 | type SearchResult struct { 22 | Text string `json:"text"` 23 | Score float64 `json:"score"` 24 | Rank int `json:"rank"` 25 | } 26 | 27 | // SearchResponse represents the search results 28 | type SearchResponse struct { 29 | Results []SearchResult `json:"results"` 30 | Timing float64 `json:"timing_ms"` 31 | } 32 | 33 | // Singleton pattern for embedder to avoid reinitialization 34 | var ( 35 | embedder *SemanticEmbedder 36 | embedderOnce sync.Once 37 | ) 38 | 39 | // SemanticEmbedder handles document embedding with attention 40 | type SemanticEmbedder struct { 41 | dimension int 42 | } 43 | 44 | // Simple semantic embedding simulation 45 | func (e *SemanticEmbedder) embedWord(word string) attention.Vector { 46 | word = strings.ToLower(word) 47 | embedding := make(attention.Vector, e.dimension) 48 | 49 | // Animal-related features 50 | if strings.Contains(word, "cat") || strings.Contains(word, "kitten") { 51 | embedding[0] = 1.0 // feline 52 | embedding[1] = 0.8 // pet 53 | embedding[2] = 0.6 // animal 54 | } 55 | if strings.Contains(word, "dog") || strings.Contains(word, "canine") { 56 | embedding[0] = 0.8 // animal 57 | embedding[1] = 0.8 // pet 58 | embedding[3] = 1.0 // canine 59 | } 60 | 61 | // Activity-related features 62 | if strings.Contains(word, "play") || strings.Contains(word, "chase") { 63 | embedding[4] = 1.0 // activity 64 | embedding[5] = 0.7 // movement 65 | } 66 | if strings.Contains(word, "sat") || strings.Contains(word, "love") { 67 | embedding[6] = 0.6 // state 68 | } 69 | 70 | // Tech-related features 71 | if strings.Contains(word, "ai") || strings.Contains(word, "artificial") || 72 | strings.Contains(word, "intelligence") { 73 | embedding[8] = 1.0 // AI 74 | embedding[9] = 0.8 // technology 75 | embedding[10] = 0.7 // computing 76 | } 77 | if strings.Contains(word, "machine") || strings.Contains(word, "learning") { 78 | embedding[8] = 0.8 // AI 79 | embedding[9] = 0.9 // technology 80 | embedding[11] = 1.0 // learning 81 | } 82 | if strings.Contains(word, "neural") || strings.Contains(word, "network") { 83 | embedding[8] = 0.7 // AI 84 | embedding[9] = 0.8 // technology 85 | embedding[12] = 1.0 // networks 86 | } 87 | 88 | return embedding 89 | } 90 | 91 | // Embed a sentence into a fixed-size vector using attention 92 | func (e *SemanticEmbedder) embedSentence(sentence string) attention.Vector { 93 | words := strings.Fields(sentence) 94 | if len(words) == 0 { 95 | return make(attention.Vector, e.dimension) 96 | } 97 | 98 | // Create embeddings for each word 99 | wordEmbeddings := make(attention.Matrix, len(words)) 100 | for i, word := range words { 101 | wordEmbeddings[i] = e.embedWord(word) 102 | } 103 | 104 | // Use attention to combine word embeddings 105 | output, _, err := attention.DotProductAttention(wordEmbeddings[0], wordEmbeddings, wordEmbeddings) 106 | if err != nil { 107 | log.Printf("Error in embedSentence: %v", err) 108 | return make(attention.Vector, e.dimension) 109 | } 110 | 111 | return output 112 | } 113 | 114 | func getEmbedder() *SemanticEmbedder { 115 | embedderOnce.Do(func() { 116 | embedder = &SemanticEmbedder{ 117 | dimension: 16, 118 | } 119 | }) 120 | return embedder 121 | } 122 | 123 | // HandleSearch is the serverless entry point 124 | func HandleSearch(w http.ResponseWriter, r *http.Request) { 125 | startTime := time.Now() 126 | 127 | if r.Method != http.MethodPost { 128 | http.Error(w, "Only POST method is allowed", http.StatusMethodNotAllowed) 129 | return 130 | } 131 | 132 | // Parse request 133 | var req SearchRequest 134 | if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 135 | http.Error(w, "Invalid request format", http.StatusBadRequest) 136 | return 137 | } 138 | 139 | // Get singleton embedder (zero cold start penalty after first request) 140 | emb := getEmbedder() 141 | 142 | // Process query and documents in parallel 143 | queryEmbed := emb.embedSentence(req.Query) 144 | 145 | // Use goroutines for parallel document processing 146 | docCount := len(req.Docs) 147 | docEmbeddings := make(attention.Matrix, docCount) 148 | var wg sync.WaitGroup 149 | wg.Add(docCount) 150 | 151 | for i := range req.Docs { 152 | go func(idx int) { 153 | defer wg.Done() 154 | docEmbeddings[idx] = emb.embedSentence(req.Docs[idx]) 155 | }(i) 156 | } 157 | wg.Wait() 158 | 159 | // Compute attention scores 160 | _, weights, err := attention.DotProductAttention(queryEmbed, docEmbeddings, docEmbeddings) 161 | if err != nil { 162 | http.Error(w, "Processing error", http.StatusInternalServerError) 163 | return 164 | } 165 | 166 | // Prepare results 167 | results := make([]SearchResult, len(weights)) 168 | for i := range weights { 169 | results[i] = SearchResult{ 170 | Text: req.Docs[i], 171 | Score: weights[i] * 100, // Convert to percentage 172 | Rank: i + 1, 173 | } 174 | } 175 | 176 | // Sort results (quick sort is more efficient than bubble sort) 177 | sort.Slice(results, func(i, j int) bool { 178 | return results[i].Score > results[j].Score 179 | }) 180 | 181 | // Return top results 182 | response := SearchResponse{ 183 | Results: results[:min(3, len(results))], 184 | Timing: float64(time.Since(startTime).Microseconds()) / 1000.0, // Convert to milliseconds 185 | } 186 | 187 | w.Header().Set("Content-Type", "application/json") 188 | json.NewEncoder(w).Encode(response) 189 | } 190 | 191 | func min(a, b int) int { 192 | if a < b { 193 | return a 194 | } 195 | return b 196 | } 197 | 198 | // For local testing 199 | func main() { 200 | log.Printf("Starting semantic search server on :8080...") 201 | http.HandleFunc("/search", HandleSearch) 202 | log.Fatal(http.ListenAndServe(":8080", nil)) 203 | } -------------------------------------------------------------------------------- /examples/transformer_example.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "math/rand" 7 | "time" 8 | 9 | "github.com/takara-ai/go-attention/attention" 10 | "github.com/takara-ai/go-attention/transformer" 11 | ) 12 | 13 | func main() { 14 | // Seed random number generator 15 | rand.Seed(time.Now().UnixNano()) 16 | 17 | // Create a transformer layer configuration 18 | config := transformer.TransformerConfig{ 19 | DModel: 64, // Model dimension 20 | NumHeads: 4, // Number of attention heads 21 | DHidden: 256, // Hidden dimension in feed-forward network 22 | DropoutRate: 0.1, // Dropout rate (not used in this implementation) 23 | } 24 | 25 | // Create a transformer layer 26 | layer, err := transformer.NewTransformerLayer(config) 27 | if err != nil { 28 | log.Fatalf("Failed to create transformer layer: %v", err) 29 | } 30 | 31 | // Create sample input sequence 32 | // In this example, we'll create a sequence of 3 tokens, each with dimension DModel 33 | input := make(attention.Matrix, 3) 34 | for i := range input { 35 | input[i] = make(attention.Vector, config.DModel) 36 | // Initialize with random values 37 | for j := range input[i] { 38 | input[i][j] = rand.Float64()*2 - 1 39 | } 40 | } 41 | 42 | // Process the sequence through the transformer layer 43 | output, err := layer.Forward(input) 44 | if err != nil { 45 | log.Fatalf("Failed to process sequence: %v", err) 46 | } 47 | 48 | // Print results 49 | fmt.Println("Input sequence:") 50 | for i, vec := range input { 51 | fmt.Printf("Token %d: First 4 values: %v\n", i, vec[:4]) 52 | } 53 | 54 | fmt.Println("\nTransformed sequence:") 55 | for i, vec := range output { 56 | fmt.Printf("Token %d: First 4 values: %v\n", i, vec[:4]) 57 | } 58 | 59 | // Demonstrate multi-head attention separately 60 | fmt.Println("\nDemonstrating Multi-Head Attention:") 61 | 62 | // Create input vectors with correct dimensions 63 | batchSize := 2 64 | 65 | queries := make(attention.Matrix, batchSize) 66 | keys := make(attention.Matrix, batchSize) 67 | values := make(attention.Matrix, batchSize) 68 | 69 | // Initialize with some random values 70 | for b := 0; b < batchSize; b++ { 71 | // Create query sequence 72 | queries[b] = make(attention.Vector, config.DModel) 73 | for j := range queries[b] { 74 | queries[b][j] = rand.Float64()*2 - 1 75 | } 76 | 77 | // Create key sequence 78 | keys[b] = make(attention.Vector, config.DModel) 79 | for j := range keys[b] { 80 | keys[b][j] = rand.Float64()*2 - 1 81 | } 82 | 83 | // Create value sequence 84 | values[b] = make(attention.Vector, config.DModel) 85 | for j := range values[b] { 86 | values[b][j] = rand.Float64()*2 - 1 87 | } 88 | } 89 | 90 | // Create multi-head attention 91 | mha, err := attention.NewMultiHeadAttention(attention.MultiHeadConfig{ 92 | NumHeads: 4, 93 | DModel: config.DModel, 94 | DKey: config.DModel / 4, 95 | DValue: config.DModel / 4, 96 | DropoutRate: 0.1, 97 | }) 98 | if err != nil { 99 | log.Fatalf("Failed to create multi-head attention: %v", err) 100 | } 101 | 102 | // Process through multi-head attention 103 | attended, err := mha.Forward(queries, keys, values) 104 | if err != nil { 105 | log.Fatalf("Failed to compute multi-head attention: %v", err) 106 | } 107 | 108 | fmt.Println("\nMulti-Head Attention outputs (first 4 values for each batch):") 109 | for b := range attended { 110 | fmt.Printf("Batch %d: %v\n", b, attended[b][:4]) 111 | } 112 | } -------------------------------------------------------------------------------- /examples/vector_search.go: -------------------------------------------------------------------------------- 1 | // main.go 2 | package main 3 | 4 | import ( 5 | "fmt" 6 | "log" 7 | "os" 8 | "sort" 9 | "strconv" 10 | "strings" 11 | 12 | "github.com/takara-ai/go-attention/attention" 13 | ) 14 | 15 | // Document holds a title, some content, and an embedding vector. 16 | type Document struct { 17 | Title string 18 | Content string 19 | Embedding attention.Vector 20 | } 21 | 22 | // parseVector converts a comma-separated string (e.g. "1.0,0.0,1.0,0.0") 23 | // into an attention.Vector. 24 | func parseVector(s string) (attention.Vector, error) { 25 | parts := strings.Split(s, ",") 26 | vec := make(attention.Vector, len(parts)) 27 | for i, p := range parts { 28 | f, err := strconv.ParseFloat(strings.TrimSpace(p), 64) 29 | if err != nil { 30 | return nil, err 31 | } 32 | vec[i] = f 33 | } 34 | return vec, nil 35 | } 36 | 37 | func main() { 38 | // Define a small "database" of documents with 4-dimensional embeddings. 39 | documents := []Document{ 40 | { 41 | Title: "Cats", 42 | Content: "Cats are small, carnivorous mammals that are often valued by humans for companionship.", 43 | Embedding: attention.Vector{1.0, 0.0, 1.0, 0.0}, 44 | }, 45 | { 46 | Title: "Dogs", 47 | Content: "Dogs are domesticated mammals, known for their loyalty and companionship with humans.", 48 | Embedding: attention.Vector{0.0, 1.0, 0.0, 1.0}, 49 | }, 50 | { 51 | Title: "Neutral", 52 | Content: "This document does not lean toward any particular subject.", 53 | Embedding: attention.Vector{0.5, 0.5, 0.5, 0.5}, 54 | }, 55 | { 56 | Title: "Birds", 57 | Content: "Birds are warm-blooded vertebrates characterized by feathers and beaks.", 58 | Embedding: attention.Vector{1.0, 1.0, 0.0, 0.0}, 59 | }, 60 | } 61 | 62 | // Default query vector (for example, looking for cat‑like features). 63 | defaultQuery := attention.Vector{1.0, 0.0, 1.0, 0.0} 64 | 65 | // If the user provides a query vector as the first command-line argument, 66 | // parse it. Otherwise, use the default. 67 | var query attention.Vector 68 | var err error 69 | if len(os.Args) > 1 { 70 | query, err = parseVector(os.Args[1]) 71 | if err != nil { 72 | log.Fatalf("Error parsing query vector: %v", err) 73 | } 74 | } else { 75 | query = defaultQuery 76 | } 77 | 78 | // Ensure that the query vector dimension matches our document embeddings. 79 | if len(query) != len(documents[0].Embedding) { 80 | log.Fatalf("Query vector dimension (%d) does not match document embedding dimension (%d)", 81 | len(query), len(documents[0].Embedding)) 82 | } 83 | 84 | // Build the keys and values matrices from the document embeddings. 85 | // In this simple example, we use the embeddings for both keys and values. 86 | keys := make(attention.Matrix, len(documents)) 87 | values := make(attention.Matrix, len(documents)) 88 | for i, doc := range documents { 89 | keys[i] = doc.Embedding 90 | values[i] = doc.Embedding 91 | } 92 | 93 | // Compute dot-product attention. 94 | // The returned 'weights' slice contains the attention weight for each document. 95 | _, weights, err := attention.DotProductAttention(query, keys, values) 96 | if err != nil { 97 | log.Fatalf("Error computing attention: %v", err) 98 | } 99 | 100 | // Create a slice that holds each document's index and its corresponding attention score. 101 | type docScore struct { 102 | index int 103 | score float64 104 | } 105 | scores := make([]docScore, len(documents)) 106 | for i, w := range weights { 107 | scores[i] = docScore{index: i, score: w} 108 | } 109 | 110 | // Sort documents by descending attention weight (i.e. relevance). 111 | sort.Slice(scores, func(i, j int) bool { 112 | return scores[i].score > scores[j].score 113 | }) 114 | 115 | // Print the query and the documents ranked by their relevance. 116 | fmt.Println("Query Vector:", query) 117 | fmt.Println("\nDocument Relevance Scores:") 118 | for _, ds := range scores { 119 | doc := documents[ds.index] 120 | fmt.Printf("Title: %s\nScore: %.3f\nContent: %s\n\n", doc.Title, ds.score, doc.Content) 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/takara-ai/go-attention 2 | 3 | go 1.23.4 4 | -------------------------------------------------------------------------------- /transformer/transformer.go: -------------------------------------------------------------------------------- 1 | package transformer 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "math/rand" 7 | 8 | "github.com/takara-ai/go-attention/attention" 9 | ) 10 | 11 | // LayerNorm implements layer normalization 12 | type LayerNorm struct { 13 | Dim int 14 | Eps float64 15 | Gamma attention.Vector // Scale parameter 16 | Beta attention.Vector // Shift parameter 17 | } 18 | 19 | // NewLayerNorm creates a new layer normalization module 20 | func NewLayerNorm(dim int, eps float64) *LayerNorm { 21 | gamma := make(attention.Vector, dim) 22 | beta := make(attention.Vector, dim) 23 | 24 | // Initialize parameters 25 | for i := range gamma { 26 | gamma[i] = 1.0 // Initialize scale to 1 27 | beta[i] = 0.0 // Initialize shift to 0 28 | } 29 | 30 | return &LayerNorm{ 31 | Dim: dim, 32 | Eps: eps, 33 | Gamma: gamma, 34 | Beta: beta, 35 | } 36 | } 37 | 38 | // Forward applies layer normalization 39 | func (ln *LayerNorm) Forward(input attention.Matrix) (attention.Matrix, error) { 40 | output := make(attention.Matrix, len(input)) 41 | 42 | for i, vec := range input { 43 | if len(vec) != ln.Dim { 44 | return nil, fmt.Errorf("input dimension mismatch: expected %d, got %d", ln.Dim, len(vec)) 45 | } 46 | 47 | // Compute mean 48 | mean := 0.0 49 | for _, v := range vec { 50 | mean += v 51 | } 52 | mean /= float64(ln.Dim) 53 | 54 | // Compute variance 55 | variance := 0.0 56 | for _, v := range vec { 57 | diff := v - mean 58 | variance += diff * diff 59 | } 60 | variance /= float64(ln.Dim) 61 | 62 | // Normalize 63 | normalized := make(attention.Vector, ln.Dim) 64 | stdDev := math.Sqrt(variance + ln.Eps) 65 | for j, v := range vec { 66 | normalized[j] = ln.Gamma[j]*((v-mean)/stdDev) + ln.Beta[j] 67 | } 68 | 69 | output[i] = normalized 70 | } 71 | 72 | return output, nil 73 | } 74 | 75 | // FeedForward implements a position-wise feed-forward network 76 | type FeedForward struct { 77 | DModel int 78 | DHidden int 79 | W1, W2 attention.Matrix 80 | B1, B2 attention.Vector 81 | } 82 | 83 | // NewFeedForward creates a new feed-forward network 84 | func NewFeedForward(dModel, dHidden int) *FeedForward { 85 | ff := &FeedForward{ 86 | DModel: dModel, 87 | DHidden: dHidden, 88 | W1: make(attention.Matrix, dModel), 89 | W2: make(attention.Matrix, dHidden), 90 | B1: make(attention.Vector, dHidden), 91 | B2: make(attention.Vector, dModel), 92 | } 93 | 94 | // Initialize weights with Xavier initialization 95 | scale1 := math.Sqrt(2.0 / float64(dModel+dHidden)) 96 | scale2 := math.Sqrt(2.0 / float64(dHidden+dModel)) 97 | 98 | for i := range ff.W1 { 99 | ff.W1[i] = make(attention.Vector, dHidden) 100 | for j := range ff.W1[i] { 101 | ff.W1[i][j] = (rand.Float64() - 0.5) * scale1 102 | } 103 | } 104 | 105 | for i := range ff.W2 { 106 | ff.W2[i] = make(attention.Vector, dModel) 107 | for j := range ff.W2[i] { 108 | ff.W2[i][j] = (rand.Float64() - 0.5) * scale2 109 | } 110 | } 111 | 112 | return ff 113 | } 114 | 115 | // Forward applies the feed-forward network 116 | func (ff *FeedForward) Forward(input attention.Matrix) (attention.Matrix, error) { 117 | // First layer 118 | hidden := make(attention.Matrix, len(input)) 119 | for i, vec := range input { 120 | projected, err := projectVector(vec, ff.W1) 121 | if err != nil { 122 | return nil, fmt.Errorf("projecting first layer: %w", err) 123 | } 124 | 125 | // Add bias and apply ReLU 126 | hidden[i] = make(attention.Vector, ff.DHidden) 127 | for j := range projected { 128 | val := projected[j] + ff.B1[j] 129 | if val > 0 { 130 | hidden[i][j] = val 131 | } 132 | } 133 | } 134 | 135 | // Second layer 136 | output := make(attention.Matrix, len(input)) 137 | for i, vec := range hidden { 138 | projected, err := projectVector(vec, ff.W2) 139 | if err != nil { 140 | return nil, fmt.Errorf("projecting second layer: %w", err) 141 | } 142 | 143 | // Add bias 144 | output[i] = make(attention.Vector, ff.DModel) 145 | for j := range projected { 146 | output[i][j] = projected[j] + ff.B2[j] 147 | } 148 | } 149 | 150 | return output, nil 151 | } 152 | 153 | // Helper function to project a vector through a weight matrix 154 | func projectVector(input attention.Vector, weights attention.Matrix) (attention.Vector, error) { 155 | if len(weights) == 0 || len(weights[0]) == 0 { 156 | return nil, fmt.Errorf("empty weight matrix") 157 | } 158 | if len(input) != len(weights) { 159 | return nil, fmt.Errorf("input dimension (%d) does not match weights (%d)", len(input), len(weights)) 160 | } 161 | 162 | output := make(attention.Vector, len(weights[0])) 163 | for i := range weights[0] { 164 | for j, w := range weights { 165 | output[i] += input[j] * w[i] 166 | } 167 | } 168 | return output, nil 169 | } 170 | 171 | // TransformerConfig holds configuration for a transformer layer 172 | type TransformerConfig struct { 173 | DModel int // Model dimension 174 | NumHeads int // Number of attention heads 175 | DHidden int // Hidden dimension in feed-forward network 176 | DropoutRate float64 // Dropout rate 177 | } 178 | 179 | // TransformerLayer implements a single transformer layer 180 | type TransformerLayer struct { 181 | Config TransformerConfig 182 | SelfAttn *attention.MultiHeadAttention 183 | FeedForward *FeedForward 184 | Norm1 *LayerNorm 185 | Norm2 *LayerNorm 186 | } 187 | 188 | // NewTransformerLayer creates a new transformer layer 189 | func NewTransformerLayer(config TransformerConfig) (*TransformerLayer, error) { 190 | // Create multi-head attention 191 | attnConfig := attention.MultiHeadConfig{ 192 | NumHeads: config.NumHeads, 193 | DModel: config.DModel, 194 | DKey: config.DModel / config.NumHeads, 195 | DValue: config.DModel / config.NumHeads, 196 | DropoutRate: config.DropoutRate, 197 | } 198 | 199 | selfAttn, err := attention.NewMultiHeadAttention(attnConfig) 200 | if err != nil { 201 | return nil, fmt.Errorf("creating self-attention: %w", err) 202 | } 203 | 204 | return &TransformerLayer{ 205 | Config: config, 206 | SelfAttn: selfAttn, 207 | FeedForward: NewFeedForward(config.DModel, config.DHidden), 208 | Norm1: NewLayerNorm(config.DModel, 1e-5), 209 | Norm2: NewLayerNorm(config.DModel, 1e-5), 210 | }, nil 211 | } 212 | 213 | // Forward applies the transformer layer 214 | func (t *TransformerLayer) Forward(input attention.Matrix) (attention.Matrix, error) { 215 | // Self-attention sub-layer 216 | normalized1, err := t.Norm1.Forward(input) 217 | if err != nil { 218 | return nil, fmt.Errorf("normalizing input: %w", err) 219 | } 220 | 221 | attended, err := t.SelfAttn.Forward(normalized1, normalized1, normalized1) 222 | if err != nil { 223 | return nil, fmt.Errorf("computing self-attention: %w", err) 224 | } 225 | 226 | // Add & Norm 227 | residual1 := make(attention.Matrix, len(input)) 228 | for i := range input { 229 | residual1[i], err = attention.AddVectors(input[i], attended[i]) 230 | if err != nil { 231 | return nil, fmt.Errorf("adding residual connection: %w", err) 232 | } 233 | } 234 | 235 | // Feed-forward sub-layer 236 | normalized2, err := t.Norm2.Forward(residual1) 237 | if err != nil { 238 | return nil, fmt.Errorf("normalizing first sub-layer output: %w", err) 239 | } 240 | 241 | ffOutput, err := t.FeedForward.Forward(normalized2) 242 | if err != nil { 243 | return nil, fmt.Errorf("computing feed-forward: %w", err) 244 | } 245 | 246 | // Add & Norm 247 | output := make(attention.Matrix, len(input)) 248 | for i := range input { 249 | output[i], err = attention.AddVectors(residual1[i], ffOutput[i]) 250 | if err != nil { 251 | return nil, fmt.Errorf("adding final residual connection: %w", err) 252 | } 253 | } 254 | 255 | return output, nil 256 | } --------------------------------------------------------------------------------