├── lib ├── neighbor-s3.rb └── neighbor │ ├── s3 │ ├── version.rb │ └── index.rb │ └── s3.rb ├── CHANGELOG.md ├── Gemfile ├── examples ├── Gemfile └── disco_item_recs.rb ├── .gitignore ├── Rakefile ├── test ├── test_helper.rb └── index_test.rb ├── neighbor-s3.gemspec ├── LICENSE.txt └── README.md /lib/neighbor-s3.rb: -------------------------------------------------------------------------------- 1 | require_relative "neighbor/s3" 2 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## 0.1.0 (2025-10-02) 2 | 3 | - First release 4 | -------------------------------------------------------------------------------- /Gemfile: -------------------------------------------------------------------------------- 1 | source "https://rubygems.org" 2 | 3 | gemspec 4 | 5 | gem "rake" 6 | gem "minitest" 7 | -------------------------------------------------------------------------------- /examples/Gemfile: -------------------------------------------------------------------------------- 1 | source "https://rubygems.org" 2 | 3 | gemspec path: ".." 4 | 5 | gem "disco" 6 | -------------------------------------------------------------------------------- /lib/neighbor/s3/version.rb: -------------------------------------------------------------------------------- 1 | module Neighbor 2 | module S3 3 | VERSION = "0.1.0" 4 | end 5 | end 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.bundle/ 2 | /.yardoc 3 | /_yardoc/ 4 | /coverage/ 5 | /doc/ 6 | /pkg/ 7 | /spec/reports/ 8 | /tmp/ 9 | *.lock 10 | -------------------------------------------------------------------------------- /Rakefile: -------------------------------------------------------------------------------- 1 | require "bundler/gem_tasks" 2 | require "rake/testtask" 3 | 4 | Rake::TestTask.new(:test) do |t| 5 | t.libs << "test" 6 | t.test_files = FileList["test/**/*_test.rb"] 7 | end 8 | 9 | task default: :test 10 | -------------------------------------------------------------------------------- /test/test_helper.rb: -------------------------------------------------------------------------------- 1 | require "bundler/setup" 2 | Bundler.require(:default) 3 | require "minitest/autorun" 4 | 5 | class Minitest::Test 6 | def assert_elements_in_delta(expected, actual) 7 | assert_equal expected.size, actual.size 8 | expected.zip(actual) do |exp, act| 9 | assert_in_delta exp, act 10 | end 11 | end 12 | 13 | def bucket 14 | ENV.fetch("S3_BUCKET") 15 | end 16 | end 17 | -------------------------------------------------------------------------------- /lib/neighbor/s3.rb: -------------------------------------------------------------------------------- 1 | # dependencies 2 | require "aws-sdk-s3vectors" 3 | 4 | # modules 5 | require_relative "s3/index" 6 | require_relative "s3/version" 7 | 8 | module Neighbor 9 | module S3 10 | class Error < StandardError; end 11 | 12 | class << self 13 | attr_writer :client 14 | 15 | def client 16 | @client ||= Aws::S3Vectors::Client.new 17 | end 18 | end 19 | end 20 | end 21 | -------------------------------------------------------------------------------- /examples/disco_item_recs.rb: -------------------------------------------------------------------------------- 1 | require "disco" 2 | require "neighbor-s3" 3 | 4 | index = Neighbor::S3::Index.new("movies", bucket: "my-bucket", dimensions: 20, distance: "cosine") 5 | index.drop if index.exists? 6 | index.create 7 | 8 | data = Disco.load_movielens 9 | recommender = Disco::Recommender.new(factors: 20) 10 | recommender.fit(data) 11 | 12 | index.add_all(recommender.item_ids.map { |v| {id: v, vector: recommender.item_factors(v)} }) 13 | 14 | pp index.search_id("Star Wars (1977)").map { |v| v[:id] } 15 | -------------------------------------------------------------------------------- /neighbor-s3.gemspec: -------------------------------------------------------------------------------- 1 | require_relative "lib/neighbor/s3/version" 2 | 3 | Gem::Specification.new do |spec| 4 | spec.name = "neighbor-s3" 5 | spec.version = Neighbor::S3::VERSION 6 | spec.summary = "Nearest neighbor search for Ruby and S3 Vectors" 7 | spec.homepage = "https://github.com/ankane/neighbor-s3" 8 | spec.license = "MIT" 9 | 10 | spec.author = "Andrew Kane" 11 | spec.email = "andrew@ankane.org" 12 | 13 | spec.files = Dir["*.{md,txt}", "{lib}/**/*"] 14 | spec.require_path = "lib" 15 | 16 | spec.required_ruby_version = ">= 3.2" 17 | 18 | spec.add_dependency "aws-sdk-s3vectors" 19 | end 20 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2025 Andrew Kane 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neighbor S3 2 | 3 | Nearest neighbor search for Ruby and S3 Vectors 4 | 5 | ## Installation 6 | 7 | Add this line to your application’s Gemfile: 8 | 9 | ```ruby 10 | gem "neighbor-s3" 11 | ``` 12 | 13 | Create a [vector bucket](https://console.aws.amazon.com/s3/vector-buckets) and set your AWS credentials in your environment: 14 | 15 | ```sh 16 | AWS_ACCESS_KEY_ID=... 17 | AWS_SECRET_ACCESS_KEY=... 18 | ``` 19 | 20 | ## Getting Started 21 | 22 | Create an index 23 | 24 | ```ruby 25 | index = Neighbor::S3::Index.new("items", bucket: "my-bucket", dimensions: 3, distance: "cosine") 26 | index.create 27 | ``` 28 | 29 | Add vectors 30 | 31 | ```ruby 32 | index.add(1, [1, 1, 1]) 33 | index.add(2, [2, 2, 2]) 34 | index.add(3, [1, 1, 2]) 35 | ``` 36 | 37 | Search for nearest neighbors to a vector 38 | 39 | ```ruby 40 | index.search([1, 1, 1], count: 5) 41 | ``` 42 | 43 | Search for nearest neighbors to a vector in the index 44 | 45 | ```ruby 46 | index.search_id(1, count: 5) 47 | ``` 48 | 49 | IDs are treated as strings by default, but can also be treated as integers 50 | 51 | ```ruby 52 | Neighbor::S3::Index.new("items", id_type: "integer", ...) 53 | ``` 54 | 55 | ## Operations 56 | 57 | Add or update a vector 58 | 59 | ```ruby 60 | index.add(id, vector) 61 | ``` 62 | 63 | Add or update multiple vectors 64 | 65 | ```ruby 66 | index.add_all([{id: 1, vector: [1, 2, 3]}, {id: 2, vector: [4, 5, 6]}]) 67 | ``` 68 | 69 | Get a vector 70 | 71 | ```ruby 72 | index.find(id) 73 | ``` 74 | 75 | Get all vectors 76 | 77 | ```ruby 78 | index.find_in_batches do |batch| 79 | # ... 80 | end 81 | ``` 82 | 83 | Remove a vector 84 | 85 | ```ruby 86 | index.remove(id) 87 | ``` 88 | 89 | Remove multiple vectors 90 | 91 | ```ruby 92 | index.remove_all(ids) 93 | ``` 94 | 95 | ## Metadata 96 | 97 | Add a vector with metadata 98 | 99 | ```ruby 100 | index.add(id, vector, metadata: {category: "A"}) 101 | ``` 102 | 103 | Add multiple vectors with metadata 104 | 105 | ```ruby 106 | index.add_all([ 107 | {id: 1, vector: [1, 2, 3], metadata: {category: "A"}}, 108 | {id: 2, vector: [4, 5, 6], metadata: {category: "B"}} 109 | ]) 110 | ``` 111 | 112 | Get metadata with search results 113 | 114 | ```ruby 115 | index.search(vector, with_metadata: true) 116 | ``` 117 | 118 | Filter by metadata 119 | 120 | ```ruby 121 | index.search(vector, filter: {category: "A"}) 122 | ``` 123 | 124 | Supports [these operators](https://docs.aws.amazon.com/AmazonS3/latest/userguide/s3-vectors-metadata-filtering.html#s3-vectors-metadata-filtering-filterable) 125 | 126 | Specify non-filterable metadata on index creation 127 | 128 | ```ruby 129 | Neighbor::S3::Index.new(name, non_filterable: ["category"], ...) 130 | ``` 131 | 132 | ## Example 133 | 134 | You can use Neighbor S3 for online item-based recommendations with [Disco](https://github.com/ankane/disco). We’ll use MovieLens data for this example. 135 | 136 | Create an index 137 | 138 | ```ruby 139 | index = Neighbor::S3::Index.new("movies", bucket: "my-bucket", dimensions: 20, distance: "cosine") 140 | ``` 141 | 142 | Fit the recommender 143 | 144 | ```ruby 145 | data = Disco.load_movielens 146 | recommender = Disco::Recommender.new(factors: 20) 147 | recommender.fit(data) 148 | ``` 149 | 150 | Store the item factors 151 | 152 | ```ruby 153 | index.add_all(recommender.item_ids.map { |v| {id: v, vector: recommender.item_factors(v)} }) 154 | ``` 155 | 156 | And get similar movies 157 | 158 | ```ruby 159 | index.search_id("Star Wars (1977)").map { |v| v[:id] } 160 | ``` 161 | 162 | See the [complete code](examples/disco_item_recs.rb) 163 | 164 | ## Reference 165 | 166 | Get index info 167 | 168 | ```ruby 169 | index.info 170 | ``` 171 | 172 | Check if an index exists 173 | 174 | ```ruby 175 | index.exists? 176 | ``` 177 | 178 | Drop an index 179 | 180 | ```ruby 181 | index.drop 182 | ``` 183 | 184 | ## History 185 | 186 | View the [changelog](https://github.com/ankane/neighbor-s3/blob/master/CHANGELOG.md) 187 | 188 | ## Contributing 189 | 190 | Everyone is encouraged to help improve this project. Here are a few ways you can help: 191 | 192 | - [Report bugs](https://github.com/ankane/neighbor-s3/issues) 193 | - Fix bugs and [submit pull requests](https://github.com/ankane/neighbor-s3/pulls) 194 | - Write, clarify, or fix documentation 195 | - Suggest or add new features 196 | 197 | To get started with development: 198 | 199 | ```sh 200 | git clone https://github.com/ankane/neighbor-s3.git 201 | cd neighbor-s3 202 | bundle install 203 | bundle exec rake test 204 | ``` 205 | -------------------------------------------------------------------------------- /lib/neighbor/s3/index.rb: -------------------------------------------------------------------------------- 1 | module Neighbor 2 | module S3 3 | class Index 4 | def initialize(name, bucket:, dimensions:, distance:, id_type: "string", non_filterable: nil) 5 | @name = name 6 | @bucket = bucket 7 | @dimensions = dimensions.to_i 8 | 9 | @distance_metric = 10 | case distance.to_s 11 | when "euclidean" 12 | "euclidean" 13 | when "cosine" 14 | "cosine" 15 | else 16 | raise ArgumentError, "invalid distance" 17 | end 18 | 19 | @int_ids = 20 | case id_type.to_s 21 | when "string" 22 | false 23 | when "integer" 24 | true 25 | else 26 | raise ArgumentError, "invalid id_type" 27 | end 28 | 29 | @non_filterable = non_filterable.to_a 30 | end 31 | 32 | def self.create(*args, **options) 33 | index = new(*args, **options) 34 | index.create 35 | index 36 | end 37 | 38 | def create 39 | options = { 40 | vector_bucket_name: @bucket, 41 | index_name: @name, 42 | data_type: "float32", 43 | dimension: @dimensions, 44 | distance_metric: @distance_metric 45 | } 46 | if @non_filterable.any? 47 | options[:metadata_configuration] = { 48 | non_filterable_metadata_keys: @non_filterable 49 | } 50 | end 51 | client.create_index(options) 52 | nil 53 | end 54 | 55 | def exists? 56 | client.get_index({ 57 | vector_bucket_name: @bucket, 58 | index_name: @name 59 | }) 60 | true 61 | rescue Aws::S3Vectors::Errors::NotFoundException 62 | false 63 | end 64 | 65 | def info 66 | client.get_index({ 67 | vector_bucket_name: @bucket, 68 | index_name: @name 69 | }).index.to_h 70 | end 71 | 72 | def add(id, vector, metadata: nil) 73 | add_all([{id: id, vector: vector, metadata: metadata}]) 74 | end 75 | 76 | def add_all(items) 77 | # perform checks first to reduce chance of non-atomic updates 78 | vectors = 79 | items.map do |item| 80 | vector = item.fetch(:vector).to_a 81 | check_dimensions(vector) 82 | 83 | { 84 | key: item_id(item.fetch(:id)).to_s, 85 | data: {float32: vector}, 86 | metadata: item[:metadata] 87 | } 88 | end 89 | 90 | vectors.each_slice(500) do |batch| 91 | client.put_vectors({ 92 | vector_bucket_name: @bucket, 93 | index_name: @name, 94 | vectors: batch 95 | }) 96 | end 97 | nil 98 | end 99 | 100 | def member?(id) 101 | id = item_id(id) 102 | 103 | client.get_vectors({ 104 | vector_bucket_name: @bucket, 105 | index_name: @name, 106 | keys: [id.to_s], 107 | return_data: false, 108 | return_metadata: false 109 | }).vectors.any? 110 | end 111 | alias_method :include?, :member? 112 | 113 | def remove(id) 114 | remove_all([id]) 115 | end 116 | 117 | def remove_all(ids) 118 | ids = ids.to_a.map { |id| item_id(id) } 119 | 120 | ids.each_slice(500) do |batch| 121 | client.delete_vectors({ 122 | vector_bucket_name: @bucket, 123 | index_name: @name, 124 | keys: batch.map(&:to_s) 125 | }) 126 | end 127 | nil 128 | end 129 | 130 | def find(id, with_metadata: true) 131 | id = item_id(id) 132 | 133 | v = 134 | client.get_vectors({ 135 | vector_bucket_name: @bucket, 136 | index_name: @name, 137 | keys: [id.to_s], 138 | return_data: true, 139 | return_metadata: with_metadata 140 | }).vectors.first 141 | 142 | if v 143 | item = { 144 | id: item_id(v.key), 145 | vector: v.data.float32 146 | } 147 | item[:metadata] = v.metadata if with_metadata 148 | item 149 | end 150 | end 151 | 152 | def find_in_batches(batch_size: 1000, with_metadata: true) 153 | options = { 154 | vector_bucket_name: @bucket, 155 | index_name: @name, 156 | max_results: batch_size, 157 | return_data: true, 158 | return_metadata: with_metadata 159 | } 160 | 161 | begin 162 | resp = client.list_vectors(options) 163 | batch = 164 | resp.vectors.map do |v| 165 | item = { 166 | id: item_id(v.key), 167 | vector: v.data.float32 168 | } 169 | item[:metadata] = v.metadata if with_metadata 170 | item 171 | end 172 | yield batch 173 | options[:next_token] = resp.next_token 174 | end while resp.next_token 175 | end 176 | 177 | def search(vector, count: 5, with_metadata: false, filter: nil) 178 | check_dimensions(vector) 179 | 180 | client.query_vectors({ 181 | vector_bucket_name: @bucket, 182 | index_name: @name, 183 | top_k: count, 184 | query_vector: { 185 | float32: vector, 186 | }, 187 | filter: filter, 188 | return_metadata: with_metadata, 189 | return_distance: true 190 | }).vectors.map do |v| 191 | item = { 192 | id: item_id(v.key), 193 | distance: @distance_metric == "euclidean" ? Math.sqrt(v.distance) : v.distance 194 | } 195 | item[:metadata] = v.metadata if with_metadata 196 | item 197 | end 198 | end 199 | 200 | def search_id(id, count: 5, with_metadata: false, filter: nil) 201 | id = item_id(id) 202 | 203 | item = find(id) 204 | unless item 205 | raise Error, "Could not find item #{id}" 206 | end 207 | 208 | result = search(item[:vector], count: count + 1, with_metadata:, filter:) 209 | result.reject { |v| v[:id] == id }.first(count) 210 | end 211 | 212 | def drop 213 | client.delete_index({ 214 | vector_bucket_name: @bucket, 215 | index_name: @name 216 | }) 217 | nil 218 | end 219 | 220 | private 221 | 222 | def check_dimensions(vector) 223 | if vector.size != @dimensions 224 | raise ArgumentError, "expected #{@dimensions} dimensions" 225 | end 226 | end 227 | 228 | def item_id(id) 229 | @int_ids ? Integer(id) : id.to_s 230 | end 231 | 232 | def client 233 | S3.client 234 | end 235 | end 236 | end 237 | end 238 | -------------------------------------------------------------------------------- /test/index_test.rb: -------------------------------------------------------------------------------- 1 | require_relative "test_helper" 2 | 3 | class IndexTest < Minitest::Test 4 | def setup 5 | super 6 | index = Neighbor::S3::Index.new("items", bucket: bucket, dimensions: 3, distance: "euclidean") 7 | index.drop if index.exists? 8 | end 9 | 10 | def test_create 11 | index = Neighbor::S3::Index.new("items", bucket: bucket, dimensions: 3, distance: "euclidean") 12 | assert_nil index.create 13 | end 14 | 15 | def test_create_exists 16 | index = create_index 17 | 18 | error = assert_raises(Aws::S3Vectors::Errors::ConflictException) do 19 | index.create 20 | end 21 | assert_equal "An index with the specified name already exists", error.message 22 | end 23 | 24 | def test_exists 25 | index = Neighbor::S3::Index.new("items", bucket: bucket, dimensions: 3, distance: "euclidean") 26 | assert_equal false, index.exists? 27 | index.create 28 | assert_equal true, index.exists? 29 | end 30 | 31 | def test_info 32 | index = create_index 33 | info = index.info 34 | assert_equal "items", info[:index_name] 35 | assert_equal "float32", info[:data_type] 36 | assert_equal 3, info[:dimension] 37 | assert_equal "cosine", info[:distance_metric] 38 | end 39 | 40 | def test_info_missing 41 | index = Neighbor::S3::Index.new("items", bucket: bucket, dimensions: 3, distance: "euclidean") 42 | error = assert_raises(Aws::S3Vectors::Errors::NotFoundException) do 43 | index.info 44 | end 45 | assert_equal "The specified index could not be found", error.message 46 | end 47 | 48 | def test_add 49 | index = create_index 50 | assert_nil index.add(1, [1, 1, 1]) 51 | assert_nil index.add(1, [2, 2, 2]) 52 | assert_equal [2, 2, 2], index.find(1)[:vector] 53 | end 54 | 55 | def test_add_metadata 56 | index = create_index 57 | assert_nil index.add(1, [1, 1, 1], metadata: {category: "A"}) 58 | assert_equal ({"category" => "A"}), index.find(1)[:metadata] 59 | 60 | assert_nil index.add(1, [2, 2, 2]) 61 | assert_empty index.find(1)[:metadata] 62 | end 63 | 64 | def test_add_different_dimensions 65 | index = create_index 66 | error = assert_raises(ArgumentError) do 67 | index.add(4, [1, 2]) 68 | end 69 | assert_equal "expected 3 dimensions", error.message 70 | end 71 | 72 | def test_add_before_create 73 | index = Neighbor::S3::Index.new("items", bucket: bucket, dimensions: 3, distance: "euclidean", id_type: "integer") 74 | error = assert_raises(Aws::S3Vectors::Errors::NotFoundException) do 75 | index.add(1, [1, 1, 1]) 76 | end 77 | assert_equal "The specified index could not be found", error.message 78 | end 79 | 80 | def test_add_all 81 | index = create_index 82 | assert_nil index.add_all([{id: 1, vector: [1, 1, 1]}, {id: 2, vector: [2, 2, 2]}]) 83 | assert_nil index.add_all([{id: 1, vector: [1, 1, 1]}, {id: 3, vector: [1, 1, 2]}]) 84 | end 85 | 86 | def test_add_all_different_dimensions 87 | index = create_index 88 | error = assert_raises(ArgumentError) do 89 | index.add_all([{id: 1, vector: [1, 1, 1]}, {id: 4, vector: [1, 2]}]) 90 | end 91 | assert_equal "expected 3 dimensions", error.message 92 | end 93 | 94 | def test_add_all_missing_key 95 | index = create_index 96 | error = assert_raises(KeyError) do 97 | index.add_all([{id: 1}]) 98 | end 99 | assert_equal "key not found: :vector", error.message 100 | end 101 | 102 | def test_member 103 | index = create_index 104 | add_items(index) 105 | assert_equal true, index.member?(2) 106 | assert_equal false, index.member?(4) 107 | end 108 | 109 | def test_include 110 | index = create_index 111 | add_items(index) 112 | assert_equal true, index.include?(2) 113 | assert_equal false, index.include?(4) 114 | end 115 | 116 | def test_remove 117 | index = create_index(distance: "euclidean", id_type: "integer") 118 | add_items(index) 119 | assert_nil index.remove(2) 120 | assert_nil index.remove(4) 121 | assert_equal [1, 3], index.search([1, 1, 1]).map { |v| v[:id] } 122 | end 123 | 124 | def test_remove_all 125 | index = create_index(distance: "euclidean", id_type: "integer") 126 | add_items(index) 127 | assert_nil index.remove_all([2, 4]) 128 | assert_equal [1, 3], index.search([1, 1, 1]).map { |v| v[:id] } 129 | end 130 | 131 | def test_find 132 | index = create_index 133 | add_items(index) 134 | assert_elements_in_delta [1, 1, 1], index.find(1)[:vector] 135 | assert_elements_in_delta [2, 2, 2], index.find(2)[:vector] 136 | assert_elements_in_delta [1, 1, 2], index.find(3)[:vector] 137 | assert_nil index.find(4) 138 | end 139 | 140 | def test_find_metadata 141 | index = create_index 142 | index.add(1, [1, 1, 1], metadata: {category: "A"}) 143 | index.add(2, [-1, -1, -1], metadata: {category: "B"}) 144 | index.add(3, [1, 1, 0]) 145 | 146 | assert_equal ({"category" => "A"}), index.find(1)[:metadata] 147 | assert_equal ({"category" => "B"}), index.find(2)[:metadata] 148 | assert_empty index.find(3)[:metadata] 149 | assert_nil index.find(4) 150 | end 151 | 152 | def test_find_in_batches 153 | index = create_index 154 | add_items(index) 155 | batches = [] 156 | index.find_in_batches(batch_size: 2) do |batch| 157 | batches << batch 158 | end 159 | assert_equal 2, batches.size 160 | end 161 | 162 | def test_find_in_batches_batch_size 163 | index = create_index 164 | error = assert_raises(Aws::S3Vectors::Errors::ValidationException) do 165 | index.find_in_batches(batch_size: 1001) 166 | end 167 | assert_match "Member must be between 1 and 1000, inclusive", error.message 168 | end 169 | 170 | def test_search_euclidean 171 | index = create_index(distance: "euclidean", id_type: "integer") 172 | add_items(index) 173 | result = index.search([1, 1, 1]) 174 | assert_equal [1, 3, 2], result.map { |v| v[:id] } 175 | assert_elements_in_delta [0, 1, 1.7320507764816284], result.map { |v| v[:distance] } 176 | end 177 | 178 | def test_search_cosine 179 | index = create_index(distance: "cosine", id_type: "integer") 180 | index.add(1, [1, 1, 1]) 181 | index.add(2, [-1, -1, -1]) 182 | index.add(3, [1, 1, 2]) 183 | result = index.search([1, 1, 1]) 184 | assert_equal [1, 3, 2], result.map { |v| v[:id] } 185 | assert_elements_in_delta [0, 0.05719095841050148, 2], result.map { |v| v[:distance] } 186 | end 187 | 188 | def test_search_with_metadata 189 | index = create_index 190 | index.add(1, [1, 1, 1], metadata: {category: "A", quantity: 2}) 191 | index.add(2, [-1, -1, -1], metadata: {category: "B", quantity: 4}) 192 | index.add(3, [1, 1, 0]) 193 | 194 | result = index.search([1, 1, 1], with_metadata: true) 195 | assert_equal ({"category" => "A", "quantity" => 2}), result[0][:metadata] 196 | assert_empty result[1][:metadata] 197 | assert_equal ({"category" => "B", "quantity" => 4}), result[2][:metadata] 198 | end 199 | 200 | def test_search_filter 201 | index = create_index(distance: "cosine", id_type: "integer") 202 | index.add(1, [1, 1, 1], metadata: {category: "A", quantity: 2}) 203 | index.add(2, [-1, -1, -1], metadata: {category: "B", quantity: 4}) 204 | index.add(3, [1, 1, 0]) 205 | 206 | result = index.search([1, 1, 1], filter: {category: "B"}) 207 | assert_equal [2], result.map { |v| v[:id] } 208 | 209 | result = index.search([1, 1, 1], filter: {quantity: {"$gt" => 2}}) 210 | assert_equal [2], result.map { |v| v[:id] } 211 | 212 | result = index.search([1, 1, 1], filter: {quantity: {"$exists" => true}}) 213 | assert_equal [1, 2], result.map { |v| v[:id] } 214 | end 215 | 216 | def test_search_non_filterable 217 | index = create_index(distance: "cosine", id_type: "integer", non_filterable: ["category"]) 218 | index.add(1, [1, 1, 1], metadata: {category: "A"}) 219 | 220 | error = assert_raises(Aws::S3Vectors::Errors::ValidationException) do 221 | index.search([1, 1, 1], filter: {category: "A"}) 222 | end 223 | assert_equal "Invalid use of non-filterable metadata in filter", error.message 224 | end 225 | 226 | def test_search_different_dimensions 227 | index = create_index 228 | error = assert_raises(ArgumentError) do 229 | index.search([1, 2]) 230 | end 231 | assert_equal "expected 3 dimensions", error.message 232 | end 233 | 234 | def test_search_id_euclidean 235 | index = create_index(distance: "euclidean", id_type: "integer") 236 | add_items(index) 237 | result = index.search_id(1) 238 | assert_equal [3, 2], result.map { |v| v[:id] } 239 | assert_elements_in_delta [1, 1.7320507764816284], result.map { |v| v[:distance] } 240 | end 241 | 242 | def test_search_id_cosine 243 | index = create_index(distance: "cosine", id_type: "integer") 244 | add_items(index) 245 | result = index.search_id(1) 246 | assert_equal [2, 3], result.map { |v| v[:id] } 247 | assert_elements_in_delta [0, 0.05719095841050148], result.map { |v| v[:distance] } 248 | end 249 | 250 | def test_search_id_with_metadata 251 | index = create_index 252 | index.add(1, [1, 1, 1], metadata: {category: "A", quantity: 2}) 253 | index.add(2, [-1, -1, -1], metadata: {category: "B", quantity: 4}) 254 | index.add(3, [1, 1, 0]) 255 | 256 | result = index.search_id(1, with_metadata: true) 257 | assert_empty result[0][:metadata] 258 | assert_equal ({"category" => "B", "quantity" => 4}), result[1][:metadata] 259 | end 260 | 261 | def test_search_id_filter 262 | index = create_index(distance: "cosine", id_type: "integer") 263 | index.add(1, [1, 1, 1], metadata: {category: "A", quantity: 2}) 264 | index.add(2, [-1, -1, -1], metadata: {category: "B", quantity: 4}) 265 | index.add(3, [1, 1, 0]) 266 | 267 | result = index.search_id(1, filter: {category: "B"}) 268 | assert_equal [2], result.map { |v| v[:id] } 269 | end 270 | 271 | def test_search_id_missing 272 | index = create_index 273 | error = assert_raises(Neighbor::S3::Error) do 274 | index.search_id(4) 275 | end 276 | assert_equal "Could not find item 4", error.message 277 | end 278 | 279 | def test_drop 280 | index = create_index 281 | assert_equal true, index.exists? 282 | assert_nil index.drop 283 | assert_equal false, index.exists? 284 | assert_nil index.drop 285 | end 286 | 287 | def test_id_type_integer 288 | index = create_index(distance: "euclidean", id_type: "integer") 289 | index.add(1, [1, 1, 1]) 290 | index.add("2", [-1, -1, -1]) 291 | error = assert_raises(ArgumentError) do 292 | index.add("3a", [1, 1, 0]) 293 | end 294 | assert_match "invalid value for Integer()", error.message 295 | assert_equal [2], index.search_id(1).map { |v| v[:id] } 296 | assert_equal [1, 2], index.search([1, 1, 1]).map { |v| v[:id] } 297 | end 298 | 299 | def test_id_type_string 300 | index = create_index(distance: "euclidean", id_type: "string") 301 | index.add(1, [1, 1, 1]) 302 | index.add("2", [-1, -1, -1]) 303 | assert_equal ["2"], index.search_id(1).map { |v| v[:id] } 304 | assert_equal ["1", "2"], index.search([1, 1, 1]).map { |v| v[:id] } 305 | end 306 | 307 | private 308 | 309 | def create_index(**options) 310 | options[:distance] ||= ["euclidean", "cosine"].sample 311 | Neighbor::S3::Index.create("items", bucket: bucket, dimensions: 3, **options) 312 | end 313 | 314 | def add_items(index) 315 | items = [ 316 | {id: 1, vector: [1, 1, 1]}, 317 | {id: 2, vector: [2, 2, 2]}, 318 | {id: 3, vector: [1, 1, 2]} 319 | ] 320 | index.add_all(items) 321 | end 322 | end 323 | --------------------------------------------------------------------------------