├── .github └── workflows │ └── go.yaml ├── .gitignore ├── LICENSE ├── README.md ├── analyzer.go ├── distance.go ├── distance_test.go ├── encode.go ├── encode_test.go ├── example └── readme │ └── main.go ├── go.mod ├── go.sum ├── graph.go ├── graph_test.go └── heap ├── heap.go └── heap_test.go /.github/workflows/go.yaml: -------------------------------------------------------------------------------- 1 | name: Go Test 2 | 3 | on: 4 | push: 5 | branches: 6 | pull_request: 7 | branches: 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Checkout code 15 | uses: actions/checkout@v2 16 | 17 | - name: Set up Go 18 | uses: actions/setup-go@v3 19 | with: 20 | go-version: 1.22 21 | 22 | - name: Install dependencies 23 | run: go mod tidy 24 | 25 | - name: Run tests 26 | run: go test ./... 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __debug_bin** 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Legal Code 2 | 3 | CC0 1.0 Universal 4 | 5 | CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE 6 | LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN 7 | ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS 8 | INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES 9 | REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS 10 | PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM 11 | THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED 12 | HEREUNDER. 13 | 14 | Statement of Purpose 15 | 16 | The laws of most jurisdictions throughout the world automatically confer 17 | exclusive Copyright and Related Rights (defined below) upon the creator 18 | and subsequent owner(s) (each and all, an "owner") of an original work of 19 | authorship and/or a database (each, a "Work"). 20 | 21 | Certain owners wish to permanently relinquish those rights to a Work for 22 | the purpose of contributing to a commons of creative, cultural and 23 | scientific works ("Commons") that the public can reliably and without fear 24 | of later claims of infringement build upon, modify, incorporate in other 25 | works, reuse and redistribute as freely as possible in any form whatsoever 26 | and for any purposes, including without limitation commercial purposes. 27 | These owners may contribute to the Commons to promote the ideal of a free 28 | culture and the further production of creative, cultural and scientific 29 | works, or to gain reputation or greater distribution for their Work in 30 | part through the use and efforts of others. 31 | 32 | For these and/or other purposes and motivations, and without any 33 | expectation of additional consideration or compensation, the person 34 | associating CC0 with a Work (the "Affirmer"), to the extent that he or she 35 | is an owner of Copyright and Related Rights in the Work, voluntarily 36 | elects to apply CC0 to the Work and publicly distribute the Work under its 37 | terms, with knowledge of his or her Copyright and Related Rights in the 38 | Work and the meaning and intended legal effect of CC0 on those rights. 39 | 40 | 1. Copyright and Related Rights. A Work made available under CC0 may be 41 | protected by copyright and related or neighboring rights ("Copyright and 42 | Related Rights"). Copyright and Related Rights include, but are not 43 | limited to, the following: 44 | 45 | i. the right to reproduce, adapt, distribute, perform, display, 46 | communicate, and translate a Work; 47 | ii. moral rights retained by the original author(s) and/or performer(s); 48 | iii. publicity and privacy rights pertaining to a person's image or 49 | likeness depicted in a Work; 50 | iv. rights protecting against unfair competition in regards to a Work, 51 | subject to the limitations in paragraph 4(a), below; 52 | v. rights protecting the extraction, dissemination, use and reuse of data 53 | in a Work; 54 | vi. database rights (such as those arising under Directive 96/9/EC of the 55 | European Parliament and of the Council of 11 March 1996 on the legal 56 | protection of databases, and under any national implementation 57 | thereof, including any amended or successor version of such 58 | directive); and 59 | vii. other similar, equivalent or corresponding rights throughout the 60 | world based on applicable law or treaty, and any national 61 | implementations thereof. 62 | 63 | 2. Waiver. To the greatest extent permitted by, but not in contravention 64 | of, applicable law, Affirmer hereby overtly, fully, permanently, 65 | irrevocably and unconditionally waives, abandons, and surrenders all of 66 | Affirmer's Copyright and Related Rights and associated claims and causes 67 | of action, whether now known or unknown (including existing as well as 68 | future claims and causes of action), in the Work (i) in all territories 69 | worldwide, (ii) for the maximum duration provided by applicable law or 70 | treaty (including future time extensions), (iii) in any current or future 71 | medium and for any number of copies, and (iv) for any purpose whatsoever, 72 | including without limitation commercial, advertising or promotional 73 | purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each 74 | member of the public at large and to the detriment of Affirmer's heirs and 75 | successors, fully intending that such Waiver shall not be subject to 76 | revocation, rescission, cancellation, termination, or any other legal or 77 | equitable action to disrupt the quiet enjoyment of the Work by the public 78 | as contemplated by Affirmer's express Statement of Purpose. 79 | 80 | 3. Public License Fallback. Should any part of the Waiver for any reason 81 | be judged legally invalid or ineffective under applicable law, then the 82 | Waiver shall be preserved to the maximum extent permitted taking into 83 | account Affirmer's express Statement of Purpose. In addition, to the 84 | extent the Waiver is so judged Affirmer hereby grants to each affected 85 | person a royalty-free, non transferable, non sublicensable, non exclusive, 86 | irrevocable and unconditional license to exercise Affirmer's Copyright and 87 | Related Rights in the Work (i) in all territories worldwide, (ii) for the 88 | maximum duration provided by applicable law or treaty (including future 89 | time extensions), (iii) in any current or future medium and for any number 90 | of copies, and (iv) for any purpose whatsoever, including without 91 | limitation commercial, advertising or promotional purposes (the 92 | "License"). The License shall be deemed effective as of the date CC0 was 93 | applied by Affirmer to the Work. Should any part of the License for any 94 | reason be judged legally invalid or ineffective under applicable law, such 95 | partial invalidity or ineffectiveness shall not invalidate the remainder 96 | of the License, and in such case Affirmer hereby affirms that he or she 97 | will not (i) exercise any of his or her remaining Copyright and Related 98 | Rights in the Work or (ii) assert any associated claims and causes of 99 | action with respect to the Work, in either case contrary to Affirmer's 100 | express Statement of Purpose. 101 | 102 | 4. Limitations and Disclaimers. 103 | 104 | a. No trademark or patent rights held by Affirmer are waived, abandoned, 105 | surrendered, licensed or otherwise affected by this document. 106 | b. Affirmer offers the Work as-is and makes no representations or 107 | warranties of any kind concerning the Work, express, implied, 108 | statutory or otherwise, including without limitation warranties of 109 | title, merchantability, fitness for a particular purpose, non 110 | infringement, or the absence of latent or other defects, accuracy, or 111 | the present or absence of errors, whether or not discoverable, all to 112 | the greatest extent permissible under applicable law. 113 | c. Affirmer disclaims responsibility for clearing rights of other persons 114 | that may apply to the Work or any use thereof, including without 115 | limitation any person's Copyright and Related Rights in the Work. 116 | Further, Affirmer disclaims responsibility for obtaining any necessary 117 | consents, permissions or other rights required for any use of the 118 | Work. 119 | d. Affirmer understands and acknowledges that Creative Commons is not a 120 | party to this document and has no duty or obligation with respect to 121 | this CC0 or use of the Work. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # hnsw 2 | [![GoDoc](https://godoc.org/github.com/golang/gddo?status.svg)](https://pkg.go.dev/github.com/coder/hnsw@main?utm_source=godoc) 3 | ![Go workflow status](https://github.com/coder/hnsw/actions/workflows/go.yaml/badge.svg) 4 | 5 | 6 | 7 | Package `hnsw` implements Hierarchical Navigable Small World graphs in Go. You 8 | can read up about how they work [here](https://www.pinecone.io/learn/series/faiss/hnsw/). In essence, 9 | they allow for fast approximate nearest neighbor searches with high-dimensional 10 | vector data. 11 | 12 | This package can be thought of as an in-memory alternative to your favorite 13 | vector database (e.g. Pinecone, Weaviate). It implements just the essential 14 | operations: 15 | 16 | | Operation | Complexity | Description | 17 | | --------- | --------------------- | -------------------------------------------- | 18 | | Insert | $O(log(n))$ | Insert a vector into the graph | 19 | | Delete | $O(M^2 \cdot log(n))$ | Delete a vector from the graph | 20 | | Search | $O(log(n))$ | Search for the nearest neighbors of a vector | 21 | | Lookup | $O(1)$ | Retrieve a vector by ID | 22 | 23 | > [!NOTE] 24 | > Complexities are approximate where $n$ is the number of vectors in the graph 25 | > and $M$ is the maximum number of neighbors each node can have. This [paper](https://arxiv.org/pdf/1603.09320) is a good resource for understanding the effect of 26 | > the various construction parameters. 27 | 28 | ## Usage 29 | 30 | ``` 31 | go get github.com/coder/hnsw@main 32 | ``` 33 | 34 | ```go 35 | g := hnsw.NewGraph[int]() 36 | g.Add( 37 | hnsw.MakeNode(1, []float32{1, 1, 1}), 38 | hnsw.MakeNode(2, []float32{1, -1, 0.999}), 39 | hnsw.MakeNode(3, []float32{1, 0, -0.5}), 40 | ) 41 | 42 | neighbors := g.Search( 43 | []float32{0.5, 0.5, 0.5}, 44 | 1, 45 | ) 46 | fmt.Printf("best friend: %v\n", neighbors[0].Vec) 47 | // Output: best friend: [1 1 1] 48 | ``` 49 | 50 | 51 | 52 | ## Persistence 53 | 54 | While all graph operations are in-memory, `hnsw` provides facilities for loading/saving from persistent storage. 55 | 56 | For an `io.Reader`/`io.Writer` interface, use `Graph.Export` and `Graph.Import`. 57 | 58 | If you're using a single file as the backend, hnsw provides a convenient `SavedGraph` type instead: 59 | 60 | ```go 61 | path := "some.graph" 62 | g1, err := LoadSavedGraph[int](path) 63 | if err != nil { 64 | panic(err) 65 | } 66 | // Insert some vectors 67 | for i := 0; i < 128; i++ { 68 | g1.Add(hnsw.MakeNode(i, []float32{float32(i)})) 69 | } 70 | 71 | // Save to disk 72 | err = g1.Save() 73 | if err != nil { 74 | panic(err) 75 | } 76 | 77 | // Later... 78 | // g2 is a copy of g1 79 | g2, err := LoadSavedGraph[int](path) 80 | if err != nil { 81 | panic(err) 82 | } 83 | ``` 84 | 85 | See more: 86 | * [Export](https://pkg.go.dev/github.com/coder/hnsw#Graph.Export) 87 | * [Import](https://pkg.go.dev/github.com/coder/hnsw#Graph.Import) 88 | * [SavedGraph](https://pkg.go.dev/github.com/coder/hnsw#SavedGraph) 89 | 90 | We use a fast binary encoding for the graph, so you can expect to save/load 91 | nearly at disk speed. On my M3 Macbook I get these benchmark results: 92 | 93 | ``` 94 | goos: darwin 95 | goarch: arm64 96 | pkg: github.com/coder/hnsw 97 | BenchmarkGraph_Import-16 4029 259927 ns/op 796.85 MB/s 496022 B/op 3212 allocs/op 98 | BenchmarkGraph_Export-16 7042 168028 ns/op 1232.49 MB/s 239886 B/op 2388 allocs/op 99 | PASS 100 | ok github.com/coder/hnsw 2.624s 101 | ``` 102 | 103 | when saving/loading a graph of 100 vectors with 256 dimensions. 104 | 105 | ## Performance 106 | 107 | By and large the greatest effect you can have on the performance of the graph 108 | is reducing the dimensionality of your data. At 1536 dimensions (OpenAI default), 109 | 70% of the query process under default parameters is spent in the distance function. 110 | 111 | If you're struggling with slowness / latency, consider: 112 | * Reducing dimensionality 113 | * Increasing $M$ 114 | 115 | And, if you're struggling with excess memory usage, consider: 116 | * Reducing $M$ a.k.a `Graph.M` (the maximum number of neighbors each node can have) 117 | * Reducing $m_L$ a.k.a `Graph.Ml` (the level generation parameter) 118 | 119 | ## Memory Overhead 120 | 121 | The memory overhead of a graph looks like: 122 | 123 | $$ 124 | \displaylines{ 125 | mem_{graph} = n \cdot \log(n) \cdot \text{size(id)} \cdot M \\ 126 | mem_{base} = n \cdot d \cdot 4 \\ 127 | mem_{total} = mem_{graph} + mem_{base} 128 | } 129 | $$ 130 | 131 | where: 132 | * $n$ is the number of vectors in the graph 133 | * $\text{size(key)}$ is the average size of the key in bytes 134 | * $M$ is the maximum number of neighbors each node can have 135 | * $d$ is the dimensionality of the vectors 136 | * $mem_{graph}$ is the memory used by the graph structure across all layers 137 | * $mem_{base}$ is the memory used by the vectors themselves in the base or 0th layer 138 | 139 | You can infer that: 140 | * Connectivity ($M$) is very expensive if keys are large 141 | * If $d \cdot 4$ is far larger than $M \cdot \text{size(key)}$, you should expect linear memory usage spent on representing vector data 142 | * If $d \cdot 4$ is far smaller than $M \cdot \text{size(key)}$, you should expect $n \cdot \log(n)$ memory usage spent on representing graph structure 143 | 144 | In the example of a graph with 256 dimensions, and $M = 16$, with 8 byte keys, you would see that each vector takes: 145 | 146 | * $256 \cdot 4 = 1024$ data bytes 147 | * $16 \cdot 8 = 128$ metadata bytes 148 | 149 | and memory growth is mostly linear. 150 | -------------------------------------------------------------------------------- /analyzer.go: -------------------------------------------------------------------------------- 1 | package hnsw 2 | 3 | import "cmp" 4 | 5 | // Analyzer is a struct that holds a graph and provides 6 | // methods for analyzing it. It offers no compatibility guarantee 7 | // as the methods of measuring the graph's health with change 8 | // with the implementation. 9 | type Analyzer[K cmp.Ordered] struct { 10 | Graph *Graph[K] 11 | } 12 | 13 | func (a *Analyzer[T]) Height() int { 14 | return len(a.Graph.layers) 15 | } 16 | 17 | // Connectivity returns the average number of edges in the 18 | // graph for each non-empty layer. 19 | func (a *Analyzer[T]) Connectivity() []float64 { 20 | var layerConnectivity []float64 21 | for _, layer := range a.Graph.layers { 22 | if len(layer.nodes) == 0 { 23 | continue 24 | } 25 | 26 | var sum float64 27 | for _, node := range layer.nodes { 28 | sum += float64(len(node.neighbors)) 29 | } 30 | 31 | layerConnectivity = append(layerConnectivity, sum/float64(len(layer.nodes))) 32 | } 33 | 34 | return layerConnectivity 35 | } 36 | 37 | // Topography returns the number of nodes in each layer of the graph. 38 | func (a *Analyzer[T]) Topography() []int { 39 | var topography []int 40 | for _, layer := range a.Graph.layers { 41 | topography = append(topography, len(layer.nodes)) 42 | } 43 | return topography 44 | } 45 | -------------------------------------------------------------------------------- /distance.go: -------------------------------------------------------------------------------- 1 | package hnsw 2 | 3 | import ( 4 | "math" 5 | "reflect" 6 | 7 | "github.com/viterin/vek/vek32" 8 | ) 9 | 10 | // DistanceFunc is a function that computes the distance between two vectors. 11 | type DistanceFunc func(a, b []float32) float32 12 | 13 | // CosineDistance computes the cosine distance between two vectors. 14 | func CosineDistance(a, b []float32) float32 { 15 | return 1 - vek32.CosineSimilarity(a, b) 16 | } 17 | 18 | // EuclideanDistance computes the Euclidean distance between two vectors. 19 | func EuclideanDistance(a, b []float32) float32 { 20 | // TODO: can we speedup with vek? 21 | var sum float32 = 0 22 | for i := range a { 23 | diff := a[i] - b[i] 24 | sum += diff * diff 25 | } 26 | return float32(math.Sqrt(float64(sum))) 27 | } 28 | 29 | var distanceFuncs = map[string]DistanceFunc{ 30 | "euclidean": EuclideanDistance, 31 | "cosine": CosineDistance, 32 | } 33 | 34 | func distanceFuncToName(fn DistanceFunc) (string, bool) { 35 | for name, f := range distanceFuncs { 36 | fnptr := reflect.ValueOf(fn).Pointer() 37 | fptr := reflect.ValueOf(f).Pointer() 38 | if fptr == fnptr { 39 | return name, true 40 | } 41 | } 42 | return "", false 43 | } 44 | 45 | // RegisterDistanceFunc registers a distance function with a name. 46 | // A distance function must be registered here before a graph can be 47 | // exported and imported. 48 | func RegisterDistanceFunc(name string, fn DistanceFunc) { 49 | distanceFuncs[name] = fn 50 | } 51 | -------------------------------------------------------------------------------- /distance_test.go: -------------------------------------------------------------------------------- 1 | package hnsw 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | func TestEuclideanDistance(t *testing.T) { 10 | a := []float32{1, 2, 3} 11 | b := []float32{4, 5, 6} 12 | require.Equal(t, float32(5.196152), EuclideanDistance(a, b)) 13 | } 14 | 15 | func TestCosineSimilarity(t *testing.T) { 16 | var a, b []float32 17 | // Same magnitude, same direction. 18 | a = []float32{1, 1, 1} 19 | b = []float32{0.8, 0.8, 0.8} 20 | require.InDelta(t, 0, CosineDistance(a, b), 0.000001) 21 | 22 | // Perpendicular vectors. 23 | a = []float32{1, 0} 24 | b = []float32{0, 1} 25 | require.InDelta(t, 1, CosineDistance(a, b), 0.000001) 26 | 27 | // Equivalent vectors. 28 | a = []float32{1, 0} 29 | b = []float32{1, 0} 30 | require.InDelta(t, 0, CosineDistance(a, b), 0.000001) 31 | } 32 | 33 | func BenchmarkCosineSimilarity(b *testing.B) { 34 | v1 := randFloats(1536) 35 | v2 := randFloats(1536) 36 | b.ResetTimer() 37 | for i := 0; i < b.N; i++ { 38 | CosineDistance(v1, v2) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /encode.go: -------------------------------------------------------------------------------- 1 | package hnsw 2 | 3 | import ( 4 | "bufio" 5 | "cmp" 6 | "encoding/binary" 7 | "fmt" 8 | "io" 9 | "os" 10 | 11 | "github.com/google/renameio" 12 | ) 13 | 14 | // errorEncoder is a helper type to encode multiple values 15 | 16 | var byteOrder = binary.LittleEndian 17 | 18 | func binaryRead(r io.Reader, data interface{}) (int, error) { 19 | switch v := data.(type) { 20 | case *int: 21 | br, ok := r.(io.ByteReader) 22 | if !ok { 23 | return 0, fmt.Errorf("reader does not implement io.ByteReader") 24 | } 25 | 26 | i, err := binary.ReadVarint(br) 27 | if err != nil { 28 | return 0, err 29 | } 30 | 31 | *v = int(i) 32 | // TODO: this will usually overshoot size. 33 | return binary.MaxVarintLen64, nil 34 | 35 | case *string: 36 | var ln int 37 | _, err := binaryRead(r, &ln) 38 | if err != nil { 39 | return 0, err 40 | } 41 | 42 | s := make([]byte, ln) 43 | _, err = binaryRead(r, &s) 44 | *v = string(s) 45 | return len(s), err 46 | 47 | case *[]float32: 48 | var ln int 49 | _, err := binaryRead(r, &ln) 50 | if err != nil { 51 | return 0, err 52 | } 53 | 54 | *v = make([]float32, ln) 55 | return binary.Size(*v), binary.Read(r, byteOrder, *v) 56 | 57 | case io.ReaderFrom: 58 | n, err := v.ReadFrom(r) 59 | return int(n), err 60 | 61 | default: 62 | return binary.Size(data), binary.Read(r, byteOrder, data) 63 | } 64 | } 65 | 66 | func binaryWrite(w io.Writer, data any) (int, error) { 67 | switch v := data.(type) { 68 | case int: 69 | var buf [binary.MaxVarintLen64]byte 70 | n := binary.PutVarint(buf[:], int64(v)) 71 | n, err := w.Write(buf[:n]) 72 | return n, err 73 | case io.WriterTo: 74 | n, err := v.WriteTo(w) 75 | return int(n), err 76 | case string: 77 | n, err := binaryWrite(w, len(v)) 78 | if err != nil { 79 | return n, err 80 | } 81 | n2, err := io.WriteString(w, v) 82 | if err != nil { 83 | return n + n2, err 84 | } 85 | 86 | return n + n2, nil 87 | case []float32: 88 | n, err := binaryWrite(w, len(v)) 89 | if err != nil { 90 | return n, err 91 | } 92 | return n + binary.Size(v), binary.Write(w, byteOrder, v) 93 | 94 | default: 95 | sz := binary.Size(data) 96 | err := binary.Write(w, byteOrder, data) 97 | if err != nil { 98 | return 0, fmt.Errorf("encoding %T: %w", data, err) 99 | } 100 | return sz, err 101 | } 102 | } 103 | 104 | func multiBinaryWrite(w io.Writer, data ...any) (int, error) { 105 | var written int 106 | for _, d := range data { 107 | n, err := binaryWrite(w, d) 108 | written += n 109 | if err != nil { 110 | return written, err 111 | } 112 | } 113 | return written, nil 114 | } 115 | 116 | func multiBinaryRead(r io.Reader, data ...any) (int, error) { 117 | var read int 118 | for i, d := range data { 119 | n, err := binaryRead(r, d) 120 | read += n 121 | if err != nil { 122 | return read, fmt.Errorf("reading %T at index %v: %w", d, i, err) 123 | } 124 | } 125 | return read, nil 126 | } 127 | 128 | const encodingVersion = 1 129 | 130 | // Export writes the graph to a writer. 131 | // 132 | // T must implement io.WriterTo. 133 | func (h *Graph[K]) Export(w io.Writer) error { 134 | distFuncName, ok := distanceFuncToName(h.Distance) 135 | if !ok { 136 | return fmt.Errorf("distance function %v must be registered with RegisterDistanceFunc", h.Distance) 137 | } 138 | _, err := multiBinaryWrite( 139 | w, 140 | encodingVersion, 141 | h.M, 142 | h.Ml, 143 | h.EfSearch, 144 | distFuncName, 145 | ) 146 | if err != nil { 147 | return fmt.Errorf("encode parameters: %w", err) 148 | } 149 | _, err = binaryWrite(w, len(h.layers)) 150 | if err != nil { 151 | return fmt.Errorf("encode number of layers: %w", err) 152 | } 153 | for _, layer := range h.layers { 154 | _, err = binaryWrite(w, len(layer.nodes)) 155 | if err != nil { 156 | return fmt.Errorf("encode number of nodes: %w", err) 157 | } 158 | for _, node := range layer.nodes { 159 | _, err = multiBinaryWrite(w, node.Key, node.Value, len(node.neighbors)) 160 | if err != nil { 161 | return fmt.Errorf("encode node data: %w", err) 162 | } 163 | 164 | for neighbor := range node.neighbors { 165 | _, err = binaryWrite(w, neighbor) 166 | if err != nil { 167 | return fmt.Errorf("encode neighbor %v: %w", neighbor, err) 168 | } 169 | } 170 | } 171 | } 172 | 173 | return nil 174 | } 175 | 176 | // Import reads the graph from a reader. 177 | // T must implement io.ReaderFrom. 178 | // The imported graph does not have to match the exported graph's parameters (except for 179 | // dimensionality). The graph will converge onto the new parameters. 180 | func (h *Graph[K]) Import(r io.Reader) error { 181 | var ( 182 | version int 183 | dist string 184 | ) 185 | _, err := multiBinaryRead(r, &version, &h.M, &h.Ml, &h.EfSearch, 186 | &dist, 187 | ) 188 | if err != nil { 189 | return err 190 | } 191 | 192 | var ok bool 193 | h.Distance, ok = distanceFuncs[dist] 194 | if !ok { 195 | return fmt.Errorf("unknown distance function %q", dist) 196 | } 197 | if h.Rng == nil { 198 | h.Rng = defaultRand() 199 | } 200 | 201 | if version != encodingVersion { 202 | return fmt.Errorf("incompatible encoding version: %d", version) 203 | } 204 | 205 | var nLayers int 206 | _, err = binaryRead(r, &nLayers) 207 | if err != nil { 208 | return err 209 | } 210 | 211 | h.layers = make([]*layer[K], nLayers) 212 | for i := 0; i < nLayers; i++ { 213 | var nNodes int 214 | _, err = binaryRead(r, &nNodes) 215 | if err != nil { 216 | return err 217 | } 218 | 219 | nodes := make(map[K]*layerNode[K], nNodes) 220 | for j := 0; j < nNodes; j++ { 221 | var key K 222 | var vec Vector 223 | var nNeighbors int 224 | _, err = multiBinaryRead(r, &key, &vec, &nNeighbors) 225 | if err != nil { 226 | return fmt.Errorf("decoding node %d: %w", j, err) 227 | } 228 | 229 | neighbors := make([]K, nNeighbors) 230 | for k := 0; k < nNeighbors; k++ { 231 | var neighbor K 232 | _, err = binaryRead(r, &neighbor) 233 | if err != nil { 234 | return fmt.Errorf("decoding neighbor %d for node %d: %w", k, j, err) 235 | } 236 | neighbors[k] = neighbor 237 | } 238 | 239 | node := &layerNode[K]{ 240 | Node: Node[K]{ 241 | Key: key, 242 | Value: vec, 243 | }, 244 | neighbors: make(map[K]*layerNode[K]), 245 | } 246 | 247 | nodes[key] = node 248 | for _, neighbor := range neighbors { 249 | node.neighbors[neighbor] = nil 250 | } 251 | } 252 | // Fill in neighbor pointers 253 | for _, node := range nodes { 254 | for key := range node.neighbors { 255 | node.neighbors[key] = nodes[key] 256 | } 257 | } 258 | h.layers[i] = &layer[K]{nodes: nodes} 259 | } 260 | 261 | return nil 262 | } 263 | 264 | // SavedGraph is a wrapper around a graph that persists 265 | // changes to a file upon calls to Save. It is more convenient 266 | // but less powerful than calling Graph.Export and Graph.Import 267 | // directly. 268 | type SavedGraph[K cmp.Ordered] struct { 269 | *Graph[K] 270 | Path string 271 | } 272 | 273 | // LoadSavedGraph opens a graph from a file, reads it, and returns it. 274 | // 275 | // If the file does not exist (i.e. this is a new graph), 276 | // the equivalent of NewGraph is returned. 277 | // 278 | // It does not hold open a file descriptor, so SavedGraph can be forgotten 279 | // without ever calling Save. 280 | func LoadSavedGraph[K cmp.Ordered](path string) (*SavedGraph[K], error) { 281 | f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600) 282 | if err != nil { 283 | return nil, err 284 | } 285 | defer f.Close() 286 | info, err := f.Stat() 287 | if err != nil { 288 | return nil, err 289 | } 290 | 291 | g := NewGraph[K]() 292 | if info.Size() > 0 { 293 | err = g.Import(bufio.NewReader(f)) 294 | if err != nil { 295 | return nil, fmt.Errorf("import: %w", err) 296 | } 297 | } 298 | 299 | return &SavedGraph[K]{Graph: g, Path: path}, nil 300 | } 301 | 302 | // Save writes the graph to the file. 303 | func (g *SavedGraph[K]) Save() error { 304 | tmp, err := renameio.TempFile("", g.Path) 305 | if err != nil { 306 | return err 307 | } 308 | defer tmp.Cleanup() 309 | 310 | wr := bufio.NewWriter(tmp) 311 | err = g.Export(wr) 312 | if err != nil { 313 | return fmt.Errorf("exporting: %w", err) 314 | } 315 | 316 | err = wr.Flush() 317 | if err != nil { 318 | return fmt.Errorf("flushing: %w", err) 319 | } 320 | 321 | err = tmp.CloseAtomicallyReplace() 322 | if err != nil { 323 | return fmt.Errorf("closing atomically: %w", err) 324 | } 325 | 326 | return nil 327 | } 328 | -------------------------------------------------------------------------------- /encode_test.go: -------------------------------------------------------------------------------- 1 | package hnsw 2 | 3 | import ( 4 | "bytes" 5 | "cmp" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func Test_binaryVarint(t *testing.T) { 12 | buf := bytes.NewBuffer(nil) 13 | i := 1337 14 | 15 | n, err := binaryWrite(buf, i) 16 | require.NoError(t, err) 17 | require.Equal(t, 2, n) 18 | 19 | // Ensure that binaryRead doesn't read past the 20 | // varint. 21 | buf.Write([]byte{0, 0, 0, 0}) 22 | 23 | var j int 24 | _, err = binaryRead(buf, &j) 25 | require.NoError(t, err) 26 | require.Equal(t, 1337, j) 27 | 28 | require.Equal( 29 | t, 30 | []byte{0, 0, 0, 0}, 31 | buf.Bytes(), 32 | ) 33 | } 34 | 35 | func Test_binaryWrite_string(t *testing.T) { 36 | buf := bytes.NewBuffer(nil) 37 | s := "hello" 38 | 39 | n, err := binaryWrite(buf, s) 40 | require.NoError(t, err) 41 | // 5 bytes for the string, 1 byte for the length. 42 | require.Equal(t, 5+1, n) 43 | 44 | var s2 string 45 | _, err = binaryRead(buf, &s2) 46 | require.NoError(t, err) 47 | require.Equal(t, "hello", s2) 48 | 49 | require.Empty(t, buf.Bytes()) 50 | } 51 | 52 | func verifyGraphNodes[K cmp.Ordered](t *testing.T, g *Graph[K]) { 53 | for _, layer := range g.layers { 54 | for _, node := range layer.nodes { 55 | for neighborKey, neighbor := range node.neighbors { 56 | _, ok := layer.nodes[neighbor.Key] 57 | if !ok { 58 | t.Errorf( 59 | "node %v has neighbor %v, but neighbor does not exist", 60 | node.Key, neighbor.Key, 61 | ) 62 | } 63 | 64 | if neighborKey != neighbor.Key { 65 | t.Errorf("node %v has neighbor %v, but neighbor key is %v", node.Key, 66 | neighbor.Key, 67 | neighborKey, 68 | ) 69 | } 70 | } 71 | } 72 | } 73 | } 74 | 75 | // requireGraphApproxEquals checks that two graphs are equal. 76 | func requireGraphApproxEquals[K cmp.Ordered](t *testing.T, g1, g2 *Graph[K]) { 77 | require.Equal(t, g1.Len(), g2.Len()) 78 | a1 := Analyzer[K]{g1} 79 | a2 := Analyzer[K]{g2} 80 | 81 | require.Equal( 82 | t, 83 | a1.Topography(), 84 | a2.Topography(), 85 | ) 86 | 87 | require.Equal( 88 | t, 89 | a1.Connectivity(), 90 | a2.Connectivity(), 91 | ) 92 | 93 | require.NotNil(t, g1.Distance) 94 | require.NotNil(t, g2.Distance) 95 | require.Equal( 96 | t, 97 | g1.Distance([]float32{0.5}, []float32{1}), 98 | g2.Distance([]float32{0.5}, []float32{1}), 99 | ) 100 | 101 | require.Equal(t, 102 | g1.M, 103 | g2.M, 104 | ) 105 | 106 | require.Equal(t, 107 | g1.Ml, 108 | g2.Ml, 109 | ) 110 | 111 | require.Equal(t, 112 | g1.EfSearch, 113 | g2.EfSearch, 114 | ) 115 | 116 | require.NotNil(t, g1.Rng) 117 | require.NotNil(t, g2.Rng) 118 | } 119 | 120 | func TestGraph_ExportImport(t *testing.T) { 121 | g1 := newTestGraph[int]() 122 | for i := 0; i < 128; i++ { 123 | g1.Add( 124 | Node[int]{ 125 | i, randFloats(1), 126 | }, 127 | ) 128 | } 129 | 130 | buf := &bytes.Buffer{} 131 | err := g1.Export(buf) 132 | require.NoError(t, err) 133 | 134 | // Don't use newTestGraph to ensure parameters 135 | // are imported. 136 | g2 := &Graph[int]{} 137 | err = g2.Import(buf) 138 | require.NoError(t, err) 139 | 140 | requireGraphApproxEquals(t, g1, g2) 141 | 142 | n1 := g1.Search( 143 | []float32{0.5}, 144 | 10, 145 | ) 146 | 147 | n2 := g2.Search( 148 | []float32{0.5}, 149 | 10, 150 | ) 151 | 152 | require.Equal(t, n1, n2) 153 | 154 | verifyGraphNodes(t, g1) 155 | verifyGraphNodes(t, g2) 156 | } 157 | 158 | func TestSavedGraph(t *testing.T) { 159 | dir := t.TempDir() 160 | 161 | g1, err := LoadSavedGraph[int](dir + "/graph") 162 | require.NoError(t, err) 163 | require.Equal(t, 0, g1.Len()) 164 | for i := 0; i < 128; i++ { 165 | g1.Add( 166 | Node[int]{ 167 | i, randFloats(1), 168 | }, 169 | ) 170 | } 171 | 172 | err = g1.Save() 173 | require.NoError(t, err) 174 | 175 | g2, err := LoadSavedGraph[int](dir + "/graph") 176 | require.NoError(t, err) 177 | 178 | requireGraphApproxEquals(t, g1.Graph, g2.Graph) 179 | } 180 | 181 | const benchGraphSize = 100 182 | 183 | func BenchmarkGraph_Import(b *testing.B) { 184 | b.ReportAllocs() 185 | g := newTestGraph[int]() 186 | for i := 0; i < benchGraphSize; i++ { 187 | g.Add( 188 | Node[int]{ 189 | i, randFloats(256), 190 | }, 191 | ) 192 | } 193 | 194 | buf := &bytes.Buffer{} 195 | err := g.Export(buf) 196 | require.NoError(b, err) 197 | 198 | b.ResetTimer() 199 | b.SetBytes(int64(buf.Len())) 200 | 201 | for i := 0; i < b.N; i++ { 202 | b.StopTimer() 203 | rdr := bytes.NewReader(buf.Bytes()) 204 | g := newTestGraph[int]() 205 | b.StartTimer() 206 | err = g.Import(rdr) 207 | require.NoError(b, err) 208 | } 209 | } 210 | 211 | func BenchmarkGraph_Export(b *testing.B) { 212 | b.ReportAllocs() 213 | g := newTestGraph[int]() 214 | for i := 0; i < benchGraphSize; i++ { 215 | g.Add( 216 | Node[int]{ 217 | i, randFloats(256), 218 | }, 219 | ) 220 | } 221 | 222 | var buf bytes.Buffer 223 | b.ResetTimer() 224 | for i := 0; i < b.N; i++ { 225 | err := g.Export(&buf) 226 | require.NoError(b, err) 227 | if i == 0 { 228 | ln := buf.Len() 229 | b.SetBytes(int64(ln)) 230 | } 231 | buf.Reset() 232 | } 233 | } 234 | -------------------------------------------------------------------------------- /example/readme/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/coder/hnsw" 7 | ) 8 | 9 | func main() { 10 | g := hnsw.NewGraph[int]() 11 | g.Add( 12 | hnsw.MakeNode(1, []float32{1, 1, 1}), 13 | hnsw.MakeNode(2, []float32{1, -1, 0.999}), 14 | hnsw.MakeNode(3, []float32{1, 0, -0.5}), 15 | ) 16 | 17 | neighbors := g.Search( 18 | []float32{0.5, 0.5, 0.5}, 19 | 1, 20 | ) 21 | fmt.Printf("best friend: %v\n", neighbors[0].Value) 22 | } 23 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/coder/hnsw 2 | 3 | go 1.21.4 4 | 5 | require github.com/stretchr/testify v1.9.0 6 | 7 | require github.com/google/renameio v1.0.1 8 | 9 | require ( 10 | github.com/chewxy/math32 v1.10.1 // indirect 11 | github.com/viterin/partial v1.1.0 // indirect 12 | github.com/viterin/vek v0.4.2 // indirect 13 | golang.org/x/sys v0.11.0 // indirect 14 | ) 15 | 16 | require ( 17 | github.com/davecgh/go-spew v1.1.1 // indirect 18 | github.com/pmezard/go-difflib v1.0.0 // indirect 19 | golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 20 | gopkg.in/yaml.v3 v3.0.1 // indirect 21 | ) 22 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/chewxy/math32 v1.10.1 h1:LFpeY0SLJXeaiej/eIp2L40VYfscTvKh/FSEZ68uMkU= 2 | github.com/chewxy/math32 v1.10.1/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs= 3 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 4 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/google/renameio v1.0.1 h1:Lh/jXZmvZxb0BBeSY5VKEfidcbcbenKjZFzM/q0fSeU= 6 | github.com/google/renameio v1.0.1/go.mod h1:t/HQoYBZSsWSNK35C6CO/TpPLDVWvxOHboWUAweKUpk= 7 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 8 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 9 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 10 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 11 | github.com/viterin/partial v1.1.0 h1:iH1l1xqBlapXsYzADS1dcbizg3iQUKTU1rbwkHv/80E= 12 | github.com/viterin/partial v1.1.0/go.mod h1:oKGAo7/wylWkJTLrWX8n+f4aDPtQMQ6VG4dd2qur5QA= 13 | github.com/viterin/vek v0.4.2 h1:Vyv04UjQT6gcjEFX82AS9ocgNbAJqsHviheIBdPlv5U= 14 | github.com/viterin/vek v0.4.2/go.mod h1:A4JRAe8OvbhdzBL5ofzjBS0J29FyUrf95tQogvtHHUc= 15 | golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= 16 | golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= 17 | golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= 18 | golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 19 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 20 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 21 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 22 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 23 | -------------------------------------------------------------------------------- /graph.go: -------------------------------------------------------------------------------- 1 | package hnsw 2 | 3 | import ( 4 | "cmp" 5 | "fmt" 6 | "math" 7 | "math/rand" 8 | "slices" 9 | "time" 10 | 11 | "github.com/coder/hnsw/heap" 12 | "golang.org/x/exp/maps" 13 | ) 14 | 15 | type Vector = []float32 16 | 17 | // Node is a node in the graph. 18 | type Node[K cmp.Ordered] struct { 19 | Key K 20 | Value Vector 21 | } 22 | 23 | func MakeNode[K cmp.Ordered](key K, vec Vector) Node[K] { 24 | return Node[K]{Key: key, Value: vec} 25 | } 26 | 27 | // layerNode is a node in a layer of the graph. 28 | type layerNode[K cmp.Ordered] struct { 29 | Node[K] 30 | 31 | // neighbors is map of neighbor keys to neighbor nodes. 32 | // It is a map and not a slice to allow for efficient deletes, esp. 33 | // when M is high. 34 | neighbors map[K]*layerNode[K] 35 | } 36 | 37 | // addNeighbor adds a o neighbor to the node, replacing the neighbor 38 | // with the worst distance if the neighbor set is full. 39 | func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFunc) { 40 | if n.neighbors == nil { 41 | n.neighbors = make(map[K]*layerNode[K], m) 42 | } 43 | 44 | n.neighbors[newNode.Key] = newNode 45 | if len(n.neighbors) <= m { 46 | return 47 | } 48 | 49 | // Find the neighbor with the worst distance. 50 | var ( 51 | worstDist = float32(math.Inf(-1)) 52 | worst *layerNode[K] 53 | ) 54 | for _, neighbor := range n.neighbors { 55 | d := dist(neighbor.Value, n.Value) 56 | // d > worstDist may always be false if the distance function 57 | // returns NaN, e.g., when the embeddings are zero. 58 | if d > worstDist || worst == nil { 59 | worstDist = d 60 | worst = neighbor 61 | } 62 | } 63 | 64 | delete(n.neighbors, worst.Key) 65 | // Delete backlink from the worst neighbor. 66 | delete(worst.neighbors, n.Key) 67 | worst.replenish(m) 68 | } 69 | 70 | type searchCandidate[K cmp.Ordered] struct { 71 | node *layerNode[K] 72 | dist float32 73 | } 74 | 75 | func (s searchCandidate[K]) Less(o searchCandidate[K]) bool { 76 | return s.dist < o.dist 77 | } 78 | 79 | // search returns the layer node closest to the target node 80 | // within the same layer. 81 | func (n *layerNode[K]) search( 82 | // k is the number of candidates in the result set. 83 | k int, 84 | efSearch int, 85 | target Vector, 86 | distance DistanceFunc, 87 | ) []searchCandidate[K] { 88 | // This is a basic greedy algorithm to find the entry point at the given level 89 | // that is closest to the target node. 90 | candidates := heap.Heap[searchCandidate[K]]{} 91 | candidates.Init(make([]searchCandidate[K], 0, efSearch)) 92 | candidates.Push( 93 | searchCandidate[K]{ 94 | node: n, 95 | dist: distance(n.Value, target), 96 | }, 97 | ) 98 | var ( 99 | result = heap.Heap[searchCandidate[K]]{} 100 | visited = make(map[K]bool) 101 | ) 102 | result.Init(make([]searchCandidate[K], 0, k)) 103 | 104 | // Begin with the entry node in the result set. 105 | result.Push(candidates.Min()) 106 | visited[n.Key] = true 107 | 108 | for candidates.Len() > 0 { 109 | var ( 110 | current = candidates.Pop().node 111 | improved = false 112 | ) 113 | 114 | // We iterate the map in a sorted, deterministic fashion for 115 | // tests. 116 | neighborKeys := maps.Keys(current.neighbors) 117 | slices.Sort(neighborKeys) 118 | for _, neighborID := range neighborKeys { 119 | neighbor := current.neighbors[neighborID] 120 | if visited[neighborID] { 121 | continue 122 | } 123 | visited[neighborID] = true 124 | 125 | dist := distance(neighbor.Value, target) 126 | improved = improved || dist < result.Min().dist 127 | if result.Len() < k { 128 | result.Push(searchCandidate[K]{node: neighbor, dist: dist}) 129 | } else if dist < result.Max().dist { 130 | result.PopLast() 131 | result.Push(searchCandidate[K]{node: neighbor, dist: dist}) 132 | } 133 | 134 | candidates.Push(searchCandidate[K]{node: neighbor, dist: dist}) 135 | // Always store candidates if we haven't reached the limit. 136 | if candidates.Len() > efSearch { 137 | candidates.PopLast() 138 | } 139 | } 140 | 141 | // Termination condition: no improvement in distance and at least 142 | // kMin candidates in the result set. 143 | if !improved && result.Len() >= k { 144 | break 145 | } 146 | } 147 | 148 | return result.Slice() 149 | } 150 | 151 | func (n *layerNode[K]) replenish(m int) { 152 | if len(n.neighbors) >= m { 153 | return 154 | } 155 | 156 | // Restore connectivity by adding new neighbors. 157 | // This is a naive implementation that could be improved by 158 | // using a priority queue to find the best candidates. 159 | for _, neighbor := range n.neighbors { 160 | for key, candidate := range neighbor.neighbors { 161 | if _, ok := n.neighbors[key]; ok { 162 | // do not add duplicates 163 | continue 164 | } 165 | if candidate == n { 166 | continue 167 | } 168 | n.addNeighbor(candidate, m, CosineDistance) 169 | if len(n.neighbors) >= m { 170 | return 171 | } 172 | } 173 | } 174 | } 175 | 176 | // isolates remove the node from the graph by removing all connections 177 | // to neighbors. 178 | func (n *layerNode[K]) isolate(m int) { 179 | for _, neighbor := range n.neighbors { 180 | delete(neighbor.neighbors, n.Key) 181 | neighbor.replenish(m) 182 | } 183 | } 184 | 185 | type layer[K cmp.Ordered] struct { 186 | // nodes is a map of nodes IDs to nodes. 187 | // All nodes in a higher layer are also in the lower layers, an essential 188 | // property of the graph. 189 | // 190 | // nodes is exported for interop with encoding/gob. 191 | nodes map[K]*layerNode[K] 192 | } 193 | 194 | // entry returns the entry node of the layer. 195 | // It doesn't matter which node is returned, even that the 196 | // entry node is consistent, so we just return the first node 197 | // in the map to avoid tracking extra state. 198 | func (l *layer[K]) entry() *layerNode[K] { 199 | if l == nil { 200 | return nil 201 | } 202 | for _, node := range l.nodes { 203 | return node 204 | } 205 | return nil 206 | } 207 | 208 | func (l *layer[K]) size() int { 209 | if l == nil { 210 | return 0 211 | } 212 | return len(l.nodes) 213 | } 214 | 215 | // Graph is a Hierarchical Navigable Small World graph. 216 | // All public parameters must be set before adding nodes to the graph. 217 | // K is cmp.Ordered instead of of comparable so that they can be sorted. 218 | type Graph[K cmp.Ordered] struct { 219 | // Distance is the distance function used to compare embeddings. 220 | Distance DistanceFunc 221 | 222 | // Rng is used for level generation. It may be set to a deterministic value 223 | // for reproducibility. Note that deterministic number generation can lead to 224 | // degenerate graphs when exposed to adversarial inputs. 225 | Rng *rand.Rand 226 | 227 | // M is the maximum number of neighbors to keep for each node. 228 | // A good default for OpenAI embeddings is 16. 229 | M int 230 | 231 | // Ml is the level generation factor. 232 | // E.g., for Ml = 0.25, each layer is 1/4 the size of the previous layer. 233 | Ml float64 234 | 235 | // EfSearch is the number of nodes to consider in the search phase. 236 | // 20 is a reasonable default. Higher values improve search accuracy at 237 | // the expense of memory. 238 | EfSearch int 239 | 240 | // layers is a slice of layers in the graph. 241 | layers []*layer[K] 242 | } 243 | 244 | func defaultRand() *rand.Rand { 245 | return rand.New(rand.NewSource(time.Now().UnixNano())) 246 | } 247 | 248 | // NewGraph returns a new graph with default parameters, roughly designed for 249 | // storing OpenAI embeddings. 250 | func NewGraph[K cmp.Ordered]() *Graph[K] { 251 | return &Graph[K]{ 252 | M: 16, 253 | Ml: 0.25, 254 | Distance: CosineDistance, 255 | EfSearch: 20, 256 | Rng: defaultRand(), 257 | } 258 | } 259 | 260 | // maxLevel returns an upper-bound on the number of levels in the graph 261 | // based on the size of the base layer. 262 | func maxLevel(ml float64, numNodes int) int { 263 | if ml == 0 { 264 | panic("ml must be greater than 0") 265 | } 266 | 267 | if numNodes == 0 { 268 | return 1 269 | } 270 | 271 | l := math.Log(float64(numNodes)) 272 | l /= math.Log(1 / ml) 273 | 274 | m := int(math.Round(l)) + 1 275 | 276 | return m 277 | } 278 | 279 | // randomLevel generates a random level for a new node. 280 | func (h *Graph[K]) randomLevel() int { 281 | // max avoids having to accept an additional parameter for the maximum level 282 | // by calculating a probably good one from the size of the base layer. 283 | max := 1 284 | if len(h.layers) > 0 { 285 | if h.Ml == 0 { 286 | panic("(*Graph).Ml must be greater than 0") 287 | } 288 | max = maxLevel(h.Ml, h.layers[0].size()) 289 | } 290 | 291 | for level := 0; level < max; level++ { 292 | if h.Rng == nil { 293 | h.Rng = defaultRand() 294 | } 295 | r := h.Rng.Float64() 296 | if r > h.Ml { 297 | return level 298 | } 299 | } 300 | 301 | return max 302 | } 303 | 304 | func (g *Graph[K]) assertDims(n Vector) { 305 | if len(g.layers) == 0 { 306 | return 307 | } 308 | hasDims := g.Dims() 309 | if hasDims != len(n) { 310 | panic(fmt.Sprint("embedding dimension mismatch: ", hasDims, " != ", len(n))) 311 | } 312 | } 313 | 314 | // Dims returns the number of dimensions in the graph, or 315 | // 0 if the graph is empty. 316 | func (g *Graph[K]) Dims() int { 317 | if len(g.layers) == 0 { 318 | return 0 319 | } 320 | return len(g.layers[0].entry().Value) 321 | } 322 | 323 | func ptr[T any](v T) *T { 324 | return &v 325 | } 326 | 327 | // Add inserts nodes into the graph. 328 | // If another node with the same ID exists, it is replaced. 329 | func (g *Graph[K]) Add(nodes ...Node[K]) { 330 | for _, node := range nodes { 331 | key := node.Key 332 | vec := node.Value 333 | 334 | g.assertDims(vec) 335 | insertLevel := g.randomLevel() 336 | // Create layers that don't exist yet. 337 | for insertLevel >= len(g.layers) { 338 | g.layers = append(g.layers, &layer[K]{}) 339 | } 340 | 341 | if insertLevel < 0 { 342 | panic("invalid level") 343 | } 344 | 345 | var elevator *K 346 | 347 | preLen := g.Len() 348 | 349 | // Insert node at each layer, beginning with the highest. 350 | for i := len(g.layers) - 1; i >= 0; i-- { 351 | layer := g.layers[i] 352 | newNode := &layerNode[K]{ 353 | Node: Node[K]{ 354 | Key: key, 355 | Value: vec, 356 | }, 357 | } 358 | 359 | // Insert the new node into the layer. 360 | if layer.entry() == nil { 361 | layer.nodes = map[K]*layerNode[K]{key: newNode} 362 | continue 363 | } 364 | 365 | // Now at the highest layer with more than one node, so we can begin 366 | // searching for the best way to enter the graph. 367 | searchPoint := layer.entry() 368 | 369 | // On subsequent layers, we use the elevator node to enter the graph 370 | // at the best point. 371 | if elevator != nil { 372 | searchPoint = layer.nodes[*elevator] 373 | } 374 | 375 | if g.Distance == nil { 376 | panic("(*Graph).Distance must be set") 377 | } 378 | 379 | neighborhood := searchPoint.search(g.M, g.EfSearch, vec, g.Distance) 380 | if len(neighborhood) == 0 { 381 | // This should never happen because the searchPoint itself 382 | // should be in the result set. 383 | panic("no nodes found") 384 | } 385 | 386 | // Re-set the elevator node for the next layer. 387 | elevator = ptr(neighborhood[0].node.Key) 388 | 389 | if insertLevel >= i { 390 | if _, ok := layer.nodes[key]; ok { 391 | g.Delete(key) 392 | } 393 | // Insert the new node into the layer. 394 | layer.nodes[key] = newNode 395 | for _, node := range neighborhood { 396 | // Create a bi-directional edge between the new node and the best node. 397 | node.node.addNeighbor(newNode, g.M, g.Distance) 398 | newNode.addNeighbor(node.node, g.M, g.Distance) 399 | } 400 | } 401 | } 402 | 403 | // Invariant check: the node should have been added to the graph. 404 | if g.Len() != preLen+1 { 405 | panic("node not added") 406 | } 407 | } 408 | } 409 | 410 | // Search finds the k nearest neighbors from the target node. 411 | func (h *Graph[K]) Search(near Vector, k int) []Node[K] { 412 | h.assertDims(near) 413 | if len(h.layers) == 0 { 414 | return nil 415 | } 416 | 417 | var ( 418 | efSearch = h.EfSearch 419 | 420 | elevator *K 421 | ) 422 | 423 | for layer := len(h.layers) - 1; layer >= 0; layer-- { 424 | searchPoint := h.layers[layer].entry() 425 | if elevator != nil { 426 | searchPoint = h.layers[layer].nodes[*elevator] 427 | } 428 | 429 | // Descending hierarchies 430 | if layer > 0 { 431 | nodes := searchPoint.search(1, efSearch, near, h.Distance) 432 | elevator = ptr(nodes[0].node.Key) 433 | continue 434 | } 435 | 436 | nodes := searchPoint.search(k, efSearch, near, h.Distance) 437 | out := make([]Node[K], 0, len(nodes)) 438 | 439 | for _, node := range nodes { 440 | out = append(out, node.node.Node) 441 | } 442 | 443 | return out 444 | } 445 | 446 | panic("unreachable") 447 | } 448 | 449 | // Len returns the number of nodes in the graph. 450 | func (h *Graph[K]) Len() int { 451 | if len(h.layers) == 0 { 452 | return 0 453 | } 454 | return h.layers[0].size() 455 | } 456 | 457 | // Delete removes a node from the graph by key. 458 | // It tries to preserve the clustering properties of the graph by 459 | // replenishing connectivity in the affected neighborhoods. 460 | func (h *Graph[K]) Delete(key K) bool { 461 | if len(h.layers) == 0 { 462 | return false 463 | } 464 | 465 | var deleted bool 466 | for _, layer := range h.layers { 467 | node, ok := layer.nodes[key] 468 | if !ok { 469 | continue 470 | } 471 | delete(layer.nodes, key) 472 | node.isolate(h.M) 473 | deleted = true 474 | } 475 | 476 | return deleted 477 | } 478 | 479 | // Lookup returns the vector with the given key. 480 | func (h *Graph[K]) Lookup(key K) (Vector, bool) { 481 | if len(h.layers) == 0 { 482 | return nil, false 483 | } 484 | 485 | node, ok := h.layers[0].nodes[key] 486 | if !ok { 487 | return nil, false 488 | } 489 | return node.Value, ok 490 | } 491 | -------------------------------------------------------------------------------- /graph_test.go: -------------------------------------------------------------------------------- 1 | package hnsw 2 | 3 | import ( 4 | "cmp" 5 | "math/rand" 6 | "strconv" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func Test_maxLevel(t *testing.T) { 13 | var m int 14 | 15 | m = maxLevel(0.5, 10) 16 | require.Equal(t, 4, m) 17 | 18 | m = maxLevel(0.5, 1000) 19 | require.Equal(t, 11, m) 20 | } 21 | 22 | func Test_layerNode_search(t *testing.T) { 23 | entry := &layerNode[int]{ 24 | Node: Node[int]{ 25 | Value: Vector{0}, 26 | Key: 0, 27 | }, 28 | neighbors: map[int]*layerNode[int]{ 29 | 1: { 30 | Node: Node[int]{ 31 | Value: Vector{1}, 32 | Key: 1, 33 | }, 34 | }, 35 | 2: { 36 | Node: Node[int]{ 37 | Value: Vector{2}, 38 | Key: 2, 39 | }, 40 | }, 41 | 3: { 42 | Node: Node[int]{ 43 | Value: Vector{3}, 44 | Key: 3, 45 | }, 46 | neighbors: map[int]*layerNode[int]{ 47 | 4: { 48 | Node: Node[int]{ 49 | Value: Vector{4}, 50 | Key: 5, 51 | }, 52 | }, 53 | 5: { 54 | Node: Node[int]{ 55 | Value: Vector{5}, 56 | Key: 5, 57 | }, 58 | }, 59 | }, 60 | }, 61 | }, 62 | } 63 | 64 | best := entry.search(2, 4, []float32{4}, EuclideanDistance) 65 | 66 | require.Equal(t, 5, best[0].node.Key) 67 | require.Equal(t, 3, best[1].node.Key) 68 | require.Len(t, best, 2) 69 | } 70 | 71 | func newTestGraph[K cmp.Ordered]() *Graph[K] { 72 | return &Graph[K]{ 73 | M: 6, 74 | Distance: EuclideanDistance, 75 | Ml: 0.5, 76 | EfSearch: 20, 77 | Rng: rand.New(rand.NewSource(0)), 78 | } 79 | } 80 | 81 | func TestGraph_AddSearch(t *testing.T) { 82 | t.Parallel() 83 | 84 | g := newTestGraph[int]() 85 | 86 | for i := 0; i < 128; i++ { 87 | g.Add( 88 | Node[int]{ 89 | Key: i, 90 | Value: Vector{float32(i)}, 91 | }, 92 | ) 93 | } 94 | 95 | al := Analyzer[int]{Graph: g} 96 | 97 | // Layers should be approximately log2(128) = 7 98 | // Look for an approximate doubling of the number of nodes in each layer. 99 | require.Equal(t, []int{ 100 | 128, 101 | 67, 102 | 28, 103 | 12, 104 | 6, 105 | 2, 106 | 1, 107 | 1, 108 | }, al.Topography()) 109 | 110 | nearest := g.Search( 111 | []float32{64.5}, 112 | 4, 113 | ) 114 | 115 | require.Len(t, nearest, 4) 116 | require.EqualValues( 117 | t, 118 | []Node[int]{ 119 | {64, Vector{64}}, 120 | {65, Vector{65}}, 121 | {62, Vector{62}}, 122 | {63, Vector{63}}, 123 | }, 124 | nearest, 125 | ) 126 | } 127 | 128 | func TestGraph_AddDelete(t *testing.T) { 129 | t.Parallel() 130 | 131 | g := newTestGraph[int]() 132 | for i := 0; i < 128; i++ { 133 | g.Add(Node[int]{ 134 | Key: i, 135 | Value: Vector{float32(i)}, 136 | }) 137 | } 138 | 139 | require.Equal(t, 128, g.Len()) 140 | an := Analyzer[int]{Graph: g} 141 | 142 | preDeleteConnectivity := an.Connectivity() 143 | 144 | // Delete every even node. 145 | for i := 0; i < 128; i += 2 { 146 | ok := g.Delete(i) 147 | require.True(t, ok) 148 | } 149 | 150 | require.Equal(t, 64, g.Len()) 151 | 152 | postDeleteConnectivity := an.Connectivity() 153 | 154 | // Connectivity should be the same for the lowest layer. 155 | require.Equal( 156 | t, preDeleteConnectivity[0], 157 | postDeleteConnectivity[0], 158 | ) 159 | 160 | t.Run("DeleteNotFound", func(t *testing.T) { 161 | ok := g.Delete(-1) 162 | require.False(t, ok) 163 | }) 164 | } 165 | 166 | func Benchmark_HSNW(b *testing.B) { 167 | b.ReportAllocs() 168 | 169 | sizes := []int{100, 1000, 10000} 170 | 171 | // Use this to ensure that complexity is O(log n) where n = h.Len(). 172 | for _, size := range sizes { 173 | b.Run(strconv.Itoa(size), func(b *testing.B) { 174 | g := Graph[int]{} 175 | g.Ml = 0.5 176 | g.Distance = EuclideanDistance 177 | for i := 0; i < size; i++ { 178 | g.Add(Node[int]{ 179 | Key: i, 180 | Value: Vector{float32(i)}, 181 | }) 182 | } 183 | b.ResetTimer() 184 | 185 | b.Run("Search", func(b *testing.B) { 186 | for i := 0; i < b.N; i++ { 187 | g.Search( 188 | []float32{float32(i % size)}, 189 | 4, 190 | ) 191 | } 192 | }) 193 | }) 194 | } 195 | } 196 | 197 | func randFloats(n int) []float32 { 198 | x := make([]float32, n) 199 | for i := range x { 200 | x[i] = rand.Float32() 201 | } 202 | return x 203 | } 204 | 205 | func Benchmark_HNSW_1536(b *testing.B) { 206 | b.ReportAllocs() 207 | 208 | g := newTestGraph[int]() 209 | const size = 1000 210 | points := make([]Node[int], size) 211 | for i := 0; i < size; i++ { 212 | points[i] = Node[int]{ 213 | Key: i, 214 | Value: Vector(randFloats(1536)), 215 | } 216 | g.Add(points[i]) 217 | } 218 | b.ResetTimer() 219 | 220 | b.Run("Search", func(b *testing.B) { 221 | for i := 0; i < b.N; i++ { 222 | g.Search( 223 | points[i%size].Value, 224 | 4, 225 | ) 226 | } 227 | }) 228 | } 229 | 230 | func TestGraph_DefaultCosine(t *testing.T) { 231 | g := NewGraph[int]() 232 | g.Add( 233 | Node[int]{Key: 1, Value: Vector{1, 1}}, 234 | Node[int]{Key: 2, Value: Vector{0, 1}}, 235 | Node[int]{Key: 3, Value: Vector{1, -1}}, 236 | ) 237 | 238 | neighbors := g.Search( 239 | []float32{0.5, 0.5}, 240 | 1, 241 | ) 242 | 243 | require.Equal( 244 | t, 245 | []Node[int]{ 246 | {1, Vector{1, 1}}, 247 | }, 248 | neighbors, 249 | ) 250 | } 251 | -------------------------------------------------------------------------------- /heap/heap.go: -------------------------------------------------------------------------------- 1 | package heap 2 | 3 | import "container/heap" 4 | 5 | // Lessable is an interface that allows a type to be compared to another of the same type. 6 | // It is used to define the order of elements in the heap. 7 | type Lessable[T any] interface { 8 | Less(T) bool 9 | } 10 | 11 | // innerHeap is a type that represents the heap data structure. 12 | // it implements the std heap interface. 13 | type innerHeap[T Lessable[T]] struct { 14 | data []T 15 | } 16 | 17 | func (h *innerHeap[T]) Len() int { 18 | return len(h.data) 19 | } 20 | 21 | func (h *innerHeap[T]) Less(i, j int) bool { 22 | return h.data[i].Less(h.data[j]) 23 | } 24 | 25 | func (h *innerHeap[T]) Swap(i, j int) { 26 | h.data[i], h.data[j] = h.data[j], h.data[i] 27 | } 28 | 29 | func (h *innerHeap[T]) Push(x interface{}) { 30 | h.data = append(h.data, x.(T)) 31 | } 32 | 33 | func (h *innerHeap[T]) Pop() interface{} { 34 | n := len(h.data) 35 | x := h.data[n-1] 36 | h.data = h.data[:n-1] 37 | return x 38 | } 39 | 40 | // Heap represents the heap data structure using a flat array to store the elements. 41 | // It is a wrapper around the standard library's heap. 42 | type Heap[T Lessable[T]] struct { 43 | inner innerHeap[T] 44 | } 45 | 46 | // Init establishes the heap invariants required by the other routines in this package. 47 | // Init is idempotent with respect to the heap invariants 48 | // and may be called whenever the heap invariants may have been invalidated. 49 | // The complexity is O(n) where n = h.Len(). 50 | func (h *Heap[T]) Init(d []T) { 51 | h.inner.data = d 52 | heap.Init(&h.inner) 53 | } 54 | 55 | // Len returns the number of elements in the heap. 56 | func (h *Heap[T]) Len() int { 57 | return h.inner.Len() 58 | } 59 | 60 | // Push pushes the element x onto the heap. 61 | // The complexity is O(log n) where n = h.Len(). 62 | func (h *Heap[T]) Push(x T) { 63 | heap.Push(&h.inner, x) 64 | } 65 | 66 | // Pop removes and returns the minimum element (according to Less) from the heap. 67 | // The complexity is O(log n) where n = h.Len(). 68 | // Pop is equivalent to Remove(h, 0). 69 | func (h *Heap[T]) Pop() T { 70 | return heap.Pop(&h.inner).(T) 71 | } 72 | 73 | func (h *Heap[T]) PopLast() T { 74 | return h.Remove(h.Len() - 1) 75 | } 76 | 77 | // Remove removes and returns the element at index i from the heap. 78 | // The complexity is O(log n) where n = h.Len(). 79 | func (h *Heap[T]) Remove(i int) T { 80 | return heap.Remove(&h.inner, i).(T) 81 | } 82 | 83 | // Min returns the minimum element in the heap. 84 | func (h *Heap[T]) Min() T { 85 | return h.inner.data[0] 86 | } 87 | 88 | // Max returns the maximum element in the heap. 89 | func (h *Heap[T]) Max() T { 90 | return h.inner.data[h.inner.Len()-1] 91 | } 92 | 93 | func (h *Heap[T]) Slice() []T { 94 | return h.inner.data 95 | } 96 | -------------------------------------------------------------------------------- /heap/heap_test.go: -------------------------------------------------------------------------------- 1 | package heap 2 | 3 | import ( 4 | "math/rand" 5 | "slices" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | type Int int 12 | 13 | func (i Int) Less(j Int) bool { 14 | return i < j 15 | } 16 | 17 | func TestHeap(t *testing.T) { 18 | h := Heap[Int]{} 19 | 20 | for i := 0; i < 20; i++ { 21 | h.Push(Int(rand.Int() % 100)) 22 | } 23 | 24 | require.Equal(t, 20, h.Len()) 25 | 26 | var inOrder []Int 27 | for h.Len() > 0 { 28 | inOrder = append(inOrder, h.Pop()) 29 | } 30 | 31 | if !slices.IsSorted(inOrder) { 32 | t.Errorf("Heap did not return sorted elements: %+v", inOrder) 33 | } 34 | } 35 | --------------------------------------------------------------------------------