├── .gitignore ├── README.md ├── bitsetpool ├── bitsetpool.go └── bitsetpool_test.go ├── distqueue ├── distqueue.go └── distqueue_test.go ├── env.sh ├── examples └── simple.go ├── f32 ├── f32_amd64.go ├── l2squared8_avx_amd64.s ├── l2squared_amd64.s └── l2squared_test.go ├── get_test_data.sh ├── hnsw.go └── hnsw_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | sift 2 | siftsmall 3 | src/ 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # go-hnsw 2 | 3 | go-hnsw is a GO implementation of the HNSW approximate nearest-neighbour search algorithm implemented in C++ in https://github.com/searchivarius/nmslib and described in https://arxiv.org/abs/1603.09320 4 | 5 | ## Usage 6 | 7 | Simple usage example. See examples folder for more. 8 | Note that both index building and searching can be safely done in parallel with multiple goroutines. 9 | You can always extend the index, even while searching. 10 | 11 | ```go 12 | package main 13 | 14 | import ( 15 | "fmt" 16 | "math/rand" 17 | "time" 18 | 19 | "github.com/Bithack/go-hnsw" 20 | ) 21 | 22 | func main() { 23 | 24 | const ( 25 | M = 32 26 | efConstruction = 400 27 | efSearch = 100 28 | K = 10 29 | ) 30 | 31 | var zero hnsw.Point = make([]float32, 128) 32 | 33 | h := hnsw.New(M, efConstruction, zero) 34 | h.Grow(10000) 35 | 36 | // Note that added ID:s must start from 1 37 | for i := 1; i <= 10000; i++ { 38 | h.Add(randomPoint(), uint32(i)) 39 | if (i)%1000 == 0 { 40 | fmt.Printf("%v points added\n", i) 41 | } 42 | } 43 | 44 | start := time.Now() 45 | for i := 0; i < 1000; i++ { 46 | Search(randomPoint, efSearch, K) 47 | } 48 | stop := time.Since(start) 49 | 50 | fmt.Printf("%v queries / second (single thread)\n", 1000.0/stop.Seconds()) 51 | } 52 | 53 | func randomPoint() hnsw.Point { 54 | var v hnsw.Point = make([]float32, 128) 55 | for i := range v { 56 | v[i] = rand.Float32() 57 | } 58 | return v 59 | } 60 | 61 | ``` -------------------------------------------------------------------------------- /bitsetpool/bitsetpool.go: -------------------------------------------------------------------------------- 1 | package bitsetpool 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/willf/bitset" 7 | ) 8 | 9 | type poolItem struct { 10 | b bitset.BitSet 11 | busy bool 12 | } 13 | 14 | type BitsetPool struct { 15 | sync.RWMutex 16 | pool []poolItem 17 | } 18 | 19 | func New() *BitsetPool { 20 | var bp BitsetPool 21 | bp.pool = make([]poolItem, 0) 22 | return &bp 23 | } 24 | 25 | func (bp *BitsetPool) Free(i int) { 26 | bp.Lock() 27 | bp.pool[i].busy = false 28 | bp.Unlock() 29 | } 30 | 31 | func (bp *BitsetPool) Get() (int, *bitset.BitSet) { 32 | bp.Lock() 33 | for i := range bp.pool { 34 | if !bp.pool[i].busy { 35 | bp.pool[i].busy = true 36 | bp.pool[i].b.ClearAll() 37 | bp.Unlock() 38 | return i, &bp.pool[i].b 39 | } 40 | } 41 | id := len(bp.pool) 42 | bp.pool = append(bp.pool, poolItem{}) 43 | bp.Unlock() 44 | return id, &bp.pool[id].b 45 | } 46 | -------------------------------------------------------------------------------- /bitsetpool/bitsetpool_test.go: -------------------------------------------------------------------------------- 1 | package bitsetpool 2 | 3 | import ( 4 | "math/rand" 5 | "testing" 6 | "time" 7 | 8 | "github.com/willf/bitset" 9 | ) 10 | 11 | func TestBitset(t *testing.T) { 12 | 13 | start2 := time.Now() 14 | for j := 0; j < 100000; j++ { 15 | b2 := make(map[uint32]bool) 16 | for i := 0; i < 100; i++ { 17 | n := rand.Intn(1000000) 18 | b2[uint32(n)] = true 19 | m := rand.Intn(1000000) 20 | if b2[uint32(m)] == false { 21 | } 22 | } 23 | } 24 | stop2 := time.Since(start2) 25 | t.Logf("map done in %v", stop2.Seconds()) 26 | 27 | start := time.Now() 28 | for j := 0; j < 100000; j++ { 29 | var b1 bitset.BitSet 30 | for i := 0; i < 100; i++ { 31 | n := rand.Intn(1000000) 32 | b1.Set(uint(n)) 33 | m := rand.Intn(1000000) 34 | b1.Test(uint(m)) 35 | } 36 | } 37 | stop := time.Since(start) 38 | t.Logf("bitset done in %v", stop.Seconds()) 39 | 40 | start3 := time.Now() 41 | pool := NewBitsetPool() 42 | for j := 0; j < 100000; j++ { 43 | id, b := pool.Get() 44 | for i := 0; i < 100; i++ { 45 | n := rand.Intn(1000000) 46 | b.Set(uint(n)) 47 | m := rand.Intn(1000000) 48 | b.Test(uint(m)) 49 | } 50 | pool.Free(id) 51 | } 52 | stop3 := time.Since(start3) 53 | t.Logf("bitset pool done in %v", stop3.Seconds()) 54 | 55 | t.Logf("Performance boost %.2f%%", 100*(1-stop3.Seconds()/stop2.Seconds())) 56 | } 57 | -------------------------------------------------------------------------------- /distqueue/distqueue.go: -------------------------------------------------------------------------------- 1 | package distqueue 2 | 3 | type Item struct { 4 | ID uint32 5 | D float32 6 | } 7 | 8 | type DistQueueClosestFirst struct { 9 | initiated bool 10 | items []*Item 11 | Size int 12 | } 13 | 14 | func (pq *DistQueueClosestFirst) Init() *DistQueueClosestFirst { 15 | pq.items = make([]*Item, 1, pq.Size+1) 16 | pq.items[0] = nil // Heap queue first element should always be nil 17 | pq.initiated = true 18 | return pq 19 | } 20 | 21 | func (pq *DistQueueClosestFirst) Reset() { 22 | pq.items = pq.items[0:1] 23 | } 24 | func (pq *DistQueueClosestFirst) Items() []*Item { 25 | return pq.items[1:] 26 | } 27 | func (pq *DistQueueClosestFirst) Reserve(n int) { 28 | if n > len(pq.items)-1 { 29 | // reserve memory by setting the slice capacity 30 | items2 := make([]*Item, len(pq.items), n+1) 31 | copy(pq.items, items2) 32 | pq.items = items2 33 | } 34 | } 35 | 36 | // Push the value item into the priority queue with provided priority. 37 | func (pq *DistQueueClosestFirst) Push(id uint32, d float32) *Item { 38 | if !pq.initiated { 39 | pq.Init() 40 | } 41 | item := &Item{ID: id, D: d} 42 | pq.items = append(pq.items, item) 43 | pq.swim(len(pq.items) - 1) 44 | return item 45 | } 46 | 47 | func (pq *DistQueueClosestFirst) PushItem(item *Item) { 48 | if !pq.initiated { 49 | pq.Init() 50 | } 51 | pq.items = append(pq.items, item) 52 | pq.swim(len(pq.items) - 1) 53 | } 54 | 55 | func (pq *DistQueueClosestFirst) Pop() *Item { 56 | if len(pq.items) <= 1 { 57 | return nil 58 | } 59 | var max = pq.items[1] 60 | //pq.items[1], pq.items[len(pq.items)-1] = pq.items[len(pq.items)-1], pq.items[1] 61 | pq.items[1], pq.items[len(pq.items)-1] = pq.items[len(pq.items)-1], pq.items[1] 62 | pq.items = pq.items[0 : len(pq.items)-1] 63 | pq.sink(1) 64 | return max 65 | } 66 | 67 | func (pq *DistQueueClosestFirst) Top() (uint32, float32) { 68 | if len(pq.items) <= 1 { 69 | return 0, 0 70 | } 71 | return pq.items[1].ID, pq.items[1].D 72 | } 73 | 74 | func (pq *DistQueueClosestFirst) Head() (uint32, float32) { 75 | if len(pq.items) <= 1 { 76 | return 0, 0 77 | } 78 | return pq.items[1].ID, pq.items[1].D 79 | } 80 | 81 | func (pq *DistQueueClosestFirst) Len() int { 82 | return len(pq.items) - 1 83 | } 84 | 85 | func (pq *DistQueueClosestFirst) Empty() bool { 86 | return len(pq.items) == 1 87 | } 88 | 89 | func (pq *DistQueueClosestFirst) swim(k int) { 90 | for k > 1 && (pq.items[k/2].D > pq.items[k].D) { 91 | pq.items[k], pq.items[k/2] = pq.items[k/2], pq.items[k] 92 | k = k / 2 93 | } 94 | } 95 | 96 | func (pq *DistQueueClosestFirst) sink(k int) { 97 | for 2*k <= len(pq.items)-1 { 98 | var j = 2 * k 99 | if j < len(pq.items)-1 && (pq.items[j].D > pq.items[j+1].D) { 100 | j++ 101 | } 102 | if !(pq.items[k].D > pq.items[j].D) { 103 | break 104 | } 105 | pq.items[k], pq.items[j] = pq.items[j], pq.items[k] 106 | k = j 107 | } 108 | } 109 | 110 | type DistQueueClosestLast struct { 111 | initiated bool 112 | items []*Item 113 | Size int 114 | } 115 | 116 | func (pq *DistQueueClosestLast) Init() *DistQueueClosestLast { 117 | pq.items = make([]*Item, 1, pq.Size+1) 118 | pq.items[0] = nil // Heap queue first element should always be nil 119 | pq.initiated = true 120 | return pq 121 | } 122 | 123 | func (pq *DistQueueClosestLast) Items() []*Item { 124 | return pq.items[1:] 125 | } 126 | func (pq *DistQueueClosestLast) Reserve(n int) { 127 | if n > len(pq.items)-1 { 128 | // reserve memory by setting the slice capacity 129 | items2 := make([]*Item, len(pq.items), n+1) 130 | copy(pq.items, items2) 131 | pq.items = items2 132 | } 133 | } 134 | 135 | // Push the value item into the priority queue with provided priority. 136 | func (pq *DistQueueClosestLast) Push(id uint32, d float32) *Item { 137 | if !pq.initiated { 138 | pq.Init() 139 | } 140 | item := &Item{ID: id, D: d} 141 | pq.items = append(pq.items, item) 142 | pq.swim(len(pq.items) - 1) 143 | return item 144 | } 145 | 146 | // PopAndPush pops the top element and adds a new to the heap in one operation which is faster than two seperate calls to Pop and Push 147 | func (pq *DistQueueClosestLast) PopAndPush(id uint32, d float32) *Item { 148 | if !pq.initiated { 149 | pq.Init() 150 | } 151 | item := &Item{ID: id, D: d} 152 | pq.items[1] = item 153 | pq.sink(1) 154 | return item 155 | } 156 | 157 | func (pq *DistQueueClosestLast) PushItem(item *Item) { 158 | if !pq.initiated { 159 | pq.Init() 160 | } 161 | pq.items = append(pq.items, item) 162 | pq.swim(len(pq.items) - 1) 163 | } 164 | 165 | func (pq *DistQueueClosestLast) Pop() *Item { 166 | if len(pq.items) <= 1 { 167 | return nil 168 | } 169 | var max = pq.items[1] 170 | pq.items[1], pq.items[len(pq.items)-1] = pq.items[len(pq.items)-1], pq.items[1] 171 | pq.items = pq.items[0 : len(pq.items)-1] 172 | pq.sink(1) 173 | return max 174 | } 175 | 176 | func (pq *DistQueueClosestLast) Top() (uint32, float32) { 177 | if len(pq.items) <= 1 { 178 | return 0, 0 179 | } 180 | return pq.items[1].ID, pq.items[1].D 181 | } 182 | 183 | func (pq *DistQueueClosestLast) Head() (uint32, float32) { 184 | if len(pq.items) <= 1 { 185 | return 0, 0 186 | } 187 | return pq.items[1].ID, pq.items[1].D 188 | } 189 | 190 | func (pq *DistQueueClosestLast) Len() int { 191 | return len(pq.items) - 1 192 | } 193 | 194 | func (pq *DistQueueClosestLast) Empty() bool { 195 | return len(pq.items) == 1 196 | } 197 | 198 | func (pq *DistQueueClosestLast) swim(k int) { 199 | for k > 1 && (pq.items[k/2].D < pq.items[k].D) { 200 | pq.items[k], pq.items[k/2] = pq.items[k/2], pq.items[k] 201 | //pq.exch(k/2, k) 202 | k = k / 2 203 | } 204 | } 205 | 206 | func (pq *DistQueueClosestLast) sink(k int) { 207 | for 2*k <= len(pq.items)-1 { 208 | var j = 2 * k 209 | if j < len(pq.items)-1 && (pq.items[j].D < pq.items[j+1].D) { 210 | j++ 211 | } 212 | if !(pq.items[k].D < pq.items[j].D) { 213 | break 214 | } 215 | pq.items[k], pq.items[j] = pq.items[j], pq.items[k] 216 | k = j 217 | } 218 | } 219 | -------------------------------------------------------------------------------- /distqueue/distqueue_test.go: -------------------------------------------------------------------------------- 1 | package distqueue 2 | 3 | import ( 4 | "math/rand" 5 | "testing" 6 | ) 7 | 8 | func TestQueue(t *testing.T) { 9 | 10 | pq := &DistQueueClosestFirst{} 11 | 12 | for i := 0; i < 10; i++ { 13 | pq.Push(rand.Uint32(), float32(rand.Float64())) 14 | } 15 | 16 | t.Log("Closest first, pop") 17 | ID, D := pq.Top() 18 | t.Logf("TOP before first top: %v %v", ID, D) 19 | var l float32 = 0.0 20 | for pq.Len() > 0 { 21 | item := pq.Pop() 22 | if item.D < l { 23 | t.Error("Incorrect order") 24 | } 25 | l = item.D 26 | t.Logf("%+v", item) 27 | } 28 | 29 | pq2 := &DistQueueClosestLast{} 30 | l = 1.0 31 | pq2.Init() 32 | pq2.Reserve(200) // try reserve 33 | for i := 0; i < 10; i++ { 34 | pq2.Push(rand.Uint32(), float32(rand.Float64())) 35 | } 36 | t.Log("Closest last, pop") 37 | for !pq2.Empty() { 38 | item := pq2.Pop() 39 | if item.D > l { 40 | t.Error("Incorrect order") 41 | } 42 | l = item.D 43 | t.Logf("%+v", item) 44 | } 45 | } 46 | 47 | func TestKBest(t *testing.T) { 48 | 49 | pq := &DistQueueClosestFirst{} 50 | pq.Reserve(5) // reserve less than needed 51 | for i := 0; i < 20; i++ { 52 | pq.Push(rand.Uint32(), rand.Float32()) 53 | } 54 | 55 | // return K best matches, ordered as best first 56 | t.Log("closest last, still return K best") 57 | K := 10 58 | for pq.Len() > K { 59 | pq.Pop() 60 | } 61 | res := make([]*Item, K) 62 | for i := K - 1; i >= 0; i-- { 63 | res[i] = pq.Pop() 64 | } 65 | for i := 0; i < len(res); i++ { 66 | t.Logf("%+v", res[i]) 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export GOPATH=`pwd` 4 | export GOBIN=$GOPATH/bin 5 | -------------------------------------------------------------------------------- /examples/simple.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "time" 7 | 8 | hnsw "github.com/Bithack/go-hnsw" 9 | ) 10 | 11 | func main() { 12 | 13 | const ( 14 | M = 32 15 | efConstruction = 400 16 | efSearch = 100 17 | K = 10 18 | ) 19 | 20 | var zero hnsw.Point = make([]float32, 128) 21 | 22 | h := hnsw.New(M, efConstruction, zero) 23 | h.Grow(10000) 24 | 25 | for i := 1; i <= 10000; i++ { 26 | h.Add(randomPoint(), uint32(i)) 27 | if (i)%1000 == 0 { 28 | fmt.Printf("%v points added\n", i) 29 | } 30 | } 31 | 32 | fmt.Printf("Generating queries and calculating true answers using bruteforce search...\n") 33 | queries := make([]hnsw.Point, 1000) 34 | truth := make([][]uint32, 1000) 35 | for i := range queries { 36 | queries[i] = randomPoint() 37 | result := h.SearchBrute(queries[i], K) 38 | truth[i] = make([]uint32, K) 39 | for j := K - 1; j >= 0; j-- { 40 | item := result.Pop() 41 | truth[i][j] = item.ID 42 | } 43 | } 44 | 45 | fmt.Printf("Now searching with HNSW...\n") 46 | hits := 0 47 | start := time.Now() 48 | for i := 0; i < 1000; i++ { 49 | result := h.Search(queries[i], efSearch, K) 50 | for j := 0; j < K; j++ { 51 | item := result.Pop() 52 | for k := 0; k < K; k++ { 53 | if item.ID == truth[i][k] { 54 | hits++ 55 | } 56 | } 57 | } 58 | } 59 | stop := time.Since(start) 60 | 61 | fmt.Printf("%v queries / second (single thread)\n", 1000.0/stop.Seconds()) 62 | fmt.Printf("Average 10-NN precision: %v\n", float64(hits)/(1000.0*float64(K))) 63 | 64 | } 65 | 66 | func randomPoint() hnsw.Point { 67 | var v hnsw.Point = make([]float32, 128) 68 | for i := range v { 69 | v[i] = rand.Float32() 70 | } 71 | return v 72 | } 73 | -------------------------------------------------------------------------------- /f32/f32_amd64.go: -------------------------------------------------------------------------------- 1 | //+build !noasm,!appengine 2 | 3 | package f32 4 | 5 | func L2Squared(x, y []float32) float32 6 | 7 | func L2Squared8AVX(x, y []float32) float32 8 | -------------------------------------------------------------------------------- /f32/l2squared8_avx_amd64.s: -------------------------------------------------------------------------------- 1 | //+build !noasm,!appengine 2 | 3 | #include "textflag.h" 4 | 5 | // This version is AVX optimized for vectors where the dimension is a multiple of 8 6 | // Latest GO versions seems to align []float32 slices on 32-bytes on a 64-bit system, so we skip checks for this... 7 | 8 | // func L2Squared8AVX(x, y []float32) (sum float32) 9 | TEXT ·L2Squared8AVX(SB), NOSPLIT, $0 10 | MOVQ x_base+0(FP), SI // SI = &x 11 | MOVQ x_len+8(FP), AX // AX = len(x) 12 | MOVQ y_base+24(FP), DI // DI = &y 13 | 14 | MOVQ AX, BX // BX = len(x) 15 | 16 | SHLQ $2, AX 17 | ADDQ AX, SI 18 | ADDQ AX, DI 19 | SHRQ $2, AX 20 | NEGQ AX 21 | 22 | BYTE $0xc5; BYTE $0xfc; BYTE $0x57; BYTE $0xc0 // vxorps ymm0,ymm0,ymm0 23 | 24 | ANDQ $0xF, BX // BX = len % 16 25 | JZ l2_loop_16 26 | 27 | // PRE LOOP, 8 values 28 | BYTE $0xc5; BYTE $0xfc; BYTE $0x28; BYTE $0x0c; BYTE $0x86 //vmovaps ymm1,YMMWORD PTR [esi+eax*4] 29 | BYTE $0xc5; BYTE $0xf4; BYTE $0x5c; BYTE $0x0c; BYTE $0x87 // vsubps ymm1,ymm1,YMMWORD PTR [edi+eax*4] 30 | BYTE $0xc5; BYTE $0xf4; BYTE $0x59; BYTE $0xc9 31 | BYTE $0xc5; BYTE $0xfc; BYTE $0x58; BYTE $0xc1; 32 | ADDQ $8, AX 33 | 34 | l2_loop_16: 35 | BYTE $0xc5; BYTE $0xfc; BYTE $0x28; BYTE $0x0c; BYTE $0x86 //vmovaps ymm1,YMMWORD PTR [esi+eax*4] 36 | BYTE $0xc5; BYTE $0xfc; BYTE $0x28; BYTE $0x54; BYTE $0x86; BYTE $0x20 //vmovaps ymm2,YMMWORD PTR [esi+eax*4+0x20] 37 | BYTE $0xc5; BYTE $0xf4; BYTE $0x5c; BYTE $0x0c; BYTE $0x87 // vsubps ymm1,ymm1,YMMWORD PTR [edi+eax*4] 38 | BYTE $0xc5; BYTE $0xec; BYTE $0x5c; BYTE $0x54; BYTE $0x87; BYTE $0x20 // vsubps ymm2,ymm2,YMMWORD PTR [edi+eax*4+0x20] 39 | BYTE $0xc5; BYTE $0xf4; BYTE $0x59; BYTE $0xc9 // vmulps ymmX,ymmX,ymmX 40 | BYTE $0xc5; BYTE $0xec; BYTE $0x59; BYTE $0xd2 41 | BYTE $0xc5; BYTE $0xfc; BYTE $0x58; BYTE $0xc1; // vaddps ymm0,ymm0,ymmX 42 | BYTE $0xc5; BYTE $0xfc; BYTE $0x58; BYTE $0xc2; 43 | ADDQ $16, AX // eax += 16 44 | JS l2_loop_16 // jump if negative 45 | 46 | l2_end: 47 | //auto x = _mm256_permute2f128_ps(v, v, 1); 48 | BYTE $0xc4; BYTE $0xe3; BYTE $0x7d; BYTE $0x06; BYTE $0xc8; BYTE $0x01; // vperm2f128 ymm1,ymm0,ymm0,0x1 49 | //auto y = _mm256_add_ps(v, x); 50 | BYTE $0xc5;BYTE $0xfc; BYTE $0x58;BYTE $0xc1; // vaddps ymm0,ymm0,ymm1 51 | //x = _mm256_shuffle_ps(y, y, _MM_SHUFFLE(2, 3, 0, 1)=0xB1); 52 | //_MM_SHUFFLE 53 | BYTE $0xc5;BYTE $0xfc;BYTE $0xc6;BYTE $0xc8; BYTE $0xb1 // vshufps ymm1,ymm0,ymm0,0xb1 54 | //x = _mm256_add_ps(x, y); 55 | BYTE $0xc5;BYTE $0xf4; BYTE $0x58;BYTE $0xc8 // vaddps ymm1,ymm1,ymm0 56 | //y = _mm256_shuffle_ps(x, x, _MM_SHUFFLE(1, 0, 3, 2)=0x8E); 57 | BYTE $0xc5;BYTE $0xf4; BYTE $0xc6;BYTE $0xc1; BYTE $0x8e // vshufps ymm0,ymm1,ymm1,0x8e 58 | //return _mm256_add_ps(x, y); 59 | BYTE $0xc5; BYTE $0xf4; BYTE $0x58; BYTE $0xc8 // vaddps ymm1,ymm1,ymm0 60 | 61 | VZEROUPPER 62 | MOVSS X1, ret+48(FP) // Return final sum. 63 | 64 | RET 65 | -------------------------------------------------------------------------------- /f32/l2squared_amd64.s: -------------------------------------------------------------------------------- 1 | //+build !noasm,!appengine 2 | 3 | #include "textflag.h" 4 | 5 | // This is the 16-byte SSE2 version. 6 | // It skips pointer alignment checks, since latest GO versions seems to align all []float32 slices on 16-bytes 7 | 8 | // func L2Squared(x, y []float32) (sum float32) 9 | TEXT ·L2Squared(SB), NOSPLIT, $0 10 | MOVQ x_base+0(FP), SI // SI = &x 11 | MOVQ y_base+24(FP), DI // DI = &y 12 | 13 | MOVQ x_len+8(FP), BX // BX = min( len(x), len(y) ) 14 | CMPQ y_len+32(FP), BX 15 | CMOVQLE y_len+32(FP), BX 16 | CMPQ BX, $0 // if BX == 0 { return } 17 | JE l2_end 18 | 19 | XORPS X1, X1 // sum = 0 20 | XORQ AX, AX // i = 0 21 | 22 | MOVQ BX, CX 23 | ANDQ $0xF, BX // BX = len % 16 24 | SHRQ $4, CX // CX = int( len / 16 ) 25 | JZ l2_tail4_start // if CX == 0 { return } 26 | 27 | l2_loop: // Loop unrolled 16x do { 28 | MOVAPS (SI)(AX*4), X2 // X2 = x[i:i+4] 29 | MOVAPS 16(SI)(AX*4), X3 30 | MOVAPS 32(SI)(AX*4), X4 31 | MOVAPS 48(SI)(AX*4), X5 32 | 33 | SUBPS (DI)(AX*4), X2 // X2 -= y[i:i+4] 34 | SUBPS 16(DI)(AX*4), X3 35 | SUBPS 32(DI)(AX*4), X4 36 | SUBPS 48(DI)(AX*4), X5 37 | 38 | MULPS X2, X2 39 | MULPS X3, X3 40 | MULPS X4, X4 41 | MULPS X5, X5 42 | 43 | ADDPS X2, X1 44 | ADDPS X3, X1 45 | ADDPS X4, X1 46 | ADDPS X5, X1 47 | 48 | ADDQ $16, AX // i += 16 49 | LOOP l2_loop // while (--CX) > 0 50 | CMPQ BX, $0 // if BX == 0 { return } 51 | JE l2_end 52 | 53 | l2_tail4_start: // Reset loop counter for 4-wide tail loop 54 | MOVQ BX, CX // CX = floor( BX / 4 ) 55 | SHRQ $2, CX 56 | JZ l2_tail_start // if CX == 0 { goto l2_tail_start } 57 | 58 | l2_tail4: // Loop unrolled 4x do { 59 | MOVUPS (SI)(AX*4), X2 // X2 = x[i] 60 | SUBPS (DI)(AX*4), X2 // X2 -= y[i:i+4] 61 | MULPS X2, X2 // X2 *= X2 62 | ADDPS X2, X1 // X1 += X2 63 | ADDQ $4, AX // i += 4 64 | LOOP l2_tail4 // } while --CX > 0 65 | 66 | l2_tail_start: // Reset loop counter for 1-wide tail loop 67 | MOVQ BX, CX // CX = BX % 4 68 | ANDQ $3, CX 69 | JZ l2_end // if CX == 0 { return } 70 | 71 | l2_tail: 72 | MOVSS (SI)(AX*4), X2 // X1 = x[i] 73 | SUBSS (DI)(AX*4), X2 // X1 -= y[i] 74 | MULSS X2, X2 // X1 *= a 75 | ADDSS X2, X1 // sum += X2 76 | INCQ AX // i++ 77 | LOOP l2_tail // } while --CX > 0 78 | 79 | l2_end: 80 | 81 | MOVUPS X1, X0 82 | SHUFPS $0x93, X0, X0 83 | ADDPS X0, X1 84 | SHUFPS $0x93, X0, X0 85 | ADDPS X0, X1 86 | SHUFPS $0x93, X0, X0 87 | ADDPS X0, X1 88 | 89 | MOVSS X1, ret+48(FP) // Return final sum. 90 | RET 91 | -------------------------------------------------------------------------------- /f32/l2squared_test.go: -------------------------------------------------------------------------------- 1 | package f32 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "testing" 7 | "time" 8 | "unsafe" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func DistGo(a, b []float32) (r float32) { 14 | var d float32 15 | for i := range a { 16 | d = a[i] - b[i] 17 | r += d * d 18 | } 19 | return r 20 | } 21 | 22 | func Test1(t *testing.T) { 23 | a := []float32{1} 24 | b := []float32{4} 25 | assert.Equal(t, DistGo(a, b), L2Squared(a, b), "Incorrect") 26 | } 27 | 28 | func Test4(t *testing.T) { 29 | a := []float32{1, 2, 3, 4} 30 | b := []float32{4, 3, 2, 1} 31 | assert.Equal(t, DistGo(a, b), L2Squared(a, b), "Incorrect") 32 | } 33 | 34 | func Test5(t *testing.T) { 35 | a := []float32{1, 2, 3, 4, 1} 36 | b := []float32{4, 3, 2, 1, 9} 37 | assert.Equal(t, DistGo(a, b), L2Squared(a, b), "Incorrect") 38 | } 39 | 40 | func Test21(t *testing.T) { 41 | a := []float32{1, 2, 3, 4, 1, 1, 2, 3, 4, 1, 1, 2, 3, 4, 1, 1, 2, 3, 4, 1, 9} 42 | b := []float32{4, 3, 2, 1, 9, 4, 3, 2, 1, 9, 4, 3, 2, 1, 9, 4, 3, 2, 1, 9, 0} 43 | assert.Equal(t, DistGo(a, b), L2Squared(a, b), "Incorrect") 44 | } 45 | 46 | func TestAlignment(t *testing.T) { 47 | for i := 0; i < 10000; i++ { 48 | a := make([]float32, rand.Intn(256)) 49 | assert.True(t, uintptr(unsafe.Pointer(&a))%16 == 0, "[]float32 Not 16-bytes aligned!") 50 | assert.True(t, uintptr(unsafe.Pointer(&a))%32 == 0, "[]float32 Not 32-bytes aligned!") 51 | } 52 | } 53 | 54 | func Test24(t *testing.T) { 55 | a := []float32{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4} 56 | b := []float32{4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4} 57 | assert.Equal(t, DistGo(a, b), L2Squared(a, b), "Incorrect") 58 | assert.Equal(t, DistGo(b, a), L2Squared8AVX(a, b), "8avx Incorrect") 59 | } 60 | 61 | func Test128(t *testing.T) { 62 | a := []float32{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 63 | 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 64 | 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 65 | 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 66 | } 67 | b := []float32{4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 68 | 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 69 | 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 70 | 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 71 | } 72 | assert.Equal(t, DistGo(a, b), L2Squared(a, b), "Incorrect") 73 | assert.Equal(t, DistGo(b, a), L2Squared8AVX(a, b), "8avx Incorrect") 74 | } 75 | 76 | func TestBenchmark(t *testing.T) { 77 | a2 := []float32{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 78 | 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 79 | 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 80 | 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 81 | } 82 | b2 := []float32{4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 83 | 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 84 | 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 85 | 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 86 | } 87 | l := 10000000 88 | fmt.Printf("Testing %v calls with %v dim []float32\n", l, len(a2)) 89 | 90 | start := time.Now() 91 | for i := 0; i < l; i++ { 92 | L2Squared(a2, b2) 93 | } 94 | stop := time.Since(start) 95 | fmt.Printf("l2squared Done in %v. %v calcs / second\n", stop, float64(l)/stop.Seconds()) 96 | 97 | start = time.Now() 98 | for i := 0; i < l; i++ { 99 | L2Squared8AVX(a2, b2) 100 | } 101 | stop = time.Since(start) 102 | fmt.Printf("l2squared8AVX Done in %v. %v calcs / second\n", stop, float64(l)/stop.Seconds()) 103 | 104 | start = time.Now() 105 | for i := 0; i < l; i++ { 106 | DistGo(a2, b2) 107 | } 108 | stop = time.Since(start) 109 | fmt.Printf("Go version done in %v. %v calcs / second\n", stop, float64(l)/stop.Seconds()) 110 | } 111 | -------------------------------------------------------------------------------- /get_test_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget ftp://ftp.irisa.fr/local/texmex/corpus/siftsmall.tar.gz 3 | wget ftp://ftp.irisa.fr/local/texmex/corpus/sift.tar.gz 4 | tar -xf siftsmall.tar.gz 5 | tar -xf sift.tar.gz 6 | rm siftsmall.tar.gz 7 | rm sift.tar.gz 8 | -------------------------------------------------------------------------------- /hnsw.go: -------------------------------------------------------------------------------- 1 | package hnsw 2 | 3 | import ( 4 | "compress/gzip" 5 | "encoding/binary" 6 | "fmt" 7 | "io" 8 | "math" 9 | "math/rand" 10 | "os" 11 | "sync" 12 | "time" 13 | 14 | "github.com/Bithack/go-hnsw/bitsetpool" 15 | "github.com/Bithack/go-hnsw/distqueue" 16 | "github.com/Bithack/go-hnsw/f32" 17 | ) 18 | 19 | type Point []float32 20 | 21 | func (a Point) Size() int { 22 | return len(a) * 4 23 | } 24 | 25 | type node struct { 26 | sync.RWMutex 27 | locked bool 28 | p Point 29 | level int 30 | friends [][]uint32 31 | } 32 | 33 | type Hnsw struct { 34 | sync.RWMutex 35 | M int 36 | M0 int 37 | efConstruction int 38 | linkMode int 39 | DelaunayType int 40 | 41 | DistFunc func([]float32, []float32) float32 42 | 43 | nodes []node 44 | 45 | bitset *bitsetpool.BitsetPool 46 | 47 | LevelMult float64 48 | maxLayer int 49 | enterpoint uint32 50 | } 51 | 52 | // Load opens a index file previously written by Save(). Returnes a new index and the timestamp the file was written 53 | func Load(filename string) (*Hnsw, int64, error) { 54 | f, err := os.Open(filename) 55 | if err != nil { 56 | return nil, 0, err 57 | } 58 | z, err := gzip.NewReader(f) 59 | if err != nil { 60 | return nil, 0, err 61 | } 62 | 63 | timestamp := readInt64(z) 64 | 65 | h := new(Hnsw) 66 | h.M = readInt32(z) 67 | h.M0 = readInt32(z) 68 | h.efConstruction = readInt32(z) 69 | h.linkMode = readInt32(z) 70 | h.DelaunayType = readInt32(z) 71 | h.LevelMult = readFloat64(z) 72 | h.maxLayer = readInt32(z) 73 | h.enterpoint = uint32(readInt32(z)) 74 | 75 | h.DistFunc = f32.L2Squared8AVX 76 | h.bitset = bitsetpool.New() 77 | 78 | l := readInt32(z) 79 | h.nodes = make([]node, l) 80 | 81 | for i := range h.nodes { 82 | 83 | l := readInt32(z) 84 | h.nodes[i].p = make([]float32, l) 85 | 86 | err = binary.Read(z, binary.LittleEndian, h.nodes[i].p) 87 | if err != nil { 88 | panic(err) 89 | } 90 | h.nodes[i].level = readInt32(z) 91 | 92 | l = readInt32(z) 93 | h.nodes[i].friends = make([][]uint32, l) 94 | 95 | for j := range h.nodes[i].friends { 96 | l := readInt32(z) 97 | h.nodes[i].friends[j] = make([]uint32, l) 98 | err = binary.Read(z, binary.LittleEndian, h.nodes[i].friends[j]) 99 | if err != nil { 100 | panic(err) 101 | } 102 | } 103 | 104 | } 105 | 106 | z.Close() 107 | f.Close() 108 | 109 | return h, timestamp, nil 110 | } 111 | 112 | // Save writes to current index to a gzipped binary data file 113 | func (h *Hnsw) Save(filename string) error { 114 | f, err := os.Create(filename) 115 | if err != nil { 116 | return err 117 | } 118 | z := gzip.NewWriter(f) 119 | 120 | timestamp := time.Now().Unix() 121 | 122 | writeInt64(timestamp, z) 123 | 124 | writeInt32(h.M, z) 125 | writeInt32(h.M0, z) 126 | writeInt32(h.efConstruction, z) 127 | writeInt32(h.linkMode, z) 128 | writeInt32(h.DelaunayType, z) 129 | writeFloat64(h.LevelMult, z) 130 | writeInt32(h.maxLayer, z) 131 | writeInt32(int(h.enterpoint), z) 132 | 133 | l := len(h.nodes) 134 | writeInt32(l, z) 135 | 136 | if err != nil { 137 | return err 138 | } 139 | for _, n := range h.nodes { 140 | l := len(n.p) 141 | writeInt32(l, z) 142 | err = binary.Write(z, binary.LittleEndian, []float32(n.p)) 143 | if err != nil { 144 | panic(err) 145 | } 146 | writeInt32(n.level, z) 147 | 148 | l = len(n.friends) 149 | writeInt32(l, z) 150 | for _, f := range n.friends { 151 | l := len(f) 152 | writeInt32(l, z) 153 | err = binary.Write(z, binary.LittleEndian, f) 154 | if err != nil { 155 | panic(err) 156 | } 157 | } 158 | } 159 | 160 | z.Close() 161 | f.Close() 162 | 163 | return nil 164 | } 165 | 166 | func writeInt64(v int64, w io.Writer) { 167 | err := binary.Write(w, binary.LittleEndian, &v) 168 | if err != nil { 169 | panic(err) 170 | } 171 | } 172 | 173 | func writeInt32(v int, w io.Writer) { 174 | i := int32(v) 175 | err := binary.Write(w, binary.LittleEndian, &i) 176 | if err != nil { 177 | panic(err) 178 | } 179 | } 180 | 181 | func readInt32(r io.Reader) int { 182 | var i int32 183 | err := binary.Read(r, binary.LittleEndian, &i) 184 | if err != nil { 185 | panic(err) 186 | } 187 | return int(i) 188 | } 189 | 190 | func writeFloat64(v float64, w io.Writer) { 191 | err := binary.Write(w, binary.LittleEndian, &v) 192 | if err != nil { 193 | panic(err) 194 | } 195 | } 196 | 197 | func readInt64(r io.Reader) (v int64) { 198 | err := binary.Read(r, binary.LittleEndian, &v) 199 | if err != nil { 200 | panic(err) 201 | } 202 | return 203 | } 204 | 205 | func readFloat64(r io.Reader) (v float64) { 206 | err := binary.Read(r, binary.LittleEndian, &v) 207 | if err != nil { 208 | panic(err) 209 | } 210 | return 211 | } 212 | 213 | func (h *Hnsw) getFriends(n uint32, level int) []uint32 { 214 | if len(h.nodes[n].friends) < level+1 { 215 | return make([]uint32, 0) 216 | } 217 | return h.nodes[n].friends[level] 218 | } 219 | 220 | func (h *Hnsw) Link(first, second uint32, level int) { 221 | 222 | maxL := h.M 223 | if level == 0 { 224 | maxL = h.M0 225 | } 226 | 227 | h.RLock() 228 | node := &h.nodes[first] 229 | h.RUnlock() 230 | 231 | node.Lock() 232 | 233 | // check if we have allocated friends slices up to this level? 234 | if len(node.friends) < level+1 { 235 | for j := len(node.friends); j <= level; j++ { 236 | // allocate new list with 0 elements but capacity maxL 237 | node.friends = append(node.friends, make([]uint32, 0, maxL)) 238 | } 239 | // now grow it by one and add the first connection for this layer 240 | node.friends[level] = node.friends[level][0:1] 241 | node.friends[level][0] = second 242 | 243 | } else { 244 | // we did have some already... this will allocate more space if it overflows maxL 245 | node.friends[level] = append(node.friends[level], second) 246 | } 247 | 248 | l := len(node.friends[level]) 249 | 250 | if l > maxL { 251 | 252 | // to many links, deal with it 253 | 254 | switch h.DelaunayType { 255 | case 0: 256 | resultSet := &distqueue.DistQueueClosestLast{Size: len(node.friends[level])} 257 | 258 | for _, n := range node.friends[level] { 259 | resultSet.Push(n, h.DistFunc(node.p, h.nodes[n].p)) 260 | } 261 | for resultSet.Len() > maxL { 262 | resultSet.Pop() 263 | } 264 | // FRIENDS ARE STORED IN DISTANCE ORDER, closest at index 0 265 | node.friends[level] = node.friends[level][0:maxL] 266 | for i := maxL - 1; i >= 0; i-- { 267 | item := resultSet.Pop() 268 | node.friends[level][i] = item.ID 269 | } 270 | 271 | case 1: 272 | 273 | resultSet := &distqueue.DistQueueClosestFirst{Size: len(node.friends[level])} 274 | 275 | for _, n := range node.friends[level] { 276 | resultSet.Push(n, h.DistFunc(node.p, h.nodes[n].p)) 277 | } 278 | h.getNeighborsByHeuristicClosestFirst(resultSet, maxL) 279 | 280 | // FRIENDS ARE STORED IN DISTANCE ORDER, closest at index 0 281 | node.friends[level] = node.friends[level][0:maxL] 282 | for i := 0; i < maxL; i++ { 283 | item := resultSet.Pop() 284 | node.friends[level][i] = item.ID 285 | } 286 | } 287 | } 288 | node.Unlock() 289 | } 290 | 291 | func (h *Hnsw) getNeighborsByHeuristicClosestLast(resultSet1 *distqueue.DistQueueClosestLast, M int) { 292 | if resultSet1.Len() <= M { 293 | return 294 | } 295 | resultSet := &distqueue.DistQueueClosestFirst{Size: resultSet1.Len()} 296 | tempList := &distqueue.DistQueueClosestFirst{Size: resultSet1.Len()} 297 | result := make([]*distqueue.Item, 0, M) 298 | for resultSet1.Len() > 0 { 299 | resultSet.PushItem(resultSet1.Pop()) 300 | } 301 | for resultSet.Len() > 0 { 302 | if len(result) >= M { 303 | break 304 | } 305 | e := resultSet.Pop() 306 | good := true 307 | for _, r := range result { 308 | if h.DistFunc(h.nodes[r.ID].p, h.nodes[e.ID].p) < e.D { 309 | good = false 310 | break 311 | } 312 | } 313 | if good { 314 | result = append(result, e) 315 | } else { 316 | tempList.PushItem(e) 317 | } 318 | } 319 | for len(result) < M && tempList.Len() > 0 { 320 | result = append(result, tempList.Pop()) 321 | } 322 | for _, item := range result { 323 | resultSet1.PushItem(item) 324 | } 325 | } 326 | 327 | func (h *Hnsw) getNeighborsByHeuristicClosestFirst(resultSet *distqueue.DistQueueClosestFirst, M int) { 328 | if resultSet.Len() <= M { 329 | return 330 | } 331 | tempList := &distqueue.DistQueueClosestFirst{Size: resultSet.Len()} 332 | result := make([]*distqueue.Item, 0, M) 333 | for resultSet.Len() > 0 { 334 | if len(result) >= M { 335 | break 336 | } 337 | e := resultSet.Pop() 338 | good := true 339 | for _, r := range result { 340 | if h.DistFunc(h.nodes[r.ID].p, h.nodes[e.ID].p) < e.D { 341 | good = false 342 | break 343 | } 344 | } 345 | if good { 346 | result = append(result, e) 347 | } else { 348 | tempList.PushItem(e) 349 | } 350 | } 351 | for len(result) < M && tempList.Len() > 0 { 352 | result = append(result, tempList.Pop()) 353 | } 354 | resultSet.Reset() 355 | 356 | for _, item := range result { 357 | resultSet.PushItem(item) 358 | } 359 | } 360 | 361 | func New(M int, efConstruction int, first Point) *Hnsw { 362 | 363 | h := Hnsw{} 364 | h.M = M 365 | // default values used in c++ implementation 366 | h.LevelMult = 1 / math.Log(float64(M)) 367 | h.efConstruction = efConstruction 368 | h.M0 = 2 * M 369 | h.DelaunayType = 1 370 | 371 | h.bitset = bitsetpool.New() 372 | 373 | h.DistFunc = f32.L2Squared8AVX 374 | 375 | // add first point, it will be our enterpoint (index 0) 376 | h.nodes = []node{node{level: 0, p: first}} 377 | 378 | return &h 379 | } 380 | 381 | func (h *Hnsw) Stats() string { 382 | s := "HNSW Index\n" 383 | s = s + fmt.Sprintf("M: %v, efConstruction: %v\n", h.M, h.efConstruction) 384 | s = s + fmt.Sprintf("DelaunayType: %v\n", h.DelaunayType) 385 | s = s + fmt.Sprintf("Number of nodes: %v\n", len(h.nodes)) 386 | s = s + fmt.Sprintf("Max layer: %v\n", h.maxLayer) 387 | memoryUseData := 0 388 | memoryUseIndex := 0 389 | levCount := make([]int, h.maxLayer+1) 390 | conns := make([]int, h.maxLayer+1) 391 | connsC := make([]int, h.maxLayer+1) 392 | for i := range h.nodes { 393 | levCount[h.nodes[i].level]++ 394 | for j := 0; j <= h.nodes[i].level; j++ { 395 | if len(h.nodes[i].friends) > j { 396 | l := len(h.nodes[i].friends[j]) 397 | conns[j] += l 398 | connsC[j]++ 399 | } 400 | } 401 | memoryUseData += h.nodes[i].p.Size() 402 | memoryUseIndex += h.nodes[i].level*h.M*4 + h.M0*4 403 | } 404 | for i := range levCount { 405 | avg := conns[i] / max(1, connsC[i]) 406 | s = s + fmt.Sprintf("Level %v: %v nodes, average number of connections %v\n", i, levCount[i], avg) 407 | } 408 | s = s + fmt.Sprintf("Memory use for data: %v (%v bytes / point)\n", memoryUseData, memoryUseData/len(h.nodes)) 409 | s = s + fmt.Sprintf("Memory use for index: %v (avg %v bytes / point)\n", memoryUseIndex, memoryUseIndex/len(h.nodes)) 410 | return s 411 | } 412 | 413 | func (h *Hnsw) Grow(size int) { 414 | if size+1 <= len(h.nodes) { 415 | return 416 | } 417 | newNodes := make([]node, len(h.nodes), size+1) 418 | copy(newNodes, h.nodes) 419 | h.nodes = newNodes 420 | 421 | } 422 | 423 | func (h *Hnsw) Add(q Point, id uint32) { 424 | 425 | if id == 0 { 426 | panic("Id 0 is reserved, use ID:s starting from 1 when building index") 427 | } 428 | 429 | // generate random level 430 | curlevel := int(math.Floor(-math.Log(rand.Float64() * h.LevelMult))) 431 | 432 | epID := h.enterpoint 433 | currentMaxLayer := h.nodes[epID].level 434 | ep := &distqueue.Item{ID: h.enterpoint, D: h.DistFunc(h.nodes[h.enterpoint].p, q)} 435 | 436 | // assume Grow has been called in advance 437 | newID := id 438 | newNode := node{p: q, level: curlevel, friends: make([][]uint32, min(curlevel, currentMaxLayer)+1)} 439 | 440 | // first pass, find another ep if curlevel < maxLayer 441 | for level := currentMaxLayer; level > curlevel; level-- { 442 | changed := true 443 | for changed { 444 | changed = false 445 | for _, i := range h.getFriends(ep.ID, level) { 446 | d := h.DistFunc(h.nodes[i].p, q) 447 | if d < ep.D { 448 | ep = &distqueue.Item{ID: i, D: d} 449 | changed = true 450 | } 451 | } 452 | } 453 | } 454 | 455 | // second pass, ef = efConstruction 456 | // loop through every level from the new nodes level down to level 0 457 | // create new connections in every layer 458 | for level := min(curlevel, currentMaxLayer); level >= 0; level-- { 459 | 460 | resultSet := &distqueue.DistQueueClosestLast{} 461 | h.searchAtLayer(q, resultSet, h.efConstruction, ep, level) 462 | switch h.DelaunayType { 463 | case 0: 464 | // shrink resultSet to M closest elements (the simple heuristic) 465 | for resultSet.Len() > h.M { 466 | resultSet.Pop() 467 | } 468 | case 1: 469 | h.getNeighborsByHeuristicClosestLast(resultSet, h.M) 470 | } 471 | newNode.friends[level] = make([]uint32, resultSet.Len()) 472 | for i := resultSet.Len() - 1; i >= 0; i-- { 473 | item := resultSet.Pop() 474 | // store in order, closest at index 0 475 | newNode.friends[level][i] = item.ID 476 | } 477 | } 478 | 479 | h.Lock() 480 | // Add it and increase slice length if neccessary 481 | if len(h.nodes) < int(newID)+1 { 482 | h.nodes = h.nodes[0 : newID+1] 483 | } 484 | h.nodes[newID] = newNode 485 | h.Unlock() 486 | 487 | // now add connections to newNode from newNodes neighbours (makes it visible in the graph) 488 | for level := min(curlevel, currentMaxLayer); level >= 0; level-- { 489 | for _, n := range newNode.friends[level] { 490 | h.Link(n, newID, level) 491 | } 492 | } 493 | 494 | h.Lock() 495 | if curlevel > h.maxLayer { 496 | h.maxLayer = curlevel 497 | h.enterpoint = newID 498 | } 499 | h.Unlock() 500 | } 501 | 502 | func (h *Hnsw) searchAtLayer(q Point, resultSet *distqueue.DistQueueClosestLast, efConstruction int, ep *distqueue.Item, level int) { 503 | 504 | var pool, visited = h.bitset.Get() 505 | //visited := make(map[uint32]bool) 506 | 507 | candidates := &distqueue.DistQueueClosestFirst{Size: efConstruction * 3} 508 | 509 | visited.Set(uint(ep.ID)) 510 | //visited[ep.ID] = true 511 | candidates.Push(ep.ID, ep.D) 512 | 513 | resultSet.Push(ep.ID, ep.D) 514 | 515 | for candidates.Len() > 0 { 516 | _, lowerBound := resultSet.Top() // worst distance so far 517 | c := candidates.Pop() 518 | 519 | if c.D > lowerBound { 520 | // since candidates is sorted, it wont get any better... 521 | break 522 | } 523 | 524 | if len(h.nodes[c.ID].friends) >= level+1 { 525 | friends := h.nodes[c.ID].friends[level] 526 | for _, n := range friends { 527 | if !visited.Test(uint(n)) { 528 | visited.Set(uint(n)) 529 | d := h.DistFunc(q, h.nodes[n].p) 530 | _, topD := resultSet.Top() 531 | if resultSet.Len() < efConstruction { 532 | item := resultSet.Push(n, d) 533 | candidates.PushItem(item) 534 | } else if topD > d { 535 | // keep length of resultSet to max efConstruction 536 | item := resultSet.PopAndPush(n, d) 537 | candidates.PushItem(item) 538 | } 539 | } 540 | } 541 | } 542 | } 543 | h.bitset.Free(pool) 544 | } 545 | 546 | // SearchBrute returns the true K nearest neigbours to search point q 547 | func (h *Hnsw) SearchBrute(q Point, K int) *distqueue.DistQueueClosestLast { 548 | resultSet := &distqueue.DistQueueClosestLast{Size: K} 549 | for i := 1; i < len(h.nodes); i++ { 550 | d := h.DistFunc(h.nodes[i].p, q) 551 | if resultSet.Len() < K { 552 | resultSet.Push(uint32(i), d) 553 | continue 554 | } 555 | _, topD := resultSet.Head() 556 | if d < topD { 557 | resultSet.PopAndPush(uint32(i), d) 558 | continue 559 | } 560 | } 561 | return resultSet 562 | } 563 | 564 | // Benchmark test precision by comparing the results of SearchBrute and Search 565 | func (h *Hnsw) Benchmark(q Point, ef int, K int) float64 { 566 | result := h.Search(q, ef, K) 567 | groundTruth := h.SearchBrute(q, K) 568 | truth := make([]uint32, 0) 569 | for groundTruth.Len() > 0 { 570 | truth = append(truth, groundTruth.Pop().ID) 571 | } 572 | p := 0 573 | for result.Len() > 0 { 574 | i := result.Pop() 575 | for j := 0; j < K; j++ { 576 | if truth[j] == i.ID { 577 | p++ 578 | } 579 | } 580 | } 581 | return float64(p) / float64(K) 582 | } 583 | 584 | func (h *Hnsw) Search(q Point, ef int, K int) *distqueue.DistQueueClosestLast { 585 | 586 | h.RLock() 587 | currentMaxLayer := h.maxLayer 588 | ep := &distqueue.Item{ID: h.enterpoint, D: h.DistFunc(h.nodes[h.enterpoint].p, q)} 589 | h.RUnlock() 590 | 591 | resultSet := &distqueue.DistQueueClosestLast{Size: ef + 1} 592 | // first pass, find best ep 593 | for level := currentMaxLayer; level > 0; level-- { 594 | changed := true 595 | for changed { 596 | changed = false 597 | for _, i := range h.getFriends(ep.ID, level) { 598 | d := h.DistFunc(h.nodes[i].p, q) 599 | if d < ep.D { 600 | ep.ID, ep.D = i, d 601 | changed = true 602 | } 603 | } 604 | } 605 | } 606 | h.searchAtLayer(q, resultSet, ef, ep, 0) 607 | 608 | for resultSet.Len() > K { 609 | resultSet.Pop() 610 | } 611 | return resultSet 612 | } 613 | 614 | func min(a, b int) int { 615 | if a < b { 616 | return a 617 | } 618 | return b 619 | } 620 | 621 | func max(a, b int) int { 622 | if a > b { 623 | return a 624 | } 625 | return b 626 | } 627 | -------------------------------------------------------------------------------- /hnsw_test.go: -------------------------------------------------------------------------------- 1 | package hnsw 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | "math" 7 | "os" 8 | "runtime" 9 | "sync" 10 | "sync/atomic" 11 | "testing" 12 | "time" 13 | 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | var prefix = "siftsmall/siftsmall" 18 | var dataSize = 10000 19 | var efSearch = []int{1, 2, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 300, 400} 20 | var queries []Point 21 | var truth [][]uint32 22 | 23 | func TestMain(m *testing.M) { 24 | // LOAD QUERIES AND GROUNDTRUTH 25 | fmt.Printf("Loading query records\n") 26 | queries, truth = loadQueriesFromFvec(prefix) 27 | os.Exit(m.Run()) 28 | } 29 | func TestSaveLoad(t *testing.T) { 30 | h := buildIndex() 31 | testSearch(h) 32 | 33 | fmt.Printf("Saving to index.dat\n") 34 | err := h.Save("index.dat") 35 | assert.Nil(t, err) 36 | 37 | fmt.Printf("Loading from index.dat\n") 38 | h2, timestamp, err := Load("index.dat") 39 | assert.Nil(t, err) 40 | 41 | fmt.Printf("Index loaded, time saved was %v", time.Unix(timestamp, 0)) 42 | 43 | fmt.Printf(h2.Stats()) 44 | testSearch(h2) 45 | } 46 | 47 | func TestSIFT(t *testing.T) { 48 | h := buildIndex() 49 | testSearch(h) 50 | } 51 | 52 | func buildIndex() *Hnsw { 53 | // BUILD INDEX 54 | var p Point 55 | p = make([]float32, 128) 56 | h := New(4, 200, p) 57 | h.DelaunayType = 1 58 | h.Grow(dataSize) 59 | 60 | buildStart := time.Now() 61 | fmt.Printf("Loading data and building index\n") 62 | points := make(chan job) 63 | go loadDataFromFvec(prefix, points) 64 | buildFromChan(h, points) 65 | buildStop := time.Since(buildStart) 66 | fmt.Printf("Index build in %v\n", buildStop) 67 | fmt.Printf(h.Stats()) 68 | 69 | return h 70 | } 71 | 72 | func testSearch(h *Hnsw) { 73 | // SEARCH 74 | for _, ef := range efSearch { 75 | fmt.Printf("Now searching with ef=%v\n", ef) 76 | bestPrecision := 0.0 77 | bestTime := 999.0 78 | for i := 0; i < 10; i++ { 79 | start := time.Now() 80 | p := search(h, queries, truth, ef) 81 | stop := time.Since(start) 82 | bestPrecision = math.Max(bestPrecision, p) 83 | bestTime = math.Min(bestTime, stop.Seconds()/float64(len(queries))) 84 | } 85 | fmt.Printf("Best Precision 10-NN: %v\n", bestPrecision) 86 | fmt.Printf("Best time: %v s (%v queries / s)\n", bestTime, 1/bestTime) 87 | } 88 | } 89 | 90 | type job struct { 91 | p Point 92 | id uint32 93 | } 94 | 95 | func buildFromChan(h *Hnsw, points chan job) { 96 | var wg sync.WaitGroup 97 | for i := 0; i < runtime.NumCPU(); i++ { 98 | wg.Add(1) 99 | go func() { 100 | for { 101 | job, more := <-points 102 | if !more { 103 | wg.Done() 104 | return 105 | } 106 | h.Add(job.p, job.id) 107 | } 108 | }() 109 | } 110 | wg.Wait() 111 | } 112 | 113 | func search(h *Hnsw, queries []Point, truth [][]uint32, efSearch int) float64 { 114 | var p int32 115 | var wg sync.WaitGroup 116 | l := runtime.NumCPU() 117 | b := len(queries) / l 118 | 119 | for i := 0; i < runtime.NumCPU(); i++ { 120 | wg.Add(1) 121 | go func(queries []Point, truth [][]uint32) { 122 | for j := range queries { 123 | results := h.Search(queries[j], efSearch, 10) 124 | // calc 10-NN precision 125 | for results.Len() > 10 { 126 | results.Pop() 127 | } 128 | for _, item := range results.Items() { 129 | for k := 0; k < 10; k++ { 130 | // !!! Our index numbers starts from 1 131 | if int32(truth[j][k]) == int32(item.ID)-1 { 132 | atomic.AddInt32(&p, 1) 133 | } 134 | } 135 | } 136 | } 137 | wg.Done() 138 | }(queries[i*b:i*b+b], truth[i*b:i*b+b]) 139 | } 140 | wg.Wait() 141 | return (float64(p) / float64(10*b*l)) 142 | } 143 | 144 | func readFloat32(f *os.File) (float32, error) { 145 | bs := make([]byte, 4) 146 | _, err := f.Read(bs) 147 | return float32(math.Float32frombits(binary.LittleEndian.Uint32(bs))), err 148 | } 149 | 150 | func readUint32(f *os.File) (uint32, error) { 151 | bs := make([]byte, 4) 152 | _, err := f.Read(bs) 153 | return binary.LittleEndian.Uint32(bs), err 154 | } 155 | 156 | func loadQueriesFromFvec(prefix string) (queries []Point, truth [][]uint32) { 157 | f2, err := os.Open(prefix + "_query.fvecs") 158 | if err != nil { 159 | panic("couldn't open query data file") 160 | } 161 | defer f2.Close() 162 | queries = make([]Point, 10000) 163 | qcount := 0 164 | for { 165 | d, err := readUint32(f2) 166 | if err != nil { 167 | break 168 | } 169 | if d != 128 { 170 | panic("Wrong dimension for this test...") 171 | } 172 | queries[qcount] = make([]float32, 128) 173 | for i := 0; i < int(d); i++ { 174 | queries[qcount][i], err = readFloat32(f2) 175 | } 176 | qcount++ 177 | } 178 | queries = queries[0:qcount] // resize it 179 | fmt.Printf("Read %v query records\n", qcount) 180 | fmt.Printf("Loading groundtruth\n") 181 | // load query Vectors 182 | f3, err := os.Open(prefix + "_groundtruth.ivecs") 183 | if err != nil { 184 | panic("couldn't open groundtruth data file") 185 | } 186 | defer f3.Close() 187 | truth = make([][]uint32, 10000) 188 | tcount := 0 189 | for { 190 | d, err := readUint32(f3) 191 | if err != nil { 192 | break 193 | } 194 | if d != 100 { 195 | panic("Wrong dimension for this test...") 196 | } 197 | vec := make([]uint32, d) 198 | for i := 0; i < int(d); i++ { 199 | vec[i], err = readUint32(f3) 200 | } 201 | truth[tcount] = vec 202 | tcount++ 203 | } 204 | fmt.Printf("Read %v truth records\n", tcount) 205 | 206 | if tcount != qcount { 207 | panic("Count mismatch queries <-> groundtruth") 208 | } 209 | 210 | return queries, truth 211 | } 212 | 213 | func loadDataFromFvec(prefix string, points chan job) { 214 | f, err := os.Open(prefix + "_base.fvecs") 215 | if err != nil { 216 | panic("couldn't open data file") 217 | } 218 | defer f.Close() 219 | count := 1 220 | for { 221 | d, err := readUint32(f) 222 | if err != nil { 223 | break 224 | } 225 | if d != 128 { 226 | panic("Wrong dimension for this test...") 227 | } 228 | var vec Point 229 | vec = make([]float32, 128) 230 | for i := 0; i < int(d); i++ { 231 | vec[i], err = readFloat32(f) 232 | } 233 | points <- job{p: vec, id: uint32(count)} 234 | count++ 235 | if count%1000 == 0 { 236 | fmt.Printf("Read %v records\n", count) 237 | } 238 | } 239 | close(points) 240 | } 241 | --------------------------------------------------------------------------------