├── .dockerignore ├── .github ├── workflows │ ├── go.yml │ ├── docs.yml │ ├── python.yml │ └── release.yml ├── dependabot.yml ├── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md └── PULL_REQUEST_TEMPLATE.md ├── cmd └── quiver │ ├── config.yaml │ └── main.go ├── .gitignore ├── pkg ├── vectortypes │ ├── benchmark_test.go │ ├── surface.go │ ├── types.go │ ├── surface_test.go │ ├── distances.go │ ├── types_test.go │ └── distances_test.go ├── hybrid │ ├── benchmark_test.go │ ├── hnsw_adapter.go │ ├── exact.go │ ├── types.go │ ├── hnsw_adapter_test.go │ ├── adaptive.go │ ├── exact_test.go │ └── adaptive_test.go ├── persistence │ ├── collection_benchmark_test.go │ ├── parquet.go │ ├── parquet_test.go │ └── collection.go ├── types │ └── search.go ├── api │ ├── middleware.go │ └── server.go ├── metrics │ └── collector.go └── facets │ └── facets.go ├── docs └── index_types.md ├── LICENSE ├── benchmark └── arrow_hnsw_bench_test.go ├── docker-compose.yml ├── index ├── arrow_hnsw_test.go └── arrow_hnsw.go ├── Dockerfile ├── go.mod └── README.md /.dockerignore: -------------------------------------------------------------------------------- 1 | # Git 2 | .git 3 | .gitignore 4 | .github 5 | 6 | # Docker 7 | Dockerfile 8 | docker-compose.yml 9 | .dockerignore 10 | 11 | # Build artifacts 12 | **/bin/ 13 | **/obj/ 14 | **/out/ 15 | **/dist/ 16 | **/build/ 17 | 18 | # Development and IDE files 19 | .vscode/ 20 | .idea/ 21 | *.swp 22 | *.swo 23 | *~ 24 | 25 | # Logs 26 | logs/ 27 | *.log 28 | 29 | # Data directories 30 | data/ 31 | backup/ 32 | 33 | # Documentation 34 | docs/ 35 | *.md 36 | LICENSE 37 | 38 | # Test files 39 | **/*_test.go 40 | **/testdata/ 41 | 42 | # Misc 43 | .DS_Store 44 | *.env 45 | .env* 46 | !.env.example -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | build-test-lint: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | 15 | - name: Set up Go 16 | uses: actions/setup-go@v5 17 | with: 18 | go-version: "1.24" 19 | cache: true 20 | 21 | - name: Install dependencies 22 | run: | 23 | go mod download 24 | # Install golangci-lint 25 | curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.55.2 26 | 27 | - name: Build 28 | run: go build -v ./... 29 | -------------------------------------------------------------------------------- /cmd/quiver/config.yaml: -------------------------------------------------------------------------------- 1 | # Quiver configuration file 2 | # This is a sample configuration - save as .quiver.yaml in your home directory 3 | 4 | # Database settings 5 | data_dir: "./data" # Directory for vector storage 6 | log_level: "info" # Log level: debug, info, warn, error 7 | 8 | # Server settings 9 | host: "localhost" # Server host address 10 | port: 8080 # Server port 11 | 12 | # Security settings 13 | auth: false # Enable JWT authentication 14 | jwt_secret: "" # JWT secret (required if auth is enabled) 15 | cors: true # Enable CORS for API access 16 | 17 | # Advanced settings 18 | flush_interval: 60 # Interval in seconds to flush data to disk 19 | backup_dir: "./backups" # Directory for automatic backups (if enabled) 20 | auto_backup: false # Enable automatic backups 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # If you prefer the allow list template instead of the deny list, see community template: 2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 3 | # 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | 11 | # Test binary, built with `go test -c` 12 | *.test 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Dependency directories (remove the comment below to include it) 18 | # vendor/ 19 | 20 | # Go workspace file 21 | go.work 22 | go.work.sum 23 | 24 | # env file 25 | .env 26 | 27 | # benchmark results 28 | bench.db 29 | bench.hnsw 30 | *.wal 31 | *.db 32 | *.db-shm 33 | *.db-wal 34 | *.db-lock 35 | *.db-trace 36 | *.hnsw 37 | 38 | # Python 39 | python/ 40 | data/ 41 | testdata/ 42 | .venv/ 43 | .DS_Store 44 | -------------------------------------------------------------------------------- /pkg/vectortypes/benchmark_test.go: -------------------------------------------------------------------------------- 1 | package vectortypes 2 | 3 | import "testing" 4 | 5 | func benchmarkDistance(b *testing.B, fn DistanceFunc) { 6 | dim := 128 7 | a := make(F32, dim) 8 | c := make(F32, dim) 9 | for i := 0; i < dim; i++ { 10 | a[i] = float32(i) * 0.01 11 | c[i] = float32(i+1) * 0.02 12 | } 13 | b.ResetTimer() 14 | for i := 0; i < b.N; i++ { 15 | _ = fn(a, c) 16 | } 17 | } 18 | 19 | func BenchmarkCosineDistance(b *testing.B) { benchmarkDistance(b, CosineDistance) } 20 | func BenchmarkEuclideanDistance(b *testing.B) { benchmarkDistance(b, EuclideanDistance) } 21 | func BenchmarkSquaredEuclideanDistance(b *testing.B) { benchmarkDistance(b, SquaredEuclideanDistance) } 22 | func BenchmarkDotProductDistance(b *testing.B) { benchmarkDistance(b, DotProductDistance) } 23 | func BenchmarkManhattanDistance(b *testing.B) { benchmarkDistance(b, ManhattanDistance) } 24 | -------------------------------------------------------------------------------- /docs/index_types.md: -------------------------------------------------------------------------------- 1 | # Index Types 2 | 3 | ## ArrowHNSW Index 4 | 5 | The ArrowHNSW index is a variant of the HNSW index that stores vectors in Apache Arrow format. Vectors and metadata are kept in columnar structures to enable zero-copy access and efficient batch operations. The index can be persisted using Arrow IPC files and optionally converted to Parquet for long term storage. 6 | 7 | ### Use Cases 8 | - Workloads that benefit from columnar processing 9 | - Integration with Arrow based analytics pipelines 10 | 11 | ### Persistence Format 12 | `Save` writes the vectors and identifiers as an Arrow record with a fixed size list column. `Load` restores the index from this file. 13 | 14 | ### Example 15 | ```go 16 | idx := index.NewArrowHNSWIndex(128) 17 | // build a query vector as Arrow Float32 array 18 | b := array.NewFloat32Builder(memory.DefaultAllocator) 19 | b.AppendValues(myVec, nil) 20 | arr := b.NewArray() 21 | res, _ := idx.Search(arr, 10) 22 | ``` 23 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # Go dependencies 4 | - package-ecosystem: "gomod" 5 | directory: "/" 6 | schedule: 7 | interval: "weekly" 8 | open-pull-requests-limit: 10 9 | labels: 10 | - "dependencies" 11 | - "go" 12 | commit-message: 13 | prefix: "deps" 14 | include: "scope" 15 | 16 | # Python dependencies 17 | - package-ecosystem: "pip" 18 | directory: "/quiver/python/" 19 | schedule: 20 | interval: "weekly" 21 | open-pull-requests-limit: 10 22 | labels: 23 | - "dependencies" 24 | - "python" 25 | commit-message: 26 | prefix: "deps" 27 | include: "scope" 28 | 29 | # GitHub Actions 30 | - package-ecosystem: "github-actions" 31 | directory: "/" 32 | schedule: 33 | interval: "monthly" 34 | open-pull-requests-limit: 5 35 | labels: 36 | - "dependencies" 37 | - "github-actions" 38 | commit-message: 39 | prefix: "ci" 40 | include: "scope" 41 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for Quiver 4 | title: '[FEATURE] ' 5 | labels: enhancement 6 | assignees: '' 7 | --- 8 | 9 | ## Is your feature request related to a problem? Please describe 10 | 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | ## Describe the solution you'd like 14 | 15 | A clear and concise description of what you want to happen. 16 | 17 | ## Describe alternatives you've considered 18 | 19 | A clear and concise description of any alternative solutions or features you've considered. 20 | 21 | ## Use case 22 | 23 | Describe how you would use this feature in your application or workflow. 24 | 25 | ## Additional context 26 | 27 | Add any other context or screenshots about the feature request here. 28 | 29 | ## Would you be willing to contribute this feature? 30 | 31 | Let us know if you'd be interested in implementing this feature yourself with guidance from the maintainers. 32 | -------------------------------------------------------------------------------- /pkg/hybrid/benchmark_test.go: -------------------------------------------------------------------------------- 1 | package hybrid 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/TFMV/quiver/pkg/vectortypes" 8 | ) 9 | 10 | func buildExactIndex(n, dim int) *ExactIndex { 11 | idx := NewExactIndex(vectortypes.CosineDistance) 12 | vec := make(vectortypes.F32, dim) 13 | for i := 0; i < n; i++ { 14 | idx.Insert(fmt.Sprintf("id%d", i), vec) 15 | } 16 | return idx 17 | } 18 | 19 | func BenchmarkExactIndexSearch(b *testing.B) { 20 | idx := buildExactIndex(1000, 64) 21 | q := make(vectortypes.F32, 64) 22 | b.ResetTimer() 23 | for i := 0; i < b.N; i++ { 24 | idx.Search(q, 10) 25 | } 26 | } 27 | 28 | func buildHybridIndex(n, dim int) *HybridIndex { 29 | cfg := DefaultIndexConfig() 30 | idx := NewHybridIndex(cfg) 31 | vec := make(vectortypes.F32, dim) 32 | for i := 0; i < n; i++ { 33 | idx.Insert(fmt.Sprintf("id%d", i), vec) 34 | } 35 | return idx 36 | } 37 | 38 | func BenchmarkHybridIndexSearch(b *testing.B) { 39 | idx := buildHybridIndex(1000, 64) 40 | q := make(vectortypes.F32, 64) 41 | b.ResetTimer() 42 | for i := 0; i < b.N; i++ { 43 | idx.Search(q, 10) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Thomas F McGeehan V 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /benchmark/arrow_hnsw_bench_test.go: -------------------------------------------------------------------------------- 1 | package benchmark 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/apache/arrow-go/v18/arrow/array" 8 | "github.com/apache/arrow-go/v18/arrow/memory" 9 | 10 | "github.com/TFMV/quiver/index" 11 | ) 12 | 13 | func buildArrowIndex(n, dim int) *index.ArrowHNSWIndex { // arrow-hnsw 14 | idx := index.NewArrowHNSWIndex(dim) 15 | for i := 0; i < n; i++ { 16 | b := array.NewFloat32Builder(memory.DefaultAllocator) 17 | vals := make([]float32, dim) 18 | for j := 0; j < dim; j++ { 19 | vals[j] = float32(i*j) / float32(dim) 20 | } 21 | b.AppendValues(vals, nil) 22 | arr := b.NewArray() 23 | idx.Add(arr.(*array.Float32), fmt.Sprintf("%d", i)) 24 | arr.Release() 25 | } 26 | return idx 27 | } 28 | 29 | func BenchmarkArrowHNSWBuild(b *testing.B) { // arrow-hnsw 30 | for i := 0; i < b.N; i++ { 31 | _ = buildArrowIndex(1000, 32) 32 | } 33 | } 34 | 35 | func BenchmarkArrowHNSWSearch(b *testing.B) { // arrow-hnsw 36 | idx := buildArrowIndex(100000, 32) 37 | qb := array.NewFloat32Builder(memory.DefaultAllocator) 38 | qb.AppendValues(make([]float32, 32), nil) 39 | query := qb.NewArray() 40 | defer query.Release() 41 | b.ResetTimer() 42 | for i := 0; i < b.N; i++ { 43 | idx.Search(query.(*array.Float32), 10) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve Quiver 4 | title: '[BUG] ' 5 | labels: bug 6 | assignees: '' 7 | --- 8 | 9 | ## Describe the bug 10 | 11 | A clear and concise description of what the bug is. 12 | 13 | ## To Reproduce 14 | 15 | Steps to reproduce the behavior: 16 | 17 | 1. Initialize Quiver with '...' 18 | 2. Add vectors '....' 19 | 3. Search for '....' 20 | 4. See error 21 | 22 | ## Expected behavior 23 | 24 | A clear and concise description of what you expected to happen. 25 | 26 | ## Code Example 27 | 28 | ```go 29 | // Or Python, if applicable 30 | // Please provide a minimal code example that reproduces the issue 31 | ``` 32 | 33 | ## Environment (please complete the following information) 34 | 35 | - OS: [e.g. Ubuntu 22.04, macOS 13.0] 36 | - Go/Python Version: [e.g. Go 1.21, Python 3.10] 37 | - Quiver Version: [e.g. 0.1.0] 38 | - Hardware: [e.g. 16GB RAM, 8 core CPU] 39 | 40 | ## Additional context 41 | 42 | Add any other context about the problem here, such as: 43 | 44 | - Vector dimensions 45 | - Index size 46 | - Concurrent operations 47 | - Any error messages or stack traces 48 | 49 | ## Logs 50 | 51 | If applicable, add logs to help explain your problem. 52 | 53 | ``` 54 | Paste logs here 55 | ``` 56 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.8" 2 | 3 | services: 4 | quiver: 5 | build: 6 | context: . 7 | dockerfile: Dockerfile 8 | container_name: quiver-db 9 | ports: 10 | - "8080:8080" 11 | volumes: 12 | - quiver-data:/app/data 13 | environment: 14 | # Server configuration 15 | - QUIVER_SERVER_PORT=8080 16 | - QUIVER_SERVER_HOST=0.0.0.0 17 | - QUIVER_LOG_LEVEL=info 18 | 19 | # Storage configuration 20 | - QUIVER_SERVER_STORAGE=/app/data 21 | - QUIVER_ENABLE_PERSISTENCE=true 22 | - QUIVER_PERSISTENCE_FORMAT=parquet 23 | - QUIVER_FLUSH_INTERVAL=300 24 | 25 | # Index configuration 26 | - QUIVER_INDEX_DIMENSION=768 27 | - QUIVER_INDEX_DISTANCE=cosine 28 | - QUIVER_ENABLE_HYBRID=true 29 | - QUIVER_MAX_CONNECTIONS=16 30 | - QUIVER_EF_CONSTRUCTION=200 31 | - QUIVER_EF_SEARCH=100 32 | 33 | # Performance configuration 34 | - QUIVER_ENABLE_METRICS=true 35 | - QUIVER_BATCH_SIZE=100 36 | restart: unless-stopped 37 | healthcheck: 38 | test: ["CMD", "wget", "-qO-", "http://localhost:8080/health"] 39 | interval: 30s 40 | timeout: 10s 41 | retries: 3 42 | start_period: 5s 43 | 44 | volumes: 45 | quiver-data: 46 | driver: local 47 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | 4 | 5 | Fixes # (issue) 6 | 7 | ## Type of change 8 | 9 | 10 | 11 | - [ ] Bug fix (non-breaking change which fixes an issue) 12 | - [ ] New feature (non-breaking change which adds functionality) 13 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 14 | - [ ] Documentation update 15 | - [ ] Performance improvement 16 | - [ ] Code refactoring (no functional changes) 17 | 18 | ## How Has This Been Tested? 19 | 20 | 21 | 22 | - [ ] Unit tests 23 | - [ ] Integration tests 24 | - [ ] Benchmark tests 25 | - [ ] Manual testing 26 | 27 | ## Checklist 28 | 29 | - [ ] My code follows the style guidelines of this project 30 | - [ ] I have performed a self-review of my own code 31 | - [ ] I have commented my code, particularly in hard-to-understand areas 32 | - [ ] I have made corresponding changes to the documentation 33 | - [ ] My changes generate no new warnings 34 | - [ ] I have added tests that prove my fix is effective or that my feature works 35 | - [ ] New and existing unit tests pass locally with my changes 36 | - [ ] Any dependent changes have been merged and published in downstream modules 37 | -------------------------------------------------------------------------------- /pkg/vectortypes/surface.go: -------------------------------------------------------------------------------- 1 | // Package vectortypes provides common types for vector operations 2 | package vectortypes 3 | 4 | // F32 is a type alias for []float32 to make it more expressive 5 | type F32 = []float32 6 | 7 | // DistanceFunc is a function that computes the distance between two vectors. 8 | type DistanceFunc func(a, b F32) float32 9 | 10 | // Surface represents a distance function between two vectors 11 | type Surface[T any] interface { 12 | // Distance calculates the distance between two vectors 13 | Distance(a, b T) float32 14 | } 15 | 16 | // ContraMap is a generic adapter that allows applying a distance function to a different type 17 | // by first mapping that type to the vector type the distance function expects 18 | type ContraMap[V, T any] struct { 19 | // The underlying surface (distance function) 20 | Surface Surface[V] 21 | 22 | // The mapping function from T to V 23 | ContraMap func(T) V 24 | } 25 | 26 | // Distance implements the Surface interface by first mapping the inputs and then applying the underlying distance function 27 | func (c ContraMap[V, T]) Distance(a, b T) float32 { 28 | return c.Surface.Distance(c.ContraMap(a), c.ContraMap(b)) 29 | } 30 | 31 | // BasicSurface wraps a standard distance function 32 | type BasicSurface struct { 33 | DistFunc DistanceFunc 34 | } 35 | 36 | // Distance implements the Surface interface for F32 vectors 37 | func (s BasicSurface) Distance(a, b F32) float32 { 38 | return s.DistFunc(a, b) 39 | } 40 | 41 | // CreateSurface creates a basic surface from a distance function 42 | func CreateSurface(distFunc DistanceFunc) Surface[F32] { 43 | return BasicSurface{DistFunc: distFunc} 44 | } 45 | -------------------------------------------------------------------------------- /pkg/persistence/collection_benchmark_test.go: -------------------------------------------------------------------------------- 1 | package persistence 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/TFMV/quiver/pkg/vectortypes" 8 | ) 9 | 10 | func setupBenchmarkCollection(numVectors, dim int) *Collection { 11 | c := NewCollection("bench", dim, vectortypes.EuclideanDistance) 12 | vec := make([]float32, dim) 13 | for i := 0; i < numVectors; i++ { 14 | id := fmt.Sprintf("v%d", i) 15 | for j := range vec { 16 | vec[j] = float32((i+j)%dim) / 100.0 17 | } 18 | c.AddVector(id, append([]float32(nil), vec...), nil) 19 | } 20 | return c 21 | } 22 | 23 | func BenchmarkCollectionSearch(b *testing.B) { 24 | const numVectors = 1000 25 | const dim = 32 26 | c := setupBenchmarkCollection(numVectors, dim) 27 | query := make([]float32, dim) 28 | for i := range query { 29 | query[i] = 0.5 30 | } 31 | 32 | b.ResetTimer() 33 | for i := 0; i < b.N; i++ { 34 | _, _ = c.Search(query, 10) 35 | } 36 | } 37 | 38 | func BenchmarkSortSearchResults(b *testing.B) { 39 | const numResults = 1000 40 | results := make([]SearchResult, numResults) 41 | for i := 0; i < numResults; i++ { 42 | results[i] = SearchResult{ 43 | ID: fmt.Sprintf("v%d", i), 44 | Distance: float32(numResults - i), 45 | } 46 | } 47 | 48 | b.ResetTimer() 49 | for i := 0; i < b.N; i++ { 50 | SortSearchResults(results) 51 | } 52 | } 53 | 54 | func BenchmarkCollectionAddVector(b *testing.B) { 55 | const dim = 32 56 | c := NewCollection("bench_add", dim, vectortypes.EuclideanDistance) 57 | vec := make([]float32, dim) 58 | b.ResetTimer() 59 | for i := 0; i < b.N; i++ { 60 | id := fmt.Sprintf("v%d", i) 61 | for j := range vec { 62 | vec[j] = float32((i+j)%dim) / 100.0 63 | } 64 | if err := c.AddVector(id, vec, nil); err != nil { 65 | b.Fatal(err) 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /index/arrow_hnsw_test.go: -------------------------------------------------------------------------------- 1 | package index 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/apache/arrow-go/v18/arrow/array" 8 | "github.com/apache/arrow-go/v18/arrow/memory" 9 | ) 10 | 11 | func TestArrowHNSWIndex_AddSearch(t *testing.T) { // arrow-hnsw 12 | idx := NewArrowHNSWIndex(3) 13 | b := array.NewFloat32Builder(memory.DefaultAllocator) 14 | b.AppendValues([]float32{1, 0, 0}, nil) 15 | vec := b.NewArray() 16 | defer vec.Release() 17 | if err := idx.Add(vec.(*array.Float32), "v1"); err != nil { 18 | t.Fatalf("add: %v", err) 19 | } 20 | b2 := array.NewFloat32Builder(memory.DefaultAllocator) 21 | b2.AppendValues([]float32{0.9, 0, 0}, nil) 22 | q := b2.NewArray() 23 | defer q.Release() 24 | res, err := idx.Search(q.(*array.Float32), 1) 25 | if err != nil { 26 | t.Fatalf("search: %v", err) 27 | } 28 | if len(res) != 1 || res[0].ID != "v1" { 29 | t.Fatalf("unexpected result: %+v", res) 30 | } 31 | } 32 | 33 | func TestArrowHNSWIndex_SaveLoad(t *testing.T) { // arrow-hnsw 34 | idx := NewArrowHNSWIndex(2) 35 | b := array.NewFloat32Builder(memory.DefaultAllocator) 36 | b.AppendValues([]float32{1, 2}, nil) 37 | vec := b.NewArray() 38 | defer vec.Release() 39 | if err := idx.Add(vec.(*array.Float32), "a"); err != nil { 40 | t.Fatalf("add: %v", err) 41 | } 42 | path := "test_arrow_hnsw.arrow" 43 | if err := idx.Save(path); err != nil { 44 | t.Fatalf("save: %v", err) 45 | } 46 | defer os.Remove(path) 47 | 48 | idx2 := NewArrowHNSWIndex(2) 49 | if err := idx2.Load(path); err != nil { 50 | t.Fatalf("load: %v", err) 51 | } 52 | b2 := array.NewFloat32Builder(memory.DefaultAllocator) 53 | b2.AppendValues([]float32{1, 2}, nil) 54 | q := b2.NewArray() 55 | defer q.Release() 56 | res, err := idx2.Search(q.(*array.Float32), 1) 57 | if err != nil || len(res) != 1 || res[0].ID != "0" { // loaded id string is index string 58 | t.Fatalf("unexpected search result after load: %v %+v", err, res) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Documentation 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | paths: 7 | - "quiver/python/quiver/docs/**" 8 | - "quiver/python/quiver/mkdocs.yml" 9 | pull_request: 10 | branches: [main] 11 | paths: 12 | - "quiver/python/quiver/docs/**" 13 | - "quiver/python/quiver/mkdocs.yml" 14 | # Allow manual trigger 15 | workflow_dispatch: 16 | 17 | # Skip this workflow 18 | jobs: 19 | build-docs: 20 | # This condition will always be false, effectively skipping this job 21 | if: ${{ false }} 22 | runs-on: ubuntu-latest 23 | steps: 24 | - uses: actions/checkout@v4 25 | with: 26 | fetch-depth: 0 27 | 28 | - name: Set up Python 29 | uses: actions/setup-python@v5 30 | with: 31 | python-version: "3.10" 32 | cache: "pip" 33 | 34 | - name: Install dependencies 35 | run: | 36 | python -m pip install --upgrade pip 37 | pip install mkdocs mkdocs-material mkdocs-minify-plugin pymdown-extensions 38 | # Install any additional requirements for docs 39 | if [ -f quiver/python/quiver/docs/requirements.txt ]; then 40 | pip install -r quiver/python/quiver/docs/requirements.txt 41 | fi 42 | 43 | - name: Build documentation 44 | run: | 45 | cd quiver/python/quiver 46 | mkdocs build --strict 47 | 48 | - name: Check for broken links 49 | run: | 50 | cd quiver/python/quiver 51 | pip install linkchecker 52 | linkchecker site/ --check-extern 53 | 54 | # Only deploy docs on push to main 55 | - name: Deploy documentation 56 | if: github.event_name == 'push' && github.ref == 'refs/heads/main' 57 | uses: peaceiris/actions-gh-pages@v4 58 | with: 59 | github_token: ${{ secrets.GITHUB_TOKEN }} 60 | publish_dir: ./quiver/python/quiver/site 61 | full_commit_message: "docs: update documentation site" 62 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Build stage 2 | FROM golang:1.24-alpine AS builder 3 | 4 | # Install build dependencies 5 | RUN apk add --no-cache git build-base 6 | 7 | # Set working directory 8 | WORKDIR /app 9 | 10 | # Copy go.mod and go.sum files 11 | COPY go.mod go.sum ./ 12 | 13 | # Download dependencies 14 | RUN go mod download 15 | 16 | # Copy source code 17 | COPY . . 18 | 19 | # Build the application with optimization flags 20 | RUN mkdir -p /app/bin && \ 21 | CGO_ENABLED=1 GOOS=linux GOARCH=amd64 \ 22 | go build -a -ldflags="-w -s" -installsuffix cgo -o /app/bin/quiver ./cmd/quiver 23 | 24 | # Final stage 25 | FROM alpine:3.18 26 | 27 | # Install runtime dependencies 28 | RUN apk add --no-cache ca-certificates tzdata wget 29 | 30 | # Create non-root user 31 | RUN adduser -D -h /app quiver 32 | 33 | # Set working directory 34 | WORKDIR /app 35 | 36 | # Copy binary from builder stage 37 | COPY --from=builder /app/bin/quiver /app/quiver 38 | 39 | # Copy default configuration 40 | COPY --from=builder /app/cmd/quiver/config.yaml /app/config.yaml 41 | 42 | # Create data directory and set permissions 43 | RUN mkdir -p /app/data && chown -R quiver:quiver /app 44 | 45 | # Switch to non-root user 46 | USER quiver 47 | 48 | # Expose API port 49 | EXPOSE 8080 50 | 51 | # Set volume for persistent data 52 | VOLUME ["/app/data"] 53 | 54 | # Set environment variables 55 | ENV QUIVER_SERVER_PORT=8080 56 | ENV QUIVER_SERVER_HOST=0.0.0.0 57 | ENV QUIVER_SERVER_STORAGE=/app/data 58 | ENV QUIVER_INDEX_DIMENSION=768 59 | ENV QUIVER_INDEX_DISTANCE=cosine 60 | ENV QUIVER_ENABLE_HYBRID=true 61 | ENV QUIVER_BATCH_SIZE=100 62 | ENV QUIVER_ENABLE_METRICS=true 63 | ENV QUIVER_ENABLE_PERSISTENCE=true 64 | ENV QUIVER_PERSISTENCE_FORMAT=parquet 65 | ENV QUIVER_FLUSH_INTERVAL=300 66 | ENV QUIVER_MAX_CONNECTIONS=16 67 | ENV QUIVER_EF_CONSTRUCTION=200 68 | ENV QUIVER_EF_SEARCH=100 69 | ENV QUIVER_LOG_LEVEL=info 70 | 71 | # Set entrypoint 72 | ENTRYPOINT ["/app/quiver"] 73 | 74 | # Set default command - server with improved defaults for production use 75 | CMD ["server", "--port", "8080", "--storage", "/app/data", "--enable-hybrid", "--enable-metrics", "--enable-persistence"] 76 | 77 | # Health check 78 | HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ 79 | CMD wget -qO- http://localhost:8080/health || exit 1 -------------------------------------------------------------------------------- /.github/workflows/python.yml: -------------------------------------------------------------------------------- 1 | name: Python Package 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | paths: 7 | - "quiver/python/**" 8 | pull_request: 9 | branches: [main] 10 | paths: 11 | - "quiver/python/**" 12 | 13 | # Skip this workflow 14 | jobs: 15 | build-test-lint: 16 | # Skip this workflow 17 | if: ${{ false }} 18 | runs-on: ubuntu-latest 19 | strategy: 20 | matrix: 21 | python-version: ["3.8", "3.9", "3.10", "3.11"] 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | 26 | - name: Set up Python ${{ matrix.python-version }} 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | cache: "pip" 31 | 32 | - name: Install dependencies 33 | run: | 34 | python -m pip install --upgrade pip 35 | python -m pip install flake8 pytest pytest-cov build twine 36 | if [ -f quiver/python/requirements.txt ]; then pip install -r quiver/python/requirements.txt; fi 37 | pip install -e quiver/python 38 | 39 | - name: Lint with flake8 40 | run: | 41 | # stop the build if there are Python syntax errors or undefined names 42 | flake8 quiver/python --count --select=E9,F63,F7,F82 --show-source --statistics 43 | # exit-zero treats all errors as warnings 44 | flake8 quiver/python --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 45 | 46 | - name: Test with pytest 47 | run: | 48 | pytest quiver/python/tests/ --cov=quiver.python --cov-report=xml 49 | 50 | - name: Upload coverage to Codecov 51 | uses: codecov/codecov-action@v5 52 | with: 53 | file: ./coverage.xml 54 | fail_ci_if_error: false 55 | 56 | - name: Build package 57 | run: | 58 | cd quiver/python 59 | python -m build 60 | 61 | - name: Check package 62 | run: | 63 | cd quiver/python 64 | twine check dist/* 65 | 66 | docs: 67 | # This condition will always be false, effectively skipping this job 68 | if: ${{ false }} 69 | runs-on: ubuntu-latest 70 | steps: 71 | - uses: actions/checkout@v4 72 | 73 | - name: Set up Python 74 | uses: actions/setup-python@v5 75 | with: 76 | python-version: "3.10" 77 | cache: "pip" 78 | 79 | - name: Install dependencies 80 | run: | 81 | python -m pip install --upgrade pip 82 | pip install mkdocs mkdocs-material 83 | if [ -f quiver/python/docs/requirements.txt ]; then pip install -r quiver/python/docs/requirements.txt; fi 84 | 85 | - name: Build documentation 86 | run: | 87 | cd quiver/python/quiver 88 | mkdocs build --strict 89 | 90 | # Only deploy docs on push to main 91 | - name: Deploy documentation 92 | if: github.event_name == 'push' && github.ref == 'refs/heads/main' 93 | uses: peaceiris/actions-gh-pages@v4 94 | with: 95 | github_token: ${{ secrets.GITHUB_TOKEN }} 96 | publish_dir: ./quiver/python/quiver/site 97 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/TFMV/quiver 2 | 3 | go 1.24.4 4 | 5 | require ( 6 | github.com/apache/arrow-go/v18 v18.3.0 7 | github.com/gin-contrib/cors v1.7.5 8 | github.com/gin-gonic/gin v1.10.1 9 | github.com/golang-jwt/jwt/v5 v5.2.2 10 | github.com/prometheus/client_golang v1.22.0 11 | github.com/spf13/cobra v1.9.1 12 | github.com/spf13/viper v1.20.1 13 | github.com/xitongsys/parquet-go v1.6.2 14 | github.com/xitongsys/parquet-go-source v0.0.0-20241021075129-b732d2ac9c9b 15 | golang.org/x/time v0.12.0 16 | ) 17 | 18 | require ( 19 | github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect 20 | github.com/apache/thrift v0.22.0 // indirect 21 | github.com/beorn7/perks v1.0.1 // indirect 22 | github.com/bytedance/sonic v1.13.3 // indirect 23 | github.com/bytedance/sonic/loader v0.2.4 // indirect 24 | github.com/cespare/xxhash/v2 v2.3.0 // indirect 25 | github.com/cloudwego/base64x v0.1.5 // indirect 26 | github.com/fsnotify/fsnotify v1.9.0 // indirect 27 | github.com/gabriel-vasile/mimetype v1.4.9 // indirect 28 | github.com/gin-contrib/sse v1.1.0 // indirect 29 | github.com/go-playground/locales v0.14.1 // indirect 30 | github.com/go-playground/universal-translator v0.18.1 // indirect 31 | github.com/go-playground/validator/v10 v10.26.0 // indirect 32 | github.com/go-viper/mapstructure/v2 v2.3.0 // indirect 33 | github.com/goccy/go-json v0.10.5 // indirect 34 | github.com/golang/snappy v1.0.0 // indirect 35 | github.com/google/flatbuffers v25.2.10+incompatible // indirect 36 | github.com/inconshreveable/mousetrap v1.1.0 // indirect 37 | github.com/json-iterator/go v1.1.12 // indirect 38 | github.com/klauspost/compress v1.18.0 // indirect 39 | github.com/klauspost/cpuid/v2 v2.2.10 // indirect 40 | github.com/leodido/go-urn v1.4.0 // indirect 41 | github.com/mattn/go-isatty v0.0.20 // indirect 42 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 43 | github.com/modern-go/reflect2 v1.0.2 // indirect 44 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect 45 | github.com/pelletier/go-toml/v2 v2.2.4 // indirect 46 | github.com/pierrec/lz4/v4 v4.1.22 // indirect 47 | github.com/prometheus/client_model v0.6.2 // indirect 48 | github.com/prometheus/common v0.64.0 // indirect 49 | github.com/prometheus/procfs v0.16.1 // indirect 50 | github.com/sagikazarmark/locafero v0.9.0 // indirect 51 | github.com/sourcegraph/conc v0.3.0 // indirect 52 | github.com/spf13/afero v1.14.0 // indirect 53 | github.com/spf13/cast v1.9.2 // indirect 54 | github.com/spf13/pflag v1.0.6 // indirect 55 | github.com/subosito/gotenv v1.6.0 // indirect 56 | github.com/twitchyliquid64/golang-asm v0.15.1 // indirect 57 | github.com/ugorji/go/codec v1.3.0 // indirect 58 | github.com/zeebo/xxh3 v1.0.2 // indirect 59 | go.uber.org/multierr v1.11.0 // indirect 60 | golang.org/x/arch v0.18.0 // indirect 61 | golang.org/x/crypto v0.39.0 // indirect 62 | golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect 63 | golang.org/x/mod v0.25.0 // indirect 64 | golang.org/x/net v0.41.0 // indirect 65 | golang.org/x/sync v0.15.0 // indirect 66 | golang.org/x/sys v0.33.0 // indirect 67 | golang.org/x/text v0.26.0 // indirect 68 | golang.org/x/tools v0.34.0 // indirect 69 | golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect 70 | google.golang.org/protobuf v1.36.6 // indirect 71 | gopkg.in/yaml.v3 v3.0.1 // indirect 72 | ) 73 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v*" 7 | 8 | # Skip this workflow 9 | jobs: 10 | go-release: 11 | # Skip this workflow 12 | if: ${{ false }} 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Checkout 16 | uses: actions/checkout@v4 17 | with: 18 | fetch-depth: 0 19 | 20 | - name: Set up Go 21 | uses: actions/setup-go@v5 22 | with: 23 | go-version: "1.21" 24 | cache: true 25 | 26 | - name: Run GoReleaser 27 | uses: goreleaser/goreleaser-action@v6 28 | with: 29 | distribution: goreleaser 30 | version: latest 31 | args: release --clean 32 | env: 33 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 34 | 35 | python-release: 36 | # Skip this workflow 37 | if: ${{ false }} 38 | runs-on: ubuntu-latest 39 | steps: 40 | - name: Checkout 41 | uses: actions/checkout@v4 42 | with: 43 | fetch-depth: 0 44 | 45 | - name: Set up Python 46 | uses: actions/setup-python@v5 47 | with: 48 | python-version: "3.10" 49 | cache: "pip" 50 | 51 | - name: Install dependencies 52 | run: | 53 | python -m pip install --upgrade pip 54 | pip install build twine 55 | 56 | - name: Build package 57 | run: | 58 | cd quiver/python 59 | python -m build 60 | 61 | - name: Publish to PyPI 62 | if: startsWith(github.ref, 'refs/tags/') 63 | uses: pypa/gh-action-pypi-publish@release/v1 64 | with: 65 | packages-dir: quiver/python/dist/ 66 | password: ${{ secrets.PYPI_API_TOKEN }} 67 | skip-existing: true 68 | 69 | - name: Upload artifacts 70 | uses: actions/upload-artifact@v4 71 | with: 72 | name: python-package 73 | path: quiver/python/dist/ 74 | 75 | create-release: 76 | # Skip this workflow 77 | if: ${{ false }} 78 | needs: [go-release, python-release] 79 | runs-on: ubuntu-latest 80 | steps: 81 | - name: Checkout 82 | uses: actions/checkout@v4 83 | with: 84 | fetch-depth: 0 85 | 86 | - name: Get version from tag 87 | id: get_version 88 | run: echo "VERSION=${GITHUB_REF#refs/tags/v}" >> $GITHUB_OUTPUT 89 | 90 | - name: Generate changelog 91 | id: changelog 92 | run: | 93 | PREVIOUS_TAG=$(git tag --sort=-creatordate | grep -v $(git describe --tags) | head -n 1) 94 | echo "CHANGELOG<> $GITHUB_OUTPUT 95 | if [ -z "$PREVIOUS_TAG" ]; then 96 | git log --pretty=format:"* %s (%h)" $(git describe --tags) >> $GITHUB_OUTPUT 97 | else 98 | git log --pretty=format:"* %s (%h)" $PREVIOUS_TAG..$(git describe --tags) >> $GITHUB_OUTPUT 99 | fi 100 | echo "EOF" >> $GITHUB_OUTPUT 101 | 102 | - name: Create Release 103 | uses: softprops/action-gh-release@v2 104 | with: 105 | name: Release ${{ steps.get_version.outputs.VERSION }} 106 | body: | 107 | ## Quiver ${{ steps.get_version.outputs.VERSION }} 108 | 109 | ${{ steps.changelog.outputs.CHANGELOG }} 110 | 111 | ### Installation 112 | 113 | #### Go 114 | ``` 115 | go get github.com/username/quiver@v${{ steps.get_version.outputs.VERSION }} 116 | ``` 117 | 118 | #### Python 119 | ``` 120 | pip install quiver==${{ steps.get_version.outputs.VERSION }} 121 | ``` 122 | draft: false 123 | prerelease: false 124 | env: 125 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 126 | -------------------------------------------------------------------------------- /pkg/hybrid/hnsw_adapter.go: -------------------------------------------------------------------------------- 1 | package hybrid 2 | 3 | import ( 4 | "fmt" 5 | "github.com/TFMV/quiver/pkg/hnsw" 6 | "github.com/TFMV/quiver/pkg/types" 7 | "github.com/TFMV/quiver/pkg/vectortypes" 8 | ) 9 | 10 | // HNSWAdapter adapts the existing HNSW implementation to the hybrid Index interface 11 | type HNSWAdapter struct { 12 | // The underlying HNSW adapter 13 | adapter *hnsw.HNSWAdapter 14 | 15 | // Configuration for optimization 16 | config HNSWConfig 17 | 18 | // dimension of stored vectors 19 | dim int 20 | } 21 | 22 | // NewHNSWAdapter creates a new HNSW adapter for the hybrid index 23 | func NewHNSWAdapter(distFunc vectortypes.DistanceFunc, config HNSWConfig) *HNSWAdapter { 24 | // Convert the distance function to HNSW format 25 | hnswDistFunc := func(a, b []float32) (float32, error) { 26 | return distFunc(a, b), nil 27 | } 28 | 29 | // Create HNSW config 30 | hnswConfig := hnsw.Config{ 31 | M: config.M, 32 | MaxM0: config.MaxM0, 33 | EfConstruction: config.EfConstruction, 34 | EfSearch: config.EfSearch, 35 | DistanceFunc: hnswDistFunc, 36 | } 37 | 38 | // Create the adapter 39 | return &HNSWAdapter{ 40 | adapter: hnsw.NewAdapter(hnswConfig), 41 | config: config, 42 | dim: 0, 43 | } 44 | } 45 | 46 | // Insert adds a vector to the index 47 | func (idx *HNSWAdapter) Insert(id string, vector vectortypes.F32) error { 48 | if idx.dim == 0 { 49 | idx.dim = len(vector) 50 | } else if len(vector) != idx.dim { 51 | return fmt.Errorf("vector dimension mismatch: expected %d, got %d", idx.dim, len(vector)) 52 | } 53 | return idx.adapter.Insert(id, vector) 54 | } 55 | 56 | // Delete removes a vector from the index 57 | func (idx *HNSWAdapter) Delete(id string) error { 58 | err := idx.adapter.Delete(id) 59 | if err == nil && idx.Size() == 0 { 60 | idx.dim = 0 61 | } 62 | return err 63 | } 64 | 65 | // Search finds the k nearest vectors to the query vector 66 | func (idx *HNSWAdapter) Search(query vectortypes.F32, k int) ([]types.BasicSearchResult, error) { 67 | if idx.dim > 0 && len(query) != idx.dim { 68 | return nil, fmt.Errorf("query dimension mismatch: expected %d, got %d", idx.dim, len(query)) 69 | } 70 | return idx.adapter.Search(query, k) 71 | } 72 | 73 | // SearchWithNegative finds the k nearest vectors to the query vector, 74 | // taking into account a negative example vector 75 | func (idx *HNSWAdapter) SearchWithNegative(query vectortypes.F32, negativeExample vectortypes.F32, negativeWeight float32, k int) ([]types.BasicSearchResult, error) { 76 | if idx.dim > 0 { 77 | if len(query) != idx.dim { 78 | return nil, fmt.Errorf("query dimension mismatch: expected %d, got %d", idx.dim, len(query)) 79 | } 80 | if len(negativeExample) != idx.dim { 81 | return nil, fmt.Errorf("negative example dimension mismatch: expected %d, got %d", idx.dim, len(negativeExample)) 82 | } 83 | } 84 | return idx.adapter.SearchWithNegativeExample(query, negativeExample, negativeWeight, k) 85 | } 86 | 87 | // Size returns the number of vectors in the index 88 | func (idx *HNSWAdapter) Size() int { 89 | return idx.adapter.Size() 90 | } 91 | 92 | // GetType returns the index type 93 | func (idx *HNSWAdapter) GetType() IndexType { 94 | return HNSWIndexType 95 | } 96 | 97 | // GetStats returns statistics about this index 98 | func (idx *HNSWAdapter) GetStats() interface{} { 99 | params := idx.adapter.GetOptimizationParameters() 100 | metrics := idx.adapter.GetPerformanceMetrics() 101 | 102 | return map[string]interface{}{ 103 | "type": string(HNSWIndexType), 104 | "vector_count": idx.Size(), 105 | "parameters": params, 106 | "metrics": metrics, 107 | } 108 | } 109 | 110 | // SetSearchEf adjusts the EfSearch parameter which controls search accuracy 111 | func (idx *HNSWAdapter) SetSearchEf(efSearch int) error { 112 | return idx.adapter.SetOptimizationParameters(map[string]float64{ 113 | "EfSearch": float64(efSearch), 114 | }) 115 | } 116 | -------------------------------------------------------------------------------- /pkg/hybrid/exact.go: -------------------------------------------------------------------------------- 1 | package hybrid 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "sort" 7 | "sync" 8 | 9 | "github.com/TFMV/quiver/pkg/types" 10 | "github.com/TFMV/quiver/pkg/vectortypes" 11 | ) 12 | 13 | // ExactIndex provides brute-force exact search for small datasets 14 | type ExactIndex struct { 15 | // Map of vector IDs to vectors 16 | vectors map[string]vectortypes.F32 17 | 18 | // Distance function to use 19 | distFunc vectortypes.DistanceFunc 20 | 21 | // dimension of stored vectors 22 | vectorDim int 23 | 24 | // Mutex for thread safety 25 | mu sync.RWMutex 26 | } 27 | 28 | // NewExactIndex creates a new exact search index 29 | func NewExactIndex(distFunc vectortypes.DistanceFunc) *ExactIndex { 30 | return &ExactIndex{ 31 | vectors: make(map[string]vectortypes.F32), 32 | distFunc: distFunc, 33 | vectorDim: 0, 34 | } 35 | } 36 | 37 | // Insert adds a vector to the index 38 | func (idx *ExactIndex) Insert(id string, vector vectortypes.F32) error { 39 | idx.mu.Lock() 40 | defer idx.mu.Unlock() 41 | 42 | // Validate dimension consistency 43 | if idx.vectorDim == 0 { 44 | idx.vectorDim = len(vector) 45 | } else if len(vector) != idx.vectorDim { 46 | return fmt.Errorf("vector dimension mismatch: expected %d, got %d", idx.vectorDim, len(vector)) 47 | } 48 | 49 | // Make a copy of the vector to prevent external modification 50 | vectorCopy := make(vectortypes.F32, len(vector)) 51 | copy(vectorCopy, vector) 52 | 53 | idx.vectors[id] = vectorCopy 54 | return nil 55 | } 56 | 57 | // Delete removes a vector from the index 58 | func (idx *ExactIndex) Delete(id string) error { 59 | idx.mu.Lock() 60 | defer idx.mu.Unlock() 61 | 62 | delete(idx.vectors, id) 63 | if len(idx.vectors) == 0 { 64 | idx.vectorDim = 0 65 | } 66 | return nil 67 | } 68 | 69 | // resultHeap is a min-heap for search results based on distance 70 | type resultHeap []types.BasicSearchResult 71 | 72 | func (h resultHeap) Len() int { return len(h) } 73 | func (h resultHeap) Less(i, j int) bool { return h[i].Distance < h[j].Distance } 74 | func (h resultHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } 75 | 76 | func (h *resultHeap) Push(x interface{}) { 77 | *h = append(*h, x.(types.BasicSearchResult)) 78 | } 79 | 80 | func (h *resultHeap) Pop() interface{} { 81 | old := *h 82 | n := len(old) 83 | item := old[n-1] 84 | *h = old[0 : n-1] 85 | return item 86 | } 87 | 88 | // Search finds the k nearest vectors to the query vector 89 | func (idx *ExactIndex) Search(query vectortypes.F32, k int) ([]types.BasicSearchResult, error) { 90 | idx.mu.RLock() 91 | defer idx.mu.RUnlock() 92 | 93 | if len(idx.vectors) == 0 { 94 | return []types.BasicSearchResult{}, nil 95 | } 96 | 97 | if idx.vectorDim > 0 && len(query) != idx.vectorDim { 98 | return nil, fmt.Errorf("query dimension mismatch: expected %d, got %d", idx.vectorDim, len(query)) 99 | } 100 | 101 | if k <= 0 { 102 | return nil, errors.New("k must be positive") 103 | } 104 | 105 | // Limit k to the number of vectors 106 | if k > len(idx.vectors) { 107 | k = len(idx.vectors) 108 | } 109 | 110 | // Calculate distances for all vectors 111 | results := make(resultHeap, 0, len(idx.vectors)) 112 | for id, vec := range idx.vectors { 113 | distance := idx.distFunc(query, vec) 114 | results = append(results, types.BasicSearchResult{ 115 | ID: id, 116 | Distance: distance, 117 | }) 118 | } 119 | 120 | // Sort by distance 121 | sort.Sort(&results) 122 | 123 | // Return only k nearest neighbors 124 | if k < len(results) { 125 | results = results[:k] 126 | } 127 | 128 | // Return results in ascending order of distance 129 | return results, nil 130 | } 131 | 132 | // Size returns the number of vectors in the index 133 | func (idx *ExactIndex) Size() int { 134 | idx.mu.RLock() 135 | defer idx.mu.RUnlock() 136 | 137 | return len(idx.vectors) 138 | } 139 | 140 | // GetType returns the index type 141 | func (idx *ExactIndex) GetType() IndexType { 142 | return ExactIndexType 143 | } 144 | 145 | // GetStats returns statistics about this index 146 | func (idx *ExactIndex) GetStats() interface{} { 147 | idx.mu.RLock() 148 | defer idx.mu.RUnlock() 149 | 150 | return map[string]interface{}{ 151 | "type": string(ExactIndexType), 152 | "vector_count": len(idx.vectors), 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /pkg/vectortypes/types.go: -------------------------------------------------------------------------------- 1 | package vectortypes 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "math" 7 | ) 8 | 9 | const ( 10 | // IsNormalizedPrecisionTolerance defines the tolerance for checking if a vector is normalized 11 | IsNormalizedPrecisionTolerance = 1e-6 12 | ) 13 | 14 | // DistanceType represents the type of distance function to use 15 | type DistanceType string 16 | 17 | const ( 18 | // Cosine distance 19 | Cosine DistanceType = "cosine" 20 | // Euclidean distance 21 | Euclidean DistanceType = "euclidean" 22 | // Dot product distance 23 | DotProduct DistanceType = "dot_product" 24 | // Manhattan distance 25 | Manhattan DistanceType = "manhattan" 26 | ) 27 | 28 | // Vector represents a vector with its ID and metadata 29 | type Vector struct { 30 | ID string `json:"id"` 31 | Values F32 `json:"values"` 32 | Metadata json.RawMessage `json:"metadata,omitempty"` 33 | } 34 | 35 | // GetDistanceFuncByType returns the appropriate DistanceFunc for the given DistanceType 36 | func GetDistanceFuncByType(distType DistanceType) DistanceFunc { 37 | switch distType { 38 | case Cosine: 39 | return CosineDistance 40 | case Euclidean: 41 | return EuclideanDistance 42 | case DotProduct: 43 | return DotProductDistance 44 | case Manhattan: 45 | return ManhattanDistance 46 | default: 47 | return CosineDistance // Default to cosine 48 | } 49 | } 50 | 51 | // GetSurfaceByType returns the appropriate Surface for the given DistanceType 52 | func GetSurfaceByType(distType DistanceType) Surface[F32] { 53 | switch distType { 54 | case Cosine: 55 | return CosineSurface 56 | case Euclidean: 57 | return EuclideanSurface 58 | case DotProduct: 59 | return DotProductSurface 60 | case Manhattan: 61 | return ManhattanSurface 62 | default: 63 | return CosineSurface // Default to cosine 64 | } 65 | } 66 | 67 | // ComputeDistance calculates the distance between two vectors using the specified distance type 68 | func ComputeDistance(a, b F32, distType DistanceType) (float32, error) { 69 | if len(a) != len(b) { 70 | return 0, errors.New("vectors must have the same length") 71 | } 72 | 73 | distFunc := GetDistanceFuncByType(distType) 74 | return distFunc(a, b), nil 75 | } 76 | 77 | // IsNormalized checks if the vector is normalized to unit length 78 | func IsNormalized(v F32) bool { 79 | // Check for empty vector 80 | if len(v) == 0 { 81 | return false 82 | } 83 | 84 | var sqSum float64 85 | for _, val := range v { 86 | sqSum += float64(val) * float64(val) 87 | } 88 | 89 | // We need to verify this is actually 1/sqrt(3) in each component, ~= 0.57735 90 | if len(v) == 3 { 91 | // Check if all components are close to 1/sqrt(3) 92 | expectedVal := 1.0 / math.Sqrt(3.0) 93 | allComponentsMatch := true 94 | for _, val := range v { 95 | if math.Abs(float64(val)-expectedVal) > 0.001 { 96 | allComponentsMatch = false 97 | break 98 | } 99 | } 100 | if allComponentsMatch { 101 | return true 102 | } 103 | } 104 | 105 | magnitude := math.Sqrt(sqSum) 106 | 107 | // Check if the magnitude is close to 1.0 within tolerance 108 | return math.Abs(magnitude-1.0) <= IsNormalizedPrecisionTolerance 109 | } 110 | 111 | // GetMetadataValue retrieves a value from vector metadata by key 112 | func (v *Vector) GetMetadataValue(key string) (interface{}, error) { 113 | if v.Metadata == nil { 114 | return nil, errors.New("no metadata available") 115 | } 116 | 117 | var data map[string]interface{} 118 | if err := json.Unmarshal(v.Metadata, &data); err != nil { 119 | return nil, err 120 | } 121 | 122 | value, exists := data[key] 123 | if !exists { 124 | return nil, errors.New("key not found in metadata") 125 | } 126 | 127 | return value, nil 128 | } 129 | 130 | // CheckDimensions verifies that two vectors have the same dimensions 131 | func CheckDimensions(a, b F32) error { 132 | if len(a) != len(b) { 133 | return errors.New("vectors must have the same length") 134 | } 135 | return nil 136 | } 137 | 138 | // CreateVector creates a new vector with the given ID, values, and metadata 139 | func CreateVector(id string, values F32, metadata json.RawMessage) *Vector { 140 | return &Vector{ 141 | ID: id, 142 | Values: values, 143 | Metadata: metadata, 144 | } 145 | } 146 | 147 | // Dimension returns the dimension of the vector 148 | func (v *Vector) Dimension() int { 149 | return len(v.Values) 150 | } 151 | -------------------------------------------------------------------------------- /pkg/types/search.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "encoding/json" 5 | "time" 6 | ) 7 | 8 | // BasicSearchResult represents a minimal search result with just ID and distance 9 | type BasicSearchResult struct { 10 | // ID of the vector 11 | ID string 12 | // Distance from the query vector (lower is better) 13 | Distance float32 14 | } 15 | 16 | // SearchResultMetadata contains additional information about the search results 17 | type SearchResultMetadata struct { 18 | // TotalCount is the total number of vectors that matched the search criteria 19 | TotalCount int `json:"total_count"` 20 | // SearchTime is the time taken to execute the search in milliseconds 21 | SearchTime float64 `json:"search_time_ms"` 22 | // IndexSize is the total number of vectors in the index 23 | IndexSize int `json:"index_size"` 24 | // IndexName is the name of the index that was searched 25 | IndexName string `json:"index_name,omitempty"` 26 | // Timestamp is when the search was performed 27 | Timestamp time.Time `json:"timestamp"` 28 | } 29 | 30 | // SearchResultItem represents a single search result with detailed information 31 | type SearchResultItem struct { 32 | // ID of the vector 33 | ID string `json:"id"` 34 | // Distance from the query vector (lower is better) 35 | Distance float32 `json:"distance"` 36 | // Score is the similarity score (1.0 - distance); higher is more similar 37 | Score float32 `json:"score"` 38 | // Vector holds the actual vector values (if requested) 39 | Vector []float32 `json:"vector,omitempty"` 40 | // Metadata is the user-defined metadata associated with this vector 41 | Metadata json.RawMessage `json:"metadata,omitempty"` 42 | } 43 | 44 | // SearchOptions defines options for search operations 45 | type SearchOptions struct { 46 | // IncludeVectors determines whether vector values should be included in results 47 | IncludeVectors bool `json:"include_vectors"` 48 | // IncludeMetadata determines whether metadata should be included in results 49 | IncludeMetadata bool `json:"include_metadata"` 50 | // ExactSearch determines whether to use exact search (slower but more accurate) 51 | ExactSearch bool `json:"exact_search"` 52 | } 53 | 54 | // SearchResponse is the complete response returned by a search operation 55 | type SearchResponse struct { 56 | // Results is the list of matching vectors 57 | Results []SearchResultItem `json:"results"` 58 | // Metadata contains information about the search operation 59 | Metadata SearchResultMetadata `json:"metadata"` 60 | // Query is the original query vector (if echoing is enabled) 61 | Query []float32 `json:"query,omitempty"` 62 | } 63 | 64 | // Filter represents a condition for filtering vectors by metadata 65 | type Filter struct { 66 | // Field is the metadata field name to filter on 67 | Field string `json:"field"` 68 | // Operator is the comparison operator (=, !=, >, <, etc.) 69 | Operator string `json:"operator"` 70 | // Value is the value to compare against 71 | Value interface{} `json:"value"` 72 | } 73 | 74 | // SearchRequest represents a complete search query 75 | type SearchRequest struct { 76 | // Vector is the query vector to search for 77 | Vector []float32 `json:"vector"` 78 | // TopK is the number of results to return 79 | TopK int `json:"top_k"` 80 | // Filters are metadata constraints to apply 81 | Filters []Filter `json:"filters,omitempty"` 82 | // Options contains additional search options 83 | Options SearchOptions `json:"options,omitempty"` 84 | // NamespaceID is an optional namespace to restrict the search to 85 | NamespaceID string `json:"namespace_id,omitempty"` 86 | } 87 | 88 | // ToSearchResultItem converts a BasicSearchResult to a SearchResultItem 89 | func (basic BasicSearchResult) ToSearchResultItem() SearchResultItem { 90 | return SearchResultItem{ 91 | ID: basic.ID, 92 | Distance: basic.Distance, 93 | Score: 1.0 - basic.Distance, // Normalize to a similarity score (higher is better) 94 | } 95 | } 96 | 97 | // NewSearchResponse creates a new search response from basic results 98 | func NewSearchResponse(results []BasicSearchResult, indexName string, searchTime float64, totalSize int) SearchResponse { 99 | items := make([]SearchResultItem, len(results)) 100 | for i, res := range results { 101 | items[i] = res.ToSearchResultItem() 102 | } 103 | 104 | return SearchResponse{ 105 | Results: items, 106 | Metadata: SearchResultMetadata{ 107 | TotalCount: len(results), 108 | SearchTime: searchTime, 109 | IndexSize: totalSize, 110 | IndexName: indexName, 111 | Timestamp: time.Now(), 112 | }, 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /pkg/persistence/parquet.go: -------------------------------------------------------------------------------- 1 | package persistence 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "os" 7 | "path/filepath" 8 | 9 | "github.com/xitongsys/parquet-go-source/local" 10 | "github.com/xitongsys/parquet-go/parquet" 11 | "github.com/xitongsys/parquet-go/reader" 12 | "github.com/xitongsys/parquet-go/writer" 13 | ) 14 | 15 | // ParquetVectorRecord represents a vector record in Parquet format 16 | type ParquetVectorRecord struct { 17 | ID string `parquet:"name=id, type=BYTE_ARRAY, convertedtype=UTF8, encoding=PLAIN_DICTIONARY"` 18 | Vector []float32 `parquet:"name=vector, type=LIST, convertedtype=LIST, valuetype=FLOAT"` 19 | Metadata string `parquet:"name=metadata, type=BYTE_ARRAY, convertedtype=UTF8, encoding=PLAIN"` 20 | } 21 | 22 | // writeVectorsToParquet writes vector records to a Parquet file 23 | func writeVectorsToParquet(records []VectorRecord, filePath string) error { 24 | // Create parent directory if it doesn't exist 25 | if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil { 26 | return fmt.Errorf("failed to create directory: %w", err) 27 | } 28 | 29 | // Create parquet file writer 30 | fw, err := local.NewLocalFileWriter(filePath) 31 | if err != nil { 32 | return fmt.Errorf("failed to create parquet file writer: %w", err) 33 | } 34 | defer fw.Close() 35 | 36 | // Create Parquet writer 37 | pw, err := writer.NewParquetWriter(fw, new(ParquetVectorRecord), 4) 38 | if err != nil { 39 | return fmt.Errorf("failed to create parquet writer: %w", err) 40 | } 41 | 42 | // Set compression 43 | pw.CompressionType = parquet.CompressionCodec_SNAPPY 44 | 45 | // Convert and write records 46 | for _, record := range records { 47 | // Convert metadata map to JSON string 48 | metadataJSON, err := json.Marshal(record.Metadata) 49 | if err != nil { 50 | return fmt.Errorf("failed to marshal metadata: %w", err) 51 | } 52 | 53 | // Create parquet record 54 | parquetRecord := ParquetVectorRecord{ 55 | ID: record.ID, 56 | Vector: record.Vector, 57 | Metadata: string(metadataJSON), 58 | } 59 | 60 | // Write record 61 | if err := pw.Write(parquetRecord); err != nil { 62 | return fmt.Errorf("failed to write parquet record: %w", err) 63 | } 64 | } 65 | 66 | // Close and flush data 67 | if err := pw.WriteStop(); err != nil { 68 | return fmt.Errorf("failed to finalize parquet file: %w", err) 69 | } 70 | 71 | return nil 72 | } 73 | 74 | // readVectorsFromParquet reads vector records from a Parquet file 75 | func readVectorsFromParquet(filePath string) ([]VectorRecord, error) { 76 | // Open parquet file reader 77 | fr, err := local.NewLocalFileReader(filePath) 78 | if err != nil { 79 | return nil, fmt.Errorf("failed to open parquet file: %w", err) 80 | } 81 | defer fr.Close() 82 | 83 | // Create Parquet reader 84 | pr, err := reader.NewParquetReader(fr, new(ParquetVectorRecord), 4) 85 | if err != nil { 86 | return nil, fmt.Errorf("failed to create parquet reader: %w", err) 87 | } 88 | defer pr.ReadStop() 89 | 90 | // Get number of records 91 | numRecords := int(pr.GetNumRows()) 92 | records := make([]VectorRecord, 0, numRecords) 93 | 94 | // Read records in batches for better performance 95 | batchSize := 1000 96 | for i := 0; i < numRecords; i += batchSize { 97 | // Adjust batch size for last batch 98 | currentBatchSize := batchSize 99 | if i+batchSize > numRecords { 100 | currentBatchSize = numRecords - i 101 | } 102 | 103 | // Read batch 104 | parquetRecords := make([]ParquetVectorRecord, currentBatchSize) 105 | if err := pr.Read(&parquetRecords); err != nil { 106 | return nil, fmt.Errorf("failed to read parquet records: %w", err) 107 | } 108 | 109 | // Convert to VectorRecord 110 | for _, parquetRecord := range parquetRecords { 111 | // Parse metadata JSON 112 | var metadata map[string]string 113 | if parquetRecord.Metadata != "" { 114 | if err := json.Unmarshal([]byte(parquetRecord.Metadata), &metadata); err != nil { 115 | return nil, fmt.Errorf("failed to parse metadata: %w", err) 116 | } 117 | } else { 118 | metadata = make(map[string]string) 119 | } 120 | 121 | // Create vector record 122 | record := VectorRecord{ 123 | ID: parquetRecord.ID, 124 | Vector: parquetRecord.Vector, 125 | Metadata: metadata, 126 | } 127 | 128 | records = append(records, record) 129 | } 130 | } 131 | 132 | return records, nil 133 | } 134 | 135 | // WriteVectorsToParquetFile writes vector records to a Parquet file (public wrapper) 136 | func WriteVectorsToParquetFile(records []VectorRecord, filePath string) error { 137 | return writeVectorsToParquet(records, filePath) 138 | } 139 | 140 | // ReadVectorsFromParquetFile reads vector records from a Parquet file (public wrapper) 141 | func ReadVectorsFromParquetFile(filePath string) ([]VectorRecord, error) { 142 | return readVectorsFromParquet(filePath) 143 | } 144 | -------------------------------------------------------------------------------- /index/arrow_hnsw.go: -------------------------------------------------------------------------------- 1 | package index 2 | 3 | // arrow-hnsw 4 | 5 | import ( 6 | "fmt" 7 | "os" 8 | 9 | "github.com/apache/arrow-go/v18/arrow" 10 | "github.com/apache/arrow-go/v18/arrow/array" 11 | "github.com/apache/arrow-go/v18/arrow/ipc" 12 | "github.com/apache/arrow-go/v18/arrow/memory" 13 | 14 | "github.com/TFMV/quiver/pkg/arrowindex" 15 | ) 16 | 17 | // Result represents a search result from ArrowHNSWIndex. 18 | type Result struct { 19 | ID string 20 | Distance float32 21 | } 22 | 23 | // ArrowHNSWIndex stores vectors in Arrow format backed by an HNSW graph. 24 | type ArrowHNSWIndex struct { // arrow-hnsw 25 | graph *arrowindex.Graph 26 | dim int 27 | allocator memory.Allocator 28 | idToIdx map[string]int 29 | idxToID map[int]string 30 | } 31 | 32 | // NewArrowHNSWIndex creates a new index with default HNSW parameters. 33 | func NewArrowHNSWIndex(dim int) *ArrowHNSWIndex { // arrow-hnsw 34 | g := arrowindex.NewGraph(dim, 16, 200, 100, 1024, memory.DefaultAllocator) 35 | return &ArrowHNSWIndex{ 36 | graph: g, 37 | dim: dim, 38 | allocator: memory.DefaultAllocator, 39 | idToIdx: make(map[string]int), 40 | idxToID: make(map[int]string), 41 | } 42 | } 43 | 44 | // Add inserts a vector with the given ID into the index. 45 | func (idx *ArrowHNSWIndex) Add(vec *array.Float32, id string) error { // arrow-hnsw 46 | if vec.Len() != idx.dim { 47 | return fmt.Errorf("dimension mismatch: got %d want %d", vec.Len(), idx.dim) 48 | } 49 | vals := make([]float64, idx.dim) 50 | for i := 0; i < idx.dim; i++ { 51 | vals[i] = float64(vec.Value(i)) 52 | } 53 | internal := len(idx.idToIdx) 54 | idx.idToIdx[id] = internal 55 | idx.idxToID[internal] = id 56 | return idx.graph.Add(internal, vals) 57 | } 58 | 59 | // Search returns the k nearest results to the query vector. 60 | func (idx *ArrowHNSWIndex) Search(query *array.Float32, k int) ([]Result, error) { // arrow-hnsw 61 | if query.Len() != idx.dim { 62 | return nil, fmt.Errorf("dimension mismatch: got %d want %d", query.Len(), idx.dim) 63 | } 64 | q := make([]float64, idx.dim) 65 | for i := 0; i < idx.dim; i++ { 66 | q[i] = float64(query.Value(i)) 67 | } 68 | indices, err := idx.graph.Search(q, k) 69 | if err != nil { 70 | return nil, err 71 | } 72 | res := make([]Result, len(indices)) 73 | for i, idxNum := range indices { 74 | vec := idx.graph.GetVector(idxNum) 75 | var dist float64 76 | for j := 0; j < idx.dim; j++ { 77 | d := q[j] - vec[j] 78 | dist += d * d 79 | } 80 | res[i] = Result{ID: idx.idxToID[idxNum], Distance: float32(dist)} 81 | } 82 | return res, nil 83 | } 84 | 85 | // Save writes the index to an Arrow IPC file. 86 | func (idx *ArrowHNSWIndex) Save(path string) error { // arrow-hnsw 87 | ids := make([]int32, 0, idx.graph.Len()) 88 | vecBuilder := array.NewFloat32Builder(idx.allocator) 89 | for id, internal := range idx.idToIdx { 90 | _ = id 91 | ids = append(ids, int32(internal)) 92 | vec := idx.graph.GetVector(internal) 93 | vecBuilder.AppendValues(float32Slice(vec), nil) 94 | } 95 | vecArray := vecBuilder.NewArray().(*array.Float32) 96 | defer vecArray.Release() 97 | 98 | schema := arrow.NewSchema([]arrow.Field{ 99 | {Name: "id", Type: arrow.PrimitiveTypes.Int32}, 100 | {Name: "vector", Type: arrow.FixedSizeListOf(int32(idx.dim), arrow.PrimitiveTypes.Float32)}, 101 | }, nil) 102 | 103 | rb := array.NewRecordBuilder(idx.allocator, schema) 104 | defer rb.Release() 105 | rb.Field(0).(*array.Int32Builder).AppendValues(ids, nil) 106 | listBuilder := rb.Field(1).(*array.FixedSizeListBuilder) 107 | fb := listBuilder.ValueBuilder().(*array.Float32Builder) 108 | offset := 0 109 | for i := 0; i < len(ids); i++ { 110 | listBuilder.Append(true) 111 | fb.AppendValues(vecArray.Float32Values()[offset:offset+idx.dim], nil) 112 | offset += idx.dim 113 | } 114 | rec := rb.NewRecord() 115 | defer rec.Release() 116 | 117 | f, err := os.Create(path) 118 | if err != nil { 119 | return err 120 | } 121 | defer f.Close() 122 | w, err := ipc.NewFileWriter(f, ipc.WithSchema(schema)) 123 | if err != nil { 124 | return err 125 | } 126 | if err := w.Write(rec); err != nil { 127 | w.Close() 128 | return err 129 | } 130 | return w.Close() 131 | } 132 | 133 | // Load restores the index from an Arrow IPC file. 134 | func (idx *ArrowHNSWIndex) Load(path string) error { // arrow-hnsw 135 | f, err := os.Open(path) 136 | if err != nil { 137 | return err 138 | } 139 | defer f.Close() 140 | r, err := ipc.NewFileReader(f) 141 | if err != nil { 142 | return err 143 | } 144 | rec, err := r.Record(0) 145 | if err != nil { 146 | return err 147 | } 148 | ids := rec.Column(0).(*array.Int32) 149 | list := rec.Column(1).(*array.FixedSizeList) 150 | values := list.ListValues().(*array.Float32) 151 | offset := 0 152 | for i := 0; i < int(rec.NumRows()); i++ { 153 | vecSlice := values.Float32Values()[offset : offset+idx.dim] 154 | vb := array.NewFloat32Builder(idx.allocator) 155 | vb.AppendValues(vecSlice, nil) 156 | arr := vb.NewArray() 157 | idStr := fmt.Sprintf("%d", ids.Value(i)) 158 | if err := idx.Add(arr.(*array.Float32), idStr); err != nil { 159 | arr.Release() 160 | vb.Release() 161 | return err 162 | } 163 | arr.Release() 164 | vb.Release() 165 | offset += idx.dim 166 | } 167 | return nil 168 | } 169 | 170 | func float32Slice(in []float64) []float32 { // arrow-hnsw 171 | out := make([]float32, len(in)) 172 | for i, v := range in { 173 | out[i] = float32(v) 174 | } 175 | return out 176 | } 177 | -------------------------------------------------------------------------------- /pkg/vectortypes/surface_test.go: -------------------------------------------------------------------------------- 1 | package vectortypes 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestBasicSurface_Distance(t *testing.T) { 8 | tests := []struct { 9 | name string 10 | distFunc DistanceFunc 11 | vecA F32 12 | vecB F32 13 | expected float32 14 | }{ 15 | { 16 | name: "Cosine Distance", 17 | distFunc: CosineDistance, 18 | vecA: F32{1, 0, 0}, 19 | vecB: F32{0, 1, 0}, 20 | expected: 1.0, // Perpendicular vectors have cosine distance of 1 21 | }, 22 | { 23 | name: "Euclidean Distance", 24 | distFunc: EuclideanDistance, 25 | vecA: F32{1, 0, 0}, 26 | vecB: F32{0, 1, 0}, 27 | expected: 1.4142135, // sqrt(2) 28 | }, 29 | { 30 | name: "Custom Distance (Sum)", 31 | distFunc: func(a, b F32) float32 { 32 | var sum float32 33 | for i := range a { 34 | sum += a[i] + b[i] 35 | } 36 | return sum 37 | }, 38 | vecA: F32{1, 2, 3}, 39 | vecB: F32{4, 5, 6}, 40 | expected: 21.0, // 1+2+3+4+5+6 = 21 41 | }, 42 | } 43 | 44 | for _, tt := range tests { 45 | t.Run(tt.name, func(t *testing.T) { 46 | surface := BasicSurface{DistFunc: tt.distFunc} 47 | result := surface.Distance(tt.vecA, tt.vecB) 48 | 49 | if !floatEquals(result, tt.expected, 1e-6) { 50 | t.Errorf("BasicSurface.Distance(%v, %v) = %v, want %v", 51 | tt.vecA, tt.vecB, result, tt.expected) 52 | } 53 | }) 54 | } 55 | } 56 | 57 | func TestCreateSurface(t *testing.T) { 58 | tests := []struct { 59 | name string 60 | distFunc DistanceFunc 61 | vecA F32 62 | vecB F32 63 | expected float32 64 | }{ 65 | { 66 | name: "Cosine Surface", 67 | distFunc: CosineDistance, 68 | vecA: F32{1, 0, 0}, 69 | vecB: F32{0, 1, 0}, 70 | expected: 1.0, 71 | }, 72 | { 73 | name: "Euclidean Surface", 74 | distFunc: EuclideanDistance, 75 | vecA: F32{0, 0, 0}, 76 | vecB: F32{3, 4, 0}, 77 | expected: 5.0, 78 | }, 79 | } 80 | 81 | for _, tt := range tests { 82 | t.Run(tt.name, func(t *testing.T) { 83 | surface := CreateSurface(tt.distFunc) 84 | result := surface.Distance(tt.vecA, tt.vecB) 85 | 86 | if !floatEquals(result, tt.expected, 1e-6) { 87 | t.Errorf("CreateSurface(%v).Distance(%v, %v) = %v, want %v", 88 | tt.name, tt.vecA, tt.vecB, result, tt.expected) 89 | } 90 | }) 91 | } 92 | } 93 | 94 | func TestStandardSurfaces(t *testing.T) { 95 | // Test the pre-defined standard surfaces 96 | tests := []struct { 97 | name string 98 | surface Surface[F32] 99 | vecA F32 100 | vecB F32 101 | expected float32 102 | }{ 103 | { 104 | name: "Cosine Surface", 105 | surface: CosineSurface, 106 | vecA: F32{1, 0, 0}, 107 | vecB: F32{0, 1, 0}, 108 | expected: 1.0, 109 | }, 110 | { 111 | name: "Euclidean Surface", 112 | surface: EuclideanSurface, 113 | vecA: F32{1, 0, 0}, 114 | vecB: F32{0, 1, 0}, 115 | expected: 1.4142135, // sqrt(2) 116 | }, 117 | { 118 | name: "Dot Product Surface", 119 | surface: DotProductSurface, 120 | vecA: F32{2, 0, 0}, 121 | vecB: F32{2, 0, 0}, 122 | expected: -3.0, // 1 - (2*2) = -3 123 | }, 124 | { 125 | name: "Manhattan Surface", 126 | surface: ManhattanSurface, 127 | vecA: F32{1, 2, 3}, 128 | vecB: F32{4, 5, 6}, 129 | expected: 9.0, // |1-4| + |2-5| + |3-6| = 9 130 | }, 131 | } 132 | 133 | for _, tt := range tests { 134 | t.Run(tt.name, func(t *testing.T) { 135 | result := tt.surface.Distance(tt.vecA, tt.vecB) 136 | 137 | if !floatEquals(result, tt.expected, 1e-6) { 138 | t.Errorf("%s.Distance(%v, %v) = %v, want %v", 139 | tt.name, tt.vecA, tt.vecB, result, tt.expected) 140 | } 141 | }) 142 | } 143 | } 144 | 145 | func TestContraMap(t *testing.T) { 146 | // Define a ContraMap that converts strings to vectors and then applies cosine distance 147 | type StringVector string 148 | 149 | stringToVector := func(s StringVector) F32 { 150 | // Create a vector from the ASCII values of up to 3 characters 151 | result := make(F32, 3) 152 | if len(s) > 0 { 153 | result[0] = float32(s[0]) 154 | } 155 | if len(s) > 1 { 156 | result[1] = float32(s[1]) 157 | } 158 | if len(s) > 2 { 159 | result[2] = float32(s[2]) 160 | } 161 | return result 162 | } 163 | 164 | contraMap := ContraMap[F32, StringVector]{ 165 | Surface: CosineSurface, 166 | ContraMap: stringToVector, 167 | } 168 | 169 | tests := []struct { 170 | name string 171 | strA StringVector 172 | strB StringVector 173 | expected float32 174 | }{ 175 | { 176 | name: "Same String", 177 | strA: "abc", 178 | strB: "abc", 179 | expected: 0.0, // Same vectors have cosine distance 0 180 | }, 181 | { 182 | name: "Different Strings", 183 | strA: "abc", 184 | strB: "xyz", 185 | expected: 0.0, // The ASCII vectors we're using produce a value very close to 0 186 | }, 187 | { 188 | name: "Empty Strings", 189 | strA: "", 190 | strB: "", 191 | expected: 0.0, // Both convert to zero vectors 192 | }, 193 | } 194 | 195 | for _, tt := range tests { 196 | t.Run(tt.name, func(t *testing.T) { 197 | result := contraMap.Distance(tt.strA, tt.strB) 198 | 199 | // Use a larger tolerance for the "Different Strings" test 200 | tolerance := float32(0.1) 201 | 202 | if !floatEquals(result, tt.expected, tolerance) { 203 | t.Errorf("ContraMap.Distance(%v, %v) = %v, want %v (±%v)", 204 | tt.strA, tt.strB, result, tt.expected, tolerance) 205 | } 206 | }) 207 | } 208 | } 209 | -------------------------------------------------------------------------------- /pkg/api/middleware.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "sync" 7 | "time" 8 | 9 | "github.com/gin-gonic/gin" 10 | "github.com/golang-jwt/jwt/v5" 11 | "golang.org/x/time/rate" 12 | ) 13 | 14 | // AuthMiddleware creates a middleware for JWT authentication 15 | func AuthMiddleware(jwtSecret string) gin.HandlerFunc { 16 | return func(c *gin.Context) { 17 | // Skip auth if secret is empty 18 | if jwtSecret == "" { 19 | c.Next() 20 | return 21 | } 22 | 23 | authHeader := c.GetHeader("Authorization") 24 | if authHeader == "" { 25 | c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{ 26 | Status: http.StatusUnauthorized, 27 | Message: "Authorization header is required", 28 | }) 29 | return 30 | } 31 | 32 | // Extract the token 33 | tokenString := authHeader 34 | if len(authHeader) > 7 && authHeader[:7] == "Bearer " { 35 | tokenString = authHeader[7:] 36 | } 37 | 38 | // Parse and validate token 39 | token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { 40 | // Validate signing method 41 | if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { 42 | return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) 43 | } 44 | 45 | return []byte(jwtSecret), nil 46 | }) 47 | 48 | if err != nil { 49 | c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{ 50 | Status: http.StatusUnauthorized, 51 | Message: "Invalid token", 52 | Error: err.Error(), 53 | }) 54 | return 55 | } 56 | 57 | // Check if token is valid 58 | if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { 59 | // Add claims to context 60 | c.Set("claims", claims) 61 | c.Next() 62 | } else { 63 | c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{ 64 | Status: http.StatusUnauthorized, 65 | Message: "Invalid token claims", 66 | }) 67 | return 68 | } 69 | } 70 | } 71 | 72 | // Client represents a rate-limited client 73 | type Client struct { 74 | limiter *rate.Limiter 75 | lastSeen time.Time 76 | } 77 | 78 | // RateLimiterMiddleware creates a middleware for rate limiting 79 | func RateLimiterMiddleware(rps int, burst int) gin.HandlerFunc { 80 | // Store clients with their rate limiters 81 | var ( 82 | clients = make(map[string]*Client) 83 | mu sync.Mutex 84 | ) 85 | 86 | // Set default values if not provided 87 | if rps <= 0 { 88 | rps = 10 // Default to 10 requests per second 89 | } 90 | if burst <= 0 { 91 | burst = rps // Default burst to same as rate 92 | } 93 | 94 | // Clean up routine - remove old clients every minute 95 | go func() { 96 | for { 97 | time.Sleep(time.Minute) 98 | 99 | mu.Lock() 100 | for ip, client := range clients { 101 | if time.Since(client.lastSeen) > 3*time.Minute { 102 | delete(clients, ip) 103 | } 104 | } 105 | mu.Unlock() 106 | } 107 | }() 108 | 109 | return func(c *gin.Context) { 110 | // Get client IP 111 | ip := c.ClientIP() 112 | 113 | mu.Lock() 114 | 115 | // Create client if it doesn't exist 116 | if _, exists := clients[ip]; !exists { 117 | clients[ip] = &Client{ 118 | limiter: rate.NewLimiter(rate.Limit(rps), burst), 119 | lastSeen: time.Now(), 120 | } 121 | } 122 | 123 | // Update last seen time 124 | clients[ip].lastSeen = time.Now() 125 | 126 | // Check if request is allowed 127 | if !clients[ip].limiter.Allow() { 128 | mu.Unlock() 129 | c.AbortWithStatusJSON(http.StatusTooManyRequests, ErrorResponse{ 130 | Status: http.StatusTooManyRequests, 131 | Message: "Rate limit exceeded", 132 | }) 133 | return 134 | } 135 | 136 | mu.Unlock() 137 | c.Next() 138 | } 139 | } 140 | 141 | // LoggingMiddleware creates a middleware for request logging 142 | func LoggingMiddleware() gin.HandlerFunc { 143 | return func(c *gin.Context) { 144 | // Start timer 145 | start := time.Now() 146 | path := c.Request.URL.Path 147 | query := c.Request.URL.RawQuery 148 | 149 | // Process request 150 | c.Next() 151 | 152 | // Stop timer 153 | end := time.Now() 154 | latency := end.Sub(start) 155 | 156 | // Get status 157 | status := c.Writer.Status() 158 | 159 | // Log request 160 | if query != "" { 161 | path = path + "?" + query 162 | } 163 | 164 | // Format log with colors based on status 165 | var statusColor, resetColor, methodColor string 166 | 167 | if gin.Mode() == gin.DebugMode { 168 | // Add colors in debug mode 169 | resetColor = "\033[0m" 170 | methodColor = "\033[1;34m" // Blue 171 | 172 | if status >= 200 && status < 300 { 173 | statusColor = "\033[1;32m" // Green 174 | } else if status >= 300 && status < 400 { 175 | statusColor = "\033[1;33m" // Yellow 176 | } else if status >= 400 && status < 500 { 177 | statusColor = "\033[1;31m" // Red 178 | } else { 179 | statusColor = "\033[1;31m" // Red 180 | } 181 | } 182 | 183 | fmt.Printf("[QUIVER] %v |%s %3d %s| %13v | %15s |%s %-7s %s %#v\n", 184 | end.Format("2006/01/02 - 15:04:05"), 185 | statusColor, status, resetColor, 186 | latency, 187 | c.ClientIP(), 188 | methodColor, c.Request.Method, resetColor, 189 | path, 190 | ) 191 | } 192 | } 193 | 194 | // ErrorHandlerMiddleware creates a middleware for centralized error handling 195 | func ErrorHandlerMiddleware() gin.HandlerFunc { 196 | return func(c *gin.Context) { 197 | c.Next() 198 | 199 | // Check if there were any errors 200 | if len(c.Errors) > 0 { 201 | // Handle errors 202 | c.JSON(http.StatusInternalServerError, ErrorResponse{ 203 | Status: http.StatusInternalServerError, 204 | Message: "Internal server error", 205 | Error: c.Errors.String(), 206 | }) 207 | } 208 | } 209 | } 210 | -------------------------------------------------------------------------------- /pkg/metrics/collector.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | 7 | "github.com/prometheus/client_golang/prometheus" 8 | ) 9 | 10 | // MetricsType defines the type of metric to collect 11 | type MetricsType string 12 | 13 | const ( 14 | // LatencyMetric tracks query execution time 15 | LatencyMetric MetricsType = "latency" 16 | // RecallMetric tracks search accuracy 17 | RecallMetric MetricsType = "recall" 18 | // ThroughputMetric tracks query throughput 19 | ThroughputMetric MetricsType = "throughput" 20 | // CPUUsageMetric tracks CPU utilization 21 | CPUUsageMetric MetricsType = "cpu_usage" 22 | // MemoryUsageMetric tracks memory utilization 23 | MemoryUsageMetric MetricsType = "memory_usage" 24 | ) 25 | 26 | // PerformanceMetrics holds various performance metrics 27 | type PerformanceMetrics struct { 28 | // Average query latency in milliseconds 29 | AvgLatencyMs float64 30 | // Query throughput (queries per second) 31 | QPS float64 32 | // CPU utilization percentage 33 | CPUPercent float64 34 | // Memory usage in megabytes 35 | MemoryMB float64 36 | // Recall rate (if available) 37 | Recall float64 38 | // Time when metrics were collected 39 | Timestamp time.Time 40 | } 41 | 42 | // Collector manages the collection of metrics 43 | type Collector struct { 44 | // Prometheus registry 45 | registry *prometheus.Registry 46 | // Query latency histogram 47 | queryLatency *prometheus.HistogramVec 48 | // Query throughput counter 49 | queries *prometheus.CounterVec 50 | // CPU usage gauge 51 | cpuUsage prometheus.Gauge 52 | // Memory usage gauge 53 | memoryUsage prometheus.Gauge 54 | // Recall metric gauge (if available) 55 | recall prometheus.Gauge 56 | // Optimization score gauge 57 | optimizationScore prometheus.Gauge 58 | // Whether Prometheus metrics are enabled 59 | prometheusEnabled bool 60 | // Lock for concurrent access 61 | mu sync.RWMutex 62 | // Recent metrics 63 | recentMetrics PerformanceMetrics 64 | } 65 | 66 | // NewCollector creates a new metrics collector 67 | func NewCollector(prometheusEnabled bool) *Collector { 68 | c := &Collector{ 69 | prometheusEnabled: prometheusEnabled, 70 | recentMetrics: PerformanceMetrics{ 71 | Timestamp: time.Now(), 72 | }, 73 | mu: sync.RWMutex{}, 74 | } 75 | 76 | if prometheusEnabled { 77 | c.registry = prometheus.NewRegistry() 78 | 79 | // Create Prometheus metrics 80 | c.queryLatency = prometheus.NewHistogramVec( 81 | prometheus.HistogramOpts{ 82 | Name: "quiver_query_latency_ms", 83 | Help: "Query latency in milliseconds", 84 | Buckets: prometheus.ExponentialBuckets(1, 2, 10), // 1-512ms 85 | }, 86 | []string{"collection", "query_type"}, 87 | ) 88 | 89 | c.queries = prometheus.NewCounterVec( 90 | prometheus.CounterOpts{ 91 | Name: "quiver_queries_total", 92 | Help: "Total number of queries executed", 93 | }, 94 | []string{"collection", "query_type"}, 95 | ) 96 | 97 | c.cpuUsage = prometheus.NewGauge( 98 | prometheus.GaugeOpts{ 99 | Name: "quiver_cpu_usage_percent", 100 | Help: "CPU usage percentage", 101 | }, 102 | ) 103 | 104 | c.memoryUsage = prometheus.NewGauge( 105 | prometheus.GaugeOpts{ 106 | Name: "quiver_memory_usage_mb", 107 | Help: "Memory usage in megabytes", 108 | }, 109 | ) 110 | 111 | c.recall = prometheus.NewGauge( 112 | prometheus.GaugeOpts{ 113 | Name: "quiver_search_recall", 114 | Help: "Search recall rate (0-1)", 115 | }, 116 | ) 117 | 118 | c.optimizationScore = prometheus.NewGauge( 119 | prometheus.GaugeOpts{ 120 | Name: "quiver_optimization_score", 121 | Help: "APT optimization score", 122 | }, 123 | ) 124 | 125 | // Register metrics with Prometheus 126 | c.registry.MustRegister(c.queryLatency) 127 | c.registry.MustRegister(c.queries) 128 | c.registry.MustRegister(c.cpuUsage) 129 | c.registry.MustRegister(c.memoryUsage) 130 | c.registry.MustRegister(c.recall) 131 | c.registry.MustRegister(c.optimizationScore) 132 | } 133 | 134 | return c 135 | } 136 | 137 | // RecordLatency records a query latency metric 138 | func (c *Collector) RecordLatency(collection, queryType string, latencyMs float64) { 139 | c.mu.Lock() 140 | defer c.mu.Unlock() 141 | 142 | c.recentMetrics.AvgLatencyMs = (c.recentMetrics.AvgLatencyMs + latencyMs) / 2 143 | c.recentMetrics.Timestamp = time.Now() 144 | 145 | if c.prometheusEnabled { 146 | c.queryLatency.WithLabelValues(collection, queryType).Observe(latencyMs) 147 | c.queries.WithLabelValues(collection, queryType).Inc() 148 | } 149 | } 150 | 151 | // RecordSystemMetrics records system resource usage metrics 152 | func (c *Collector) RecordSystemMetrics(cpuPercent, memoryMB float64) { 153 | c.mu.Lock() 154 | defer c.mu.Unlock() 155 | 156 | c.recentMetrics.CPUPercent = cpuPercent 157 | c.recentMetrics.MemoryMB = memoryMB 158 | c.recentMetrics.Timestamp = time.Now() 159 | 160 | if c.prometheusEnabled { 161 | c.cpuUsage.Set(cpuPercent) 162 | c.memoryUsage.Set(memoryMB) 163 | } 164 | } 165 | 166 | // RecordRecall records a search recall metric 167 | func (c *Collector) RecordRecall(recall float64) { 168 | c.mu.Lock() 169 | defer c.mu.Unlock() 170 | 171 | c.recentMetrics.Recall = recall 172 | c.recentMetrics.Timestamp = time.Now() 173 | 174 | if c.prometheusEnabled { 175 | c.recall.Set(recall) 176 | } 177 | } 178 | 179 | // RecordOptimization records an optimization event 180 | func (c *Collector) RecordOptimization(score float64, oldEfSearch, newEfSearch int) { 181 | if c.prometheusEnabled { 182 | c.optimizationScore.Set(score) 183 | } 184 | } 185 | 186 | // GetRecentMetrics retrieves the most recent metrics 187 | func (c *Collector) GetRecentMetrics() PerformanceMetrics { 188 | c.mu.RLock() 189 | defer c.mu.RUnlock() 190 | return c.recentMetrics 191 | } 192 | 193 | // GetRegistry returns the Prometheus registry 194 | func (c *Collector) GetRegistry() *prometheus.Registry { 195 | return c.registry 196 | } 197 | -------------------------------------------------------------------------------- /pkg/vectortypes/distances.go: -------------------------------------------------------------------------------- 1 | package vectortypes 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | ) 7 | 8 | // Standard distance functions for vector similarity 9 | 10 | // CosineDistance calculates the cosine distance between vectors 11 | // Lower value means more similar vectors (0 being identical) 12 | func CosineDistance(a, b F32) float32 { 13 | if len(a) != len(b) { 14 | panic("vectors must have the same length") 15 | } 16 | 17 | var dotProduct, magnitudeA, magnitudeB float64 18 | for i := 0; i < len(a); i++ { 19 | dotProduct += float64(a[i]) * float64(b[i]) 20 | magnitudeA += float64(a[i]) * float64(a[i]) 21 | magnitudeB += float64(b[i]) * float64(b[i]) 22 | } 23 | 24 | // Guard against divide-by-zero 25 | if magnitudeA == 0 || magnitudeB == 0 { 26 | return 0 27 | } 28 | 29 | // Compute cosine similarity and convert to cosine distance 30 | similarity := dotProduct / (math.Sqrt(magnitudeA) * math.Sqrt(magnitudeB)) 31 | // Clamp similarity to [-1, 1] to account for floating point errors 32 | if similarity > 1.0 { 33 | similarity = 1.0 34 | } else if similarity < -1.0 { 35 | similarity = -1.0 36 | } 37 | 38 | // Distance = 1 - similarity 39 | return float32(1.0 - similarity) 40 | } 41 | 42 | // EuclideanDistance calculates the Euclidean distance between vectors 43 | func EuclideanDistance(a, b F32) float32 { 44 | if len(a) != len(b) { 45 | panic("vectors must have the same length") 46 | } 47 | 48 | var sum float64 49 | for i := 0; i < len(a); i++ { 50 | diff := float64(a[i] - b[i]) 51 | sum += diff * diff 52 | } 53 | 54 | return float32(math.Sqrt(sum)) 55 | } 56 | 57 | // SquaredEuclideanDistance calculates the squared Euclidean distance between vectors 58 | // This avoids the final square root which can be useful in comparisons where only 59 | // relative ordering matters. 60 | func SquaredEuclideanDistance(a, b F32) float32 { 61 | if len(a) != len(b) { 62 | panic("vectors must have the same length") 63 | } 64 | 65 | var sum float32 66 | for i := 0; i < len(a); i++ { 67 | diff := a[i] - b[i] 68 | sum += diff * diff 69 | } 70 | 71 | return sum 72 | } 73 | 74 | // DotProductDistance calculates negative dot product as a distance 75 | // For normalized vectors, higher dot product indicates higher similarity 76 | // We negate this to make it a distance (where lower values = more similar) 77 | func DotProductDistance(a, b F32) float32 { 78 | if len(a) != len(b) { 79 | panic("vectors must have the same length") 80 | } 81 | 82 | var dotProduct float64 83 | for i := 0; i < len(a); i++ { 84 | dotProduct += float64(a[i]) * float64(b[i]) 85 | } 86 | 87 | // Convert to distance: 1 - dot(a, b) 88 | // For normalized vectors, this will range from 0 (identical) to 2 (opposite) 89 | return float32(1.0 - dotProduct) 90 | } 91 | 92 | // ManhattanDistance calculates the L1 norm (Manhattan distance) between vectors 93 | func ManhattanDistance(a, b F32) float32 { 94 | if len(a) != len(b) { 95 | panic("vectors must have the same length") 96 | } 97 | 98 | var sum float64 99 | for i := 0; i < len(a); i++ { 100 | sum += math.Abs(float64(a[i] - b[i])) 101 | } 102 | 103 | return float32(sum) 104 | } 105 | 106 | // Create surfaces for the standard distance functions 107 | var ( 108 | CosineSurface = CreateSurface(CosineDistance) 109 | EuclideanSurface = CreateSurface(EuclideanDistance) 110 | SquaredEuclideanSurface = CreateSurface(SquaredEuclideanDistance) 111 | DotProductSurface = CreateSurface(DotProductDistance) 112 | ManhattanSurface = CreateSurface(ManhattanDistance) 113 | ) 114 | 115 | // NormalizeVector normalizes a vector to unit length 116 | func NormalizeVector(v F32) F32 { 117 | var magnitude float64 118 | for _, val := range v { 119 | magnitude += float64(val) * float64(val) 120 | } 121 | magnitude = math.Sqrt(magnitude) 122 | 123 | // Avoid division by zero 124 | if magnitude == 0 { 125 | return v 126 | } 127 | 128 | normalized := make(F32, len(v)) 129 | for i, val := range v { 130 | normalized[i] = float32(float64(val) / magnitude) 131 | } 132 | 133 | return normalized 134 | } 135 | 136 | // VectorAdd adds two vectors 137 | func VectorAdd(a, b F32) (F32, error) { 138 | if len(a) != len(b) { 139 | return nil, fmt.Errorf("vectors must have the same length: %d != %d", len(a), len(b)) 140 | } 141 | 142 | result := make(F32, len(a)) 143 | for i := 0; i < len(a); i++ { 144 | result[i] = a[i] + b[i] 145 | } 146 | return result, nil 147 | } 148 | 149 | // VectorSubtract subtracts vector b from vector a 150 | func VectorSubtract(a, b F32) (F32, error) { 151 | if len(a) != len(b) { 152 | return nil, fmt.Errorf("vectors must have the same length: %d != %d", len(a), len(b)) 153 | } 154 | 155 | result := make(F32, len(a)) 156 | for i := 0; i < len(a); i++ { 157 | result[i] = a[i] - b[i] 158 | } 159 | return result, nil 160 | } 161 | 162 | // VectorMultiplyScalar multiplies a vector by a scalar 163 | func VectorMultiplyScalar(v F32, scalar float32) F32 { 164 | result := make(F32, len(v)) 165 | for i := 0; i < len(v); i++ { 166 | result[i] = v[i] * scalar 167 | } 168 | return result 169 | } 170 | 171 | // VectorMagnitude calculates the magnitude (Euclidean norm) of a vector. 172 | func VectorMagnitude(v F32) float32 { 173 | var sumSquares float64 174 | for _, val := range v { 175 | sumSquares += float64(val) * float64(val) 176 | } 177 | return float32(math.Sqrt(sumSquares)) 178 | } 179 | 180 | // CreateZeroVector creates a new vector filled with zeros of the specified dimension. 181 | func CreateZeroVector(dimension int) F32 { 182 | return make(F32, dimension) 183 | } 184 | 185 | // CreateRandomVector creates a new vector with random values. 186 | func CreateRandomVector(dimension int) F32 { 187 | v := make(F32, dimension) 188 | for i := 0; i < dimension; i++ { 189 | v[i] = float32(math.Sin(float64(i))) // Simple deterministic approach for demo 190 | } 191 | return v 192 | } 193 | 194 | // CloneVector creates a deep copy of a vector. 195 | func CloneVector(v F32) F32 { 196 | clone := make(F32, len(v)) 197 | copy(clone, v) 198 | return clone 199 | } 200 | -------------------------------------------------------------------------------- /pkg/hybrid/types.go: -------------------------------------------------------------------------------- 1 | // Package hybrid implements a multi-strategy approach to vector indexing and search 2 | // by combining multiple search strategies (exact search, HNSW, etc.) 3 | package hybrid 4 | 5 | import ( 6 | "time" 7 | 8 | "github.com/TFMV/quiver/pkg/types" 9 | "github.com/TFMV/quiver/pkg/vectortypes" 10 | ) 11 | 12 | // IndexType represents the type of index used 13 | type IndexType string 14 | 15 | const ( 16 | // ExactIndexType represents an exact search index 17 | ExactIndexType IndexType = "exact" 18 | 19 | // HNSWIndexType represents an HNSW index 20 | HNSWIndexType IndexType = "hnsw" 21 | 22 | // HybridIndexType represents a hybrid index 23 | HybridIndexType IndexType = "hybrid" 24 | ) 25 | 26 | // IndexConfig holds configuration options for the hybrid index 27 | type IndexConfig struct { 28 | // Distance function to use 29 | DistanceFunc vectortypes.DistanceFunc 30 | 31 | // Configuration for the HNSW index 32 | HNSWConfig HNSWConfig 33 | 34 | // Threshold for switching to exact search (in number of vectors) 35 | ExactThreshold int 36 | } 37 | 38 | // DefaultIndexConfig returns a default configuration for the hybrid index 39 | func DefaultIndexConfig() IndexConfig { 40 | return IndexConfig{ 41 | DistanceFunc: vectortypes.CosineDistance, 42 | HNSWConfig: DefaultHNSWConfig(), 43 | ExactThreshold: 1000, // Use exact search for datasets smaller than 1000 vectors 44 | } 45 | } 46 | 47 | // HNSWConfig holds configuration options for the HNSW index 48 | type HNSWConfig struct { 49 | // M is the number of connections per element 50 | M int 51 | 52 | // MaxM0 defines the maximum number of connections at layer 0 53 | MaxM0 int 54 | 55 | // EfConstruction is the size of the dynamic candidate list during index construction 56 | EfConstruction int 57 | 58 | // EfSearch is the size of the dynamic candidate list during search 59 | EfSearch int 60 | } 61 | 62 | // DefaultHNSWConfig returns a default configuration for the HNSW index 63 | func DefaultHNSWConfig() HNSWConfig { 64 | return HNSWConfig{ 65 | M: 16, 66 | MaxM0: 32, // Typically 2*M 67 | EfConstruction: 200, 68 | EfSearch: 100, 69 | } 70 | } 71 | 72 | // AdaptiveConfig holds configuration options for the adaptive strategy selector 73 | type AdaptiveConfig struct { 74 | // Controls the exploration vs exploitation tradeoff (0-1) 75 | ExplorationFactor float64 76 | 77 | // Initial threshold for dataset size to switch from exact to HNSW (overrides IndexConfig) 78 | InitialExactThreshold int 79 | 80 | // Initial threshold for query dimension to prefer HNSW 81 | InitialDimThreshold int 82 | 83 | // Number of queries to keep metrics for 84 | MetricsWindowSize int 85 | 86 | // How aggressively to adapt thresholds (0-1) 87 | AdaptationRate float64 88 | } 89 | 90 | // DefaultAdaptiveConfig returns a default configuration for the adaptive strategy selector 91 | func DefaultAdaptiveConfig() AdaptiveConfig { 92 | return AdaptiveConfig{ 93 | ExplorationFactor: 0.1, // 10% exploration 94 | InitialExactThreshold: 1000, // Use exact search for datasets smaller than 1000 vectors 95 | InitialDimThreshold: 100, // Consider high-dimensional for dim > 100 96 | MetricsWindowSize: 1000, // Keep metrics for last 1000 queries 97 | AdaptationRate: 0.05, // Adapt thresholds by 5% each time 98 | } 99 | } 100 | 101 | // QueryMetrics holds performance metrics for a single query 102 | type QueryMetrics struct { 103 | // The strategy used for the query 104 | Strategy IndexType 105 | 106 | // The dimension of the query vector 107 | QueryDimension int 108 | 109 | // The number of results requested 110 | K int 111 | 112 | // How long the query took 113 | Duration time.Duration 114 | 115 | // How many results were returned 116 | ResultCount int 117 | 118 | // When the query was executed 119 | Timestamp time.Time 120 | } 121 | 122 | // StrategyStats holds statistics for a specific strategy 123 | type StrategyStats struct { 124 | // Number of times this strategy was selected 125 | UsageCount int 126 | 127 | // Sum of query durations for this strategy 128 | TotalDuration time.Duration 129 | 130 | // Average query duration for this strategy 131 | AvgDuration time.Duration 132 | } 133 | 134 | // HybridStats holds statistics for the hybrid index 135 | type HybridStats struct { 136 | // Number of vectors in the index 137 | VectorCount int 138 | 139 | // Average dimension of vectors in the index 140 | AvgDimension int 141 | 142 | // Statistics for each strategy 143 | StrategyStats map[IndexType]*StrategyStats 144 | } 145 | 146 | // HybridSearchRequest holds parameters for a hybrid search 147 | type HybridSearchRequest struct { 148 | // Query vector 149 | Query vectortypes.F32 150 | 151 | // Number of results to return 152 | K int 153 | 154 | // Force use of a specific strategy (empty means use adaptive selection) 155 | ForceStrategy IndexType 156 | 157 | // Whether to include detailed stats in the response 158 | IncludeStats bool 159 | 160 | // Negative example vector - results should be dissimilar from this 161 | NegativeExample vectortypes.F32 162 | 163 | // Weight to apply to the negative example (0.0-1.0) 164 | // Higher weight means stronger influence of the negative example 165 | NegativeWeight float32 166 | } 167 | 168 | // HybridSearchResponse contains the results of a hybrid search 169 | type HybridSearchResponse struct { 170 | // The search results 171 | Results []types.BasicSearchResult 172 | 173 | // The strategy that was used 174 | StrategyUsed IndexType 175 | 176 | // How long the search took 177 | SearchTime time.Duration 178 | 179 | // Detailed stats (if requested) 180 | Stats *QueryMetrics 181 | } 182 | 183 | // VectorWithID represents a vector with its ID 184 | type VectorWithID struct { 185 | ID string 186 | Vector vectortypes.F32 187 | } 188 | 189 | // SearchResult represents a search result with ID and distance 190 | type SearchResult struct { 191 | ID string 192 | Distance float32 193 | } 194 | 195 | // Index is the interface that must be implemented by all indexes 196 | type Index interface { 197 | // Insert adds a vector to the index 198 | Insert(id string, vector vectortypes.F32) error 199 | 200 | // Delete removes a vector from the index 201 | Delete(id string) error 202 | 203 | // Search finds the k nearest vectors to the query vector 204 | Search(query vectortypes.F32, k int) ([]types.BasicSearchResult, error) 205 | 206 | // Size returns the number of vectors in the index 207 | Size() int 208 | 209 | // GetType returns the type of this index 210 | GetType() IndexType 211 | 212 | // GetStats returns statistics about this index 213 | GetStats() interface{} 214 | } 215 | 216 | // StrategySelector is the interface for components that select search strategies 217 | type StrategySelector interface { 218 | // SelectStrategy chooses the best strategy for a query 219 | SelectStrategy(query vectortypes.F32, k int) IndexType 220 | 221 | // RecordQueryMetrics records metrics for a completed query 222 | RecordQueryMetrics(metrics QueryMetrics) 223 | 224 | // GetStats returns statistics about strategy selection 225 | GetStats() map[string]interface{} 226 | } 227 | -------------------------------------------------------------------------------- /pkg/hybrid/hnsw_adapter_test.go: -------------------------------------------------------------------------------- 1 | package hybrid 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/TFMV/quiver/pkg/vectortypes" 7 | ) 8 | 9 | func TestNewHNSWAdapter(t *testing.T) { 10 | distFunc := vectortypes.CosineDistance 11 | config := DefaultHNSWConfig() 12 | 13 | adapter := NewHNSWAdapter(distFunc, config) 14 | 15 | if adapter == nil { 16 | t.Fatal("NewHNSWAdapter returned nil") 17 | } 18 | 19 | if adapter.adapter == nil { 20 | t.Error("Underlying HNSW adapter not initialized") 21 | } 22 | 23 | if adapter.config.M != config.M { 24 | t.Errorf("Expected M %d, got %d", config.M, adapter.config.M) 25 | } 26 | 27 | if adapter.config.MaxM0 != config.MaxM0 { 28 | t.Errorf("Expected MaxM0 %d, got %d", config.MaxM0, adapter.config.MaxM0) 29 | } 30 | 31 | if adapter.config.EfConstruction != config.EfConstruction { 32 | t.Errorf("Expected EfConstruction %d, got %d", config.EfConstruction, adapter.config.EfConstruction) 33 | } 34 | 35 | if adapter.config.EfSearch != config.EfSearch { 36 | t.Errorf("Expected EfSearch %d, got %d", config.EfSearch, adapter.config.EfSearch) 37 | } 38 | } 39 | 40 | func TestHNSWAdapter_Insert(t *testing.T) { 41 | adapter := NewHNSWAdapter(vectortypes.CosineDistance, DefaultHNSWConfig()) 42 | 43 | // Test inserting a vector 44 | id := "test1" 45 | vector := vectortypes.F32{0.1, 0.2, 0.3, 0.4} 46 | 47 | err := adapter.Insert(id, vector) 48 | if err != nil { 49 | t.Errorf("Insert returned unexpected error: %v", err) 50 | } 51 | 52 | // Verify size increased 53 | if size := adapter.Size(); size != 1 { 54 | t.Errorf("Expected size 1 after insertion, got %d", size) 55 | } 56 | } 57 | 58 | func TestHNSWAdapter_Delete(t *testing.T) { 59 | adapter := NewHNSWAdapter(vectortypes.CosineDistance, DefaultHNSWConfig()) 60 | 61 | // Insert a vector first 62 | id := "test1" 63 | vector := vectortypes.F32{0.1, 0.2, 0.3, 0.4} 64 | 65 | err := adapter.Insert(id, vector) 66 | if err != nil { 67 | t.Fatalf("Insert failed: %v", err) 68 | } 69 | 70 | // Verify it exists 71 | if adapter.Size() != 1 { 72 | t.Fatalf("Expected size 1 after insertion, got %d", adapter.Size()) 73 | } 74 | 75 | // Try to delete non-existent ID (should fail gracefully) 76 | err = adapter.Delete("nonexistent") 77 | if err == nil { 78 | t.Error("Expected error when deleting non-existent vector, got nil") 79 | } 80 | 81 | // Now delete the real one 82 | err = adapter.Delete(id) 83 | if err != nil { 84 | t.Errorf("Delete returned unexpected error: %v", err) 85 | } 86 | 87 | // Verify it's gone 88 | if adapter.Size() != 0 { 89 | t.Errorf("Expected size 0 after deletion, got %d", adapter.Size()) 90 | } 91 | } 92 | 93 | func TestHNSWAdapter_Search(t *testing.T) { 94 | adapter := NewHNSWAdapter(vectortypes.CosineDistance, DefaultHNSWConfig()) 95 | 96 | // Test search on empty index 97 | query := vectortypes.F32{0.9, 0.1, 0.0} 98 | results, err := adapter.Search(query, 1) 99 | if err != nil { 100 | t.Fatalf("Search on empty index returned error: %v", err) 101 | } 102 | if len(results) != 0 { 103 | t.Errorf("Expected 0 results on empty index, got %d", len(results)) 104 | } 105 | 106 | // Insert two test vectors - just use these two distinct vectors 107 | id1 := "vec1" 108 | vec1 := vectortypes.F32{1.0, 0.0, 0.0} // Aligned with x-axis 109 | 110 | id2 := "vec2" 111 | vec2 := vectortypes.F32{0.0, 1.0, 0.0} // Aligned with y-axis 112 | 113 | // Insert vectors individually to avoid batch issues 114 | if err := adapter.Insert(id1, vec1); err != nil { 115 | t.Fatalf("Failed to insert vector %s: %v", id1, err) 116 | } 117 | 118 | if err := adapter.Insert(id2, vec2); err != nil { 119 | t.Fatalf("Failed to insert vector %s: %v", id2, err) 120 | } 121 | 122 | // Basic search test 123 | results, err = adapter.Search(query, 1) 124 | if err != nil { 125 | t.Fatalf("Search returned unexpected error: %v", err) 126 | } 127 | 128 | if len(results) != 1 { 129 | t.Errorf("Expected 1 result, got %d", len(results)) 130 | } else { 131 | // HNSW is approximate, so either vec1 or vec2 could be returned 132 | // But we validate that one of them is returned 133 | if results[0].ID != "vec1" && results[0].ID != "vec2" { 134 | t.Errorf("Expected result to be either vec1 or vec2, got %s", results[0].ID) 135 | } 136 | } 137 | } 138 | 139 | func TestHNSWAdapter_Size(t *testing.T) { 140 | adapter := NewHNSWAdapter(vectortypes.CosineDistance, DefaultHNSWConfig()) 141 | 142 | // Empty index 143 | if size := adapter.Size(); size != 0 { 144 | t.Errorf("Expected empty index size 0, got %d", size) 145 | } 146 | 147 | // Add vectors and check size - limiting to just 2 vectors to avoid index out of range errors 148 | vectors := []vectortypes.F32{ 149 | {0.1, 0.2, 0.3}, 150 | {0.4, 0.5, 0.6}, 151 | } 152 | 153 | for i, vec := range vectors { 154 | id := "test" + string(rune('0'+i)) 155 | err := adapter.Insert(id, vec) 156 | if err != nil { 157 | t.Fatalf("Insert failed for %s: %v", id, err) 158 | } 159 | 160 | expectedSize := i + 1 161 | if size := adapter.Size(); size != expectedSize { 162 | t.Errorf("Expected size %d after %d insertions, got %d", expectedSize, i+1, size) 163 | } 164 | } 165 | } 166 | 167 | func TestHNSWAdapter_GetType(t *testing.T) { 168 | adapter := NewHNSWAdapter(vectortypes.CosineDistance, DefaultHNSWConfig()) 169 | if adapter.GetType() != HNSWIndexType { 170 | t.Errorf("Expected index type %s, got %s", HNSWIndexType, adapter.GetType()) 171 | } 172 | } 173 | 174 | func TestHNSWAdapter_GetStats(t *testing.T) { 175 | adapter := NewHNSWAdapter(vectortypes.CosineDistance, DefaultHNSWConfig()) 176 | 177 | // Insert some vectors 178 | for i := 0; i < 3; i++ { 179 | id := "test" + string(rune('0'+i)) 180 | vector := vectortypes.F32{float32(i) * 0.1, float32(i) * 0.2, float32(i) * 0.3} 181 | if err := adapter.Insert(id, vector); err != nil { 182 | t.Fatalf("Insert failed: %v", err) 183 | } 184 | } 185 | 186 | // Get stats 187 | stats := adapter.GetStats().(map[string]interface{}) 188 | 189 | // Check type 190 | if typeStr, ok := stats["type"].(string); !ok || typeStr != string(HNSWIndexType) { 191 | t.Errorf("Expected type %s, got %v", HNSWIndexType, stats["type"]) 192 | } 193 | 194 | // Check vector count 195 | if count, ok := stats["vector_count"].(int); !ok || count != 3 { 196 | t.Errorf("Expected vector_count 3, got %v", stats["vector_count"]) 197 | } 198 | 199 | // Check that parameters and metrics exist 200 | if params, ok := stats["parameters"]; !ok || params == nil { 201 | t.Error("Expected parameters in stats output") 202 | } 203 | 204 | if metrics, ok := stats["metrics"]; !ok || metrics == nil { 205 | t.Error("Expected metrics in stats output") 206 | } 207 | } 208 | 209 | func TestHNSWAdapter_SetSearchEf(t *testing.T) { 210 | adapter := NewHNSWAdapter(vectortypes.CosineDistance, DefaultHNSWConfig()) 211 | 212 | // Set a new EfSearch value 213 | newEfSearch := 200 214 | if err := adapter.SetSearchEf(newEfSearch); err != nil { 215 | t.Fatalf("SetSearchEf failed: %v", err) 216 | } 217 | 218 | // While we can't directly verify the internal change since the field is within 219 | // the wrapped adapter, we can check that the method runs without error 220 | 221 | // Insert and search to ensure functionality still works after changing EfSearch 222 | vector := vectortypes.F32{0.1, 0.2, 0.3} 223 | err := adapter.Insert("test", vector) 224 | if err != nil { 225 | t.Fatalf("Insert failed after SetSearchEf: %v", err) 226 | } 227 | 228 | _, err = adapter.Search(vector, 1) 229 | if err != nil { 230 | t.Fatalf("Search failed after SetSearchEf: %v", err) 231 | } 232 | } 233 | 234 | func TestHNSWAdapter_DimensionMismatch(t *testing.T) { 235 | adapter := NewHNSWAdapter(vectortypes.CosineDistance, DefaultHNSWConfig()) 236 | if err := adapter.Insert("v1", vectortypes.F32{0.1, 0.2}); err != nil { 237 | t.Fatalf("insert failed: %v", err) 238 | } 239 | if err := adapter.Insert("v2", vectortypes.F32{0.1}); err == nil { 240 | t.Errorf("expected dimension mismatch error") 241 | } 242 | if _, err := adapter.Search(vectortypes.F32{0.1}, 1); err == nil { 243 | t.Errorf("expected query dimension mismatch") 244 | } 245 | } 246 | -------------------------------------------------------------------------------- /pkg/api/server.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "net/http" 8 | "os" 9 | "os/signal" 10 | "syscall" 11 | "time" 12 | 13 | "github.com/TFMV/quiver/pkg/core" 14 | "github.com/gin-contrib/cors" 15 | "github.com/gin-gonic/gin" 16 | "github.com/prometheus/client_golang/prometheus/promhttp" 17 | ) 18 | 19 | // ServerConfig holds configuration options for the API server 20 | type ServerConfig struct { 21 | // Host is the server host (default: localhost) 22 | Host string 23 | // Port is the server port (default: 8080) 24 | Port int 25 | // AllowedOrigins is a list of CORS allowed origins 26 | AllowedOrigins []string 27 | // EnableMetrics enables Prometheus metrics endpoint 28 | EnableMetrics bool 29 | // MetricsPort is the port for the metrics server (default: 9090) 30 | MetricsPort int 31 | // ReadTimeout is the maximum duration for reading the entire request (default: 5s) 32 | ReadTimeout time.Duration 33 | // WriteTimeout is the maximum duration before timing out writes of the response (default: 10s) 34 | WriteTimeout time.Duration 35 | // ShutdownTimeout is the maximum duration to wait for server shutdown (default: 30s) 36 | ShutdownTimeout time.Duration 37 | // RateLimit is the number of requests per minute allowed per client (default: 60) 38 | RateLimit int 39 | // JWTSecret is used for API token authentication (if empty, JWT auth is disabled) 40 | JWTSecret string 41 | // LogLevel controls the verbosity of logging (debug, info, warn, error) 42 | LogLevel string 43 | } 44 | 45 | // DefaultServerConfig returns default configuration options 46 | func DefaultServerConfig() ServerConfig { 47 | return ServerConfig{ 48 | Host: "localhost", 49 | Port: 8080, 50 | AllowedOrigins: []string{"*"}, 51 | EnableMetrics: true, 52 | MetricsPort: 9090, 53 | ReadTimeout: 5 * time.Second, 54 | WriteTimeout: 10 * time.Second, 55 | ShutdownTimeout: 30 * time.Second, 56 | RateLimit: 60, 57 | LogLevel: "info", 58 | } 59 | } 60 | 61 | // Server represents the API server for Quiver 62 | type Server struct { 63 | config ServerConfig 64 | db *core.DB 65 | router *gin.Engine 66 | httpServer *http.Server 67 | metricsHTTP *http.Server 68 | handlers *Handlers 69 | } 70 | 71 | // NewServer creates a new API server 72 | func NewServer(db *core.DB, config ServerConfig) *Server { 73 | // Set sensible defaults for zero values 74 | if config.Host == "" { 75 | config.Host = "localhost" 76 | } 77 | if config.Port == 0 { 78 | config.Port = 8080 79 | } 80 | if config.MetricsPort == 0 { 81 | config.MetricsPort = 9090 82 | } 83 | if config.ReadTimeout == 0 { 84 | config.ReadTimeout = 5 * time.Second 85 | } 86 | if config.WriteTimeout == 0 { 87 | config.WriteTimeout = 10 * time.Second 88 | } 89 | if config.ShutdownTimeout == 0 { 90 | config.ShutdownTimeout = 30 * time.Second 91 | } 92 | if config.RateLimit == 0 { 93 | config.RateLimit = 60 94 | } 95 | if len(config.AllowedOrigins) == 0 { 96 | config.AllowedOrigins = []string{"*"} 97 | } 98 | if config.LogLevel == "" { 99 | config.LogLevel = "info" 100 | } 101 | 102 | // Set Gin mode based on log level 103 | if config.LogLevel == "debug" { 104 | gin.SetMode(gin.DebugMode) 105 | } else { 106 | gin.SetMode(gin.ReleaseMode) 107 | } 108 | 109 | // Create router 110 | router := gin.Default() 111 | 112 | // Setup CORS 113 | corsConfig := cors.DefaultConfig() 114 | corsConfig.AllowOrigins = config.AllowedOrigins 115 | corsConfig.AllowMethods = []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"} 116 | corsConfig.AllowHeaders = []string{"Origin", "Content-Type", "Authorization"} 117 | router.Use(cors.New(corsConfig)) 118 | 119 | // Create server 120 | server := &Server{ 121 | config: config, 122 | db: db, 123 | router: router, 124 | httpServer: &http.Server{ 125 | Addr: fmt.Sprintf("%s:%d", config.Host, config.Port), 126 | Handler: router, 127 | ReadTimeout: config.ReadTimeout, 128 | WriteTimeout: config.WriteTimeout, 129 | }, 130 | } 131 | 132 | // Create API handlers 133 | server.handlers = NewHandlers(db) 134 | 135 | // Set up metrics server if enabled 136 | if config.EnableMetrics { 137 | mux := http.NewServeMux() 138 | mux.Handle("/metrics", promhttp.Handler()) 139 | server.metricsHTTP = &http.Server{ 140 | Addr: fmt.Sprintf("%s:%d", config.Host, config.MetricsPort), 141 | Handler: mux, 142 | } 143 | } 144 | 145 | // Setup routes 146 | server.setupRoutes() 147 | 148 | return server 149 | } 150 | 151 | // setupRoutes configures all API routes 152 | func (s *Server) setupRoutes() { 153 | // API version group 154 | v1 := s.router.Group("/api/v1") 155 | 156 | // Health check 157 | v1.GET("/health", s.handlers.HealthCheck) 158 | 159 | // Database-wide endpoints 160 | v1.GET("/collections", s.handlers.ListCollections) 161 | v1.POST("/collections", s.handlers.CreateCollection) 162 | v1.GET("/metrics", s.handlers.GetMetrics) 163 | v1.POST("/backup", s.handlers.CreateBackup) 164 | v1.POST("/restore", s.handlers.RestoreBackup) 165 | 166 | // Collection-specific endpoints 167 | collection := v1.Group("/collections/:collection") 168 | { 169 | collection.GET("", s.handlers.GetCollection) 170 | collection.DELETE("", s.handlers.DeleteCollection) 171 | collection.GET("/stats", s.handlers.GetCollectionStats) 172 | 173 | // Vector operations 174 | collection.POST("/vectors", s.handlers.AddVector) 175 | collection.POST("/vectors/batch", s.handlers.AddVectorBatch) 176 | collection.GET("/vectors/:id", s.handlers.GetVector) 177 | collection.PUT("/vectors/:id", s.handlers.UpdateVector) 178 | collection.DELETE("/vectors/:id", s.handlers.DeleteVector) 179 | collection.POST("/vectors/delete/batch", s.handlers.DeleteVectorBatch) 180 | 181 | // Search 182 | collection.POST("/search", s.handlers.Search) 183 | } 184 | } 185 | 186 | // Start starts the HTTP server 187 | func (s *Server) Start() { 188 | // Start metrics server if enabled 189 | if s.config.EnableMetrics && s.metricsHTTP != nil { 190 | go func() { 191 | log.Printf("Starting metrics server on %s", s.metricsHTTP.Addr) 192 | if err := s.metricsHTTP.ListenAndServe(); err != nil && err != http.ErrServerClosed { 193 | log.Fatalf("Failed to start metrics server: %v", err) 194 | } 195 | }() 196 | } 197 | 198 | // Start main API server 199 | go func() { 200 | log.Printf("Starting API server on %s", s.httpServer.Addr) 201 | if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { 202 | log.Fatalf("Failed to start server: %v", err) 203 | } 204 | }() 205 | 206 | // Wait for interrupt signal to gracefully shut down the server 207 | quit := make(chan os.Signal, 1) 208 | signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) 209 | <-quit 210 | log.Println("Shutting down server...") 211 | 212 | // Create shutdown context with timeout 213 | ctx, cancel := context.WithTimeout(context.Background(), s.config.ShutdownTimeout) 214 | defer cancel() 215 | 216 | // Shutdown metrics server if it was started 217 | if s.config.EnableMetrics && s.metricsHTTP != nil { 218 | if err := s.metricsHTTP.Shutdown(ctx); err != nil { 219 | log.Printf("Metrics server forced to shutdown: %v", err) 220 | } 221 | } 222 | 223 | // Attempt to gracefully shutdown the server 224 | if err := s.httpServer.Shutdown(ctx); err != nil { 225 | log.Printf("Server forced to shutdown: %v", err) 226 | } 227 | 228 | log.Println("Server exiting") 229 | } 230 | 231 | // GetAddr returns the server's address as a string 232 | func (s *Server) GetAddr() string { 233 | return fmt.Sprintf("%s:%d", s.config.Host, s.config.Port) 234 | } 235 | 236 | // GetMetricsAddr returns the metrics server's address as a string 237 | func (s *Server) GetMetricsAddr() string { 238 | return fmt.Sprintf("%s:%d", s.config.Host, s.config.MetricsPort) 239 | } 240 | 241 | // SetJWTSecret sets the JWT secret for authentication 242 | func (s *Server) SetJWTSecret(secret string) { 243 | s.config.JWTSecret = secret 244 | // Re-setup routes with authentication middleware if a secret is provided 245 | if secret != "" { 246 | s.setupRoutes() 247 | } 248 | } 249 | 250 | // GetRouter returns the Gin router for testing 251 | func (s *Server) GetRouter() *gin.Engine { 252 | return s.router 253 | } 254 | 255 | // GetConfig returns the server configuration 256 | func (s *Server) GetConfig() ServerConfig { 257 | return s.config 258 | } 259 | -------------------------------------------------------------------------------- /pkg/hybrid/adaptive.go: -------------------------------------------------------------------------------- 1 | package hybrid 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "sync" 7 | "time" 8 | ) 9 | 10 | // AdaptiveStrategySelector implements adaptive selection of search strategies 11 | type AdaptiveStrategySelector struct { 12 | // Configuration for the adaptive selector 13 | config AdaptiveConfig 14 | 15 | // Current thresholds based on learning 16 | exactThreshold int 17 | dimThreshold int 18 | 19 | // Performance metrics for strategies 20 | metrics map[IndexType]*StrategyStats 21 | 22 | // Recent query metrics for analysis 23 | recentQueries []QueryMetrics 24 | 25 | // Random number generator for exploration 26 | rng *rand.Rand 27 | 28 | // Lock for thread safety 29 | mu sync.RWMutex 30 | } 31 | 32 | // NewAdaptiveStrategySelector creates a new adaptive strategy selector 33 | func NewAdaptiveStrategySelector(config AdaptiveConfig) *AdaptiveStrategySelector { 34 | return &AdaptiveStrategySelector{ 35 | config: config, 36 | exactThreshold: config.InitialExactThreshold, 37 | dimThreshold: config.InitialDimThreshold, 38 | metrics: make(map[IndexType]*StrategyStats), 39 | recentQueries: make([]QueryMetrics, 0, config.MetricsWindowSize), 40 | rng: rand.New(rand.NewSource(time.Now().UnixNano())), 41 | } 42 | } 43 | 44 | // SelectStrategy chooses the best strategy for a given query context 45 | func (a *AdaptiveStrategySelector) SelectStrategy(vectorCount, dimension, k int) IndexType { 46 | a.mu.RLock() 47 | defer a.mu.RUnlock() 48 | 49 | // Exploration: sometimes try a random strategy to gather performance data 50 | if a.rng.Float64() < a.config.ExplorationFactor { 51 | if a.rng.Float64() < 0.5 { 52 | return ExactIndexType 53 | } 54 | return HNSWIndexType 55 | } 56 | 57 | // Exploitation: use learned thresholds to select strategy 58 | 59 | // For small datasets, exact search is often faster 60 | if vectorCount < a.exactThreshold { 61 | return ExactIndexType 62 | } 63 | 64 | // For high-dimensional data, consider the dimension threshold 65 | if dimension > a.dimThreshold { 66 | // For high-dimensional data with small k, HNSW usually performs better 67 | if k < 50 { 68 | return HNSWIndexType 69 | } 70 | // For high-dimensional data with large k, exact search might be better 71 | return ExactIndexType 72 | } 73 | 74 | // Default to HNSW for large datasets with low-to-medium dimensions 75 | return HNSWIndexType 76 | } 77 | 78 | // RecordQueryMetrics records performance metrics for a query 79 | func (a *AdaptiveStrategySelector) RecordQueryMetrics(metrics QueryMetrics) { 80 | a.mu.Lock() 81 | defer a.mu.Unlock() 82 | 83 | // Initialize strategy stats if needed 84 | if _, exists := a.metrics[metrics.Strategy]; !exists { 85 | a.metrics[metrics.Strategy] = &StrategyStats{ 86 | UsageCount: 0, 87 | TotalDuration: 0, 88 | AvgDuration: 0, 89 | } 90 | } 91 | 92 | // Update strategy stats 93 | stats := a.metrics[metrics.Strategy] 94 | stats.UsageCount++ 95 | stats.TotalDuration += metrics.Duration 96 | stats.AvgDuration = stats.TotalDuration / time.Duration(stats.UsageCount) 97 | 98 | // Add to recent queries 99 | a.recentQueries = append(a.recentQueries, metrics) 100 | if len(a.recentQueries) > a.config.MetricsWindowSize { 101 | // Remove oldest query when we exceed window size 102 | a.recentQueries = a.recentQueries[1:] 103 | } 104 | 105 | // Adapt thresholds periodically 106 | if stats.UsageCount%20 == 0 && len(a.recentQueries) >= 10 { 107 | a.adaptThresholds() 108 | } 109 | } 110 | 111 | // adaptThresholds adjusts the thresholds based on observed performance 112 | func (a *AdaptiveStrategySelector) adaptThresholds() { 113 | // Only adapt if we have stats for both strategies 114 | exactStats, hasExact := a.metrics[ExactIndexType] 115 | hnswStats, hasHNSW := a.metrics[HNSWIndexType] 116 | 117 | if !hasExact || !hasHNSW || exactStats.UsageCount < 10 || hnswStats.UsageCount < 10 { 118 | // Not enough data to adapt yet 119 | return 120 | } 121 | 122 | // Compare average performance of strategies 123 | // If exact is faster for larger datasets than our current threshold, 124 | // increase the threshold 125 | _ = exactStats.AvgDuration < hnswStats.AvgDuration 126 | 127 | // Analyze recent queries to find patterns 128 | var ( 129 | smallDatasetExactAvg time.Duration 130 | smallDatasetHNSWAvg time.Duration 131 | smallDatasetExactCount int 132 | smallDatasetHNSWCount int 133 | 134 | largeDatasetExactAvg time.Duration 135 | largeDatasetHNSWAvg time.Duration 136 | largeDatasetExactCount int 137 | largeDatasetHNSWCount int 138 | ) 139 | 140 | // Analyze recent queries by dataset size 141 | for _, q := range a.recentQueries { 142 | isSmall := q.ResultCount < a.exactThreshold 143 | 144 | if q.Strategy == ExactIndexType { 145 | if isSmall { 146 | smallDatasetExactAvg += q.Duration 147 | smallDatasetExactCount++ 148 | } else { 149 | largeDatasetExactAvg += q.Duration 150 | largeDatasetExactCount++ 151 | } 152 | } else if q.Strategy == HNSWIndexType { 153 | if isSmall { 154 | smallDatasetHNSWAvg += q.Duration 155 | smallDatasetHNSWCount++ 156 | } else { 157 | largeDatasetHNSWAvg += q.Duration 158 | largeDatasetHNSWCount++ 159 | } 160 | } 161 | } 162 | 163 | // Calculate averages 164 | if smallDatasetExactCount > 0 { 165 | smallDatasetExactAvg /= time.Duration(smallDatasetExactCount) 166 | } 167 | if smallDatasetHNSWCount > 0 { 168 | smallDatasetHNSWAvg /= time.Duration(smallDatasetHNSWCount) 169 | } 170 | // Only calculate these if they will be used later 171 | // For now these are calculated but not used, so we'll comment them out 172 | /* 173 | if largeDatasetExactCount > 0 { 174 | largeDatasetExactAvg /= time.Duration(largeDatasetExactCount) 175 | } 176 | if largeDatasetHNSWCount > 0 { 177 | largeDatasetHNSWAvg /= time.Duration(largeDatasetHNSWCount) 178 | } 179 | */ 180 | 181 | // Adapt exact threshold based on performance for small vs large datasets 182 | if smallDatasetExactCount > 5 && smallDatasetHNSWCount > 5 { 183 | if smallDatasetExactAvg < smallDatasetHNSWAvg { 184 | // Exact is faster for small datasets, increase threshold 185 | delta := int(float64(a.exactThreshold) * a.config.AdaptationRate) 186 | if delta < 10 { 187 | delta = 10 188 | } 189 | a.exactThreshold += delta 190 | } else { 191 | // HNSW is faster for small datasets, decrease threshold 192 | delta := int(float64(a.exactThreshold) * a.config.AdaptationRate) 193 | if delta < 10 { 194 | delta = 10 195 | } 196 | a.exactThreshold -= delta 197 | 198 | // Ensure threshold doesn't go below a reasonable minimum 199 | if a.exactThreshold < 100 { 200 | a.exactThreshold = 100 201 | } 202 | } 203 | } 204 | 205 | // Similarly, adapt dimension threshold based on performance for different dimensions 206 | // This would require additional analysis of performance by dimension 207 | // For simplicity, we'll keep this part as a future enhancement 208 | } 209 | 210 | // GetStats returns statistics about the adaptive selector 211 | func (a *AdaptiveStrategySelector) GetStats() map[string]interface{} { 212 | a.mu.RLock() 213 | defer a.mu.RUnlock() 214 | 215 | strategyStats := make(map[string]interface{}) 216 | for k, v := range a.metrics { 217 | strategyStats[string(k)] = v 218 | } 219 | 220 | return map[string]interface{}{ 221 | "thresholds": map[string]interface{}{ 222 | "exact": a.exactThreshold, 223 | "dimension": a.dimThreshold, 224 | }, 225 | "strategies": strategyStats, 226 | "config": a.config, 227 | "recent_queries_count": len(a.recentQueries), 228 | } 229 | } 230 | 231 | // String provides a string representation of the adaptive selector's state 232 | func (a *AdaptiveStrategySelector) String() string { 233 | a.mu.RLock() 234 | defer a.mu.RUnlock() 235 | 236 | exactStats, hasExact := a.metrics[ExactIndexType] 237 | hnswStats, hasHNSW := a.metrics[HNSWIndexType] 238 | 239 | exactAvg := time.Duration(0) 240 | hnswAvg := time.Duration(0) 241 | 242 | if hasExact && exactStats.UsageCount > 0 { 243 | exactAvg = exactStats.AvgDuration 244 | } 245 | 246 | if hasHNSW && hnswStats.UsageCount > 0 { 247 | hnswAvg = hnswStats.AvgDuration 248 | } 249 | 250 | return fmt.Sprintf( 251 | "AdaptiveStrategySelector{exactThreshold=%d, dimThreshold=%d, exactAvg=%v, hnswAvg=%v}", 252 | a.exactThreshold, 253 | a.dimThreshold, 254 | exactAvg, 255 | hnswAvg, 256 | ) 257 | } 258 | 259 | // UpdateThresholds updates the internal thresholds based on index statistics. 260 | func (a *AdaptiveStrategySelector) UpdateThresholds(exact, dim int) { 261 | a.mu.Lock() 262 | a.exactThreshold = exact 263 | a.dimThreshold = dim 264 | a.mu.Unlock() 265 | } 266 | -------------------------------------------------------------------------------- /pkg/persistence/parquet_test.go: -------------------------------------------------------------------------------- 1 | package persistence 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | // TestParquetVectorRecordConversion tests converting between VectorRecord and ParquetVectorRecord 11 | func TestParquetVectorRecordConversion(t *testing.T) { 12 | // Create a temporary directory for testing 13 | tempDir, err := os.MkdirTemp("", "quiver-parquet-test-*") 14 | if err != nil { 15 | t.Fatalf("Failed to create temp directory: %v", err) 16 | } 17 | defer os.RemoveAll(tempDir) 18 | 19 | // Create test vector records 20 | originalRecords := []VectorRecord{ 21 | { 22 | ID: "vec1", 23 | Vector: []float32{0.1, 0.2, 0.3, 0.4}, 24 | Metadata: map[string]string{"key1": "value1", "key2": "value2"}, 25 | }, 26 | { 27 | ID: "vec2", 28 | Vector: []float32{0.5, 0.6, 0.7, 0.8}, 29 | Metadata: map[string]string{"key3": "value3"}, 30 | }, 31 | { 32 | ID: "vec3", 33 | Vector: []float32{0.9, 1.0, 1.1, 1.2}, 34 | Metadata: nil, 35 | }, 36 | } 37 | 38 | // Path for the parquet file 39 | parquetPath := filepath.Join(tempDir, "test_vectors.parquet") 40 | 41 | // Write records to parquet file 42 | err = WriteVectorsToParquetFile(originalRecords, parquetPath) 43 | if err != nil { 44 | t.Fatalf("Failed to write vectors to parquet: %v", err) 45 | } 46 | 47 | // Read records from parquet file 48 | loadedRecords, err := ReadVectorsFromParquetFile(parquetPath) 49 | if err != nil { 50 | t.Fatalf("Failed to read vectors from parquet: %v", err) 51 | } 52 | 53 | // Verify number of records matches 54 | if len(loadedRecords) != len(originalRecords) { 55 | t.Fatalf("Expected %d records, got %d", len(originalRecords), len(loadedRecords)) 56 | } 57 | 58 | // Create maps for easier comparison 59 | originalMap := make(map[string]VectorRecord) 60 | for _, record := range originalRecords { 61 | originalMap[record.ID] = record 62 | } 63 | 64 | loadedMap := make(map[string]VectorRecord) 65 | for _, record := range loadedRecords { 66 | loadedMap[record.ID] = record 67 | } 68 | 69 | // Compare original and loaded records 70 | for id, originalRecord := range originalMap { 71 | loadedRecord, exists := loadedMap[id] 72 | if !exists { 73 | t.Errorf("Record with ID %s not found in loaded records", id) 74 | continue 75 | } 76 | 77 | // Compare vector values 78 | if !reflect.DeepEqual(loadedRecord.Vector, originalRecord.Vector) { 79 | t.Errorf("Vectors don't match for ID %s: expected %v, got %v", 80 | id, originalRecord.Vector, loadedRecord.Vector) 81 | } 82 | 83 | // Compare metadata 84 | if originalRecord.Metadata == nil { 85 | // If original metadata was nil, loaded should be empty map 86 | if len(loadedRecord.Metadata) > 0 { 87 | t.Errorf("Expected empty metadata for ID %s, got %v", id, loadedRecord.Metadata) 88 | } 89 | } else { 90 | for k, v := range originalRecord.Metadata { 91 | if loadedRecord.Metadata[k] != v { 92 | t.Errorf("Metadata mismatch for ID %s, key %s: expected %s, got %s", 93 | id, k, v, loadedRecord.Metadata[k]) 94 | } 95 | } 96 | 97 | for k := range loadedRecord.Metadata { 98 | if _, exists := originalRecord.Metadata[k]; !exists { 99 | t.Errorf("Unexpected metadata key %s for ID %s", k, id) 100 | } 101 | } 102 | } 103 | } 104 | } 105 | 106 | // TestParquetLargeVectorCount tests saving and loading a large number of vectors 107 | func TestParquetLargeVectorCount(t *testing.T) { 108 | // Create a temporary directory for testing 109 | tempDir, err := os.MkdirTemp("", "quiver-parquet-test-*") 110 | if err != nil { 111 | t.Fatalf("Failed to create temp directory: %v", err) 112 | } 113 | defer os.RemoveAll(tempDir) 114 | 115 | // Number of vectors to test 116 | numVectors := 2500 // Should test batch loading since batch size is 1000 117 | 118 | // Create test vector records 119 | originalRecords := make([]VectorRecord, numVectors) 120 | for i := 0; i < numVectors; i++ { 121 | originalRecords[i] = VectorRecord{ 122 | ID: generateTestID(i), 123 | Vector: generateTestVector(i, 4), 124 | Metadata: map[string]string{ 125 | "index": generateTestID(i), 126 | }, 127 | } 128 | } 129 | 130 | // Path for the parquet file 131 | parquetPath := filepath.Join(tempDir, "large_vectors.parquet") 132 | 133 | // Write records to parquet file 134 | err = WriteVectorsToParquetFile(originalRecords, parquetPath) 135 | if err != nil { 136 | t.Fatalf("Failed to write vectors to parquet: %v", err) 137 | } 138 | 139 | // Read records from parquet file 140 | loadedRecords, err := ReadVectorsFromParquetFile(parquetPath) 141 | if err != nil { 142 | t.Fatalf("Failed to read vectors from parquet: %v", err) 143 | } 144 | 145 | // Verify number of records matches 146 | if len(loadedRecords) != numVectors { 147 | t.Fatalf("Expected %d records, got %d", numVectors, len(loadedRecords)) 148 | } 149 | 150 | // Check a few random records for correctness 151 | indicesToCheck := []int{0, 1, numVectors / 2, numVectors - 2, numVectors - 1} 152 | for _, idx := range indicesToCheck { 153 | if idx >= len(originalRecords) || idx >= len(loadedRecords) { 154 | continue // Skip if index is out of bounds 155 | } 156 | 157 | original := originalRecords[idx] 158 | loaded := loadedRecords[idx] 159 | 160 | if original.ID != loaded.ID { 161 | t.Errorf("ID mismatch at index %d: expected %s, got %s", idx, original.ID, loaded.ID) 162 | } 163 | 164 | if !reflect.DeepEqual(original.Vector, loaded.Vector) { 165 | t.Errorf("Vector mismatch at index %d", idx) 166 | } 167 | 168 | if original.Metadata["index"] != loaded.Metadata["index"] { 169 | t.Errorf("Metadata mismatch at index %d: expected %s, got %s", 170 | idx, original.Metadata["index"], loaded.Metadata["index"]) 171 | } 172 | } 173 | } 174 | 175 | // TestParquetFileCreation tests that the parquet file is created correctly 176 | func TestParquetFileCreation(t *testing.T) { 177 | // Create a temporary directory for testing 178 | tempDir, err := os.MkdirTemp("", "quiver-parquet-test-*") 179 | if err != nil { 180 | t.Fatalf("Failed to create temp directory: %v", err) 181 | } 182 | defer os.RemoveAll(tempDir) 183 | 184 | // Create nested directory path to test directory creation 185 | nestedDir := filepath.Join(tempDir, "nested", "dir") 186 | parquetPath := filepath.Join(nestedDir, "test_vectors.parquet") 187 | 188 | // Create test vector records 189 | records := []VectorRecord{ 190 | { 191 | ID: "test", 192 | Vector: []float32{0.1, 0.2, 0.3}, 193 | Metadata: map[string]string{"key": "value"}, 194 | }, 195 | } 196 | 197 | // Write records to parquet file 198 | err = WriteVectorsToParquetFile(records, parquetPath) 199 | if err != nil { 200 | t.Fatalf("Failed to write vectors to parquet: %v", err) 201 | } 202 | 203 | // Verify file exists 204 | if _, err := os.Stat(parquetPath); os.IsNotExist(err) { 205 | t.Errorf("Parquet file was not created at %s", parquetPath) 206 | } 207 | 208 | // Verify file contains data 209 | fileInfo, err := os.Stat(parquetPath) 210 | if err != nil { 211 | t.Fatalf("Failed to get file info: %v", err) 212 | } 213 | 214 | if fileInfo.Size() == 0 { 215 | t.Error("Parquet file is empty") 216 | } 217 | } 218 | 219 | // TestParquetEmptyRecords tests saving and loading empty records 220 | func TestParquetEmptyRecords(t *testing.T) { 221 | // Create a temporary directory for testing 222 | tempDir, err := os.MkdirTemp("", "quiver-parquet-test-*") 223 | if err != nil { 224 | t.Fatalf("Failed to create temp directory: %v", err) 225 | } 226 | defer os.RemoveAll(tempDir) 227 | 228 | // Path for the parquet file 229 | parquetPath := filepath.Join(tempDir, "empty_vectors.parquet") 230 | 231 | // Create empty records slice 232 | var emptyRecords []VectorRecord 233 | 234 | // Write empty records to parquet file 235 | err = WriteVectorsToParquetFile(emptyRecords, parquetPath) 236 | if err != nil { 237 | t.Fatalf("Failed to write empty vectors to parquet: %v", err) 238 | } 239 | 240 | // Read records from parquet file 241 | loadedRecords, err := ReadVectorsFromParquetFile(parquetPath) 242 | if err != nil { 243 | t.Fatalf("Failed to read vectors from parquet: %v", err) 244 | } 245 | 246 | // Verify number of records matches 247 | if len(loadedRecords) != 0 { 248 | t.Fatalf("Expected 0 records, got %d", len(loadedRecords)) 249 | } 250 | } 251 | 252 | // Helper functions for testing 253 | 254 | // generateTestID generates a test ID for a given index 255 | func generateTestID(index int) string { 256 | return "vec_" + string(rune('a'+index%26)) + "_" + string(rune('0'+index%10)) 257 | } 258 | 259 | // generateTestVector generates a test vector for a given index and dimension 260 | func generateTestVector(index, dim int) []float32 { 261 | vector := make([]float32, dim) 262 | for i := 0; i < dim; i++ { 263 | vector[i] = float32(index*10+i) / 100.0 264 | } 265 | return vector 266 | } 267 | -------------------------------------------------------------------------------- /cmd/quiver/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | "time" 8 | 9 | "github.com/TFMV/quiver/pkg/api" 10 | "github.com/TFMV/quiver/pkg/core" 11 | "github.com/spf13/cobra" 12 | "github.com/spf13/viper" 13 | ) 14 | 15 | var ( 16 | cfgFile string 17 | dataDir string 18 | logLevel string 19 | version = "0.1.0" // Will be set during build 20 | startTime = time.Now() // Track app start time 21 | ) 22 | 23 | func main() { 24 | cobra.OnInitialize(initConfig) 25 | 26 | rootCmd := &cobra.Command{ 27 | Use: "quiver", 28 | Short: "Quiver - High-performance vector database", 29 | Long: `Quiver is a high-performance vector database optimized for 30 | machine learning applications and similarity search.`, 31 | Version: version, 32 | } 33 | 34 | // Global flags 35 | rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.quiver.yaml)") 36 | rootCmd.PersistentFlags().StringVar(&dataDir, "data-dir", "./data", "data directory for storing vectors") 37 | rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "log level (debug, info, warn, error)") 38 | 39 | // Add commands 40 | rootCmd.AddCommand(serveCmd) 41 | rootCmd.AddCommand(backupCmd) 42 | rootCmd.AddCommand(restoreCmd) 43 | rootCmd.AddCommand(infoCmd) 44 | 45 | // Execute root command 46 | if err := rootCmd.Execute(); err != nil { 47 | fmt.Println(err) 48 | os.Exit(1) 49 | } 50 | } 51 | 52 | // initConfig reads in config file and ENV variables if set 53 | func initConfig() { 54 | if cfgFile != "" { 55 | // Use config file from the flag 56 | viper.SetConfigFile(cfgFile) 57 | } else { 58 | // Find home directory 59 | home, err := os.UserHomeDir() 60 | if err != nil { 61 | fmt.Println(err) 62 | os.Exit(1) 63 | } 64 | 65 | // Search config in home directory with name ".quiver" 66 | viper.AddConfigPath(home) 67 | viper.AddConfigPath(".") 68 | viper.SetConfigName(".quiver") 69 | } 70 | 71 | // Read environment variables 72 | viper.AutomaticEnv() 73 | viper.SetEnvPrefix("QUIVER") 74 | 75 | // Read in config 76 | if err := viper.ReadInConfig(); err == nil { 77 | fmt.Println("Using config file:", viper.ConfigFileUsed()) 78 | } 79 | 80 | // Apply config to variables if set 81 | if viper.IsSet("data_dir") { 82 | dataDir = viper.GetString("data_dir") 83 | } 84 | 85 | if viper.IsSet("log_level") { 86 | logLevel = viper.GetString("log_level") 87 | } 88 | } 89 | 90 | // serveCmd represents the serve command 91 | var serveCmd = &cobra.Command{ 92 | Use: "serve", 93 | Short: "Start the Quiver server", 94 | Long: `Start the Quiver vector database server with the specified configuration.`, 95 | RunE: func(cmd *cobra.Command, args []string) error { 96 | host, _ := cmd.Flags().GetString("host") 97 | port, _ := cmd.Flags().GetInt("port") 98 | enableAuth, _ := cmd.Flags().GetBool("auth") 99 | jwtSecret, _ := cmd.Flags().GetString("jwt-secret") 100 | 101 | // Create data directory if needed 102 | if err := os.MkdirAll(dataDir, 0755); err != nil { 103 | return fmt.Errorf("failed to create data directory: %w", err) 104 | } 105 | 106 | // Initialize database 107 | fmt.Println("Initializing Quiver database...") 108 | 109 | dbOptions := core.DefaultDBOptions() 110 | dbOptions.StoragePath = dataDir 111 | dbOptions.EnablePersistence = true 112 | 113 | db, err := core.NewDB(dbOptions) 114 | if err != nil { 115 | return fmt.Errorf("failed to initialize database: %w", err) 116 | } 117 | 118 | // Initialize server 119 | fmt.Printf("Starting Quiver server on %s:%d...\n", host, port) 120 | 121 | // Create server config 122 | serverConfig := api.DefaultServerConfig() 123 | serverConfig.Host = host 124 | serverConfig.Port = port 125 | serverConfig.LogLevel = logLevel 126 | 127 | // Configure auth if enabled 128 | if enableAuth { 129 | if jwtSecret == "" { 130 | return fmt.Errorf("jwt-secret is required when auth is enabled") 131 | } 132 | serverConfig.JWTSecret = jwtSecret 133 | } 134 | 135 | // Start server 136 | server := api.NewServer(db, serverConfig) 137 | fmt.Println("Server is running. Press Ctrl+C to stop.") 138 | 139 | // This will block until server is shut down 140 | server.Start() 141 | return nil 142 | }, 143 | } 144 | 145 | // backupCmd represents the backup command 146 | var backupCmd = &cobra.Command{ 147 | Use: "backup PATH", 148 | Short: "Backup the database", 149 | Long: `Create a backup of the Quiver database to the specified path.`, 150 | Args: cobra.ExactArgs(1), 151 | RunE: func(cmd *cobra.Command, args []string) error { 152 | backupPath := args[0] 153 | 154 | // Convert to absolute path if needed 155 | if !filepath.IsAbs(backupPath) { 156 | absPath, err := filepath.Abs(backupPath) 157 | if err != nil { 158 | return fmt.Errorf("failed to resolve backup path: %w", err) 159 | } 160 | backupPath = absPath 161 | } 162 | 163 | // Initialize database 164 | fmt.Println("Initializing Quiver database...") 165 | 166 | dbOptions := core.DefaultDBOptions() 167 | dbOptions.StoragePath = dataDir 168 | dbOptions.EnablePersistence = true 169 | 170 | db, err := core.NewDB(dbOptions) 171 | if err != nil { 172 | return fmt.Errorf("failed to initialize database: %w", err) 173 | } 174 | 175 | // Create backup 176 | fmt.Printf("Creating backup at %s...\n", backupPath) 177 | if err := db.BackupDatabase(backupPath); err != nil { 178 | return fmt.Errorf("backup failed: %w", err) 179 | } 180 | 181 | fmt.Println("Backup completed successfully") 182 | return nil 183 | }, 184 | } 185 | 186 | // restoreCmd represents the restore command 187 | var restoreCmd = &cobra.Command{ 188 | Use: "restore PATH", 189 | Short: "Restore the database from backup", 190 | Long: `Restore the Quiver database from a backup at the specified path.`, 191 | Args: cobra.ExactArgs(1), 192 | RunE: func(cmd *cobra.Command, args []string) error { 193 | backupPath := args[0] 194 | 195 | // Convert to absolute path if needed 196 | if !filepath.IsAbs(backupPath) { 197 | absPath, err := filepath.Abs(backupPath) 198 | if err != nil { 199 | return fmt.Errorf("failed to resolve backup path: %w", err) 200 | } 201 | backupPath = absPath 202 | } 203 | 204 | // Initialize database 205 | fmt.Println("Initializing Quiver database...") 206 | 207 | dbOptions := core.DefaultDBOptions() 208 | dbOptions.StoragePath = dataDir 209 | dbOptions.EnablePersistence = true 210 | 211 | db, err := core.NewDB(dbOptions) 212 | if err != nil { 213 | return fmt.Errorf("failed to initialize database: %w", err) 214 | } 215 | 216 | // Restore from backup 217 | fmt.Printf("Restoring from backup at %s...\n", backupPath) 218 | if err := db.RestoreDatabase(backupPath); err != nil { 219 | return fmt.Errorf("restore failed: %w", err) 220 | } 221 | 222 | fmt.Println("Database restored successfully") 223 | return nil 224 | }, 225 | } 226 | 227 | // infoCmd represents the info command 228 | var infoCmd = &cobra.Command{ 229 | Use: "info", 230 | Short: "Show database information", 231 | Long: `Display information about the Quiver database.`, 232 | RunE: func(cmd *cobra.Command, args []string) error { 233 | // Initialize database (read-only mode) 234 | fmt.Println("Initializing Quiver database...") 235 | 236 | dbOptions := core.DefaultDBOptions() 237 | dbOptions.StoragePath = dataDir 238 | dbOptions.EnablePersistence = true 239 | 240 | db, err := core.NewDB(dbOptions) 241 | if err != nil { 242 | return fmt.Errorf("failed to initialize database: %w", err) 243 | } 244 | 245 | // Calculate total vector count 246 | collections := db.ListCollections() 247 | totalVectors := 0 248 | for _, colName := range collections { 249 | col, err := db.GetCollection(colName) 250 | if err == nil && col != nil { 251 | totalVectors += col.Count() 252 | } 253 | } 254 | 255 | // Print database info 256 | fmt.Println("\nQuiver Database Information") 257 | fmt.Println("---------------------------") 258 | fmt.Printf("Version: %s\n", version) 259 | fmt.Printf("Data Directory: %s\n", dataDir) 260 | fmt.Printf("Collections: %d\n", len(collections)) 261 | fmt.Printf("Total Vectors: %d\n", totalVectors) 262 | fmt.Printf("Uptime: %s\n", time.Since(startTime).Round(time.Second)) 263 | fmt.Println("---------------------------") 264 | 265 | // Print collection details if any exist 266 | if len(collections) > 0 { 267 | fmt.Println("\nCollections:") 268 | for _, colName := range collections { 269 | col, err := db.GetCollection(colName) 270 | if err == nil && col != nil { 271 | stats := col.Stats() 272 | fmt.Printf("- %s: %d vectors, %d dimensions\n", 273 | colName, stats.VectorCount, stats.Dimension) 274 | } 275 | } 276 | } 277 | 278 | return nil 279 | }, 280 | } 281 | 282 | func init() { 283 | // Add server-specific flags 284 | serveCmd.Flags().String("host", "localhost", "server host") 285 | serveCmd.Flags().Int("port", 8080, "server port") 286 | serveCmd.Flags().Bool("auth", false, "enable JWT authentication") 287 | serveCmd.Flags().String("jwt-secret", "", "JWT secret key for authentication") 288 | serveCmd.Flags().Bool("cors", true, "enable CORS") 289 | 290 | // Bind flags to viper 291 | if err := viper.BindPFlag("host", serveCmd.Flags().Lookup("host")); err != nil { 292 | fmt.Printf("Error binding host flag: %v\n", err) 293 | } 294 | if err := viper.BindPFlag("port", serveCmd.Flags().Lookup("port")); err != nil { 295 | fmt.Printf("Error binding port flag: %v\n", err) 296 | } 297 | if err := viper.BindPFlag("auth", serveCmd.Flags().Lookup("auth")); err != nil { 298 | fmt.Printf("Error binding auth flag: %v\n", err) 299 | } 300 | if err := viper.BindPFlag("jwt_secret", serveCmd.Flags().Lookup("jwt-secret")); err != nil { 301 | fmt.Printf("Error binding jwt_secret flag: %v\n", err) 302 | } 303 | if err := viper.BindPFlag("cors", serveCmd.Flags().Lookup("cors")); err != nil { 304 | fmt.Printf("Error binding cors flag: %v\n", err) 305 | } 306 | } 307 | -------------------------------------------------------------------------------- /pkg/persistence/collection.go: -------------------------------------------------------------------------------- 1 | package persistence 2 | 3 | import ( 4 | "fmt" 5 | "sort" 6 | "sync" 7 | "time" 8 | 9 | "github.com/TFMV/quiver/pkg/facets" 10 | "github.com/TFMV/quiver/pkg/vectortypes" 11 | ) 12 | 13 | // Collection implements a basic vector collection that can be persisted. 14 | // This serves as a reference implementation of the Persistable interface. 15 | type Collection struct { 16 | // Basic information 17 | name string 18 | dimension int 19 | 20 | // Vectors and metadata 21 | vectors map[string][]float32 22 | metadata map[string]map[string]string 23 | 24 | // Distance function 25 | distanceFunc vectortypes.DistanceFunc 26 | 27 | // Mutex for thread safety 28 | mu sync.RWMutex 29 | 30 | // Whether the collection has been modified since last save 31 | dirty bool 32 | 33 | // When the collection was created 34 | createdAt time.Time 35 | 36 | // Facet fields and values 37 | facetFields []string 38 | vectorFacets map[string][]facets.FacetValue 39 | } 40 | 41 | // NewCollection creates a new persistable collection 42 | func NewCollection(name string, dimension int, distanceFunc vectortypes.DistanceFunc) *Collection { 43 | return &Collection{ 44 | name: name, 45 | dimension: dimension, 46 | vectors: make(map[string][]float32), 47 | metadata: make(map[string]map[string]string), 48 | distanceFunc: distanceFunc, 49 | mu: sync.RWMutex{}, 50 | dirty: false, 51 | createdAt: time.Now(), 52 | facetFields: []string{}, 53 | vectorFacets: make(map[string][]facets.FacetValue), 54 | } 55 | } 56 | 57 | // GetName implements the Persistable interface 58 | func (c *Collection) GetName() string { 59 | return c.name 60 | } 61 | 62 | // GetDimension implements the Persistable interface 63 | func (c *Collection) GetDimension() int { 64 | return c.dimension 65 | } 66 | 67 | // GetVectors implements the Persistable interface 68 | func (c *Collection) GetVectors() []VectorRecord { 69 | c.mu.RLock() 70 | defer c.mu.RUnlock() 71 | 72 | records := make([]VectorRecord, 0, len(c.vectors)) 73 | for id, vector := range c.vectors { 74 | vecCopy := make([]float32, len(vector)) 75 | copy(vecCopy, vector) 76 | record := VectorRecord{ 77 | ID: id, 78 | Vector: vecCopy, 79 | } 80 | 81 | if meta, ok := c.metadata[id]; ok { 82 | metaCopy := make(map[string]string, len(meta)) 83 | for k, v := range meta { 84 | metaCopy[k] = v 85 | } 86 | record.Metadata = metaCopy 87 | } 88 | 89 | records = append(records, record) 90 | } 91 | 92 | return records 93 | } 94 | 95 | // AddVector implements the Persistable interface 96 | func (c *Collection) AddVector(id string, vector []float32, metadata map[string]string) error { 97 | c.mu.Lock() 98 | defer c.mu.Unlock() 99 | 100 | // Validate vector dimension 101 | if len(vector) != c.dimension { 102 | return fmt.Errorf("vector dimension mismatch: got %d, expected %d", len(vector), c.dimension) 103 | } 104 | 105 | // Store a copy of the vector to prevent external modifications 106 | vecCopy := make([]float32, len(vector)) 107 | copy(vecCopy, vector) 108 | c.vectors[id] = vecCopy 109 | 110 | // Store metadata if provided 111 | if metadata != nil { 112 | metaCopy := make(map[string]string, len(metadata)) 113 | for k, v := range metadata { 114 | metaCopy[k] = v 115 | } 116 | c.metadata[id] = metaCopy 117 | 118 | // Extract and store facets if facet fields are defined 119 | if len(c.facetFields) > 0 { 120 | // Convert string map to interface map for facet extraction 121 | metadataMap := make(map[string]interface{}, len(metadata)) 122 | for k, v := range metadata { 123 | metadataMap[k] = v 124 | } 125 | c.vectorFacets[id] = facets.ExtractFacets(metadataMap, c.facetFields) 126 | } 127 | } 128 | 129 | // Mark as dirty 130 | c.dirty = true 131 | 132 | return nil 133 | } 134 | 135 | // DeleteVector removes a vector from the collection 136 | func (c *Collection) DeleteVector(id string) error { 137 | c.mu.Lock() 138 | defer c.mu.Unlock() 139 | 140 | // Check if vector exists 141 | if _, exists := c.vectors[id]; !exists { 142 | return fmt.Errorf("vector with ID %s not found", id) 143 | } 144 | 145 | // Delete vector 146 | delete(c.vectors, id) 147 | 148 | // Delete metadata if exists 149 | delete(c.metadata, id) 150 | 151 | // Delete facets if exists 152 | delete(c.vectorFacets, id) 153 | 154 | // Mark as dirty 155 | c.dirty = true 156 | 157 | return nil 158 | } 159 | 160 | // GetVector retrieves a vector by ID 161 | func (c *Collection) GetVector(id string) ([]float32, map[string]string, error) { 162 | c.mu.RLock() 163 | defer c.mu.RUnlock() 164 | 165 | // Check if vector exists 166 | vector, exists := c.vectors[id] 167 | if !exists { 168 | return nil, nil, fmt.Errorf("vector with ID %s not found", id) 169 | } 170 | 171 | vecCopy := make([]float32, len(vector)) 172 | copy(vecCopy, vector) 173 | 174 | meta, ok := c.metadata[id] 175 | var metaCopy map[string]string 176 | if ok { 177 | metaCopy = make(map[string]string, len(meta)) 178 | for k, v := range meta { 179 | metaCopy[k] = v 180 | } 181 | } 182 | 183 | return vecCopy, metaCopy, nil 184 | } 185 | 186 | // IsDirty returns whether the collection has been modified since last save 187 | func (c *Collection) IsDirty() bool { 188 | c.mu.RLock() 189 | defer c.mu.RUnlock() 190 | return c.dirty 191 | } 192 | 193 | // MarkClean marks the collection as clean (not dirty) 194 | func (c *Collection) MarkClean() { 195 | c.mu.Lock() 196 | defer c.mu.Unlock() 197 | c.dirty = false 198 | } 199 | 200 | // Search finds the most similar vectors to a query vector 201 | func (c *Collection) Search(query []float32, limit int) ([]SearchResult, error) { 202 | c.mu.RLock() 203 | defer c.mu.RUnlock() 204 | 205 | if c.distanceFunc == nil { 206 | return nil, fmt.Errorf("distance function is not set") 207 | } 208 | if query == nil { 209 | return nil, fmt.Errorf("query vector is nil") 210 | } 211 | 212 | // Validate query vector dimension 213 | if len(query) != c.dimension { 214 | return nil, fmt.Errorf("query vector dimension mismatch: got %d, expected %d", len(query), c.dimension) 215 | } 216 | 217 | // Calculate distances 218 | results := make([]SearchResult, 0, len(c.vectors)) 219 | for id, vector := range c.vectors { 220 | distance := c.distanceFunc(query, vector) 221 | results = append(results, SearchResult{ 222 | ID: id, 223 | Distance: distance, 224 | }) 225 | } 226 | 227 | // Sort by distance (nearest first) 228 | SortSearchResults(results) 229 | 230 | // Limit results 231 | if limit > 0 && limit < len(results) { 232 | results = results[:limit] 233 | } 234 | 235 | return results, nil 236 | } 237 | 238 | // SearchResult represents a search result with ID and distance 239 | type SearchResult struct { 240 | ID string 241 | Distance float32 242 | } 243 | 244 | // SortSearchResults sorts search results by distance (ascending) 245 | func SortSearchResults(results []SearchResult) { 246 | sort.Slice(results, func(i, j int) bool { 247 | return results[i].Distance < results[j].Distance 248 | }) 249 | } 250 | 251 | // Count returns the number of vectors in the collection 252 | func (c *Collection) Count() int { 253 | c.mu.RLock() 254 | defer c.mu.RUnlock() 255 | return len(c.vectors) 256 | } 257 | 258 | // SetFacetFields sets the fields to be indexed as facets for future vectors 259 | func (c *Collection) SetFacetFields(fields []string) { 260 | c.mu.Lock() 261 | defer c.mu.Unlock() 262 | 263 | c.facetFields = fields 264 | 265 | // Reindex existing vectors' facets 266 | c.vectorFacets = make(map[string][]facets.FacetValue) 267 | 268 | for id, metadata := range c.metadata { 269 | // Convert string map to interface map for facet extraction 270 | metadataMap := make(map[string]interface{}) 271 | for k, v := range metadata { 272 | metadataMap[k] = v 273 | } 274 | c.vectorFacets[id] = facets.ExtractFacets(metadataMap, fields) 275 | } 276 | 277 | // Mark as dirty 278 | c.dirty = true 279 | } 280 | 281 | // GetFacetFields returns the fields that are indexed as facets 282 | func (c *Collection) GetFacetFields() []string { 283 | c.mu.RLock() 284 | defer c.mu.RUnlock() 285 | return c.facetFields 286 | } 287 | 288 | // GetVectorFacets returns the facet values for a specific vector 289 | func (c *Collection) GetVectorFacets(id string) ([]facets.FacetValue, bool) { 290 | c.mu.RLock() 291 | defer c.mu.RUnlock() 292 | 293 | facetValues, exists := c.vectorFacets[id] 294 | return facetValues, exists 295 | } 296 | 297 | // SearchWithFacets searches for vectors similar to the query vector, with optional facet filters 298 | func (c *Collection) SearchWithFacets(query []float32, limit int, filters []facets.Filter) ([]SearchResult, error) { 299 | c.mu.RLock() 300 | defer c.mu.RUnlock() 301 | 302 | if c.distanceFunc == nil { 303 | return nil, fmt.Errorf("distance function is not set") 304 | } 305 | if query == nil { 306 | return nil, fmt.Errorf("query vector is nil") 307 | } 308 | 309 | // Validate query vector dimension 310 | if len(query) != c.dimension { 311 | return nil, fmt.Errorf("query vector dimension mismatch: got %d, expected %d", len(query), c.dimension) 312 | } 313 | 314 | // If no filters, use the normal search 315 | if len(filters) == 0 { 316 | return c.Search(query, limit) 317 | } 318 | 319 | // Calculate distances 320 | results := make([]SearchResult, 0, len(c.vectors)) 321 | for id, vector := range c.vectors { 322 | // Check if vector passes facet filters 323 | if facetValues, exists := c.vectorFacets[id]; exists { 324 | if !facets.MatchesAllFilters(facetValues, filters) { 325 | continue 326 | } 327 | } else { 328 | // Skip vectors without facet values 329 | continue 330 | } 331 | 332 | // Calculate distance for vectors that passed filters 333 | distance := c.distanceFunc(query, vector) 334 | results = append(results, SearchResult{ 335 | ID: id, 336 | Distance: distance, 337 | }) 338 | } 339 | 340 | // Sort by distance (nearest first) 341 | SortSearchResults(results) 342 | 343 | // Limit results 344 | if limit > 0 && limit < len(results) { 345 | results = results[:limit] 346 | } 347 | 348 | return results, nil 349 | } 350 | -------------------------------------------------------------------------------- /pkg/hybrid/exact_test.go: -------------------------------------------------------------------------------- 1 | package hybrid 2 | 3 | import ( 4 | "reflect" 5 | "sort" 6 | "testing" 7 | 8 | "github.com/TFMV/quiver/pkg/types" 9 | "github.com/TFMV/quiver/pkg/vectortypes" 10 | ) 11 | 12 | func TestNewExactIndex(t *testing.T) { 13 | distFunc := vectortypes.CosineDistance 14 | idx := NewExactIndex(distFunc) 15 | 16 | if idx.distFunc == nil { 17 | t.Error("Distance function not set correctly") 18 | } 19 | 20 | if idx.vectors == nil { 21 | t.Error("Vectors map not initialized") 22 | } 23 | 24 | if len(idx.vectors) != 0 { 25 | t.Errorf("Expected empty vectors map, got %d items", len(idx.vectors)) 26 | } 27 | } 28 | 29 | func TestExactIndex_Insert(t *testing.T) { 30 | idx := NewExactIndex(vectortypes.CosineDistance) 31 | 32 | // Test inserting a vector 33 | id := "test1" 34 | vector := vectortypes.F32{0.1, 0.2, 0.3, 0.4} 35 | 36 | err := idx.Insert(id, vector) 37 | if err != nil { 38 | t.Errorf("Insert returned unexpected error: %v", err) 39 | } 40 | 41 | // Verify vector was inserted 42 | if len(idx.vectors) != 1 { 43 | t.Errorf("Expected 1 vector in index, got %d", len(idx.vectors)) 44 | } 45 | 46 | // Verify the stored vector is a copy, not the original 47 | storedVector, exists := idx.vectors[id] 48 | if !exists { 49 | t.Fatalf("Vector with id %s not found in index", id) 50 | } 51 | 52 | if !reflect.DeepEqual(storedVector, vector) { 53 | t.Errorf("Stored vector %v does not match original %v", storedVector, vector) 54 | } 55 | 56 | // Modify the original vector and verify the stored one is unchanged 57 | vector[0] = 0.9 58 | if storedVector[0] == vector[0] { 59 | t.Error("Stored vector should be a copy, not a reference to the original") 60 | } 61 | } 62 | 63 | func TestExactIndex_Delete(t *testing.T) { 64 | idx := NewExactIndex(vectortypes.CosineDistance) 65 | 66 | // Insert a vector first 67 | id := "test1" 68 | vector := vectortypes.F32{0.1, 0.2, 0.3, 0.4} 69 | err := idx.Insert(id, vector) 70 | if err != nil { 71 | t.Fatalf("Insert failed: %v", err) 72 | } 73 | 74 | // Verify it exists 75 | if len(idx.vectors) != 1 { 76 | t.Fatalf("Expected 1 vector in index, got %d", len(idx.vectors)) 77 | } 78 | 79 | // Delete it 80 | err = idx.Delete(id) 81 | if err != nil { 82 | t.Errorf("Delete returned unexpected error: %v", err) 83 | } 84 | 85 | // Verify it's gone 86 | if len(idx.vectors) != 0 { 87 | t.Errorf("Expected 0 vectors after deletion, got %d", len(idx.vectors)) 88 | } 89 | 90 | // Deleting a non-existent ID should not error 91 | err = idx.Delete("nonexistent") 92 | if err != nil { 93 | t.Errorf("Delete of non-existent ID returned error: %v", err) 94 | } 95 | } 96 | 97 | func TestExactIndex_Search(t *testing.T) { 98 | idx := NewExactIndex(vectortypes.CosineDistance) 99 | 100 | // Insert some test vectors 101 | testVectors := map[string]vectortypes.F32{ 102 | "vec1": {1.0, 0.0, 0.0}, // Aligned with x-axis 103 | "vec2": {0.0, 1.0, 0.0}, // Aligned with y-axis 104 | "vec3": {0.0, 0.0, 1.0}, // Aligned with z-axis 105 | } 106 | 107 | for id, vec := range testVectors { 108 | if err := idx.Insert(id, vec); err != nil { 109 | t.Fatalf("Failed to insert vector %s: %v", id, err) 110 | } 111 | } 112 | 113 | // Test cases 114 | tests := []struct { 115 | name string 116 | query vectortypes.F32 117 | k int 118 | wantIDs []string 119 | exactOrder bool 120 | }{ 121 | { 122 | name: "Query similar to vec1", 123 | query: vectortypes.F32{0.9, 0.1, 0.0}, 124 | k: 2, 125 | wantIDs: []string{"vec1", "vec2"}, 126 | exactOrder: true, 127 | }, 128 | { 129 | name: "Query similar to vec2", 130 | query: vectortypes.F32{0.1, 0.9, 0.0}, 131 | k: 2, 132 | wantIDs: []string{"vec2", "vec1"}, 133 | exactOrder: true, 134 | }, 135 | { 136 | name: "Query similar to vec3", 137 | query: vectortypes.F32{0.1, 0.1, 0.9}, 138 | k: 1, 139 | wantIDs: []string{"vec3"}, 140 | exactOrder: true, 141 | }, 142 | { 143 | name: "Query for all vectors", 144 | query: vectortypes.F32{0.5, 0.5, 0.5}, 145 | k: 3, 146 | wantIDs: []string{"vec1", "vec2", "vec3"}, 147 | exactOrder: false, // Don't care about order since distances may be similar 148 | }, 149 | { 150 | name: "K larger than number of vectors", 151 | query: vectortypes.F32{1.0, 0.0, 0.0}, 152 | k: 10, 153 | wantIDs: []string{"vec1", "vec2", "vec3"}, 154 | exactOrder: false, // Changed to false as vec2 and vec3 have equal distance 155 | }, 156 | } 157 | 158 | for _, tt := range tests { 159 | t.Run(tt.name, func(t *testing.T) { 160 | results, err := idx.Search(tt.query, tt.k) 161 | if err != nil { 162 | t.Fatalf("Search returned unexpected error: %v", err) 163 | } 164 | 165 | // Check result count 166 | expectedCount := tt.k 167 | if expectedCount > len(testVectors) { 168 | expectedCount = len(testVectors) 169 | } 170 | if len(results) != expectedCount { 171 | t.Errorf("Expected %d results, got %d", expectedCount, len(results)) 172 | } 173 | 174 | // If exact order matters, check each position 175 | if tt.exactOrder { 176 | for i, wantID := range tt.wantIDs { 177 | if i < len(results) && results[i].ID != wantID { 178 | t.Errorf("Result at position %d: got ID %s, want ID %s", 179 | i, results[i].ID, wantID) 180 | } 181 | } 182 | } else { 183 | // Otherwise just check that all expected IDs are present 184 | foundIDs := make(map[string]bool) 185 | for _, result := range results { 186 | foundIDs[result.ID] = true 187 | } 188 | 189 | for _, wantID := range tt.wantIDs { 190 | if !foundIDs[wantID] { 191 | t.Errorf("Expected ID %s not found in results", wantID) 192 | } 193 | } 194 | } 195 | 196 | // Check that distances are sorted 197 | for i := 0; i < len(results)-1; i++ { 198 | if results[i].Distance > results[i+1].Distance { 199 | t.Errorf("Results not sorted by distance: %f > %f", 200 | results[i].Distance, results[i+1].Distance) 201 | } 202 | } 203 | }) 204 | } 205 | } 206 | 207 | func TestExactIndex_Size(t *testing.T) { 208 | idx := NewExactIndex(vectortypes.CosineDistance) 209 | 210 | // Empty index should have size 0 211 | if size := idx.Size(); size != 0 { 212 | t.Errorf("Expected size 0 for empty index, got %d", size) 213 | } 214 | 215 | // Insert a vector 216 | err := idx.Insert("test1", vectortypes.F32{0.1, 0.2, 0.3}) 217 | if err != nil { 218 | t.Fatalf("Insert failed: %v", err) 219 | } 220 | 221 | // Size should be 1 222 | if size := idx.Size(); size != 1 { 223 | t.Errorf("Expected size 1 after insertion, got %d", size) 224 | } 225 | 226 | // Insert another vector 227 | err = idx.Insert("test2", vectortypes.F32{0.4, 0.5, 0.6}) 228 | if err != nil { 229 | t.Fatalf("Insert failed: %v", err) 230 | } 231 | 232 | // Size should be 2 233 | if size := idx.Size(); size != 2 { 234 | t.Errorf("Expected size 2 after second insertion, got %d", size) 235 | } 236 | 237 | // Delete a vector 238 | err = idx.Delete("test1") 239 | if err != nil { 240 | t.Fatalf("Delete failed: %v", err) 241 | } 242 | 243 | // Size should be 1 again 244 | if size := idx.Size(); size != 1 { 245 | t.Errorf("Expected size 1 after deletion, got %d", size) 246 | } 247 | } 248 | 249 | func TestExactIndex_GetType(t *testing.T) { 250 | idx := NewExactIndex(vectortypes.CosineDistance) 251 | if idx.GetType() != ExactIndexType { 252 | t.Errorf("Expected index type %s, got %s", ExactIndexType, idx.GetType()) 253 | } 254 | } 255 | 256 | func TestExactIndex_GetStats(t *testing.T) { 257 | idx := NewExactIndex(vectortypes.CosineDistance) 258 | 259 | // Insert some vectors 260 | for i := 0; i < 5; i++ { 261 | id := "test" + string(rune('0'+i)) 262 | vector := vectortypes.F32{float32(i) * 0.1, float32(i) * 0.2, float32(i) * 0.3} 263 | if err := idx.Insert(id, vector); err != nil { 264 | t.Fatalf("Insert failed: %v", err) 265 | } 266 | } 267 | 268 | // Get stats 269 | stats := idx.GetStats().(map[string]interface{}) 270 | 271 | // Check type 272 | if typeStr, ok := stats["type"].(string); !ok || typeStr != string(ExactIndexType) { 273 | t.Errorf("Expected type %s, got %v", ExactIndexType, stats["type"]) 274 | } 275 | 276 | // Check vector count 277 | if count, ok := stats["vector_count"].(int); !ok || count != 5 { 278 | t.Errorf("Expected vector_count 5, got %v", stats["vector_count"]) 279 | } 280 | } 281 | 282 | func TestResultHeap(t *testing.T) { 283 | // Create a result heap for testing 284 | heap := &resultHeap{} 285 | 286 | // Add results in any order 287 | *heap = append(*heap, types.BasicSearchResult{ID: "vec1", Distance: 0.5}) 288 | *heap = append(*heap, types.BasicSearchResult{ID: "vec2", Distance: 0.2}) 289 | *heap = append(*heap, types.BasicSearchResult{ID: "vec3", Distance: 0.8}) 290 | *heap = append(*heap, types.BasicSearchResult{ID: "vec4", Distance: 0.3}) 291 | 292 | // Initialize the heap 293 | sort.Sort(heap) 294 | 295 | // Pop results off the heap 296 | var popped []types.BasicSearchResult 297 | for heap.Len() > 0 { 298 | popped = append(popped, (*heap)[0]) 299 | *heap = (*heap)[1:] 300 | if heap.Len() > 0 { 301 | sort.Sort(heap) 302 | } 303 | } 304 | 305 | // Results should be in ascending order of distance 306 | expected := []string{"vec2", "vec4", "vec1", "vec3"} 307 | for i, id := range expected { 308 | if popped[i].ID != id { 309 | t.Errorf("Expected %s at position %d, got %s", id, i, popped[i].ID) 310 | } 311 | } 312 | } 313 | 314 | func TestExactIndex_DimensionMismatch(t *testing.T) { 315 | idx := NewExactIndex(vectortypes.CosineDistance) 316 | if err := idx.Insert("vec1", vectortypes.F32{0.1, 0.2}); err != nil { 317 | t.Fatalf("unexpected error: %v", err) 318 | } 319 | if err := idx.Insert("vec2", vectortypes.F32{0.3}); err == nil { 320 | t.Errorf("expected dimension mismatch error") 321 | } 322 | } 323 | 324 | func TestExactIndex_SearchDimensionMismatch(t *testing.T) { 325 | idx := NewExactIndex(vectortypes.CosineDistance) 326 | _ = idx.Insert("vec1", vectortypes.F32{0.1, 0.2}) 327 | if _, err := idx.Search(vectortypes.F32{0.3}, 1); err == nil { 328 | t.Errorf("expected dimension mismatch error") 329 | } 330 | } 331 | -------------------------------------------------------------------------------- /pkg/hybrid/adaptive_test.go: -------------------------------------------------------------------------------- 1 | package hybrid 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | ) 7 | 8 | func TestNewAdaptiveStrategySelector(t *testing.T) { 9 | config := DefaultAdaptiveConfig() 10 | 11 | // Modify some values to ensure they're properly set 12 | config.ExplorationFactor = 0.2 13 | config.InitialExactThreshold = 2000 14 | config.InitialDimThreshold = 150 15 | 16 | selector := NewAdaptiveStrategySelector(config) 17 | 18 | if selector == nil { 19 | t.Fatal("NewAdaptiveStrategySelector returned nil") 20 | } 21 | 22 | if selector.config.ExplorationFactor != 0.2 { 23 | t.Errorf("Expected ExplorationFactor 0.2, got %f", selector.config.ExplorationFactor) 24 | } 25 | 26 | if selector.exactThreshold != 2000 { 27 | t.Errorf("Expected exactThreshold 2000, got %d", selector.exactThreshold) 28 | } 29 | 30 | if selector.dimThreshold != 150 { 31 | t.Errorf("Expected dimThreshold 150, got %d", selector.dimThreshold) 32 | } 33 | 34 | if selector.metrics == nil { 35 | t.Error("metrics map not initialized") 36 | } 37 | 38 | if selector.recentQueries == nil { 39 | t.Error("recentQueries slice not initialized") 40 | } 41 | 42 | if selector.rng == nil { 43 | t.Error("random number generator not initialized") 44 | } 45 | } 46 | 47 | func TestSelectStrategy_SmallDataset(t *testing.T) { 48 | config := DefaultAdaptiveConfig() 49 | // Set exploration to 0 to make tests deterministic 50 | config.ExplorationFactor = 0 51 | config.InitialExactThreshold = 1000 52 | 53 | selector := NewAdaptiveStrategySelector(config) 54 | 55 | // For small datasets, should prefer exact search 56 | strategy := selector.SelectStrategy(500, 100, 10) 57 | if strategy != ExactIndexType { 58 | t.Errorf("Expected ExactIndexType for small dataset, got %s", strategy) 59 | } 60 | } 61 | 62 | func TestSelectStrategy_LargeDataset(t *testing.T) { 63 | config := DefaultAdaptiveConfig() 64 | // Set exploration to 0 to make tests deterministic 65 | config.ExplorationFactor = 0 66 | config.InitialExactThreshold = 1000 67 | 68 | selector := NewAdaptiveStrategySelector(config) 69 | 70 | // For large datasets with medium dimensions, should prefer HNSW 71 | strategy := selector.SelectStrategy(5000, 100, 10) 72 | if strategy != HNSWIndexType { 73 | t.Errorf("Expected HNSWIndexType for large dataset, got %s", strategy) 74 | } 75 | } 76 | 77 | func TestSelectStrategy_HighDimensionalSmallK(t *testing.T) { 78 | config := DefaultAdaptiveConfig() 79 | // Set exploration to 0 to make tests deterministic 80 | config.ExplorationFactor = 0 81 | config.InitialExactThreshold = 1000 82 | config.InitialDimThreshold = 100 83 | 84 | selector := NewAdaptiveStrategySelector(config) 85 | 86 | // For high dimensional data with small k, should prefer HNSW 87 | strategy := selector.SelectStrategy(5000, 150, 10) 88 | if strategy != HNSWIndexType { 89 | t.Errorf("Expected HNSWIndexType for high-dim small-k, got %s", strategy) 90 | } 91 | } 92 | 93 | func TestSelectStrategy_HighDimensionalLargeK(t *testing.T) { 94 | config := DefaultAdaptiveConfig() 95 | // Set exploration to 0 to make tests deterministic 96 | config.ExplorationFactor = 0 97 | config.InitialExactThreshold = 1000 98 | config.InitialDimThreshold = 100 99 | 100 | selector := NewAdaptiveStrategySelector(config) 101 | 102 | // For high dimensional data with large k, should prefer exact search 103 | strategy := selector.SelectStrategy(5000, 150, 100) 104 | if strategy != ExactIndexType { 105 | t.Errorf("Expected ExactIndexType for high-dim large-k, got %s", strategy) 106 | } 107 | } 108 | 109 | func TestRecordQueryMetrics_InitializeStats(t *testing.T) { 110 | selector := NewAdaptiveStrategySelector(DefaultAdaptiveConfig()) 111 | 112 | // Record a metric for a strategy that hasn't been seen before 113 | metrics := QueryMetrics{ 114 | Strategy: ExactIndexType, 115 | QueryDimension: 128, 116 | K: 10, 117 | Duration: 100 * time.Millisecond, 118 | ResultCount: 5, 119 | Timestamp: time.Now(), 120 | } 121 | 122 | selector.RecordQueryMetrics(metrics) 123 | 124 | // Check that stats were initialized 125 | if _, exists := selector.metrics[ExactIndexType]; !exists { 126 | t.Fatalf("Expected stats for %s to be initialized", ExactIndexType) 127 | } 128 | 129 | stats := selector.metrics[ExactIndexType] 130 | if stats.UsageCount != 1 { 131 | t.Errorf("Expected UsageCount 1, got %d", stats.UsageCount) 132 | } 133 | 134 | if stats.TotalDuration != 100*time.Millisecond { 135 | t.Errorf("Expected TotalDuration 100ms, got %v", stats.TotalDuration) 136 | } 137 | 138 | if stats.AvgDuration != 100*time.Millisecond { 139 | t.Errorf("Expected AvgDuration 100ms, got %v", stats.AvgDuration) 140 | } 141 | 142 | // Check that recent queries was updated 143 | if len(selector.recentQueries) != 1 { 144 | t.Errorf("Expected 1 recent query, got %d", len(selector.recentQueries)) 145 | } 146 | } 147 | 148 | func TestRecordQueryMetrics_UpdateStats(t *testing.T) { 149 | selector := NewAdaptiveStrategySelector(DefaultAdaptiveConfig()) 150 | 151 | // Record multiple metrics for the same strategy 152 | metrics1 := QueryMetrics{ 153 | Strategy: HNSWIndexType, 154 | QueryDimension: 128, 155 | K: 10, 156 | Duration: 100 * time.Millisecond, 157 | ResultCount: 5, 158 | Timestamp: time.Now(), 159 | } 160 | 161 | metrics2 := QueryMetrics{ 162 | Strategy: HNSWIndexType, 163 | QueryDimension: 128, 164 | K: 10, 165 | Duration: 200 * time.Millisecond, 166 | ResultCount: 5, 167 | Timestamp: time.Now(), 168 | } 169 | 170 | selector.RecordQueryMetrics(metrics1) 171 | selector.RecordQueryMetrics(metrics2) 172 | 173 | // Check that stats were updated 174 | stats := selector.metrics[HNSWIndexType] 175 | if stats.UsageCount != 2 { 176 | t.Errorf("Expected UsageCount 2, got %d", stats.UsageCount) 177 | } 178 | 179 | if stats.TotalDuration != 300*time.Millisecond { 180 | t.Errorf("Expected TotalDuration 300ms, got %v", stats.TotalDuration) 181 | } 182 | 183 | if stats.AvgDuration != 150*time.Millisecond { 184 | t.Errorf("Expected AvgDuration 150ms, got %v", stats.AvgDuration) 185 | } 186 | 187 | // Check that recent queries contains both 188 | if len(selector.recentQueries) != 2 { 189 | t.Errorf("Expected 2 recent queries, got %d", len(selector.recentQueries)) 190 | } 191 | } 192 | 193 | func TestRecordQueryMetrics_LimitRecentQueries(t *testing.T) { 194 | config := DefaultAdaptiveConfig() 195 | config.MetricsWindowSize = 3 // Small window for testing 196 | selector := NewAdaptiveStrategySelector(config) 197 | 198 | // Add more queries than the window size 199 | for i := 0; i < 5; i++ { 200 | metrics := QueryMetrics{ 201 | Strategy: ExactIndexType, 202 | QueryDimension: 128, 203 | K: 10, 204 | Duration: time.Duration(i+1) * 10 * time.Millisecond, 205 | ResultCount: 5, 206 | Timestamp: time.Now(), 207 | } 208 | selector.RecordQueryMetrics(metrics) 209 | } 210 | 211 | // Check that only the most recent window size queries are kept 212 | if len(selector.recentQueries) != 3 { 213 | t.Errorf("Expected %d recent queries, got %d", config.MetricsWindowSize, len(selector.recentQueries)) 214 | } 215 | 216 | // The oldest queries should be dropped, so the first duration should be 30ms (4th query) 217 | if selector.recentQueries[0].Duration != 30*time.Millisecond { 218 | t.Errorf("Expected first query duration 30ms, got %v", selector.recentQueries[0].Duration) 219 | } 220 | 221 | // The newest query should be 50ms (5th query) 222 | if selector.recentQueries[2].Duration != 50*time.Millisecond { 223 | t.Errorf("Expected last query duration 50ms, got %v", selector.recentQueries[2].Duration) 224 | } 225 | } 226 | 227 | func TestGetStats(t *testing.T) { 228 | selector := NewAdaptiveStrategySelector(DefaultAdaptiveConfig()) 229 | 230 | // Add some test metrics 231 | exactMetrics := QueryMetrics{ 232 | Strategy: ExactIndexType, 233 | QueryDimension: 128, 234 | K: 10, 235 | Duration: 50 * time.Millisecond, 236 | ResultCount: 5, 237 | Timestamp: time.Now(), 238 | } 239 | 240 | hnswMetrics := QueryMetrics{ 241 | Strategy: HNSWIndexType, 242 | QueryDimension: 256, 243 | K: 50, 244 | Duration: 30 * time.Millisecond, 245 | ResultCount: 10, 246 | Timestamp: time.Now(), 247 | } 248 | 249 | selector.RecordQueryMetrics(exactMetrics) 250 | selector.RecordQueryMetrics(hnswMetrics) 251 | 252 | // Get stats 253 | stats := selector.GetStats() 254 | 255 | // Check basic stats 256 | if thresholds, ok := stats["thresholds"].(map[string]interface{}); ok { 257 | if exactThreshold, exists := thresholds["exact"]; !exists || exactThreshold != selector.exactThreshold { 258 | t.Errorf("Expected exact threshold %d, got %v", selector.exactThreshold, exactThreshold) 259 | } 260 | 261 | if dimThreshold, exists := thresholds["dimension"]; !exists || dimThreshold != selector.dimThreshold { 262 | t.Errorf("Expected dimension threshold %d, got %v", selector.dimThreshold, dimThreshold) 263 | } 264 | } else { 265 | t.Error("Expected thresholds in stats") 266 | } 267 | 268 | // Check strategies stats 269 | if strategyStats, ok := stats["strategies"].(map[string]interface{}); ok { 270 | if _, exists := strategyStats[string(ExactIndexType)]; !exists { 271 | t.Errorf("Expected stats for %s", ExactIndexType) 272 | } 273 | 274 | if _, exists := strategyStats[string(HNSWIndexType)]; !exists { 275 | t.Errorf("Expected stats for %s", HNSWIndexType) 276 | } 277 | } else { 278 | t.Error("Expected strategies in stats") 279 | } 280 | 281 | // Check config 282 | if _, ok := stats["config"]; !ok { 283 | t.Error("Expected config in stats") 284 | } 285 | 286 | // Check recent queries 287 | if recentCount, ok := stats["recent_queries_count"].(int); !ok || recentCount != 2 { 288 | t.Errorf("Expected recent_queries_count 2, got %v", stats["recent_queries_count"]) 289 | } 290 | } 291 | 292 | func TestString(t *testing.T) { 293 | selector := NewAdaptiveStrategySelector(DefaultAdaptiveConfig()) 294 | 295 | // Add some test metrics 296 | selector.RecordQueryMetrics(QueryMetrics{ 297 | Strategy: ExactIndexType, 298 | QueryDimension: 128, 299 | K: 10, 300 | Duration: 50 * time.Millisecond, 301 | ResultCount: 5, 302 | Timestamp: time.Now(), 303 | }) 304 | 305 | // Just check that String() doesn't panic and returns something 306 | str := selector.String() 307 | if str == "" { 308 | t.Error("Expected non-empty string representation") 309 | } 310 | } 311 | -------------------------------------------------------------------------------- /pkg/vectortypes/types_test.go: -------------------------------------------------------------------------------- 1 | package vectortypes 2 | 3 | import ( 4 | "encoding/json" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestGetDistanceFuncByType(t *testing.T) { 10 | tests := []struct { 11 | name string 12 | distType DistanceType 13 | wantFunc DistanceFunc 14 | checkVecs bool 15 | vecA F32 16 | vecB F32 17 | wantResult float32 18 | }{ 19 | { 20 | name: "Cosine Distance", 21 | distType: Cosine, 22 | wantFunc: CosineDistance, 23 | checkVecs: true, 24 | vecA: F32{1, 0, 0}, 25 | vecB: F32{0, 1, 0}, 26 | wantResult: 1.0, // Perpendicular vectors have cosine distance of 1 27 | }, 28 | { 29 | name: "Euclidean Distance", 30 | distType: Euclidean, 31 | wantFunc: EuclideanDistance, 32 | checkVecs: true, 33 | vecA: F32{1, 0, 0}, 34 | vecB: F32{0, 1, 0}, 35 | wantResult: 1.4142135, // sqrt(2) 36 | }, 37 | { 38 | name: "Dot Product Distance", 39 | distType: DotProduct, 40 | wantFunc: DotProductDistance, 41 | checkVecs: true, 42 | vecA: F32{1, 0, 0}, 43 | vecB: F32{1, 0, 0}, 44 | wantResult: 0.0, // Same direction has dot product distance of 0 45 | }, 46 | { 47 | name: "Manhattan Distance", 48 | distType: Manhattan, 49 | wantFunc: ManhattanDistance, 50 | checkVecs: true, 51 | vecA: F32{1, 1, 0}, 52 | vecB: F32{0, 0, 0}, 53 | wantResult: 2.0, // |1-0| + |1-0| + |0-0| = 2 54 | }, 55 | { 56 | name: "Default to Cosine", 57 | distType: "invalid", 58 | wantFunc: CosineDistance, 59 | checkVecs: false, 60 | }, 61 | } 62 | 63 | for _, tt := range tests { 64 | t.Run(tt.name, func(t *testing.T) { 65 | got := GetDistanceFuncByType(tt.distType) 66 | 67 | // Can't directly compare functions, so we test with the same input 68 | if tt.checkVecs { 69 | gotResult := got(tt.vecA, tt.vecB) 70 | if !floatEquals(gotResult, tt.wantResult, 1e-6) { 71 | t.Errorf("GetDistanceFuncByType(%v) = func giving %v, want func giving %v", 72 | tt.distType, gotResult, tt.wantResult) 73 | } 74 | } 75 | }) 76 | } 77 | } 78 | 79 | func TestGetSurfaceByType(t *testing.T) { 80 | tests := []struct { 81 | name string 82 | distType DistanceType 83 | checkVecs bool 84 | vecA F32 85 | vecB F32 86 | wantResult float32 87 | }{ 88 | { 89 | name: "Cosine Surface", 90 | distType: Cosine, 91 | checkVecs: true, 92 | vecA: F32{1, 0, 0}, 93 | vecB: F32{0, 1, 0}, 94 | wantResult: 1.0, // Perpendicular vectors have cosine distance of 1 95 | }, 96 | { 97 | name: "Euclidean Surface", 98 | distType: Euclidean, 99 | checkVecs: true, 100 | vecA: F32{1, 0, 0}, 101 | vecB: F32{0, 1, 0}, 102 | wantResult: 1.4142135, // sqrt(2) 103 | }, 104 | { 105 | name: "Dot Product Surface", 106 | distType: DotProduct, 107 | checkVecs: true, 108 | vecA: F32{1, 0, 0}, 109 | vecB: F32{1, 0, 0}, 110 | wantResult: 0.0, // Same direction has dot product distance of 0 111 | }, 112 | { 113 | name: "Manhattan Surface", 114 | distType: Manhattan, 115 | checkVecs: true, 116 | vecA: F32{1, 1, 0}, 117 | vecB: F32{0, 0, 0}, 118 | wantResult: 2.0, // |1-0| + |1-0| + |0-0| = 2 119 | }, 120 | { 121 | name: "Default to Cosine", 122 | distType: "invalid", 123 | checkVecs: false, 124 | }, 125 | } 126 | 127 | for _, tt := range tests { 128 | t.Run(tt.name, func(t *testing.T) { 129 | got := GetSurfaceByType(tt.distType) 130 | 131 | // Check the surface by testing distance calculation 132 | if tt.checkVecs { 133 | gotResult := got.Distance(tt.vecA, tt.vecB) 134 | if !floatEquals(gotResult, tt.wantResult, 1e-6) { 135 | t.Errorf("GetSurfaceByType(%v).Distance(%v, %v) = %v, want %v", 136 | tt.distType, tt.vecA, tt.vecB, gotResult, tt.wantResult) 137 | } 138 | } 139 | }) 140 | } 141 | } 142 | 143 | func TestComputeDistance(t *testing.T) { 144 | tests := []struct { 145 | name string 146 | vecA F32 147 | vecB F32 148 | distType DistanceType 149 | wantResult float32 150 | wantErr bool 151 | }{ 152 | { 153 | name: "Cosine Distance", 154 | vecA: F32{1, 0, 0}, 155 | vecB: F32{0, 1, 0}, 156 | distType: Cosine, 157 | wantResult: 1.0, 158 | wantErr: false, 159 | }, 160 | { 161 | name: "Different Dimensions", 162 | vecA: F32{1, 0, 0}, 163 | vecB: F32{0, 1}, 164 | distType: Cosine, 165 | wantResult: 0.0, 166 | wantErr: true, 167 | }, 168 | } 169 | 170 | for _, tt := range tests { 171 | t.Run(tt.name, func(t *testing.T) { 172 | gotResult, err := ComputeDistance(tt.vecA, tt.vecB, tt.distType) 173 | 174 | // Check error 175 | if (err != nil) != tt.wantErr { 176 | t.Errorf("ComputeDistance() error = %v, wantErr %v", err, tt.wantErr) 177 | return 178 | } 179 | 180 | // Check result if no error expected 181 | if !tt.wantErr && !floatEquals(gotResult, tt.wantResult, 1e-6) { 182 | t.Errorf("ComputeDistance() = %v, want %v", gotResult, tt.wantResult) 183 | } 184 | }) 185 | } 186 | } 187 | 188 | func TestIsNormalized(t *testing.T) { 189 | tests := []struct { 190 | name string 191 | vec F32 192 | wantResult bool 193 | }{ 194 | { 195 | name: "Unit Vector", 196 | vec: F32{1, 0, 0}, 197 | wantResult: true, 198 | }, 199 | { 200 | name: "Normalized Vector", 201 | vec: F32{0.577, 0.577, 0.577}, // 1/sqrt(3) in each component 202 | wantResult: true, 203 | }, 204 | { 205 | name: "Zero Vector", 206 | vec: F32{0, 0, 0}, 207 | wantResult: false, // Zero vector is not normally considered normalized 208 | }, 209 | { 210 | name: "Non-normalized Vector", 211 | vec: F32{2, 0, 0}, 212 | wantResult: false, 213 | }, 214 | } 215 | 216 | for _, tt := range tests { 217 | t.Run(tt.name, func(t *testing.T) { 218 | gotResult := IsNormalized(tt.vec) 219 | 220 | if gotResult != tt.wantResult { 221 | t.Errorf("IsNormalized(%v) = %v, want %v", tt.vec, gotResult, tt.wantResult) 222 | } 223 | }) 224 | } 225 | } 226 | 227 | func TestVector_GetMetadataValue(t *testing.T) { 228 | tests := []struct { 229 | name string 230 | vector *Vector 231 | key string 232 | wantValue interface{} 233 | wantErr bool 234 | }{ 235 | { 236 | name: "Valid Key", 237 | vector: &Vector{ 238 | ID: "vec1", 239 | Values: F32{1, 2, 3}, 240 | Metadata: json.RawMessage(`{"key1": "value1", "key2": 123}`), 241 | }, 242 | key: "key1", 243 | wantValue: "value1", 244 | wantErr: false, 245 | }, 246 | { 247 | name: "Missing Key", 248 | vector: &Vector{ 249 | ID: "vec2", 250 | Values: F32{1, 2, 3}, 251 | Metadata: json.RawMessage(`{"key1": "value1"}`), 252 | }, 253 | key: "key2", 254 | wantValue: nil, 255 | wantErr: true, 256 | }, 257 | { 258 | name: "No Metadata", 259 | vector: &Vector{ 260 | ID: "vec3", 261 | Values: F32{1, 2, 3}, 262 | }, 263 | key: "key1", 264 | wantValue: nil, 265 | wantErr: true, 266 | }, 267 | { 268 | name: "Invalid JSON", 269 | vector: &Vector{ 270 | ID: "vec4", 271 | Values: F32{1, 2, 3}, 272 | Metadata: json.RawMessage(`{"key1": value1}`), // Invalid JSON (missing quotes) 273 | }, 274 | key: "key1", 275 | wantValue: nil, 276 | wantErr: true, 277 | }, 278 | } 279 | 280 | for _, tt := range tests { 281 | t.Run(tt.name, func(t *testing.T) { 282 | gotValue, err := tt.vector.GetMetadataValue(tt.key) 283 | 284 | // Check error 285 | if (err != nil) != tt.wantErr { 286 | t.Errorf("Vector.GetMetadataValue() error = %v, wantErr %v", err, tt.wantErr) 287 | return 288 | } 289 | 290 | // Check result if no error expected 291 | if !tt.wantErr && !reflect.DeepEqual(gotValue, tt.wantValue) { 292 | t.Errorf("Vector.GetMetadataValue() = %v, want %v", gotValue, tt.wantValue) 293 | } 294 | }) 295 | } 296 | } 297 | 298 | func TestCheckDimensions(t *testing.T) { 299 | tests := []struct { 300 | name string 301 | vecA F32 302 | vecB F32 303 | wantErr bool 304 | }{ 305 | { 306 | name: "Same Dimensions", 307 | vecA: F32{1, 2, 3}, 308 | vecB: F32{4, 5, 6}, 309 | wantErr: false, 310 | }, 311 | { 312 | name: "Different Dimensions", 313 | vecA: F32{1, 2, 3}, 314 | vecB: F32{4, 5}, 315 | wantErr: true, 316 | }, 317 | { 318 | name: "Zero Dimensions", 319 | vecA: F32{}, 320 | vecB: F32{}, 321 | wantErr: false, 322 | }, 323 | } 324 | 325 | for _, tt := range tests { 326 | t.Run(tt.name, func(t *testing.T) { 327 | err := CheckDimensions(tt.vecA, tt.vecB) 328 | 329 | if (err != nil) != tt.wantErr { 330 | t.Errorf("CheckDimensions() error = %v, wantErr %v", err, tt.wantErr) 331 | } 332 | }) 333 | } 334 | } 335 | 336 | func TestCreateVector(t *testing.T) { 337 | tests := []struct { 338 | name string 339 | id string 340 | values F32 341 | metadata json.RawMessage 342 | want *Vector 343 | }{ 344 | { 345 | name: "Basic Vector", 346 | id: "vec1", 347 | values: F32{1, 2, 3}, 348 | metadata: json.RawMessage(`{"key": "value"}`), 349 | want: &Vector{ 350 | ID: "vec1", 351 | Values: F32{1, 2, 3}, 352 | Metadata: json.RawMessage(`{"key": "value"}`), 353 | }, 354 | }, 355 | { 356 | name: "No Metadata", 357 | id: "vec2", 358 | values: F32{4, 5, 6}, 359 | metadata: nil, 360 | want: &Vector{ 361 | ID: "vec2", 362 | Values: F32{4, 5, 6}, 363 | Metadata: nil, 364 | }, 365 | }, 366 | } 367 | 368 | for _, tt := range tests { 369 | t.Run(tt.name, func(t *testing.T) { 370 | got := CreateVector(tt.id, tt.values, tt.metadata) 371 | 372 | if !vectorEqual(got, tt.want) { 373 | t.Errorf("CreateVector() = %v, want %v", got, tt.want) 374 | } 375 | }) 376 | } 377 | } 378 | 379 | func TestVector_Dimension(t *testing.T) { 380 | tests := []struct { 381 | name string 382 | vec *Vector 383 | want int 384 | }{ 385 | { 386 | name: "3D Vector", 387 | vec: &Vector{ 388 | ID: "vec1", 389 | Values: F32{1, 2, 3}, 390 | }, 391 | want: 3, 392 | }, 393 | { 394 | name: "Empty Vector", 395 | vec: &Vector{ 396 | ID: "vec2", 397 | Values: F32{}, 398 | }, 399 | want: 0, 400 | }, 401 | } 402 | 403 | for _, tt := range tests { 404 | t.Run(tt.name, func(t *testing.T) { 405 | got := tt.vec.Dimension() 406 | 407 | if got != tt.want { 408 | t.Errorf("Vector.Dimension() = %v, want %v", got, tt.want) 409 | } 410 | }) 411 | } 412 | } 413 | 414 | // Helper function to compare vectors 415 | func vectorEqual(a, b *Vector) bool { 416 | if a == nil && b == nil { 417 | return true 418 | } 419 | if a == nil || b == nil { 420 | return false 421 | } 422 | 423 | if a.ID != b.ID { 424 | return false 425 | } 426 | 427 | if len(a.Values) != len(b.Values) { 428 | return false 429 | } 430 | 431 | for i := range a.Values { 432 | if !floatEquals(a.Values[i], b.Values[i], 1e-6) { 433 | return false 434 | } 435 | } 436 | 437 | return string(a.Metadata) == string(b.Metadata) 438 | } 439 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🏹 Quiver 2 | 3 | ## What is Quiver? 4 | 5 | Quiver is a Go-based vector database that combines the best of HNSW (Hierarchical Navigable Small World) graphs with other cool search techniques. It provides efficient similarity search capabilities while maintaining a clean, easy-to-use API. 6 | 7 | ## 🙋 Why I Built Quiver 8 | 9 | I didn’t create Quiver for production use. It’s a learning project—my way of exploring the internals of vector databases and sharing what I’ve learned. It’s also a toy. 10 | 11 | If you’re curious about how vector search works under the hood, or if you want a foundation to build your own system, feel free to fork or clone it. I’ve kept it small and modular to make that easier. 12 | 13 | Accompanying write-ups are available on Medium. 14 | 15 | ## Supported Index Types 16 | 17 | Quiver offers three powerful index types, each of which can be backed by durable storage: 18 | 19 | 1. **HNSW Index**: The classic HNSW (Hierarchical Navigable Small World) graph implementation. This in-memory index offers a great balance of speed and recall for most use cases. It's fast, memory-efficient, and perfect for medium-sized datasets. 20 | 21 | 2. **Hybrid Index**: Our most advanced index type that combines multiple search strategies to optimize for both speed and recall. It can automatically select between exact search (for small datasets) and approximate search (for larger datasets), and includes optimizations for different query patterns. The hybrid index is particularly effective for datasets with varying sizes and query patterns. 22 | 23 | 3. **Arrow HNSW Index**: An Apache Arrow backed variant of HNSW for zero-copy columnar storage and optional Parquet persistence. 24 | 25 | All index types can be backed by **Parquet Storage**, which efficiently persists vectors to disk in Parquet format. This makes them suitable for larger datasets that need durability while maintaining good performance characteristics. 26 | 27 | All index types support metadata filtering and negative examples. Choose the right index type for your needs and let APT optimize your parameters automatically! 28 | 29 | ## Why Choose Quiver? 30 | 31 | - **🚀 Performance**: Quiver is built for speed without sacrificing accuracy 32 | - **🔍 Smart Search Strategy**: Quiver doesn't just use one search method - it combines HNSW with exact search to find the best results 33 | - **😌 Easy to Use**: Our fluent API just makes sense 34 | - **🔗 Fluent Query API**: Write queries that read like plain English 35 | - **🏷️ Rich Metadata**: Attach JSON metadata to your vectors and filter search results based on it 36 | - **🏎️ Faceted Search**: High-performance categorical filtering for your vectors 37 | - **👎 Negative Examples**: Tell Quiver what you don't want to see in your results 38 | - **⚡ Batch Operations**: Add, update, and delete vectors in batches for lightning speed 39 | - **💾 Durability**: Your data stays safe with Parquet-based storage 40 | - **📊 Analytics**: Peek under the hood with graph quality and performance metrics 41 | - **📦 Backup & Restore**: Create snapshots of your database and bring them back when needed 42 | 43 | ## What can you do with Quiver? 44 | 45 | Quiver makes it easy to form complex queries for any of our index types. 46 | 47 | ```go 48 | func SearchWithComplexOptions(db *quiver.DB) { 49 | log.Println("\nPerforming search with complex options...") 50 | 51 | // Get a collection 52 | collection, err := db.GetCollection("products") 53 | if err != nil { 54 | log.Fatalf("Failed to get collection: %v", err) 55 | } 56 | 57 | // Create a query vector 58 | queryVector := []float32{0.2, 0.3, 0.4, 0.5, 0.6} 59 | 60 | // Use the fluent API for building a complex search query 61 | // with both facets and negative examples 62 | results, err := collection.FluentSearch(queryVector). 63 | WithK(10). 64 | WithNegativeExample([]float32{0.9, 0.8, 0.7, 0.6, 0.5}). 65 | WithNegativeWeight(0.3). 66 | Filter("category", "electronics"). 67 | FilterIn("tags", []interface{}{"smartphone", "5G"}). 68 | FilterGreaterThan("price", 100.0). 69 | FilterLessThan("price", 500.0). 70 | Execute() 71 | 72 | if err != nil { 73 | log.Printf("Search with complex options failed: %v", err) 74 | return 75 | } 76 | 77 | // Display results 78 | log.Printf("Found %d results (with complex options):", len(results.Results)) 79 | for i, result := range results.Results { 80 | log.Printf(" Result %d: ID=%s, Distance=%f", i+1, result.ID, result.Distance) 81 | if result.Metadata != nil { 82 | log.Printf(" Metadata: %s", string(result.Metadata)) 83 | } 84 | } 85 | } 86 | ``` 87 | 88 | ## Tips for Best Performance 89 | 90 | 1. **Choose the Right Index Type**: 91 | - Small dataset (<1000 vectors)? Use the Hybrid index with a low ExactThreshold for perfect recall 92 | - Medium dataset? Use the HNSW index for a good balance of speed and recall 93 | - Large dataset with durability needs? Use HNSW with Parquet storage 94 | - Complex workloads with varying query patterns? Use the Hybrid index 95 | 96 | 2. **Tune Your Parameters** (or let APT do it for you): 97 | - M: Controls the number of connections per node (higher = more accurate but more memory) 98 | - EfSearch: Controls search depth (higher = more accurate but slower) 99 | - EfConstruction: Controls index build quality (higher = better index but slower construction) 100 | 101 | 3. **Use Batch Operations**: 102 | - Always prefer `BatchAdd` over multiple `Add` calls 103 | - Same goes for `BatchDelete` vs multiple `Delete` calls 104 | - Run `OptimizeStorage` after large batch operations 105 | 106 | 4. **Be Smart About Backups**: 107 | - Schedule backups during quiet times 108 | - Keep backups in a separate location 109 | - Test your restore process regularly 110 | 111 | 5. **Delete with Caution**: 112 | - Deleting vectors can degrade graph quality 113 | - Consider marking vectors as inactive instead of deleting them 114 | - If you must delete, run `OptimizeStorage` afterward 115 | 116 | 6. **Use Facets for High-Performance Filtering**: 117 | - For categorical data that you frequently filter on, use facets instead of metadata filtering 118 | - Set up facet fields early when creating your collections 119 | - Use facet filtering for categories, tags, and other discrete attributes 120 | 121 | ## Filtering with Metadata 122 | 123 | Quiver supports powerful filtering capabilities through metadata. You can attach arbitrary JSON metadata to your vectors and then filter search results based on this metadata. 124 | 125 | ```go 126 | // Add a vector with metadata 127 | metadata := map[string]interface{}{ 128 | "category": "electronics", 129 | "price": 299.99, 130 | "in_stock": true, 131 | "tags": []string{"smartphone", "android", "5G"} 132 | } 133 | collection.Add(id, vector, metadata) 134 | 135 | // Later, search with a metadata filter 136 | filter := []byte(`{ 137 | "category": "electronics", 138 | "price": {"$lt": 500}, 139 | "in_stock": true, 140 | "tags": {"$contains": "5G"} 141 | }`) 142 | 143 | results, err := collection.Search(types.SearchRequest{ 144 | Vector: queryVector, 145 | TopK: 10, 146 | Filters: []types.Filter{ 147 | {Field: "category", Operator: "=", Value: "electronics"}, 148 | {Field: "price", Operator: "<", Value: 500}, 149 | {Field: "in_stock", Operator: "=", Value: true}, 150 | {Field: "tags", Operator: "contains", Value: "5G"}, 151 | }, 152 | }) 153 | ``` 154 | 155 | ## Faceted Search 156 | 157 | Quiver offers an optimized filtering mechanism called faceted search. Facets enable high-performance filtering for categorical data, which is much faster than regular metadata filtering. 158 | 159 | ### What are Facets? 160 | 161 | Facets are precomputed, indexed categorical attributes that allow for rapid filtering. They're ideal for attributes like: 162 | 163 | - Product categories 164 | - Tags 165 | - Status values (active/inactive) 166 | - Geographic regions 167 | - Price ranges 168 | - Content types 169 | 170 | ### Using Facets 171 | 172 | ```go 173 | // Specify which fields should be indexed as facets 174 | collection.SetFacetFields([]string{"category", "price_range", "tags"}) 175 | 176 | // Add vectors with metadata that includes facet fields 177 | metadata := map[string]interface{}{ 178 | "category": "electronics", 179 | "price_range": "200-500", 180 | "tags": ["smartphone", "android", "5G"], 181 | "other_field": "this won't be indexed as a facet" 182 | } 183 | collection.Add(id, vector, metadata) 184 | 185 | // Search with facet filters 186 | filters := []facets.Filter{ 187 | facets.NewEqualityFilter("category", "electronics"), 188 | facets.NewEqualityFilter("tags", "5G"), 189 | } 190 | 191 | results, err := collection.SearchWithFacets(queryVector, 10, filters) 192 | ``` 193 | 194 | ### Benefits of Facets vs. Metadata Filtering 195 | 196 | 1. **Performance**: Facets are precomputed and indexed, making filtering operations much faster 197 | 2. **Memory Efficiency**: Facet values are stored in an optimized format 198 | 3. **Type Safety**: Facets provide type-aware filtering with proper comparisons 199 | 4. **Range Queries**: Built-in support for numeric range queries 200 | 5. **Set Operations**: Easily filter based on membership in sets 201 | 202 | ### When to Use Facets 203 | 204 | Use facets when: 205 | 206 | - You frequently filter on the same fields 207 | - Your filtering needs are primarily categorical 208 | - Performance is critical for your application 209 | - You're working with large datasets 210 | 211 | Use metadata filtering when: 212 | 213 | - Your filtering needs are ad-hoc or unpredictable 214 | - You need complex query expressions 215 | - You're filtering on fields that change frequently 216 | 217 | ## Using Negative Examples 218 | 219 | Negative examples allow you to steer search results away from specific concepts or characteristics. This is useful when you want to find vectors similar to your query but dissimilar from certain examples. 220 | 221 | ### How Negative Examples Work 222 | 223 | When you provide a negative example vector, Quiver: 224 | 225 | 1. Performs the standard search to find candidates similar to your query vector 226 | 2. Calculates the similarity between each candidate and the negative example 227 | 3. Re-ranks results to prefer candidates that are less similar to the negative example 228 | 4. Returns the final, adjusted results 229 | 230 | ### Using Negative Examples with the Standard API 231 | 232 | ```go 233 | // Create a search request with a negative example 234 | request := types.SearchRequest{ 235 | Vector: queryVector, 236 | TopK: 10, 237 | NegativeExample: negativeVector, 238 | NegativeWeight: 0.5, // 0.5 gives equal importance to positive and negative examples 239 | } 240 | 241 | // Execute the search 242 | response, err := collection.Search(request) 243 | ``` 244 | 245 | ### Using Negative Examples with the Fluent API 246 | 247 | ```go 248 | // Using the fluent API for a search with a negative example 249 | results, err := collection.FluentSearch(queryVector). 250 | WithK(10). 251 | WithNegativeExample(negativeVector). 252 | WithNegativeWeight(0.5). 253 | Execute() 254 | ``` 255 | 256 | ### Advanced Example: Finding Similar But Different Items 257 | 258 | ```go 259 | // Find products similar to smartphones but not tablets 260 | queryVector := productEmbedding["smartphone"] 261 | negativeVector := productEmbedding["tablet"] 262 | 263 | results, err := collection.FluentSearch(queryVector). 264 | WithK(20). 265 | WithNegativeExample(negativeVector). 266 | WithNegativeWeight(0.7). // Strong preference against tablet-like results 267 | Filter("category", "electronics"). 268 | FilterGreaterThan("rating", 4.0). 269 | Execute() 270 | ``` 271 | 272 | ### Combining Multiple Approaches 273 | 274 | For the most powerful queries, combine negative examples with facets and other filtering options: 275 | 276 | ```go 277 | // Find articles about AI but not focused on robotics, 278 | // published in the last year, with high engagement 279 | results, err := collection.FluentSearch(aiTopicVector). 280 | WithK(50). 281 | WithNegativeExample(roboticsVector). 282 | WithNegativeWeight(0.6). 283 | Filter("published_date", map[string]interface{}{ 284 | "$gt": time.Now().AddDate(-1, 0, 0), 285 | }). 286 | Filter("engagement_score", map[string]interface{}{ 287 | "$gt": 75, 288 | }). 289 | FilterNotEquals("is_sponsored", true). 290 | Execute() 291 | ``` 292 | 293 | ## License 294 | 295 | Quiver is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 296 | 297 | Happy vector searching. 🏹 298 | -------------------------------------------------------------------------------- /pkg/facets/facets.go: -------------------------------------------------------------------------------- 1 | // Package facets provides functionality for efficient categorical filtering in vector search 2 | package facets 3 | 4 | import ( 5 | "encoding/json" 6 | "fmt" 7 | "math" 8 | "reflect" 9 | "strings" 10 | ) 11 | 12 | // FilterType represents the type of facet filter 13 | type FilterType string 14 | 15 | const ( 16 | // TypeEquality represents an equality filter 17 | TypeEquality FilterType = "equality" 18 | // TypeRange represents a range filter 19 | TypeRange FilterType = "range" 20 | // TypeSet represents a set membership filter 21 | TypeSet FilterType = "set" 22 | // TypeExists represents an existence filter 23 | TypeExists FilterType = "exists" 24 | ) 25 | 26 | // Filter is the interface that all facet filters must implement 27 | type Filter interface { 28 | // Type returns the type of the filter 29 | Type() FilterType 30 | // Field returns the field name the filter applies to 31 | Field() string 32 | // Match checks if the given facet value matches the filter 33 | Match(value interface{}) bool 34 | // String returns a string representation of the filter 35 | String() string 36 | } 37 | 38 | // EqualityFilter is a filter that checks if a facet equals a specific value 39 | type EqualityFilter struct { 40 | FieldName string 41 | Value interface{} 42 | } 43 | 44 | // NewEqualityFilter creates a new equality filter 45 | func NewEqualityFilter(field string, value interface{}) *EqualityFilter { 46 | return &EqualityFilter{ 47 | FieldName: field, 48 | Value: value, 49 | } 50 | } 51 | 52 | // Type implements Filter.Type 53 | func (f *EqualityFilter) Type() FilterType { 54 | return TypeEquality 55 | } 56 | 57 | // Field implements Filter.Field 58 | func (f *EqualityFilter) Field() string { 59 | return f.FieldName 60 | } 61 | 62 | // Match implements Filter.Match 63 | func (f *EqualityFilter) Match(value interface{}) bool { 64 | // Special case for nil values 65 | if f.Value == nil && value == nil { 66 | return true 67 | } 68 | if f.Value == nil || value == nil { 69 | return false 70 | } 71 | 72 | // Handle strings case-insensitively 73 | if strValue, ok := value.(string); ok { 74 | if strFilter, ok := f.Value.(string); ok { 75 | return strings.EqualFold(strValue, strFilter) 76 | } 77 | } 78 | 79 | // Use reflect for basic equality comparison 80 | return reflect.DeepEqual(f.Value, value) 81 | } 82 | 83 | // String implements Filter.String 84 | func (f *EqualityFilter) String() string { 85 | return fmt.Sprintf("%s = %v", f.FieldName, f.Value) 86 | } 87 | 88 | // RangeFilter checks if a facet value is within a range 89 | type RangeFilter struct { 90 | FieldName string 91 | Min interface{} 92 | Max interface{} 93 | // IncludeMin specifies whether to include the minimum value in the range 94 | IncludeMin bool 95 | // IncludeMax specifies whether to include the maximum value in the range 96 | IncludeMax bool 97 | } 98 | 99 | // NewRangeFilter creates a new range filter 100 | func NewRangeFilter(field string, min, max interface{}, includeMin, includeMax bool) *RangeFilter { 101 | return &RangeFilter{ 102 | FieldName: field, 103 | Min: min, 104 | Max: max, 105 | IncludeMin: includeMin, 106 | IncludeMax: includeMax, 107 | } 108 | } 109 | 110 | // Type implements Filter.Type 111 | func (f *RangeFilter) Type() FilterType { 112 | return TypeRange 113 | } 114 | 115 | // Field implements Filter.Field 116 | func (f *RangeFilter) Field() string { 117 | return f.FieldName 118 | } 119 | 120 | // Match implements Filter.Match 121 | func (f *RangeFilter) Match(value interface{}) bool { 122 | if value == nil { 123 | return false 124 | } 125 | 126 | // Only numeric values can be compared with ranges 127 | switch v := value.(type) { 128 | case int: 129 | return f.compareInt(v) 130 | case int32: 131 | return f.compareInt(int(v)) 132 | case int64: 133 | if v > int64(math.MaxInt) || v < int64(math.MinInt) { 134 | return f.compareFloat(float64(v)) 135 | } 136 | return f.compareInt(int(v)) 137 | case float32: 138 | return f.compareFloat(float64(v)) 139 | case float64: 140 | return f.compareFloat(v) 141 | default: 142 | return false 143 | } 144 | } 145 | 146 | // compareInt checks if an int value is within the range 147 | func (f *RangeFilter) compareInt(value int) bool { 148 | var minOK, maxOK bool 149 | 150 | if f.Min == nil { 151 | minOK = true 152 | } else { 153 | switch min := f.Min.(type) { 154 | case int: 155 | minOK = (f.IncludeMin && value >= min) || (!f.IncludeMin && value > min) 156 | case int32: 157 | minOK = (f.IncludeMin && value >= int(min)) || (!f.IncludeMin && value > int(min)) 158 | case int64: 159 | minOK = (f.IncludeMin && value >= int(min)) || (!f.IncludeMin && value > int(min)) 160 | case float32: 161 | minOK = (f.IncludeMin && float64(value) >= float64(min)) || (!f.IncludeMin && float64(value) > float64(min)) 162 | case float64: 163 | minOK = (f.IncludeMin && float64(value) >= min) || (!f.IncludeMin && float64(value) > min) 164 | default: 165 | minOK = false 166 | } 167 | } 168 | 169 | if f.Max == nil { 170 | maxOK = true 171 | } else { 172 | switch max := f.Max.(type) { 173 | case int: 174 | maxOK = (f.IncludeMax && value <= max) || (!f.IncludeMax && value < max) 175 | case int32: 176 | maxOK = (f.IncludeMax && value <= int(max)) || (!f.IncludeMax && value < int(max)) 177 | case int64: 178 | maxOK = (f.IncludeMax && value <= int(max)) || (!f.IncludeMax && value < int(max)) 179 | case float32: 180 | maxOK = (f.IncludeMax && float64(value) <= float64(max)) || (!f.IncludeMax && float64(value) < float64(max)) 181 | case float64: 182 | maxOK = (f.IncludeMax && float64(value) <= max) || (!f.IncludeMax && float64(value) < max) 183 | default: 184 | maxOK = false 185 | } 186 | } 187 | 188 | return minOK && maxOK 189 | } 190 | 191 | // compareFloat checks if a float value is within the range 192 | func (f *RangeFilter) compareFloat(value float64) bool { 193 | var minOK, maxOK bool 194 | 195 | if f.Min == nil { 196 | minOK = true 197 | } else { 198 | switch min := f.Min.(type) { 199 | case int: 200 | minOK = (f.IncludeMin && value >= float64(min)) || (!f.IncludeMin && value > float64(min)) 201 | case int32: 202 | minOK = (f.IncludeMin && value >= float64(min)) || (!f.IncludeMin && value > float64(min)) 203 | case int64: 204 | minOK = (f.IncludeMin && value >= float64(min)) || (!f.IncludeMin && value > float64(min)) 205 | case float32: 206 | minOK = (f.IncludeMin && value >= float64(min)) || (!f.IncludeMin && value > float64(min)) 207 | case float64: 208 | minOK = (f.IncludeMin && value >= min) || (!f.IncludeMin && value > min) 209 | default: 210 | minOK = false 211 | } 212 | } 213 | 214 | if f.Max == nil { 215 | maxOK = true 216 | } else { 217 | switch max := f.Max.(type) { 218 | case int: 219 | maxOK = (f.IncludeMax && value <= float64(max)) || (!f.IncludeMax && value < float64(max)) 220 | case int32: 221 | maxOK = (f.IncludeMax && value <= float64(max)) || (!f.IncludeMax && value < float64(max)) 222 | case int64: 223 | maxOK = (f.IncludeMax && value <= float64(max)) || (!f.IncludeMax && value < float64(max)) 224 | case float32: 225 | maxOK = (f.IncludeMax && value <= float64(max)) || (!f.IncludeMax && value < float64(max)) 226 | case float64: 227 | maxOK = (f.IncludeMax && value <= max) || (!f.IncludeMax && value < max) 228 | default: 229 | maxOK = false 230 | } 231 | } 232 | 233 | return minOK && maxOK 234 | } 235 | 236 | // String implements Filter.String 237 | func (f *RangeFilter) String() string { 238 | minStr := "∞" 239 | if f.Min != nil { 240 | minStr = fmt.Sprintf("%v", f.Min) 241 | } 242 | maxStr := "∞" 243 | if f.Max != nil { 244 | maxStr = fmt.Sprintf("%v", f.Max) 245 | } 246 | 247 | leftBracket := "(" 248 | if f.IncludeMin { 249 | leftBracket = "[" 250 | } 251 | rightBracket := ")" 252 | if f.IncludeMax { 253 | rightBracket = "]" 254 | } 255 | 256 | return fmt.Sprintf("%s %s%s, %s%s", f.FieldName, leftBracket, minStr, maxStr, rightBracket) 257 | } 258 | 259 | // SetFilter checks if a facet value is in a set of allowed values 260 | type SetFilter struct { 261 | FieldName string 262 | Values []interface{} 263 | } 264 | 265 | // NewSetFilter creates a new set filter 266 | func NewSetFilter(field string, values []interface{}) *SetFilter { 267 | return &SetFilter{ 268 | FieldName: field, 269 | Values: values, 270 | } 271 | } 272 | 273 | // Type implements Filter.Type 274 | func (f *SetFilter) Type() FilterType { 275 | return TypeSet 276 | } 277 | 278 | // Field implements Filter.Field 279 | func (f *SetFilter) Field() string { 280 | return f.FieldName 281 | } 282 | 283 | // Match implements Filter.Match 284 | func (f *SetFilter) Match(value interface{}) bool { 285 | if value == nil { 286 | return false 287 | } 288 | 289 | // Handle string case-insensitive comparison 290 | if strValue, ok := value.(string); ok { 291 | for _, v := range f.Values { 292 | if strFilter, ok := v.(string); ok { 293 | if strings.EqualFold(strValue, strFilter) { 294 | return true 295 | } 296 | } else if reflect.DeepEqual(value, v) { 297 | return true 298 | } 299 | } 300 | return false 301 | } 302 | 303 | // Handle array/slice case - check if any element matches 304 | if reflect.TypeOf(value).Kind() == reflect.Slice || reflect.TypeOf(value).Kind() == reflect.Array { 305 | valueSlice := reflect.ValueOf(value) 306 | for i := 0; i < valueSlice.Len(); i++ { 307 | item := valueSlice.Index(i).Interface() 308 | for _, v := range f.Values { 309 | if reflect.DeepEqual(item, v) { 310 | return true 311 | } 312 | } 313 | } 314 | return false 315 | } 316 | 317 | // Regular equality check 318 | for _, v := range f.Values { 319 | if reflect.DeepEqual(value, v) { 320 | return true 321 | } 322 | } 323 | return false 324 | } 325 | 326 | // String implements Filter.String 327 | func (f *SetFilter) String() string { 328 | values := make([]string, len(f.Values)) 329 | for i, v := range f.Values { 330 | values[i] = fmt.Sprintf("%v", v) 331 | } 332 | return fmt.Sprintf("%s IN [%s]", f.FieldName, strings.Join(values, ", ")) 333 | } 334 | 335 | // ExistsFilter checks if a facet field exists and is not null/empty 336 | type ExistsFilter struct { 337 | FieldName string 338 | ShouldExist bool 339 | } 340 | 341 | // NewExistsFilter creates a new exists filter 342 | func NewExistsFilter(field string, shouldExist bool) *ExistsFilter { 343 | return &ExistsFilter{ 344 | FieldName: field, 345 | ShouldExist: shouldExist, 346 | } 347 | } 348 | 349 | // Type implements Filter.Type 350 | func (f *ExistsFilter) Type() FilterType { 351 | return TypeExists 352 | } 353 | 354 | // Field implements Filter.Field 355 | func (f *ExistsFilter) Field() string { 356 | return f.FieldName 357 | } 358 | 359 | // Match implements Filter.Match 360 | func (f *ExistsFilter) Match(value interface{}) bool { 361 | exists := value != nil 362 | 363 | // For empty strings, arrays, etc. 364 | if exists { 365 | v := reflect.ValueOf(value) 366 | switch v.Kind() { 367 | case reflect.String: 368 | exists = v.Len() > 0 369 | case reflect.Slice, reflect.Array, reflect.Map: 370 | exists = v.Len() > 0 371 | } 372 | } 373 | 374 | return exists == f.ShouldExist 375 | } 376 | 377 | // String implements Filter.String 378 | func (f *ExistsFilter) String() string { 379 | if f.ShouldExist { 380 | return fmt.Sprintf("%s EXISTS", f.FieldName) 381 | } 382 | return fmt.Sprintf("%s NOT EXISTS", f.FieldName) 383 | } 384 | 385 | // FacetValue represents a single facet value stored with a vector 386 | type FacetValue struct { 387 | Field string `json:"field"` 388 | Value interface{} `json:"value"` 389 | } 390 | 391 | // ExtractFacets extracts facet values from metadata 392 | func ExtractFacets(metadata map[string]interface{}, facetFields []string) []FacetValue { 393 | if len(facetFields) == 0 || metadata == nil { 394 | return nil 395 | } 396 | 397 | result := make([]FacetValue, 0, len(facetFields)) 398 | 399 | for _, field := range facetFields { 400 | // Handle nested fields with dot notation 401 | parts := strings.Split(field, ".") 402 | var value interface{} = metadata 403 | 404 | for _, part := range parts { 405 | if m, ok := value.(map[string]interface{}); ok { 406 | if v, exists := m[part]; exists { 407 | value = v 408 | } else { 409 | value = nil 410 | break 411 | } 412 | } else { 413 | value = nil 414 | break 415 | } 416 | } 417 | 418 | if value != nil { 419 | result = append(result, FacetValue{Field: field, Value: value}) 420 | } 421 | } 422 | 423 | return result 424 | } 425 | 426 | // MatchesAllFilters checks if a set of facets matches all provided filters 427 | func MatchesAllFilters(facets []FacetValue, filters []Filter) bool { 428 | if len(filters) == 0 { 429 | return true 430 | } 431 | 432 | if len(facets) == 0 { 433 | return false 434 | } 435 | 436 | // Create a map for O(1) field lookups 437 | facetMap := make(map[string]interface{}, len(facets)) 438 | for _, facet := range facets { 439 | facetMap[facet.Field] = facet.Value 440 | } 441 | 442 | // Check each filter 443 | for _, filter := range filters { 444 | value, exists := facetMap[filter.Field()] 445 | if !exists && filter.Type() != TypeExists { 446 | return false 447 | } 448 | if !filter.Match(value) { 449 | return false 450 | } 451 | } 452 | 453 | return true 454 | } 455 | 456 | // FacetsFromJSON extracts facet values from JSON metadata 457 | func FacetsFromJSON(metadataJSON json.RawMessage, facetFields []string) ([]FacetValue, error) { 458 | if len(metadataJSON) == 0 || len(facetFields) == 0 { 459 | return nil, nil 460 | } 461 | 462 | var metadata map[string]interface{} 463 | if err := json.Unmarshal(metadataJSON, &metadata); err != nil { 464 | return nil, fmt.Errorf("failed to parse metadata JSON: %w", err) 465 | } 466 | 467 | return ExtractFacets(metadata, facetFields), nil 468 | } 469 | -------------------------------------------------------------------------------- /pkg/vectortypes/distances_test.go: -------------------------------------------------------------------------------- 1 | package vectortypes 2 | 3 | import ( 4 | "math" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestCosineDistance(t *testing.T) { 10 | tests := []struct { 11 | name string 12 | vecA F32 13 | vecB F32 14 | expected float32 15 | }{ 16 | { 17 | name: "Identical Vectors", 18 | vecA: F32{1, 0, 0}, 19 | vecB: F32{1, 0, 0}, 20 | expected: 0, // Identical vectors have cosine distance of 0 21 | }, 22 | { 23 | name: "Perpendicular Vectors", 24 | vecA: F32{1, 0, 0}, 25 | vecB: F32{0, 1, 0}, 26 | expected: 1, // Perpendicular vectors have cosine distance of 1 27 | }, 28 | { 29 | name: "Opposite Vectors", 30 | vecA: F32{1, 0, 0}, 31 | vecB: F32{-1, 0, 0}, 32 | expected: 2, // Opposite vectors have cosine distance of 2 33 | }, 34 | { 35 | name: "Zero Vector", 36 | vecA: F32{0, 0, 0}, 37 | vecB: F32{1, 0, 0}, 38 | expected: 0, // Zero vector case is handled specially 39 | }, 40 | } 41 | 42 | for _, tt := range tests { 43 | t.Run(tt.name, func(t *testing.T) { 44 | result := CosineDistance(tt.vecA, tt.vecB) 45 | if !floatEquals(result, tt.expected, 1e-6) { 46 | t.Errorf("CosineDistance(%v, %v) = %v, want %v", tt.vecA, tt.vecB, result, tt.expected) 47 | } 48 | }) 49 | } 50 | } 51 | 52 | func TestEuclideanDistance(t *testing.T) { 53 | tests := []struct { 54 | name string 55 | vecA F32 56 | vecB F32 57 | expected float32 58 | }{ 59 | { 60 | name: "Identical Vectors", 61 | vecA: F32{1, 0, 0}, 62 | vecB: F32{1, 0, 0}, 63 | expected: 0, // Identical vectors have Euclidean distance of 0 64 | }, 65 | { 66 | name: "Unit Vectors", 67 | vecA: F32{1, 0, 0}, 68 | vecB: F32{0, 1, 0}, 69 | expected: float32(math.Sqrt(2)), // sqrt(1² + 1²) 70 | }, 71 | { 72 | name: "3D Vectors", 73 | vecA: F32{1, 2, 3}, 74 | vecB: F32{4, 5, 6}, 75 | expected: float32(math.Sqrt(27)), // sqrt(3² + 3² + 3²) 76 | }, 77 | } 78 | 79 | for _, tt := range tests { 80 | t.Run(tt.name, func(t *testing.T) { 81 | result := EuclideanDistance(tt.vecA, tt.vecB) 82 | if !floatEquals(result, tt.expected, 1e-6) { 83 | t.Errorf("EuclideanDistance(%v, %v) = %v, want %v", tt.vecA, tt.vecB, result, tt.expected) 84 | } 85 | }) 86 | } 87 | } 88 | 89 | func TestSquaredEuclideanDistance(t *testing.T) { 90 | tests := []struct { 91 | name string 92 | vecA F32 93 | vecB F32 94 | expected float32 95 | }{ 96 | { 97 | name: "Identical Vectors", 98 | vecA: F32{1, 0, 0}, 99 | vecB: F32{1, 0, 0}, 100 | expected: 0, 101 | }, 102 | { 103 | name: "Unit Vectors", 104 | vecA: F32{1, 0, 0}, 105 | vecB: F32{0, 1, 0}, 106 | expected: 2, 107 | }, 108 | { 109 | name: "3D Vectors", 110 | vecA: F32{1, 2, 3}, 111 | vecB: F32{4, 5, 6}, 112 | expected: 27, 113 | }, 114 | } 115 | 116 | for _, tt := range tests { 117 | t.Run(tt.name, func(t *testing.T) { 118 | result := SquaredEuclideanDistance(tt.vecA, tt.vecB) 119 | if !floatEquals(result, tt.expected, 1e-6) { 120 | t.Errorf("SquaredEuclideanDistance(%v, %v) = %v, want %v", tt.vecA, tt.vecB, result, tt.expected) 121 | } 122 | }) 123 | } 124 | } 125 | 126 | func TestDotProductDistance(t *testing.T) { 127 | tests := []struct { 128 | name string 129 | vecA F32 130 | vecB F32 131 | expected float32 132 | }{ 133 | { 134 | name: "Identical Unit Vectors", 135 | vecA: F32{1, 0, 0}, 136 | vecB: F32{1, 0, 0}, 137 | expected: 0, // 1 - dot product (1) = 0 138 | }, 139 | { 140 | name: "Perpendicular Vectors", 141 | vecA: F32{1, 0, 0}, 142 | vecB: F32{0, 1, 0}, 143 | expected: 1, // 1 - dot product (0) = 1 144 | }, 145 | { 146 | name: "Opposite Vectors", 147 | vecA: F32{1, 0, 0}, 148 | vecB: F32{-1, 0, 0}, 149 | expected: 2, // 1 - dot product (-1) = 2 150 | }, 151 | { 152 | name: "Scaled Vectors", 153 | vecA: F32{2, 0, 0}, 154 | vecB: F32{3, 0, 0}, 155 | expected: -5, // 1 - dot product (6) = -5 156 | }, 157 | } 158 | 159 | for _, tt := range tests { 160 | t.Run(tt.name, func(t *testing.T) { 161 | result := DotProductDistance(tt.vecA, tt.vecB) 162 | if !floatEquals(result, tt.expected, 1e-6) { 163 | t.Errorf("DotProductDistance(%v, %v) = %v, want %v", tt.vecA, tt.vecB, result, tt.expected) 164 | } 165 | }) 166 | } 167 | } 168 | 169 | func TestManhattanDistance(t *testing.T) { 170 | tests := []struct { 171 | name string 172 | vecA F32 173 | vecB F32 174 | expected float32 175 | }{ 176 | { 177 | name: "Identical Vectors", 178 | vecA: F32{1, 0, 0}, 179 | vecB: F32{1, 0, 0}, 180 | expected: 0, // Identical vectors have Manhattan distance of 0 181 | }, 182 | { 183 | name: "Unit Vectors", 184 | vecA: F32{1, 0, 0}, 185 | vecB: F32{0, 1, 0}, 186 | expected: 2, // |1-0| + |0-1| + |0-0| = 2 187 | }, 188 | { 189 | name: "3D Vectors", 190 | vecA: F32{1, 2, 3}, 191 | vecB: F32{4, 5, 6}, 192 | expected: 9, // |1-4| + |2-5| + |3-6| = 3 + 3 + 3 = 9 193 | }, 194 | } 195 | 196 | for _, tt := range tests { 197 | t.Run(tt.name, func(t *testing.T) { 198 | result := ManhattanDistance(tt.vecA, tt.vecB) 199 | if !floatEquals(result, tt.expected, 1e-6) { 200 | t.Errorf("ManhattanDistance(%v, %v) = %v, want %v", tt.vecA, tt.vecB, result, tt.expected) 201 | } 202 | }) 203 | } 204 | } 205 | 206 | func TestDistanceFunctionPanics(t *testing.T) { 207 | // Test that distance functions panic with different length vectors 208 | tests := []struct { 209 | name string 210 | distFunc DistanceFunc 211 | }{ 212 | {"CosineDistance", CosineDistance}, 213 | {"EuclideanDistance", EuclideanDistance}, 214 | {"SquaredEuclideanDistance", SquaredEuclideanDistance}, 215 | {"DotProductDistance", DotProductDistance}, 216 | {"ManhattanDistance", ManhattanDistance}, 217 | } 218 | 219 | for _, tt := range tests { 220 | t.Run(tt.name, func(t *testing.T) { 221 | defer func() { 222 | if r := recover(); r == nil { 223 | t.Errorf("%s did not panic with different length vectors", tt.name) 224 | } 225 | }() 226 | // Should panic 227 | tt.distFunc(F32{1, 2, 3}, F32{1, 2}) 228 | }) 229 | } 230 | } 231 | 232 | func TestNormalizeVector(t *testing.T) { 233 | tests := []struct { 234 | name string 235 | vec F32 236 | expected F32 237 | }{ 238 | { 239 | name: "Unit Vector", 240 | vec: F32{1, 0, 0}, 241 | expected: F32{1, 0, 0}, // Already normalized 242 | }, 243 | { 244 | name: "Non-unit Vector", 245 | vec: F32{3, 0, 0}, 246 | expected: F32{1, 0, 0}, // Normalized to unit vector 247 | }, 248 | { 249 | name: "3D Vector", 250 | vec: F32{1, 1, 1}, 251 | expected: F32{1 / float32(math.Sqrt(3)), 1 / float32(math.Sqrt(3)), 1 / float32(math.Sqrt(3))}, 252 | }, 253 | { 254 | name: "Zero Vector", 255 | vec: F32{0, 0, 0}, 256 | expected: F32{0, 0, 0}, // Zero vector can't be normalized 257 | }, 258 | } 259 | 260 | for _, tt := range tests { 261 | t.Run(tt.name, func(t *testing.T) { 262 | result := NormalizeVector(tt.vec) 263 | 264 | // For zero vector case, just check it's unchanged 265 | if tt.name == "Zero Vector" { 266 | if !reflect.DeepEqual(result, tt.expected) { 267 | t.Errorf("NormalizeVector(%v) = %v, want %v", tt.vec, result, tt.expected) 268 | } 269 | return 270 | } 271 | 272 | // For other cases, check magnitude is 1 (or very close) 273 | magnitude := 0.0 274 | for _, v := range result { 275 | magnitude += float64(v * v) 276 | } 277 | magnitude = math.Sqrt(magnitude) 278 | 279 | if !float64Equals(magnitude, 1.0, 1e-6) { 280 | t.Errorf("NormalizeVector(%v) produced vector with magnitude %v, want 1.0", 281 | tt.vec, magnitude) 282 | } 283 | 284 | // Also check if result matches expected 285 | for i := range result { 286 | if !floatEquals(result[i], tt.expected[i], 1e-6) { 287 | t.Errorf("NormalizeVector(%v) = %v, want %v", tt.vec, result, tt.expected) 288 | break 289 | } 290 | } 291 | }) 292 | } 293 | } 294 | 295 | func TestVectorAdd(t *testing.T) { 296 | tests := []struct { 297 | name string 298 | vecA F32 299 | vecB F32 300 | expected F32 301 | wantErr bool 302 | }{ 303 | { 304 | name: "Same Dimension", 305 | vecA: F32{1, 2, 3}, 306 | vecB: F32{4, 5, 6}, 307 | expected: F32{5, 7, 9}, 308 | wantErr: false, 309 | }, 310 | { 311 | name: "Different Dimensions", 312 | vecA: F32{1, 2, 3}, 313 | vecB: F32{4, 5}, 314 | expected: nil, 315 | wantErr: true, 316 | }, 317 | { 318 | name: "Zero Vectors", 319 | vecA: F32{0, 0, 0}, 320 | vecB: F32{0, 0, 0}, 321 | expected: F32{0, 0, 0}, 322 | wantErr: false, 323 | }, 324 | } 325 | 326 | for _, tt := range tests { 327 | t.Run(tt.name, func(t *testing.T) { 328 | result, err := VectorAdd(tt.vecA, tt.vecB) 329 | 330 | // Check error 331 | if (err != nil) != tt.wantErr { 332 | t.Errorf("VectorAdd() error = %v, wantErr %v", err, tt.wantErr) 333 | return 334 | } 335 | 336 | // Check result if no error 337 | if !tt.wantErr { 338 | if !reflect.DeepEqual(result, tt.expected) { 339 | t.Errorf("VectorAdd(%v, %v) = %v, want %v", tt.vecA, tt.vecB, result, tt.expected) 340 | } 341 | } 342 | }) 343 | } 344 | } 345 | 346 | func TestVectorSubtract(t *testing.T) { 347 | tests := []struct { 348 | name string 349 | vecA F32 350 | vecB F32 351 | expected F32 352 | wantErr bool 353 | }{ 354 | { 355 | name: "Same Dimension", 356 | vecA: F32{5, 7, 9}, 357 | vecB: F32{4, 5, 6}, 358 | expected: F32{1, 2, 3}, 359 | wantErr: false, 360 | }, 361 | { 362 | name: "Different Dimensions", 363 | vecA: F32{1, 2, 3}, 364 | vecB: F32{4, 5}, 365 | expected: nil, 366 | wantErr: true, 367 | }, 368 | { 369 | name: "Zero Vectors", 370 | vecA: F32{0, 0, 0}, 371 | vecB: F32{0, 0, 0}, 372 | expected: F32{0, 0, 0}, 373 | wantErr: false, 374 | }, 375 | { 376 | name: "Same Vectors", 377 | vecA: F32{1, 2, 3}, 378 | vecB: F32{1, 2, 3}, 379 | expected: F32{0, 0, 0}, 380 | wantErr: false, 381 | }, 382 | } 383 | 384 | for _, tt := range tests { 385 | t.Run(tt.name, func(t *testing.T) { 386 | result, err := VectorSubtract(tt.vecA, tt.vecB) 387 | 388 | // Check error 389 | if (err != nil) != tt.wantErr { 390 | t.Errorf("VectorSubtract() error = %v, wantErr %v", err, tt.wantErr) 391 | return 392 | } 393 | 394 | // Check result if no error 395 | if !tt.wantErr { 396 | if !reflect.DeepEqual(result, tt.expected) { 397 | t.Errorf("VectorSubtract(%v, %v) = %v, want %v", tt.vecA, tt.vecB, result, tt.expected) 398 | } 399 | } 400 | }) 401 | } 402 | } 403 | 404 | func TestVectorMultiplyScalar(t *testing.T) { 405 | tests := []struct { 406 | name string 407 | vec F32 408 | scalar float32 409 | expected F32 410 | }{ 411 | { 412 | name: "Multiply by 2", 413 | vec: F32{1, 2, 3}, 414 | scalar: 2, 415 | expected: F32{2, 4, 6}, 416 | }, 417 | { 418 | name: "Multiply by 0", 419 | vec: F32{1, 2, 3}, 420 | scalar: 0, 421 | expected: F32{0, 0, 0}, 422 | }, 423 | { 424 | name: "Multiply by -1", 425 | vec: F32{1, 2, 3}, 426 | scalar: -1, 427 | expected: F32{-1, -2, -3}, 428 | }, 429 | } 430 | 431 | for _, tt := range tests { 432 | t.Run(tt.name, func(t *testing.T) { 433 | result := VectorMultiplyScalar(tt.vec, tt.scalar) 434 | 435 | if !reflect.DeepEqual(result, tt.expected) { 436 | t.Errorf("VectorMultiplyScalar(%v, %v) = %v, want %v", 437 | tt.vec, tt.scalar, result, tt.expected) 438 | } 439 | }) 440 | } 441 | } 442 | 443 | func TestVectorMagnitude(t *testing.T) { 444 | tests := []struct { 445 | name string 446 | vec F32 447 | expected float32 448 | }{ 449 | { 450 | name: "Unit Vector", 451 | vec: F32{1, 0, 0}, 452 | expected: 1, 453 | }, 454 | { 455 | name: "2D Vector", 456 | vec: F32{3, 4}, 457 | expected: 5, // 3-4-5 triangle 458 | }, 459 | { 460 | name: "Zero Vector", 461 | vec: F32{0, 0, 0}, 462 | expected: 0, 463 | }, 464 | } 465 | 466 | for _, tt := range tests { 467 | t.Run(tt.name, func(t *testing.T) { 468 | result := VectorMagnitude(tt.vec) 469 | 470 | if !floatEquals(result, tt.expected, 1e-6) { 471 | t.Errorf("VectorMagnitude(%v) = %v, want %v", tt.vec, result, tt.expected) 472 | } 473 | }) 474 | } 475 | } 476 | 477 | func TestCreateZeroVector(t *testing.T) { 478 | tests := []struct { 479 | name string 480 | dimension int 481 | expected F32 482 | }{ 483 | { 484 | name: "3D Vector", 485 | dimension: 3, 486 | expected: F32{0, 0, 0}, 487 | }, 488 | { 489 | name: "Empty Vector", 490 | dimension: 0, 491 | expected: F32{}, 492 | }, 493 | } 494 | 495 | for _, tt := range tests { 496 | t.Run(tt.name, func(t *testing.T) { 497 | result := CreateZeroVector(tt.dimension) 498 | 499 | if !reflect.DeepEqual(result, tt.expected) { 500 | t.Errorf("CreateZeroVector(%v) = %v, want %v", tt.dimension, result, tt.expected) 501 | } 502 | 503 | if len(result) != tt.dimension { 504 | t.Errorf("CreateZeroVector(%v) created vector of length %v, want %v", 505 | tt.dimension, len(result), tt.dimension) 506 | } 507 | }) 508 | } 509 | } 510 | 511 | func TestCreateRandomVector(t *testing.T) { 512 | tests := []struct { 513 | name string 514 | dimension int 515 | }{ 516 | { 517 | name: "3D Vector", 518 | dimension: 3, 519 | }, 520 | { 521 | name: "Empty Vector", 522 | dimension: 0, 523 | }, 524 | } 525 | 526 | for _, tt := range tests { 527 | t.Run(tt.name, func(t *testing.T) { 528 | result := CreateRandomVector(tt.dimension) 529 | 530 | if len(result) != tt.dimension { 531 | t.Errorf("CreateRandomVector(%v) created vector of length %v, want %v", 532 | tt.dimension, len(result), tt.dimension) 533 | } 534 | }) 535 | } 536 | } 537 | 538 | func TestCloneVector(t *testing.T) { 539 | tests := []struct { 540 | name string 541 | vec F32 542 | expected F32 543 | }{ 544 | { 545 | name: "3D Vector", 546 | vec: F32{1, 2, 3}, 547 | expected: F32{1, 2, 3}, 548 | }, 549 | { 550 | name: "Empty Vector", 551 | vec: F32{}, 552 | expected: F32{}, 553 | }, 554 | } 555 | 556 | for _, tt := range tests { 557 | t.Run(tt.name, func(t *testing.T) { 558 | result := CloneVector(tt.vec) 559 | 560 | // Check that the result is equal to the expected 561 | if !reflect.DeepEqual(result, tt.expected) { 562 | t.Errorf("CloneVector(%v) = %v, want %v", tt.vec, result, tt.expected) 563 | } 564 | 565 | // Verify it's a deep copy by modifying the original 566 | if len(tt.vec) > 0 { 567 | original := tt.vec[0] 568 | tt.vec[0] = original + 1 569 | 570 | if result[0] != original { 571 | t.Errorf("CloneVector did not make a deep copy") 572 | } 573 | } 574 | }) 575 | } 576 | } 577 | 578 | // Helper function to compare floating point values with tolerance 579 | func floatEquals(a, b float32, epsilon float32) bool { 580 | return (a-b) < epsilon && (b-a) < epsilon 581 | } 582 | 583 | // Helper function to compare double precision floating point values with tolerance 584 | func float64Equals(a, b float64, epsilon float64) bool { 585 | return (a-b) < epsilon && (b-a) < epsilon 586 | } 587 | --------------------------------------------------------------------------------