├── pkg.go ├── .travis.yml ├── README.md ├── LICENSE ├── container ├── lists.go └── lists_test.go ├── graph ├── graph.go └── adjacency.go └── matrix ├── ndarray_test.go ├── matrix.go ├── ndarray.go ├── matrix_test.go ├── dense.go ├── sparse_diag.go ├── sparse_coo.go ├── base.go ├── sparse_diag_test.go └── sparse_coo_test.go /pkg.go: -------------------------------------------------------------------------------- 1 | package numgo 2 | 3 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | sudo: false 3 | notifications: 4 | email: 5 | recipients: jesse@ccs.neu.edu 6 | on_success: change 7 | on_failure: always 8 | language: go 9 | go: 10 | - tip 11 | before_script: 12 | - go get github.com/smartystreets/goconvey/convey 13 | before_install: 14 | - go get github.com/axw/gocov/gocov 15 | - go get github.com/mattn/goveralls 16 | - if ! go get code.google.com/p/go.tools/cmd/cover; then go get golang.org/x/tools/cmd/cover; fi 17 | script: 18 | - $HOME/gopath/bin/goveralls -service=travis-ci 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | numgo [![Build Status](https://travis-ci.org/jesand/numgo.svg?branch=master)](https://travis-ci.org/jesand/numgo) [![Coverage Status](https://img.shields.io/coveralls/jesand/numgo.svg)](https://coveralls.io/r/jesand/numgo?branch=master) 2 | =============== 3 | 4 | Linear algebraic tools for Go, based on the excellent NumPy library for Python. 5 | 6 | Project Status 7 | --------------- 8 | 9 | This project is in very early stages of development, and a strong argument could 10 | be made for merging its unique features into another library, such as 11 | https://github.com/gonum/matrix. However, I would like to add more of the 12 | features from NumPy and SciPy to build this into a more robust scientific 13 | computing framework. Any comments are welcome. 14 | 15 | Documentation 16 | --------------- 17 | 18 | You can find API documentation at 19 | [godoc](https://godoc.org/github.com/jesand/numgo). 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Jesse Anderton 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /container/lists.go: -------------------------------------------------------------------------------- 1 | package container 2 | 3 | // A container of generic values 4 | type List interface { 5 | 6 | // Add one or more items to the container 7 | Push(value ...interface{}) 8 | 9 | // Remove an item from the container, or return nil 10 | Pop() interface{} 11 | 12 | // Ask how many items are in the container 13 | Size() int 14 | } 15 | 16 | // A stack is a list of anything 17 | type Stack []interface{} 18 | 19 | // Push one or more items onto the stack 20 | func (s *Stack) Push(value ...interface{}) { 21 | *s = append(*s, value...) 22 | } 23 | 24 | // Pop an item from the stack in LIFO order 25 | func (s *Stack) Pop() interface{} { 26 | if len(*s) > 0 { 27 | v := (*s)[len(*s)-1] 28 | *s = (*s)[:len(*s)-1] 29 | return v 30 | } else { 31 | return nil 32 | } 33 | } 34 | 35 | // Ask how many items are in the container 36 | func (s Stack) Size() int { 37 | return len(s) 38 | } 39 | 40 | // A queue is a list of anything 41 | type Queue []interface{} 42 | 43 | // Push one or more items onto the queue 44 | func (q *Queue) Push(value ...interface{}) { 45 | *q = append(*q, value...) 46 | } 47 | 48 | // Pop an item from the queue in FIFO order 49 | func (q *Queue) Pop() interface{} { 50 | if len(*q) > 0 { 51 | v := (*q)[0] 52 | *q = (*q)[1:] 53 | return v 54 | } else { 55 | return nil 56 | } 57 | } 58 | 59 | // Ask how many items are in the container 60 | func (s Queue) Size() int { 61 | return len(s) 62 | } 63 | -------------------------------------------------------------------------------- /container/lists_test.go: -------------------------------------------------------------------------------- 1 | package container 2 | 3 | import ( 4 | "testing" 5 | 6 | . "github.com/smartystreets/goconvey/convey" 7 | ) 8 | 9 | func TestStack(t *testing.T) { 10 | var ( 11 | s Stack 12 | first int 13 | second int 14 | third int 15 | fourth interface{} 16 | ) 17 | Convey("Given an uninitialized stack", t, func() { 18 | s = nil 19 | 20 | Convey("Then the stack is empty", func() { 21 | So(len(s), ShouldEqual, 0) 22 | }) 23 | 24 | Convey("When I push 1, 2, and 3 individually", func() { 25 | s.Push(1) 26 | s.Push(2) 27 | s.Push(3) 28 | 29 | Convey("Then the stack length is 3", func() { 30 | So(len(s), ShouldEqual, 3) 31 | So(s.Size(), ShouldEqual, 3) 32 | }) 33 | 34 | Convey("When I pop three numbers", func() { 35 | first, _ = s.Pop().(int) 36 | second, _ = s.Pop().(int) 37 | third, _ = s.Pop().(int) 38 | 39 | Convey("Then I get 3, 2, and 1", func() { 40 | So(first, ShouldEqual, 3) 41 | So(second, ShouldEqual, 2) 42 | So(third, ShouldEqual, 1) 43 | }) 44 | 45 | Convey("Then the stack is empty", func() { 46 | So(len(s), ShouldEqual, 0) 47 | }) 48 | 49 | Convey("When I pop again", func() { 50 | fourth, _ = s.Pop().(int) 51 | 52 | Convey("Then I get a zero value", func() { 53 | So(fourth, ShouldBeZeroValue) 54 | }) 55 | 56 | Convey("Then the stack is empty", func() { 57 | So(len(s), ShouldEqual, 0) 58 | }) 59 | }) 60 | }) 61 | }) 62 | 63 | Convey("When I push 1, 2, and 3 all at once", func() { 64 | s.Push(1, 2, 3) 65 | 66 | Convey("Then the stack length is 3", func() { 67 | So(len(s), ShouldEqual, 3) 68 | }) 69 | 70 | Convey("When I pop three numbers", func() { 71 | first, _ = s.Pop().(int) 72 | second, _ = s.Pop().(int) 73 | third, _ = s.Pop().(int) 74 | 75 | Convey("Then I get 3, 2, and 1", func() { 76 | So(first, ShouldEqual, 3) 77 | So(second, ShouldEqual, 2) 78 | So(third, ShouldEqual, 1) 79 | }) 80 | 81 | Convey("Then the stack is empty", func() { 82 | So(len(s), ShouldEqual, 0) 83 | }) 84 | 85 | Convey("When I pop again", func() { 86 | fourth, _ = s.Pop().(int) 87 | 88 | Convey("Then I get a zero value", func() { 89 | So(fourth, ShouldBeZeroValue) 90 | }) 91 | 92 | Convey("Then the stack is empty", func() { 93 | So(len(s), ShouldEqual, 0) 94 | }) 95 | }) 96 | }) 97 | }) 98 | }) 99 | } 100 | 101 | func TestQueue(t *testing.T) { 102 | var ( 103 | q Queue 104 | first int 105 | second int 106 | third int 107 | fourth interface{} 108 | ) 109 | Convey("Given an uninitialized queue", t, func() { 110 | q = nil 111 | 112 | Convey("Then the queue is empty", func() { 113 | So(len(q), ShouldEqual, 0) 114 | }) 115 | 116 | Convey("When I push 1, 2, and 3 individually", func() { 117 | q.Push(1) 118 | q.Push(2) 119 | q.Push(3) 120 | 121 | Convey("Then the queue length is 3", func() { 122 | So(len(q), ShouldEqual, 3) 123 | So(q.Size(), ShouldEqual, 3) 124 | }) 125 | 126 | Convey("When I pop three numbers", func() { 127 | first, _ = q.Pop().(int) 128 | second, _ = q.Pop().(int) 129 | third, _ = q.Pop().(int) 130 | 131 | Convey("Then I get 1, 2, and 3", func() { 132 | So(first, ShouldEqual, 1) 133 | So(second, ShouldEqual, 2) 134 | So(third, ShouldEqual, 3) 135 | }) 136 | 137 | Convey("Then the queue is empty", func() { 138 | So(len(q), ShouldEqual, 0) 139 | }) 140 | 141 | Convey("When I pop again", func() { 142 | fourth, _ = q.Pop().(int) 143 | 144 | Convey("Then I get a zero value", func() { 145 | So(fourth, ShouldBeZeroValue) 146 | }) 147 | 148 | Convey("Then the queue is empty", func() { 149 | So(len(q), ShouldEqual, 0) 150 | }) 151 | }) 152 | }) 153 | }) 154 | 155 | Convey("When I push 1, 2, and 3 all at once", func() { 156 | q.Push(1, 2, 3) 157 | 158 | Convey("Then the queue length is 3", func() { 159 | So(len(q), ShouldEqual, 3) 160 | }) 161 | 162 | Convey("When I pop three numbers", func() { 163 | first, _ = q.Pop().(int) 164 | second, _ = q.Pop().(int) 165 | third, _ = q.Pop().(int) 166 | 167 | Convey("Then I get 1, 2, and 3", func() { 168 | So(first, ShouldEqual, 1) 169 | So(second, ShouldEqual, 2) 170 | So(third, ShouldEqual, 3) 171 | }) 172 | 173 | Convey("Then the queue is empty", func() { 174 | So(len(q), ShouldEqual, 0) 175 | }) 176 | 177 | Convey("When I pop again", func() { 178 | fourth, _ = q.Pop().(int) 179 | 180 | Convey("Then I get a zero value", func() { 181 | So(fourth, ShouldBeZeroValue) 182 | }) 183 | 184 | Convey("Then the queue is empty", func() { 185 | So(len(q), ShouldEqual, 0) 186 | }) 187 | }) 188 | }) 189 | }) 190 | }) 191 | } 192 | -------------------------------------------------------------------------------- /graph/graph.go: -------------------------------------------------------------------------------- 1 | package graph 2 | 3 | // Identifies a vertex in a graph 4 | type NodeID int 5 | 6 | // The data stored for a node 7 | type Node struct { 8 | ID NodeID 9 | Name string 10 | InDegree int 11 | OutDegree int 12 | } 13 | 14 | // A function to invoke when visiting nodes in the graph. Gives the node, the 15 | // path along which the node was found, and the edge weights along that path. 16 | // Should return true if the search should stop. 17 | type NodeVisitor func(node Node, path []Node, weights []float64) bool 18 | 19 | // The error raised when too many nodes are added 20 | type ErrGraphCapacity struct{} 21 | 22 | func (err ErrGraphCapacity) Error() string { 23 | return "The graph cannot store any additional nodes" 24 | } 25 | 26 | // The error raised when an invalid NodeID is passed to a function 27 | type ErrInvalidNode struct{} 28 | 29 | func (err ErrInvalidNode) Error() string { 30 | return "Invalid graph node ID" 31 | } 32 | 33 | // The error returned when a topological sort fails 34 | type ErrGraphIsCyclic struct{} 35 | 36 | func (err ErrGraphIsCyclic) Error() string { 37 | return "Graph contains cycles" 38 | } 39 | 40 | // A graph, with various graph manipulation routines. Note that any method 41 | // which accepts a NodeID will panic with ErrInvalidNode unless the ID has 42 | // previously been returned by a call to AddNode(). 43 | type Graph interface { 44 | 45 | // Add an edge to the graph. If the graph is not directed, the edge is 46 | // bidirectional. If the graph is weighted, the edge is assigned weight 1. 47 | AddEdge(from, to NodeID) 48 | 49 | // Add an edge with a weight to the graph. 50 | AddEdgeWithWeight(from, to NodeID, weight float64) 51 | 52 | // Add a vertex to the graph, with optional name. If the graph's internal 53 | // storage runs out of capacity, this will panic with ErrGraphCapacity. 54 | AddNode(name string) NodeID 55 | 56 | // Returns all children of a given node 57 | Children(of NodeID) []Node 58 | 59 | // Returns all children of a given node and their corresponding edge weights 60 | ChildrenWithWeights(of NodeID) ([]Node, []float64) 61 | 62 | // Make a copy of the graph 63 | Copy() Graph 64 | 65 | // Ask whether a given edge exists in the graph 66 | HasEdge(from, to NodeID) bool 67 | 68 | // Ask whether the graph contains any edges 69 | HasEdges() bool 70 | 71 | // Ask whether the graph contains cycles 72 | HasCycles() bool 73 | 74 | // Ask whether the graph contains any nodes 75 | HasNodes() bool 76 | 77 | // Ask whether a path exists between two nodes 78 | HasPath(from, to NodeID) bool 79 | 80 | // Ask whether the graph is a directed acyclic graph 81 | IsDag() bool 82 | 83 | // Ask whether the graph is directed 84 | IsDirected() bool 85 | 86 | // Ask whether the graph is a tree 87 | IsTree() bool 88 | 89 | // Returns all nodes with out-degree zero 90 | Leaves() []Node 91 | 92 | // Get the node with the given ID 93 | Node(id NodeID) Node 94 | 95 | // Returns all parents of a given node 96 | Parents(of NodeID) []Node 97 | 98 | // Returns all parents of a given node and their corresponding edge weights 99 | ParentsWithWeights(of NodeID) ([]Node, []float64) 100 | 101 | // Remove an edge from the graph 102 | RemoveEdge(from, to NodeID) 103 | 104 | // Returns all nodes with in-degree zero 105 | Roots() []Node 106 | 107 | // Returns the shortest path between two nodes and the edge weights along 108 | // the path 109 | ShortestPath(from, to NodeID) ([]Node, []float64) 110 | 111 | // Returns the total weight along the shortest path between two nodes, or 112 | // 0 if there is no such path. 113 | ShortestPathWeight(from, to NodeID) float64 114 | 115 | // Returns the weights along all pairwise shortest paths in the graph. 116 | // In the returned array, weights[from][to] gives the minimum path weight 117 | // from the node with ID=from to the node with ID=to. The weight will be 118 | // positive infinity if there is no such path. 119 | ShortestPathWeights() (weights [][]float64) 120 | 121 | // Returns the number of nodes and edges in the graph 122 | Size() (nodes, edges int) 123 | 124 | // Returns a string representation of the graph 125 | String() string 126 | 127 | // Returns a topological sort of the graph, if possible. All nodes will 128 | // follow their ancestors in the resulting list. If there is no path between 129 | // a given pair of nodes, their ordering is chosen arbitrarily. 130 | // If the graph is not acyclic, fails with ErrGraphIsCyclic. 131 | TopologicalSort() ([]Node, error) 132 | 133 | // Create the transitive closure of the graph. Adds edges so that all nodes 134 | // reachable from a given node have an edge between them. Any added edges 135 | // will be assigned a weight equal to the shortest path weight between the 136 | // nodes in the original graph. 137 | TransitiveClosure() Graph 138 | 139 | // Create the transitive reduction of the graph. Keeps only edges necessary 140 | // to preserve all paths in the graph. The behavior of this algorithm is 141 | // undefined unless the graph is a DAG. 142 | TransitiveReduction() Graph 143 | 144 | // Visit all descendants using breadth-first search. Returns the result of 145 | // the final call to fn. 146 | VisitBFS(from NodeID, fn NodeVisitor) bool 147 | 148 | // Visit all descendants using depth-first search. Returns the result of 149 | // the final call to fn. 150 | VisitDFS(from NodeID, fn NodeVisitor) bool 151 | } 152 | -------------------------------------------------------------------------------- /matrix/ndarray_test.go: -------------------------------------------------------------------------------- 1 | package matrix 2 | 3 | import ( 4 | . "github.com/smartystreets/goconvey/convey" 5 | "testing" 6 | ) 7 | 8 | func TestA(t *testing.T) { 9 | Convey("A() panics when data doesn't match dimensions", t, func() { 10 | So(func() { A([]int{5}) }, ShouldPanic) 11 | So(func() { A([]int{5}, 1, 2, 3, 4) }, ShouldPanic) 12 | So(func() { A([]int{5}, 1, 2, 3, 4, 5, 6) }, ShouldPanic) 13 | So(func() { A([]int{2, 3}, 1, 2, 3, 4, 5) }, ShouldPanic) 14 | So(func() { A([]int{2, 3}, 1, 2, 3, 4, 5, 6, 7) }, ShouldPanic) 15 | }) 16 | 17 | Convey("Given a 1D array created with A", t, func() { 18 | m := A([]int{5}, 1, 2, 3, 4, 5) 19 | Convey("Shape() is 5", func() { 20 | So(m.Shape(), ShouldResemble, []int{5}) 21 | }) 22 | Convey("Size() is 5", func() { 23 | So(m.Size(), ShouldResemble, 5) 24 | }) 25 | Convey("The data is correct", func() { 26 | So(m.Array(), ShouldResemble, []float64{1, 2, 3, 4, 5}) 27 | }) 28 | }) 29 | 30 | Convey("Given a 2D array created with A", t, func() { 31 | m := A([]int{2, 3}, 1, 2, 3, 4, 5, 6) 32 | Convey("Shape() is 2, 3", func() { 33 | So(m.Shape(), ShouldResemble, []int{2, 3}) 34 | }) 35 | Convey("Size() is 6", func() { 36 | So(m.Size(), ShouldResemble, 6) 37 | }) 38 | Convey("The data is correct", func() { 39 | So(m.Array(), ShouldResemble, []float64{ 40 | 1, 2, 3, 41 | 4, 5, 6, 42 | }) 43 | }) 44 | }) 45 | } 46 | 47 | func TestA1(t *testing.T) { 48 | Convey("Given a 1D array created with A1", t, func() { 49 | m := A1(1, 2, 3, 4, 5) 50 | Convey("Shape() is 5", func() { 51 | So(m.Shape(), ShouldResemble, []int{5}) 52 | }) 53 | Convey("Size() is 5", func() { 54 | So(m.Size(), ShouldResemble, 5) 55 | }) 56 | Convey("The data is correct", func() { 57 | So(m.Array(), ShouldResemble, []float64{1, 2, 3, 4, 5}) 58 | }) 59 | }) 60 | } 61 | 62 | func TestA2(t *testing.T) { 63 | Convey("A2 panics given arrays of differing lengths", t, func() { 64 | So(func() { A2([]float64{1}, []float64{1, 2}) }, ShouldPanic) 65 | }) 66 | 67 | Convey("Given a 2D array created with A2", t, func() { 68 | m := A2([]float64{1, 2, 3}, []float64{4, 5, 6}) 69 | Convey("Shape() is 2, 3", func() { 70 | So(m.Shape(), ShouldResemble, []int{2, 3}) 71 | }) 72 | Convey("Size() is 6", func() { 73 | So(m.Size(), ShouldResemble, 6) 74 | }) 75 | Convey("The data is correct", func() { 76 | So(m.Array(), ShouldResemble, []float64{ 77 | 1, 2, 3, 78 | 4, 5, 6, 79 | }) 80 | }) 81 | }) 82 | } 83 | 84 | func TestDense(t *testing.T) { 85 | Convey("Given a dense array with shape 5, 3", t, func() { 86 | array := Dense(5, 3) 87 | 88 | Convey("Shape() is (5, 3)", func() { 89 | So(array.Shape(), ShouldResemble, []int{5, 3}) 90 | }) 91 | 92 | Convey("Size() is 15", func() { 93 | So(array.Size(), ShouldEqual, 15) 94 | }) 95 | 96 | Convey("All values are zero", func() { 97 | for i := 0; i < array.Size(); i++ { 98 | So(array.FlatItem(i), ShouldEqual, 0) 99 | } 100 | }) 101 | }) 102 | } 103 | 104 | func TestOnes(t *testing.T) { 105 | Convey("Given a ones array with shape 5, 3", t, func() { 106 | array := Ones(5, 3) 107 | 108 | Convey("Shape() is (5, 3)", func() { 109 | So(array.Shape(), ShouldResemble, []int{5, 3}) 110 | }) 111 | 112 | Convey("Size() is 15", func() { 113 | So(array.Size(), ShouldEqual, 15) 114 | }) 115 | 116 | Convey("All values are one", func() { 117 | for i := 0; i < array.Size(); i++ { 118 | So(array.FlatItem(i), ShouldEqual, 1) 119 | } 120 | }) 121 | }) 122 | } 123 | 124 | func TestRand(t *testing.T) { 125 | Convey("Given a random array with shape 5, 3", t, func() { 126 | array := Rand(5, 3) 127 | 128 | Convey("Shape() is (5, 3)", func() { 129 | So(array.Shape(), ShouldResemble, []int{5, 3}) 130 | }) 131 | 132 | Convey("Size() is 15", func() { 133 | So(array.Size(), ShouldEqual, 15) 134 | }) 135 | 136 | Convey("All values are in (0, 1)", func() { 137 | for i := 0; i < array.Size(); i++ { 138 | So(array.FlatItem(i), ShouldBeBetween, 0, 1) 139 | } 140 | }) 141 | }) 142 | } 143 | 144 | func TestRandN(t *testing.T) { 145 | Convey("Given a random normal array with shape 5, 3", t, func() { 146 | array := RandN(5, 3) 147 | 148 | Convey("Shape() is (5, 3)", func() { 149 | So(array.Shape(), ShouldResemble, []int{5, 3}) 150 | }) 151 | 152 | Convey("Size() is 15", func() { 153 | So(array.Size(), ShouldEqual, 15) 154 | }) 155 | 156 | Convey("All is true", func() { 157 | // There's some small chance of this being false; a one-off failure is ok 158 | So(array.All(), ShouldBeTrue) 159 | }) 160 | }) 161 | } 162 | 163 | func TestWithValue(t *testing.T) { 164 | Convey("Given a WithValue array with shape 5, 3", t, func() { 165 | array := WithValue(3.5, 5, 3) 166 | 167 | Convey("Shape() is (5, 3)", func() { 168 | So(array.Shape(), ShouldResemble, []int{5, 3}) 169 | }) 170 | 171 | Convey("Size() is 15", func() { 172 | So(array.Size(), ShouldEqual, 15) 173 | }) 174 | 175 | Convey("All values are 3.5", func() { 176 | for i := 0; i < array.Size(); i++ { 177 | So(array.FlatItem(i), ShouldEqual, 3.5) 178 | } 179 | }) 180 | }) 181 | } 182 | 183 | func TestZeros(t *testing.T) { 184 | Convey("Given a zeros array with shape 5, 3", t, func() { 185 | array := Zeros(5, 3) 186 | 187 | Convey("Shape() is (5, 3)", func() { 188 | So(array.Shape(), ShouldResemble, []int{5, 3}) 189 | }) 190 | 191 | Convey("Size() is 15", func() { 192 | So(array.Size(), ShouldEqual, 15) 193 | }) 194 | 195 | Convey("All values are zero", func() { 196 | for i := 0; i < array.Size(); i++ { 197 | So(array.FlatItem(i), ShouldEqual, 0) 198 | } 199 | }) 200 | }) 201 | } 202 | -------------------------------------------------------------------------------- /matrix/matrix.go: -------------------------------------------------------------------------------- 1 | package matrix 2 | 3 | import ( 4 | "fmt" 5 | "github.com/gonum/matrix/mat64" 6 | "math" 7 | "math/rand" 8 | ) 9 | 10 | // Distance calculations we support 11 | type DistType int 12 | 13 | const ( 14 | EuclideanDist DistType = iota 15 | ) 16 | 17 | // A two dimensional array with some special functionality 18 | type Matrix interface { 19 | NDArray 20 | 21 | // Set the values of the items on a given column 22 | ColSet(col int, values []float64) 23 | 24 | // Get a particular column for read-only access. May or may not be a copy. 25 | Col(col int) []float64 26 | 27 | // Get the number of columns 28 | Cols() int 29 | 30 | // Get a column vector containing the main diagonal elements of the matrix 31 | Diag() Matrix 32 | 33 | // Treat the rows as points, and get the pairwise distance between them. 34 | // Returns a distance matrix D such that D_i,j is the distance between 35 | // rows i and j. 36 | Dist(t DistType) Matrix 37 | 38 | // Get the matrix inverse 39 | Inverse() (Matrix, error) 40 | 41 | // Solve for x, where ax = b and a is `this`. 42 | LDivide(b Matrix) Matrix 43 | 44 | // Get the result of matrix multiplication between this and some other 45 | // matrices. Matrix dimensions must be aligned correctly for multiplication. 46 | // If A is m x p and B is p x n, then C = A.MProd(B) is the m x n matrix 47 | // with C[i, j] = \sum_{k=1}^p A[i,k] * B[k,j]. 48 | MProd(others ...Matrix) Matrix 49 | 50 | // Get the matrix norm of the specified ordinality (1, 2, infinity, ...) 51 | Norm(ord float64) float64 52 | 53 | // Set the values of the items on a given row 54 | RowSet(row int, values []float64) 55 | 56 | // Get a particular column for read-only access. May or may not be a copy. 57 | Row(row int) []float64 58 | 59 | // Get the number of rows 60 | Rows() int 61 | 62 | // Return the same matrix, but with axes transposed. The same data is used, 63 | // for speed and memory efficiency. Use Copy() to create a new array. 64 | T() Matrix 65 | 66 | // Return a sparse coo copy of the matrix. The method will panic 67 | // if any off-diagonal elements are nonzero. 68 | SparseCoo() Matrix 69 | 70 | // Return a sparse diag copy of the matrix. The method will panic 71 | // if any off-diagonal elements are nonzero. 72 | SparseDiag() Matrix 73 | } 74 | 75 | // Create a square matrix with the specified elements on the main diagonal, and 76 | // zero elsewhere. 77 | func Diag(diag ...float64) Matrix { 78 | array := SparseDiag(len(diag), len(diag)) 79 | for i, v := range diag { 80 | array.ItemSet(v, i, i) 81 | } 82 | return array 83 | } 84 | 85 | // Create a square sparse identity matrix of the specified dimensionality. 86 | func Eye(size int) Matrix { 87 | diag := make([]float64, size) 88 | for i := 0; i < size; i++ { 89 | diag[i] = 1.0 90 | } 91 | return Diag(diag...) 92 | } 93 | 94 | // Create a matrix from literal data 95 | func M(rows, cols int, array ...float64) Matrix { 96 | return A([]int{rows, cols}, array...).M() 97 | } 98 | 99 | // Create a matrix from literal data and the provided shape 100 | func M2(array ...[]float64) Matrix { 101 | return A2(array...).M() 102 | } 103 | 104 | // Create a sparse matrix of the specified dimensionality. This matrix will be 105 | // stored in coordinate format: each entry is stored as a (x, y, value) triple. 106 | // The first len(array) elements of the matrix will be initialized to the 107 | // corresponding nonzero values of array. 108 | func SparseCoo(rows, cols int, array ...float64) Matrix { 109 | m := &sparseCooF64Matrix{ 110 | shape: []int{rows, cols}, 111 | values: make([]map[int]float64, rows), 112 | } 113 | for i := 0; i < rows; i++ { 114 | m.values[i] = make(map[int]float64) 115 | } 116 | for idx, val := range array { 117 | if val != 0 { 118 | m.ItemSet(val, flatToNd(m.shape, idx)...) 119 | } 120 | } 121 | return m 122 | } 123 | 124 | // Create a sparse matrix of the specified dimensionality. This matrix will be 125 | // stored in diagonal format: the main diagonal is stored as a []float64, and 126 | // all off-diagonal values are zero. The matrix is initialized from diag, or 127 | // to all zeros. 128 | func SparseDiag(rows, cols int, diag ...float64) Matrix { 129 | if len(diag) > rows || len(diag) > cols { 130 | panic(fmt.Sprintf("Can't use %d diag elements in a %dx%d matrix", len(diag), rows, cols)) 131 | } 132 | size := rows 133 | if cols < rows { 134 | size = cols 135 | } 136 | array := &sparseDiagF64Matrix{ 137 | shape: []int{rows, cols}, 138 | diag: make([]float64, size), 139 | } 140 | for pos, v := range diag { 141 | array.diag[pos] = v 142 | } 143 | return array 144 | } 145 | 146 | // Create a sparse coo matrix, randomly populated so that approximately 147 | // density * rows * cols cells are filled with random values uniformly 148 | // distributed in [0,1). Note that if density is close to 1, this function may 149 | // be extremely slow. 150 | func SparseRand(rows, cols int, density float64) Matrix { 151 | if density < 0 || density >= 1 { 152 | panic(fmt.Sprintf("Can't create a SparseRand matrix: density %f should be in [0, 1)", density)) 153 | } 154 | matrix := SparseCoo(rows, cols) 155 | shape := []int{rows, cols} 156 | size := rows * cols 157 | count := int(float64(size) * density) 158 | for i := 0; i < count; i++ { 159 | for { 160 | coord := flatToNd(shape, rand.Intn(size)) 161 | if matrix.Item(coord...) == 0 { 162 | matrix.ItemSet(rand.Float64(), coord...) 163 | break 164 | } 165 | } 166 | } 167 | return matrix 168 | } 169 | 170 | // Create a sparse coo matrix, randomly populated so that approximately 171 | // density * rows * cols cells are filled with random values in the range 172 | // [-math.MaxFloat64, +math.MaxFloat64] distributed on the standard Normal 173 | // distribution. Note that if density is close to 1, this function may 174 | // be extremely slow. 175 | func SparseRandN(rows, cols int, density float64) Matrix { 176 | if density < 0 || density >= 1 { 177 | panic(fmt.Sprintf("Can't create a SparseRandN matrix: density %f should be in [0, 1)", density)) 178 | } 179 | matrix := SparseCoo(rows, cols) 180 | shape := []int{rows, cols} 181 | size := rows * cols 182 | count := int(float64(size) * density) 183 | for i := 0; i < count; i++ { 184 | for { 185 | coord := flatToNd(shape, rand.Intn(size)) 186 | if matrix.Item(coord...) == 0 { 187 | matrix.ItemSet(rand.NormFloat64(), coord...) 188 | break 189 | } 190 | } 191 | } 192 | return matrix 193 | } 194 | 195 | // Convert our matrix type to mat64's matrix type 196 | func ToMat64(m Matrix) *mat64.Dense { 197 | return mat64.NewDense(m.Rows(), m.Cols(), m.Array()) 198 | } 199 | 200 | // Convert mat64's matrix type to our matrix type 201 | func ToMatrix(m mat64.Matrix) Matrix { 202 | rows, cols := m.Dims() 203 | array := &denseF64Array{ 204 | shape: []int{rows, cols}, 205 | array: make([]float64, rows*cols), 206 | } 207 | for i0 := 0; i0 < rows; i0++ { 208 | for i1 := 0; i1 < cols; i1++ { 209 | array.ItemSet(m.At(i0, i1), i0, i1) 210 | } 211 | } 212 | return array 213 | } 214 | 215 | // Get the matrix inverse 216 | func Inverse(a Matrix) (Matrix, error) { 217 | inv, err := mat64.Inverse(ToMat64(a)) 218 | if err != nil { 219 | return nil, err 220 | } 221 | return ToMatrix(inv), nil 222 | } 223 | 224 | // Solve for x, where ax = b. 225 | func LDivide(a, b Matrix) Matrix { 226 | var x mat64.Dense 227 | err := x.Solve(ToMat64(a), ToMat64(b)) 228 | if err != nil { 229 | return WithValue(math.NaN(), a.Shape()[0], b.Shape()[1]).M() 230 | } 231 | return ToMatrix(&x) 232 | } 233 | 234 | // Get the matrix norm of the specified ordinality (1, 2, infinity, ...) 235 | func Norm(m Matrix, ord float64) float64 { 236 | return ToMat64(m).Norm(ord) 237 | } 238 | 239 | // Solve is an alias for LDivide 240 | func Solve(a, b Matrix) Matrix { 241 | return LDivide(a, b) 242 | } 243 | -------------------------------------------------------------------------------- /matrix/ndarray.go: -------------------------------------------------------------------------------- 1 | // The matrix package contains various utilities for dealing with raw matrices. 2 | // The interface is loosely based on the NumPy package in Python. At present, 3 | // all arrays store float64 values. 4 | // 5 | // NDArray 6 | // 7 | // The NDArray interface describes a multidimensional array. Both dense and 8 | // (2D-only) sparse implementations are available, with handy constructors for 9 | // various array types. In general, the methods in NDArray are those methods 10 | // which would make sense for an array of any dimensionality. 11 | // 12 | // The following constructors all create dense arrays. For sparse 13 | // representations, see the Matrix constructors below. 14 | // 15 | // To create a one dimensional, initialized array: 16 | // a0 := A1(1.0, 2.0, 3.0) 17 | // 18 | // To create a 2x3 array containing all zeros, use one of: 19 | // a1 := Dense(2, 3) 20 | // a2 := Zeros(2, 3) 21 | // 22 | // To create a 2x3 array containing all ones, use: 23 | // a3 := Ones(2, 3) 24 | // 25 | // To create a 2x3 array with initialized values: 26 | // a4 := A([]int{2,3}, 27 | // 1.0, 2.0, 3.0, 28 | // 4.0, 5.0, 6.0) 29 | // a5 := A2([]float64{1.0, 2.0, 3.0}, 30 | // []float64{4.0, 5.0, 6.0}) 31 | // 32 | // To create a 2x3 array initialized to some arbitrary value: 33 | // a6 := WithValue(0.1, 2, 3) 34 | // 35 | // To create a 2x3 array with random values uniformly distributed in [0, 1): 36 | // a7 := Rand(2, 3) 37 | // 38 | // To create a 2x3 array with random values on the standard normal distribution: 39 | // a8 := RandN(2, 3) 40 | // 41 | // Matrix 42 | // 43 | // The Matrix interface describes operations suited to a two-dimensional array. 44 | // Note that a two-dimensional NDArray can be trivially converted to the Matrix 45 | // type by calling arr.M(). The resulting object will generally be the same, 46 | // but converted to the Matrix type. 47 | // 48 | // The following representations are available. 49 | // A dense matrix stores all values in a []float64. 50 | // A sparse diagonal matrix stores the elements of the main diagonal in a 51 | // []float64, and assumes off-diagonal elements are zero. 52 | // A sparse coo matrix stores nonzero items by position in a map[[2]int]float64. 53 | // 54 | // When possible, function implementations take advantage of matrix sparsity. 55 | // For instance, MProd(), the matrix multiplication function, performs the 56 | // minimum amount of work required based on the types of its arguments. 57 | // 58 | // To create a 2x3 matrix with initialized values: 59 | // m0 := M([]int{2,3}, 60 | // 1.0, 2.0, 3.0, 61 | // 4.0, 5.0, 6.0) 62 | // 63 | // To create a 4x4 matrix with sparse diagonal representation: 64 | // m1 := Diag(1.0, 2.0, 3.0, 4.0) 65 | // 66 | // To create a 4x6 matrix with sparse diagonal representation: 67 | // m2 := SparseDiag(4, 6, 1.0, 2.0, 3.0, 4.0) 68 | // 69 | // To create a 4x4 identity matrix with sparse diagonal representation: 70 | // m3 := Eye(4) 71 | // 72 | // To create an unpopulated 3x4 sparse coo matrix: 73 | // m4 := SparseCoo(3, 4) 74 | // 75 | // To create a 3x4 sparse coo with half the items randomly populated: 76 | // m5 := SparseRand(3, 4, 0.5) 77 | // m6 := SparseRandN(3, 4, 0.5) 78 | package matrix 79 | 80 | import ( 81 | "fmt" 82 | "math/rand" 83 | "time" 84 | ) 85 | 86 | func init() { 87 | rand.Seed(time.Now().UnixNano()) 88 | } 89 | 90 | // ArraySparsity indicates the representation type of the matrix 91 | type ArraySparsity int 92 | 93 | const ( 94 | DenseArray ArraySparsity = iota 95 | SparseCooMatrix 96 | SparseDiagMatrix 97 | ) 98 | 99 | // A NDArray is an n-dimensional array of numbers which can be manipulated in 100 | // various ways. Concrete implementations can differ; for instance, sparse 101 | // and dense representations are possible. 102 | type NDArray interface { 103 | 104 | // Return the element-wise sum of this array and one or more others 105 | Add(others ...NDArray) NDArray 106 | 107 | // Returns true if and only if all items are nonzero 108 | All() bool 109 | 110 | // Returns true if f is true for all array elements 111 | AllF(f func(v float64) bool) bool 112 | 113 | // Returns true if f is true for all pairs of array elements in the same position 114 | AllF2(f func(v1, v2 float64) bool, other NDArray) bool 115 | 116 | // Returns true if and only if any item is nonzero 117 | Any() bool 118 | 119 | // Returns true if f is true for any array element 120 | AnyF(f func(v float64) bool) bool 121 | 122 | // Returns true if f is true for any pair of array elements in the same position 123 | AnyF2(f func(v1, v2 float64) bool, other NDArray) bool 124 | 125 | // Return the result of applying a function to all elements 126 | Apply(f func(float64) float64) NDArray 127 | 128 | // Get the matrix data as a flattened 1D array; sparse matrices will make 129 | // a copy first. 130 | Array() []float64 131 | 132 | // Create a new array by concatenating this with another array along the 133 | // specified axis. The array shapes must be equal along all other axes. 134 | // It is legal to add a new axis. 135 | Concat(axis int, others ...NDArray) NDArray 136 | 137 | // Returns a duplicate of this array 138 | Copy() NDArray 139 | 140 | // Counts the number of nonzero elements in the array 141 | CountNonzero() int 142 | 143 | // Returns a dense copy of the array 144 | Dense() NDArray 145 | 146 | // Return the element-wise quotient of this array and one or more others. 147 | // This function defines 0 / 0 = 0, so it's useful for sparse arrays. 148 | Div(others ...NDArray) NDArray 149 | 150 | // Returns true if and only if all elements in the two arrays are equal 151 | Equal(other NDArray) bool 152 | 153 | // Set all array elements to the given value 154 | Fill(value float64) 155 | 156 | // Get the coordinates for the item at the specified flat position 157 | FlatCoord(index int) []int 158 | 159 | // Get an array element in a flattened verison of this array 160 | FlatItem(index int) float64 161 | 162 | // Set an array element in a flattened version of this array 163 | FlatItemSet(value float64, index int) 164 | 165 | // Get an array element 166 | Item(index ...int) float64 167 | 168 | // Return the result of adding a scalar value to each array element 169 | ItemAdd(value float64) NDArray 170 | 171 | // Return the result of dividing each array element by a scalar value 172 | ItemDiv(value float64) NDArray 173 | 174 | // Return the reuslt of multiplying each array element by a scalar value 175 | ItemProd(value float64) NDArray 176 | 177 | // Return the result of subtracting a scalar value from each array element 178 | ItemSub(value float64) NDArray 179 | 180 | // Set an array element 181 | ItemSet(value float64, index ...int) 182 | 183 | // Returns the array as a matrix. This is only possible for 1D and 2D arrays; 184 | // 1D arrays of length n are converted into n x 1 vectors. 185 | M() Matrix 186 | 187 | // Get the value of the largest array element 188 | Max() float64 189 | 190 | // Get the value of the smallest array element 191 | Min() float64 192 | 193 | // Return the element-wise product of this array and one or more others 194 | Prod(others ...NDArray) NDArray 195 | 196 | // The number of dimensions in the matrix 197 | NDim() int 198 | 199 | // Return a copy of the array, normalized to sum to 1 200 | Normalize() NDArray 201 | 202 | // Get a 1D copy of the array, in 'C' order: rightmost axes change fastest 203 | Ravel() NDArray 204 | 205 | // A slice giving the size of all array dimensions 206 | Shape() []int 207 | 208 | // The total number of elements in the matrix 209 | Size() int 210 | 211 | // Get an array containing a rectangular slice of this array. 212 | // `from` and `to` should both have one index per axis. The indices 213 | // in `from` and `to` define the first and just-past-last indices you wish 214 | // to select along each axis. Negative indexing is supported: when slicing, 215 | // index -1 refers to the item just past the last and -arr.Size() refers to 216 | // the first element. 217 | Slice(from []int, to []int) NDArray 218 | 219 | // Ask whether the matrix has a sparse representation (useful for optimization) 220 | Sparsity() ArraySparsity 221 | 222 | // Return the element-wise difference of this array and one or more others 223 | Sub(others ...NDArray) NDArray 224 | 225 | // Return the sum of all array elements 226 | Sum() float64 227 | 228 | // Visit all matrix elements, invoking a method on each. If the method 229 | // returns false, iteration is aborted and VisitNonzero() returns false. 230 | // Otherwise, it returns true. 231 | Visit(f func(pos []int, value float64) bool) bool 232 | 233 | // Visit just nonzero elements, invoking a method on each. If the method 234 | // returns false, iteration is aborted and VisitNonzero() returns false. 235 | // Otherwise, it returns true. 236 | VisitNonzero(f func(pos []int, value float64) bool) bool 237 | } 238 | 239 | // Create an array from literal data 240 | func A(shape []int, values ...float64) NDArray { 241 | size := 1 242 | for _, sz := range shape { 243 | size *= sz 244 | } 245 | if len(values) != size { 246 | panic(fmt.Sprintf("Expected %d array elements but got %d", size, len(values))) 247 | } 248 | array := &denseF64Array{ 249 | shape: shape, 250 | array: make([]float64, len(values)), 251 | } 252 | copy(array.array[:], values[:]) 253 | return array 254 | } 255 | 256 | // Create a 1D array 257 | func A1(values ...float64) NDArray { 258 | return A([]int{len(values)}, values...) 259 | } 260 | 261 | // Create a 2D array 262 | func A2(rows ...[]float64) NDArray { 263 | array := &denseF64Array{ 264 | shape: []int{len(rows), len(rows[0])}, 265 | array: make([]float64, len(rows)*len(rows[0])), 266 | } 267 | for i0 := 0; i0 < array.shape[0]; i0++ { 268 | if len(rows[i0]) != array.shape[1] { 269 | panic(fmt.Sprintf("A2 got inconsistent array lengths %d and %d", array.shape[1], len(rows[i0]))) 270 | } 271 | for i1 := 0; i1 < array.shape[1]; i1++ { 272 | array.ItemSet(rows[i0][i1], i0, i1) 273 | } 274 | } 275 | return array 276 | } 277 | 278 | // Create an NDArray of float64 values, initialized to zero 279 | func Dense(size ...int) NDArray { 280 | totalSize := 1 281 | for _, sz := range size { 282 | totalSize *= sz 283 | } 284 | return &denseF64Array{ 285 | shape: size, 286 | array: make([]float64, totalSize), 287 | } 288 | } 289 | 290 | // Create an NDArray of float64 values, initialized to value 291 | func WithValue(value float64, size ...int) NDArray { 292 | array := Dense(size...) 293 | array.Fill(value) 294 | return array 295 | } 296 | 297 | // Create an NDArray of float64 values, initialized to zero 298 | func Zeros(size ...int) NDArray { 299 | return Dense(size...) 300 | } 301 | 302 | // Create an NDArray of float64 values, initialized to one 303 | func Ones(size ...int) NDArray { 304 | return WithValue(1.0, size...) 305 | } 306 | 307 | // Create a dense NDArray of float64 values, initialized to uniformly random 308 | // values in [0, 1). 309 | func Rand(size ...int) NDArray { 310 | array := Dense(size...) 311 | 312 | max := array.Size() 313 | for i := 0; i < max; i++ { 314 | array.FlatItemSet(rand.Float64(), i) 315 | } 316 | 317 | return array 318 | } 319 | 320 | // Create a dense NDArray of float64 values, initialized to random values in 321 | // [-math.MaxFloat64, +math.MaxFloat64] distributed on the standard Normal 322 | // distribution. 323 | func RandN(size ...int) NDArray { 324 | array := Dense(size...) 325 | 326 | max := array.Size() 327 | for i := 0; i < max; i++ { 328 | array.FlatItemSet(rand.NormFloat64(), i) 329 | } 330 | 331 | return array 332 | } 333 | -------------------------------------------------------------------------------- /matrix/matrix_test.go: -------------------------------------------------------------------------------- 1 | package matrix 2 | 3 | import ( 4 | . "github.com/smartystreets/goconvey/convey" 5 | "math" 6 | "testing" 7 | ) 8 | 9 | func TestConversion(t *testing.T) { 10 | Convey("Given a Matrix", t, func() { 11 | m := Rand(3, 5).M() 12 | Convey("Converting it ToMat64 works", func() { 13 | m2 := ToMat64(m) 14 | r, c := m2.Dims() 15 | So(r, ShouldEqual, 3) 16 | So(c, ShouldEqual, 5) 17 | for i0 := 0; i0 < 3; i0++ { 18 | for i1 := 0; i1 < 5; i1++ { 19 | So(m2.At(i0, i1), ShouldEqual, m.Item(i0, i1)) 20 | } 21 | } 22 | 23 | Convey("Converting it back to a matrix works", func() { 24 | m3 := ToMatrix(m2) 25 | So(m3.Shape(), ShouldResemble, []int{3, 5}) 26 | for i0 := 0; i0 < 3; i0++ { 27 | for i1 := 0; i1 < 5; i1++ { 28 | So(m3.Item(i0, i1), ShouldEqual, m.Item(i0, i1)) 29 | } 30 | } 31 | }) 32 | }) 33 | }) 34 | } 35 | 36 | func TestDiag(t *testing.T) { 37 | Convey("Given a diagonal array with 3 elements", t, func() { 38 | array := Diag(1, 2, 3) 39 | 40 | Convey("Shape() is (3, 3)", func() { 41 | So(array.Shape(), ShouldResemble, []int{3, 3}) 42 | }) 43 | 44 | Convey("Size() is 9", func() { 45 | So(array.Size(), ShouldEqual, 9) 46 | }) 47 | 48 | Convey("Only diagonal values are set; others are zero", func() { 49 | So(array.Array(), ShouldResemble, []float64{ 50 | 1, 0, 0, 51 | 0, 2, 0, 52 | 0, 0, 3, 53 | }) 54 | }) 55 | }) 56 | } 57 | 58 | func TestEye(t *testing.T) { 59 | Convey("Given an identity array with shape 3, 3", t, func() { 60 | array := Eye(3) 61 | 62 | Convey("Shape() is (3, 3)", func() { 63 | So(array.Shape(), ShouldResemble, []int{3, 3}) 64 | }) 65 | 66 | Convey("Size() is 9", func() { 67 | So(array.Size(), ShouldEqual, 9) 68 | }) 69 | 70 | Convey("Only diagonal values are one; others are zero", func() { 71 | for i0 := 0; i0 < 3; i0++ { 72 | for i1 := 0; i1 < 3; i1++ { 73 | if i0 == i1 { 74 | So(array.Item(i0, i1), ShouldEqual, 1) 75 | } else { 76 | So(array.Item(i0, i1), ShouldEqual, 0) 77 | } 78 | } 79 | } 80 | }) 81 | }) 82 | } 83 | 84 | func TestM(t *testing.T) { 85 | Convey("M() panics when data doesn't match dimensions", t, func() { 86 | So(func() { M(2, 3, 1, 2, 3, 4, 5) }, ShouldPanic) 87 | So(func() { M(2, 3, 1, 2, 3, 4, 5, 6, 7) }, ShouldPanic) 88 | }) 89 | 90 | Convey("Given a vector created with M", t, func() { 91 | m := M(5, 1, 1, 2, 3, 4, 5) 92 | Convey("Shape() is 5,1", func() { 93 | So(m.Shape(), ShouldResemble, []int{5, 1}) 94 | }) 95 | Convey("Size() is 5", func() { 96 | So(m.Size(), ShouldResemble, 5) 97 | }) 98 | Convey("The data is correct", func() { 99 | So(m.Array(), ShouldResemble, []float64{1, 2, 3, 4, 5}) 100 | }) 101 | }) 102 | 103 | Convey("Given a 2D matrix created with M", t, func() { 104 | m := M(2, 3, 1, 2, 3, 4, 5, 6) 105 | Convey("Shape() is 2, 3", func() { 106 | So(m.Shape(), ShouldResemble, []int{2, 3}) 107 | }) 108 | Convey("Size() is 6", func() { 109 | So(m.Size(), ShouldResemble, 6) 110 | }) 111 | Convey("The data is correct", func() { 112 | So(m.Array(), ShouldResemble, []float64{ 113 | 1, 2, 3, 114 | 4, 5, 6, 115 | }) 116 | }) 117 | }) 118 | } 119 | 120 | func TestSparseCoo(t *testing.T) { 121 | Convey("Given a 2x3 SparseCoo matrix", t, func() { 122 | m := SparseCoo(2, 3) 123 | 124 | Convey("Shape() is 2, 3", func() { 125 | So(m.Shape(), ShouldResemble, []int{2, 3}) 126 | }) 127 | 128 | Convey("Size() is 6", func() { 129 | So(m.Size(), ShouldResemble, 6) 130 | }) 131 | 132 | Convey("The data is correct", func() { 133 | So(m.Array(), ShouldResemble, []float64{ 134 | 0, 0, 0, 135 | 0, 0, 0, 136 | }) 137 | }) 138 | }) 139 | } 140 | 141 | func TestSparseDiag(t *testing.T) { 142 | Convey("SparseDiag panics when given too many diagonal elements", t, func() { 143 | So(func() { SparseDiag(2, 3, 1, 2, 3, 4, 5, 6, 7) }, ShouldPanic) 144 | }) 145 | 146 | Convey("Given a 2x3 SparseDiag matrix with diagonal elements set", t, func() { 147 | m := SparseDiag(2, 3, 1, 2) 148 | 149 | Convey("Shape() is 2, 3", func() { 150 | So(m.Shape(), ShouldResemble, []int{2, 3}) 151 | }) 152 | 153 | Convey("Size() is 6", func() { 154 | So(m.Size(), ShouldResemble, 6) 155 | }) 156 | 157 | Convey("The data is correct", func() { 158 | So(m.Array(), ShouldResemble, []float64{ 159 | 1, 0, 0, 160 | 0, 2, 0, 161 | }) 162 | }) 163 | }) 164 | 165 | Convey("Given a 3x2 SparseDiag matrix with diagonal elements set", t, func() { 166 | m := SparseDiag(3, 2, 1, 2) 167 | 168 | Convey("Shape() is 3, 2", func() { 169 | So(m.Shape(), ShouldResemble, []int{3, 2}) 170 | }) 171 | 172 | Convey("Size() is 6", func() { 173 | So(m.Size(), ShouldResemble, 6) 174 | }) 175 | 176 | Convey("The data is correct", func() { 177 | So(m.Array(), ShouldResemble, []float64{ 178 | 1, 0, 179 | 0, 2, 180 | 0, 0, 181 | }) 182 | }) 183 | }) 184 | 185 | Convey("Given a 3x3 SparseDiag matrix with no diagonal elements set", t, func() { 186 | m := SparseDiag(3, 3) 187 | 188 | Convey("Shape() is 3, 3", func() { 189 | So(m.Shape(), ShouldResemble, []int{3, 3}) 190 | }) 191 | 192 | Convey("Size() is 9", func() { 193 | So(m.Size(), ShouldResemble, 9) 194 | }) 195 | 196 | Convey("The data is correct", func() { 197 | So(m.Array(), ShouldResemble, []float64{ 198 | 0, 0, 0, 199 | 0, 0, 0, 200 | 0, 0, 0, 201 | }) 202 | }) 203 | }) 204 | 205 | Convey("Given a 3x3 SparseDiag matrix with some diagonal elements set", t, func() { 206 | m := SparseDiag(3, 3, 1, 2) 207 | 208 | Convey("Shape() is 3, 3", func() { 209 | So(m.Shape(), ShouldResemble, []int{3, 3}) 210 | }) 211 | 212 | Convey("Size() is 9", func() { 213 | So(m.Size(), ShouldResemble, 9) 214 | }) 215 | 216 | Convey("The data is correct", func() { 217 | So(m.Array(), ShouldResemble, []float64{ 218 | 1, 0, 0, 219 | 0, 2, 0, 220 | 0, 0, 0, 221 | }) 222 | }) 223 | }) 224 | 225 | Convey("Given a 3x3 SparseDiag matrix with all diagonal elements set", t, func() { 226 | m := SparseDiag(3, 3, 1, 2, 3) 227 | 228 | Convey("Shape() is 3, 3", func() { 229 | So(m.Shape(), ShouldResemble, []int{3, 3}) 230 | }) 231 | 232 | Convey("Size() is 9", func() { 233 | So(m.Size(), ShouldResemble, 9) 234 | }) 235 | 236 | Convey("The data is correct", func() { 237 | So(m.Array(), ShouldResemble, []float64{ 238 | 1, 0, 0, 239 | 0, 2, 0, 240 | 0, 0, 3, 241 | }) 242 | }) 243 | }) 244 | } 245 | 246 | func TestSparseRand(t *testing.T) { 247 | Convey("SparseRand panics with an invalid density", t, func() { 248 | So(func() { SparseRand(2, 3, -1) }, ShouldPanic) 249 | So(func() { SparseRand(2, 3, 1) }, ShouldPanic) 250 | }) 251 | 252 | Convey("Given a sparse random array with shape 2, 3 and density 0.5", t, func() { 253 | array := SparseRand(2, 3, 0.5) 254 | 255 | Convey("Shape() is (2, 3)", func() { 256 | So(array.Shape(), ShouldResemble, []int{2, 3}) 257 | }) 258 | 259 | Convey("Size() is 6", func() { 260 | So(array.Size(), ShouldEqual, 6) 261 | }) 262 | 263 | Convey("Half the values are filled", func() { 264 | So(array.CountNonzero(), ShouldEqual, 3) 265 | }) 266 | }) 267 | } 268 | 269 | func TestSparseRandN(t *testing.T) { 270 | Convey("SparseRandN panics with an invalid density", t, func() { 271 | So(func() { SparseRandN(2, 3, -1) }, ShouldPanic) 272 | So(func() { SparseRandN(2, 3, 1) }, ShouldPanic) 273 | }) 274 | 275 | Convey("Given a sparse random array with shape 2, 3 and density 0.5", t, func() { 276 | array := SparseRandN(2, 3, 0.5) 277 | 278 | Convey("Shape() is (2, 3)", func() { 279 | So(array.Shape(), ShouldResemble, []int{2, 3}) 280 | }) 281 | 282 | Convey("Size() is 6", func() { 283 | So(array.Size(), ShouldEqual, 6) 284 | }) 285 | 286 | Convey("Half the values are filled", func() { 287 | So(array.CountNonzero(), ShouldEqual, 3) 288 | }) 289 | }) 290 | } 291 | 292 | func TestInverse(t *testing.T) { 293 | Convey("Given an invertible square matrix", t, func() { 294 | m := A2( 295 | []float64{4, 7}, 296 | []float64{2, 6}, 297 | ).M() 298 | 299 | Convey("When I take the inverse", func() { 300 | mi, err := Inverse(m) 301 | So(err, ShouldBeNil) 302 | 303 | Convey("The inverse is correct", func() { 304 | So(mi.Shape(), ShouldResemble, []int{2, 2}) 305 | So(mi.Item(0, 0), ShouldBeBetween, 0.6-Eps, 0.6+Eps) 306 | So(mi.Item(0, 1), ShouldBeBetween, -0.7-Eps, 0.7+Eps) 307 | So(mi.Item(1, 0), ShouldBeBetween, -0.2-Eps, -0.2+Eps) 308 | So(mi.Item(1, 1), ShouldBeBetween, 0.4-Eps, 0.4+Eps) 309 | }) 310 | 311 | Convey("The inverse gives us back I", func() { 312 | i := m.MProd(mi) 313 | So(i.Shape(), ShouldResemble, []int{2, 2}) 314 | So(i.Item(0, 0), ShouldBeBetween, 1-Eps, 1+Eps) 315 | So(i.Item(0, 1), ShouldEqual, 0) 316 | So(i.Item(1, 0), ShouldEqual, 0) 317 | So(i.Item(1, 1), ShouldBeBetween, 1-Eps, 1+Eps) 318 | }) 319 | }) 320 | }) 321 | 322 | Convey("Given a non-invertible matrix", t, func() { 323 | m := A([]int{2, 2}, 324 | 0, 0, 325 | 1, 1).M() 326 | 327 | Convey("Inverse returns an error", func() { 328 | mi, err := Inverse(m) 329 | So(mi, ShouldBeNil) 330 | So(err, ShouldNotBeNil) 331 | }) 332 | }) 333 | } 334 | 335 | func TestLDivide(t *testing.T) { 336 | Convey("Given a simple division problem", t, func() { 337 | a := M(3, 3, 338 | 8, 1, 6, 339 | 3, 5, 7, 340 | 4, 9, 2) 341 | b := M(3, 1, 342 | 15, 343 | 15, 344 | 15) 345 | 346 | Convey("When I solve the system", func() { 347 | x := LDivide(a, b) 348 | 349 | Convey("I get the correct solution", func() { 350 | So(x.Shape(), ShouldResemble, []int{3, 1}) 351 | So(x.Item(0, 0), ShouldBeBetween, 1-Eps, 1+Eps) 352 | So(x.Item(1, 0), ShouldBeBetween, 1-Eps, 1+Eps) 353 | So(x.Item(2, 0), ShouldBeBetween, 1-Eps, 1+Eps) 354 | }) 355 | 356 | Convey("The product ax = b is true", func() { 357 | b2 := a.MProd(x) 358 | So(b2.Shape(), ShouldResemble, []int{3, 1}) 359 | So(b2.Item(0, 0), ShouldBeBetween, 15-Eps, 15+Eps) 360 | So(b2.Item(1, 0), ShouldBeBetween, 15-Eps, 15+Eps) 361 | So(b2.Item(2, 0), ShouldBeBetween, 15-Eps, 15+Eps) 362 | }) 363 | }) 364 | }) 365 | 366 | Convey("Given a singular division problem", t, func() { 367 | a := M(3, 3, 368 | 0, 0, 0, 369 | 3, 5, 7, 370 | 4, 9, 2) 371 | b := M(3, 1, 372 | 15, 373 | 15, 374 | 15) 375 | 376 | Convey("When I solve the system", func() { 377 | x := LDivide(a, b) 378 | 379 | Convey("I get NaN", func() { 380 | So(x.Shape(), ShouldResemble, []int{3, 1}) 381 | So(x.AllF(math.IsNaN), ShouldBeTrue) 382 | }) 383 | }) 384 | }) 385 | } 386 | 387 | func TestNorm(t *testing.T) { 388 | Convey("Given a 3x3 matrix", t, func() { 389 | m := M(3, 3, 390 | 1, 2, 3, 391 | 4, 5, 6, 392 | 7, 8, 9) 393 | 394 | Convey("The 1-norm is correct", func() { 395 | So(Norm(m, 1), ShouldEqual, 18) 396 | }) 397 | 398 | Convey("The 2-norm is correct", func() { 399 | So(Norm(m, 2), ShouldEqual, 16.84810335261421) 400 | }) 401 | 402 | Convey("The inf-norm is correct", func() { 403 | So(Norm(m, math.Inf(1)), ShouldEqual, 24) 404 | }) 405 | }) 406 | } 407 | 408 | func TestSolve(t *testing.T) { 409 | Convey("Given a simple division problem", t, func() { 410 | a := M(3, 3, 411 | 8, 1, 6, 412 | 3, 5, 7, 413 | 4, 9, 2) 414 | b := M(3, 1, 415 | 15, 416 | 15, 417 | 15) 418 | 419 | Convey("When I solve the system", func() { 420 | x := Solve(a, b) 421 | 422 | Convey("I get the correct solution", func() { 423 | So(x.Shape(), ShouldResemble, []int{3, 1}) 424 | So(x.Item(0, 0), ShouldBeBetween, 1-Eps, 1+Eps) 425 | So(x.Item(1, 0), ShouldBeBetween, 1-Eps, 1+Eps) 426 | So(x.Item(2, 0), ShouldBeBetween, 1-Eps, 1+Eps) 427 | }) 428 | 429 | Convey("The product ax = b is true", func() { 430 | b2 := a.MProd(x) 431 | So(b2.Shape(), ShouldResemble, []int{3, 1}) 432 | So(b2.Item(0, 0), ShouldBeBetween, 15-Eps, 15+Eps) 433 | So(b2.Item(1, 0), ShouldBeBetween, 15-Eps, 15+Eps) 434 | So(b2.Item(2, 0), ShouldBeBetween, 15-Eps, 15+Eps) 435 | }) 436 | }) 437 | }) 438 | 439 | Convey("Given a singular division problem", t, func() { 440 | a := M(3, 3, 441 | 0, 0, 0, 442 | 3, 5, 7, 443 | 4, 9, 2) 444 | b := M(3, 1, 445 | 15, 446 | 15, 447 | 15) 448 | 449 | Convey("When I solve the system", func() { 450 | x := Solve(a, b) 451 | 452 | Convey("I get NaN", func() { 453 | So(x.Shape(), ShouldResemble, []int{3, 1}) 454 | So(x.AllF(math.IsNaN), ShouldBeTrue) 455 | }) 456 | }) 457 | }) 458 | } 459 | -------------------------------------------------------------------------------- /matrix/dense.go: -------------------------------------------------------------------------------- 1 | package matrix 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // An n-dimensional NDArray with dense representation 8 | type denseF64Array struct { 9 | shape []int 10 | array []float64 11 | transpose bool 12 | } 13 | 14 | // Return the element-wise sum of this array and one or more others 15 | func (array denseF64Array) Add(other ...NDArray) NDArray { 16 | return Add(&array, other...) 17 | } 18 | 19 | // Returns true if and only if all items are nonzero 20 | func (array denseF64Array) All() bool { 21 | return All(&array) 22 | } 23 | 24 | // Returns true if f is true for all array elements 25 | func (array denseF64Array) AllF(f func(v float64) bool) bool { 26 | return AllF(&array, f) 27 | } 28 | 29 | // Returns true if f is true for all pairs of array elements in the same position 30 | func (array denseF64Array) AllF2(f func(v1, v2 float64) bool, other NDArray) bool { 31 | return AllF2(&array, f, other) 32 | } 33 | 34 | // Returns true if and only if any item is nonzero 35 | func (array denseF64Array) Any() bool { 36 | return Any(&array) 37 | } 38 | 39 | // Returns true if f is true for any array element 40 | func (array denseF64Array) AnyF(f func(v float64) bool) bool { 41 | return AnyF(&array, f) 42 | } 43 | 44 | // Returns true if f is true for any pair of array elements in the same position 45 | func (array denseF64Array) AnyF2(f func(v1, v2 float64) bool, other NDArray) bool { 46 | return AnyF2(&array, f, other) 47 | } 48 | 49 | // Return the result of applying a function to all elements 50 | func (array denseF64Array) Apply(f func(float64) float64) NDArray { 51 | result := array.copy() 52 | for i, val := range result.array { 53 | result.array[i] = f(val) 54 | } 55 | return result 56 | } 57 | 58 | // Get the matrix data as a flattened 1D array; sparse matrices will make 59 | // a copy first. 60 | func (array denseF64Array) Array() []float64 { 61 | if array.transpose { 62 | return array.copy().array 63 | } else { 64 | return array.array 65 | } 66 | } 67 | 68 | // Set the values of the items on a given column 69 | func (array denseF64Array) ColSet(col int, values []float64) { 70 | if col < 0 || col >= array.shape[1] { 71 | panic(fmt.Sprintf("ColSet can't set col %d of a %d-col array", col, array.shape[1])) 72 | } else if len(values) != array.shape[0] { 73 | panic(fmt.Sprintf("ColSet has %d rows but got %d values", array.shape[0], len(values))) 74 | } 75 | for row := 0; row < array.shape[0]; row++ { 76 | array.ItemSet(values[row], row, col) 77 | } 78 | } 79 | 80 | // Get a particular column for read-only access. May or may not be a copy. 81 | func (array denseF64Array) Col(col int) []float64 { 82 | if col < 0 || col >= array.shape[1] { 83 | panic(fmt.Sprintf("Can't get column %d from a %dx%d array", col, array.shape[0], array.shape[1])) 84 | } 85 | result := make([]float64, array.shape[0]) 86 | for row := 0; row < array.shape[0]; row++ { 87 | result[row] = array.Item(row, col) 88 | } 89 | return result 90 | } 91 | 92 | // Get the number of columns 93 | func (array denseF64Array) Cols() int { 94 | return array.shape[1] 95 | } 96 | 97 | // Create a new array by concatenating this with another array along the 98 | // specified axis. The array shapes must be equal along all other axes. 99 | // It is legal to add a new axis. 100 | func (array denseF64Array) Concat(axis int, others ...NDArray) NDArray { 101 | return Concat(axis, &array, others...) 102 | } 103 | 104 | // Returns a duplicate of this array 105 | func (array denseF64Array) Copy() NDArray { 106 | return array.copy() 107 | } 108 | 109 | // Returns a duplicate of this array, preserving type 110 | func (array denseF64Array) copy() *denseF64Array { 111 | result := &denseF64Array{ 112 | shape: make([]int, len(array.shape)), 113 | array: make([]float64, len(array.array)), 114 | } 115 | copy(result.shape[:], array.shape[:]) 116 | if array.transpose { 117 | for i0 := 0; i0 < array.shape[0]; i0++ { 118 | for i1 := 0; i1 < array.shape[1]; i1++ { 119 | result.ItemSet(array.Item(i0, i1), i0, i1) 120 | } 121 | } 122 | } else { 123 | copy(result.array[:], array.array[:]) 124 | } 125 | return result 126 | } 127 | 128 | // Counts the number of nonzero elements in the array 129 | func (array denseF64Array) CountNonzero() int { 130 | count := 0 131 | for _, v := range array.array { 132 | if v != 0 { 133 | count++ 134 | } 135 | } 136 | return count 137 | } 138 | 139 | // Returns a dense copy of the array 140 | func (array denseF64Array) Dense() NDArray { 141 | return array.copy() 142 | } 143 | 144 | // Get a column vector containing the main diagonal elements of the matrix 145 | func (array denseF64Array) Diag() Matrix { 146 | size := array.shape[0] 147 | if array.shape[1] < size { 148 | size = array.shape[1] 149 | } 150 | result := Dense(size, 1).M() 151 | for i := 0; i < size; i++ { 152 | result.ItemSet(array.Item(i, i), i, 0) 153 | } 154 | return result 155 | } 156 | 157 | // Treat the rows as points, and get the pairwise distance between them. 158 | // Returns a distance matrix D such that D_i,j is the distance between 159 | // rows i and j. 160 | func (array denseF64Array) Dist(t DistType) Matrix { 161 | return Dist(&array, t) 162 | } 163 | 164 | // Return the element-wise quotient of this array and one or more others. 165 | // This function defines 0 / 0 = 0, so it's useful for sparse arrays. 166 | func (array denseF64Array) Div(other ...NDArray) NDArray { 167 | return Div(&array, other...) 168 | } 169 | 170 | // Returns true if and only if all elements in the two arrays are equal 171 | func (array denseF64Array) Equal(other NDArray) bool { 172 | return Equal(&array, other) 173 | } 174 | 175 | // Set all array elements to the given value 176 | func (array denseF64Array) Fill(value float64) { 177 | Fill(&array, value) 178 | } 179 | 180 | // Get the coordinates for the item at the specified flat position 181 | func (array denseF64Array) FlatCoord(index int) []int { 182 | return flatToNd(array.shape, index) 183 | } 184 | 185 | // Get an array element in a flattened verison of this array 186 | func (array denseF64Array) FlatItem(index int) float64 { 187 | if array.transpose { 188 | nd := flatToNd(array.shape, index) 189 | index = ndToFlat([]int{array.shape[1], array.shape[0]}, []int{nd[1], nd[0]}) 190 | } 191 | return array.array[index] 192 | } 193 | 194 | // Set an array element in a flattened version of this array 195 | func (array denseF64Array) FlatItemSet(value float64, index int) { 196 | if array.transpose { 197 | nd := flatToNd(array.shape, index) 198 | index = ndToFlat([]int{array.shape[1], array.shape[0]}, []int{nd[1], nd[0]}) 199 | } 200 | array.array[index] = value 201 | } 202 | 203 | // Get the matrix inverse 204 | func (array denseF64Array) Inverse() (Matrix, error) { 205 | return Inverse(&array) 206 | } 207 | 208 | // Get an array element 209 | func (array denseF64Array) Item(index ...int) float64 { 210 | shape := array.shape 211 | if array.transpose { 212 | index[0], index[1] = index[1], index[0] 213 | shape = []int{array.shape[1], array.shape[0]} 214 | } 215 | return array.array[ndToFlat(shape, index)] 216 | } 217 | 218 | // Add a scalar value to each array element 219 | func (array *denseF64Array) ItemAdd(value float64) NDArray { 220 | result := array.copy() 221 | for idx := range result.array { 222 | result.array[idx] += value 223 | } 224 | return result 225 | } 226 | 227 | // Divide each array element by a scalar value 228 | func (array *denseF64Array) ItemDiv(value float64) NDArray { 229 | result := array.copy() 230 | for idx := range result.array { 231 | result.array[idx] /= value 232 | } 233 | return result 234 | } 235 | 236 | // Multiply each array element by a scalar value 237 | func (array *denseF64Array) ItemProd(value float64) NDArray { 238 | result := array.copy() 239 | for idx := range result.array { 240 | result.array[idx] *= value 241 | } 242 | return result 243 | } 244 | 245 | // Subtract a scalar value from each array element 246 | func (array *denseF64Array) ItemSub(value float64) NDArray { 247 | result := array.copy() 248 | for idx := range result.array { 249 | result.array[idx] -= value 250 | } 251 | return result 252 | } 253 | 254 | // Set an array element 255 | func (array denseF64Array) ItemSet(value float64, index ...int) { 256 | shape := array.shape 257 | if array.transpose { 258 | index[0], index[1] = index[1], index[0] 259 | shape = []int{array.shape[1], array.shape[0]} 260 | } 261 | array.array[ndToFlat(shape, index)] = value 262 | } 263 | 264 | // Solve for x, where ax = b. 265 | func (array denseF64Array) LDivide(b Matrix) Matrix { 266 | return LDivide(&array, b) 267 | } 268 | 269 | // Get the result of matrix multiplication between this and some other 270 | // array(s). All arrays must have two dimensions, and the dimensions must 271 | // be aligned correctly for multiplication. 272 | // If A is m x p and B is p x n, then C = A.MProd(B) is the m x n matrix 273 | // with C[i, j] = \sum_{k=1}^p A[i,k] * B[k,j]. 274 | func (array denseF64Array) MProd(others ...Matrix) Matrix { 275 | return MProd(&array, others...) 276 | } 277 | 278 | // Get the value of the largest array element 279 | func (array denseF64Array) Max() float64 { 280 | return Max(&array) 281 | } 282 | 283 | // Get the value of the smallest array element 284 | func (array denseF64Array) Min() float64 { 285 | return Min(&array) 286 | } 287 | 288 | // The number of dimensions in the matrix 289 | func (array denseF64Array) NDim() int { 290 | return len(array.shape) 291 | } 292 | 293 | // Get the matrix norm of the specified ordinality (1, 2, infinity, ...) 294 | func (array denseF64Array) Norm(ord float64) float64 { 295 | return Norm(&array, ord) 296 | } 297 | 298 | // Return a copy of the array, normalized to sum to 1 299 | func (array *denseF64Array) Normalize() NDArray { 300 | return Normalize(array) 301 | } 302 | 303 | // Return the element-wise product of this array and one or more others 304 | func (array denseF64Array) Prod(other ...NDArray) NDArray { 305 | return Prod(&array, other...) 306 | } 307 | 308 | // Get a 1D copy of the array, in 'C' order: rightmost axes change fastest 309 | func (array denseF64Array) Ravel() NDArray { 310 | return Ravel(&array) 311 | } 312 | 313 | // Set the values of the items on a given row 314 | func (array denseF64Array) RowSet(row int, values []float64) { 315 | if row < 0 || row >= array.shape[0] { 316 | panic(fmt.Sprintf("RowSet can't set row %d of a %d-row array", row, array.shape[0])) 317 | } else if len(values) != array.shape[1] { 318 | panic(fmt.Sprintf("RowSet has %d columns but got %d values", array.shape[1], len(values))) 319 | } 320 | for col := 0; col < array.shape[1]; col++ { 321 | array.ItemSet(values[col], row, col) 322 | } 323 | } 324 | 325 | // Get a particular row for read-only access. May or may not be a copy. 326 | func (array denseF64Array) Row(row int) []float64 { 327 | if row < 0 || row >= array.shape[0] { 328 | panic(fmt.Sprintf("Can't get row %d from a %dx%d array", row, array.shape[0], array.shape[1])) 329 | } 330 | start := ndToFlat(array.shape, []int{row, 0}) 331 | return array.array[start : start+array.shape[1]] 332 | } 333 | 334 | // Get the number of rows 335 | func (array denseF64Array) Rows() int { 336 | return array.shape[0] 337 | } 338 | 339 | // A slice giving the size of all array dimensions 340 | func (array denseF64Array) Shape() []int { 341 | return array.shape 342 | } 343 | 344 | // The total number of elements in the matrix 345 | func (array denseF64Array) Size() int { 346 | return len(array.array) 347 | } 348 | 349 | // Get an array containing a rectangular slice of this array. 350 | // `from` and `to` should both have one index per axis. The indices 351 | // in `from` and `to` define the first and just-past-last indices you wish 352 | // to select along each axis. 353 | func (array denseF64Array) Slice(from []int, to []int) NDArray { 354 | return Slice(&array, from, to) 355 | } 356 | 357 | // Return a sparse coo copy of the matrix. The method will panic 358 | // if any off-diagonal elements are nonzero. 359 | func (array denseF64Array) SparseCoo() Matrix { 360 | m := SparseCoo(array.shape[0], array.shape[1]) 361 | array.VisitNonzero(func(pos []int, value float64) bool { 362 | m.ItemSet(value, pos[0], pos[1]) 363 | return true 364 | }) 365 | return m 366 | } 367 | 368 | // Return a sparse diag copy of the matrix. The method will panic 369 | // if any off-diagonal elements are nonzero. 370 | func (array denseF64Array) SparseDiag() Matrix { 371 | m := SparseDiag(array.shape[0], array.shape[1]) 372 | array.VisitNonzero(func(pos []int, value float64) bool { 373 | m.ItemSet(value, pos[0], pos[1]) 374 | return true 375 | }) 376 | return m 377 | } 378 | 379 | // Ask whether the matrix has a sparse representation (useful for optimization) 380 | func (array denseF64Array) Sparsity() ArraySparsity { 381 | return DenseArray 382 | } 383 | 384 | // Return the element-wise difference of this array and one or more others 385 | func (array denseF64Array) Sub(other ...NDArray) NDArray { 386 | return Sub(&array, other...) 387 | } 388 | 389 | // Return the sum of all array elements 390 | func (array denseF64Array) Sum() float64 { 391 | return Sum(&array) 392 | } 393 | 394 | // Returns the array as a matrix. This is only possible for 1D and 2D arrays; 395 | // 1D arrays of length n are converted into n x 1 vectors. 396 | func (array denseF64Array) M() Matrix { 397 | switch array.NDim() { 398 | default: 399 | panic(fmt.Sprintf("Cannot convert a %d-dim array into a matrix", array.NDim())) 400 | 401 | case 1: 402 | return &denseF64Array{ 403 | shape: []int{array.shape[0], 1}, 404 | array: array.array, 405 | transpose: array.transpose, 406 | } 407 | 408 | case 2: 409 | return &array 410 | } 411 | } 412 | 413 | // Return the same matrix, but with axes transposed. The same data is used, 414 | // for speed and memory efficiency. Use Copy() to create a new array. 415 | func (array denseF64Array) T() Matrix { 416 | return &denseF64Array{ 417 | shape: []int{array.shape[1], array.shape[0]}, 418 | array: array.array, 419 | transpose: !array.transpose, 420 | } 421 | } 422 | 423 | // Visit all matrix elements, invoking a method on each. If the method 424 | // returns false, iteration is aborted and VisitNonzero() returns false. 425 | // Otherwise, it returns true. 426 | func (array denseF64Array) Visit(f func(pos []int, value float64) bool) bool { 427 | for flat, value := range array.array { 428 | var pos []int 429 | if array.transpose { 430 | pOrig := flatToNd([]int{array.shape[1], array.shape[0]}, flat) 431 | pos = []int{pOrig[1], pOrig[0]} 432 | } else { 433 | pos = flatToNd(array.shape, flat) 434 | } 435 | if !f(pos, value) { 436 | return false 437 | } 438 | } 439 | return true 440 | } 441 | 442 | // Visit just nonzero elements, invoking a method on each. If the method 443 | // returns false, iteration is aborted and VisitNonzero() returns false. 444 | // Otherwise, it returns true. 445 | func (array denseF64Array) VisitNonzero(f func(pos []int, value float64) bool) bool { 446 | for flat, value := range array.array { 447 | if value != 0 { 448 | var pos []int 449 | if array.transpose { 450 | pOrig := flatToNd([]int{array.shape[1], array.shape[0]}, flat) 451 | pos = []int{pOrig[1], pOrig[0]} 452 | } else { 453 | pos = flatToNd(array.shape, flat) 454 | } 455 | if !f(pos, value) { 456 | return false 457 | } 458 | } 459 | } 460 | return true 461 | } 462 | -------------------------------------------------------------------------------- /graph/adjacency.go: -------------------------------------------------------------------------------- 1 | package graph 2 | 3 | import ( 4 | "github.com/jesand/numgo/container" 5 | "github.com/jesand/numgo/matrix" 6 | "math" 7 | "strings" 8 | ) 9 | 10 | // Create a graph using a dense adjacency matrix 11 | func NewDenseAdjacencyGraph(directed bool, maxNodes int) Graph { 12 | return &adjacencyGraph{ 13 | directed: directed, 14 | edges: matrix.Dense(maxNodes, maxNodes).M(), 15 | } 16 | } 17 | 18 | // Create a graph using a sparse adjacency matrix 19 | func NewSparseAdjacencyGraph(directed bool, maxNodes int) Graph { 20 | return &adjacencyGraph{ 21 | directed: directed, 22 | edges: matrix.SparseCoo(maxNodes, maxNodes), 23 | } 24 | } 25 | 26 | // A graph represented as an adjacency matrix 27 | type adjacencyGraph struct { 28 | 29 | // Whether the graph is directed 30 | directed bool 31 | 32 | // The node information 33 | nodes []Node 34 | 35 | // The adjacency matrix 36 | edges matrix.Matrix 37 | } 38 | 39 | // Add an edge to the graph. If the graph is not directed, the edge is 40 | // bidirectional. If the graph is weighted, the edge is assigned weight 1. 41 | func (graph *adjacencyGraph) AddEdge(from, to NodeID) { 42 | graph.AddEdgeWithWeight(from, to, 1) 43 | } 44 | 45 | // Add an edge with a weight to the graph. 46 | func (graph *adjacencyGraph) AddEdgeWithWeight(from, to NodeID, weight float64) { 47 | if int(from) >= len(graph.nodes) || int(to) >= len(graph.nodes) { 48 | panic(ErrInvalidNode{}) 49 | } 50 | graph.edges.ItemSet(weight, int(from), int(to)) 51 | graph.nodes[from].OutDegree++ 52 | graph.nodes[to].InDegree++ 53 | if !graph.directed { 54 | graph.edges.ItemSet(weight, int(to), int(from)) 55 | graph.nodes[to].OutDegree++ 56 | graph.nodes[from].InDegree++ 57 | } 58 | } 59 | 60 | // Add a vertex to the graph, with optional name. If the graph's internal 61 | // storage runs out of capacity, this will panic with ErrGraphCapacity. 62 | func (graph *adjacencyGraph) AddNode(name string) NodeID { 63 | if len(graph.nodes) >= graph.edges.Rows() { 64 | panic(ErrGraphCapacity{}) 65 | } 66 | id := NodeID(len(graph.nodes)) 67 | graph.nodes = append(graph.nodes, Node{ 68 | ID: id, 69 | Name: name, 70 | InDegree: 0, 71 | OutDegree: 0, 72 | }) 73 | return id 74 | } 75 | 76 | // Returns all children of a given node 77 | func (graph adjacencyGraph) Children(of NodeID) (children []Node) { 78 | if int(of) >= len(graph.nodes) { 79 | panic(ErrInvalidNode{}) 80 | } 81 | for child := 0; child < len(graph.nodes); child++ { 82 | if graph.edges.Item(int(of), int(child)) != 0 { 83 | children = append(children, graph.nodes[child]) 84 | } 85 | } 86 | return 87 | } 88 | 89 | // Returns all children of a given node and their corresponding edge weights 90 | func (graph adjacencyGraph) ChildrenWithWeights(of NodeID) (children []Node, weights []float64) { 91 | if int(of) >= len(graph.nodes) { 92 | panic(ErrInvalidNode{}) 93 | } 94 | for child := 0; child < len(graph.nodes); child++ { 95 | weight := graph.edges.Item(int(of), int(child)) 96 | if weight != 0 { 97 | children = append(children, graph.nodes[child]) 98 | weights = append(weights, weight) 99 | } 100 | } 101 | return 102 | } 103 | 104 | // Make a copy of the graph 105 | func (graph adjacencyGraph) Copy() Graph { 106 | result := &adjacencyGraph{ 107 | directed: graph.directed, 108 | nodes: make([]Node, len(graph.nodes)), 109 | edges: graph.edges.Copy().M(), 110 | } 111 | copy(result.nodes[:], graph.nodes[:]) 112 | return result 113 | } 114 | 115 | // Ask whether a given edge exists in the graph 116 | func (graph adjacencyGraph) HasEdge(from, to NodeID) bool { 117 | if int(from) >= len(graph.nodes) || int(to) >= len(graph.nodes) { 118 | panic(ErrInvalidNode{}) 119 | } 120 | return graph.edges.Item(int(from), int(to)) != 0 121 | } 122 | 123 | // Ask whether the graph contains any edges 124 | func (graph adjacencyGraph) HasEdges() bool { 125 | return graph.edges.CountNonzero() > 0 126 | } 127 | 128 | // Ask whether the graph contains cycles 129 | func (graph adjacencyGraph) HasCycles() bool { 130 | _, err := graph.TopologicalSort() 131 | return err != nil 132 | } 133 | 134 | // Ask whether the graph contains any nodes 135 | func (graph adjacencyGraph) HasNodes() bool { 136 | return len(graph.nodes) > 0 137 | } 138 | 139 | // Ask whether a path exists between two nodes 140 | func (graph adjacencyGraph) HasPath(from, to NodeID) bool { 141 | if int(from) >= len(graph.nodes) || int(to) >= len(graph.nodes) { 142 | panic(ErrInvalidNode{}) 143 | } 144 | p, _ := graph.ShortestPath(from, to) 145 | return len(p) > 0 146 | } 147 | 148 | // Ask whether the graph is a directed acyclic graph 149 | func (graph adjacencyGraph) IsDag() bool { 150 | return graph.IsDirected() && !graph.HasCycles() 151 | } 152 | 153 | // Ask whether the graph is directed 154 | func (graph adjacencyGraph) IsDirected() bool { 155 | return graph.directed 156 | } 157 | 158 | // Ask whether the graph is a tree 159 | func (graph adjacencyGraph) IsTree() bool { 160 | if !graph.IsDag() { 161 | return false 162 | } 163 | for _, node := range graph.nodes { 164 | if node.InDegree > 1 { 165 | return false 166 | } 167 | } 168 | return true 169 | } 170 | 171 | // Returns all nodes with out-degree zero 172 | func (graph adjacencyGraph) Leaves() (leaves []Node) { 173 | for _, node := range graph.nodes { 174 | if node.OutDegree == 0 { 175 | leaves = append(leaves, node) 176 | } 177 | } 178 | return 179 | } 180 | 181 | // Get the node with the given ID 182 | func (graph adjacencyGraph) Node(id NodeID) Node { 183 | if int(id) >= len(graph.nodes) { 184 | panic(ErrInvalidNode{}) 185 | } 186 | return graph.nodes[id] 187 | } 188 | 189 | // Returns all parents of a given node 190 | func (graph adjacencyGraph) Parents(of NodeID) (parents []Node) { 191 | if int(of) >= len(graph.nodes) { 192 | panic(ErrInvalidNode{}) 193 | } 194 | for parent := 0; parent < len(graph.nodes); parent++ { 195 | if graph.edges.Item(int(parent), int(of)) != 0 { 196 | parents = append(parents, graph.nodes[parent]) 197 | } 198 | } 199 | return 200 | } 201 | 202 | // Returns all parents of a given node and their corresponding edge weights 203 | func (graph adjacencyGraph) ParentsWithWeights(of NodeID) (parents []Node, weights []float64) { 204 | if int(of) >= len(graph.nodes) { 205 | panic(ErrInvalidNode{}) 206 | } 207 | for parent := 0; parent < len(graph.nodes); parent++ { 208 | weight := graph.edges.Item(int(parent), int(of)) 209 | if weight != 0 { 210 | parents = append(parents, graph.nodes[parent]) 211 | weights = append(weights, weight) 212 | } 213 | } 214 | return 215 | } 216 | 217 | // Remove an edge from the graph 218 | func (graph *adjacencyGraph) RemoveEdge(from, to NodeID) { 219 | if int(from) >= len(graph.nodes) || int(to) >= len(graph.nodes) { 220 | panic(ErrInvalidNode{}) 221 | } else if graph.edges.Item(int(from), int(to)) == 0 { 222 | return 223 | } 224 | graph.edges.ItemSet(0.0, int(from), int(to)) 225 | graph.nodes[from].OutDegree-- 226 | graph.nodes[to].InDegree-- 227 | if !graph.directed { 228 | graph.edges.ItemSet(0.0, int(to), int(from)) 229 | graph.nodes[to].OutDegree-- 230 | graph.nodes[from].InDegree-- 231 | } 232 | } 233 | 234 | // Returns all nodes with in-degree zero 235 | func (graph adjacencyGraph) Roots() (roots []Node) { 236 | for _, node := range graph.nodes { 237 | if node.InDegree == 0 { 238 | roots = append(roots, node) 239 | } 240 | } 241 | return 242 | } 243 | 244 | // Returns the shortest path between two nodes and the edge weights along 245 | // the path 246 | func (graph adjacencyGraph) ShortestPath(from, to NodeID) (p []Node, w []float64) { 247 | if int(from) >= len(graph.nodes) || int(to) >= len(graph.nodes) { 248 | panic(ErrInvalidNode{}) 249 | } 250 | if from == to { 251 | return []Node{graph.nodes[from]}, []float64{} 252 | } 253 | var isNode = func(node Node, path []Node, weights []float64) bool { 254 | if node.ID == to { 255 | p = path 256 | w = weights 257 | return true 258 | } 259 | return false 260 | } 261 | graph.VisitBFS(from, isNode) 262 | return 263 | } 264 | 265 | // Returns the total weight along the shortest path between two nodes, or 266 | // 0 if there is no such path. 267 | func (graph adjacencyGraph) ShortestPathWeight(from, to NodeID) (weight float64) { 268 | if int(from) >= len(graph.nodes) || int(to) >= len(graph.nodes) { 269 | panic(ErrInvalidNode{}) 270 | } 271 | path, weights := graph.ShortestPath(from, to) 272 | if len(path) == 0 { 273 | return math.Inf(1) 274 | } 275 | weight = 0 276 | for _, w := range weights { 277 | weight += w 278 | } 279 | return 280 | } 281 | 282 | // Returns the weights along all pairwise shortest paths in the graph. 283 | // In the returned array, weights[from][to] gives the minimum path weight 284 | // from the node with ID=from to the node with ID=to. The weight will be 285 | // positive infinity if there is no such path. 286 | func (graph adjacencyGraph) ShortestPathWeights() (weights [][]float64) { 287 | return graph.floydWarshallShortestPathWeights() 288 | } 289 | 290 | // The Floyd-Warshall algorithm for all-pairs shortest path weights 291 | func (graph adjacencyGraph) floydWarshallShortestPathWeights() (weights [][]float64) { 292 | 293 | // Initialize all path weights 294 | var numNodes = len(graph.nodes) 295 | for i := 0; i < numNodes; i++ { 296 | weights = append(weights, make([]float64, numNodes)) 297 | for j := 0; j < numNodes; j++ { 298 | if i != j { 299 | weight := graph.edges.Item(i, j) 300 | if weight == 0 { 301 | weight = math.Inf(1) 302 | } 303 | weights[i][j] = weight 304 | } 305 | } 306 | } 307 | 308 | // Update based on shortest paths 309 | for k := 0; k < numNodes; k++ { 310 | for i := 0; i < numNodes; i++ { 311 | for j := 0; j < numNodes; j++ { 312 | path := weights[i][k] + weights[k][j] 313 | if path < weights[i][j] { 314 | weights[i][j] = path 315 | } 316 | } 317 | } 318 | } 319 | 320 | return 321 | } 322 | 323 | // Returns the number of nodes and edges in the graph 324 | func (graph adjacencyGraph) Size() (nodes, edges int) { 325 | return len(graph.nodes), graph.edges.CountNonzero() 326 | } 327 | 328 | // Returns a string representation of the graph 329 | func (graph adjacencyGraph) String() string { 330 | var ss []string 331 | printed := make(map[NodeID]bool) 332 | for _, r := range graph.Roots() { 333 | ss = append(ss, graph.nodeString(r, 0, printed)) 334 | } 335 | return strings.Join(ss, "\n") 336 | } 337 | 338 | func (graph adjacencyGraph) nodeString(node Node, depth int, printed map[NodeID]bool) string { 339 | var ss []string 340 | for i := 0; i < depth; i++ { 341 | ss = append(ss, "| ") 342 | } 343 | ss = append(ss, node.Name) 344 | ss = append(ss, "\n") 345 | printed[node.ID] = true 346 | for _, child := range graph.Children(node.ID) { 347 | if printed[child.ID] { 348 | for i := 0; i < depth+1; i++ { 349 | ss = append(ss, "| ") 350 | } 351 | ss = append(ss, child.Name) 352 | ss = append(ss, " \n") 353 | } else { 354 | ss = append(ss, graph.nodeString(child, depth+1, printed)) 355 | } 356 | } 357 | return strings.Join(ss, "") 358 | } 359 | 360 | // Returns a topological sort of the graph, if possible. All nodes will 361 | // follow their ancestors in the resulting list. If there is no path between 362 | // a given pair of nodes, their ordering is chosen arbitrarily. 363 | // If the graph is not acyclic, fails with ErrGraphIsCyclic. 364 | func (graph adjacencyGraph) TopologicalSort() (order []Node, err error) { 365 | var ( 366 | g = graph.Copy() 367 | queue = g.Roots() 368 | ) 369 | for len(queue) > 0 { 370 | node := queue[0] 371 | queue = queue[1:] 372 | order = append(order, graph.Node(node.ID)) 373 | for _, child := range g.Children(node.ID) { 374 | if child.InDegree == 1 { 375 | queue = append(queue, child) 376 | } 377 | g.RemoveEdge(node.ID, child.ID) 378 | } 379 | } 380 | if g.HasEdges() { 381 | return nil, ErrGraphIsCyclic{} 382 | } 383 | return 384 | } 385 | 386 | // Create the transitive closure of the graph. Adds edges so that all nodes 387 | // reachable from a given node have an edge between them. Any added edges 388 | // will be assigned a weight equal to the shortest path weight between the 389 | // nodes in the original graph. 390 | func (graph adjacencyGraph) TransitiveClosure() Graph { 391 | var ( 392 | numNodes = len(graph.nodes) 393 | g = graph.Copy() 394 | weights = g.ShortestPathWeights() 395 | ) 396 | for i := 0; i < numNodes; i++ { 397 | for j := 0; j < numNodes; j++ { 398 | if i != j && graph.edges.Item(i, j) == 0 && !math.IsInf(weights[i][j], 1) { 399 | g.AddEdgeWithWeight(NodeID(i), NodeID(j), weights[i][j]) 400 | } 401 | } 402 | } 403 | return g 404 | } 405 | 406 | // Create the transitive reduction of the graph. Keeps only edges necessary 407 | // to preserve all paths in the graph. The behavior of this algorithm is 408 | // undefined unless the graph is a DAG. 409 | func (graph adjacencyGraph) TransitiveReduction() Graph { 410 | var ( 411 | numNodes = len(graph.nodes) 412 | g = graph.Copy().(*adjacencyGraph) 413 | ) 414 | weights := g.ShortestPathWeights() 415 | for i := 0; i < numNodes; i++ { 416 | for j := 0; j < numNodes; j++ { 417 | if weights[i][j] != 0 && !math.IsInf(weights[i][j], 1) { 418 | for k := 0; k < numNodes; k++ { 419 | if weights[j][k] != 0 && !math.IsInf(weights[j][k], 1) { 420 | g.RemoveEdge(NodeID(i), NodeID(k)) 421 | } 422 | } 423 | } 424 | } 425 | } 426 | return g 427 | } 428 | 429 | // Information for visiting nodes 430 | type visitInfo struct { 431 | Node Node 432 | Path []Node 433 | Weights []float64 434 | } 435 | 436 | // Generic graph visitor method 437 | func visit(graph Graph, frontier container.List, fn NodeVisitor) bool { 438 | visited := make(map[NodeID]bool) 439 | for frontier.Size() > 0 { 440 | next := frontier.Pop().(*visitInfo) 441 | visited[next.Node.ID] = true 442 | if fn(next.Node, next.Path, next.Weights) { 443 | return true 444 | } 445 | children, weights := graph.ChildrenWithWeights(next.Node.ID) 446 | for i := 0; i < len(children); i++ { 447 | if !visited[children[i].ID] { 448 | frontier.Push(&visitInfo{ 449 | Node: children[i], 450 | Path: append(next.Path, children[i]), 451 | Weights: append(next.Weights, weights[i]), 452 | }) 453 | } 454 | } 455 | } 456 | return false 457 | } 458 | 459 | // Visit all descendants using breadth-first search. Returns the result of 460 | // the final call to fn. 461 | func (graph *adjacencyGraph) VisitBFS(from NodeID, fn NodeVisitor) bool { 462 | if int(from) >= len(graph.nodes) { 463 | panic(ErrInvalidNode{}) 464 | } 465 | frontier := &container.Queue{} 466 | frontier.Push(&visitInfo{ 467 | Node: graph.nodes[from], 468 | Path: []Node{graph.nodes[from]}, 469 | Weights: []float64{}, 470 | }) 471 | return visit(graph, frontier, fn) 472 | } 473 | 474 | // Visit all descendants using depth-first search. Returns the result of 475 | // the final call to fn. 476 | func (graph *adjacencyGraph) VisitDFS(from NodeID, fn NodeVisitor) bool { 477 | if int(from) >= len(graph.nodes) { 478 | panic(ErrInvalidNode{}) 479 | } 480 | frontier := &container.Stack{} 481 | frontier.Push(&visitInfo{ 482 | Node: graph.nodes[from], 483 | Path: []Node{graph.nodes[from]}, 484 | Weights: []float64{}, 485 | }) 486 | return visit(graph, frontier, fn) 487 | } 488 | -------------------------------------------------------------------------------- /matrix/sparse_diag.go: -------------------------------------------------------------------------------- 1 | package matrix 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // A sparse 2D Matrix with diagonal representation: only the main diagonal is 8 | // stored; all other values are zero. 9 | type sparseDiagF64Matrix struct { 10 | shape []int 11 | diag []float64 12 | } 13 | 14 | // Return the element-wise sum of this array and one or more others 15 | func (array sparseDiagF64Matrix) Add(other ...NDArray) NDArray { 16 | return Add(&array, other...) 17 | } 18 | 19 | // Returns true if and only if all items are nonzero 20 | func (array sparseDiagF64Matrix) All() bool { 21 | return false 22 | } 23 | 24 | // Returns true if f is true for all array elements 25 | func (array sparseDiagF64Matrix) AllF(f func(v float64) bool) bool { 26 | return AllF(&array, f) 27 | } 28 | 29 | // Returns true if f is true for all pairs of array elements in the same position 30 | func (array sparseDiagF64Matrix) AllF2(f func(v1, v2 float64) bool, other NDArray) bool { 31 | return AllF2(&array, f, other) 32 | } 33 | 34 | // Returns true if and only if any item is nonzero 35 | func (array sparseDiagF64Matrix) Any() bool { 36 | for _, v := range array.diag { 37 | if v != 0 { 38 | return true 39 | } 40 | } 41 | return false 42 | } 43 | 44 | // Returns true if f is true for any array element 45 | func (array sparseDiagF64Matrix) AnyF(f func(v float64) bool) bool { 46 | return AnyF(&array, f) 47 | } 48 | 49 | // Returns true if f is true for any pair of array elements in the same position 50 | func (array sparseDiagF64Matrix) AnyF2(f func(v1, v2 float64) bool, other NDArray) bool { 51 | return AnyF2(&array, f, other) 52 | } 53 | 54 | // Return the result of applying a function to all elements 55 | func (array sparseDiagF64Matrix) Apply(f func(float64) float64) NDArray { 56 | return Apply(&array, f) 57 | } 58 | 59 | // Get the matrix data as a flattened 1D array; sparse matrices will make 60 | // a copy first. 61 | func (array sparseDiagF64Matrix) Array() []float64 { 62 | return array.Dense().Array() 63 | } 64 | 65 | // Set the values of the items on a given column 66 | func (array sparseDiagF64Matrix) ColSet(col int, values []float64) { 67 | if col < 0 || col >= array.shape[1] { 68 | panic(fmt.Sprintf("ColSet can't set col %d of a %d-col array", col, array.shape[1])) 69 | } else if len(values) != array.shape[0] { 70 | panic(fmt.Sprintf("ColSet has %d rows but got %d values", array.shape[0], len(values))) 71 | } 72 | for row := 0; row < array.shape[0]; row++ { 73 | if row != col { 74 | if values[row] != 0 { 75 | panic(fmt.Sprintf("ColSet can't set cell (%d, %d) of a %dx%d sparse diagonal matrix", row, col, array.shape[0], array.shape[1])) 76 | } 77 | } else { 78 | array.diag[row] = values[row] 79 | } 80 | } 81 | } 82 | 83 | // Get a particular column for read-only access. May or may not be a copy. 84 | func (array sparseDiagF64Matrix) Col(col int) []float64 { 85 | if col < 0 || col >= array.shape[1] { 86 | panic(fmt.Sprintf("Can't get column %d from a %dx%d array", col, array.shape[0], array.shape[1])) 87 | } 88 | result := make([]float64, array.shape[1]) 89 | result[col] = array.diag[col] 90 | return result 91 | } 92 | 93 | // Get the number of columns 94 | func (array sparseDiagF64Matrix) Cols() int { 95 | return array.shape[1] 96 | } 97 | 98 | // Create a new array by concatenating this with another array along the 99 | // specified axis. The array shapes must be equal along all other axes. 100 | // It is legal to add a new axis. 101 | func (array sparseDiagF64Matrix) Concat(axis int, others ...NDArray) NDArray { 102 | return Concat(axis, &array, others...) 103 | } 104 | 105 | // Returns a duplicate of this array 106 | func (array sparseDiagF64Matrix) Copy() NDArray { 107 | return array.copy() 108 | } 109 | 110 | // Returns a duplicate of this array, preserving type 111 | func (array sparseDiagF64Matrix) copy() *sparseDiagF64Matrix { 112 | result := &sparseDiagF64Matrix{ 113 | shape: make([]int, len(array.shape)), 114 | diag: make([]float64, len(array.diag)), 115 | } 116 | copy(result.shape[:], array.shape[:]) 117 | copy(result.diag[:], array.diag[:]) 118 | return result 119 | } 120 | 121 | // Counts the number of nonzero elements in the array 122 | func (array sparseDiagF64Matrix) CountNonzero() int { 123 | count := 0 124 | for _, v := range array.diag { 125 | if v != 0 { 126 | count++ 127 | } 128 | } 129 | return count 130 | } 131 | 132 | // Returns a dense copy of the array 133 | func (array sparseDiagF64Matrix) Dense() NDArray { 134 | result := Dense(array.shape...) 135 | for pos, val := range array.diag { 136 | result.ItemSet(val, pos, pos) 137 | } 138 | return result 139 | } 140 | 141 | // Get a column vector containing the main diagonal elements of the matrix 142 | func (array sparseDiagF64Matrix) Diag() Matrix { 143 | return A([]int{len(array.diag), 1}, array.diag...).M() 144 | } 145 | 146 | // Treat the rows as points, and get the pairwise distance between them. 147 | // Returns a distance matrix D such that D_i,j is the distance between 148 | // rows i and j. 149 | func (array sparseDiagF64Matrix) Dist(t DistType) Matrix { 150 | return Dist(&array, t) 151 | } 152 | 153 | // Return the element-wise quotient of this array and one or more others. 154 | // This function defines 0 / 0 = 0, so it's useful for sparse arrays. 155 | func (array sparseDiagF64Matrix) Div(other ...NDArray) NDArray { 156 | return Div(&array, other...) 157 | } 158 | 159 | // Returns true if and only if all elements in the two arrays are equal 160 | func (array sparseDiagF64Matrix) Equal(other NDArray) bool { 161 | return Equal(&array, other) 162 | } 163 | 164 | // Set all array elements to the given value 165 | func (array sparseDiagF64Matrix) Fill(value float64) { 166 | panic("Can't Fill() a sparse diagonal matrix") 167 | } 168 | 169 | // Get the coordinates for the item at the specified flat position 170 | func (array sparseDiagF64Matrix) FlatCoord(index int) []int { 171 | return flatToNd(array.shape, index) 172 | } 173 | 174 | // Get an array element in a flattened verison of this array 175 | func (array sparseDiagF64Matrix) FlatItem(index int) float64 { 176 | coord := flatToNd(array.shape, index) 177 | if coord[0] != coord[1] || coord[0] >= len(array.diag) { 178 | return 0 179 | } 180 | return array.diag[coord[0]] 181 | } 182 | 183 | // Set an array element in a flattened version of this array 184 | func (array sparseDiagF64Matrix) FlatItemSet(value float64, index int) { 185 | coord := flatToNd(array.shape, index) 186 | if coord[0] != coord[1] || coord[0] >= len(array.diag) { 187 | panic(fmt.Sprintf("FlatItemSet index %v invalid for sparse diagonal array shape %v", index, array.shape)) 188 | } 189 | array.diag[coord[0]] = value 190 | } 191 | 192 | // Get the matrix inverse 193 | func (array sparseDiagF64Matrix) Inverse() (Matrix, error) { 194 | return Inverse(&array) 195 | } 196 | 197 | // Get an array element 198 | func (array sparseDiagF64Matrix) Item(index ...int) float64 { 199 | if len(index) != 2 || index[0] >= array.shape[0] || index[1] >= array.shape[1] { 200 | panic(fmt.Sprintf("Item indices %v invalid for array shape %v", index, array.shape)) 201 | } else if index[0] != index[1] || index[0] >= len(array.diag) { 202 | return 0 203 | } 204 | return array.diag[index[0]] 205 | } 206 | 207 | // Add a scalar value to each array element 208 | func (array *sparseDiagF64Matrix) ItemAdd(value float64) NDArray { 209 | result := &denseF64Array{ 210 | shape: make([]int, 2), 211 | array: make([]float64, array.shape[0]*array.shape[1]), 212 | } 213 | copy(result.shape[:], array.shape[:]) 214 | flat := 0 215 | for row := 0; row < array.shape[0]; row++ { 216 | for col := 0; col < array.shape[1]; col++ { 217 | if row == col { 218 | result.array[flat] = array.diag[row] + value 219 | } else { 220 | result.array[flat] = value 221 | } 222 | flat++ 223 | } 224 | } 225 | return result 226 | } 227 | 228 | // Divide each array element by a scalar value 229 | func (array *sparseDiagF64Matrix) ItemDiv(value float64) NDArray { 230 | result := array.copy() 231 | for i := 0; i < len(result.diag); i++ { 232 | result.diag[i] /= value 233 | } 234 | return result 235 | } 236 | 237 | // Multiply each array element by a scalar value 238 | func (array *sparseDiagF64Matrix) ItemProd(value float64) NDArray { 239 | result := array.copy() 240 | for i := 0; i < len(result.diag); i++ { 241 | result.diag[i] *= value 242 | } 243 | return result 244 | } 245 | 246 | // Subtract a scalar value from each array element 247 | func (array *sparseDiagF64Matrix) ItemSub(value float64) NDArray { 248 | result := &denseF64Array{ 249 | shape: make([]int, 2), 250 | array: make([]float64, array.shape[0]*array.shape[1]), 251 | } 252 | copy(result.shape[:], array.shape[:]) 253 | flat := 0 254 | for row := 0; row < array.shape[0]; row++ { 255 | for col := 0; col < array.shape[1]; col++ { 256 | if row == col { 257 | result.array[flat] = array.diag[row] - value 258 | } else { 259 | result.array[flat] = -value 260 | } 261 | flat++ 262 | } 263 | } 264 | return result 265 | } 266 | 267 | // Set an array element 268 | func (array sparseDiagF64Matrix) ItemSet(value float64, index ...int) { 269 | if len(index) != 2 || index[0] >= array.shape[0] || index[1] >= array.shape[1] { 270 | panic(fmt.Sprintf("ItemSet indices %v invalid for array shape %v", index, array.shape)) 271 | } else if index[0] != index[1] { 272 | panic(fmt.Sprintf("ItemSet indices %v invalid for sparse diagonal array", index)) 273 | } 274 | array.diag[index[0]] = value 275 | } 276 | 277 | // Solve for x, where ax = b. 278 | func (array sparseDiagF64Matrix) LDivide(b Matrix) Matrix { 279 | return LDivide(&array, b) 280 | } 281 | 282 | // Get the result of matrix multiplication between this and some other 283 | // array(s). All arrays must have two dimensions, and the dimensions must 284 | // be aligned correctly for multiplication. 285 | // If A is m x p and B is p x n, then C = A.MProd(B) is the m x n matrix 286 | // with C[i, j] = \sum_{k=1}^p A[i,k] * B[k,j]. 287 | func (array sparseDiagF64Matrix) MProd(others ...Matrix) Matrix { 288 | return MProd(&array, others...) 289 | } 290 | 291 | // Get the value of the largest array element 292 | func (array sparseDiagF64Matrix) Max() float64 { 293 | return Max(&array) 294 | } 295 | 296 | // Get the value of the smallest array element 297 | func (array sparseDiagF64Matrix) Min() float64 { 298 | return Min(&array) 299 | } 300 | 301 | // The number of dimensions in the matrix 302 | func (array sparseDiagF64Matrix) NDim() int { 303 | return len(array.shape) 304 | } 305 | 306 | // Get the matrix norm of the specified ordinality (1, 2, infinity, ...) 307 | func (array sparseDiagF64Matrix) Norm(ord float64) float64 { 308 | return Norm(&array, ord) 309 | } 310 | 311 | // Return a copy of the array, normalized to sum to 1 312 | func (array *sparseDiagF64Matrix) Normalize() NDArray { 313 | return Normalize(array) 314 | } 315 | 316 | // Return the element-wise product of this array and one or more others 317 | func (array sparseDiagF64Matrix) Prod(other ...NDArray) NDArray { 318 | return Prod(&array, other...) 319 | } 320 | 321 | // Get a 1D copy of the array, in 'C' order: rightmost axes change fastest 322 | func (array sparseDiagF64Matrix) Ravel() NDArray { 323 | return Ravel(&array) 324 | } 325 | 326 | // Set the values of the items on a given row 327 | func (array sparseDiagF64Matrix) RowSet(row int, values []float64) { 328 | if row < 0 || row >= array.shape[0] { 329 | panic(fmt.Sprintf("RowSet can't set row %d of a %d-row array", row, array.shape[0])) 330 | } else if len(values) != array.shape[1] { 331 | panic(fmt.Sprintf("RowSet has %d columns but got %d values", array.shape[1], len(values))) 332 | } 333 | for col := 0; col < array.shape[1]; col++ { 334 | if row != col { 335 | if values[col] != 0 { 336 | panic(fmt.Sprintf("RowSet can't set cell (%d, %d) of a %dx%d sparse diagonal matrix", row, col, array.shape[0], array.shape[1])) 337 | } 338 | } else { 339 | array.diag[col] = values[col] 340 | } 341 | } 342 | } 343 | 344 | // Get a particular row for read-only access. May or may not be a copy. 345 | func (array sparseDiagF64Matrix) Row(row int) []float64 { 346 | if row < 0 || row >= array.shape[0] { 347 | panic(fmt.Sprintf("Can't get row %d from a %dx%d array", row, array.shape[0], array.shape[1])) 348 | } 349 | result := make([]float64, array.shape[0]) 350 | result[row] = array.diag[row] 351 | return result 352 | } 353 | 354 | // Get the number of rows 355 | func (array sparseDiagF64Matrix) Rows() int { 356 | return array.shape[0] 357 | } 358 | 359 | // A slice giving the size of all array dimensions 360 | func (array sparseDiagF64Matrix) Shape() []int { 361 | return array.shape 362 | } 363 | 364 | // The total number of elements in the matrix 365 | func (array sparseDiagF64Matrix) Size() int { 366 | return array.shape[0] * array.shape[1] 367 | } 368 | 369 | // Get an array containing a rectangular slice of this array. 370 | // `from` and `to` should both have one index per axis. The indices 371 | // in `from` and `to` define the first and just-past-last indices you wish 372 | // to select along each axis. 373 | func (array sparseDiagF64Matrix) Slice(from []int, to []int) NDArray { 374 | return Slice(&array, from, to) 375 | } 376 | 377 | // Return a sparse coo copy of the matrix. The method will panic 378 | // if any off-diagonal elements are nonzero. 379 | func (array sparseDiagF64Matrix) SparseCoo() Matrix { 380 | m := SparseCoo(array.shape[0], array.shape[1]) 381 | array.VisitNonzero(func(pos []int, value float64) bool { 382 | m.ItemSet(value, pos[0], pos[1]) 383 | return true 384 | }) 385 | return m 386 | } 387 | 388 | // Return a sparse diag copy of the matrix. The method will panic 389 | // if any off-diagonal elements are nonzero. 390 | func (array sparseDiagF64Matrix) SparseDiag() Matrix { 391 | return array.copy() 392 | } 393 | 394 | // Ask whether the matrix has a sparse representation (useful for optimization) 395 | func (array sparseDiagF64Matrix) Sparsity() ArraySparsity { 396 | return SparseDiagMatrix 397 | } 398 | 399 | // Return the element-wise difference of this array and one or more others 400 | func (array sparseDiagF64Matrix) Sub(other ...NDArray) NDArray { 401 | return Sub(&array, other...) 402 | } 403 | 404 | // Return the sum of all array elements 405 | func (array sparseDiagF64Matrix) Sum() float64 { 406 | return Sum(&array) 407 | } 408 | 409 | // Returns the array as a matrix. 410 | func (array sparseDiagF64Matrix) M() Matrix { 411 | return &array 412 | } 413 | 414 | // Return the same matrix, but with axes transposed. The same data is used, 415 | // for speed and memory efficiency. Use Copy() to create a new array. 416 | func (array sparseDiagF64Matrix) T() Matrix { 417 | return &sparseDiagF64Matrix{ 418 | shape: []int{array.shape[1], array.shape[0]}, 419 | diag: array.diag, 420 | } 421 | } 422 | 423 | // Visit all matrix elements, invoking a method on each. If the method 424 | // returns false, iteration is aborted and VisitNonzero() returns false. 425 | // Otherwise, it returns true. 426 | func (array sparseDiagF64Matrix) Visit(f func(pos []int, value float64) bool) bool { 427 | for row := 0; row < array.shape[0]; row++ { 428 | for col := 0; col < array.shape[1]; col++ { 429 | var value float64 430 | if row == col { 431 | value = array.diag[row] 432 | } else { 433 | value = 0 434 | } 435 | if !f([]int{row, col}, value) { 436 | return false 437 | } 438 | } 439 | } 440 | return true 441 | } 442 | 443 | // Visit just nonzero elements, invoking a method on each. If the method 444 | // returns false, iteration is aborted and VisitNonzero() returns false. 445 | // Otherwise, it returns true. 446 | func (array sparseDiagF64Matrix) VisitNonzero(f func(pos []int, value float64) bool) bool { 447 | for idx, value := range array.diag { 448 | if !f([]int{idx, idx}, value) { 449 | return false 450 | } 451 | } 452 | return true 453 | } 454 | -------------------------------------------------------------------------------- /matrix/sparse_coo.go: -------------------------------------------------------------------------------- 1 | package matrix 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // A sparse 2D Matrix with coordinate representation 8 | type sparseCooF64Matrix struct { 9 | shape []int 10 | values []map[int]float64 11 | transpose bool 12 | } 13 | 14 | // Return the element-wise sum of this array and one or more others 15 | func (array sparseCooF64Matrix) Add(other ...NDArray) NDArray { 16 | return Add(&array, other...) 17 | } 18 | 19 | // Returns true if and only if all items are nonzero 20 | func (array sparseCooF64Matrix) All() bool { 21 | return All(&array) 22 | } 23 | 24 | // Returns true if f is true for all array elements 25 | func (array sparseCooF64Matrix) AllF(f func(v float64) bool) bool { 26 | return AllF(&array, f) 27 | } 28 | 29 | // Returns true if f is true for all pairs of array elements in the same position 30 | func (array sparseCooF64Matrix) AllF2(f func(v1, v2 float64) bool, other NDArray) bool { 31 | return AllF2(&array, f, other) 32 | } 33 | 34 | // Returns true if and only if any item is nonzero 35 | func (array sparseCooF64Matrix) Any() bool { 36 | return Any(&array) 37 | } 38 | 39 | // Returns true if f is true for any array element 40 | func (array sparseCooF64Matrix) AnyF(f func(v float64) bool) bool { 41 | return AnyF(&array, f) 42 | } 43 | 44 | // Returns true if f is true for any pair of array elements in the same position 45 | func (array sparseCooF64Matrix) AnyF2(f func(v1, v2 float64) bool, other NDArray) bool { 46 | return AnyF2(&array, f, other) 47 | } 48 | 49 | // Return the result of applying a function to all elements 50 | func (array sparseCooF64Matrix) Apply(f func(float64) float64) NDArray { 51 | return Apply(&array, f) 52 | } 53 | 54 | // Get the matrix data as a flattened 1D array; sparse matrices will make 55 | // a copy first. 56 | func (array sparseCooF64Matrix) Array() []float64 { 57 | return array.Dense().Array() 58 | } 59 | 60 | // Set the values of the items on a given column 61 | func (array *sparseCooF64Matrix) ColSet(col int, values []float64) { 62 | if col < 0 || col >= array.shape[1] { 63 | panic(fmt.Sprintf("ColSet can't set col %d of a %d-col array", col, array.shape[1])) 64 | } else if len(values) != array.shape[0] { 65 | panic(fmt.Sprintf("ColSet has %d rows but got %d values", array.shape[0], len(values))) 66 | } 67 | for row := 0; row < array.shape[0]; row++ { 68 | array.ItemSet(values[row], row, col) 69 | } 70 | } 71 | 72 | // Get a particular column for read-only access. May or may not be a copy. 73 | func (array sparseCooF64Matrix) Col(col int) []float64 { 74 | if col < 0 || col >= array.shape[1] { 75 | panic(fmt.Sprintf("Can't get column %d from a %dx%d array", col, array.shape[0], array.shape[1])) 76 | } 77 | result := make([]float64, array.shape[1]) 78 | for row, val := range array.values { 79 | result[row] = val[col] 80 | } 81 | return result 82 | } 83 | 84 | // Get the number of columns 85 | func (array sparseCooF64Matrix) Cols() int { 86 | return array.shape[1] 87 | } 88 | 89 | // Create a new array by concatenating this with another array along the 90 | // specified axis. The array shapes must be equal along all other axes. 91 | // It is legal to add a new axis. 92 | func (array sparseCooF64Matrix) Concat(axis int, others ...NDArray) NDArray { 93 | return Concat(axis, &array, others...) 94 | } 95 | 96 | // Returns a duplicate of this array 97 | func (array sparseCooF64Matrix) Copy() NDArray { 98 | return array.copy() 99 | } 100 | 101 | // Returns a duplicate of this array, preserving type 102 | func (array sparseCooF64Matrix) copy() *sparseCooF64Matrix { 103 | result := SparseCoo(array.shape[0], array.shape[1]).(*sparseCooF64Matrix) 104 | if array.transpose { 105 | for row, val := range array.values { 106 | for col, v := range val { 107 | result.values[col][row] = v 108 | } 109 | } 110 | } else { 111 | for row, val := range array.values { 112 | for col, v := range val { 113 | result.values[row][col] = v 114 | } 115 | } 116 | } 117 | return result 118 | } 119 | 120 | // Counts the number of nonzero elements in the array 121 | func (array sparseCooF64Matrix) CountNonzero() int { 122 | count := 0 123 | for _, val := range array.values { 124 | count += len(val) 125 | } 126 | return count 127 | } 128 | 129 | // Returns a dense copy of the array 130 | func (array sparseCooF64Matrix) Dense() NDArray { 131 | var result NDArray 132 | result = Dense(array.shape...) 133 | for row, val := range array.values { 134 | for col, v := range val { 135 | if array.transpose { 136 | result.ItemSet(v, col, row) 137 | } else { 138 | result.ItemSet(v, row, col) 139 | } 140 | } 141 | } 142 | return result 143 | } 144 | 145 | // Get a column vector containing the main diagonal elements of the matrix 146 | func (array sparseCooF64Matrix) Diag() Matrix { 147 | size := array.shape[0] 148 | if array.shape[1] < size { 149 | size = array.shape[1] 150 | } 151 | result := Dense(size, 1).M() 152 | for row, val := range array.values { 153 | result.ItemSet(val[row], row, 0) 154 | } 155 | return result 156 | } 157 | 158 | // Treat the rows as points, and get the pairwise distance between them. 159 | // Returns a distance matrix D such that D_i,j is the distance between 160 | // rows i and j. 161 | func (array sparseCooF64Matrix) Dist(t DistType) Matrix { 162 | return Dist(&array, t) 163 | } 164 | 165 | // Return the element-wise quotient of this array and one or more others. 166 | // This function defines 0 / 0 = 0, so it's useful for sparse arrays. 167 | func (array sparseCooF64Matrix) Div(other ...NDArray) NDArray { 168 | return Div(&array, other...) 169 | } 170 | 171 | // Returns true if and only if all elements in the two arrays are equal 172 | func (array sparseCooF64Matrix) Equal(other NDArray) bool { 173 | return Equal(&array, other) 174 | } 175 | 176 | // Set all array elements to the given value 177 | func (array sparseCooF64Matrix) Fill(value float64) { 178 | panic("Can't Fill() a sparse coo matrix") 179 | } 180 | 181 | // Get the coordinates for the item at the specified flat position 182 | func (array sparseCooF64Matrix) FlatCoord(index int) []int { 183 | return flatToNd(array.shape, index) 184 | } 185 | 186 | // Get an array element in a flattened verison of this array 187 | func (array sparseCooF64Matrix) FlatItem(index int) float64 { 188 | // This is ok with transpose, because array.Item does transposition for us 189 | nd := flatToNd(array.shape, index) 190 | return array.Item(nd[0], nd[1]) 191 | } 192 | 193 | // Set an array element in a flattened version of this array 194 | func (array *sparseCooF64Matrix) FlatItemSet(value float64, index int) { 195 | nd := flatToNd(array.shape, index) 196 | if array.transpose { 197 | array.ItemSet(value, nd[0], nd[1]) 198 | } 199 | array.ItemSet(value, nd[0], nd[1]) 200 | } 201 | 202 | // Get the matrix inverse 203 | func (array sparseCooF64Matrix) Inverse() (Matrix, error) { 204 | return Inverse(&array) 205 | } 206 | 207 | // Get an array element 208 | func (array sparseCooF64Matrix) Item(index ...int) float64 { 209 | if len(index) != 2 || index[0] >= array.shape[0] || index[1] >= array.shape[1] { 210 | panic(fmt.Sprintf("Item indices %v invalid for array shape %v", index, array.shape)) 211 | } 212 | if array.transpose { 213 | index[0], index[1] = index[1], index[0] 214 | } 215 | return array.values[index[0]][index[1]] 216 | } 217 | 218 | // Add a scalar value to each array element 219 | func (array *sparseCooF64Matrix) ItemAdd(value float64) NDArray { 220 | result := WithValue(value, array.shape...) 221 | for row, val := range array.values { 222 | for col, v := range val { 223 | if array.transpose { 224 | result.ItemSet(v+value, col, row) 225 | } else { 226 | result.ItemSet(v+value, row, col) 227 | } 228 | } 229 | } 230 | return result 231 | } 232 | 233 | // Divide each array element by a scalar value 234 | func (array *sparseCooF64Matrix) ItemDiv(value float64) NDArray { 235 | result := Dense(array.shape...) 236 | for row, val := range array.values { 237 | for col, v := range val { 238 | if array.transpose { 239 | result.ItemSet(v/value, col, row) 240 | } else { 241 | result.ItemSet(v/value, row, col) 242 | } 243 | } 244 | } 245 | return result 246 | } 247 | 248 | // Multiply each array element by a scalar value 249 | func (array *sparseCooF64Matrix) ItemProd(value float64) NDArray { 250 | result := Dense(array.shape...) 251 | for row, val := range array.values { 252 | for col, v := range val { 253 | if array.transpose { 254 | result.ItemSet(v*value, col, row) 255 | } else { 256 | result.ItemSet(v*value, row, col) 257 | } 258 | } 259 | } 260 | return result 261 | } 262 | 263 | // Subtract a scalar value from each array element 264 | func (array *sparseCooF64Matrix) ItemSub(value float64) NDArray { 265 | result := WithValue(-value, array.shape...) 266 | for row, val := range array.values { 267 | for col, v := range val { 268 | if array.transpose { 269 | result.ItemSet(v-value, col, row) 270 | } else { 271 | result.ItemSet(v-value, row, col) 272 | } 273 | } 274 | } 275 | return result 276 | } 277 | 278 | // Set an array element 279 | func (array *sparseCooF64Matrix) ItemSet(value float64, index ...int) { 280 | if len(index) != 2 || index[0] >= array.shape[0] || index[1] >= array.shape[1] { 281 | panic(fmt.Sprintf("Item indices %v invalid for array shape %v", index, array.shape)) 282 | } 283 | if array.transpose { 284 | index[0], index[1] = index[1], index[0] 285 | } 286 | if value == 0 { 287 | delete(array.values[index[0]], index[1]) 288 | } else { 289 | array.values[index[0]][index[1]] = value 290 | } 291 | } 292 | 293 | // Solve for x, where ax = b. 294 | func (array sparseCooF64Matrix) LDivide(b Matrix) Matrix { 295 | return LDivide(&array, b) 296 | } 297 | 298 | // Get the result of matrix multiplication between this and some other 299 | // array(s). All arrays must have two dimensions, and the dimensions must 300 | // be aligned correctly for multiplication. 301 | // If A is m x p and B is p x n, then C = A.MProd(B) is the m x n matrix 302 | // with C[i, j] = \sum_{k=1}^p A[i,k] * B[k,j]. 303 | func (array sparseCooF64Matrix) MProd(others ...Matrix) Matrix { 304 | return MProd(&array, others...) 305 | } 306 | 307 | // Get the value of the largest array element 308 | func (array sparseCooF64Matrix) Max() float64 { 309 | return Max(&array) 310 | } 311 | 312 | // Get the value of the smallest array element 313 | func (array sparseCooF64Matrix) Min() float64 { 314 | return Min(&array) 315 | } 316 | 317 | // The number of dimensions in the matrix 318 | func (array sparseCooF64Matrix) NDim() int { 319 | return len(array.shape) 320 | } 321 | 322 | // Get the matrix norm of the specified ordinality (1, 2, infinity, ...) 323 | func (array sparseCooF64Matrix) Norm(ord float64) float64 { 324 | return Norm(&array, ord) 325 | } 326 | 327 | // Return a copy of the array, normalized to sum to 1 328 | func (array *sparseCooF64Matrix) Normalize() NDArray { 329 | return Normalize(array) 330 | } 331 | 332 | // Return the element-wise product of this array and one or more others 333 | func (array sparseCooF64Matrix) Prod(other ...NDArray) NDArray { 334 | return Prod(&array, other...) 335 | } 336 | 337 | // Get a 1D copy of the array, in 'C' order: rightmost axes change fastest 338 | func (array sparseCooF64Matrix) Ravel() NDArray { 339 | return Ravel(&array) 340 | } 341 | 342 | // Set the values of the items on a given row 343 | func (array *sparseCooF64Matrix) RowSet(row int, values []float64) { 344 | if row < 0 || row >= array.shape[0] { 345 | panic(fmt.Sprintf("RowSet can't set row %d of a %d-row array", row, array.shape[0])) 346 | } else if len(values) != array.shape[1] { 347 | panic(fmt.Sprintf("RowSet has %d columns but got %d values", array.shape[1], len(values))) 348 | } 349 | for col := 0; col < array.shape[1]; col++ { 350 | array.ItemSet(values[col], row, col) 351 | } 352 | } 353 | 354 | // Get a particular row for read-only access. May or may not be a copy. 355 | func (array sparseCooF64Matrix) Row(row int) []float64 { 356 | if row < 0 || row >= array.shape[0] { 357 | panic(fmt.Sprintf("Can't get row %d from a %dx%d array", row, array.shape[0], array.shape[1])) 358 | } 359 | result := make([]float64, array.shape[0]) 360 | for col, val := range array.values[row] { 361 | result[col] = val 362 | } 363 | return result 364 | } 365 | 366 | // Get the number of rows 367 | func (array sparseCooF64Matrix) Rows() int { 368 | return array.shape[0] 369 | } 370 | 371 | // A slice giving the size of all array dimensions 372 | func (array sparseCooF64Matrix) Shape() []int { 373 | return array.shape 374 | } 375 | 376 | // The total number of elements in the matrix 377 | func (array sparseCooF64Matrix) Size() int { 378 | return array.shape[0] * array.shape[1] 379 | } 380 | 381 | // Get an array containing a rectangular slice of this array. 382 | // `from` and `to` should both have one index per axis. The indices 383 | // in `from` and `to` define the first and just-past-last indices you wish 384 | // to select along each axis. 385 | func (array sparseCooF64Matrix) Slice(from []int, to []int) NDArray { 386 | return Slice(&array, from, to) 387 | } 388 | 389 | // Return a sparse coo copy of the matrix. The method will panic 390 | // if any off-diagonal elements are nonzero. 391 | func (array sparseCooF64Matrix) SparseCoo() Matrix { 392 | return array.copy() 393 | } 394 | 395 | // Return a sparse diag copy of the matrix. The method will panic 396 | // if any off-diagonal elements are nonzero. 397 | func (array sparseCooF64Matrix) SparseDiag() Matrix { 398 | m := SparseDiag(array.shape[0], array.shape[1]) 399 | array.VisitNonzero(func(pos []int, value float64) bool { 400 | m.ItemSet(value, pos[0], pos[1]) 401 | return true 402 | }) 403 | return m 404 | } 405 | 406 | // Ask whether the matrix has a sparse representation (useful for optimization) 407 | func (array sparseCooF64Matrix) Sparsity() ArraySparsity { 408 | return SparseCooMatrix 409 | } 410 | 411 | // Return the element-wise difference of this array and one or more others 412 | func (array sparseCooF64Matrix) Sub(other ...NDArray) NDArray { 413 | return Sub(&array, other...) 414 | } 415 | 416 | // Return the sum of all array elements 417 | func (array sparseCooF64Matrix) Sum() float64 { 418 | return Sum(&array) 419 | } 420 | 421 | // Returns the array as a matrix. This is only possible for 1D and 2D arrays; 422 | // 1D arrays of length n are converted into n x 1 vectors. 423 | func (array sparseCooF64Matrix) M() Matrix { 424 | return &array 425 | } 426 | 427 | // Return the same matrix, but with axes transposed. The same data is used, 428 | // for speed and memory efficiency. Use Copy() to create a new array. 429 | func (array sparseCooF64Matrix) T() Matrix { 430 | return &sparseCooF64Matrix{ 431 | shape: []int{array.shape[1], array.shape[0]}, 432 | values: array.values, 433 | transpose: !array.transpose, 434 | } 435 | } 436 | 437 | // Visit all matrix elements, invoking a method on each. If the method 438 | // returns false, iteration is aborted and VisitNonzero() returns false. 439 | // Otherwise, it returns true. 440 | func (array sparseCooF64Matrix) Visit(f func(pos []int, value float64) bool) bool { 441 | for row := 0; row < array.shape[0]; row++ { 442 | for col := 0; col < array.shape[1]; col++ { 443 | if array.transpose { 444 | if !f([]int{row, col}, array.values[col][row]) { 445 | return false 446 | } 447 | } else { 448 | if !f([]int{row, col}, array.values[row][col]) { 449 | return false 450 | } 451 | } 452 | } 453 | } 454 | return true 455 | } 456 | 457 | // Visit just nonzero elements, invoking a method on each. If the method 458 | // returns false, iteration is aborted and VisitNonzero() returns false. 459 | // Otherwise, it returns true. 460 | func (array sparseCooF64Matrix) VisitNonzero(f func(pos []int, value float64) bool) bool { 461 | for row, val := range array.values { 462 | for col, v := range val { 463 | if array.transpose { 464 | if !f([]int{col, row}, v) { 465 | return false 466 | } 467 | } else { 468 | if !f([]int{row, col}, v) { 469 | return false 470 | } 471 | } 472 | } 473 | } 474 | return true 475 | } 476 | -------------------------------------------------------------------------------- /matrix/base.go: -------------------------------------------------------------------------------- 1 | package matrix 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | ) 7 | 8 | // Get the flat index for the specified indices. Negative indexing is supported: 9 | // an index of -1 refers to the final array element. 10 | func ndToFlat(shape []int, index []int) int { 11 | if len(index) != len(shape) { 12 | panic(fmt.Sprintf("Indices %v invalid for array shape %v", index, shape)) 13 | } 14 | flat := 0 15 | for i := range shape { 16 | if index[i] >= shape[i] || index[i] < -shape[i] { 17 | panic(fmt.Sprintf("Indices %v invalid for array shape %v", index, shape)) 18 | } else if index[i] < 0 { 19 | flat += index[i] + shape[i] 20 | } else { 21 | flat += index[i] 22 | } 23 | if i < len(shape)-1 { 24 | flat *= shape[i+1] 25 | } 26 | } 27 | return flat 28 | } 29 | 30 | // Get the indices for the specified flat index. 31 | func flatToNd(shape []int, flat int) []int { 32 | size := 1 33 | for _, v := range shape { 34 | size *= v 35 | } 36 | if flat >= size || flat < -size { 37 | panic(fmt.Sprintf("Flat index %v invalid for array shape %v", flat, shape)) 38 | } 39 | if flat < 0 { 40 | flat += size 41 | } 42 | index := make([]int, len(shape)) 43 | for axis := 0; axis < len(index); axis++ { 44 | index[axis] = flat 45 | for a2 := axis + 1; a2 < len(index); a2++ { 46 | index[axis] /= shape[a2] 47 | } 48 | index[axis] %= shape[axis] 49 | } 50 | return index 51 | } 52 | 53 | // Return the element-wise sum of this array and one or more others 54 | func Add(array NDArray, others ...NDArray) NDArray { 55 | var result NDArray 56 | sp := array.Sparsity() 57 | sh := array.Shape() 58 | for _, o := range others { 59 | switch o.Sparsity() { 60 | case DenseArray: 61 | sp = DenseArray 62 | case SparseCooMatrix: 63 | if sp == SparseDiagMatrix { 64 | sp = SparseCooMatrix 65 | } 66 | } 67 | sh2 := o.Shape() 68 | if len(sh2) != len(sh) { 69 | panic(fmt.Sprintf("Can't add arrays with shapes %v and %v", sh, sh2)) 70 | } 71 | for i := range sh { 72 | if sh[i] != sh2[i] { 73 | panic(fmt.Sprintf("Can't add arrays with shapes %v and %v", sh, sh2)) 74 | } 75 | } 76 | } 77 | 78 | switch sp { 79 | case array.Sparsity(): 80 | result = array.Copy() 81 | case DenseArray: 82 | result = array.Dense() 83 | case SparseCooMatrix: 84 | result = SparseCoo(sh[0], sh[1]) 85 | array.VisitNonzero(func(pos []int, value float64) bool { 86 | result.ItemSet(value, pos...) 87 | return true 88 | }) 89 | } 90 | for _, o := range others { 91 | o.VisitNonzero(func(pos []int, value float64) bool { 92 | result.ItemSet(result.Item(pos...)+value, pos...) 93 | return true 94 | }) 95 | } 96 | return result 97 | } 98 | 99 | // Returns true if and only if all items are nonzero 100 | func All(array NDArray) bool { 101 | switch array.Sparsity() { 102 | case SparseDiagMatrix: 103 | return false 104 | default: 105 | return array.CountNonzero() == array.Size() 106 | } 107 | } 108 | 109 | // Returns true if f is true for all array elements 110 | func AllF(array NDArray, f func(v float64) bool) bool { 111 | counted := 0 112 | all := array.VisitNonzero(func(pos []int, value float64) bool { 113 | counted++ 114 | if !f(value) { 115 | return false 116 | } 117 | return true 118 | }) 119 | if !all || (counted < array.Size() && !f(0)) { 120 | return false 121 | } 122 | return true 123 | } 124 | 125 | // Returns true if f is true for all pairs of array elements in the same position 126 | func AllF2(array NDArray, f func(v1, v2 float64) bool, other NDArray) bool { 127 | sh1 := array.Shape() 128 | sh2 := other.Shape() 129 | if len(sh1) != len(sh2) { 130 | panic("AllF2() requires two arrays of the same shape") 131 | } 132 | for i := 0; i < len(sh1); i++ { 133 | if sh1[i] != sh2[i] { 134 | panic("AllF2() requires two arrays of the same shape") 135 | } 136 | } 137 | size := array.Size() 138 | for i := 0; i < size; i++ { 139 | if !f(array.FlatItem(i), other.FlatItem(i)) { 140 | return false 141 | } 142 | } 143 | return true 144 | } 145 | 146 | // Returns true if and only if any item is nonzero 147 | func Any(array NDArray) bool { 148 | return !array.VisitNonzero(func(pos []int, value float64) bool { 149 | return false 150 | }) 151 | } 152 | 153 | // Returns true if f is true for any array element 154 | func AnyF(array NDArray, f func(v float64) bool) bool { 155 | counted := 0 156 | allFalse := array.VisitNonzero(func(pos []int, value float64) bool { 157 | counted++ 158 | if f(value) { 159 | return false 160 | } 161 | return true 162 | }) 163 | if !allFalse || (counted < array.Size() && f(0)) { 164 | return true 165 | } 166 | return false 167 | } 168 | 169 | // Returns true if f is true for any pair of array elements in the same position 170 | func AnyF2(array NDArray, f func(v1, v2 float64) bool, other NDArray) bool { 171 | sh1 := array.Shape() 172 | sh2 := other.Shape() 173 | if len(sh1) != len(sh2) { 174 | panic("AnyF2() requires two arrays of the same shape") 175 | } 176 | for i := 0; i < len(sh1); i++ { 177 | if sh1[i] != sh2[i] { 178 | panic("AnyF2() requires two arrays of the same shape") 179 | } 180 | } 181 | size := array.Size() 182 | for i := 0; i < size; i++ { 183 | if f(array.FlatItem(i), other.FlatItem(i)) { 184 | return true 185 | } 186 | } 187 | return false 188 | } 189 | 190 | // Return the result of applying a function to all elements 191 | func Apply(array NDArray, f func(float64) float64) NDArray { 192 | result := array.Dense() 193 | size := result.Size() 194 | for i := 0; i < size; i++ { 195 | value := f(result.FlatItem(i)) 196 | result.FlatItemSet(value, i) 197 | } 198 | return result 199 | } 200 | 201 | // Create a new array by concatenating this with one or more others along the 202 | // specified axis. The array shapes must be equal along all other axes. 203 | // It is legal to add a new axis. 204 | func Concat(axis int, array NDArray, others ...NDArray) NDArray { 205 | if len(others) < 1 { 206 | return array.Copy() 207 | } 208 | 209 | // Calculate the new array shape 210 | shs := make([][]int, 1+len(others)) 211 | shs[0] = array.Shape() 212 | if axis > len(shs[0]) { 213 | panic(fmt.Sprintf("Can't concat %d-d arrays along invalid axis %d", len(shs[0]), axis)) 214 | } 215 | for i := 1; i < len(shs); i++ { 216 | shs[i] = others[i-1].Shape() 217 | if len(shs[0]) != len(shs[i]) { 218 | panic(fmt.Sprintf("Can't concat arrays with %d and %d dims", len(shs[0]), len(shs[i]))) 219 | } 220 | } 221 | 222 | var shOut []int 223 | if axis == len(shs[0]) { 224 | shOut = make([]int, len(shs[0])+1) 225 | } else { 226 | shOut = make([]int, len(shs[0])) 227 | } 228 | for i := range shOut { 229 | if i != axis { 230 | for j := 1; j < len(shs); j++ { 231 | if shs[0][i] != shs[j][i] { 232 | panic(fmt.Sprintf("Can't concat arrays along axis %d with unequal size on axis %d", axis, i)) 233 | } 234 | } 235 | shOut[i] = shs[0][i] 236 | } else if i < len(shs[0]) { 237 | for j := 0; j < len(shs); j++ { 238 | shOut[i] += shs[j][i] 239 | } 240 | } else { 241 | shOut[i] = len(shs) 242 | } 243 | } 244 | result := Dense(shOut...) 245 | 246 | // Copy the arrays 247 | size := result.Size() 248 | var ( 249 | value float64 250 | src []int 251 | ) 252 | for i := 0; i < size; i++ { 253 | src = flatToNd(shOut, i) 254 | if axis < len(shs[0]) { 255 | // Not creating a new dimension 256 | for j := 0; j < len(shs); j++ { 257 | if src[axis] >= shs[j][axis] { 258 | src[axis] -= shs[j][axis] 259 | } else if j == 0 { 260 | value = array.Item(src...) 261 | break 262 | } else { 263 | value = others[j-1].Item(src...) 264 | break 265 | } 266 | } 267 | } else if src[axis] == 0 { 268 | value = array.Item(src[:axis]...) 269 | } else { 270 | value = others[src[axis]-1].Item(src[:axis]...) 271 | } 272 | 273 | result.FlatItemSet(value, i) 274 | } 275 | 276 | return result 277 | } 278 | 279 | // Treat the rows as points, and get the pairwise distance between them. 280 | // Returns a distance matrix D such that D_i,j is the distance between 281 | // rows i and j. 282 | func Dist(m Matrix, t DistType) Matrix { 283 | var dist = Dense(m.Rows(), m.Rows()).M() 284 | for i := 1; i < m.Rows(); i++ { 285 | ri := m.Row(i) 286 | for j := 0; j <= i; j++ { 287 | rj := m.Row(j) 288 | var v float64 289 | switch t { 290 | case EuclideanDist: 291 | for idx, riv := range ri { 292 | v += math.Pow(riv-rj[idx], 2) 293 | } 294 | v = math.Sqrt(v) 295 | default: 296 | panic(fmt.Sprintf("Can't calculate distance of invalid type %v", t)) 297 | } 298 | dist.ItemSet(v, i, j) 299 | dist.ItemSet(v, j, i) 300 | } 301 | } 302 | return dist 303 | } 304 | 305 | // Return the element-wise quotient of this array and one or more others. 306 | // This function defines 0 / 0 = 0, so it's useful for sparse arrays. 307 | func Div(array NDArray, others ...NDArray) NDArray { 308 | sh := array.Shape() 309 | for _, o := range others { 310 | sh2 := o.Shape() 311 | if len(sh2) != len(sh) { 312 | panic(fmt.Sprintf("Can't divide arrays with shapes %v and %v", sh, sh2)) 313 | } 314 | for i := range sh { 315 | if sh[i] != sh2[i] { 316 | panic(fmt.Sprintf("Can't divide arrays with shapes %v and %v", sh, sh2)) 317 | } 318 | } 319 | } 320 | 321 | result := array.Copy() 322 | for _, o := range others { 323 | result.VisitNonzero(func(pos []int, value float64) bool { 324 | result.ItemSet(value/o.Item(pos...), pos...) 325 | return true 326 | }) 327 | } 328 | return result 329 | } 330 | 331 | // Returns true if and only if all elements in the two arrays are equal 332 | func Equal(array, other NDArray) bool { 333 | sh1 := array.Shape() 334 | sh2 := other.Shape() 335 | if len(sh1) != len(sh2) { 336 | return false 337 | } 338 | for d := 0; d < len(sh1); d++ { 339 | if sh1[d] != sh2[d] { 340 | return false 341 | } 342 | } 343 | 344 | size := array.Size() 345 | for idx := 0; idx < size; idx++ { 346 | if array.FlatItem(idx) != other.FlatItem(idx) { 347 | return false 348 | } 349 | } 350 | return true 351 | } 352 | 353 | // Set all array elements to the given value 354 | func Fill(array NDArray, value float64) { 355 | if array.Sparsity() != DenseArray { 356 | panic("Can't Fill() a sparse array") 357 | } 358 | size := array.Size() 359 | for idx := 0; idx < size; idx++ { 360 | array.FlatItemSet(value, idx) 361 | } 362 | } 363 | 364 | // Add a scalar value to each array element 365 | func ItemAdd(array NDArray, value float64) NDArray { 366 | if value == 0 { 367 | return array.Copy() 368 | } 369 | result := array.Dense() 370 | size := result.Size() 371 | for idx := 0; idx < size; idx++ { 372 | result.FlatItemSet(result.FlatItem(idx)+value, idx) 373 | } 374 | return result 375 | } 376 | 377 | // Divide each array element by a scalar value 378 | func ItemDiv(array NDArray, value float64) NDArray { 379 | if value == 1 { 380 | return array.Copy() 381 | } 382 | result := array.Copy() 383 | result.VisitNonzero(func(pos []int, v float64) bool { 384 | result.ItemSet(v/value, pos...) 385 | return true 386 | }) 387 | return result 388 | } 389 | 390 | // Multiply each array element by a scalar value 391 | func ItemProd(array NDArray, value float64) NDArray { 392 | if value == 1 { 393 | return array.Copy() 394 | } 395 | result := array.Copy() 396 | result.VisitNonzero(func(pos []int, v float64) bool { 397 | result.ItemSet(v*value, pos...) 398 | return true 399 | }) 400 | return result 401 | } 402 | 403 | // Subtract a scalar value from each array element 404 | func ItemSub(array NDArray, value float64) NDArray { 405 | if value == 0 { 406 | return array.Copy() 407 | } 408 | result := array.Dense() 409 | size := result.Size() 410 | for idx := 0; idx < size; idx++ { 411 | result.FlatItemSet(result.FlatItem(idx)-value, idx) 412 | } 413 | return result 414 | } 415 | 416 | // Get the result of matrix multiplication between this and some other 417 | // array(s). All arrays must have two dimensions, and the dimensions must 418 | // be aligned correctly for multiplication. 419 | // If A is m x p and B is p x n, then C = A.MProd(B) is the m x n matrix 420 | // with C[i, j] = \sum_{k=1}^p A[i,k] * B[k,j]. 421 | func MProd(array Matrix, others ...Matrix) Matrix { 422 | if len(others) < 1 { 423 | return array.Copy().M() 424 | } 425 | var ( 426 | left = array 427 | leftSh = array.Shape() 428 | leftSp = array.Sparsity() 429 | result Matrix 430 | ) 431 | for _, right := range others { 432 | rightSh := right.Shape() 433 | rightSp := right.Sparsity() 434 | if leftSh[1] != rightSh[0] { 435 | panic(fmt.Sprintf("Can't MProd a %dx%d to a %dx%d array; inner dimensions must match", leftSh[0], leftSh[1], rightSh[0], rightSh[1])) 436 | } 437 | 438 | if leftSp == SparseDiagMatrix { 439 | lDiag := left.Diag().Array() 440 | switch rightSp { 441 | case SparseDiagMatrix: 442 | rDiag := right.Diag().Array() 443 | resDiag := make([]float64, len(lDiag)) 444 | for idx, v := range lDiag { 445 | resDiag[idx] = v * rDiag[idx] 446 | } 447 | result = Diag(resDiag...) 448 | case SparseCooMatrix: 449 | result = SparseCoo(leftSh[0], rightSh[1]) 450 | spRes := result.(*sparseCooF64Matrix) 451 | right.VisitNonzero(func(pos []int, value float64) bool { 452 | spRes.values[pos[0]][pos[1]] += lDiag[pos[0]] * value 453 | return true 454 | }) 455 | default: 456 | result = Dense(leftSh[0], rightSh[1]).M() 457 | resArr := result.Array() 458 | rArr := right.Array() 459 | for i := 0; i < leftSh[0]; i++ { 460 | for j := 0; j < rightSh[1]; j++ { 461 | resArr[i*rightSh[1]+j] = lDiag[i] * rArr[i*rightSh[1]+j] 462 | } 463 | } 464 | } 465 | 466 | } else if leftSp == SparseCooMatrix && rightSp == SparseCooMatrix { 467 | result = SparseCoo(leftSh[0], rightSh[1]) 468 | spRes := result.(*sparseCooF64Matrix) 469 | spRight := right.(*sparseCooF64Matrix) 470 | left.VisitNonzero(func(pos []int, value float64) bool { 471 | for j := 0; j < rightSh[1]; j++ { 472 | spRes.values[pos[0]][j] += value * spRight.values[pos[1]][j] 473 | } 474 | return true 475 | }) 476 | 477 | } else if rightSp == SparseDiagMatrix { 478 | rDiag := right.Diag().Array() 479 | if leftSp == SparseCooMatrix { 480 | resArr := make([]float64, leftSh[0]*rightSh[1]) 481 | left.VisitNonzero(func(pos []int, value float64) bool { 482 | resArr[pos[0]*rightSh[1]+pos[1]] += value * rDiag[pos[1]] 483 | return true 484 | }) 485 | result = SparseCoo(leftSh[0], rightSh[1]) 486 | for idx, v := range resArr { 487 | if v != 0 { 488 | result.FlatItemSet(v, idx) 489 | } 490 | } 491 | } else { 492 | result = Dense(leftSh[0], rightSh[1]).M() 493 | resArr := result.Array() 494 | lArr := left.Array() 495 | for i := 0; i < leftSh[0]; i++ { 496 | for j := 0; j < rightSh[1]; j++ { 497 | resArr[i*rightSh[1]+j] = lArr[i*leftSh[1]+j] * rDiag[j] 498 | } 499 | } 500 | } 501 | 502 | } else { 503 | result = Dense(leftSh[0], rightSh[1]).M() 504 | resArr := result.Array() 505 | lArr := left.Array() 506 | rArr := right.Array() 507 | for i := 0; i < leftSh[0]; i++ { 508 | for j := 0; j < rightSh[1]; j++ { 509 | value := 0.0 510 | for k := 0; k < leftSh[1]; k++ { 511 | value += lArr[i*leftSh[1]+k] * rArr[k*rightSh[1]+j] 512 | } 513 | resArr[i*rightSh[1]+j] = value 514 | } 515 | } 516 | } 517 | 518 | left = result 519 | leftSh = result.Shape() 520 | leftSp = result.Sparsity() 521 | } 522 | return result 523 | } 524 | 525 | // Get the value of the largest array element 526 | func Max(array NDArray) float64 { 527 | max := math.Inf(-1) 528 | counted := 0 529 | array.VisitNonzero(func(pos []int, value float64) bool { 530 | counted++ 531 | if value > max { 532 | max = value 533 | } 534 | return true 535 | }) 536 | if max < 0 && counted < array.Size() { 537 | max = 0 538 | } 539 | return max 540 | } 541 | 542 | // Get the value of the smallest array element 543 | func Min(array NDArray) float64 { 544 | min := math.Inf(+1) 545 | counted := 0 546 | array.VisitNonzero(func(pos []int, value float64) bool { 547 | counted++ 548 | if value < min { 549 | min = value 550 | } 551 | return true 552 | }) 553 | if min > 0 && counted < array.Size() { 554 | min = 0 555 | } 556 | return min 557 | } 558 | 559 | // Return a copy of the array, normalized to sum to 1 560 | func Normalize(array NDArray) NDArray { 561 | s := array.Sum() 562 | if s != 0 && s != 1 { 563 | return array.ItemDiv(s) 564 | } else { 565 | return array.Copy() 566 | } 567 | } 568 | 569 | // Return the element-wise product of this array and one or more others 570 | func Prod(array NDArray, others ...NDArray) NDArray { 571 | sh := array.Shape() 572 | for _, o := range others { 573 | sh2 := o.Shape() 574 | if len(sh2) != len(sh) { 575 | panic(fmt.Sprintf("Can't multiply arrays with shapes %v and %v", sh, sh2)) 576 | } 577 | for i := range sh { 578 | if sh[i] != sh2[i] { 579 | panic(fmt.Sprintf("Can't multiply arrays with shapes %v and %v", sh, sh2)) 580 | } 581 | } 582 | } 583 | 584 | result := array.Copy() 585 | for _, o := range others { 586 | result.VisitNonzero(func(pos []int, value float64) bool { 587 | result.ItemSet(value*o.Item(pos...), pos...) 588 | return true 589 | }) 590 | } 591 | return result 592 | } 593 | 594 | // Get a 1D copy of the array, in 'C' order: rightmost axes change fastest 595 | func Ravel(array NDArray) NDArray { 596 | result := Dense(array.Size()) 597 | shape := array.Shape() 598 | array.VisitNonzero(func(pos []int, value float64) bool { 599 | result.ItemSet(value, ndToFlat(shape, pos)) 600 | return true 601 | }) 602 | return result 603 | } 604 | 605 | // Get an array containing a rectangular slice of this array. 606 | // `from` and `to` should both have one index per axis. The indices 607 | // in `from` and `to` define the first and just-past-last indices you wish 608 | // to select along each axis. You can also use negative indices to represent the 609 | // distance from the end of the array, where -1 represents the element just past 610 | // the end of the array. 611 | func Slice(array NDArray, from []int, to []int) NDArray { 612 | sh := array.Shape() 613 | if len(from) != len(sh) || len(to) != len(sh) { 614 | panic("Invalid Slice() indices: the arguments should have the same length as the array") 615 | } 616 | 617 | // Convert negative indices 618 | start := make([]int, len(sh)) 619 | for idx, v := range from { 620 | if v < 0 { 621 | start[idx] = v + sh[idx] + 1 622 | } else { 623 | start[idx] = v 624 | } 625 | } 626 | stop := make([]int, len(sh)) 627 | for idx, v := range to { 628 | if v < 0 { 629 | stop[idx] = v + sh[idx] + 1 630 | } else { 631 | stop[idx] = v 632 | } 633 | if stop[idx] < start[idx] { 634 | panic(fmt.Sprintf("Invalid Slice() indices: %d is before %d", to[idx], from[idx])) 635 | } 636 | } 637 | 638 | // Create an empty array 639 | shape := make([]int, len(sh)) 640 | for idx := range shape { 641 | shape[idx] = stop[idx] - start[idx] 642 | } 643 | result := Dense(shape...) 644 | 645 | // Copy the values into the new array 646 | size := result.Size() 647 | index := make([]int, len(from)) 648 | copy(index[:], start[:]) 649 | 650 | for i := 0; i < size; i++ { 651 | result.FlatItemSet(array.Item(index...), i) 652 | for j := len(index) - 1; j >= 0; j-- { 653 | index[j]++ 654 | if index[j] == stop[j] { 655 | index[j] = start[j] 656 | } else { 657 | break 658 | } 659 | } 660 | } 661 | 662 | return result 663 | } 664 | 665 | // Return the element-wise difference of this array and one or more others 666 | func Sub(array NDArray, others ...NDArray) NDArray { 667 | var result NDArray 668 | sp := array.Sparsity() 669 | sh := array.Shape() 670 | for _, o := range others { 671 | switch o.Sparsity() { 672 | case DenseArray: 673 | sp = DenseArray 674 | case SparseCooMatrix: 675 | if sp == SparseDiagMatrix { 676 | sp = SparseCooMatrix 677 | } 678 | } 679 | sh2 := o.Shape() 680 | if len(sh2) != len(sh) { 681 | panic(fmt.Sprintf("Can't add arrays with shapes %v and %v", sh, sh2)) 682 | } 683 | for i := range sh { 684 | if sh[i] != sh2[i] { 685 | panic(fmt.Sprintf("Can't add arrays with shapes %v and %v", sh, sh2)) 686 | } 687 | } 688 | } 689 | 690 | switch sp { 691 | case array.Sparsity(): 692 | result = array.Copy() 693 | case DenseArray: 694 | result = array.Dense() 695 | case SparseCooMatrix: 696 | result = SparseCoo(sh[0], sh[1]) 697 | array.VisitNonzero(func(pos []int, value float64) bool { 698 | result.ItemSet(value, pos...) 699 | return true 700 | }) 701 | } 702 | for _, o := range others { 703 | o.VisitNonzero(func(pos []int, value float64) bool { 704 | result.ItemSet(result.Item(pos...)-value, pos...) 705 | return true 706 | }) 707 | } 708 | return result 709 | } 710 | 711 | // Return the sum of all array elements 712 | func Sum(array NDArray) float64 { 713 | var result float64 714 | array.VisitNonzero(func(pos []int, value float64) bool { 715 | result += value 716 | return true 717 | }) 718 | return result 719 | } 720 | -------------------------------------------------------------------------------- /matrix/sparse_diag_test.go: -------------------------------------------------------------------------------- 1 | package matrix 2 | 3 | import ( 4 | . "github.com/smartystreets/goconvey/convey" 5 | "math" 6 | "testing" 7 | ) 8 | 9 | func TestSparseDiagAddDivMulSub(t *testing.T) { 10 | Convey("Given two sparse diag matrices", t, func() { 11 | d1 := Diag(1, 2, 3, 4) 12 | d2 := Diag(5, 6, 7, 8) 13 | 14 | Convey("Add works", func() { 15 | a := d1.Add(d2) 16 | So(a.Shape(), ShouldResemble, []int{4, 4}) 17 | So(a.Array(), ShouldResemble, []float64{ 18 | 6, 0, 0, 0, 19 | 0, 8, 0, 0, 20 | 0, 0, 10, 0, 21 | 0, 0, 0, 12, 22 | }) 23 | }) 24 | 25 | Convey("Div works", func() { 26 | a := d2.Div(d1) 27 | So(a.Shape(), ShouldResemble, []int{4, 4}) 28 | So(a.Array(), ShouldResemble, []float64{ 29 | 5, 0, 0, 0, 30 | 0, 3, 0, 0, 31 | 0, 0, 7. / 3, 0, 32 | 0, 0, 0, 2, 33 | }) 34 | }) 35 | 36 | Convey("Prod works", func() { 37 | a := d2.Prod(d1) 38 | So(a.Shape(), ShouldResemble, []int{4, 4}) 39 | So(a.Array(), ShouldResemble, []float64{ 40 | 5, 0, 0, 0, 41 | 0, 12, 0, 0, 42 | 0, 0, 21, 0, 43 | 0, 0, 0, 32, 44 | }) 45 | }) 46 | 47 | Convey("Sub works", func() { 48 | a := d2.Sub(d1) 49 | So(a.Shape(), ShouldResemble, []int{4, 4}) 50 | So(a.Array(), ShouldResemble, []float64{ 51 | 4, 0, 0, 0, 52 | 0, 4, 0, 0, 53 | 0, 0, 4, 0, 54 | 0, 0, 0, 4, 55 | }) 56 | }) 57 | }) 58 | } 59 | 60 | func TestSparseDiagAllAny(t *testing.T) { 61 | Convey("Given partial and empty arrays", t, func() { 62 | partial := Diag(1, 2, 3) 63 | empty := Diag(0, 0, 0) 64 | 65 | Convey("All() is correct", func() { 66 | So(partial.All(), ShouldBeFalse) 67 | So(empty.All(), ShouldBeFalse) 68 | }) 69 | 70 | Convey("Any() is correct", func() { 71 | So(partial.Any(), ShouldBeTrue) 72 | So(empty.Any(), ShouldBeFalse) 73 | }) 74 | }) 75 | 76 | Convey("Given pos, neg, and mixed arrays", t, func() { 77 | pos := Diag(1, 2, 3) 78 | neg := Diag(-1, -2, -3) 79 | mixed := Diag(1, -3, 3) 80 | fge := func(v float64) bool { return v >= 0 } 81 | fgt := func(v float64) bool { return v > 0 } 82 | f2ge := func(v1, v2 float64) bool { return v1 >= v2 } 83 | f2gt := func(v1, v2 float64) bool { return v1 > v2 } 84 | 85 | Convey("AllF() is correct", func() { 86 | So(pos.AllF(fge), ShouldBeTrue) 87 | So(mixed.AllF(fge), ShouldBeFalse) 88 | So(neg.AllF(fge), ShouldBeFalse) 89 | }) 90 | 91 | Convey("AnyF() is correct", func() { 92 | So(pos.AnyF(fgt), ShouldBeTrue) 93 | So(mixed.AnyF(fgt), ShouldBeTrue) 94 | So(neg.AnyF(fgt), ShouldBeFalse) 95 | }) 96 | 97 | Convey("AllF2() is correct", func() { 98 | So(pos.AllF2(f2ge, neg), ShouldBeTrue) 99 | So(mixed.AllF2(f2ge, neg), ShouldBeFalse) 100 | So(neg.AllF2(f2ge, neg), ShouldBeTrue) 101 | }) 102 | 103 | Convey("AnyF2() is correct", func() { 104 | So(pos.AnyF2(f2gt, neg), ShouldBeTrue) 105 | So(mixed.AnyF2(f2gt, neg), ShouldBeTrue) 106 | So(neg.AnyF2(f2gt, neg), ShouldBeFalse) 107 | }) 108 | }) 109 | } 110 | 111 | func TestSparseDiagApply(t *testing.T) { 112 | Convey("Apply works", t, func() { 113 | a := Diag(1, 2, 3) 114 | a2 := a.Apply(func(v float64) float64 { return 2 * v }) 115 | So(a2.Shape(), ShouldResemble, []int{3, 3}) 116 | So(a2.Array(), ShouldResemble, []float64{ 117 | 2, 0, 0, 118 | 0, 4, 0, 119 | 0, 0, 6, 120 | }) 121 | }) 122 | } 123 | 124 | func TestSparseDiagConversion(t *testing.T) { 125 | Convey("Given an array", t, func() { 126 | a := SparseDiag(3, 4, 1, 3, 5) 127 | 128 | Convey("Conversion to Dense works", func() { 129 | b := a.Dense() 130 | So(b.Dense().Shape(), ShouldResemble, []int{3, 4}) 131 | So(b.Dense().Array(), ShouldResemble, []float64{ 132 | 1, 0, 0, 0, 133 | 0, 3, 0, 0, 134 | 0, 0, 5, 0, 135 | }) 136 | }) 137 | 138 | Convey("Conversion of transpose to Dense works", func() { 139 | b := a.T().Dense() 140 | So(b.Dense().Shape(), ShouldResemble, []int{4, 3}) 141 | So(b.Dense().Array(), ShouldResemble, []float64{ 142 | 1, 0, 0, 143 | 0, 3, 0, 144 | 0, 0, 5, 145 | 0, 0, 0, 146 | }) 147 | }) 148 | 149 | Convey("Conversion to sparse coo works", func() { 150 | b := a.SparseCoo() 151 | So(b.Dense().Shape(), ShouldResemble, []int{3, 4}) 152 | So(b.Dense().Array(), ShouldResemble, []float64{ 153 | 1, 0, 0, 0, 154 | 0, 3, 0, 0, 155 | 0, 0, 5, 0, 156 | }) 157 | }) 158 | 159 | Convey("Conversion of transpose to sparse coo works", func() { 160 | b := a.T().SparseCoo() 161 | So(b.Dense().Shape(), ShouldResemble, []int{4, 3}) 162 | So(b.Dense().Array(), ShouldResemble, []float64{ 163 | 1, 0, 0, 164 | 0, 3, 0, 165 | 0, 0, 5, 166 | 0, 0, 0, 167 | }) 168 | }) 169 | 170 | Convey("Conversion to diag works", func() { 171 | b := a.SparseDiag() 172 | So(b.Dense().Shape(), ShouldResemble, []int{3, 4}) 173 | So(b.Dense().Array(), ShouldResemble, []float64{ 174 | 1, 0, 0, 0, 175 | 0, 3, 0, 0, 176 | 0, 0, 5, 0, 177 | }) 178 | }) 179 | 180 | Convey("Conversion of transpose to diag works", func() { 181 | b := a.T().SparseDiag() 182 | So(b.Dense().Shape(), ShouldResemble, []int{4, 3}) 183 | So(b.Dense().Array(), ShouldResemble, []float64{ 184 | 1, 0, 0, 185 | 0, 3, 0, 186 | 0, 0, 5, 187 | 0, 0, 0, 188 | }) 189 | }) 190 | }) 191 | } 192 | 193 | func TestSparseDiagColColSetCols(t *testing.T) { 194 | Convey("Given an array", t, func() { 195 | a := Diag(1, 2, 3) 196 | 197 | Convey("Cols is correct", func() { 198 | So(a.Cols(), ShouldEqual, 3) 199 | }) 200 | 201 | Convey("Col panics with invalid input", func() { 202 | So(func() { a.Col(-1) }, ShouldPanic) 203 | So(func() { a.Col(3) }, ShouldPanic) 204 | }) 205 | 206 | Convey("Col works", func() { 207 | So(a.Col(0), ShouldResemble, []float64{1, 0, 0}) 208 | So(a.Col(1), ShouldResemble, []float64{0, 2, 0}) 209 | So(a.Col(2), ShouldResemble, []float64{0, 0, 3}) 210 | }) 211 | 212 | Convey("ColSet panics", func() { 213 | So(func() { a.ColSet(-1, []float64{0, 1, 2}) }, ShouldPanic) 214 | So(func() { a.ColSet(3, []float64{0, 1, 2}) }, ShouldPanic) 215 | So(func() { a.ColSet(0, []float64{0, 1}) }, ShouldPanic) 216 | So(func() { a.ColSet(0, []float64{0, 1, 2}) }, ShouldPanic) 217 | So(func() { a.ColSet(0, []float64{0, 1, 2, 3}) }, ShouldPanic) 218 | }) 219 | }) 220 | } 221 | 222 | func TestSparseDiagDiag(t *testing.T) { 223 | Convey("Given an array", t, func() { 224 | a := Diag(1, 2, 3) 225 | 226 | Convey("Diag() works", func() { 227 | d := a.Diag() 228 | So(d.Shape(), ShouldResemble, []int{3, 1}) 229 | So(d.Array(), ShouldResemble, []float64{1, 2, 3}) 230 | }) 231 | 232 | Convey("Transpose Diag() works", func() { 233 | d := a.T().M().Diag() 234 | So(d.Shape(), ShouldResemble, []int{3, 1}) 235 | So(d.Array(), ShouldResemble, []float64{1, 2, 3}) 236 | }) 237 | }) 238 | } 239 | 240 | func TestSparseDiagDist(t *testing.T) { 241 | Convey("Given a SparseDiag matrix", t, func() { 242 | m := Diag(1, 2, -1).M() 243 | 244 | Convey("Invalid distance types panic", func() { 245 | So(func() { m.Dist(DistType(-1)) }, ShouldPanic) 246 | }) 247 | 248 | Convey("Euclidean distance works", func() { 249 | d := m.Dist(EuclideanDist) 250 | So(d.Array(), ShouldResemble, []float64{ 251 | 0, math.Sqrt(5), math.Sqrt(2), 252 | math.Sqrt(5), 0, math.Sqrt(5), 253 | math.Sqrt(2), math.Sqrt(5), 0, 254 | }) 255 | }) 256 | }) 257 | } 258 | 259 | func TestSparseDiagInverseNormLDivide(t *testing.T) { 260 | Convey("Given an invertible diagonal matrix", t, func() { 261 | m := Diag(2, 3, 5) 262 | 263 | Convey("When I take the inverse", func() { 264 | mi, err := m.Inverse() 265 | So(err, ShouldBeNil) 266 | 267 | Convey("The inverse is correct", func() { 268 | So(mi.Shape(), ShouldResemble, []int{3, 3}) 269 | So(mi.Array(), ShouldResemble, []float64{ 270 | .5, 0, 0, 271 | 0, 1. / 3, 0, 272 | 0, 0, .2, 273 | }) 274 | }) 275 | 276 | Convey("The inverse gives us back I", func() { 277 | i := m.MProd(mi) 278 | So(i.Shape(), ShouldResemble, []int{3, 3}) 279 | So(i.Array(), ShouldResemble, []float64{ 280 | 1, 0, 0, 281 | 0, 1, 0, 282 | 0, 0, 1, 283 | }) 284 | }) 285 | }) 286 | }) 287 | 288 | Convey("Given a 3x3 matrix", t, func() { 289 | m := Diag(2, 3, 5) 290 | 291 | Convey("The 1-norm is correct", func() { 292 | So(m.Norm(1), ShouldEqual, 5) 293 | }) 294 | 295 | Convey("The 2-norm is correct", func() { 296 | So(m.Norm(2), ShouldEqual, 5) 297 | }) 298 | 299 | Convey("The inf-norm is correct", func() { 300 | So(m.Norm(math.Inf(1)), ShouldEqual, 5) 301 | }) 302 | }) 303 | 304 | Convey("Given a simple division problem", t, func() { 305 | a := Diag(4, 6, 8) 306 | b := Diag(2, 2, 2) 307 | 308 | Convey("When I solve the system", func() { 309 | x := a.LDivide(b) 310 | 311 | Convey("I get the correct solution", func() { 312 | So(x.Shape(), ShouldResemble, []int{3, 3}) 313 | So(x.Array(), ShouldResemble, []float64{ 314 | .5, 0, 0, 315 | 0, 1. / 3, 0, 316 | 0, 0, .25, 317 | }) 318 | }) 319 | 320 | Convey("The product ax = b is true", func() { 321 | b2 := a.MProd(x) 322 | So(b2.Shape(), ShouldResemble, []int{3, 3}) 323 | So(b2.Equal(b), ShouldBeTrue) 324 | }) 325 | }) 326 | }) 327 | } 328 | 329 | func TestSparseDiagItemMath(t *testing.T) { 330 | Convey("Given a diag array", t, func() { 331 | a := Diag(1, 2, 3) 332 | 333 | Convey("When I call ItemAdd", func() { 334 | a2 := a.ItemAdd(1) 335 | Convey("The result is correct", func() { 336 | So(a2.Array(), ShouldResemble, []float64{ 337 | 2, 1, 1, 338 | 1, 3, 1, 339 | 1, 1, 4, 340 | }) 341 | }) 342 | }) 343 | 344 | Convey("When I call ItemDiv", func() { 345 | a2 := a.ItemDiv(2) 346 | Convey("The result is correct", func() { 347 | So(a2.Array(), ShouldResemble, []float64{ 348 | .5, 0, 0, 349 | 0, 1, 0, 350 | 0, 0, 1.5, 351 | }) 352 | }) 353 | }) 354 | 355 | Convey("When I call ItemProd", func() { 356 | a2 := a.ItemProd(2) 357 | Convey("The result is correct", func() { 358 | So(a2.Array(), ShouldResemble, []float64{ 359 | 2, 0, 0, 360 | 0, 4, 0, 361 | 0, 0, 6, 362 | }) 363 | }) 364 | }) 365 | 366 | Convey("When I call ItemSub", func() { 367 | a2 := a.ItemSub(1) 368 | Convey("The result is correct", func() { 369 | So(a2.Array(), ShouldResemble, []float64{ 370 | 0, -1, -1, 371 | -1, 1, -1, 372 | -1, -1, 2, 373 | }) 374 | }) 375 | }) 376 | }) 377 | } 378 | 379 | func TestSparseDiagMaxMin(t *testing.T) { 380 | Convey("Given positive and negative diagonal arrays", t, func() { 381 | pos := Diag(1, 2, 3) 382 | neg := Diag(-1, -2, -3) 383 | 384 | Convey("Max is right", func() { 385 | So(pos.Max(), ShouldEqual, 3) 386 | So(neg.Max(), ShouldEqual, 0) 387 | }) 388 | 389 | Convey("Min is right", func() { 390 | So(pos.Min(), ShouldEqual, 0) 391 | So(neg.Min(), ShouldEqual, -3) 392 | }) 393 | }) 394 | } 395 | 396 | func TestSparseDiagRowRowSetRows(t *testing.T) { 397 | Convey("Given an array", t, func() { 398 | a := Diag(1, 2, 3) 399 | 400 | Convey("Rows is correct", func() { 401 | So(a.Rows(), ShouldEqual, 3) 402 | }) 403 | 404 | Convey("Row panics with invalid input", func() { 405 | So(func() { a.Row(-1) }, ShouldPanic) 406 | So(func() { a.Row(3) }, ShouldPanic) 407 | }) 408 | 409 | Convey("Row works", func() { 410 | So(a.Row(0), ShouldResemble, []float64{1, 0, 0}) 411 | So(a.Row(1), ShouldResemble, []float64{0, 2, 0}) 412 | So(a.Row(2), ShouldResemble, []float64{0, 0, 3}) 413 | }) 414 | 415 | Convey("RowSet panics", func() { 416 | So(func() { a.RowSet(-1, []float64{0, 1, 2}) }, ShouldPanic) 417 | So(func() { a.RowSet(3, []float64{0, 1, 2}) }, ShouldPanic) 418 | So(func() { a.RowSet(0, []float64{0, 1}) }, ShouldPanic) 419 | So(func() { a.RowSet(0, []float64{0, 1, 2}) }, ShouldPanic) 420 | So(func() { a.RowSet(0, []float64{0, 1, 2, 3}) }, ShouldPanic) 421 | }) 422 | }) 423 | } 424 | 425 | func TestSparseDiagVisit(t *testing.T) { 426 | Convey("Given a sparse diag array", t, func() { 427 | a := SparseDiag(4, 3, 1.0, 2.0, 3.0) 428 | 429 | Convey("Visit sees all items", func() { 430 | saw := Zeros(a.Shape()...) 431 | b := Zeros(a.Shape()...) 432 | count := 0 433 | a.Visit(func(pos []int, value float64) bool { 434 | count++ 435 | b.ItemSet(value, pos...) 436 | saw.ItemSet(1, pos...) 437 | return true 438 | }) 439 | So(count, ShouldEqual, 12) 440 | So(saw.CountNonzero(), ShouldEqual, 12) 441 | So(b.Array(), ShouldResemble, []float64{ 442 | 1, 0, 0, 443 | 0, 2, 0, 444 | 0, 0, 3, 445 | 0, 0, 0, 446 | }) 447 | }) 448 | 449 | Convey("Visit stops early if f() returns false", func() { 450 | saw := Zeros(a.Shape()...) 451 | b := Zeros(a.Shape()...) 452 | count := 0 453 | a.Visit(func(pos []int, value float64) bool { 454 | count++ 455 | b.ItemSet(value, pos...) 456 | saw.ItemSet(1, pos...) 457 | if saw.CountNonzero() >= 5 { 458 | return false 459 | } 460 | return true 461 | }) 462 | So(count, ShouldEqual, 5) 463 | So(saw.CountNonzero(), ShouldEqual, 5) 464 | So(b.Array(), ShouldResemble, []float64{ 465 | 1, 0, 0, 466 | 0, 2, 0, 467 | 0, 0, 0, 468 | 0, 0, 0, 469 | }) 470 | }) 471 | 472 | Convey("VisitNonzero sees just nonzero items", func() { 473 | saw := Zeros(a.Shape()...) 474 | b := Zeros(a.Shape()...) 475 | count := 0 476 | a.VisitNonzero(func(pos []int, value float64) bool { 477 | count++ 478 | b.ItemSet(value, pos...) 479 | saw.ItemSet(1, pos...) 480 | return true 481 | }) 482 | So(count, ShouldEqual, 3) 483 | So(saw.CountNonzero(), ShouldEqual, 3) 484 | So(b.Array(), ShouldResemble, []float64{ 485 | 1, 0, 0, 486 | 0, 2, 0, 487 | 0, 0, 3, 488 | 0, 0, 0, 489 | }) 490 | }) 491 | 492 | Convey("VisitNonzero stops early if f() returns false", func() { 493 | saw := Zeros(a.Shape()...) 494 | b := Zeros(a.Shape()...) 495 | count := 0 496 | a.VisitNonzero(func(pos []int, value float64) bool { 497 | count++ 498 | b.ItemSet(value, pos...) 499 | saw.ItemSet(1, pos...) 500 | if saw.CountNonzero() >= 2 { 501 | return false 502 | } 503 | return true 504 | }) 505 | So(count, ShouldEqual, 2) 506 | So(saw.CountNonzero(), ShouldEqual, 2) 507 | So(b.Array(), ShouldResemble, []float64{ 508 | 1, 0, 0, 509 | 0, 2, 0, 510 | 0, 0, 0, 511 | 0, 0, 0, 512 | }) 513 | }) 514 | 515 | Convey("T().Visit sees all items", func() { 516 | saw := Zeros(a.T().Shape()...) 517 | b := Zeros(a.T().Shape()...) 518 | count := 0 519 | a.T().Visit(func(pos []int, value float64) bool { 520 | count++ 521 | b.ItemSet(value, pos...) 522 | saw.ItemSet(1, pos...) 523 | return true 524 | }) 525 | So(count, ShouldEqual, 12) 526 | So(saw.CountNonzero(), ShouldEqual, 12) 527 | So(b.Array(), ShouldResemble, []float64{ 528 | 1, 0, 0, 0, 529 | 0, 2, 0, 0, 530 | 0, 0, 3, 0, 531 | }) 532 | }) 533 | 534 | Convey("T().Visit stops early if f() returns false", func() { 535 | saw := Zeros(a.T().Shape()...) 536 | b := Zeros(a.T().Shape()...) 537 | count := 0 538 | a.T().Visit(func(pos []int, value float64) bool { 539 | count++ 540 | b.ItemSet(value, pos...) 541 | saw.ItemSet(1, pos...) 542 | if saw.CountNonzero() >= 6 { 543 | return false 544 | } 545 | return true 546 | }) 547 | So(count, ShouldEqual, 6) 548 | So(saw.CountNonzero(), ShouldEqual, 6) 549 | So(b.Array(), ShouldResemble, []float64{ 550 | 1, 0, 0, 0, 551 | 0, 2, 0, 0, 552 | 0, 0, 0, 0, 553 | }) 554 | }) 555 | 556 | Convey("T().VisitNonzero sees just nonzero items", func() { 557 | saw := Zeros(a.T().Shape()...) 558 | b := Zeros(a.T().Shape()...) 559 | count := 0 560 | a.T().VisitNonzero(func(pos []int, value float64) bool { 561 | count++ 562 | b.ItemSet(value, pos...) 563 | saw.ItemSet(1, pos...) 564 | return true 565 | }) 566 | So(count, ShouldEqual, 3) 567 | So(saw.CountNonzero(), ShouldEqual, 3) 568 | So(b.Array(), ShouldResemble, []float64{ 569 | 1, 0, 0, 0, 570 | 0, 2, 0, 0, 571 | 0, 0, 3, 0, 572 | }) 573 | }) 574 | 575 | Convey("T().VisitNonzero stops early if f() returns false", func() { 576 | saw := Zeros(a.T().Shape()...) 577 | b := Zeros(a.T().Shape()...) 578 | count := 0 579 | a.T().VisitNonzero(func(pos []int, value float64) bool { 580 | count++ 581 | b.ItemSet(value, pos...) 582 | saw.ItemSet(1, pos...) 583 | if saw.CountNonzero() >= 2 { 584 | return false 585 | } 586 | return true 587 | }) 588 | So(count, ShouldEqual, 2) 589 | So(saw.CountNonzero(), ShouldEqual, 2) 590 | So(b.Array(), ShouldResemble, []float64{ 591 | 1, 0, 0, 0, 592 | 0, 2, 0, 0, 593 | 0, 0, 0, 0, 594 | }) 595 | }) 596 | }) 597 | } 598 | 599 | func TestSparseDiagMatrix(t *testing.T) { 600 | Convey("Given a sparse matrix with shape 5, 3", t, func() { 601 | array := SparseDiag(5, 3) 602 | 603 | Convey("All is false", func() { 604 | So(array.All(), ShouldBeFalse) 605 | }) 606 | 607 | Convey("Any is false", func() { 608 | So(array.Any(), ShouldBeFalse) 609 | }) 610 | 611 | Convey("CountNonzero is correct", func() { 612 | So(array.CountNonzero(), ShouldEqual, 0) 613 | }) 614 | 615 | Convey("Size is 15", func() { 616 | So(array.Size(), ShouldEqual, 15) 617 | }) 618 | 619 | Convey("NDim is 2", func() { 620 | So(array.NDim(), ShouldEqual, 2) 621 | }) 622 | 623 | Convey("Shape is (5, 3)", func() { 624 | So(array.Shape(), ShouldResemble, []int{5, 3}) 625 | }) 626 | 627 | Convey("Item() panics given invalid input", func() { 628 | So(func() { array.Item(0) }, ShouldPanic) 629 | So(func() { array.Item(0, 0, 0) }, ShouldPanic) 630 | So(func() { array.Item(5, 0) }, ShouldPanic) 631 | So(func() { array.Item(0, 3) }, ShouldPanic) 632 | }) 633 | 634 | Convey("ItemSet() panics given invalid input", func() { 635 | So(func() { array.ItemSet(1.0, 0) }, ShouldPanic) 636 | So(func() { array.ItemSet(1.0, 0, 0, 0) }, ShouldPanic) 637 | So(func() { array.ItemSet(1.0, 5, 0) }, ShouldPanic) 638 | So(func() { array.ItemSet(1.0, 0, 3) }, ShouldPanic) 639 | }) 640 | 641 | Convey("Item() returns the right values", func() { 642 | for i0 := 0; i0 < 5; i0++ { 643 | for i1 := 0; i1 < 3; i1++ { 644 | So(array.Item(i0, i1), ShouldEqual, 0) 645 | } 646 | } 647 | }) 648 | 649 | Convey("Sum() is zero", func() { 650 | So(array.Sum(), ShouldEqual, 0) 651 | }) 652 | 653 | Convey("When I call Normalize()", func() { 654 | norm := array.Normalize() 655 | Convey("Item() returns all zeros", func() { 656 | for i0 := 0; i0 < 5; i0++ { 657 | for i1 := 0; i1 < 3; i1++ { 658 | So(norm.Item(i0, i1), ShouldEqual, 0) 659 | } 660 | } 661 | }) 662 | }) 663 | 664 | Convey("When I call ItemSet", func() { 665 | array.ItemSet(1, 1, 1) 666 | array.ItemSet(2, 2, 2) 667 | 668 | Convey("Off-diagonal elements panic", func() { 669 | So(func() { array.ItemSet(2, 3, 2) }, ShouldPanic) 670 | }) 671 | 672 | Convey("All is false", func() { 673 | So(array.All(), ShouldBeFalse) 674 | }) 675 | 676 | Convey("Any is true", func() { 677 | So(array.Any(), ShouldBeTrue) 678 | }) 679 | 680 | Convey("Ravel is correct", func() { 681 | r := array.Ravel() 682 | So(r.Shape(), ShouldResemble, []int{15}) 683 | So(r.Array(), ShouldResemble, []float64{ 684 | 0, 0, 0, 685 | 0, 1, 0, 686 | 0, 0, 2, 687 | 0, 0, 0, 688 | 0, 0, 0, 689 | }) 690 | }) 691 | 692 | Convey("Slice works", func() { 693 | r := array.Slice([]int{1, 1}, []int{3, 3}) 694 | So(r.Shape(), ShouldResemble, []int{2, 2}) 695 | So(r.Array(), ShouldResemble, []float64{ 696 | 1, 0, 697 | 0, 2, 698 | }) 699 | }) 700 | 701 | Convey("CountNonzero is correct", func() { 702 | So(array.CountNonzero(), ShouldEqual, 2) 703 | }) 704 | 705 | Convey("Equal is correct", func() { 706 | So(array.Equal(SparseDiag(5, 3, 0, 1, 2)), ShouldBeTrue) 707 | }) 708 | 709 | Convey("Item() returns updates", func() { 710 | for i0 := 0; i0 < 5; i0++ { 711 | for i1 := 0; i1 < 3; i1++ { 712 | if i0 == 1 && i1 == 1 { 713 | So(array.Item(i0, i1), ShouldEqual, 1) 714 | } else if i0 == 2 && i1 == 2 { 715 | So(array.Item(i0, i1), ShouldEqual, 2) 716 | } else { 717 | So(array.Item(i0, i1), ShouldEqual, 0) 718 | } 719 | } 720 | } 721 | }) 722 | 723 | Convey("FlatItem returns the correct values", func() { 724 | for i := 0; i < array.Size(); i++ { 725 | switch i { 726 | case 1*3 + 1: 727 | So(array.FlatItem(i), ShouldEqual, 1) 728 | case 2*3 + 2: 729 | So(array.FlatItem(i), ShouldEqual, 2) 730 | default: 731 | So(array.FlatItem(i), ShouldEqual, 0) 732 | } 733 | } 734 | }) 735 | 736 | Convey("Sum() is correct", func() { 737 | So(array.Sum(), ShouldEqual, 3) 738 | }) 739 | 740 | Convey("When I call Normalize", func() { 741 | norm := array.Normalize() 742 | 743 | Convey("Item() is correct", func() { 744 | for i0 := 0; i0 < 5; i0++ { 745 | for i1 := 0; i1 < 3; i1++ { 746 | if i0 == 1 && i1 == 1 { 747 | So(norm.Item(i0, i1), ShouldEqual, 1.0/3) 748 | } else if i0 == 2 && i1 == 2 { 749 | So(norm.Item(i0, i1), ShouldEqual, 2.0/3) 750 | } else { 751 | So(norm.Item(i0, i1), ShouldEqual, 0) 752 | } 753 | } 754 | } 755 | }) 756 | 757 | Convey("Sum() is 1", func() { 758 | So(norm.Sum(), ShouldEqual, 1) 759 | }) 760 | }) 761 | }) 762 | 763 | Convey("When I call FlatItemSet", func() { 764 | array.FlatItemSet(1, 1*3+1) 765 | array.FlatItemSet(2, 2*3+2) 766 | 767 | Convey("Off-diagonal elements panic", func() { 768 | So(func() { array.FlatItemSet(3, 3*3+2) }, ShouldPanic) 769 | }) 770 | 771 | Convey("All is false", func() { 772 | So(array.All(), ShouldBeFalse) 773 | }) 774 | 775 | Convey("Any is true", func() { 776 | So(array.Any(), ShouldBeTrue) 777 | }) 778 | 779 | Convey("Item() returns updates", func() { 780 | for i0 := 0; i0 < 5; i0++ { 781 | for i1 := 0; i1 < 3; i1++ { 782 | if i0 == 1 && i1 == 1 { 783 | So(array.Item(i0, i1), ShouldEqual, 1) 784 | } else if i0 == 2 && i1 == 2 { 785 | So(array.Item(i0, i1), ShouldEqual, 2) 786 | } else { 787 | So(array.Item(i0, i1), ShouldEqual, 0) 788 | } 789 | } 790 | } 791 | }) 792 | 793 | Convey("FlatItem returns the correct values", func() { 794 | for i := 0; i < array.Size(); i++ { 795 | switch i { 796 | case 1*3 + 1: 797 | So(array.FlatItem(i), ShouldEqual, 1) 798 | case 2*3 + 2: 799 | So(array.FlatItem(i), ShouldEqual, 2) 800 | default: 801 | So(array.FlatItem(i), ShouldEqual, 0) 802 | } 803 | } 804 | }) 805 | }) 806 | 807 | Convey("Fill panics", func() { 808 | So(func() { array.Fill(3) }, ShouldPanic) 809 | }) 810 | }) 811 | 812 | Convey("Given two 2D sparse matrices of equal length", t, func() { 813 | a1 := SparseDiag(4, 3, 1.0, 1.0, 1.0) 814 | a2 := SparseDiag(4, 3, 2.0, 2.0, 2.0) 815 | 816 | Convey("Concat works along axis 0", func() { 817 | a3 := a1.Concat(0, a2) 818 | So(a3.Shape(), ShouldResemble, []int{8, 3}) 819 | for i0 := 0; i0 < 8; i0++ { 820 | for i1 := 0; i1 < 3; i1++ { 821 | if i0 == i1 { 822 | So(a3.Item(i0, i1), ShouldEqual, 1) 823 | } else if i0-4 == i1 { 824 | So(a3.Item(i0, i1), ShouldEqual, 2) 825 | } else { 826 | So(a3.Item(i0, i1), ShouldEqual, 0) 827 | } 828 | } 829 | } 830 | }) 831 | 832 | Convey("Concat works along axis 1", func() { 833 | a3 := a1.Concat(1, a2) 834 | So(a3.Shape(), ShouldResemble, []int{4, 6}) 835 | for i0 := 0; i0 < 4; i0++ { 836 | for i1 := 0; i1 < 6; i1++ { 837 | if i0 == i1 && i0 < 3 { 838 | So(a3.Item(i0, i1), ShouldEqual, 1) 839 | } else if i0 == i1-3 { 840 | So(a3.Item(i0, i1), ShouldEqual, 2) 841 | } else { 842 | So(a3.Item(i0, i1), ShouldEqual, 0) 843 | } 844 | } 845 | } 846 | }) 847 | 848 | Convey("Concat works along axis 2", func() { 849 | a3 := a1.Concat(2, a2) 850 | So(a3.Shape(), ShouldResemble, []int{4, 3, 2}) 851 | for i0 := 0; i0 < 4; i0++ { 852 | for i1 := 0; i1 < 3; i1++ { 853 | for i2 := 0; i2 < 2; i2++ { 854 | if i0 == i1 { 855 | if i2 == 0 { 856 | So(a3.Item(i0, i1, i2), ShouldEqual, 1) 857 | } else { 858 | So(a3.Item(i0, i1, i2), ShouldEqual, 2) 859 | } 860 | } else { 861 | So(a3.Item(i0, i1, i2), ShouldEqual, 0) 862 | } 863 | } 864 | } 865 | } 866 | }) 867 | 868 | Convey("Concat panics along axis 3", func() { 869 | So(func() { a1.Concat(3, a2) }, ShouldPanic) 870 | }) 871 | }) 872 | 873 | Convey("Given a 2x3 and 3x4 array", t, func() { 874 | left := Rand(2, 3).M() 875 | right := Rand(3, 4).M() 876 | Convey("MProd() works", func() { 877 | result := left.MProd(right) 878 | So(result.Shape(), ShouldResemble, []int{2, 4}) 879 | for i0 := 0; i0 < 2; i0++ { 880 | for i1 := 0; i1 < 4; i1++ { 881 | c := left.Item(i0, 0)*right.Item(0, i1) + 882 | left.Item(i0, 1)*right.Item(1, i1) + 883 | left.Item(i0, 2)*right.Item(2, i1) 884 | So(result.Item(i0, i1), ShouldEqual, c) 885 | } 886 | } 887 | }) 888 | }) 889 | } 890 | -------------------------------------------------------------------------------- /matrix/sparse_coo_test.go: -------------------------------------------------------------------------------- 1 | package matrix 2 | 3 | import ( 4 | . "github.com/smartystreets/goconvey/convey" 5 | "math" 6 | "testing" 7 | ) 8 | 9 | func cooDiag(values ...float64) Matrix { 10 | m := SparseCoo(len(values), len(values)) 11 | for i, v := range values { 12 | m.ItemSet(v, i, i) 13 | } 14 | return m 15 | } 16 | 17 | func TestSparseCooAddDivMulSub(t *testing.T) { 18 | Convey("Given two sparse diag matrices", t, func() { 19 | d1 := cooDiag(1, 2, 3, 4) 20 | d2 := cooDiag(5, 6, 7, 8) 21 | 22 | Convey("Add works", func() { 23 | a := d1.Add(d2) 24 | So(a.Shape(), ShouldResemble, []int{4, 4}) 25 | So(a.Array(), ShouldResemble, []float64{ 26 | 6, 0, 0, 0, 27 | 0, 8, 0, 0, 28 | 0, 0, 10, 0, 29 | 0, 0, 0, 12, 30 | }) 31 | }) 32 | 33 | Convey("Div works", func() { 34 | a := d2.Div(d1) 35 | So(a.Shape(), ShouldResemble, []int{4, 4}) 36 | So(a.Array(), ShouldResemble, []float64{ 37 | 5, 0, 0, 0, 38 | 0, 3, 0, 0, 39 | 0, 0, 7. / 3, 0, 40 | 0, 0, 0, 2, 41 | }) 42 | }) 43 | 44 | Convey("Prod works", func() { 45 | a := d2.Prod(d1) 46 | So(a.Shape(), ShouldResemble, []int{4, 4}) 47 | So(a.Array(), ShouldResemble, []float64{ 48 | 5, 0, 0, 0, 49 | 0, 12, 0, 0, 50 | 0, 0, 21, 0, 51 | 0, 0, 0, 32, 52 | }) 53 | }) 54 | 55 | Convey("Sub works", func() { 56 | a := d2.Sub(d1) 57 | So(a.Shape(), ShouldResemble, []int{4, 4}) 58 | So(a.Array(), ShouldResemble, []float64{ 59 | 4, 0, 0, 0, 60 | 0, 4, 0, 0, 61 | 0, 0, 4, 0, 62 | 0, 0, 0, 4, 63 | }) 64 | }) 65 | }) 66 | } 67 | 68 | func TestSparseCooAllAny(t *testing.T) { 69 | Convey("Given partial and empty arrays", t, func() { 70 | partial := cooDiag(1, 2, 3) 71 | empty := cooDiag(0, 0, 0) 72 | 73 | Convey("All() is correct", func() { 74 | So(partial.All(), ShouldBeFalse) 75 | So(empty.All(), ShouldBeFalse) 76 | }) 77 | 78 | Convey("Any() is correct", func() { 79 | So(partial.Any(), ShouldBeTrue) 80 | So(empty.Any(), ShouldBeFalse) 81 | }) 82 | }) 83 | 84 | Convey("Given pos, neg, and mixed arrays", t, func() { 85 | pos := cooDiag(1, 2, 3) 86 | neg := cooDiag(-1, -2, -3) 87 | mixed := cooDiag(1, -3, 3) 88 | fge := func(v float64) bool { return v >= 0 } 89 | fgt := func(v float64) bool { return v > 0 } 90 | f2ge := func(v1, v2 float64) bool { return v1 >= v2 } 91 | f2gt := func(v1, v2 float64) bool { return v1 > v2 } 92 | 93 | Convey("AllF() is correct", func() { 94 | So(pos.AllF(fge), ShouldBeTrue) 95 | So(mixed.AllF(fge), ShouldBeFalse) 96 | So(neg.AllF(fge), ShouldBeFalse) 97 | }) 98 | 99 | Convey("AnyF() is correct", func() { 100 | So(pos.AnyF(fgt), ShouldBeTrue) 101 | So(mixed.AnyF(fgt), ShouldBeTrue) 102 | So(neg.AnyF(fgt), ShouldBeFalse) 103 | }) 104 | 105 | Convey("AllF2() is correct", func() { 106 | So(pos.AllF2(f2ge, neg), ShouldBeTrue) 107 | So(mixed.AllF2(f2ge, neg), ShouldBeFalse) 108 | So(neg.AllF2(f2ge, neg), ShouldBeTrue) 109 | }) 110 | 111 | Convey("AnyF2() is correct", func() { 112 | So(pos.AnyF2(f2gt, neg), ShouldBeTrue) 113 | So(mixed.AnyF2(f2gt, neg), ShouldBeTrue) 114 | So(neg.AnyF2(f2gt, neg), ShouldBeFalse) 115 | }) 116 | }) 117 | } 118 | 119 | func TestSparseCooApply(t *testing.T) { 120 | Convey("Apply works", t, func() { 121 | a := cooDiag(1, 2, 3) 122 | a2 := a.Apply(func(v float64) float64 { return 2 * v }) 123 | So(a2.Shape(), ShouldResemble, []int{3, 3}) 124 | So(a2.Array(), ShouldResemble, []float64{ 125 | 2, 0, 0, 126 | 0, 4, 0, 127 | 0, 0, 6, 128 | }) 129 | }) 130 | } 131 | 132 | func TestSparseCooColColSetCols(t *testing.T) { 133 | Convey("Given an array", t, func() { 134 | a := cooDiag(1, 2, 3) 135 | 136 | Convey("Cols is correct", func() { 137 | So(a.Cols(), ShouldEqual, 3) 138 | }) 139 | 140 | Convey("Col panics with invalid input", func() { 141 | So(func() { a.Col(-1) }, ShouldPanic) 142 | So(func() { a.Col(3) }, ShouldPanic) 143 | }) 144 | 145 | Convey("Col works", func() { 146 | So(a.Col(0), ShouldResemble, []float64{1, 0, 0}) 147 | So(a.Col(1), ShouldResemble, []float64{0, 2, 0}) 148 | So(a.Col(2), ShouldResemble, []float64{0, 0, 3}) 149 | }) 150 | 151 | Convey("ColSet panics with invalid input", func() { 152 | So(func() { a.ColSet(-1, []float64{0, 1, 2}) }, ShouldPanic) 153 | So(func() { a.ColSet(3, []float64{0, 1, 2}) }, ShouldPanic) 154 | So(func() { a.ColSet(0, []float64{0, 1}) }, ShouldPanic) 155 | So(func() { a.ColSet(0, []float64{0, 1, 2, 3}) }, ShouldPanic) 156 | }) 157 | 158 | Convey("ColSet works", func() { 159 | a.ColSet(0, []float64{0, 1, 2}) 160 | So(a.Array(), ShouldResemble, []float64{ 161 | 0, 0, 0, 162 | 1, 2, 0, 163 | 2, 0, 3, 164 | }) 165 | a.ColSet(1, []float64{0, 1, 2}) 166 | So(a.Array(), ShouldResemble, []float64{ 167 | 0, 0, 0, 168 | 1, 1, 0, 169 | 2, 2, 3, 170 | }) 171 | a.ColSet(2, []float64{0, 1, 2}) 172 | So(a.Array(), ShouldResemble, []float64{ 173 | 0, 0, 0, 174 | 1, 1, 1, 175 | 2, 2, 2, 176 | }) 177 | }) 178 | }) 179 | } 180 | 181 | func TestSparseCooConversion(t *testing.T) { 182 | Convey("Given an array", t, func() { 183 | a := SparseCoo(3, 4, 184 | 1, 2, 0, 0, 185 | 0, 3, 4, 0, 186 | 0, 0, 5, 6) 187 | 188 | Convey("Conversion to Dense works", func() { 189 | b := a.Dense() 190 | So(b.Dense().Shape(), ShouldResemble, []int{3, 4}) 191 | So(b.Dense().Array(), ShouldResemble, []float64{ 192 | 1, 2, 0, 0, 193 | 0, 3, 4, 0, 194 | 0, 0, 5, 6, 195 | }) 196 | }) 197 | 198 | Convey("Conversion of transpose to Dense works", func() { 199 | b := a.T().Dense() 200 | So(b.Dense().Shape(), ShouldResemble, []int{4, 3}) 201 | So(b.Dense().Array(), ShouldResemble, []float64{ 202 | 1, 0, 0, 203 | 2, 3, 0, 204 | 0, 4, 5, 205 | 0, 0, 6, 206 | }) 207 | }) 208 | 209 | Convey("Conversion to sparse coo works", func() { 210 | b := a.SparseCoo() 211 | So(b.Dense().Shape(), ShouldResemble, []int{3, 4}) 212 | So(b.Dense().Array(), ShouldResemble, []float64{ 213 | 1, 2, 0, 0, 214 | 0, 3, 4, 0, 215 | 0, 0, 5, 6, 216 | }) 217 | }) 218 | 219 | Convey("Conversion of transpose to sparse coo works", func() { 220 | b := a.T().SparseCoo() 221 | So(b.Dense().Shape(), ShouldResemble, []int{4, 3}) 222 | So(b.Dense().Array(), ShouldResemble, []float64{ 223 | 1, 0, 0, 224 | 2, 3, 0, 225 | 0, 4, 5, 226 | 0, 0, 6, 227 | }) 228 | }) 229 | 230 | Convey("Conversion to diag panics if matrix is not diagonal", func() { 231 | So(func() { a.SparseDiag() }, ShouldPanic) 232 | }) 233 | 234 | Convey("Conversion to diag works", func() { 235 | a.ItemSet(0, 0, 1) 236 | a.ItemSet(0, 1, 2) 237 | a.ItemSet(0, 2, 3) 238 | b := a.SparseDiag() 239 | So(b.Dense().Shape(), ShouldResemble, []int{3, 4}) 240 | So(b.Dense().Array(), ShouldResemble, []float64{ 241 | 1, 0, 0, 0, 242 | 0, 3, 0, 0, 243 | 0, 0, 5, 0, 244 | }) 245 | }) 246 | 247 | Convey("Conversion of transpose to diag works", func() { 248 | a.ItemSet(0, 0, 1) 249 | a.ItemSet(0, 1, 2) 250 | a.ItemSet(0, 2, 3) 251 | b := a.T().SparseDiag() 252 | So(b.Dense().Shape(), ShouldResemble, []int{4, 3}) 253 | So(b.Dense().Array(), ShouldResemble, []float64{ 254 | 1, 0, 0, 255 | 0, 3, 0, 256 | 0, 0, 5, 257 | 0, 0, 0, 258 | }) 259 | }) 260 | }) 261 | } 262 | 263 | func TestSparseCooDiag(t *testing.T) { 264 | Convey("Given an array", t, func() { 265 | a := SparseCoo(3, 4) 266 | a.ItemSet(1.0, 0, 0) 267 | a.ItemSet(2.0, 1, 1) 268 | a.ItemSet(3.0, 2, 2) 269 | 270 | Convey("Diag() works", func() { 271 | d := a.Diag() 272 | So(d.Shape(), ShouldResemble, []int{3, 1}) 273 | So(d.Array(), ShouldResemble, []float64{1, 2, 3}) 274 | }) 275 | 276 | Convey("Transpose Diag() works", func() { 277 | d := a.T().M().Diag() 278 | So(d.Shape(), ShouldResemble, []int{3, 1}) 279 | So(d.Array(), ShouldResemble, []float64{1, 2, 3}) 280 | }) 281 | }) 282 | } 283 | 284 | func TestSparseCooDist(t *testing.T) { 285 | Convey("Given a SparseCoo matrix", t, func() { 286 | m := SparseCoo(3, 2).M() 287 | m.ItemSet(1, 0, 0) 288 | m.ItemSet(2, 0, 1) 289 | m.ItemSet(3, 1, 0) 290 | m.ItemSet(2, 1, 1) 291 | m.ItemSet(-1, 2, 0) 292 | m.ItemSet(4, 2, 1) 293 | 294 | Convey("Invalid distance types panic", func() { 295 | So(func() { m.Dist(DistType(-1)) }, ShouldPanic) 296 | }) 297 | 298 | Convey("Euclidean distance works", func() { 299 | d := m.Dist(EuclideanDist) 300 | So(d.Array(), ShouldResemble, []float64{ 301 | 0, 2, math.Sqrt(8), 302 | 2, 0, math.Sqrt(20), 303 | math.Sqrt(8), math.Sqrt(20), 0, 304 | }) 305 | }) 306 | }) 307 | } 308 | 309 | func TestSparseCooInverseNormLDivide(t *testing.T) { 310 | Convey("Given an invertible diagonal matrix", t, func() { 311 | m := cooDiag(2, 3, 5) 312 | 313 | Convey("When I take the inverse", func() { 314 | mi, err := m.Inverse() 315 | So(err, ShouldBeNil) 316 | 317 | Convey("The inverse is correct", func() { 318 | So(mi.Shape(), ShouldResemble, []int{3, 3}) 319 | So(mi.Array(), ShouldResemble, []float64{ 320 | .5, 0, 0, 321 | 0, 1. / 3, 0, 322 | 0, 0, .2, 323 | }) 324 | }) 325 | 326 | Convey("The inverse gives us back I", func() { 327 | i := m.MProd(mi) 328 | So(i.Shape(), ShouldResemble, []int{3, 3}) 329 | So(i.Array(), ShouldResemble, []float64{ 330 | 1, 0, 0, 331 | 0, 1, 0, 332 | 0, 0, 1, 333 | }) 334 | }) 335 | }) 336 | }) 337 | 338 | Convey("Given a 3x3 matrix", t, func() { 339 | m := cooDiag(2, 3, 5) 340 | 341 | Convey("The 1-norm is correct", func() { 342 | So(m.Norm(1), ShouldEqual, 5) 343 | }) 344 | 345 | Convey("The 2-norm is correct", func() { 346 | So(m.Norm(2), ShouldEqual, 5) 347 | }) 348 | 349 | Convey("The inf-norm is correct", func() { 350 | So(m.Norm(math.Inf(1)), ShouldEqual, 5) 351 | }) 352 | }) 353 | 354 | Convey("Given a simple division problem", t, func() { 355 | a := cooDiag(4, 6, 8) 356 | b := cooDiag(2, 2, 2) 357 | 358 | Convey("When I solve the system", func() { 359 | x := a.LDivide(b) 360 | 361 | Convey("I get the correct solution", func() { 362 | So(x.Shape(), ShouldResemble, []int{3, 3}) 363 | So(x.Array(), ShouldResemble, []float64{ 364 | .5, 0, 0, 365 | 0, 1. / 3, 0, 366 | 0, 0, .25, 367 | }) 368 | }) 369 | 370 | Convey("The product ax = b is true", func() { 371 | b2 := a.MProd(x) 372 | So(b2.Shape(), ShouldResemble, []int{3, 3}) 373 | So(b2.Equal(b), ShouldBeTrue) 374 | }) 375 | }) 376 | }) 377 | } 378 | 379 | func TestSparseCooItemMath(t *testing.T) { 380 | Convey("Given a diag array", t, func() { 381 | a := cooDiag(1, 2, 3) 382 | 383 | Convey("When I call ItemAdd", func() { 384 | a2 := a.ItemAdd(1) 385 | Convey("The result is correct", func() { 386 | So(a2.Array(), ShouldResemble, []float64{ 387 | 2, 1, 1, 388 | 1, 3, 1, 389 | 1, 1, 4, 390 | }) 391 | }) 392 | }) 393 | 394 | Convey("When I call ItemDiv", func() { 395 | a2 := a.ItemDiv(2) 396 | Convey("The result is correct", func() { 397 | So(a2.Array(), ShouldResemble, []float64{ 398 | .5, 0, 0, 399 | 0, 1, 0, 400 | 0, 0, 1.5, 401 | }) 402 | }) 403 | }) 404 | 405 | Convey("When I call ItemProd", func() { 406 | a2 := a.ItemProd(2) 407 | Convey("The result is correct", func() { 408 | So(a2.Array(), ShouldResemble, []float64{ 409 | 2, 0, 0, 410 | 0, 4, 0, 411 | 0, 0, 6, 412 | }) 413 | }) 414 | }) 415 | 416 | Convey("When I call ItemSub", func() { 417 | a2 := a.ItemSub(1) 418 | Convey("The result is correct", func() { 419 | So(a2.Array(), ShouldResemble, []float64{ 420 | 0, -1, -1, 421 | -1, 1, -1, 422 | -1, -1, 2, 423 | }) 424 | }) 425 | }) 426 | }) 427 | } 428 | 429 | func TestSparseCooMaxMin(t *testing.T) { 430 | Convey("Given positive and negative diagonal arrays", t, func() { 431 | pos := cooDiag(1, 2, 3) 432 | neg := cooDiag(-1, -2, -3) 433 | 434 | Convey("Max is right", func() { 435 | So(pos.Max(), ShouldEqual, 3) 436 | So(neg.Max(), ShouldEqual, 0) 437 | }) 438 | 439 | Convey("Min is right", func() { 440 | So(pos.Min(), ShouldEqual, 0) 441 | So(neg.Min(), ShouldEqual, -3) 442 | }) 443 | }) 444 | } 445 | 446 | func TestSparseCooRowRowSetRows(t *testing.T) { 447 | Convey("Given an array", t, func() { 448 | a := cooDiag(1, 2, 3) 449 | 450 | Convey("Rows is correct", func() { 451 | So(a.Rows(), ShouldEqual, 3) 452 | }) 453 | 454 | Convey("Row panics with invalid input", func() { 455 | So(func() { a.Row(-1) }, ShouldPanic) 456 | So(func() { a.Row(3) }, ShouldPanic) 457 | }) 458 | 459 | Convey("Row works", func() { 460 | So(a.Row(0), ShouldResemble, []float64{1, 0, 0}) 461 | So(a.Row(1), ShouldResemble, []float64{0, 2, 0}) 462 | So(a.Row(2), ShouldResemble, []float64{0, 0, 3}) 463 | }) 464 | 465 | Convey("RowSet panics with invalid input", func() { 466 | So(func() { a.RowSet(-1, []float64{0, 1, 2}) }, ShouldPanic) 467 | So(func() { a.RowSet(3, []float64{0, 1, 2}) }, ShouldPanic) 468 | So(func() { a.RowSet(0, []float64{0, 1}) }, ShouldPanic) 469 | So(func() { a.RowSet(0, []float64{0, 1, 2, 3}) }, ShouldPanic) 470 | }) 471 | 472 | Convey("RowSet works", func() { 473 | a.RowSet(0, []float64{0, 1, 2}) 474 | So(a.Array(), ShouldResemble, []float64{ 475 | 0, 1, 2, 476 | 0, 2, 0, 477 | 0, 0, 3, 478 | }) 479 | a.RowSet(1, []float64{0, 1, 2}) 480 | So(a.Array(), ShouldResemble, []float64{ 481 | 0, 1, 2, 482 | 0, 1, 2, 483 | 0, 0, 3, 484 | }) 485 | a.RowSet(2, []float64{0, 1, 2}) 486 | So(a.Array(), ShouldResemble, []float64{ 487 | 0, 1, 2, 488 | 0, 1, 2, 489 | 0, 1, 2, 490 | }) 491 | }) 492 | }) 493 | } 494 | 495 | func TestSparseCooTranspose(t *testing.T) { 496 | Convey("Given a transposed matrix", t, func() { 497 | a := SparseCoo(4, 3) 498 | a.ItemSet(1.0, 0, 1) 499 | a.ItemSet(2.0, 2, 0) 500 | a.ItemSet(3.0, 3, 2) 501 | tr := a.T() 502 | 503 | Convey("The matrix statistics are correct", func() { 504 | So(tr.Size(), ShouldEqual, 12) 505 | So(tr.Shape(), ShouldResemble, []int{3, 4}) 506 | }) 507 | 508 | Convey("FlatItem and FlatItemSet work", func() { 509 | tr.FlatItemSet(-1, 8) 510 | So(tr.Array(), ShouldResemble, []float64{ 511 | 0, 0, 2, 0, 512 | 1, 0, 0, 0, 513 | -1, 0, 0, 3, 514 | }) 515 | So(tr.Item(1, 0), ShouldEqual, 1) 516 | So(tr.Item(2, 0), ShouldEqual, -1) 517 | So(tr.FlatItem(4), ShouldEqual, 1) 518 | So(tr.FlatItem(8), ShouldEqual, -1) 519 | }) 520 | 521 | Convey("Item and ItemSet work", func() { 522 | tr.ItemSet(-1, 2, 2) 523 | So(tr.Array(), ShouldResemble, []float64{ 524 | 0, 0, 2, 0, 525 | 1, 0, 0, 0, 526 | 0, 0, -1, 3, 527 | }) 528 | So(tr.Item(2, 2), ShouldEqual, -1) 529 | So(tr.Item(2, 3), ShouldEqual, 3) 530 | So(tr.FlatItem(10), ShouldEqual, -1) 531 | So(tr.FlatItem(11), ShouldEqual, 3) 532 | }) 533 | 534 | Convey("ItemAdd works", func() { 535 | res := tr.ItemAdd(1) 536 | So(res.Array(), ShouldResemble, []float64{ 537 | 1, 1, 3, 1, 538 | 2, 1, 1, 1, 539 | 1, 1, 1, 4, 540 | }) 541 | }) 542 | 543 | Convey("ItemDiv works", func() { 544 | res := tr.ItemDiv(2) 545 | So(res.Array(), ShouldResemble, []float64{ 546 | 0, 0, 1, 0, 547 | .5, 0, 0, 0, 548 | 0, 0, 0, 1.5, 549 | }) 550 | }) 551 | 552 | Convey("ItemProd works", func() { 553 | res := tr.ItemProd(2) 554 | So(res.Array(), ShouldResemble, []float64{ 555 | 0, 0, 4, 0, 556 | 2, 0, 0, 0, 557 | 0, 0, 0, 6, 558 | }) 559 | }) 560 | 561 | Convey("ItemSub works", func() { 562 | res := tr.ItemSub(1) 563 | So(res.Array(), ShouldResemble, []float64{ 564 | -1, -1, 1, -1, 565 | 0, -1, -1, -1, 566 | -1, -1, -1, 2, 567 | }) 568 | }) 569 | 570 | Convey("Array() is correct", func() { 571 | So(a.Array(), ShouldResemble, []float64{ 572 | 0, 1, 0, 573 | 0, 0, 0, 574 | 2, 0, 0, 575 | 0, 0, 3, 576 | }) 577 | So(tr.Array(), ShouldResemble, []float64{ 578 | 0, 0, 2, 0, 579 | 1, 0, 0, 0, 580 | 0, 0, 0, 3, 581 | }) 582 | }) 583 | 584 | Convey("Copy() is correct", func() { 585 | So(tr.Copy().Array(), ShouldResemble, []float64{ 586 | 0, 0, 2, 0, 587 | 1, 0, 0, 0, 588 | 0, 0, 0, 3, 589 | }) 590 | }) 591 | 592 | Convey("Ravel() is correct", func() { 593 | So(tr.Ravel().Array(), ShouldResemble, []float64{ 594 | 0, 0, 2, 0, 595 | 1, 0, 0, 0, 596 | 0, 0, 0, 3, 597 | }) 598 | }) 599 | }) 600 | } 601 | 602 | func TestSparseCooVisit(t *testing.T) { 603 | Convey("Given a sparse coo array", t, func() { 604 | a := SparseCoo(4, 3) 605 | for row := 0; row < 4; row++ { 606 | for col := 0; col < 3; col++ { 607 | if row != col { 608 | a.ItemSet(float64(row*3+col+1), row, col) 609 | } 610 | } 611 | } 612 | 613 | Convey("Visit sees all items", func() { 614 | saw := Zeros(a.Shape()...) 615 | b := Zeros(a.Shape()...) 616 | count := 0 617 | a.Visit(func(pos []int, value float64) bool { 618 | count++ 619 | b.ItemSet(value, pos...) 620 | saw.ItemSet(1, pos...) 621 | return true 622 | }) 623 | So(count, ShouldEqual, 12) 624 | So(saw.CountNonzero(), ShouldEqual, 12) 625 | So(b.Array(), ShouldResemble, []float64{ 626 | 0, 2, 3, 627 | 4, 0, 6, 628 | 7, 8, 0, 629 | 10, 11, 12, 630 | }) 631 | }) 632 | 633 | Convey("Visit stops early if f() returns false", func() { 634 | saw := Zeros(a.Shape()...) 635 | b := Zeros(a.Shape()...) 636 | count := 0 637 | a.Visit(func(pos []int, value float64) bool { 638 | count++ 639 | b.ItemSet(value, pos...) 640 | saw.ItemSet(1, pos...) 641 | if saw.CountNonzero() >= 5 { 642 | return false 643 | } 644 | return true 645 | }) 646 | So(count, ShouldEqual, 5) 647 | So(saw.CountNonzero(), ShouldEqual, 5) 648 | b.VisitNonzero(func(pos []int, value float64) bool { 649 | So(value, ShouldEqual, a.Item(pos...)) 650 | return true 651 | }) 652 | }) 653 | 654 | Convey("VisitNonzero sees just nonzero items", func() { 655 | saw := Zeros(a.Shape()...) 656 | b := Zeros(a.Shape()...) 657 | count := 0 658 | a.VisitNonzero(func(pos []int, value float64) bool { 659 | count++ 660 | b.ItemSet(value, pos...) 661 | saw.ItemSet(1, pos...) 662 | return true 663 | }) 664 | So(count, ShouldEqual, 9) 665 | So(saw.CountNonzero(), ShouldEqual, 9) 666 | So(b.Array(), ShouldResemble, []float64{ 667 | 0, 2, 3, 668 | 4, 0, 6, 669 | 7, 8, 0, 670 | 10, 11, 12, 671 | }) 672 | }) 673 | 674 | Convey("VisitNonzero stops early if f() returns false", func() { 675 | saw := Zeros(a.Shape()...) 676 | b := Zeros(a.Shape()...) 677 | count := 0 678 | a.VisitNonzero(func(pos []int, value float64) bool { 679 | count++ 680 | b.ItemSet(value, pos...) 681 | saw.ItemSet(1, pos...) 682 | if saw.CountNonzero() >= 5 { 683 | return false 684 | } 685 | return true 686 | }) 687 | So(count, ShouldEqual, 5) 688 | So(saw.CountNonzero(), ShouldEqual, 5) 689 | b.VisitNonzero(func(pos []int, value float64) bool { 690 | So(value, ShouldEqual, a.Item(pos...)) 691 | return true 692 | }) 693 | }) 694 | 695 | Convey("T().Visit sees all items", func() { 696 | saw := Zeros(a.T().Shape()...) 697 | b := Zeros(a.T().Shape()...) 698 | count := 0 699 | a.T().Visit(func(pos []int, value float64) bool { 700 | count++ 701 | b.ItemSet(value, pos...) 702 | saw.ItemSet(1, pos...) 703 | return true 704 | }) 705 | So(count, ShouldEqual, 12) 706 | So(saw.CountNonzero(), ShouldEqual, 12) 707 | So(b.Array(), ShouldResemble, []float64{ 708 | 0, 4, 7, 10, 709 | 2, 0, 8, 11, 710 | 3, 6, 0, 12, 711 | }) 712 | }) 713 | 714 | Convey("T().Visit stops early if f() returns false", func() { 715 | saw := Zeros(a.T().Shape()...) 716 | b := Zeros(a.T().Shape()...) 717 | count := 0 718 | a.T().Visit(func(pos []int, value float64) bool { 719 | count++ 720 | b.ItemSet(value, pos...) 721 | saw.ItemSet(1, pos...) 722 | if saw.CountNonzero() >= 5 { 723 | return false 724 | } 725 | return true 726 | }) 727 | So(count, ShouldEqual, 5) 728 | So(saw.CountNonzero(), ShouldEqual, 5) 729 | b.VisitNonzero(func(pos []int, value float64) bool { 730 | So(value, ShouldEqual, a.T().Item(pos...)) 731 | return true 732 | }) 733 | }) 734 | 735 | Convey("T().VisitNonzero sees just nonzero items", func() { 736 | saw := Zeros(a.T().Shape()...) 737 | b := Zeros(a.T().Shape()...) 738 | count := 0 739 | a.T().VisitNonzero(func(pos []int, value float64) bool { 740 | count++ 741 | b.ItemSet(value, pos...) 742 | saw.ItemSet(1, pos...) 743 | return true 744 | }) 745 | // So(count, ShouldEqual, 9) 746 | // So(saw.CountNonzero(), ShouldEqual, 9) 747 | So(b.Array(), ShouldResemble, []float64{ 748 | 0, 4, 7, 10, 749 | 2, 0, 8, 11, 750 | 3, 6, 0, 12, 751 | }) 752 | }) 753 | 754 | Convey("T().VisitNonzero stops early if f() returns false", func() { 755 | saw := Zeros(a.T().Shape()...) 756 | b := Zeros(a.T().Shape()...) 757 | count := 0 758 | a.T().VisitNonzero(func(pos []int, value float64) bool { 759 | count++ 760 | b.ItemSet(value, pos...) 761 | saw.ItemSet(1, pos...) 762 | if saw.CountNonzero() >= 5 { 763 | return false 764 | } 765 | return true 766 | }) 767 | So(count, ShouldEqual, 5) 768 | So(saw.CountNonzero(), ShouldEqual, 5) 769 | b.VisitNonzero(func(pos []int, value float64) bool { 770 | So(value, ShouldEqual, a.T().Item(pos...)) 771 | return true 772 | }) 773 | }) 774 | }) 775 | } 776 | 777 | func TestSparseCooMatrix(t *testing.T) { 778 | Convey("Given a sparse matrix with shape 5, 3", t, func() { 779 | array := SparseCoo(5, 3) 780 | 781 | Convey("All is false", func() { 782 | So(array.All(), ShouldBeFalse) 783 | }) 784 | 785 | Convey("Any is false", func() { 786 | So(array.Any(), ShouldBeFalse) 787 | }) 788 | 789 | Convey("Size is 15", func() { 790 | So(array.Size(), ShouldEqual, 15) 791 | }) 792 | 793 | Convey("NDim is 2", func() { 794 | So(array.NDim(), ShouldEqual, 2) 795 | }) 796 | 797 | Convey("Shape is (5, 3)", func() { 798 | So(array.Shape(), ShouldResemble, []int{5, 3}) 799 | }) 800 | 801 | Convey("Item() panics with invalid indices", func() { 802 | So(func() { array.Item(5, 0) }, ShouldPanic) 803 | So(func() { array.Item(0, 3) }, ShouldPanic) 804 | So(func() { array.Item(0) }, ShouldPanic) 805 | So(func() { array.Item(0, 0, 0) }, ShouldPanic) 806 | }) 807 | 808 | Convey("ItemSet() panics with invalid indices", func() { 809 | So(func() { array.ItemSet(1.0, 5, 0) }, ShouldPanic) 810 | So(func() { array.ItemSet(1.0, 0, 3) }, ShouldPanic) 811 | So(func() { array.ItemSet(1.0, 0) }, ShouldPanic) 812 | So(func() { array.ItemSet(1.0, 0, 0, 0) }, ShouldPanic) 813 | }) 814 | 815 | Convey("Item() returns all zeros", func() { 816 | for i0 := 0; i0 < 5; i0++ { 817 | for i1 := 0; i1 < 3; i1++ { 818 | So(array.Item(i0, i1), ShouldEqual, 0) 819 | } 820 | } 821 | }) 822 | 823 | Convey("Sum() is zero", func() { 824 | So(array.Sum(), ShouldEqual, 0) 825 | }) 826 | 827 | Convey("When I call Normalize()", func() { 828 | norm := array.Normalize() 829 | Convey("Item() returns all zeros", func() { 830 | for i0 := 0; i0 < 5; i0++ { 831 | for i1 := 0; i1 < 3; i1++ { 832 | So(norm.Item(i0, i1), ShouldEqual, 0) 833 | } 834 | } 835 | }) 836 | }) 837 | 838 | Convey("When I call ItemSet", func() { 839 | array.ItemSet(1, 1, 0) 840 | array.ItemSet(2, 3, 2) 841 | 842 | Convey("All is false", func() { 843 | So(array.All(), ShouldBeFalse) 844 | }) 845 | 846 | Convey("Any is true", func() { 847 | So(array.Any(), ShouldBeTrue) 848 | }) 849 | 850 | Convey("Equal is correct", func() { 851 | So(array.Equal(Zeros(5, 3)), ShouldBeFalse) 852 | So(array.Equal(array.Copy()), ShouldBeTrue) 853 | }) 854 | 855 | Convey("Item() returns updates", func() { 856 | for i0 := 0; i0 < 5; i0++ { 857 | for i1 := 0; i1 < 3; i1++ { 858 | if i0 == 1 && i1 == 0 { 859 | So(array.Item(i0, i1), ShouldEqual, 1) 860 | } else if i0 == 3 && i1 == 2 { 861 | So(array.Item(i0, i1), ShouldEqual, 2) 862 | } else { 863 | So(array.Item(i0, i1), ShouldEqual, 0) 864 | } 865 | } 866 | } 867 | }) 868 | 869 | Convey("FlatItem returns the correct values", func() { 870 | for i := 0; i < array.Size(); i++ { 871 | switch i { 872 | case 1*3 + 0: 873 | So(array.FlatItem(i), ShouldEqual, 1) 874 | case 3*3 + 2: 875 | So(array.FlatItem(i), ShouldEqual, 2) 876 | default: 877 | So(array.FlatItem(i), ShouldEqual, 0) 878 | } 879 | } 880 | }) 881 | 882 | Convey("Sum() is correct", func() { 883 | So(array.Sum(), ShouldEqual, 3) 884 | }) 885 | 886 | Convey("When I call Normalize", func() { 887 | norm := array.Normalize() 888 | 889 | Convey("Item() is correct", func() { 890 | for i0 := 0; i0 < 5; i0++ { 891 | for i1 := 0; i1 < 3; i1++ { 892 | if i0 == 1 && i1 == 0 { 893 | So(norm.Item(i0, i1), ShouldEqual, 1.0/3) 894 | } else if i0 == 3 && i1 == 2 { 895 | So(norm.Item(i0, i1), ShouldEqual, 2.0/3) 896 | } else { 897 | So(norm.Item(i0, i1), ShouldEqual, 0) 898 | } 899 | } 900 | } 901 | }) 902 | 903 | Convey("Sum() is 1", func() { 904 | So(norm.Sum(), ShouldEqual, 1) 905 | }) 906 | }) 907 | }) 908 | 909 | Convey("When I call FlatItemSet", func() { 910 | array.FlatItemSet(1, 1*3+0) 911 | array.FlatItemSet(2, 3*3+2) 912 | 913 | Convey("All is false", func() { 914 | So(array.All(), ShouldBeFalse) 915 | }) 916 | 917 | Convey("Any is true", func() { 918 | So(array.Any(), ShouldBeTrue) 919 | }) 920 | 921 | Convey("Item() returns updates", func() { 922 | for i0 := 0; i0 < 5; i0++ { 923 | for i1 := 0; i1 < 3; i1++ { 924 | if i0 == 1 && i1 == 0 { 925 | So(array.Item(i0, i1), ShouldEqual, 1) 926 | } else if i0 == 3 && i1 == 2 { 927 | So(array.Item(i0, i1), ShouldEqual, 2) 928 | } else { 929 | So(array.Item(i0, i1), ShouldEqual, 0) 930 | } 931 | } 932 | } 933 | }) 934 | 935 | Convey("FlatItem returns the correct values", func() { 936 | for i := 0; i < array.Size(); i++ { 937 | switch i { 938 | case 1*3 + 0: 939 | So(array.FlatItem(i), ShouldEqual, 1) 940 | case 3*3 + 2: 941 | So(array.FlatItem(i), ShouldEqual, 2) 942 | default: 943 | So(array.FlatItem(i), ShouldEqual, 0) 944 | } 945 | } 946 | }) 947 | }) 948 | 949 | Convey("Fill panics", func() { 950 | So(func() { array.Fill(3) }, ShouldPanic) 951 | }) 952 | }) 953 | 954 | Convey("Given a random array with shape 5,3", t, func() { 955 | array := SparseRand(5, 3, 0.2) 956 | 957 | Convey("When I call T", func() { 958 | tr := array.T() 959 | 960 | Convey("The transpose shape is 3x5", func() { 961 | So(tr.Shape(), ShouldResemble, []int{3, 5}) 962 | }) 963 | 964 | Convey("The transpose size is 15", func() { 965 | So(tr.Size(), ShouldEqual, 15) 966 | }) 967 | 968 | Convey("FlatItem works", func() { 969 | next := 0 970 | for i0 := 0; i0 < 3; i0++ { 971 | for i1 := 0; i1 < 5; i1++ { 972 | So(tr.FlatItem(next), ShouldEqual, array.Item(i1, i0)) 973 | next++ 974 | } 975 | } 976 | }) 977 | 978 | Convey("Item works", func() { 979 | for i0 := 0; i0 < 3; i0++ { 980 | for i1 := 0; i1 < 5; i1++ { 981 | So(tr.Item(i0, i1), ShouldEqual, array.Item(i1, i0)) 982 | } 983 | } 984 | }) 985 | }) 986 | 987 | Convey("A slice with no indices panics", func() { 988 | So(func() { array.Slice([]int{}, []int{}) }, ShouldPanic) 989 | }) 990 | 991 | Convey("A slice with too few indices panics", func() { 992 | So(func() { array.Slice([]int{1}, []int{2}) }, ShouldPanic) 993 | }) 994 | 995 | Convey("A slice with too many indices panics", func() { 996 | So(func() { array.Slice([]int{1, 2, 3}, []int{4, 5, 6}) }, ShouldPanic) 997 | }) 998 | 999 | Convey("The slice from [2:0, 0:1] panics", func() { 1000 | So(func() { array.Slice([]int{2, 0}, []int{0, 1}) }, ShouldPanic) 1001 | }) 1002 | 1003 | Convey("The slice from [-1:-3, 0:1] panics", func() { 1004 | So(func() { array.Slice([]int{-1, 0}, []int{-3, 1}) }, ShouldPanic) 1005 | }) 1006 | 1007 | Convey("The slice from [0:2, 1:3] is correct", func() { 1008 | slice := array.Slice([]int{0, 1}, []int{2, 3}) 1009 | So(slice.Shape(), ShouldResemble, []int{2, 2}) 1010 | next := 0 1011 | for i0 := 0; i0 < 2; i0++ { 1012 | for i1 := 1; i1 < 3; i1++ { 1013 | So(slice.FlatItem(next), ShouldEqual, array.Item(i0, i1)) 1014 | next++ 1015 | } 1016 | } 1017 | }) 1018 | 1019 | Convey("The slice from [3:5,0:2] is correct", func() { 1020 | slice := array.Slice([]int{3, 0}, []int{5, 2}) 1021 | So(slice.Shape(), ShouldResemble, []int{2, 2}) 1022 | next := 0 1023 | for i0 := 3; i0 < 5; i0++ { 1024 | for i1 := 0; i1 < 2; i1++ { 1025 | So(slice.FlatItem(next), ShouldEqual, array.Item(i0, i1)) 1026 | next++ 1027 | } 1028 | } 1029 | }) 1030 | 1031 | Convey("The slice from [3:-1,0:-2] is correct", func() { 1032 | slice := array.Slice([]int{3, 0}, []int{-1, -2}) 1033 | So(slice.Shape(), ShouldResemble, []int{2, 2}) 1034 | next := 0 1035 | for i0 := 3; i0 < 5; i0++ { 1036 | for i1 := 0; i1 < 2; i1++ { 1037 | So(slice.FlatItem(next), ShouldEqual, array.Item(i0, i1)) 1038 | next++ 1039 | } 1040 | } 1041 | }) 1042 | 1043 | Convey("The slice from [-3:-1,-3:-2] is correct", func() { 1044 | slice := array.Slice([]int{-3, -3}, []int{-1, -2}) 1045 | So(slice.Shape(), ShouldResemble, []int{2, 1}) 1046 | next := 0 1047 | for i0 := 3; i0 < 5; i0++ { 1048 | for i1 := 1; i1 < 2; i1++ { 1049 | So(slice.FlatItem(next), ShouldEqual, array.Item(i0, i1)) 1050 | next++ 1051 | } 1052 | } 1053 | }) 1054 | }) 1055 | 1056 | Convey("Given two 2D sparse matrices of equal length", t, func() { 1057 | a1 := SparseCoo(4, 3) 1058 | a2 := SparseCoo(4, 3) 1059 | for row := 0; row < 4; row++ { 1060 | for col := 0; col < 3; col++ { 1061 | a1.ItemSet(1, row, col) 1062 | a2.ItemSet(2, row, col) 1063 | } 1064 | } 1065 | 1066 | Convey("Concat works along axis 0", func() { 1067 | a3 := a1.Concat(0, a2) 1068 | So(a3.Shape(), ShouldResemble, []int{8, 3}) 1069 | for i0 := 0; i0 < 8; i0++ { 1070 | for i1 := 0; i1 < 3; i1++ { 1071 | if i0 < 4 { 1072 | So(a3.Item(i0, i1), ShouldEqual, 1) 1073 | } else { 1074 | So(a3.Item(i0, i1), ShouldEqual, 2) 1075 | } 1076 | } 1077 | } 1078 | }) 1079 | 1080 | Convey("Concat works along axis 1", func() { 1081 | a3 := a1.Concat(1, a2) 1082 | So(a3.Shape(), ShouldResemble, []int{4, 6}) 1083 | for i0 := 0; i0 < 4; i0++ { 1084 | for i1 := 0; i1 < 6; i1++ { 1085 | if i1 < 3 { 1086 | So(a3.Item(i0, i1), ShouldEqual, 1) 1087 | } else { 1088 | So(a3.Item(i0, i1), ShouldEqual, 2) 1089 | } 1090 | } 1091 | } 1092 | }) 1093 | 1094 | Convey("Concat works along axis 2", func() { 1095 | a3 := a1.Concat(2, a2) 1096 | So(a3.Shape(), ShouldResemble, []int{4, 3, 2}) 1097 | for i0 := 0; i0 < 4; i0++ { 1098 | for i1 := 0; i1 < 3; i1++ { 1099 | for i2 := 0; i2 < 2; i2++ { 1100 | if i2 == 0 { 1101 | So(a3.Item(i0, i1, i2), ShouldEqual, 1) 1102 | } else { 1103 | So(a3.Item(i0, i1, i2), ShouldEqual, 2) 1104 | } 1105 | } 1106 | } 1107 | } 1108 | }) 1109 | 1110 | Convey("Concat panics along axis 3", func() { 1111 | So(func() { a1.Concat(3, a2) }, ShouldPanic) 1112 | }) 1113 | }) 1114 | 1115 | Convey("Given a 2x3 and 3x4 array", t, func() { 1116 | left := SparseRand(2, 3, 0.2).M() 1117 | right := SparseRand(3, 4, 0.2).M() 1118 | Convey("MProd() works", func() { 1119 | result := left.MProd(right) 1120 | So(result.Shape(), ShouldResemble, []int{2, 4}) 1121 | for i0 := 0; i0 < 2; i0++ { 1122 | for i1 := 0; i1 < 4; i1++ { 1123 | c := left.Item(i0, 0)*right.Item(0, i1) + 1124 | left.Item(i0, 1)*right.Item(1, i1) + 1125 | left.Item(i0, 2)*right.Item(2, i1) 1126 | So(result.Item(i0, i1), ShouldEqual, c) 1127 | } 1128 | } 1129 | }) 1130 | }) 1131 | } 1132 | --------------------------------------------------------------------------------