├── go.mod ├── tensor ├── error.go ├── broadtcast_test.go ├── tensor_utils_test.go ├── math_test.go ├── scalar.go ├── broadcast.go ├── math.go ├── tensor.go └── tensor_utils.go ├── main.go ├── README.md └── nn └── layers └── dense.go /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/biraj21/nnfs-go 2 | 3 | go 1.22.0 4 | -------------------------------------------------------------------------------- /tensor/error.go: -------------------------------------------------------------------------------- 1 | package tensor 2 | 3 | const ( 4 | // Error message for an empty array or slice. 5 | ErrorEmptyArraySlice = "Found an empty array or slice!" 6 | 7 | // Error message for a value that's neither an array nor slice. 8 | ErrorNonArraySlice = "Value is neither an array nor a slice!" 9 | 10 | // Error message for a non-homologous tensor value. 11 | ErrorNonHomologous = "Tensor is not homologous!" 12 | 13 | // Error message for a non-2D matrix. 14 | ErrorMatMulConflictingDims = "The number of columns in the first matrix must be equal to the number of rows in the second matrix!" 15 | 16 | // Error message for a shape mismatch. 17 | ErrorShapeMismatch = "Shapes of the tensors do not match!" 18 | 19 | // Error message for incompatible reshaping. 20 | ErrorIncompatibleReshape = "Incompatible reshaping!" 21 | 22 | // Error message for tensors incompatible for broadcast 23 | ErrorCannotBroadcast = "Tensors could not be broadcast together!" 24 | ) 25 | -------------------------------------------------------------------------------- /tensor/broadtcast_test.go: -------------------------------------------------------------------------------- 1 | package tensor 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestBroadcast(t *testing.T) { 9 | t1 := WithValue[int]([]int{1, 2, 3}) 10 | t2 := WithValue[int]([][]int{ 11 | {1}, 12 | {2}, 13 | {3}, 14 | }) 15 | 16 | expectedLen := 2 17 | broadcasts := Broadcast(t1, t2) 18 | if expectedLen != len(broadcasts) { 19 | t.Fatalf("len(broadcasts): expected %v, found %v", expectedLen, len(broadcasts)) 20 | } 21 | 22 | expectedBt1 := WithValue[int]([][]int{ 23 | {1, 2, 3}, 24 | {1, 2, 3}, 25 | {1, 2, 3}, 26 | }) 27 | bt1 := broadcasts[0].ToTensor() 28 | if !reflect.DeepEqual(expectedBt1, bt1) { 29 | t.Fatalf("broadcasts[0].ToTensor(): expected %v, found %v", expectedBt1, bt1) 30 | } 31 | 32 | expectedBt2 := WithValue[int]([][]int{ 33 | {1, 1, 1}, 34 | {2, 2, 2}, 35 | {3, 3, 3}, 36 | }) 37 | bt2 := broadcasts[1].ToTensor() 38 | if !reflect.DeepEqual(expectedBt2, bt2) { 39 | t.Fatalf("broadcasts[1].ToTensor(): expected %v, found %v", expectedBt2, bt2) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /tensor/tensor_utils_test.go: -------------------------------------------------------------------------------- 1 | package tensor 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestCopyWithPadding(t *testing.T) { 9 | padWith := -2 10 | 11 | arr := []int{1, 2, 3} 12 | dest := make([]int, 5) 13 | 14 | copyWithPadding(dest, arr, padWith) 15 | 16 | expected := []int{padWith, padWith, 1, 2, 3} 17 | if !reflect.DeepEqual(expected, dest) { 18 | t.Fatalf("copyWithPadding(): expected %v, got %v", expected, dest) 19 | } 20 | } 21 | 22 | func TestAreShapesBroadcastableYes(t *testing.T) { 23 | areBroadcastable := areShapesBroadcastable( 24 | []uint{5, 1}, 25 | []uint{1, 6}, 26 | []uint{6}, 27 | []uint{}, 28 | ) 29 | 30 | expected := true 31 | if expected != areBroadcastable { 32 | t.Fatalf("areBroadcastable(): expected %v, got %v", expected, areBroadcastable) 33 | } 34 | } 35 | 36 | func TestAreShapesBroadcastableNo(t *testing.T) { 37 | areBroadcastable := areShapesBroadcastable( 38 | []uint{4, 3}, 39 | []uint{4}, 40 | ) 41 | 42 | expected := false 43 | if expected != areBroadcastable { 44 | t.Fatalf("areBroadcastable(): expected %v, got %v", expected, areBroadcastable) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/biraj21/nnfs-go/nn/layers" 7 | "github.com/biraj21/nnfs-go/tensor" 8 | ) 9 | 10 | func main() { 11 | // 3 samples x 4 features 12 | inputs := tensor.WithValue[float64]([][]float64{ 13 | {1.0, 2.0, 3.0, 2.5}, 14 | {2.0, 5.0, -1.0, 2.0}, 15 | {-1.5, 2.7, 3.3, -0.8}, 16 | }) 17 | 18 | // 4 inputs x 3 neurons (basically each column is weights per neuron) 19 | weights1 := tensor.WithValue[float64]([][]float64{ 20 | {0.2, 0.8, -0.5, 1.0}, 21 | {0.5, -0.91, 0.26, -0.5}, 22 | {-0.26, -0.27, 0.17, 0.87}, 23 | }).Transpose() 24 | 25 | // 3 neurons 26 | biases1 := tensor.WithValue[float64]([][]float64{{2.0, 3.0, 0.5}}) 27 | 28 | // 3 inputs x 3 neurons (3 inputs cuz prev layer has 3 neurons) 29 | weights2 := tensor.WithValue[float64]([][]float64{ 30 | {0.1, -0.14, 0.5}, 31 | {-0.5, 0.12, -0.33}, 32 | {-0.44, 0.73, -0.13}, 33 | }).Transpose() 34 | 35 | // 3 neurons 36 | biases2 := tensor.WithValue[float64]([][]float64{{-1, 2, -0.5}}) 37 | 38 | l1 := layers.DenseInit(3, 4, layers.DenseInitTensors{Weights: weights1, Biases: biases1}) 39 | l1Output := l1.Forward(inputs) 40 | 41 | l2 := layers.DenseInit(3, 3, layers.DenseInitTensors{Weights: weights2, Biases: biases2}) 42 | l2Output := l2.Forward(l1Output) 43 | 44 | fmt.Println(l2Output) 45 | } 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Networks from Scratch in Go 2 | 3 | This repository contains a neural network implementation and tensor operations written from scratch in Go. It is inspired by the book [Neural Networks from Scratch](https://nnfs.io/) by Harrison Kinsley (Sentdex) and Daniel Kukiela. 4 | 5 | The book is written in Python (well, English), and neural networks are implemented using NumPy. Since there is no NumPy in Go (not that I'm aware of), I am also building a bit of NumPy functionality from scratch in Go as a byproduct. Even if there was a NumPy equivalent in Go, I probably wouldn't use it. 6 | 7 | **Note:** This project is still a work in progress. It's barely started, to be honest. I'm working on it in my free time, and I'm not sure how far I'll go with this project. I'm doing it for fun and to learn more about neural networks and Go. 8 | 9 | ## Acknowledgements 10 | 11 | - Neural Networks from Scratch book by Harrison Kinsley and Daniel Kukiela. 12 | - 3Blue1Brown YouTube channel by Grant Sanderson. 13 | - Andrej Karpathy's YouTube channel. 14 | - Tomas Beuzen, for [his article on NumPy](https://www.tomasbeuzen.com/python-programming-for-data-science/chapters/chapter6-numpy-addendum.html#memory-layout-and-strides), where I learned about memory layout and strides. 15 | - Myself, for my array visualization tool at [arrayvis.netlify.app](https://arrayvis.netlify.app/). 16 | -------------------------------------------------------------------------------- /tensor/math_test.go: -------------------------------------------------------------------------------- 1 | package tensor 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | var t1 = WithValue[int]([]int{1, 2, 3}) 9 | var t2 = WithValue[int]([][]int{ 10 | {1}, 11 | {2}, 12 | {3}, 13 | }) 14 | 15 | func TestAdd(t *testing.T) { 16 | expected := WithValue[int]([][]int{ 17 | {2, 3, 4}, 18 | {3, 4, 5}, 19 | {4, 5, 6}, 20 | }) 21 | 22 | result := Add(t1, t2) 23 | if !reflect.DeepEqual(expected, result) { 24 | t.Fatalf("expected %v, got %v", expected, result) 25 | } 26 | } 27 | 28 | func TestSubtract(t *testing.T) { 29 | expected := WithValue[int]([][]int{ 30 | {0, 1, 2}, 31 | {-1, 0, 1}, 32 | {-2, -1, 0}, 33 | }) 34 | 35 | result := Subtract(t1, t2) 36 | if !reflect.DeepEqual(expected, result) { 37 | t.Fatalf("expected %v, got %v", expected, result) 38 | } 39 | } 40 | 41 | func TestMultiply(t *testing.T) { 42 | expected := WithValue[int]([][]int{ 43 | {1, 2, 3}, 44 | {2, 4, 6}, 45 | {3, 6, 9}, 46 | }) 47 | 48 | result := Multiply(t1, t2) 49 | if !reflect.DeepEqual(expected, result) { 50 | t.Fatalf("expected %v, got %v", expected, result) 51 | } 52 | } 53 | 54 | func TestDivide(t *testing.T) { 55 | expected := WithValue[int]([][]int{ 56 | {1, 2, 3}, 57 | {0, 1, 1}, 58 | {0, 0, 1}, 59 | }) 60 | 61 | result := Divide(t1, t2) 62 | if !reflect.DeepEqual(expected, result) { 63 | t.Fatalf("expected %v, got %v", expected, result) 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /tensor/scalar.go: -------------------------------------------------------------------------------- 1 | package tensor 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | ) 7 | 8 | type IntScalar interface { 9 | int | int8 | int16 | int32 | int64 10 | } 11 | 12 | type UintScalar interface { 13 | uint | uint8 | uint16 | uint32 | uint64 | uintptr 14 | } 15 | 16 | type FloatScalar interface { 17 | float32 | float64 18 | } 19 | 20 | // NumericScalarReal is a type that is a real number (int, float64, etc.) 21 | type NumericScalarReal interface { 22 | IntScalar | UintScalar | FloatScalar 23 | } 24 | 25 | // complex number is was making it difficult for me to create randomBetween function 26 | // so i removed them. skill issue 27 | // type NumericScalarComplex interface { 28 | // complex64 | complex128 29 | // } 30 | 31 | // Scalar is a type that is only a single value, not a collection of values. For example, int, float64, etc. 32 | // 33 | // Note that it doesn't include booleans because they are not numeric, and thus numeric operations can't be performed on them. 34 | type Scalar interface { 35 | NumericScalarReal 36 | } 37 | 38 | // Checks if the provided value is a Scalar or not. Panics if it's a Scalar but not of the expected type. 39 | func IsScalar[T interface{}](value interface{}) bool { 40 | val := reflect.ValueOf(value) 41 | switch val.Kind() { 42 | case reflect.Bool, 43 | reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, 44 | reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, 45 | reflect.Float32, reflect.Float64, 46 | reflect.Complex64, reflect.Complex128: 47 | 48 | switch val.Interface().(type) { 49 | case T: 50 | // eat 5-star, do nothing 51 | default: 52 | var validScalar T 53 | panic(fmt.Sprintf("Data type mismatch! Expected a Scalar of type %T, found %T", validScalar, val.Interface())) 54 | } 55 | return true 56 | default: 57 | return false 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /nn/layers/dense.go: -------------------------------------------------------------------------------- 1 | package layers 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "reflect" 7 | 8 | "github.com/biraj21/nnfs-go/tensor" 9 | ) 10 | 11 | type Dense struct { 12 | numNeurons uint 13 | numInputs uint 14 | 15 | // shape: numInputs x numNeurons 16 | weights *tensor.Tensor[float64] 17 | 18 | // shape: 1 x numNeurons 19 | biases *tensor.Tensor[float64] 20 | } 21 | 22 | // use this to initialize weights & biases with the given tensors 23 | type DenseInitTensors struct { 24 | Weights *tensor.Tensor[float64] 25 | Biases *tensor.Tensor[float64] 26 | } 27 | 28 | func DenseInit(numNeurons, numInputs uint, initTensors DenseInitTensors) Dense { 29 | layer := Dense{ 30 | numNeurons: numNeurons, 31 | numInputs: numInputs, 32 | } 33 | 34 | weightsShape := []uint{numInputs, numNeurons} 35 | biasesShape := []uint{1, numNeurons} 36 | 37 | if initTensors.Weights == nil { 38 | // initialize weights using Uniform Xavier initialization 39 | // note: i just know that this is one of the state of the art methods 40 | // for weight initialization. idk why. out of league rn 41 | x := math.Sqrt(6.0 / float64((numNeurons + numInputs))) 42 | layer.weights = tensor.WithRandom[float64](weightsShape, -x, x) 43 | } else { 44 | if reflect.DeepEqual(initTensors.Weights.Shape(), weightsShape) { 45 | layer.weights = initTensors.Weights 46 | } else { 47 | panic(fmt.Sprintf( 48 | "DenseInit(): expected weights of shape %v, got %v", 49 | weightsShape, 50 | initTensors.Weights.Shape(), 51 | )) 52 | } 53 | } 54 | 55 | if initTensors.Biases == nil { 56 | // initialize with zero 57 | layer.biases = tensor.WithShape[float64](biasesShape) 58 | } else { 59 | if reflect.DeepEqual(initTensors.Biases.Shape(), biasesShape) { 60 | layer.biases = initTensors.Biases 61 | } else { 62 | panic(fmt.Sprintf( 63 | "DenseInit(): expected weights of shape %v, got %v", 64 | biasesShape, 65 | initTensors.Biases.Shape(), 66 | )) 67 | } 68 | } 69 | 70 | return layer 71 | } 72 | 73 | func (d *Dense) Forward(inputs *tensor.Tensor[float64]) *tensor.Tensor[float64] { 74 | output := tensor.MatrixMultiplication(inputs, d.weights).Add(d.biases) 75 | return output 76 | } 77 | -------------------------------------------------------------------------------- /tensor/broadcast.go: -------------------------------------------------------------------------------- 1 | package tensor 2 | 3 | import "fmt" 4 | 5 | type BroadcastTensor[T Scalar] struct { 6 | // new shape after broadcast 7 | shape []uint 8 | 9 | // the tensor which this broadcast represents 10 | tensor *Tensor[T] 11 | } 12 | 13 | // Returns the shape of the BroadcastTensor. 14 | func (b *BroadcastTensor[T]) Shape() []uint { 15 | return b.shape 16 | } 17 | 18 | func (b *BroadcastTensor[T]) Get(indices ...int) T { 19 | if len(indices) != len(b.shape) { 20 | panic(fmt.Sprintf("Invalid number of indices %d for broadcast of shape %v", len(indices), b.shape)) 21 | } 22 | 23 | tensorIndices := make([]int, len(b.tensor.shape)) 24 | 25 | // skip extra indices. for eg, broadcast has 5 dims & actual tensor has just 3, then skip the first 2 indices 26 | copy(tensorIndices, indices[len(indices)-len(b.tensor.shape):]) 27 | 28 | // make remaining indices compatible with the tensor's shape 29 | for i, dimSize := range b.tensor.shape { 30 | if int(dimSize) <= tensorIndices[i] { 31 | tensorIndices[i] = int(dimSize) - 1 32 | } 33 | } 34 | 35 | return b.tensor.Get(tensorIndices...) 36 | } 37 | 38 | func (b *BroadcastTensor[T]) dataIndexToIndices(dataIndex int) []int { 39 | strides := calculateStrides(b.shape) 40 | 41 | indices := make([]int, len(strides)) 42 | for i := len(strides) - 1; i >= 0; i-- { 43 | // Calculate the index for the current dimension 44 | indices[i] = dataIndex % int(b.shape[i]) 45 | // Update the remaining dataIndex for the next dimension 46 | dataIndex /= int(b.shape[i]) 47 | } 48 | 49 | return indices 50 | } 51 | 52 | func (b *BroadcastTensor[T]) FlattenedGet(index int) T { 53 | indices := b.dataIndexToIndices(index) 54 | return b.Get(indices...) 55 | } 56 | 57 | // Converts the BroadcastTensor to Tensor. 58 | func (b *BroadcastTensor[T]) ToTensor() *Tensor[T] { 59 | t := WithShape[T](b.shape) 60 | 61 | for _, indices := range getAllIndices(b.shape) { 62 | t.Set(indices, b.Get(indices...)) 63 | } 64 | 65 | return t 66 | } 67 | 68 | func CanBroadcast[T Scalar](tensors ...*Tensor[T]) bool { 69 | if len(tensors) == 0 { 70 | panic("At least one tensor is required!") 71 | } 72 | 73 | if len(tensors) == 1 { 74 | return true 75 | } 76 | 77 | shapes := make([][]uint, len(tensors)) 78 | 79 | for i, t := range tensors { 80 | shapes[i] = t.shape 81 | } 82 | 83 | return areShapesBroadcastable(shapes...) 84 | } 85 | 86 | func Broadcast[T Scalar](tensors ...*Tensor[T]) []*BroadcastTensor[T] { 87 | if len(tensors) == 0 { 88 | return []*BroadcastTensor[T]{} 89 | } 90 | 91 | if !CanBroadcast(tensors...) { 92 | panic(ErrorCannotBroadcast) 93 | } 94 | 95 | maxDimensions := 0 96 | for _, t := range tensors { 97 | if len(t.shape) > maxDimensions { 98 | maxDimensions = len(t.shape) 99 | } 100 | } 101 | 102 | broadcastShape := make([]uint, maxDimensions) 103 | for i := 0; i < maxDimensions; i++ { 104 | broadcastShape[i] = 0 105 | } 106 | 107 | for _, t := range tensors { 108 | for j, size := range t.shape { 109 | i := maxDimensions - len(t.shape) + j 110 | if size > broadcastShape[i] { 111 | broadcastShape[i] = size 112 | } 113 | } 114 | } 115 | 116 | broadcast := make([]*BroadcastTensor[T], len(tensors)) 117 | for i, t := range tensors { 118 | broadcast[i] = &BroadcastTensor[T]{ 119 | shape: broadcastShape, 120 | tensor: t, 121 | } 122 | } 123 | 124 | return broadcast 125 | } 126 | -------------------------------------------------------------------------------- /tensor/math.go: -------------------------------------------------------------------------------- 1 | package tensor 2 | 3 | // Performs matrix multiplication on two 2D matrices. 4 | func MatrixMultiplication[T Scalar](t1, t2 *Tensor[T]) (result *Tensor[T]) { 5 | // check if both tensors are 2D matrices 6 | if len(t1.shape) != 2 || len(t2.shape) != 2 { 7 | panic("Both tensors must be 2D matrices!") 8 | } 9 | 10 | // check if the number of columns in the first matrix is equal to the number of rows in the second matrix 11 | if t1.shape[1] != t2.shape[0] { 12 | panic(ErrorMatMulConflictingDims) 13 | } 14 | 15 | resultShape := []uint{t1.shape[0], t2.shape[1]} 16 | result = WithShape[T](resultShape) 17 | 18 | for r := 0; r < int(resultShape[0]); r++ { 19 | for c := 0; c < int(resultShape[1]); c++ { 20 | sumOfProducts := T(0) 21 | for k := 0; k < int(t1.shape[1]); k++ { 22 | sumOfProducts += t1.Get(r, k) * t2.Get(k, c) 23 | } 24 | 25 | result.Set([]int{r, c}, sumOfProducts) 26 | } 27 | } 28 | 29 | return result 30 | } 31 | 32 | // Adds two tensors. 33 | func Add[T Scalar](t1, t2 *Tensor[T]) *Tensor[T] { 34 | broadcasts := Broadcast(t1, t2) 35 | b1 := broadcasts[0] 36 | b2 := broadcasts[1] 37 | 38 | result := WithShape[T](b1.shape) 39 | numElements := int(countElementsFromShape(result.shape)) 40 | for i := 0; i < numElements; i++ { 41 | result.data[i] = b1.FlattenedGet(i) + b2.FlattenedGet(i) 42 | } 43 | 44 | return result 45 | } 46 | 47 | // Subtracts two tensors. 48 | func Subtract[T Scalar](t1, t2 *Tensor[T]) *Tensor[T] { 49 | broadcasts := Broadcast(t1, t2) 50 | b1 := broadcasts[0] 51 | b2 := broadcasts[1] 52 | 53 | result := WithShape[T](b1.shape) 54 | numElements := int(countElementsFromShape(result.shape)) 55 | for i := 0; i < numElements; i++ { 56 | result.data[i] = b1.FlattenedGet(i) - b2.FlattenedGet(i) 57 | } 58 | 59 | return result 60 | } 61 | 62 | // Multiplies two tensors. 63 | func Multiply[T Scalar](t1, t2 *Tensor[T]) *Tensor[T] { 64 | broadcasts := Broadcast(t1, t2) 65 | b1 := broadcasts[0] 66 | b2 := broadcasts[1] 67 | 68 | result := WithShape[T](b1.shape) 69 | numElements := int(countElementsFromShape(result.shape)) 70 | for i := 0; i < numElements; i++ { 71 | result.data[i] = b1.FlattenedGet(i) * b2.FlattenedGet(i) 72 | } 73 | 74 | return result 75 | } 76 | 77 | // Divides two tensors. 78 | func Divide[T Scalar](t1, t2 *Tensor[T]) *Tensor[T] { 79 | broadcasts := Broadcast(t1, t2) 80 | b1 := broadcasts[0] 81 | b2 := broadcasts[1] 82 | 83 | result := WithShape[T](b1.shape) 84 | numElements := int(countElementsFromShape(result.shape)) 85 | for i := 0; i < numElements; i++ { 86 | result.data[i] = b1.FlattenedGet(i) / b2.FlattenedGet(i) 87 | } 88 | 89 | return result 90 | } 91 | 92 | // Returns the transpose of the given tensor. 93 | func Transpose[T Scalar](t *Tensor[T]) *Tensor[T] { 94 | numDimensions := len(t.shape) 95 | 96 | // transpose of a 0D & 1D tensors is the same as itself 97 | if numDimensions < 2 { 98 | return t.Copy() 99 | } 100 | 101 | // reverse the shape of the given tensor to obtain the shape of the transposed tensor 102 | reversedShape := make([]uint, numDimensions) 103 | for i := 0; i < numDimensions; i++ { 104 | reversedShape[i] = t.shape[numDimensions-1-i] 105 | } 106 | 107 | // initialize the transposed Tensor 108 | transposedTensor := WithShape[T](reversedShape) 109 | 110 | // get all the indices of the original tensor 111 | originalIndices := getAllIndices(t.shape) 112 | 113 | // array that will hold the reversed location 114 | reversedLocation := make([]int, numDimensions) 115 | 116 | // traverse the original tensor and set the values to the transposed tensor 117 | for _, location := range originalIndices { 118 | // reverse the current location 119 | for j := 0; j < numDimensions; j++ { 120 | reversedLocation[j] = location[numDimensions-1-j] 121 | } 122 | 123 | // set the value at the reversed location to the transposed tensor 124 | transposedTensor.Set(reversedLocation, t.Get(location...)) 125 | } 126 | 127 | return transposedTensor 128 | } 129 | -------------------------------------------------------------------------------- /tensor/tensor.go: -------------------------------------------------------------------------------- 1 | package tensor 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | ) 7 | 8 | // Tensor is a struct that represents a multi-dimensional array. 9 | type Tensor[T Scalar] struct { 10 | data []T 11 | dataType reflect.Type 12 | shape []uint 13 | strides []uint 14 | } 15 | 16 | // Returns the value of the tensor. 17 | func (t *Tensor[T]) Value() interface{} { 18 | return t.data 19 | } 20 | 21 | // Returns the shape of the tensor. 22 | func (t *Tensor[T]) Shape() []uint { 23 | return t.shape 24 | } 25 | 26 | // Returns the number of dimensions of the tensor. 27 | func (t *Tensor[T]) NDims() int { 28 | return len(t.shape) 29 | } 30 | 31 | // Returns the data type of the tensor. 32 | func (t *Tensor[T]) DataType() reflect.Type { 33 | return t.dataType 34 | } 35 | 36 | func (t *Tensor[T]) Copy() *Tensor[T] { 37 | dataCopy := make([]T, len(t.data)) 38 | copy(dataCopy, t.data) 39 | 40 | return &Tensor[T]{ 41 | data: dataCopy, 42 | dataType: t.dataType, 43 | shape: t.shape, 44 | strides: t.strides, 45 | } 46 | } 47 | 48 | func (t *Tensor[T]) Reshape(newDims ...uint) *Tensor[T] { 49 | if len(newDims) == 0 { 50 | panic("Cannot reshape to an empty shape!") 51 | } 52 | 53 | // make sure that reshaping is possible 54 | if countElementsFromShape(newDims) != countElementsFromShape(t.shape) { 55 | panic(fmt.Sprintf("Incompatible reshaping: %v -> %v", t.shape, newDims)) 56 | } 57 | 58 | // create a copy of this tensor 59 | tCopy := t.Copy() 60 | 61 | // update the shape and strides 62 | tCopy.shape = newDims 63 | tCopy.strides = calculateStrides(newDims) 64 | 65 | return tCopy 66 | } 67 | 68 | // Converts multidimensional indices to the index in the flattened data array representation. 69 | func (t *Tensor[T]) indicesToDataIndex(indices ...int) int { 70 | dataIndex := 0 71 | for i, index := range indices { 72 | dataIndex += index * int(t.strides[i]) 73 | } 74 | 75 | return dataIndex 76 | } 77 | 78 | func (t *Tensor[T]) Get(indices ...int) T { 79 | if len(indices) != len(t.shape) { 80 | panic(fmt.Sprintf("Invalid number of indices %d for tensor of shape %v", len(indices), t.shape)) 81 | } 82 | 83 | index := t.indicesToDataIndex(indices...) 84 | return t.data[index] 85 | } 86 | 87 | func (t *Tensor[T]) Set(indices []int, value T) { 88 | if len(indices) != len(t.shape) { 89 | panic(fmt.Sprintf("Invalid number of indices %d for tensor of shape %v", len(indices), t.shape)) 90 | } 91 | 92 | index := t.indicesToDataIndex(indices...) 93 | t.data[index] = value 94 | } 95 | 96 | // Returns a string representation of the tensor. 97 | func (t *Tensor[T]) String() string { 98 | if t.NDims() < 2 { 99 | return fmt.Sprintf("Tensor(%v)", prettifyTensorValue(t.data)) 100 | } 101 | 102 | // when we've more than 1 dimensions, we first convert the flattened t.data to a multi-dimensional representation 103 | // this makes it easier to work with the data while creating its string representation 104 | 105 | tensorValue := initTensorValue[T](t.shape) 106 | tensorIndices := getAllIndices(t.shape) 107 | for i := 0; i < len(tensorIndices); i++ { 108 | v := t.Get(tensorIndices[i]...) 109 | valueAt(tensorValue, tensorIndices[i]...).Set(reflect.ValueOf(v)) 110 | } 111 | 112 | return fmt.Sprintf("Tensor(%v)", prettifyTensorValue(tensorValue)) 113 | } 114 | 115 | // Adds two tensors. 116 | func (t *Tensor[T]) Add(t2 *Tensor[T]) *Tensor[T] { 117 | return Add(t, t2) 118 | } 119 | 120 | // Subtracts two tensors. 121 | func (t *Tensor[T]) Subtract(t2 *Tensor[T]) *Tensor[T] { 122 | return Subtract(t, t2) 123 | } 124 | 125 | // Multiplies two tensors. 126 | func (t *Tensor[T]) Multiply(t2 *Tensor[T]) *Tensor[T] { 127 | return Multiply(t, t2) 128 | } 129 | 130 | // Divides two tensors. 131 | func (t *Tensor[T]) Divide(t2 *Tensor[T]) *Tensor[T] { 132 | return Divide(t, t2) 133 | } 134 | 135 | // Returns the transpose of the tensor. 136 | func (t *Tensor[T]) Transpose() *Tensor[T] { 137 | return Transpose(t) 138 | } 139 | 140 | // Creates a new tensor with the given shape and initial value. Zero by default. 141 | func WithShape[T Scalar](shape []uint, initialValue ...T) *Tensor[T] { 142 | if len(initialValue) > 1 { 143 | panic("Only one initial value is allowed!") 144 | } 145 | 146 | for i, dim := range shape { 147 | if dim <= 0 { 148 | panic(fmt.Sprintf("Invalid shape: dimension %d cannot be %d", i, dim)) 149 | } 150 | } 151 | 152 | data := make([]T, countElementsFromShape(shape)) 153 | if len(initialValue) > 0 { 154 | for i := 0; i < len(data); i++ { 155 | data[i] = initialValue[0] 156 | } 157 | } 158 | 159 | // dummy variable to get the data type at runtime 160 | var dataType T 161 | 162 | return &Tensor[T]{ 163 | data: data, 164 | dataType: reflect.TypeOf(dataType), 165 | shape: shape, 166 | strides: calculateStrides(shape), 167 | } 168 | } 169 | 170 | func WithRandom[T Scalar](shape []uint, minValue, maxValue T) *Tensor[T] { 171 | for i, dim := range shape { 172 | if dim <= 0 { 173 | panic(fmt.Sprintf("Invalid shape: dimension %d cannot be %d", i, dim)) 174 | } 175 | } 176 | 177 | data := make([]T, countElementsFromShape(shape)) 178 | for i := 0; i < len(data); i++ { 179 | data[i] = randomBetween(minValue, maxValue) 180 | } 181 | 182 | // dummy variable to get the data type at runtime 183 | var dataType T 184 | 185 | return &Tensor[T]{ 186 | data: data, 187 | dataType: reflect.TypeOf(dataType), 188 | shape: shape, 189 | strides: calculateStrides(shape), 190 | } 191 | } 192 | 193 | // Create a new tensor from the given value. 194 | func WithValue[T Scalar](data interface{}) *Tensor[T] { 195 | // validate that the tensor is homogenous, i.e., all elements are of the same type 196 | ensureHomogeneous[T](data) 197 | 198 | // validate that the tensor is homologous, i.e. each list along a dimension is of the same size 199 | ensureHomologous(data) 200 | 201 | // determine the shape of the tensor 202 | shape := detectShape(data) 203 | 204 | numElements := countElementsFromShape(shape) 205 | 206 | // its length would be same as the number of elements in the tensor 207 | tensorIndices := getAllIndices(shape) 208 | 209 | tensorData := make([]T, numElements) 210 | for i := uint(0); i < numElements; i++ { 211 | tensorData[i] = valueAt(data, tensorIndices[i]...).Interface().(T) 212 | } 213 | 214 | // just dummy variable to get the data type at runtime 215 | var dataType T 216 | 217 | // create the tensor & return 218 | return &Tensor[T]{ 219 | data: tensorData, 220 | dataType: reflect.TypeOf(dataType), 221 | shape: shape, 222 | strides: calculateStrides(shape), 223 | } 224 | } 225 | -------------------------------------------------------------------------------- /tensor/tensor_utils.go: -------------------------------------------------------------------------------- 1 | package tensor 2 | 3 | import ( 4 | "fmt" 5 | "math/rand/v2" 6 | "reflect" 7 | "strings" 8 | ) 9 | 10 | // Recursively initializes the tensor with the given shape and initial value. 11 | func initTensorValue[T Scalar](shape []uint, initialValue ...T) interface{} { 12 | if len(initialValue) > 1 { 13 | panic("There cannot be more than one initial value for a tensor!") 14 | } 15 | 16 | if len(shape) == 0 { 17 | if len(initialValue) == 0 { 18 | var emptyValue T 19 | return emptyValue 20 | } 21 | 22 | return initialValue[0] 23 | } 24 | 25 | if len(shape) == 1 { 26 | slice := make([]T, shape[0]) 27 | 28 | if len(initialValue) > 0 { 29 | for i := range slice { 30 | slice[i] = initialValue[0] 31 | } 32 | } 33 | 34 | return slice 35 | } 36 | 37 | slice := make([]interface{}, shape[0]) 38 | for i := range slice { 39 | slice[i] = initTensorValue(shape[1:], initialValue...) 40 | } 41 | 42 | return slice 43 | } 44 | 45 | // Recursively validates that the values in the tensor are of Scalars and of the same type. 46 | func ensureHomogeneous[T Scalar](value interface{}) { 47 | val := reflect.ValueOf(value) 48 | kind := val.Kind() 49 | 50 | // if it's a Scalar, then just make sure that it's of type T 51 | if IsScalar[T](value) { 52 | return 53 | } 54 | 55 | // it should be a multi-dimensional array or slice 56 | if kind != reflect.Array && kind != reflect.Slice { 57 | panic(ErrorNonArraySlice) 58 | } 59 | 60 | // it should not be empty 61 | if val.Len() == 0 { 62 | panic(ErrorEmptyArraySlice) 63 | } 64 | 65 | // validate each element's type 66 | for i := 0; i < val.Len(); i++ { 67 | elem := val.Index(i) 68 | if elem.Kind() == reflect.Array || elem.Kind() == reflect.Slice { 69 | ensureHomogeneous[T](elem.Interface()) 70 | continue 71 | } 72 | 73 | if IsScalar[T](elem.Interface()) { 74 | continue 75 | } 76 | 77 | var validScalar T 78 | panic(fmt.Sprintf("Unexpected type %T in tensorand. Expected a Scalar of type %T", elem.Interface(), validScalar)) 79 | } 80 | } 81 | 82 | // Checks if the provided tensor value is homologous. Panics with an error message if it's not. 83 | // It first detects and ensure that the tensor matches the shape. Basically, it's a wrapper over detectShape() and ensureShape() functions. 84 | func ensureHomologous(value interface{}) { 85 | shape := detectShape(value) 86 | 87 | // if it's just a scalar or single-dimensional, then it's obviously homologous 88 | if len(shape) < 2 { 89 | return 90 | } 91 | 92 | ensureShape(value, shape, 0) 93 | } 94 | 95 | // Tries to detect the shape of the tensor value. 96 | func detectShape(value interface{}) (shape []uint) { 97 | val := reflect.ValueOf(value) 98 | shape = []uint{} 99 | 100 | for { 101 | // validate that it's an array or slice 102 | valueKind := val.Kind() 103 | if valueKind != reflect.Array && valueKind != reflect.Slice { 104 | break 105 | } 106 | 107 | // validate that it's not empty 108 | if val.Len() == 0 { 109 | panic(ErrorEmptyArraySlice) 110 | } 111 | 112 | // append the size of the current dimension 113 | shape = append(shape, uint(val.Len())) 114 | 115 | // go deeper to get next dimension's size 116 | val = val.Index(0) 117 | } 118 | 119 | return shape 120 | } 121 | 122 | func ensureShape(value interface{}, shape []uint, currentDim int) { 123 | if len(shape) == 0 { 124 | panic("Invalid shape") 125 | } 126 | 127 | val := reflect.ValueOf(value) 128 | if val.Kind() != reflect.Array && val.Kind() != reflect.Slice { 129 | panic(ErrorNonArraySlice) 130 | } 131 | 132 | if uint(val.Len()) != shape[currentDim] { 133 | panic(fmt.Sprintf("%s Detected shape was %v, but there's a mismatch at dimension %d with size %d. Shouldn't it be of size %d?", 134 | ErrorNonHomologous, 135 | shape, 136 | currentDim, 137 | val.Len(), 138 | shape[currentDim], 139 | )) 140 | } 141 | 142 | for i := 0; i < val.Len(); i++ { 143 | elem := val.Index(i) 144 | 145 | if elem.Kind() == reflect.Array || elem.Kind() == reflect.Slice { 146 | ensureShape(elem.Interface(), shape, currentDim+1) 147 | continue 148 | } else if currentDim < len(shape)-1 { 149 | panic(fmt.Sprintf("%s Detected shape was %v, but found an unexpected %d at dimension %d. Expected a slice or array.", 150 | ErrorNonHomologous, 151 | shape, 152 | elem.Type(), 153 | currentDim, 154 | )) 155 | } 156 | } 157 | } 158 | 159 | // Returns a pretty string representation of the tensor value. 160 | func prettifyTensorValue(value interface{}, indentation ...int) string { 161 | if IsScalar[any](value) { 162 | return fmt.Sprintf("%v", value) 163 | } 164 | 165 | val := reflect.ValueOf(value) 166 | if val.Kind() != reflect.Array && val.Kind() != reflect.Slice { 167 | panic(ErrorNonArraySlice) 168 | } 169 | 170 | if len(indentation) == 0 { 171 | indentation = []int{0} 172 | } 173 | 174 | indentationLevel := indentation[0] 175 | indentationStr := strings.Repeat(" ", indentationLevel) 176 | if val.Len() == 0 { 177 | return indentationStr + "[]" 178 | } 179 | 180 | isDeepest := true 181 | s := indentationStr + "[" 182 | for i := 0; i < val.Len(); i++ { 183 | elem := val.Index(i) 184 | 185 | if elem.Kind() == reflect.Interface { 186 | elem = elem.Elem() 187 | } 188 | 189 | if elem.Kind() == reflect.Array || elem.Kind() == reflect.Slice { 190 | s += "\n" + prettifyTensorValue(elem.Interface(), indentationLevel+2) 191 | isDeepest = false 192 | } else { 193 | s += fmt.Sprintf("%v", elem.Interface()) 194 | } 195 | 196 | // add comma if it's not the last element 197 | if i < val.Len()-1 { 198 | s += " " 199 | } 200 | 201 | // add new line if it's an array or slice 202 | if (elem.Kind() == reflect.Array || elem.Kind() == reflect.Slice) && i == val.Len()-1 { 203 | s += "\n" 204 | } 205 | } 206 | 207 | /// only add indentation before closing bracket if it's not the deepest level 208 | if !isDeepest { 209 | s += indentationStr 210 | } 211 | 212 | s += "]" 213 | return s 214 | } 215 | 216 | func calculateStrides(shape []uint) []uint { 217 | // if it's a zero-d array, then there's no strides 218 | if len(shape) == 0 { 219 | return []uint{} 220 | } 221 | 222 | // initialize strides. it's the same length as the shape 223 | strides := make([]uint, len(shape)) 224 | 225 | // last stride is always 1 226 | strides[len(shape)-1] = 1 227 | 228 | for i := len(shape) - 2; i >= 0; i-- { 229 | strides[i] = shape[i+1] * strides[i+1] 230 | } 231 | 232 | return strides 233 | } 234 | 235 | func valueAt(data interface{}, indices ...int) reflect.Value { 236 | val := reflect.ValueOf(data) 237 | 238 | // handle Scalars 239 | if len(indices) == 1 && IsScalar[any](data) { 240 | return val 241 | } 242 | 243 | if val.Kind() != reflect.Slice && val.Kind() != reflect.Array { 244 | panic(ErrorNonArraySlice) 245 | } 246 | 247 | for i, index := range indices { 248 | if val.Kind() == reflect.Interface { 249 | val = val.Elem() 250 | } 251 | 252 | if val.Kind() == reflect.Slice || val.Kind() == reflect.Array { 253 | val = val.Index(index) 254 | } else { 255 | panic(fmt.Sprintf("Invalid tensor value %v at %v", data, indices[:i+1])) 256 | } 257 | } 258 | 259 | return val 260 | } 261 | 262 | // Returns all the traversable indices for a tenson of the given shape. 263 | // 264 | // For example, if the shape is [2, 3], then the returned indices will be: 265 | // [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]] 266 | func getAllIndices(shape []uint) [][]int { 267 | numDimensions := len(shape) 268 | numElements := countElementsFromShape(shape) 269 | 270 | tensorIndices := make([][]int, numElements) 271 | 272 | if numDimensions == 0 { 273 | tensorIndices[0] = []int{0} 274 | return tensorIndices 275 | } 276 | 277 | indices := make([]int, numDimensions) 278 | 279 | // initialize first set of indices to 0 (eg. [0, 0, ..., 0]) it will be added to tensorIndices in the loop 280 | for i := range indices { 281 | indices[i] = 0 282 | } 283 | 284 | // populate the tensor indices in a pattern like 285 | // [0, 0, 0], [0, 0, 1], [0, 0, 2], ..., [0, 1, 0], [0, 1, 1], [0, 1, 2], ..., [1, 0, 0], [1, 0, 1], [1, 0, 2], ... 286 | for i := 0; i < len(tensorIndices); i++ { 287 | tensorIndices[i] = make([]int, numDimensions) 288 | copy(tensorIndices[i], indices) 289 | 290 | // increment the indices 291 | for dim := numDimensions - 1; dim >= 0; dim-- { 292 | indices[dim]++ 293 | 294 | // if the index at current dimension is still less than its size, we will keep on 295 | // incrementing that dimension's index only before moving to the next one 296 | if indices[dim] < int(shape[dim]) { 297 | break 298 | } 299 | 300 | // reset the current dimension's index to 0 before moving to the next dimension 301 | indices[dim] = 0 302 | } 303 | } 304 | 305 | return tensorIndices 306 | } 307 | 308 | func countElementsFromShape(shape []uint) uint { 309 | count := uint(1) 310 | for _, dimSize := range shape { 311 | count *= dimSize 312 | } 313 | 314 | return count 315 | } 316 | 317 | func areShapesBroadcastable(shapes ...[]uint) bool { 318 | if len(shapes) == 0 { 319 | panic("areShapesBroadcastable: At least one shape is required") 320 | } 321 | 322 | maxDimensions := 0 323 | for _, shape := range shapes { 324 | if len(shape) > maxDimensions { 325 | maxDimensions = len(shape) 326 | } 327 | } 328 | 329 | firstShape := make([]uint, maxDimensions) 330 | copyWithPadding(firstShape, shapes[0], 1) 331 | 332 | currentShape := make([]uint, maxDimensions) 333 | for i := 1; i < len(shapes); i++ { 334 | copyWithPadding(currentShape, shapes[i], 1) 335 | 336 | for j := 0; j < maxDimensions; j++ { 337 | // general broadcasting rules: https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules 338 | // two dimensions are compatible if 1. they are equal, or 2. one of them is 1 339 | f := firstShape[j] 340 | c := currentShape[j] 341 | if f != c && f != 1 && c != 1 { 342 | return false 343 | } 344 | } 345 | } 346 | 347 | return true 348 | } 349 | 350 | func copyWithPadding[T Scalar](dest []T, src []T, padWith T) { 351 | if len(dest) < len(src) { 352 | panic("Length of destination array cannot be lesser than that of source array.") 353 | } 354 | 355 | // add padding 356 | for i := 0; i < len(dest)-len(src); i++ { 357 | dest[i] = padWith 358 | } 359 | 360 | // copy the remaining values 361 | copy(dest[len(dest)-len(src):], src) 362 | } 363 | 364 | func randomBetween[T Scalar](minValue, maxValue T) T { 365 | switch any(minValue).(type) { 366 | case int, int8, int16, int32, int64: 367 | return T(rand.Int64N(int64(maxValue-minValue)) + int64(minValue)) 368 | case uint, uint8, uint16, uint32, uint64, uintptr: 369 | return T(rand.Uint64N(uint64(maxValue-minValue)) + uint64(minValue)) 370 | case float32: 371 | return T(rand.Float32()*float32(maxValue-minValue) + float32(minValue)) 372 | case float64: 373 | return T(rand.Float64()*float64(maxValue-minValue) + float64(minValue)) 374 | default: 375 | panic("Unsupported type for random number generation") 376 | } 377 | } 378 | --------------------------------------------------------------------------------