├── .gitattributes ├── rustfmt.toml ├── Cargo.toml ├── README.md ├── .gitignore ├── LICENSE └── src ├── benchmark.sh ├── faiss_run.py ├── search.rs └── main.rs /.gitattributes: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- 1 | max_width = 80 2 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "ann" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | rand = "0.8" 10 | rayon = "1.7.0" 11 | itertools = "0.10.5" 12 | clap = "2.33.3" 13 | json = "0.12" 14 | dashmap = "5.4.0" 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FANN 2 | FANN stands for "Fennel Approx Nearest Neighbor" and is a tiny (~200 LOC) Rust library for doing approx nearest neighbor search. 3 | It is meant to be for educational purposes. You're free to use it in production at your risk (though since this is only ~200 lines of Rust, you can literally read/inspect the code before using it in prod). 4 | 5 | # Fennel 6 | [Fennel](https://fennel.ai/) offers realtime feature engineering platform as fully managed service. You can 7 | read Fennel [docs here](https://docs.fennel.ai/). 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | /target/ 4 | 5 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 6 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 7 | Cargo.lock 8 | 9 | # These are backup files generated by rustfmt 10 | **/*.rs.bk 11 | 12 | # IDE/OS-based additional files 13 | .DS_Store 14 | __pycache__ 15 | 16 | # Files only for benchmarking that are auto-downloaded - only possible to move to 17 | # git with LFS but not necessary 18 | data/* 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 fennel-ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This benchmarking script is responsible for: 4 | # 1) Making the data directory if needed 5 | # 2) Download the FastText embeddings if not available locally 6 | # 3) Run the FAISS HNSW benchmarking 7 | # 4) Run the Rust Annoy Index logic 8 | 9 | # Make the data folder if not exists 10 | SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 11 | ROOT_DIR="$( cd "${SCRIPT_DIR}/.." >/dev/null 2>&1 && pwd )" 12 | DATA_DIR="$ROOT_DIR/data" 13 | mkdir -p "$DATA_DIR" 14 | 15 | # Download the wikidata data file if not exists 16 | WIKIDATA_FILE="${DATA_DIR}/wikidata.vec" 17 | WIKIDATA_URL="https://dl.fbaipublicfiles.com/fasttext/vectors-english/wiki-news-300d-1M.vec.zip" 18 | if [ ! -f "$WIKIDATA_FILE" ]; then 19 | echo "$WIKIDATA_FILE does not exist. Downloading..." 20 | curl -o temporary.zip "$WIKIDATA_URL" 21 | unzip temporary.zip 22 | mv wiki-news-300d-1M.vec "$WIKIDATA_FILE" 23 | rm -rf temporary.zip 24 | else 25 | echo "$WIKIDATA_FILE already exists." 26 | fi 27 | 28 | # run the faiss benchmarking 29 | FAISS_FILE="${SCRIPT_DIR}/faiss_run.py" 30 | python3 "$FAISS_FILE" --input-vec "$WIKIDATA_FILE" --data-dir "$DATA_DIR" 31 | 32 | # Benchmark our Rust index 33 | CWD_TO_PRESERVE=$(pwd) 34 | cd "$ROOT_DIR" 35 | cargo build --release 36 | ./target/release/ann "$DATA_DIR" "$WIKIDATA_FILE" 37 | cd "$CWD_TO_PRESERVE" 38 | -------------------------------------------------------------------------------- /src/faiss_run.py: -------------------------------------------------------------------------------- 1 | try: 2 | import faiss 3 | except Exception as ex: 4 | print("Please ensure faiss-cpu==1.7.3 is installed locally") 5 | raise ex 6 | import pickle 7 | import os 8 | import numpy as np 9 | import time 10 | import argparse 11 | 12 | 13 | def convert_raw_data_if_needed(data_dir: str, input_vec_path: str): 14 | input_file_name = os.path.split(input_vec_path)[1] 15 | input_file_no_ext = os.path.splitext(input_file_name)[0] 16 | processed_npz_path = os.path.join(data_dir, f"faiss_{input_file_no_ext}.npz") 17 | processed_pkl_path = os.path.join(data_dir, f"faiss_{input_file_no_ext}.pkl") 18 | if not os.path.exists(processed_npz_path) or not os.path.exists(processed_pkl_path): 19 | # We must actually process the data 20 | start = time.time() 21 | with open(input_vec_path, "r") as f: 22 | data = f.readlines() 23 | num_vectors, dim = [int(value) for value in data[0].split()] 24 | print(f"Running with {num_vectors} vectors and {dim} dimensions") 25 | sample_data = np.zeros((num_vectors, dim), dtype=np.float32) 26 | word_to_index_map = {} 27 | for i in range(len(sample_data)): 28 | entries = data[i + 1].split() 29 | word_to_index_map[entries[0]] = i 30 | for j in range(dim): 31 | sample_data[i,j] = float(entries[j+1]) 32 | del data 33 | end = time.time() 34 | np.savez(processed_npz_path, x=sample_data) 35 | pickle.dump(word_to_index_map, open(processed_pkl_path, "wb")) 36 | print(f"Loaded in {sample_data.shape}-shape input data in {end-start} seconds") 37 | return processed_npz_path, processed_pkl_path 38 | 39 | 40 | def load_data(npz_path: str, pkl_path: str): 41 | # Expects [convert_data] above to have been called once to pre-process the files 42 | sample_data = np.load(npz_path)["x"] 43 | word_to_index_map = pickle.load(open(pkl_path, "rb")) 44 | num_vectors, dim = sample_data.shape 45 | print(f"Running with {num_vectors} vectors and {dim} dimensions") 46 | index_to_word_map = {} 47 | for word, i in word_to_index_map.items(): 48 | index_to_word_map[i] = word 49 | return sample_data, word_to_index_map, index_to_word_map 50 | 51 | 52 | def index_and_check_runtime( 53 | data, ef_search = None, ef_construction = None, 54 | max_node_size = 15, top_k = 20, save_all_distances_folder = None): 55 | print(f"ef_search={ef_search}, ef_construction={ef_construction}, max_node_size={max_node_size}, top_k={top_k}") 56 | # Index this data into Faiss 57 | idx_start = time.time() 58 | index = faiss.IndexHNSWFlat(data.shape[1], max_node_size) 59 | if ef_search is not None: 60 | index.hnsw.efSearch = ef_search 61 | if ef_construction is not None: 62 | index.hnsw.efConstruction = ef_construction 63 | index.add(data) 64 | idx_end = time.time() 65 | idx_time = idx_end - idx_start 66 | print(f"Indexed data into HNSWFlat Index in {idx_time} seconds") 67 | search_data = data[np.random.choice(data.shape[0], size=1000)].reshape(1000, 1, -1) 68 | sch_start = time.time() 69 | for i in range(1000): 70 | # Take a random sample of vectors and just time search on FAISS 71 | index.search(search_data[i], top_k) 72 | sch_end = time.time() 73 | avg_sch_time = (sch_end - sch_start) / 1000 74 | print(f"Average time searching in bulk took {avg_sch_time} seconds") 75 | all_distances = [] 76 | for vector in data: 77 | D, _ = index.search(vector.reshape(1,-1), top_k) 78 | all_distances.append(sum(D[0]**(1/2)) / len(D[0])) 79 | print(f"Average Euclidean Distance = {sum(all_distances)/len(all_distances)}") 80 | if save_all_distances_folder is not None: 81 | print("Saving all distances for future visualization") 82 | filename = f"faiss_{ef_search}_{ef_construction}_{max_node_size}_{top_k}.pkl" 83 | path = os.path.join(save_all_distances_folder, filename) 84 | pickle.dump(all_distances, open(path, "wb")) 85 | return index 86 | 87 | 88 | 89 | def search_a_word_faiss( 90 | word, top_k, sample_data, word_to_index_map, index_to_word_map, index): 91 | search = np.zeros((1, sample_data.shape[1]), dtype=np.float32) 92 | search[0] = sample_data[word_to_index_map[word]] 93 | D, I = index.search(search, top_k) 94 | words = [index_to_word_map[I[0][i]] for i in range(I.shape[1])] 95 | return D[0]**(1/2), words 96 | 97 | 98 | def search_a_word_exhaustive( 99 | word, top_k, sample_data, word_to_index_map, index_to_word_map): 100 | data = sample_data[word_to_index_map[word]] 101 | euc_distances = ((sample_data - data)**2).sum(axis=1)**(1/2) 102 | first_k_indices = euc_distances.argsort()[:top_k] 103 | distances = [euc_distances[idx] for idx in first_k_indices] 104 | words = [index_to_word_map[idx] for idx in first_k_indices] 105 | return distances, words 106 | 107 | 108 | def display_comparison( 109 | word, top_k, sample_data, word_to_index_map, index_to_word_map, index): 110 | faiss_dist, faiss_words = search_a_word_faiss( 111 | word, top_k, sample_data, word_to_index_map, index_to_word_map, index) 112 | print(f"Word: {word}") 113 | print(f"FAISS Euclidean Dist: {faiss_dist}") 114 | print(f"FAISS Words: {faiss_words}") 115 | ex_dist, ex_words = search_a_word_exhaustive( 116 | word, top_k, sample_data, word_to_index_map, index_to_word_map) 117 | print(f"Exhaustive Euclidean Dist: {ex_dist}") 118 | print(f"Exhaustive Words: {ex_words}") 119 | print() 120 | 121 | 122 | # Parse the location of the input file 123 | parser = argparse.ArgumentParser(description="Run FAISS Benchmarking") 124 | parser.add_argument("--data-dir", type=str, help="The data directory path") 125 | parser.add_argument("--input-vec", type=str, help="The path to the input vec file") 126 | args = parser.parse_args() 127 | 128 | # Load in the sample data 129 | # First time ever, run [convert_raw_data] and then you can save the conversion time 130 | npz_path, pkl_path = convert_raw_data_if_needed(args.data_dir, args.input_vec) 131 | sample_data, word_to_index_map, index_to_word_map = load_data(npz_path, pkl_path) 132 | 133 | # On the default configurations, print all distances too 134 | index = index_and_check_runtime(sample_data, save_all_distances_folder=args.data_dir) 135 | 136 | # Now, evaluate FAISS results qualitatively 137 | other_display_args = [20, sample_data, word_to_index_map, index_to_word_map, index] 138 | display_comparison("river", *other_display_args) 139 | display_comparison("war", *other_display_args) 140 | display_comparison("love", *other_display_args) 141 | display_comparison("education", *other_display_args) 142 | -------------------------------------------------------------------------------- /src/search.rs: -------------------------------------------------------------------------------- 1 | use dashmap::DashSet; 2 | use itertools::Itertools; 3 | use rand::prelude::SliceRandom; 4 | use rayon::prelude::*; 5 | use std::{cmp::min, collections::HashSet}; 6 | 7 | #[derive(Eq, PartialEq, Hash)] 8 | pub struct HashKey([u32; N]); 9 | 10 | struct HyperPlane { 11 | coefficients: Vector, 12 | constant: f32, 13 | } 14 | impl HyperPlane { 15 | pub fn point_is_above(&self, point: &Vector) -> bool { 16 | self.coefficients.dot_product(point) + self.constant >= 0.0 17 | } 18 | } 19 | 20 | #[derive(Copy, Clone)] 21 | pub struct Vector(pub [f32; N]); 22 | impl Vector { 23 | pub fn subtract_from(&self, vector: &Vector) -> Vector { 24 | let mapped = self.0.iter().zip(vector.0).map(|(a, b)| b - a); 25 | let coords: [f32; N] = mapped.collect::>().try_into().unwrap(); 26 | return Vector(coords); 27 | } 28 | pub fn avg(&self, vector: &Vector) -> Vector { 29 | let mapped = self.0.iter().zip(vector.0).map(|(a, b)| (a + b) / 2.0); 30 | let coords: [f32; N] = mapped.collect::>().try_into().unwrap(); 31 | return Vector(coords); 32 | } 33 | pub fn dot_product(&self, vector: &Vector) -> f32 { 34 | let zipped_iter = self.0.iter().zip(vector.0); 35 | return zipped_iter.map(|(a, b)| a * b).sum::(); 36 | } 37 | pub fn to_hashkey(&self) -> HashKey { 38 | // f32 in Rust doesn't implement hash. We use bytes to dedup. While it 39 | // can't differentiate ~16M ways NaN is written, it's safe for us 40 | let bit_iter = self.0.iter().map(|a| a.to_bits()); 41 | let data: [u32; N] = bit_iter.collect::>().try_into().unwrap(); 42 | return HashKey::(data); 43 | } 44 | pub fn sq_euc_dis(&self, vector: &Vector) -> f32 { 45 | let zipped_iter = self.0.iter().zip(vector.0); 46 | return zipped_iter.map(|(a, b)| (a - b).powi(2)).sum(); 47 | } 48 | } 49 | 50 | enum Node { 51 | Inner(Box>), 52 | Leaf(Box>), 53 | } 54 | struct LeafNode(Vec); 55 | struct InnerNode { 56 | hyperplane: HyperPlane, 57 | left_node: Node, 58 | right_node: Node, 59 | } 60 | pub struct ANNIndex { 61 | trees: Vec>, 62 | ids: Vec, 63 | values: Vec>, 64 | } 65 | impl ANNIndex { 66 | fn build_hyperplane( 67 | indexes: &Vec, 68 | all_vecs: &Vec>, 69 | ) -> (HyperPlane, Vec, Vec) { 70 | let sample: Vec<_> = indexes 71 | .choose_multiple(&mut rand::thread_rng(), 2) 72 | .collect(); 73 | // cartesian eq for hyperplane n * (x - x_0) = 0 74 | // n (normal vector) is the coefs x_1 to x_n 75 | let (a, b) = (*sample[0], *sample[1]); 76 | let coefficients = all_vecs[a].subtract_from(&all_vecs[b]); 77 | let point_on_plane = all_vecs[a].avg(&all_vecs[b]); 78 | let constant = -coefficients.dot_product(&point_on_plane); 79 | let hyperplane = HyperPlane:: { 80 | coefficients, 81 | constant, 82 | }; 83 | let (mut above, mut below) = (vec![], vec![]); 84 | for &id in indexes.iter() { 85 | if hyperplane.point_is_above(&all_vecs[id]) { 86 | above.push(id) 87 | } else { 88 | below.push(id) 89 | }; 90 | } 91 | return (hyperplane, above, below); 92 | } 93 | 94 | fn build_a_tree( 95 | max_size: i32, 96 | indexes: &Vec, 97 | all_vecs: &Vec>, 98 | ) -> Node { 99 | if indexes.len() <= (max_size as usize) { 100 | return Node::Leaf(Box::new(LeafNode::(indexes.clone()))); 101 | } 102 | let (plane, above, below) = Self::build_hyperplane(indexes, all_vecs); 103 | let node_above = Self::build_a_tree(max_size, &above, all_vecs); 104 | let node_below = Self::build_a_tree(max_size, &below, all_vecs); 105 | return Node::Inner(Box::new(InnerNode:: { 106 | hyperplane: plane, 107 | left_node: node_below, 108 | right_node: node_above, 109 | })); 110 | } 111 | 112 | fn deduplicate( 113 | vectors: &Vec>, 114 | ids: &Vec, 115 | dedup_vectors: &mut Vec>, 116 | ids_of_dedup_vectors: &mut Vec, 117 | ) { 118 | let mut hashes_seen = HashSet::new(); 119 | for i in 1..vectors.len() { 120 | let hash_key = vectors[i].to_hashkey(); 121 | if !hashes_seen.contains(&hash_key) { 122 | hashes_seen.insert(hash_key); 123 | dedup_vectors.push(vectors[i]); 124 | ids_of_dedup_vectors.push(ids[i]); 125 | } 126 | } 127 | } 128 | 129 | pub fn build_index( 130 | num_trees: i32, 131 | max_size: i32, 132 | vecs: &Vec>, 133 | vec_ids: &Vec, 134 | ) -> ANNIndex { 135 | let (mut unique_vecs, mut ids) = (vec![], vec![]); 136 | Self::deduplicate(vecs, vec_ids, &mut unique_vecs, &mut ids); 137 | // Trees hold an index into the [unique_vecs] list which is not 138 | // necessarily its id, if duplicates existed 139 | let all_indexes: Vec = (0..unique_vecs.len()).collect(); 140 | let trees: Vec<_> = (0..num_trees) 141 | .into_par_iter() 142 | .map(|_| Self::build_a_tree(max_size, &all_indexes, &unique_vecs)) 143 | .collect(); 144 | return ANNIndex:: { 145 | trees, 146 | ids, 147 | values: unique_vecs, 148 | }; 149 | } 150 | 151 | fn tree_result( 152 | query: Vector, 153 | n: i32, 154 | tree: &Node, 155 | candidates: &DashSet, 156 | ) -> i32 { 157 | // take everything in node, if still needed, take from alternate subtree 158 | match tree { 159 | Node::Leaf(box_leaf) => { 160 | let leaf_values = &(box_leaf.0); 161 | let num_candidates_found = min(n as usize, leaf_values.len()); 162 | for i in 0..num_candidates_found { 163 | candidates.insert(leaf_values[i]); 164 | } 165 | return num_candidates_found as i32; 166 | } 167 | Node::Inner(inner) => { 168 | let above = (*inner).hyperplane.point_is_above(&query); 169 | let (main, backup) = match above { 170 | true => (&(inner.right_node), &(inner.left_node)), 171 | false => (&(inner.left_node), &(inner.right_node)), 172 | }; 173 | match Self::tree_result(query, n, main, candidates) { 174 | k if k < n => { 175 | k + Self::tree_result(query, n - k, backup, candidates) 176 | } 177 | k => k, 178 | } 179 | } 180 | } 181 | } 182 | 183 | pub fn search_approximate( 184 | &self, 185 | query: Vector, 186 | top_k: i32, 187 | ) -> Vec<(i32, f32)> { 188 | let candidates = DashSet::new(); 189 | self.trees.par_iter().for_each(|tree| { 190 | Self::tree_result(query, top_k, tree, &candidates); 191 | }); 192 | candidates 193 | .into_iter() 194 | .map(|idx| (idx, self.values[idx].sq_euc_dis(&query))) 195 | .sorted_by(|a, b| a.1.partial_cmp(&b.1).unwrap()) 196 | .take(top_k as usize) 197 | .map(|(idx, dis)| (self.ids[idx], dis)) 198 | .collect() 199 | } 200 | } 201 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | use rand::prelude::IteratorRandom; 2 | use rayon::prelude::*; 3 | use std::io::BufRead; 4 | mod search; 5 | use clap::{App, Arg}; 6 | use json; 7 | use search::{ANNIndex, Vector}; 8 | use std::collections::{HashMap, HashSet}; 9 | use std::fs::File; 10 | use std::io::Write; 11 | 12 | fn search_exhaustive( 13 | all_data: &Vec>, 14 | vector: &Vector, 15 | top_k: i32, 16 | ) -> HashSet { 17 | let enumerated_iter = all_data.iter().enumerate(); 18 | let mut idx_sq_euc_dis: Vec<(usize, f32)> = enumerated_iter 19 | .map(|(i, can)| (i, can.sq_euc_dis(vector))) 20 | .collect(); 21 | idx_sq_euc_dis.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); 22 | // Return a set of IDs corresponding to the closest matches 23 | let mut final_candidates = std::collections::HashSet::new(); 24 | for i in 0..top_k as usize { 25 | final_candidates.insert(idx_sq_euc_dis[i].0 as i32); 26 | } 27 | return final_candidates; 28 | } 29 | 30 | fn load_raw_wiki_data( 31 | filename: &str, 32 | all_data: &mut Vec>, 33 | word_to_idx_mapping: &mut HashMap, 34 | idx_to_word_mapping: &mut HashMap, 35 | ) { 36 | // wiki-news has 999,994 vectors in 300 dimensions 37 | let file = std::fs::File::open(filename).expect("could not read file"); 38 | let reader = std::io::BufReader::new(file); 39 | let mut cur_idx: usize = 0; 40 | // We skip the first line that simply has metadata 41 | for maybe_line in reader.lines().skip(1) { 42 | let line = maybe_line.expect("Should decode the line"); 43 | let mut data_on_line_iter = line.split_whitespace(); 44 | let word = data_on_line_iter 45 | .next() 46 | .expect("Each line begins with a word"); 47 | // Update the mappings 48 | word_to_idx_mapping.insert(word.to_owned(), cur_idx); 49 | idx_to_word_mapping.insert(cur_idx, word.to_owned()); 50 | cur_idx += 1; 51 | // Parse the vector. Everything except the word on the line is the vector 52 | let embedding: [f32; N] = data_on_line_iter 53 | .map(|s| s.parse::().unwrap()) 54 | .collect::>() 55 | .try_into() 56 | .unwrap(); 57 | all_data.push(search::Vector(embedding)); 58 | } 59 | } 60 | 61 | fn search_approximate_as_hashset( 62 | index: &ANNIndex, 63 | vector: Vector, 64 | top_k: i32, 65 | ) -> HashSet { 66 | let nearby_idx_and_distance = index.search_approximate(vector, top_k); 67 | let mut id_hashset = std::collections::HashSet::new(); 68 | for &(idx, _) in nearby_idx_and_distance.iter() { 69 | id_hashset.insert(idx); 70 | } 71 | return id_hashset; 72 | } 73 | 74 | fn build_and_benchmark_index( 75 | my_input_data: &Vec>, 76 | num_trees: i32, 77 | max_node_size: i32, 78 | top_k: i32, 79 | words_to_visualize: &Vec, 80 | word_to_idx_mapping: &HashMap, 81 | idx_to_word_mapping: &HashMap, 82 | sample_idx: Option<&Vec>, 83 | ) -> Vec> { 84 | println!( 85 | "dimensions={}, num_trees={}, max_node_size={}, top_k={}", 86 | N, num_trees, max_node_size, top_k 87 | ); 88 | // Build the index 89 | let start = std::time::Instant::now(); 90 | let my_ids: Vec = (0..my_input_data.len() as i32).collect(); 91 | let index = ANNIndex::::build_index( 92 | num_trees, 93 | max_node_size, 94 | &my_input_data, 95 | &my_ids, 96 | ); 97 | let duration = start.elapsed(); 98 | println!("Built ANN index in {}-D in {:?}", N, duration); 99 | // Benchmark it with 1000 sequential queries 100 | let benchmark_idx: Vec = (0..my_input_data.len() as i32) 101 | .choose_multiple(&mut rand::thread_rng(), 1000); 102 | let mut search_vectors: Vec> = Vec::new(); 103 | for idx in benchmark_idx { 104 | search_vectors.push(my_input_data[idx as usize]); 105 | } 106 | let start = std::time::Instant::now(); 107 | for i in 0..1000 { 108 | index.search_approximate(search_vectors[i], top_k); 109 | } 110 | let duration = start.elapsed() / 1000; 111 | println!("Bulk ANN-search in {}-D has average time {:?}", N, duration); 112 | // Visualize some words 113 | for word in words_to_visualize.iter() { 114 | println!("Currently visualizing {}", word); 115 | let word_index = word_to_idx_mapping[word]; 116 | let embedding = my_input_data[word_index]; 117 | let nearby_idx_and_distance = 118 | index.search_approximate(embedding, top_k); 119 | for &(idx, distance) in nearby_idx_and_distance.iter() { 120 | println!( 121 | "{}, distance={}", 122 | idx_to_word_mapping[&(idx as usize)], 123 | distance.sqrt() 124 | ); 125 | } 126 | } 127 | // If [sample_idx] provided, only find the top_k neighbours for those 128 | // and return that data. Otherwise, find it for all vectors in the 129 | // corpus. When benchmarking other hyper-parameters, we use a smaller 130 | // [sample_idx] set to control run-time and get efficient estimates 131 | // of performance metrics. 132 | let start = std::time::Instant::now(); 133 | let mut subset: Vec> = Vec::new(); 134 | let sample_from_my_data = match sample_idx { 135 | Some(sample_indices) => { 136 | for &idx in sample_indices { 137 | subset.push(my_input_data[idx as usize]); 138 | } 139 | &subset 140 | } 141 | None => my_input_data, 142 | }; 143 | let index_results: Vec> = sample_from_my_data 144 | .par_iter() 145 | .map(|&vector| search_approximate_as_hashset(&index, vector, top_k)) 146 | .collect(); 147 | let duration = start.elapsed(); 148 | println!( 149 | "Collected {} quality results in {:?}", 150 | index_results.len(), 151 | duration 152 | ); 153 | return index_results; 154 | } 155 | 156 | fn analyze_average_euclidean_metrics( 157 | json_path: &String, 158 | all_embedding_data: &Vec>, 159 | index_results: &Vec>, 160 | sample_idx: Option<&Vec>, 161 | ) { 162 | let start = std::time::Instant::now(); 163 | let mut subset: Vec> = Vec::new(); 164 | // If [sample_idx] is provided, that means len(sample_idx) == 165 | // len(index_results) and index_results[i] are the nearest neighbours 166 | // for all_embedding_data[sample_idx[i]]. 167 | // If not provided, it is assumed that index_results are available 168 | // for all vectors in our corpus. 169 | let sample_from_my_data = match sample_idx { 170 | Some(sample_indices) => { 171 | assert!(sample_indices.len() == index_results.len()); 172 | for &idx in sample_indices { 173 | subset.push(all_embedding_data[idx as usize]); 174 | } 175 | &subset 176 | } 177 | None => all_embedding_data, 178 | }; 179 | let mut euc_distances: Vec = Vec::new(); 180 | for (i, neighbours) in index_results.iter().enumerate() { 181 | // Ignore the distance here since we must compute it against the full 300-D embedding 182 | let mut sum_dist_to_neighbours = 0.0; 183 | for &neighbour_id in neighbours.iter() { 184 | sum_dist_to_neighbours += sample_from_my_data[i] 185 | .sq_euc_dis(&all_embedding_data[neighbour_id as usize]) 186 | .sqrt(); 187 | } 188 | let avg_dist_to_neighbours = 189 | sum_dist_to_neighbours / neighbours.len() as f32; 190 | euc_distances.push(avg_dist_to_neighbours); 191 | } 192 | let mean_euc: f32 = 193 | euc_distances.iter().sum::() / euc_distances.len() as f32; 194 | println!("Average Euclidean {} in {:?}", mean_euc, start.elapsed()); 195 | let json_string = json::stringify(euc_distances); 196 | // Write the JSON to a file 197 | let mut file = File::create(json_path).expect("Failed to create file"); 198 | file.write_all(json_string.as_bytes()) 199 | .expect("Failed to write JSON to file"); 200 | } 201 | 202 | fn analyze_recall_metrics( 203 | exhaustive_results: &Vec>, 204 | reduced_index_results: &Vec>, 205 | ) { 206 | let start = std::time::Instant::now(); 207 | // The exhaustive search and index results should line up for us 208 | // to compare the recall. 209 | assert!(reduced_index_results.len() == exhaustive_results.len()); 210 | let mut total_recall_pct = 0.0; 211 | for (i, true_neighbours) in exhaustive_results.iter().enumerate() { 212 | let mut num_matches_with_brute_results = 0.0; 213 | for approx_id in reduced_index_results[i].iter() { 214 | if true_neighbours.contains(approx_id) { 215 | num_matches_with_brute_results += 1.0; 216 | } 217 | } 218 | total_recall_pct += 219 | num_matches_with_brute_results / true_neighbours.len() as f32; 220 | } 221 | let average_recall = total_recall_pct / exhaustive_results.len() as f32; 222 | let duration = start.elapsed(); 223 | println!("Average Recall% = {} in {:?}", average_recall, duration); 224 | } 225 | 226 | fn main() { 227 | // Parse command line arguments 228 | let data_dir_arg = 229 | Arg::with_name("data-dir").takes_value(true).required(true); 230 | let input_vec_arg = 231 | Arg::with_name("input-vec").takes_value(true).required(true); 232 | let app = App::new("FANN").arg(data_dir_arg).arg(input_vec_arg); 233 | let matches = app.get_matches(); 234 | let data_dir = matches.value_of("data-dir").unwrap(); 235 | let input_vec_path = matches.value_of("input-vec").unwrap(); 236 | // Parse the data from wiki-news 237 | const DIM: usize = 300; 238 | const TOP_K: i32 = 20; 239 | let start = std::time::Instant::now(); 240 | let mut my_input_data: Vec> = Vec::new(); 241 | let mut word_to_idx_mapping: HashMap = HashMap::new(); 242 | let mut idx_to_word_mapping: HashMap = HashMap::new(); 243 | load_raw_wiki_data::( 244 | input_vec_path, 245 | &mut my_input_data, 246 | &mut word_to_idx_mapping, 247 | &mut idx_to_word_mapping, 248 | ); 249 | let duration = start.elapsed(); 250 | println!("Parsed {} vectors in {:?}", my_input_data.len(), duration); 251 | // Try the naive exact-search for TOP_K elements 252 | let start = std::time::Instant::now(); 253 | search_exhaustive::(&my_input_data, &my_input_data[0], TOP_K); 254 | let duration = start.elapsed(); 255 | println!("Found vectors via brute-search in {:?}", duration); 256 | // Take 1000 random vectors from the input data and find its TOP_K nearest 257 | // neighbors using the exhaustive/brute-force approach. This allows us to 258 | // calculate recall for our implementations. This is a list of randomly chosen 259 | // indices from our main embedding set - we use a subset to 260 | // estimate our metrics due to the computation cost. 261 | let start = std::time::Instant::now(); 262 | rayon::ThreadPoolBuilder::new() 263 | .num_threads(5) 264 | .build_global() 265 | .unwrap(); 266 | let sample_idx: Vec = (0..my_input_data.len() as i32) 267 | .choose_multiple(&mut rand::thread_rng(), 1000); 268 | // Make a vector of hashsets where hashset at position i represents the exhaustive 269 | // nearest neighbours for the embedding at position idx in my_input_data, where idx 270 | // is at position i in sample_idx. This means [exhaustive_results] and [sample_idx] 271 | // are perfectly aligned / lined-up 272 | let exhaustive_results: Vec> = sample_idx 273 | .par_iter() 274 | .map(|&idx| { 275 | search_exhaustive::( 276 | &my_input_data, 277 | &my_input_data[idx as usize], 278 | TOP_K, 279 | ) 280 | }) 281 | .collect(); 282 | let duration = start.elapsed(); 283 | println!("Found exhaustive neighbors for sample in {:?}", duration); 284 | // Build, benchmark and visualize our index with default parameters 285 | let words_to_visualize: Vec = ["river", "war", "love", "education"] 286 | .into_iter() 287 | .map(|x| x.to_owned()) 288 | .collect(); 289 | let index_results = build_and_benchmark_index::( 290 | &my_input_data, 291 | 3, 292 | 15, 293 | TOP_K, 294 | &words_to_visualize, 295 | &word_to_idx_mapping, 296 | &idx_to_word_mapping, 297 | None, 298 | ); 299 | let path = 300 | format!("{}/trees_{}_max_node_{}_k_{}.json", data_dir, 3, 15, TOP_K); 301 | analyze_average_euclidean_metrics::( 302 | &path, 303 | &my_input_data, 304 | &index_results, 305 | None, 306 | ); 307 | // We only have exhaustive data for [sample_idx] 308 | let reduced_index_results: Vec> = sample_idx 309 | .iter() 310 | .map(|&idx| index_results[idx as usize].clone()) 311 | .collect(); 312 | analyze_recall_metrics::(&exhaustive_results, &reduced_index_results); 313 | // Try some other parameters. New values for max_node_size, num_trees at dim=300. 314 | // See how we can make it better in its accuracy/ quality. 315 | let no_words: Vec = Vec::new(); 316 | for max_size in [5, 15, 30] { 317 | for num_trees in [3, 9, 15] { 318 | // Skip num_trees=3 and max_size=15 since that was our default config 319 | // we already saw above. 320 | if (num_trees == 3) && (max_size == 15) { 321 | continue; 322 | } 323 | let index_results = build_and_benchmark_index::( 324 | &my_input_data, 325 | num_trees, 326 | max_size, 327 | TOP_K, 328 | &no_words, 329 | &word_to_idx_mapping, 330 | &idx_to_word_mapping, 331 | Some(&sample_idx), 332 | ); 333 | let path = format!( 334 | "{}/trees_{}_max_node_{}_k_{}.json", 335 | data_dir, num_trees, max_size, TOP_K 336 | ); 337 | analyze_average_euclidean_metrics::( 338 | &path, 339 | &my_input_data, 340 | &index_results, 341 | Some(&sample_idx), 342 | ); 343 | // We only have exhaustive data for [sample_idx] 344 | analyze_recall_metrics::(&exhaustive_results, &index_results); 345 | } 346 | } 347 | } 348 | --------------------------------------------------------------------------------