├── .circleci └── config.yml ├── LICENSE.txt ├── README.md ├── errors.go ├── gann.go ├── go.mod ├── go.sum ├── index.go ├── index_test.go ├── item.go ├── metric ├── cosine.go ├── cosine_test.go └── metric.go ├── node.go ├── node_test.go ├── queue.go ├── queue_test.go ├── search.go └── search_test.go /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | jobs: 4 | build: 5 | docker: 6 | - image: circleci/golang:1.11.7 7 | steps: 8 | - checkout 9 | - restore_cache: 10 | keys: 11 | - go-mod-v1-{{ checksum "go.sum" }} 12 | - run: go mod download 13 | - run: go test -v -race ./... 14 | - run: go build . 15 | - save_cache: 16 | key: go-mod-v1-{{ checksum "go.sum" }} 17 | paths: 18 | - /go/pkg/mod/cache -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 @mathetake 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gann 2 | [![CircleCI](https://circleci.com/gh/mathetake/gann.svg?style=shield&circle-token=9a6608c5baa7a400661a700127778a9ff8baeee3)](https://circleci.com/gh/mathetake/gann) 3 | [![MIT License](http://img.shields.io/badge/license-MIT-blue.svg?style=flat)](LICENSE) 4 | [![](https://godoc.org/github.com/mathetake/gann?status.svg)](http://godoc.org/github.com/mathetake/gann) 5 | 6 | portfolio_view 7 | 8 | gann (go-approximate-nearest-neighbor) is a library for approximate nearest neighbor search purely written in golang. 9 | 10 | The implemented algorithm is truly inspired by Annoy (https://github.com/spotify/annoy). 11 | 12 | ## feature 13 | 1. purely written in Go: no dependencies out of Go world. 14 | 2. easy to tune with a bit of parameters 15 | 16 | ## installation 17 | 18 | ``` 19 | go get github.com/mathetake/gann 20 | ``` 21 | 22 | ## parameters 23 | 24 | ### setup phase parameters 25 | 26 | |name|type|description|run-time complexity|space complexity|accuracy| 27 | |:---:|:---:|:---:|:---:|:---:|:---:| 28 | |dim|int| dimension of target vectors| the larger, the more expensive | the larger, the more expensive | N/A | 29 | |nTree|int| # of trees|the larger, the more expensive| the larger, the more expensive | the larger, the more accurate| 30 | |k|int|maximum # of items in a single leaf|the larger, the less expensive| N/A| the larger, the less accurate| 31 | 32 | ### runtime (search phase) parameters 33 | 34 | |name|type|description|time complexity|accuracy| 35 | |:---:|:---:|:---:|:---:|:---:| 36 | |searchNum|int| # of requested neighbors|the larger, the more expensive|N/A| 37 | |bucketScale|float64| affects the size of `bucket` |the larger, the more expensive|the larger, the more accurate| 38 | 39 | `bucketScale` affects the size of `bucket` which consists of items for exact distance calculation. 40 | The actual size of the bucket is [calculated by](https://github.com/mathetake/gann/blob/357c3abd241bd6455e895a5b392251b06507a8e8/search.go#L30) `int(searchNum * bucketScale)`. 41 | 42 | In the search phase, we traverse index trees and continuously put items on reached leaves to the bucket [until the bucket becomes full](https://github.com/mathetake/gann/blob/357c3abd241bd6455e895a5b392251b06507a8e8/search.go#L48). 43 | Then we [calculate the exact distances between a item in the bucket and the query vector](https://github.com/mathetake/gann/blob/357c3abd241bd6455e895a5b392251b06507a8e8/search.go#L74-L81) to get approximate nearest neighbors. 44 | 45 | Therefore, the larger `bucketScale`, the more computational complexity while the more accurate result to be produced. 46 | 47 | ## example 48 | 49 | ```golang 50 | package main 51 | 52 | import ( 53 | "fmt" 54 | "math/rand" 55 | "time" 56 | 57 | "github.com/mathetake/gann" 58 | "github.com/mathetake/gann/metric" 59 | ) 60 | 61 | var ( 62 | dim = 3 63 | nTrees = 2 64 | k = 10 65 | nItem = 1000 66 | ) 67 | 68 | func main() { 69 | rawItems := make([][]float64, 0, nItem) 70 | rand.Seed(time.Now().UnixNano()) 71 | 72 | for i := 0; i < nItem; i++ { 73 | item := make([]float64, 0, dim) 74 | for j := 0; j < dim; j++ { 75 | item = append(item, rand.Float64()) 76 | } 77 | rawItems = append(rawItems, item) 78 | } 79 | 80 | m, err := metric.NewCosineMetric(dim) 81 | if err != nil { 82 | // err handling 83 | return 84 | } 85 | 86 | // create index 87 | idx, err := gann.CreateNewIndex(rawItems, dim, nTrees, k, m) 88 | if err != nil { 89 | // error handling 90 | return 91 | } 92 | 93 | // search 94 | var searchNum = 5 95 | var bucketScale float64 = 10 96 | q := []float64{0.1, 0.02, 0.001} 97 | res, err := idx.GetANNbyVector(q, searchNum, bucketScale) 98 | if err != nil { 99 | // error handling 100 | return 101 | } 102 | 103 | fmt.Printf("res: %v\n", res) 104 | } 105 | ``` 106 | 107 | ## references 108 | 109 | - https://github.com/spotify/annoy 110 | - https://en.wikipedia.org/wiki/Nearest_neighbor_search#Approximate_nearest_neighbor 111 | 112 | ## License 113 | 114 | MIT 115 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package gann 2 | 3 | import "github.com/pkg/errors" 4 | 5 | var ( 6 | errDimensionMismatch = errors.New("dimension mismatch") 7 | errInvalidIndex = errors.New("invalid index") 8 | errInvalidKeyVector = errors.New("invalid key vector") 9 | errItemNotFoundOnGivenItemID = errors.New("item not found for give item id") 10 | errNotEnoughItems = errors.New("not enough items to build the tree") 11 | ) 12 | -------------------------------------------------------------------------------- /gann.go: -------------------------------------------------------------------------------- 1 | // Package gann can be used for approximate nearest neighbor search. 2 | // 3 | // By calling gann.CreateNewIndex function, we can obtain a search index. 4 | // Its interface is defined in gann.Index: 5 | // 6 | // type Index interface { 7 | // GetANNbyItemID(id int64, searchNum int, bucketScale float64) (ann []int64, err error) 8 | // GetANNbyVector(v []float64, searchNum int, bucketScale float64) (ann []int64, err error) 9 | // } 10 | // 11 | // GetANNbyItemID allows us to pass id of specific item for search execution 12 | // and instead GetANNbyVector allows us to pass a vector. 13 | // 14 | // See README.md for more details. 15 | package gann 16 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/mathetake/gann 2 | 3 | require ( 4 | github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 5 | github.com/google/uuid v1.1.1 6 | github.com/kr/pretty v0.1.0 // indirect 7 | github.com/pkg/errors v0.8.1 8 | ) 9 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= 2 | github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= 3 | github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= 4 | github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 5 | github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= 6 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 7 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 8 | github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= 9 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 10 | github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= 11 | github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 12 | -------------------------------------------------------------------------------- /index.go: -------------------------------------------------------------------------------- 1 | package gann 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/google/uuid" 7 | "github.com/mathetake/gann/metric" 8 | ) 9 | 10 | // Index is the interface of gann's search index. GetANNbyItemID and GetANNbyVector are different in the form of query. 11 | // GetANNbyItemID can be executed by passing a certain item's id contained in the list of items used in the index building phase. 12 | // GetANNbyVector allows us to pass any vector of proper dimension. 13 | // 14 | // searchNum is the number of requested approximated nearest neighbors, and bucketScale can be tuned to make balance between 15 | // the search result's accuracy and computational complexity in the search phase. 16 | // 17 | // see README.md for more details. 18 | type Index interface { 19 | // GetANNbyItemID ... search approximate nearest neighbors by a given itemID 20 | GetANNbyItemID(id int64, searchNum int, bucketScale float64) (ann []int64, err error) 21 | 22 | // GetANNbyVector ... search approximate nearest neighbors by a given query vector 23 | GetANNbyVector(v []float64, searchNum int, bucketScale float64) (ann []int64, err error) 24 | } 25 | 26 | type index struct { 27 | metric metric.Metric 28 | 29 | // dim ... dimension of the target space 30 | dim int 31 | 32 | // k ... maximum # of items in a single leaf node 33 | k int 34 | 35 | // itemIDToItem ... ItemIDToItem 36 | itemIDToItem map[itemId]*item 37 | 38 | // nodeIDToNode ... NodeIDToNode 39 | nodeIDToNode map[nodeId]*node 40 | 41 | // roots ... roots of the trees 42 | roots []*node 43 | 44 | mux *sync.Mutex 45 | } 46 | 47 | // CreateNewIndex build a new search index for given vectors. rawItems should consist of search target vectors and 48 | // its slice index corresponds to the first argument id of GetANNbyItemID. For example, if we want to search approximate 49 | // nearest neighbors of rawItems[3], it can simply achieved by calling index.GetANNbyItemID(3, ...). 50 | // 51 | // dim is the dimension of target spaces. nTree and k are tunable parameters which affects performances of 52 | // the index (see README.md for details.) 53 | // 54 | // The last argument m is type of metric.Metric and represents the metric of the target search space. 55 | // See https://godoc.org/github.com/mathetake/gann/metric for details. 56 | func CreateNewIndex(rawItems [][]float64, dim, nTree, k int, m metric.Metric) (Index, error) { 57 | // verify that given items have same dimension 58 | for _, it := range rawItems { 59 | if len(it) != dim { 60 | return nil, errDimensionMismatch 61 | } 62 | } 63 | 64 | if len(rawItems) < 2 { 65 | return nil, errNotEnoughItems 66 | } 67 | 68 | its := make([]*item, len(rawItems)) 69 | idToItem := make(map[itemId]*item, len(rawItems)) 70 | for i, v := range rawItems { 71 | it := &item{ 72 | id: itemId(i), 73 | vector: v, 74 | } 75 | its[i] = it 76 | idToItem[it.id] = it 77 | } 78 | 79 | idx := &index{ 80 | metric: m, 81 | dim: dim, 82 | k: k, 83 | itemIDToItem: idToItem, 84 | roots: make([]*node, nTree), 85 | nodeIDToNode: map[nodeId]*node{}, 86 | mux: &sync.Mutex{}, 87 | } 88 | 89 | // build 90 | idx.build(its, nTree) 91 | return idx, nil 92 | } 93 | 94 | func (idx *index) build(items []*item, nTree int) { 95 | vs := make([][]float64, len(idx.itemIDToItem)) 96 | for i, it := range items { 97 | vs[i] = it.vector 98 | } 99 | 100 | for i := 0; i < nTree; i++ { 101 | nv := idx.metric.GetSplittingVector(vs) 102 | rn := &node{ 103 | id: nodeId(uuid.New().String()), 104 | vec: nv, 105 | idxPtr: idx, 106 | children: map[direction]*node{}, 107 | } 108 | idx.roots[i] = rn 109 | idx.nodeIDToNode[rn.id] = rn 110 | } 111 | 112 | var wg sync.WaitGroup 113 | wg.Add(nTree) 114 | for _, rn := range idx.roots { 115 | rn := rn 116 | go func() { 117 | defer wg.Done() 118 | rn.build(items) 119 | }() 120 | } 121 | wg.Wait() 122 | } 123 | -------------------------------------------------------------------------------- /index_test.go: -------------------------------------------------------------------------------- 1 | package gann 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "math/rand" 7 | "testing" 8 | "time" 9 | 10 | "github.com/bmizerany/assert" 11 | "github.com/mathetake/gann/metric" 12 | ) 13 | 14 | func init() { 15 | rand.Seed(time.Now().UnixNano()) 16 | } 17 | 18 | func TestCreateNewIndex(t *testing.T) { 19 | for i, c := range []struct { 20 | dim, num, nTree, k int 21 | }{ 22 | {dim: 2, num: 1000, nTree: 10, k: 2}, 23 | {dim: 10, num: 100, nTree: 5, k: 10}, 24 | {dim: 10, num: 100000, nTree: 5, k: 10}, 25 | {dim: 1000, num: 10000, nTree: 5, k: 10}, 26 | } { 27 | c := c 28 | t.Run(fmt.Sprintf("%d-th case", i), func(t *testing.T) { 29 | rawItems := make([][]float64, c.num) 30 | for i := range rawItems { 31 | v := make([]float64, c.dim) 32 | 33 | var norm float64 34 | for j := range v { 35 | cof := rand.Float64() - 0.5 36 | v[j] = cof 37 | norm += cof * cof 38 | } 39 | 40 | norm = math.Sqrt(norm) 41 | for j := range v { 42 | v[j] /= norm 43 | } 44 | 45 | rawItems[i] = v 46 | } 47 | 48 | m, err := metric.NewCosineMetric(c.dim) 49 | if err != nil { 50 | t.Fatal(err) 51 | } 52 | 53 | idx, err := CreateNewIndex(rawItems, c.dim, c.nTree, c.k, m) 54 | if err != nil { 55 | t.Fatal(err) 56 | } 57 | 58 | rawIdx, ok := idx.(*index) 59 | if !ok { 60 | t.Fatal("type assertion failed") 61 | } 62 | 63 | assert.Equal(t, c.nTree, len(rawIdx.roots)) 64 | assert.Equal(t, true, len(rawIdx.nodeIDToNode) > c.nTree) 65 | }) 66 | } 67 | 68 | } 69 | 70 | func TestCreateNewIndexNotEnoughItems(t *testing.T) { 71 | rawItems := make([][]float64, 1) 72 | rawItems[0] = []float64{1, 2, 3, 4} 73 | 74 | m, err := metric.NewCosineMetric(4) 75 | if err != nil { 76 | t.Fatal(err) 77 | } 78 | 79 | //1 vector is not enough 80 | _, err = CreateNewIndex(rawItems, 4, 4, 2, m) 81 | if err != errNotEnoughItems { 82 | t.Fatalf("expected error errNotEnoughItems, got %v instead", err) 83 | } 84 | 85 | rawItems2 := make([][]float64, 2) 86 | rawItems2[0] = []float64{1, 2, 3, 4} 87 | rawItems2[1] = []float64{2, 2, 2, 2} 88 | 89 | //2 vectors are ok 90 | _, err = CreateNewIndex(rawItems2, 4, 4, 2, m) 91 | if err != nil { 92 | t.Fatalf("unexpected error %v", err) 93 | } 94 | 95 | } 96 | -------------------------------------------------------------------------------- /item.go: -------------------------------------------------------------------------------- 1 | package gann 2 | 3 | type itemId int64 4 | 5 | type item struct { 6 | id itemId 7 | vector []float64 8 | } 9 | -------------------------------------------------------------------------------- /metric/cosine.go: -------------------------------------------------------------------------------- 1 | package metric 2 | 3 | import ( 4 | "math/rand" 5 | "time" 6 | ) 7 | 8 | const ( 9 | cosineMetricsMaxIteration = 200 10 | cosineMetricsMaxTargetSample = 100 11 | cosineMetricsTwoMeansThreshold = 0.7 12 | cosineMetricsCentroidCalcRatio = 0.0001 13 | ) 14 | 15 | func init() { 16 | rand.Seed(time.Now().UnixNano()) 17 | } 18 | 19 | type cosineDistance struct { 20 | dim int 21 | } 22 | 23 | // NewCosineMetric returns cosineDistance. 24 | // NOTE: We assume that the given vectors are already normalized, i.e. the norm equals 1 25 | func NewCosineMetric(dim int) (Metric, error) { 26 | return &cosineDistance{ 27 | dim: dim, 28 | }, nil 29 | } 30 | 31 | func (c *cosineDistance) CalcDistance(v1, v2 []float64) float64 { 32 | var ret float64 33 | for i := range v1 { 34 | ret += v1[i] * v2[i] 35 | } 36 | return -ret 37 | } 38 | 39 | func (c *cosineDistance) GetSplittingVector(vs [][]float64) []float64 { 40 | lvs := len(vs) 41 | // init centroids 42 | k := rand.Intn(lvs) 43 | l := rand.Intn(lvs - 1) 44 | if k == l { 45 | l++ 46 | } 47 | c0 := vs[k] 48 | c1 := vs[l] 49 | 50 | for i := 0; i < cosineMetricsMaxIteration; i++ { 51 | clusterToVecs := map[int][][]float64{} 52 | 53 | iter := cosineMetricsMaxTargetSample 54 | if len(vs) < cosineMetricsMaxTargetSample { 55 | iter = len(vs) 56 | } 57 | for i := 0; i < iter; i++ { 58 | v := vs[rand.Intn(len(vs))] 59 | ip0 := c.CalcDistance(c0, v) 60 | ip1 := c.CalcDistance(c1, v) 61 | if ip0 > ip1 { 62 | clusterToVecs[0] = append(clusterToVecs[0], v) 63 | } else { 64 | clusterToVecs[1] = append(clusterToVecs[1], v) 65 | } 66 | } 67 | 68 | lc0 := len(clusterToVecs[0]) 69 | lc1 := len(clusterToVecs[1]) 70 | 71 | if (float64(lc0)/float64(iter) <= cosineMetricsTwoMeansThreshold) && 72 | (float64(lc1)/float64(iter) <= cosineMetricsTwoMeansThreshold) { 73 | break 74 | } 75 | 76 | // update centroids 77 | if lc0 == 0 || lc1 == 0 { 78 | k := rand.Intn(lvs) 79 | l := rand.Intn(lvs - 1) 80 | if k == l { 81 | l++ 82 | } 83 | c0 = vs[k] 84 | c1 = vs[l] 85 | continue 86 | } 87 | 88 | c0 = make([]float64, c.dim) 89 | it0 := int(float64(lvs) * cosineMetricsCentroidCalcRatio) 90 | for i := 0; i < it0; i++ { 91 | for d := 0; d < c.dim; d++ { 92 | c0[d] += clusterToVecs[0][rand.Intn(lc0)][d] / float64(it0) 93 | } 94 | } 95 | 96 | c1 = make([]float64, c.dim) 97 | it1 := int(float64(lvs)*cosineMetricsCentroidCalcRatio + 1) 98 | for i := 0; i < int(float64(lc1)*cosineMetricsCentroidCalcRatio+1); i++ { 99 | for d := 0; d < c.dim; d++ { 100 | c1[d] += clusterToVecs[1][rand.Intn(lc1)][d] / float64(it1) 101 | } 102 | } 103 | } 104 | 105 | ret := make([]float64, c.dim) 106 | for d := 0; d < c.dim; d++ { 107 | v := c0[d] - c1[d] 108 | ret[d] += v 109 | } 110 | return ret 111 | } 112 | 113 | func (c *cosineDistance) CalcDirectionPriority(base, target []float64) float64 { 114 | var ret float64 115 | for i := range base { 116 | ret += base[i] * target[i] 117 | } 118 | return ret 119 | } 120 | -------------------------------------------------------------------------------- /metric/cosine_test.go: -------------------------------------------------------------------------------- 1 | package metric 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "testing" 7 | 8 | "github.com/bmizerany/assert" 9 | ) 10 | 11 | func TestCosineDistance_CalcDirectionPriority(t *testing.T) { 12 | for i, c := range []struct { 13 | v1, v2 []float64 14 | exp float64 15 | dim int 16 | }{ 17 | { 18 | v1: []float64{1.2, 0.1}, 19 | v2: []float64{-1.2, 0.2}, 20 | dim: 2, 21 | exp: -1.42, 22 | }, 23 | { 24 | v1: []float64{1.2, 0.1, 0, 0, 0, 0, 0, 0, 0, 0}, 25 | v2: []float64{-1.2, 0.2, 0, 0, 0, 0, 0, 0, 0, 0}, 26 | dim: 10, 27 | exp: -1.42, 28 | }, 29 | } { 30 | c := c 31 | t.Run(fmt.Sprintf("%d-th case", i), func(t *testing.T) { 32 | cosine := &cosineDistance{dim: c.dim} 33 | actual := cosine.CalcDirectionPriority(c.v1, c.v2) 34 | assert.Equal(t, c.exp, actual) 35 | }) 36 | } 37 | } 38 | 39 | func TestCosineDistance_GetSplittingVector(t *testing.T) { 40 | for i, c := range []struct { 41 | dim, num int 42 | }{ 43 | { 44 | dim: 5, num: 100, 45 | }, 46 | } { 47 | c := c 48 | t.Run(fmt.Sprintf("%d-th case", i), func(t *testing.T) { 49 | cosine := &cosineDistance{dim: c.dim} 50 | vs := make([][]float64, c.num) 51 | for i := 0; i < c.num; i++ { 52 | v := make([]float64, c.dim) 53 | for d := 0; d < c.dim; d++ { 54 | v[d] = rand.Float64() 55 | } 56 | vs[i] = v 57 | } 58 | 59 | cosine.GetSplittingVector(vs) 60 | }) 61 | } 62 | } 63 | 64 | func TestCosineDistance_CalcDistance(t *testing.T) { 65 | for i, c := range []struct { 66 | v1, v2 []float64 67 | exp float64 68 | dim int 69 | }{ 70 | { 71 | v1: []float64{1.2, 0.1}, 72 | v2: []float64{-1.2, 0.2}, 73 | dim: 2, 74 | exp: 1.42, 75 | }, 76 | { 77 | v1: []float64{1.2, 0.1, 0, 0, 0, 0, 0, 0, 0, 0}, 78 | v2: []float64{-1.2, 0.2, 0, 0, 0, 0, 0, 0, 0, 0}, 79 | dim: 10, 80 | exp: 1.42, 81 | }, 82 | } { 83 | c := c 84 | t.Run(fmt.Sprintf("%d-th case", i), func(t *testing.T) { 85 | cosine := &cosineDistance{dim: c.dim} 86 | actual := cosine.CalcDistance(c.v1, c.v2) 87 | assert.Equal(t, c.exp, actual) 88 | }) 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /metric/metric.go: -------------------------------------------------------------------------------- 1 | package metric 2 | 3 | // Metric is the interface of metrics which defines target search spaces. 4 | type Metric interface { 5 | // CalcDistance ... calculates the distance between given vectors 6 | CalcDistance(v1, v2 []float64) float64 7 | // GetSplittingVector ... calculates the splitting vector which becomes a node's vector in the index 8 | GetSplittingVector(vs [][]float64) []float64 9 | // CalcDirectionPriority ... calculates the priority of the children nodes which can be used for determining 10 | // which way (right or left child) should go next traversal. The return values must be contained in [-1, 1]. 11 | CalcDirectionPriority(base, target []float64) float64 12 | } 13 | -------------------------------------------------------------------------------- /node.go: -------------------------------------------------------------------------------- 1 | package gann 2 | 3 | import ( 4 | "github.com/google/uuid" 5 | ) 6 | 7 | type nodeId string 8 | type direction string 9 | 10 | const ( 11 | left direction = "left" 12 | right direction = "right" 13 | ) 14 | 15 | var directions = []direction{left, right} 16 | 17 | type node struct { 18 | idxPtr *index 19 | 20 | id nodeId 21 | 22 | // the normal vector of the hyper plane which splits the space, represented by the node 23 | vec []float64 24 | 25 | // children of node. If len equals 0, then it is leaf node. 26 | children map[direction]*node 27 | 28 | // In our setting, a `leaf` is a kind of node with len(leaf) > 0 29 | leaf []itemId 30 | } 31 | 32 | func (n *node) build(its []*item) { 33 | if len(its) <= n.idxPtr.k { 34 | n.leaf = make([]itemId, len(its)) 35 | for i, it := range its { 36 | n.leaf[i] = it.id 37 | } 38 | return 39 | } 40 | n.buildChildren(its) 41 | } 42 | 43 | func (n *node) buildChildren(its []*item) { 44 | dItems := map[direction][]*item{} 45 | dVectors := map[direction][][]float64{} 46 | for _, it := range its { 47 | if n.idxPtr.metric.CalcDirectionPriority(n.vec, it.vector) < 0 { 48 | dItems[left] = append(dItems[left], it) 49 | dVectors[left] = append(dVectors[left], it.vector) 50 | } else { 51 | dItems[right] = append(dItems[right], it) 52 | dVectors[right] = append(dVectors[right], it.vector) 53 | } 54 | } 55 | 56 | var shouldMerge = false 57 | for _, s := range directions { 58 | if len(dItems[s]) <= n.idxPtr.k { 59 | shouldMerge = true 60 | } 61 | } 62 | 63 | if shouldMerge { 64 | n.leaf = make([]itemId, len(its)) 65 | for i, it := range its { 66 | n.leaf[i] = it.id 67 | } 68 | return 69 | } 70 | 71 | for _, s := range directions { 72 | // build child 73 | c := &node{ 74 | vec: n.idxPtr.metric.GetSplittingVector(dVectors[s]), 75 | id: nodeId(uuid.New().String()), 76 | idxPtr: n.idxPtr, 77 | children: make(map[direction]*node, len(directions)), 78 | } 79 | 80 | c.build(dItems[s]) 81 | 82 | // append child for the search phase 83 | n.children[s] = c 84 | 85 | // append child to global map for the search phase 86 | n.idxPtr.mux.Lock() 87 | n.idxPtr.nodeIDToNode[c.id] = c 88 | n.idxPtr.mux.Unlock() 89 | } 90 | return 91 | } 92 | -------------------------------------------------------------------------------- /node_test.go: -------------------------------------------------------------------------------- 1 | package gann 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | "testing" 7 | 8 | "github.com/bmizerany/assert" 9 | "github.com/mathetake/gann/metric" 10 | ) 11 | 12 | func TestNodeBuild(t *testing.T) { 13 | for i, c := range []struct { 14 | vec []float64 15 | items []*item 16 | dim, k int 17 | expLeaf bool 18 | }{ 19 | { 20 | vec: []float64{0.0, 1.0}, 21 | items: []*item{ 22 | {id: 0, vector: []float64{0.0, 1.0}}, 23 | {id: 1, vector: []float64{0.0, -1.0}}, 24 | }, 25 | k: 2, 26 | dim: 2, 27 | expLeaf: true, 28 | }, 29 | { 30 | vec: []float64{0.0, 1.0}, 31 | items: []*item{ 32 | {id: 0, vector: []float64{0.0, 1.0}}, 33 | {id: 0, vector: []float64{0.0, 1.1}}, 34 | {id: 0, vector: []float64{0.0, 1.2}}, 35 | {id: 1, vector: []float64{0.0, -1.0}}, 36 | {id: 2, vector: []float64{0.0, -1.1}}, 37 | {id: 2, vector: []float64{0.0, -1.2}}, 38 | }, 39 | k: 2, 40 | dim: 2, 41 | expLeaf: false, 42 | }, 43 | } { 44 | c := c 45 | i := i 46 | t.Run(fmt.Sprintf("%d-th case", i), func(t *testing.T) { 47 | m, err := metric.NewCosineMetric(c.dim) 48 | if err != nil { 49 | t.Fatal(err) 50 | } 51 | 52 | idxPtr := &index{ 53 | k: 1, 54 | mux: &sync.Mutex{}, 55 | metric: m, 56 | nodeIDToNode: map[nodeId]*node{}, 57 | } 58 | 59 | n := &node{ 60 | id: nodeId(fmt.Sprintf("%d", i)), 61 | vec: c.vec, 62 | idxPtr: idxPtr, 63 | children: make(map[direction]*node, len(directions)), 64 | } 65 | n.build(c.items) 66 | 67 | if c.expLeaf { 68 | assert.Equal(t, true, len(n.leaf) > 0) 69 | } else { 70 | assert.Equal(t, true, len(n.leaf) == 0) 71 | assert.Equal(t, true, len(n.children) > 0) 72 | } 73 | }) 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /queue.go: -------------------------------------------------------------------------------- 1 | package gann 2 | 3 | type queueItem struct { 4 | value nodeId 5 | index int 6 | priority float64 7 | } 8 | 9 | type priorityQueue []*queueItem 10 | 11 | func (pq priorityQueue) Len() int { return len(pq) } 12 | 13 | func (pq priorityQueue) Less(i, j int) bool { return pq[i].priority < pq[j].priority } 14 | 15 | func (pq priorityQueue) Swap(i, j int) { 16 | pq[i], pq[j] = pq[j], pq[i] 17 | pq[i].index = i 18 | pq[j].index = j 19 | } 20 | 21 | func (pq *priorityQueue) Push(x interface{}) { 22 | n := len(*pq) 23 | item := x.(*queueItem) 24 | item.index = n 25 | *pq = append(*pq, item) 26 | } 27 | 28 | func (pq *priorityQueue) Pop() interface{} { 29 | old := *pq 30 | n := len(old) 31 | item := old[n-1] 32 | item.index = -1 // for safety 33 | *pq = old[0 : n-1] 34 | return item 35 | } 36 | -------------------------------------------------------------------------------- /queue_test.go: -------------------------------------------------------------------------------- 1 | package gann 2 | 3 | import ( 4 | "container/heap" 5 | "fmt" 6 | "math" 7 | "testing" 8 | 9 | "github.com/bmizerany/assert" 10 | ) 11 | 12 | func TestPriorityQueue(t *testing.T) { 13 | for i, c := range []struct { 14 | valueToPriority map[nodeId]float64 15 | expValues []nodeId 16 | }{ 17 | { 18 | valueToPriority: map[nodeId]float64{ 19 | "a": 1, 20 | "b": math.Inf(-1), 21 | "c": 3, 22 | "d": 100, 23 | }, 24 | expValues: []nodeId{ 25 | "b", "a", "c", "d", 26 | }, 27 | }, 28 | { 29 | valueToPriority: map[nodeId]float64{ 30 | "a": 1, 31 | "b": math.Inf(-1), 32 | "c": 3, 33 | "d": -10, 34 | }, 35 | expValues: []nodeId{ 36 | "b", "d", "a", "c", 37 | }, 38 | }, 39 | } { 40 | c := c 41 | t.Run(fmt.Sprintf("%d-th case", i), func(t *testing.T) { 42 | var i int 43 | pq := make(priorityQueue, len(c.valueToPriority)) 44 | for v, pr := range c.valueToPriority { 45 | pq[i] = &queueItem{ 46 | value: v, 47 | priority: pr, 48 | index: i, 49 | } 50 | i++ 51 | } 52 | heap.Init(&pq) 53 | 54 | for _, v := range c.expValues { 55 | qi := heap.Pop(&pq).(*queueItem) 56 | assert.Equal(t, qi.value, v) 57 | } 58 | }) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /search.go: -------------------------------------------------------------------------------- 1 | package gann 2 | 3 | import ( 4 | "container/heap" 5 | "math" 6 | "sort" 7 | ) 8 | 9 | func (idx *index) GetANNbyItemID(id int64, searchNum int, bucketScale float64) ([]int64, error) { 10 | it, ok := idx.itemIDToItem[itemId(id)] 11 | if !ok { 12 | return nil, errItemNotFoundOnGivenItemID 13 | } 14 | return idx.GetANNbyVector(it.vector, searchNum, bucketScale) 15 | } 16 | 17 | func (idx *index) GetANNbyVector(v []float64, searchNum int, bucketScale float64) ([]int64, error) { 18 | /* 19 | 1. insert root nodes into the priority queue 20 | 2. search all trees until len(`ann`) is enough. 21 | 3. calculate actual distances to each elements in ann from v. 22 | 4. sort `ann` by distances. 23 | 5. Return the top `num` ones. 24 | */ 25 | 26 | if len(v) != idx.dim { 27 | return nil, errInvalidKeyVector 28 | } 29 | 30 | bucketSize := int(float64(searchNum) * bucketScale) 31 | annMap := make(map[itemId]struct{}, bucketSize) 32 | 33 | pq := priorityQueue{} 34 | 35 | // 1. 36 | for i, r := range idx.roots { 37 | n := &queueItem{ 38 | value: r.id, 39 | index: i, 40 | priority: math.Inf(-1), 41 | } 42 | pq = append(pq, n) 43 | } 44 | 45 | heap.Init(&pq) 46 | 47 | // 2. 48 | for pq.Len() > 0 && len(annMap) < bucketSize { 49 | q, ok := heap.Pop(&pq).(*queueItem) 50 | d := q.priority 51 | n, ok := idx.nodeIDToNode[q.value] 52 | if !ok { 53 | return nil, errInvalidIndex 54 | } 55 | 56 | if len(n.leaf) > 0 { 57 | for _, id := range n.leaf { 58 | annMap[id] = struct{}{} 59 | } 60 | continue 61 | } 62 | 63 | dp := idx.metric.CalcDirectionPriority(n.vec, v) 64 | heap.Push(&pq, &queueItem{ 65 | value: n.children[left].id, 66 | priority: max(d, dp), 67 | }) 68 | heap.Push(&pq, &queueItem{ 69 | value: n.children[right].id, 70 | priority: max(d, -dp), 71 | }) 72 | } 73 | 74 | // 3. 75 | idToDist := make(map[int64]float64, len(annMap)) 76 | ann := make([]int64, 0, len(annMap)) 77 | for id := range annMap { 78 | iid := int64(id) 79 | ann = append(ann, iid) 80 | idToDist[iid] = idx.metric.CalcDistance(idx.itemIDToItem[id].vector, v) 81 | } 82 | 83 | // 4. 84 | sort.Slice(ann, func(i, j int) bool { 85 | return idToDist[ann[i]] < idToDist[ann[j]] 86 | }) 87 | 88 | // 5. 89 | if len(ann) > searchNum { 90 | ann = ann[:searchNum] 91 | } 92 | return ann, nil 93 | } 94 | 95 | func max(a, b float64) float64 { 96 | if a < b { 97 | return b 98 | } 99 | return a 100 | } 101 | -------------------------------------------------------------------------------- /search_test.go: -------------------------------------------------------------------------------- 1 | package gann 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "math/rand" 7 | "sort" 8 | "testing" 9 | 10 | "github.com/mathetake/gann/metric" 11 | ) 12 | 13 | func TestIndex_GetANNbyItemID(t *testing.T) { 14 | for i, c := range []struct { 15 | dim, num, nTree, k int 16 | }{ 17 | {dim: 2, num: 1000, nTree: 10, k: 2}, 18 | {dim: 10, num: 100, nTree: 5, k: 10}, 19 | {dim: 1000, num: 10000, nTree: 5, k: 10}, 20 | } { 21 | c := c 22 | t.Run(fmt.Sprintf("%d-th case", i), func(t *testing.T) { 23 | rawItems := make([][]float64, c.num) 24 | for i := range rawItems { 25 | v := make([]float64, c.dim) 26 | 27 | var norm float64 28 | for j := range v { 29 | cof := rand.Float64() - 0.5 30 | v[j] = cof 31 | norm += cof * cof 32 | } 33 | 34 | norm = math.Sqrt(norm) 35 | for j := range v { 36 | v[j] /= norm 37 | } 38 | 39 | rawItems[i] = v 40 | } 41 | 42 | m, err := metric.NewCosineMetric(c.dim) 43 | if err != nil { 44 | t.Fatal(err) 45 | } 46 | 47 | idx, err := CreateNewIndex(rawItems, c.dim, c.nTree, c.k, m) 48 | if err != nil { 49 | t.Fatal(err) 50 | } 51 | 52 | if _, err = idx.GetANNbyItemID(0, 10, 2); err != nil { 53 | t.Fatal(err) 54 | } 55 | }) 56 | } 57 | } 58 | 59 | func TestIndex_GetANNbyVector(t *testing.T) { 60 | for i, c := range []struct { 61 | dim, num, nTree, k int 62 | }{ 63 | {dim: 2, num: 1000, nTree: 10, k: 2}, 64 | {dim: 10, num: 100, nTree: 5, k: 10}, 65 | {dim: 1000, num: 10000, nTree: 5, k: 10}, 66 | } { 67 | c := c 68 | t.Run(fmt.Sprintf("%d-th case", i), func(t *testing.T) { 69 | rawItems := make([][]float64, c.num) 70 | for i := range rawItems { 71 | v := make([]float64, c.dim) 72 | 73 | var norm float64 74 | for j := range v { 75 | cof := rand.Float64() - 0.5 76 | v[j] = cof 77 | norm += cof * cof 78 | } 79 | 80 | norm = math.Sqrt(norm) 81 | for j := range v { 82 | v[j] /= norm 83 | } 84 | 85 | rawItems[i] = v 86 | } 87 | 88 | m, err := metric.NewCosineMetric(c.dim) 89 | if err != nil { 90 | t.Fatal(err) 91 | } 92 | 93 | idx, err := CreateNewIndex(rawItems, c.dim, c.nTree, c.k, m) 94 | if err != nil { 95 | t.Fatal(err) 96 | } 97 | 98 | key := make([]float64, c.dim) 99 | for i := range key { 100 | key[i] = rand.Float64() - 0.5 101 | } 102 | 103 | if _, err = idx.GetANNbyVector(key, 10, 2); err != nil { 104 | t.Fatal(err) 105 | } 106 | }) 107 | } 108 | } 109 | 110 | // This unit test is made to verify if our algorithm can correctly find 111 | // the `exact` neighbors. That is done by checking the ratio of exact 112 | // neighbors in the result returned by `getANNbyVector` is less than 113 | // the given threshold. 114 | func TestAnnSearchAccuracy(t *testing.T) { 115 | for i, c := range []struct { 116 | k, dim, num, nTree, searchNum int 117 | threshold, bucketScale float64 118 | }{ 119 | { 120 | k: 2, 121 | dim: 20, 122 | num: 10000, 123 | nTree: 20, 124 | threshold: 0.90, 125 | searchNum: 200, 126 | bucketScale: 20, 127 | }, 128 | { 129 | k: 2, 130 | dim: 20, 131 | num: 10000, 132 | nTree: 20, 133 | threshold: 0.8, 134 | searchNum: 20, 135 | bucketScale: 1000, 136 | }, 137 | } { 138 | c := c 139 | t.Run(fmt.Sprintf("%d-th case", i), func(t *testing.T) { 140 | rawItems := make([][]float64, c.num) 141 | for i := range rawItems { 142 | v := make([]float64, c.dim) 143 | 144 | var norm float64 145 | for j := range v { 146 | cof := rand.Float64() - 0.5 147 | v[j] = cof 148 | norm += cof * cof 149 | } 150 | 151 | norm = math.Sqrt(norm) 152 | for j := range v { 153 | v[j] /= norm 154 | } 155 | 156 | rawItems[i] = v 157 | } 158 | 159 | m, err := metric.NewCosineMetric(c.dim) 160 | if err != nil { 161 | t.Fatal(err) 162 | } 163 | 164 | idx, err := CreateNewIndex(rawItems, c.dim, c.nTree, c.k, m) 165 | if err != nil { 166 | t.Fatal(err) 167 | } 168 | 169 | rawIdx, ok := idx.(*index) 170 | if !ok { 171 | t.Fatal("assertion failed") 172 | } 173 | 174 | // query vector 175 | query := make([]float64, c.dim) 176 | query[0] = 0.1 177 | 178 | // exact neighbors 179 | aDist := map[int64]float64{} 180 | ids := make([]int64, len(rawItems)) 181 | for i, v := range rawItems { 182 | ids[i] = int64(i) 183 | aDist[int64(i)] = rawIdx.metric.CalcDistance(v, query) 184 | } 185 | sort.Slice(ids, func(i, j int) bool { 186 | return aDist[ids[i]] < aDist[ids[j]] 187 | }) 188 | 189 | expectedIDsMap := make(map[int64]struct{}, c.searchNum) 190 | for _, id := range ids[:c.searchNum] { 191 | expectedIDsMap[int64(id)] = struct{}{} 192 | } 193 | 194 | ass, err := idx.GetANNbyVector(query, c.searchNum, c.bucketScale) 195 | if err != nil { 196 | t.Fatal(err) 197 | } 198 | 199 | var count int 200 | for _, id := range ass { 201 | if _, ok := expectedIDsMap[id]; ok { 202 | count++ 203 | } 204 | } 205 | 206 | if ratio := float64(count) / float64(c.searchNum); ratio < c.threshold { 207 | t.Fatalf("Too few exact neighbors found in approximated result: %d / %d = %f", count, c.searchNum, ratio) 208 | } else { 209 | t.Logf("ratio of exact neighbors in approximated result: %d / %d = %f", count, c.searchNum, ratio) 210 | } 211 | }) 212 | } 213 | } 214 | --------------------------------------------------------------------------------