├── serve_goko ├── Readme.md ├── src │ ├── http │ │ ├── mod.rs │ │ ├── maker.rs │ │ └── message.rs │ ├── lib.rs │ ├── api │ │ ├── path.rs │ │ ├── parameters.rs │ │ ├── knn.rs │ │ ├── tracker.rs │ │ └── mod.rs │ ├── core │ │ ├── mod.rs │ │ ├── tracker_worker.rs │ │ └── internal_service.rs │ ├── parsers │ │ ├── msgpack_dense.rs │ │ └── mod.rs │ └── errors.rs ├── Cargo.toml └── examples │ ├── mnist_server.rs │ └── ember_server.rs ├── NOTICE.txt ├── pygoko ├── MANIFEST.in ├── requirements-dev.txt ├── pygoko │ └── __init__.py ├── pyproject.toml ├── tests │ ├── test_one_d_viz.py │ └── two_d_viz.py ├── Cargo.toml ├── src │ ├── lib.rs │ ├── plugins.rs │ ├── node.rs │ └── layer.rs └── setup.py ├── .travis.yml ├── .gitignore ├── goko ├── src │ ├── covertree │ │ ├── mod.rs │ │ └── query_tools │ │ │ ├── mod.rs │ │ │ ├── query_items.rs │ │ │ └── trace_query_heap.rs │ ├── monomap │ │ ├── README.md │ │ └── inner.rs │ ├── plugins │ │ ├── discrete │ │ │ └── mod.rs │ │ ├── gaussians │ │ │ ├── mod.rs │ │ │ └── svd_gaussian.rs │ │ ├── utils.rs │ │ ├── labels.rs │ │ └── mod.rs │ ├── lib.rs │ ├── errors.rs │ ├── utils.rs │ └── query_interface │ │ └── mod.rs ├── protos │ └── tree_file_format.proto ├── build.rs ├── Cargo.toml ├── benches │ └── path_bench.rs ├── README.md └── examples │ ├── mnist.rs │ ├── ember_sequence_track.rs │ └── ember_drop.rs ├── Cargo.toml ├── pointcloud ├── src │ ├── metrics │ │ ├── mod.rs │ │ ├── l1_f32.rs │ │ ├── l2_f32.rs │ │ ├── l1_misc.rs │ │ └── l2_misc.rs │ ├── data_sources │ │ ├── mod.rs │ │ └── sparse_ram.rs │ ├── lib.rs │ ├── loaders │ │ ├── mod.rs │ │ └── csv_loaders.rs │ └── summaries │ │ └── mod.rs ├── README.md └── Cargo.toml ├── .flake8 ├── README.md ├── data └── setup_data.py └── examples ├── graphistry_vis.ipynb ├── mnist_knn.ipynb └── ember_chronological_drift.py /serve_goko/Readme.md: -------------------------------------------------------------------------------- 1 | hi -------------------------------------------------------------------------------- /NOTICE.txt: -------------------------------------------------------------------------------- 1 | Goko 2 | Copyright 2012-2018 Elasticsearch B.V. -------------------------------------------------------------------------------- /pygoko/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include pyproject.toml Cargo.toml 2 | recursive-include src * -------------------------------------------------------------------------------- /pygoko/requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pip>=19.1 2 | pytest>=3.5.0 3 | setuptools-rust>=0.11.5 4 | pytest-benchmark>=3.1.1 -------------------------------------------------------------------------------- /serve_goko/src/http/mod.rs: -------------------------------------------------------------------------------- 1 | mod maker; 2 | mod message; 3 | mod service; 4 | 5 | pub use service::GokoHttp; 6 | pub use message::ResponseFuture; 7 | pub use maker::MakeGokoHttp; -------------------------------------------------------------------------------- /pygoko/pygoko/__init__.py: -------------------------------------------------------------------------------- 1 | from .pygoko import CoverTree, PyBayesCategoricalTracker, PyKLDivergenceBaseline 2 | 3 | __all__ = ["CoverTree", "PyBayesCategoricalTracker", "PyKLDivergenceBaseline"] 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | jobs: 2 | include: 3 | - language: rust 4 | rust: 5 | - nightly 6 | fast_finish: true 7 | script: 8 | - cargo build --verbose 9 | - cargo test --verbose -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | Cargo.lock 3 | data 4 | models 5 | build 6 | dist 7 | *.egg-info 8 | *.png 9 | __pycache__ 10 | add_license.py 11 | .DS_store 12 | .vscode 13 | .ipynb_checkpoints 14 | graphistry_creds.json 15 | -------------------------------------------------------------------------------- /goko/src/covertree/mod.rs: -------------------------------------------------------------------------------- 1 | pub(crate) mod builders; 2 | pub(crate) mod data_caches; 3 | pub mod layer; 4 | pub mod node; 5 | pub mod query_tools; 6 | 7 | mod tree; 8 | 9 | pub use builders::CoverTreeBuilder; 10 | pub use tree::*; 11 | -------------------------------------------------------------------------------- /pygoko/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "pygoko" 3 | version = "0.1.0" 4 | description = "A fast covertree implementation" 5 | authors = [] 6 | 7 | [build-system] 8 | requires = ["setuptools>=41.0.0", "wheel", "setuptools_rust>=0.10.2", "toml"] 9 | build-backend = "setuptools.build_meta" 10 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | 3 | members = [ 4 | "goko", 5 | "pygoko", 6 | "pointcloud", 7 | "serve_goko", 8 | ] 9 | default-members = [ 10 | "goko", 11 | "pointcloud", 12 | "serve_goko", 13 | ] 14 | 15 | [profile.dev] 16 | opt-level = 0 17 | 18 | [profile.release] 19 | opt-level = 3 -------------------------------------------------------------------------------- /pointcloud/src/metrics/mod.rs: -------------------------------------------------------------------------------- 1 | //! Metrics. 2 | 3 | pub mod l2_misc; 4 | pub use l2_misc::*; 5 | pub mod l1_misc; 6 | pub use l1_misc::*; 7 | pub mod l2_f32; 8 | pub use l2_f32::*; 9 | pub mod l1_f32; 10 | pub use l1_f32::*; 11 | 12 | #[derive(Debug)] 13 | /// L2 distance trait. 14 | pub struct L2 {} 15 | /// L1 distance trait 16 | pub struct L1 {} 17 | -------------------------------------------------------------------------------- /goko/src/monomap/README.md: -------------------------------------------------------------------------------- 1 | # Jon Gjengset's evmap 2 | 3 | This is a modification to Jon Gjengset's evmap to allow edits. It duplicates the data, which isn't ideal but is still damn fast and there are no locks. 4 | 5 | This is here as I feel odd forking his library and uploading a new one on crates.io. This is mostly not my code, I just simplified Jon Gjengset's code. -------------------------------------------------------------------------------- /serve_goko/src/lib.rs: -------------------------------------------------------------------------------- 1 | //#![deny(warnings)] 2 | 3 | //! # A server for Goko 4 | //! 5 | //! 6 | //! See [`GokoRequest`] for documentation of how to query the HTTP server. 7 | //mod client; 8 | //pub use client::*; 9 | pub mod parsers; 10 | pub mod errors; 11 | 12 | pub mod api; 13 | pub use api::GokoRequest; 14 | pub use api::GokoResponse; 15 | pub use parsers::PointParser; 16 | 17 | pub mod http; 18 | pub mod core; -------------------------------------------------------------------------------- /pointcloud/README.md: -------------------------------------------------------------------------------- 1 | # Point Cloud 2 | 3 | A dataset access layer that allows for metadata to be attached to points. Used for `goko`. Currently this accelerates distance calculations with a set of `packed_simd` accelerated norms and a `rayon` threadpool while abstracting the access of the datapoints across multiple data files. It's structured in such a way that adding formats should be easy. 4 | 5 | ## Planned Features 6 | 7 | #### Current work 8 | * Benchmarks. 9 | * PCA, & Gaussian calculators. 10 | 11 | #### Near Future 12 | * Cleanup of the metadata feature in `pointcloud` 13 | * Sparse accessors and sparse databacking 14 | 15 | #### Future 16 | * Network interface for distributed datasets. 17 | * Image file abstraction for applications like imagenet. 18 | * Asynchronous access for the network and file accessors. -------------------------------------------------------------------------------- /pygoko/tests/test_one_d_viz.py: -------------------------------------------------------------------------------- 1 | import pygoko 2 | 3 | import numpy as np 4 | from one_d_viz import show1D 5 | 6 | 7 | data = np.array([[0.499], [0.48], [-0.49], [0.0]], dtype=np.float32) 8 | 9 | 10 | tree = pygoko.CoverTree() 11 | tree.set_scale_base(2) 12 | tree.set_leaf_cutoff(0) 13 | tree.fit(data) 14 | 15 | print(tree.knn(tree.data_point(0), 5)) 16 | 17 | print("============= KL Divergence =============") 18 | prior_weight = 1.0 19 | observation_weight = 1.0 20 | window_size = 3 21 | sequence_len = 10 22 | sample_rate = 2 23 | sequence_count = 1 24 | baseline = tree.kl_div_dirichlet_baseline( 25 | prior_weight, 26 | observation_weight, 27 | sequence_len, 28 | sequence_count, 29 | window_size, 30 | sample_rate, 31 | ) 32 | for i in range(0, sequence_len, sample_rate): 33 | print(baseline.stats(i)) 34 | 35 | show1D(tree, data) 36 | -------------------------------------------------------------------------------- /serve_goko/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "serve_goko" 3 | version = "0.1.0" 4 | authors = ["sven "] 5 | edition = "2018" 6 | 7 | [dependencies] 8 | pin-project = "1.0" 9 | futures-util = { version = "0.3", features = [ "sink" ] } 10 | log = "0.4" 11 | env_logger = "0.7.1" 12 | http = "0.2.3" 13 | warp = "0.3" 14 | bytes = "1.0.1" 15 | hyper = { version = "0.14", features = ["full"] } 16 | tower = { version = "0.4.4", features = ["make", "load", "balance", "util"] } 17 | tokio = { version = "1.1.1", features = ["full"] } 18 | goko = { path = "../goko" } 19 | pointcloud = { path = "../pointcloud" } 20 | serde_json = "1.0.61" 21 | serde = { version = "1.0.123", features = ["derive"] } 22 | indexmap = {version = "1.0.2", features = ["serde-1"]} 23 | rayon = "*" 24 | futures = "0.3.12" 25 | flate2 = "1.0.20" 26 | lazy_static = "*" 27 | rmp-serde = "0.15" 28 | regex = "1.4.3" 29 | base64 = "*" -------------------------------------------------------------------------------- /pygoko/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "pygoko" 3 | version = "0.4.1" 4 | edition = "2018" 5 | 6 | description = "An python interface for goko" 7 | readme = "../README.md" 8 | 9 | authors = ["Sven Cattell "] 10 | 11 | documentation = "https://docs.rs/goko" 12 | homepage = "https://github.com/elastic/goko" 13 | repository = "https://github.com/elastic/goko.git" 14 | 15 | license = "Apache-2.0" 16 | 17 | [badges] 18 | travis-ci = { repository = "https://github.com/elastic/goko.git", branch = "master" } 19 | 20 | [toolchain] 21 | channel = "nightly" 22 | 23 | [dependencies] 24 | goko = { path = "../goko" } 25 | pointcloud = { path = "../pointcloud" } 26 | pyo3 = { version = "0.12.4", features = ["extension-module"] } 27 | numpy = "0.12.1" 28 | ndarray = "0.14.0" 29 | rayon = "1.4.0" 30 | rustc-hash = "1.1.0" 31 | rand = { version = "0.7.3", features = ["small_rng"] } 32 | 33 | [lib] 34 | name = "pygoko" 35 | crate-type = ["cdylib"] 36 | -------------------------------------------------------------------------------- /goko/src/plugins/discrete/mod.rs: -------------------------------------------------------------------------------- 1 | //! # Dirichlet probability 2 | //! 3 | //! We know that the users are quering based on what they want to know about. 4 | //! This has some geometric structure, especially for attackers. We have some 5 | //! prior knowlege about the queries, they should function similarly to the training set. 6 | //! Sections of data that are highly populated should have a higher likelyhood of being 7 | //! queried. 8 | //! 9 | //! This plugin lets us simulate the unknown distribution of the queries of a user in a 10 | //! bayesian way. There may be more applications of this idea, but defending against 11 | //! attackers has been proven. 12 | 13 | pub mod baseline; 14 | pub mod categorical; 15 | pub mod dirichlet; 16 | pub mod tracker; 17 | 18 | #[allow(unused_imports)] 19 | pub mod prelude { 20 | //! Easy way of importing everything 21 | pub use super::baseline::*; 22 | pub use super::categorical::*; 23 | pub use super::dirichlet::*; 24 | pub use super::tracker::*; 25 | } 26 | -------------------------------------------------------------------------------- /goko/protos/tree_file_format.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package CoverTree; 4 | 5 | message NodeProto { 6 | uint64 coverage_count = 1; 7 | uint64 center_index = 2; 8 | string name = 3; 9 | int32 scale_index = 4; 10 | uint64 parent_center_index = 5; 11 | int32 parent_scale_index = 6; 12 | 13 | bool is_leaf = 7; 14 | 15 | repeated uint64 children_point_indexes = 8; 16 | repeated int32 children_scale_indexes = 9; 17 | int32 nested_scale_index = 10; 18 | 19 | repeated uint64 outlier_point_indexes = 11; 20 | string outlier_summary_json = 12; 21 | float radius = 13; 22 | } 23 | 24 | message LayerProto { 25 | int32 scale_index = 1; 26 | repeated NodeProto nodes = 2; 27 | } 28 | 29 | message CoreProto { 30 | bool use_singletons = 1; 31 | float scale_base = 2; 32 | uint64 cutoff = 3; 33 | sint32 resolution = 4; 34 | string partition_type = 5; 35 | 36 | uint64 dim = 7; 37 | uint64 count = 8; 38 | 39 | int32 root_scale = 9; 40 | uint64 root_index = 10; 41 | 42 | repeated LayerProto layers = 11; 43 | map name_map = 12; 44 | } -------------------------------------------------------------------------------- /pointcloud/src/data_sources/mod.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to Elasticsearch B.V. under one or more contributor 3 | * license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright 5 | * ownership. Elasticsearch B.V. licenses this file to you under 6 | * the Apache License, Version 2.0 (the "License"); you may 7 | * not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | //! Some data sources and a trait to dimension and uniformly reference the data contained. 21 | //! The only currently supported are memmaps and ram blobs. 22 | 23 | mod memmap_ram; 24 | mod sparse_ram; 25 | 26 | #[allow(dead_code)] 27 | mod memmapf32; 28 | 29 | #[doc(hidden)] 30 | pub use memmap_ram::*; 31 | -------------------------------------------------------------------------------- /pygoko/src/lib.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to Elasticsearch B.V. under one or more contributor 3 | * license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright 5 | * ownership. Elasticsearch B.V. licenses this file to you under 6 | * the Apache License, Version 2.0 (the "License"); you may 7 | * not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | use pyo3::prelude::*; 21 | 22 | pub mod layer; 23 | pub mod node; 24 | pub mod plugins; 25 | pub mod tree; 26 | 27 | use plugins::*; 28 | use tree::CoverTree; 29 | 30 | #[pymodule] 31 | fn pygoko(_py: Python<'_>, m: &PyModule) -> PyResult<()> { 32 | m.add_class::()?; 33 | m.add_class::()?; 34 | m.add_class::()?; 35 | Ok(()) 36 | } 37 | -------------------------------------------------------------------------------- /serve_goko/examples/mnist_server.rs: -------------------------------------------------------------------------------- 1 | use goko::utils::*; 2 | use goko::CoverTreeWriter; 3 | use pointcloud::*; 4 | use std::path::Path; 5 | extern crate serve_goko; 6 | use serve_goko::parsers::MsgPackDense; 7 | use serve_goko::http::*; 8 | use serve_goko::core::*; 9 | use std::sync::Arc; 10 | use goko::plugins::discrete::prelude::GokoDirichlet; 11 | use hyper::Server; 12 | use log::LevelFilter; 13 | use env_logger::Builder; 14 | 15 | fn build_tree() -> CoverTreeWriter> { 16 | let file_name = "../data/mnist_complex.yml"; 17 | let path = Path::new(file_name); 18 | if !path.exists() { 19 | panic!("{} does not exist", file_name); 20 | } 21 | 22 | cover_tree_from_labeled_yaml(&path).unwrap() 23 | } 24 | 25 | #[tokio::main(worker_threads = 12)] 26 | pub async fn main() -> Result<(), Box> { 27 | let mut builder = Builder::new(); 28 | 29 | builder.filter_level(LevelFilter::Info).init(); 30 | let mut ct_writer = build_tree(); 31 | ct_writer.add_plugin::(GokoDirichlet {}); 32 | ct_writer.generate_summaries(); 33 | let goko_server = MakeGokoHttp::<_,MsgPackDense>::new(Arc::new(CoreWriter::new(ct_writer))); 34 | 35 | let addr = ([127, 0, 0, 1], 3031).into(); 36 | 37 | let server = Server::bind(&addr).serve(goko_server); 38 | 39 | println!("Listening on http://{}", addr); 40 | 41 | server.await?; 42 | 43 | Ok(()) 44 | } 45 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | 4 | # W503 is for "W503 line break before binary operator" black sometimes adds this, so this is for compatibility 5 | # D401 is for "D401 First line should be in imperative mood". Has false positives, and doesn't like titles. 6 | # D400 is for "D400 First line should end with a period". 7 | # E203 is for "E203 Whitespace before ':'". Black is incompatible with this warning. 8 | # ANN101 conflicts with black. 9 | ignore = D401, W503, D400, E203, ANN101 10 | 11 | # ANN101 is for "ANN101 Missing type annotation for self in method", which is very common in testing and annotation is not needed 12 | # ANN102 is for "ANN102 Missing type annotation for cls in classmethod", which is very common in testing and annotation is not needed 13 | # ANN201 is for "ANN201 Missing return type annotation for public function", which is very common in testing and annotation is not needed 14 | # ANN206 is for "ANN206 Missing return type annotation for classmethod", which is very common in testing and annotation is not needed 15 | # D101 is for "D101 Missing docstring in public class", wee have a proliferation of classes due to the way we setup standard tests 16 | # and documenting them all would be pointless 17 | per-file-ignores = 18 | tests/*: D101, ANN101, ANN102, ANN201, ANN206 19 | libraries/python/catapult/*: W291, E231, W293 20 | libraries/python/catapult/tests/*: D100, D101, D102, D103, D104, D200 21 | services/nametag/*: E402 22 | -------------------------------------------------------------------------------- /goko/build.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to Elasticsearch B.V. under one or more contributor 3 | * license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright 5 | * ownership. Elasticsearch B.V. licenses this file to you under 6 | * the Apache License, Version 2.0 (the "License"); you may 7 | * not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | extern crate protoc_rust; 20 | 21 | #[cfg(not(feature = "docs-only"))] 22 | use std::env; 23 | 24 | #[cfg(not(feature = "docs-only"))] 25 | fn main() { 26 | if env::var("TRAVIS_RUST_VERSION").is_err() { 27 | protoc_rust::Codegen::new() 28 | .out_dir("src") 29 | .include("protos") 30 | .input("protos/tree_file_format.proto") 31 | .run() 32 | .expect("protoc"); 33 | } 34 | println!("cargo:rerun-if-changed=protos/tree_file_format.proto"); 35 | } 36 | 37 | #[cfg(feature = "docs-only")] 38 | fn main() { 39 | println!("NOT Building proto"); 40 | } 41 | -------------------------------------------------------------------------------- /serve_goko/src/api/path.rs: -------------------------------------------------------------------------------- 1 | use pointcloud::*; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | use std::ops::Deref; 5 | 6 | use goko::errors::GokoError; 7 | use crate::core::*; 8 | use super::NodeDistance; 9 | 10 | /// Response: [`PathResponse`] 11 | #[derive(Deserialize, Serialize)] 12 | pub struct PathRequest { 13 | pub point: T, 14 | } 15 | 16 | /// Request: [`PathRequest`] 17 | #[derive(Deserialize, Serialize)] 18 | pub struct PathResponse { 19 | pub path: Vec>, 20 | } 21 | 22 | impl PathRequest { 23 | pub fn process(self, reader: &mut CoreReader) -> Result, GokoError> 24 | where 25 | D: PointCloud, 26 | T: Deref + Send + Sync, 27 | { 28 | let knn = reader.tree.path(&self.point)?; 29 | let pc = &reader.tree.parameters().point_cloud; 30 | 31 | let resp: Result>, GokoError> = knn 32 | .iter() 33 | .map(|(distance, (layer, pi))| { 34 | let label_summary = reader.tree.get_node_label_summary((*layer, *pi)).map(|s| (*s).clone()); 35 | Ok(NodeDistance { 36 | name: pc.name(*pi)?, 37 | layer: *layer, 38 | distance: *distance, 39 | label_summary, 40 | }) 41 | }) 42 | .collect(); 43 | Ok(PathResponse { path: resp? }) 44 | } 45 | } -------------------------------------------------------------------------------- /goko/src/plugins/gaussians/mod.rs: -------------------------------------------------------------------------------- 1 | //! # Probability Distributions Plugins 2 | //! 3 | //! This module containes plugins that simulate probability distributions on the nodes. 4 | //! It also has trackers used to see when queries and sequences are out of distribution. 5 | 6 | use super::*; 7 | use rand::Rng; 8 | use std::fmt::Debug; 9 | 10 | mod diag_gaussian; 11 | pub use diag_gaussian::*; 12 | 13 | /* 14 | There's an issue with rust-numpy and ndarray causing the linear algebra package for ndarray to fail. 15 | 16 | Temporary removal 17 | 18 | mod svd_gaussian; 19 | pub use svd_gaussian::*; 20 | */ 21 | 22 | use pointcloud::PointRef; 23 | 24 | /// 25 | pub trait ContinousDistribution: Clone + 'static { 26 | /// Pass none if you want to test for a singleton, returns 0 if 27 | fn ln_pdf(&self, point: &T) -> Option; 28 | /// Samples a point from this distribution 29 | fn sample(&self, rng: &mut R) -> Vec; 30 | 31 | /// Computes the KL divergence of two bucket probs. 32 | /// KL(self || other) 33 | /// Returns None if the support of the self is not a subset of the support of the other, or the calculation is undefined. 34 | fn kl_divergence(&self, other: &Self) -> Option; 35 | } 36 | 37 | /// 38 | pub trait ContinousBayesianDistribution: ContinousDistribution + Clone + 'static { 39 | /// Adds an observation to the distribution. 40 | /// This currently shifts the underlying parameters of the distribution rather than be tracked. 41 | fn add_observation(&mut self, point: &T); 42 | } 43 | -------------------------------------------------------------------------------- /pointcloud/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "pointcloud" 3 | version = "0.5.4" 4 | edition = "2018" 5 | 6 | description = "An accessor layer for goko" 7 | readme = "README.md" 8 | 9 | authors = ["Sven Cattell "] 10 | 11 | documentation = "https://docs.rs/pointcloud" 12 | homepage = "https://github.com/elastic/goko" 13 | repository = "https://github.com/elastic/goko.git" 14 | 15 | license-file = "../LICENSE.txt" 16 | 17 | keywords = ["datasets"] 18 | categories = ["science", "data-structures"] 19 | 20 | [badges] 21 | travis-ci = { repository = "https://github.com/elastic/goko.git", branch = "master" } 22 | 23 | [toolchain] 24 | channel = "nightly" 25 | 26 | [features] 27 | default = [] 28 | 29 | [dependencies] 30 | log = "0.4" 31 | csv = "1.1.6" 32 | libc = "0.2" 33 | yaml-rust = "0.4" 34 | rayon = "1.4.0" 35 | packed_simd = { version = "0.3.4", package = "packed_simd_2" } 36 | glob = "0.3.0" 37 | fxhash = "0.2.1" 38 | hashbrown = { version = "0.11.2", features = ["rayon", "serde"] } 39 | serde_json = "1.0.64" 40 | serde = { version = "1.0.116", features = ["derive"] } 41 | flate2 = "1.0.17" 42 | rand = "0.8.3" 43 | smallvec = { version = "1.3.0", features = ["serde"] } 44 | num-traits = "0.2" 45 | ndarray = "0.14.0" 46 | 47 | [target.'cfg(windows)'.dependencies] 48 | winapi = { version = "0.3", features = ["basetsd", "handleapi", "memoryapi", "minwindef", "std", "sysinfoapi"] } 49 | 50 | [dev-dependencies] 51 | tempdir = "0.3" 52 | criterion = "0.3" 53 | assert_approx_eq = "1.0.0" 54 | 55 | [[bench]] 56 | name = "dists_bench" 57 | path = "benches/dists_bench.rs" 58 | harness = false 59 | -------------------------------------------------------------------------------- /goko/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "goko" 3 | version = "0.5.4" 4 | edition = "2018" 5 | 6 | description = "A lock-free, eventually consistent, concurrent covertree." 7 | readme = "README.md" 8 | 9 | authors = ["Sven Cattell "] 10 | 11 | documentation = "https://docs.rs/goko" 12 | homepage = "https://github.com/elastic/goko" 13 | repository = "https://github.com/elastic/goko.git" 14 | 15 | keywords = ["cover-tree","knn","lock-free"] 16 | categories = ["concurrency", "data-structures"] 17 | 18 | license-file = "../LICENSE.txt" 19 | include = ["protos/tree_file_format.proto","build.rs","src/*","Cargo.toml"] 20 | 21 | [toolchain] 22 | channel = "nightly" 23 | 24 | [features] 25 | docs-only = [] 26 | 27 | 28 | [lib] 29 | path = "src/lib.rs" 30 | test = true 31 | 32 | [dependencies] 33 | protobuf = "2.23.0" 34 | rand = { version = "0.8.3", features = ["small_rng"]} 35 | rand_distr = "0.4.0" 36 | yaml-rust = "0.4.5" 37 | pbr = "1.0.4" 38 | fxhash = "0.2.1" 39 | rayon = "1.5.0" 40 | hashbrown = { version = "0.11.2", features = ["rayon"] } 41 | crossbeam-channel = "0.5.1" 42 | pointcloud = { version = "0.5.4", path = "../pointcloud" } 43 | serde = { version = "1.0.125", features = ["derive"] } 44 | smallvec = "1.6.1" 45 | type-map = "0.5.0" 46 | statrs = "0.13.0" 47 | ndarray = "0.14.0" 48 | 49 | [dev-dependencies] 50 | criterion = "0.3.4" 51 | assert_approx_eq = "1.0.0" 52 | 53 | [[bench]] 54 | name = "path_bench" 55 | path = "benches/path_bench.rs" 56 | harness = false 57 | 58 | [build-dependencies] 59 | protoc-rust = "2.23.0" 60 | 61 | [package.metadata.docs.rs] 62 | features = [ "docs-only" ] -------------------------------------------------------------------------------- /serve_goko/src/http/maker.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use pointcloud::*; 4 | 5 | use tower::Service; 6 | 7 | use std::convert::Infallible; 8 | 9 | use futures::future; 10 | use core::task::Context; 11 | use std::task::Poll; 12 | 13 | use std::marker::PhantomData; 14 | use std::ops::Deref; 15 | 16 | use super::GokoHttp; 17 | use crate::parsers::{PointParser, PointBuffer}; 18 | use crate::core::*; 19 | 20 | pub struct MakeGokoHttp { 21 | writer: Arc>, 22 | parser: PhantomData

, 23 | } 24 | 25 | impl MakeGokoHttp 26 | where 27 | D: PointCloud, 28 | P: PointParser, 29 | P::Point: Deref + Send + Sync, 30 | { 31 | pub fn new(writer: Arc>) -> MakeGokoHttp { 32 | MakeGokoHttp { 33 | writer, 34 | parser: PhantomData, 35 | } 36 | } 37 | } 38 | 39 | impl Service for MakeGokoHttp 40 | where 41 | D: PointCloud, 42 | P: PointParser, 43 | P::Point: Deref + Send + Sync + 'static, 44 | { 45 | type Response = GokoHttp; 46 | type Error = Infallible; 47 | type Future = futures::future::Ready>; 48 | 49 | fn poll_ready(&mut self, _: &mut Context) -> Poll> { 50 | Poll::Ready(Ok(())) 51 | } 52 | 53 | fn call(&mut self, _: T) -> Self::Future { 54 | let reader = self.writer.reader(); 55 | let parser = PointBuffer::

::new(); 56 | future::ready(Ok(GokoHttp::new(reader, parser))) 57 | } 58 | } -------------------------------------------------------------------------------- /goko/src/monomap/inner.rs: -------------------------------------------------------------------------------- 1 | use std::hash::{BuildHasher, Hash}; 2 | 3 | #[cfg(not(feature = "indexed"))] 4 | use hashbrown::HashMap as MapImpl; 5 | #[cfg(feature = "indexed")] 6 | use indexmap::IndexMap as MapImpl; 7 | 8 | pub(crate) struct Inner 9 | where 10 | K: Eq + Hash, 11 | S: BuildHasher, 12 | { 13 | pub(crate) data: MapImpl, 14 | pub(crate) meta: M, 15 | ready: bool, 16 | } 17 | 18 | impl Clone for Inner 19 | where 20 | K: Eq + Hash + Clone, 21 | S: BuildHasher + Clone, 22 | M: Clone, 23 | { 24 | fn clone(&self) -> Self { 25 | assert!(self.data.is_empty()); 26 | Inner { 27 | data: MapImpl::with_capacity_and_hasher( 28 | self.data.capacity(), 29 | self.data.hasher().clone(), 30 | ), 31 | meta: self.meta.clone(), 32 | ready: self.ready, 33 | } 34 | } 35 | } 36 | 37 | impl Inner 38 | where 39 | K: Eq + Hash, 40 | S: BuildHasher, 41 | { 42 | pub fn with_hasher(m: M, hash_builder: S) -> Self { 43 | Inner { 44 | data: MapImpl::with_hasher(hash_builder), 45 | meta: m, 46 | ready: false, 47 | } 48 | } 49 | 50 | pub fn with_capacity_and_hasher(m: M, capacity: usize, hash_builder: S) -> Self { 51 | Inner { 52 | data: MapImpl::with_capacity_and_hasher(capacity, hash_builder), 53 | meta: m, 54 | ready: false, 55 | } 56 | } 57 | 58 | pub fn mark_ready(&mut self) { 59 | self.ready = true; 60 | } 61 | 62 | pub fn is_ready(&self) -> bool { 63 | self.ready 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /serve_goko/examples/ember_server.rs: -------------------------------------------------------------------------------- 1 | use goko::{CoverTreeWriter, CoverTreeBuilder}; 2 | use pointcloud::*; 3 | use std::path::Path; 4 | extern crate serve_goko; 5 | use serve_goko::parsers::MsgPackDense; 6 | use serve_goko::http::*; 7 | use serve_goko::core::*; 8 | use std::sync::Arc; 9 | use goko::plugins::discrete::prelude::GokoDirichlet; 10 | use hyper::Server; 11 | use pointcloud::loaders::labeled_ram_from_yaml; 12 | use pointcloud::label_sources::SmallIntLabels; 13 | use pointcloud::data_sources::DataRam; 14 | use log::LevelFilter; 15 | use env_logger::Builder; 16 | 17 | fn build_tree() -> CoverTreeWriter, SmallIntLabels>> { 18 | let file_name = "../data/ember_complex_test.yml"; 19 | let path = Path::new(file_name); 20 | if !path.exists() { 21 | panic!("{} does not exist", file_name); 22 | } 23 | let builder = CoverTreeBuilder::from_yaml(&path); 24 | let point_cloud = labeled_ram_from_yaml("../data/ember_complex_test.yml").unwrap(); 25 | builder.build(Arc::new(point_cloud)).unwrap() 26 | } 27 | 28 | #[tokio::main] 29 | pub async fn main() -> Result<(), Box> { 30 | let mut builder = Builder::new(); 31 | builder.filter_level(LevelFilter::Info).init(); 32 | 33 | let mut ct_writer = build_tree(); 34 | ct_writer.add_plugin::(GokoDirichlet {}); 35 | ct_writer.generate_summaries(); 36 | let goko_server = MakeGokoHttp::<_,MsgPackDense>::new(Arc::new(CoreWriter::new(ct_writer))); 37 | 38 | let addr = ([127, 0, 0, 1], 3030).into(); 39 | 40 | let server = Server::bind(&addr).serve(goko_server); 41 | 42 | println!("Listening on http://{}", addr); 43 | 44 | server.await?; 45 | 46 | Ok(()) 47 | } 48 | -------------------------------------------------------------------------------- /serve_goko/src/core/mod.rs: -------------------------------------------------------------------------------- 1 | use pointcloud::PointCloud; 2 | use goko::{CoverTreeReader,CoverTreeWriter}; 3 | use std::sync::Arc; 4 | use tokio::sync::RwLock; 5 | use std::collections::HashMap; 6 | use std::ops::Deref; 7 | 8 | pub(crate) mod internal_service; 9 | use internal_service::InternalServiceOperator; 10 | use crate::api::{TrackerWorker, TrackingRequest, TrackingResponse}; 11 | 12 | 13 | pub struct CoreWriter { 14 | pub(crate) tree: CoverTreeWriter, 15 | pub(crate) trackers: Arc, TrackingResponse>>>>, 16 | pub(crate) main_tracker: Arc, TrackingResponse>>, 17 | } 18 | 19 | impl + Send + Sync> CoreWriter { 20 | pub fn new(writer: CoverTreeWriter) -> Self { 21 | let trackers = Arc::new(RwLock::new(HashMap::new())); 22 | let main_tracker = Arc::new(TrackerWorker::operator(writer.reader())); 23 | CoreWriter { 24 | trackers, 25 | main_tracker, 26 | tree: writer, 27 | } 28 | } 29 | 30 | pub fn reader(&self) -> CoreReader { 31 | let tree = self.tree.reader(); 32 | CoreReader { 33 | trackers: Arc::clone(&self.trackers), 34 | main_tracker: Arc::clone(&self.main_tracker), 35 | tree, 36 | } 37 | } 38 | } 39 | 40 | pub struct CoreReader { 41 | pub(crate) tree: CoverTreeReader, 42 | pub(crate) trackers: Arc, TrackingResponse>>>>, 43 | pub(crate) main_tracker: Arc, TrackingResponse>>, 44 | } 45 | 46 | -------------------------------------------------------------------------------- /goko/src/covertree/query_tools/mod.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to Elasticsearch B.V. under one or more contributor 3 | * license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright 5 | * ownership. Elasticsearch B.V. licenses this file to you under 6 | * the Apache License, Version 2.0 (the "License"); you may 7 | * not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | //! Tools and data structures for assisting cover tree queries. 21 | 22 | use crate::NodeAddress; 23 | 24 | pub(crate) mod query_items; 25 | 26 | pub(crate) mod knn_query_heap; 27 | pub use knn_query_heap::KnnQueryHeap; 28 | pub(crate) mod trace_query_heap; 29 | pub use trace_query_heap::MultiscaleQueryHeap; 30 | 31 | /// If you have a algorithm that does local brute force KNN on just the children, 32 | /// implement this to use the node fn 33 | pub trait RoutingQueryHeap { 34 | /// Shoves data in. 35 | fn push_nodes( 36 | &mut self, 37 | indexes: &[NodeAddress], 38 | dists: &[f32], 39 | parent_address: Option, 40 | ); 41 | } 42 | 43 | /// If you have a algorithm that does local brute force KNN on just the singletons, 44 | /// implement this to use the node fn 45 | pub trait SingletonQueryHeap { 46 | /// Shove a bunch of single points onto the heap 47 | fn push_outliers(&mut self, indexes: &[usize], dists: &[f32]); 48 | } 49 | -------------------------------------------------------------------------------- /serve_goko/src/api/parameters.rs: -------------------------------------------------------------------------------- 1 | use pointcloud::*; 2 | 3 | use goko::PartitionType; 4 | use serde::{Deserialize, Serialize}; 5 | use crate::core::*; 6 | use goko::errors::GokoError; 7 | 8 | /// Send a `GET` request to `/` for this 9 | #[derive(Deserialize, Serialize, Clone, Copy)] 10 | pub struct ParametersRequest; 11 | 12 | /// Response to a parameters request 13 | #[derive(Deserialize, Serialize)] 14 | pub struct ParametersResponse { 15 | /// See paper or main description, governs the number of children of each node. Higher is more. 16 | pub scale_base: f32, 17 | /// If a node covers less than or equal to this number of points, it becomes a leaf. 18 | pub leaf_cutoff: usize, 19 | /// If a node has scale index less than or equal to this, it becomes a leaf 20 | pub min_res_index: i32, 21 | /// If you don't want singletons messing with your tree and want everything to be a node or a element of leaf node, make this true. 22 | pub use_singletons: bool, 23 | /// The partition type of the tree 24 | pub partition_type: PartitionType, 25 | /// This should be replaced by a logging solution 26 | pub verbosity: u32, 27 | /// The seed to use for deterministic trees. This is xor-ed with the point index to create a seed for `rand::rngs::SmallRng`. 28 | pub rng_seed: Option, 29 | } 30 | 31 | impl ParametersRequest { 32 | pub fn process(self, reader: &mut CoreReader) -> Result { 33 | let params = reader.tree.parameters(); 34 | Ok(ParametersResponse { 35 | scale_base: params.scale_base, 36 | leaf_cutoff: params.leaf_cutoff, 37 | min_res_index: params.min_res_index, 38 | use_singletons: params.use_singletons, 39 | partition_type: params.partition_type, 40 | verbosity: params.verbosity, 41 | rng_seed: params.rng_seed, 42 | }) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Goko, a Geometric Analysis Library 2 | 3 | [![Build Status](https://travis-ci.com/elastic/grandma.svg?branch=master)](https://travis-ci.com/elastic/goko) 4 | [![Crate](https://img.shields.io/crates/v/goko.svg)](https://crates.io/crates/goko) 5 | [![API](https://docs.rs/goko/badge.svg)](https://docs.rs/goko) 6 | 7 | This is a covertree library with some modifications to make it more suitable for real data. Currently it only implements the [fast covertree](http://proceedings.mlr.press/v37/izbicki15.pdf), which is an extension of the original covertree [(pdf)](https://homes.cs.washington.edu/~sham/papers/ml/cover_tree.pdf). There are plans to enable support for full [geometric multi-resolution analysis](https://arxiv.org/pdf/1611.01179.pdf) (GMRA, where the library get it's name from) and [topological data analysis](https://arxiv.org/pdf/1602.06245.pdf). Help is welcome! We'd love to collaborate on more cool tricks to do with covertrees or coding up the large backlog of planned features to support the current known tricks. 8 | 9 | ## Project Layout & Documentation 10 | 11 | Data Access is handled through the `pointcloud` library. See [here](https://docs.rs/pointcloud) for `pointcloud`'s documentation. This is meant to abstract many files and make them look like one, and due to this handles computations like adjacency matrices. The covertree implementation is inside the `goko` library, it's the bread and butter of the library. See [here](https://docs.rs/goko) for it's documentation. 12 | 13 | The `pygoko` library is a python & numpy partial wrap around `goko`. It can access the components of `goko` for gathering statistics on your trees. Once we settle on how this is implemented we will publish the documentation somewhere. 14 | 15 | 16 | #### License 17 | 18 | 19 | Licensed under of Apache License, Version 2.0. 20 | 21 | 22 |
23 | 24 | 25 | Unless you explicitly state otherwise, any contribution intentionally submitted 26 | for inclusion in this crate by you, as defined in the Apache-2.0 license, shall 27 | be licensed as above, without any additional terms or conditions. 28 | 29 | -------------------------------------------------------------------------------- /goko/benches/path_bench.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to Elasticsearch B.V. under one or more contributor 3 | * license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright 5 | * ownership. Elasticsearch B.V. licenses this file to you under 6 | * the Apache License, Version 2.0 (the "License"); you may 7 | * not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | use goko::*; 21 | use pointcloud::*; 22 | use pointcloud::{data_sources::*, label_sources::*, loaders::*}; 23 | use std::path::Path; 24 | 25 | use std::sync::Arc; 26 | 27 | use criterion::{black_box, criterion_group, criterion_main, Criterion}; 28 | 29 | fn build_tree() -> CoverTreeWriter, SmallIntLabels>> { 30 | let file_name = "data/ember_complex.yml"; 31 | let path = Path::new(file_name); 32 | if !path.exists() { 33 | panic!(file_name.to_owned() + &" does not exist".to_string()); 34 | } 35 | let builder = CoverTreeBuilder::from_yaml(&path); 36 | let point_cloud = labeled_ram_from_yaml("data/ember_complex.yml").unwrap(); 37 | builder.build(Arc::new(point_cloud)).unwrap() 38 | } 39 | 40 | pub fn criterion_benchmark(c: &mut Criterion) { 41 | let ct = build_tree(); 42 | let reader = ct.reader(); 43 | c.bench_function("Known Path 0", |b| { 44 | b.iter(|| reader.known_path(black_box(0))) 45 | }); 46 | 47 | let pointcloud = reader.point_cloud(); 48 | let point = pointcloud.point(0).unwrap(); 49 | c.bench_function("Unknown Path 0", |b| { 50 | b.iter(|| reader.path(black_box(&point))) 51 | }); 52 | } 53 | 54 | criterion_group!(benches, criterion_benchmark); 55 | criterion_main!(benches); 56 | -------------------------------------------------------------------------------- /data/setup_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | mnist = tf.keras.datasets.mnist 4 | import pandas as pd 5 | from sklearn.neighbors import KDTree 6 | import os 7 | 8 | # Base MNIST transform for easy access. 9 | # The Yaml files are often messed with, these are the base files. 10 | mnist_yaml = ''' 11 | --- 12 | cutoff: 5 13 | resolution: -10 14 | scale_base: 2 15 | data_path: ../data/mnist.dat 16 | labels_path: ../data/mnist.csv 17 | count: 60000 18 | data_dim: 784 19 | labels_dim: 10 20 | in_ram: True 21 | ''' 22 | 23 | mnist_complex_yaml = ''' 24 | --- 25 | cutoff: 0 26 | resolution: -20 27 | scale_base: 1.3 28 | use_singletons: true 29 | verbosity: 0 30 | data_path: ../data/mnist.dat 31 | labels_path: ../data/mnist.csv 32 | count: 60000 33 | data_dim: 784 34 | in_ram: True 35 | schema: 36 | 'y': i32 37 | name: string 38 | ''' 39 | 40 | metaFile = open("data/mnist.yml","wb") 41 | metaFile.write(mnist_yaml.encode('utf-8')) 42 | metaFile.close() 43 | metaFile = open("data/mnist_complex.yml","wb") 44 | metaFile.write(mnist_complex_yaml.encode('utf-8')) 45 | metaFile.close() 46 | 47 | (x_train, y_train),(x_test, y_test) = mnist.load_data() 48 | x_train, x_test = x_train / 255.0, x_test / 255.0 49 | x_train = x_train.astype(np.float32) 50 | x_train = x_train.reshape(-1, 28*28) 51 | dataFile = open("mnist.dat", "wb") 52 | for x in x_train: 53 | dataFile.write(x.tobytes()) 54 | dataFile.close() 55 | y_bools = [y%2 == 0 for y in y_train] 56 | y_str = [str(y) for y in y_train] 57 | 58 | df = pd.DataFrame({"y":y_train,"even":y_bools,"name":y_str}) 59 | df.index.rename('index', inplace=True) 60 | df.to_csv('mnist.csv') 61 | 62 | # KNN data for tests 63 | data = np.memmap("mnist.dat", dtype=np.float32) 64 | data = data.reshape([-1,784]) 65 | 66 | tree = KDTree(data, leaf_size=2) 67 | dist, ind = tree.query(data[:100], k=5) 68 | 69 | dist, ind = tree.query(np.zeros([1,784],dtype=np.float32), k=5) 70 | 71 | nbrs = {"d0":dist[:,0], 72 | "d1":dist[:,1], 73 | "d2":dist[:,2], 74 | "d3":dist[:,3], 75 | "d4":dist[:,4], 76 | "i0": ind[:,0], 77 | "i1": ind[:,1], 78 | "i2": ind[:,2], 79 | "i3": ind[:,3], 80 | "i4": ind[:,4],} 81 | 82 | csv = pd.DataFrame(nbrs) 83 | csv.to_csv("mnist_nbrs.csv") -------------------------------------------------------------------------------- /pointcloud/src/lib.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to Elasticsearch B.V. under one or more contributor 3 | * license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright 5 | * ownership. Elasticsearch B.V. licenses this file to you under 6 | * the Apache License, Version 2.0 (the "License"); you may 7 | * not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | //! # Point Cloud 20 | //! Abstracts data access over several files and glues metadata files to vector data files 21 | 22 | #![allow(dead_code)] 23 | //#![deny(warnings)] 24 | #![warn(missing_docs)] 25 | #![allow(clippy::cast_ptr_alignment)] 26 | #![feature(result_flattening)] 27 | #![feature(is_sorted)] 28 | #![feature(generic_associated_types)] 29 | 30 | #[cfg(test)] 31 | #[macro_use] 32 | extern crate assert_approx_eq; 33 | 34 | pub mod pc_errors; 35 | 36 | mod base_traits; 37 | #[doc(inline)] 38 | pub use base_traits::*; 39 | pub mod metrics; 40 | 41 | pub mod points; 42 | 43 | pub mod data_sources; 44 | 45 | pub mod glued_data_cloud; 46 | 47 | pub mod label_sources; 48 | pub mod summaries; 49 | 50 | pub mod loaders; 51 | 52 | use data_sources::DataRam; 53 | use label_sources::SmallIntLabels; 54 | 55 | pub use metrics::L2; 56 | 57 | /// A sensible default for an labeled cloud 58 | pub type DefaultLabeledCloud = SimpleLabeledCloud, SmallIntLabels>; 59 | /// A sensible default for an unlabeled cloud 60 | pub type DefaultCloud = DataRam; 61 | 62 | impl> DefaultLabeledCloud { 63 | /// Simple way of gluing together the most common data source 64 | pub fn new_simple(data: Vec, dim: usize, labels: Vec) -> DefaultLabeledCloud { 65 | SimpleLabeledCloud::new( 66 | DataRam::new(data, dim).unwrap(), 67 | SmallIntLabels::new(labels, None), 68 | ) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /goko/README.md: -------------------------------------------------------------------------------- 1 | # Goko 2 | 3 | ## Plugins 4 | 5 | These are going to be the core of the library. Attaching math gadgets to each node of the covertree is the most interesting thing you can do with them and is the goal of this library. The objective is to have zero-cost plugins on the fastest covertree implementation possible. It should be impossible to write a faster implementation of your covertree based algorithm. This is the third near total rewrite of the library. The previous 2 implementations had foundational bottlenecks and details that hindered this goal. We feel that this latest architecture has the strongest foundation for future work. 6 | 7 | Plugins are currently partially implemented. Expect them in the coming weeks! 8 | 9 | ## Read-Write heads 10 | 11 | `goko` stores the covertree as an array of layers, each with an arena style allocation of all nodes on that layer in a `evmap`. This is a pair of hashmaps where all the readers point at one constant map (so it can be safely, locklessly queried), while the writer edits the other map. When the writer is done it swaps the pointer. The readers then can see the updates, while the writer writes the same changes (and more) to the other hashmap. Plugins cannot control the update order, however the tree will be updated from leaf to root so you can use recursive logic. 12 | 13 | ## Singleton Children 14 | 15 | Children of a node that would only cover are stored in a singleton list. This saves a significant amount of memory (Trees built on Ember are 2-3x smaller) and you can chose not to query these nodes to speed up KNN & other queries. Considerations should be made when writing algorithms for this implementation for these unique children. 16 | 17 | ## Installation 18 | 19 | This depends on protoc which can be installed by installing `protobuf-compiler` on ubuntu. A normal dev machine should have the rest installed, but fresh machines may need to install the following, along with Rust: 20 | 21 | ```bash 22 | sudo apt install -y build-essential libssl-dev libblas-dev liblapacke-dev pkg-config libprotobuf-dev libprotoc-dev protobuf-compiler 23 | ``` 24 | 25 | It also requires nightly due to `packed-simd`'s reliance on experimental APIs. Currently the latest version that is confirmed to work is 26 | ```bash 27 | rustup install nightly-2020-09-14 28 | rustup override set nightly-2020-09-14 29 | ``` -------------------------------------------------------------------------------- /pointcloud/src/loaders/mod.rs: -------------------------------------------------------------------------------- 1 | //! Loaders for datasets. Just opens them up and returns a point cloud. 2 | 3 | use std::path::{Path, PathBuf}; 4 | 5 | use crate::base_traits::*; 6 | use crate::data_sources::*; 7 | use crate::glued_data_cloud::*; 8 | use crate::label_sources::*; 9 | use crate::pc_errors::*; 10 | 11 | mod yaml_loaders; 12 | pub use yaml_loaders::*; 13 | mod csv_loaders; 14 | pub use csv_loaders::*; 15 | 16 | /// Opens a set of memmaps of both data and labels 17 | pub fn open_labeled_memmaps>( 18 | data_dim: usize, 19 | label_dim: usize, 20 | data_paths: &[PathBuf], 21 | labels_paths: &[PathBuf], 22 | ) -> PointCloudResult, VecLabels>>> { 23 | if data_paths.len() != labels_paths.len() { 24 | panic!( 25 | "Mismatch of label and data paths Data: {:?}, Labels: {:?}", 26 | data_paths, labels_paths 27 | ); 28 | } 29 | let collection: PointCloudResult, VecLabels>>> = 30 | data_paths 31 | .iter() 32 | .zip(labels_paths.iter()) 33 | .map(|(dp, lp)| { 34 | let data = DataMemmap::::new(data_dim, &dp)?; 35 | let labels = DataMemmap::::new(label_dim, &lp)?.convert_to_labels(); 36 | Ok(SimpleLabeledCloud::new(data, labels)) 37 | }) 38 | .collect(); 39 | Ok(HashGluedCloud::new(collection?)) 40 | } 41 | 42 | /// Opens a set of memmaps of just data 43 | pub fn open_memmaps>( 44 | data_dim: usize, 45 | data_paths: &[PathBuf], 46 | ) -> PointCloudResult>> { 47 | let collection: PointCloudResult>> = data_paths 48 | .iter() 49 | .map(|dp| DataMemmap::::new(data_dim, &dp)) 50 | .collect(); 51 | Ok(HashGluedCloud::new(collection?)) 52 | } 53 | 54 | /// Concatenates a glued data memmap to a single ram dataset 55 | pub fn convert_glued_memmap_to_ram>( 56 | glued_cloud: HashGluedCloud>, 57 | ) -> DataRam { 58 | glued_cloud 59 | .take_data_sources() 60 | .drain(0..) 61 | .map(|ds| ds.convert_to_ram()) 62 | .reduce(|mut a, b| { 63 | a.merge(b); 64 | a 65 | }) 66 | .unwrap() 67 | } 68 | -------------------------------------------------------------------------------- /serve_goko/src/api/knn.rs: -------------------------------------------------------------------------------- 1 | use pointcloud::*; 2 | use crate::core::*; 3 | 4 | use serde::{Deserialize, Serialize}; 5 | use std::ops::Deref; 6 | 7 | use goko::errors::GokoError; 8 | 9 | use super::NamedDistance; 10 | 11 | /// Response: [`KnnResponse`] 12 | #[derive(Deserialize, Serialize)] 13 | pub struct KnnRequest { 14 | pub k: usize, 15 | pub point: T, 16 | } 17 | 18 | /// Request: [`KnnRequest`] 19 | #[derive(Deserialize, Serialize)] 20 | pub struct KnnResponse { 21 | pub knn: Vec, 22 | } 23 | 24 | impl KnnRequest { 25 | pub fn process(self, reader: &mut CoreReader) -> Result 26 | where 27 | D: PointCloud, 28 | T: Deref + Send + Sync, 29 | { 30 | let knn = reader.tree.knn(&self.point, self.k)?; 31 | let pc = &reader.tree.parameters().point_cloud; 32 | let resp: Result, GokoError> = knn 33 | .iter() 34 | .map(|(distance, pi)| { 35 | Ok(NamedDistance { 36 | name: pc.name(*pi)?, 37 | distance: *distance, 38 | }) 39 | }) 40 | .collect(); 41 | 42 | Ok(KnnResponse { knn: resp? }) 43 | } 44 | } 45 | 46 | /// Response: [`RoutingKnnResponse`] 47 | #[derive(Deserialize, Serialize)] 48 | pub struct RoutingKnnRequest { 49 | pub k: usize, 50 | pub point: T, 51 | } 52 | 53 | /// Request: [`RoutingKnnRequest`] 54 | #[derive(Deserialize, Serialize)] 55 | pub struct RoutingKnnResponse { 56 | pub routing_knn: Vec, 57 | } 58 | 59 | impl RoutingKnnRequest { 60 | pub fn process(self, reader: &CoreReader) -> Result 61 | where 62 | D: PointCloud, 63 | T: Deref + Send + Sync, 64 | { 65 | let knn = reader.tree.routing_knn(&self.point, self.k)?; 66 | let pc = &reader.tree.parameters().point_cloud; 67 | let resp: Result, GokoError> = knn 68 | .iter() 69 | .map(|(distance, pi)| { 70 | Ok(NamedDistance { 71 | name: pc.name(*pi)?, 72 | distance: *distance, 73 | }) 74 | }) 75 | .collect(); 76 | 77 | Ok(RoutingKnnResponse { routing_knn: resp? }) 78 | } 79 | } -------------------------------------------------------------------------------- /pointcloud/src/loaders/csv_loaders.rs: -------------------------------------------------------------------------------- 1 | use crate::pc_errors::*; 2 | use csv::Reader; 3 | use flate2::read::GzDecoder; 4 | use std::fs::File; 5 | use std::io::Read; 6 | use std::path::Path; 7 | 8 | use crate::label_sources::*; 9 | 10 | /// Opens a CSV and reads a single column from it as a integer label. Negative labels are treated as unlabeled and are masked. 11 | pub fn open_int_csv + std::fmt::Debug>( 12 | path: &P, 13 | index: usize, 14 | ) -> PointCloudResult { 15 | if !path.as_ref().exists() { 16 | panic!("CSV file {:?} does not exist", path); 17 | } 18 | 19 | match File::open(&path) { 20 | Ok(file) => { 21 | if path.as_ref().extension().unwrap() == "gz" { 22 | read_csv(index, path, Reader::from_reader(GzDecoder::new(file))) 23 | } else { 24 | read_csv(index, path, Reader::from_reader(file)) 25 | } 26 | } 27 | Err(e) => panic!("Unable to open csv file {:#?}", e), 28 | } 29 | } 30 | 31 | fn read_csv + std::fmt::Debug, R: Read>( 32 | index: usize, 33 | path: &P, 34 | mut rdr: Reader, 35 | ) -> PointCloudResult { 36 | let mut labels = Vec::new(); 37 | let mut mask = Vec::new(); 38 | 39 | for result in rdr.records() { 40 | // The iterator yields Result, so we check the 41 | // error here. 42 | let record = result.expect("Unable to read a record from the label CSV"); 43 | match record.get(index) { 44 | Some(val) => { 45 | let val = val.parse::().map_err(|_| { 46 | PointCloudError::ParsingError(ParsingError::CSVReadError { 47 | file_name: path.as_ref().to_string_lossy().to_string(), 48 | line_number: record.position().unwrap().line() as usize, 49 | key: format!("Unable to read u64 from {:?}", record), 50 | }) 51 | })?; 52 | if 0 < val { 53 | mask.push(true); 54 | } else { 55 | mask.push(false); 56 | } 57 | labels.push(val); 58 | } 59 | None => { 60 | labels.push(0); 61 | mask.push(false); 62 | } 63 | } 64 | } 65 | if mask.iter().any(|f| !f) { 66 | Ok(SmallIntLabels::new(labels, Some(mask))) 67 | } else { 68 | Ok(SmallIntLabels::new(labels, None)) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /serve_goko/src/parsers/msgpack_dense.rs: -------------------------------------------------------------------------------- 1 | //! # Parser System 2 | //! 3 | //! This currently isn't safe (we assume the caller is going to send requests with bodies of reasonable size), and does more allocations 4 | //! than it strictly needs to. 5 | 6 | use hyper::{Request, Body}; 7 | 8 | use http::header::CONTENT_TYPE; 9 | use flate2::read::{DeflateDecoder, ZlibDecoder}; 10 | use rmp_serde; 11 | use std::io::Read; 12 | use crate::PointParser; 13 | use log::trace; 14 | use crate::errors::*; 15 | 16 | pub trait ParserService: Send + Sync + 'static { 17 | type Point; 18 | fn parse(&self, bytes: &[u8]) -> Result; 19 | } 20 | 21 | #[derive(Clone)] 22 | pub struct MsgPackDense {} 23 | 24 | pub enum Readers { 25 | Zlib(DeflateDecoder), 26 | Gzip(ZlibDecoder), 27 | None(R), 28 | } 29 | 30 | impl Read for Readers { 31 | fn read(&mut self, buf: &mut [u8]) -> std::io::Result { 32 | use Readers::*; 33 | match self { 34 | Zlib(reader) => reader.read(buf), 35 | Gzip(reader) => reader.read(buf), 36 | None(reader) => reader.read(buf), 37 | } 38 | } 39 | } 40 | 41 | 42 | impl PointParser for MsgPackDense { 43 | type Point = Vec; 44 | fn parse(body_buffer: &[u8], scratch_buffer: &mut Vec, request: &Request) -> Result { 45 | scratch_buffer.clear(); 46 | let mut reader = match request.headers().get(CONTENT_TYPE) { 47 | Some(typestr) => { 48 | let token = typestr.to_str().unwrap(); 49 | match token { 50 | "zlib" => { 51 | Readers::Zlib(DeflateDecoder::new(body_buffer)) 52 | } 53 | "gzip" => { 54 | Readers::Gzip(ZlibDecoder::new(body_buffer)) 55 | } 56 | _ => { 57 | return Err(GokoClientError::parse(Box::new(std::io::Error::new(std::io::ErrorKind::Other, "Unknown Content Type")))); 58 | } 59 | } 60 | } 61 | None => Readers::None(body_buffer), 62 | }; 63 | reader.read_to_end(scratch_buffer).map_err(|e| GokoClientError::parse(Box::new(e)))?; 64 | if scratch_buffer.len() > 0 { 65 | let point: Vec = 66 | rmp_serde::from_read_ref(scratch_buffer).map_err(|e| GokoClientError::Parse(Box::new(e)))?; 67 | trace!("Initial Buffer len: {}, Scratch Buffer Len: {}, Final point lenght: {}", body_buffer.len(), scratch_buffer.len(), point.len()); 68 | Ok(point) 69 | } else { 70 | Err(GokoClientError::MissingBody) 71 | } 72 | } 73 | } -------------------------------------------------------------------------------- /goko/examples/mnist.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to Elasticsearch B.V. under one or more contributor 3 | * license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright 5 | * ownership. Elasticsearch B.V. licenses this file to you under 6 | * the Apache License, Version 2.0 (the "License"); you may 7 | * not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | use goko::utils::*; 21 | use goko::CoverTreeWriter; 22 | use pointcloud::*; 23 | use std::path::Path; 24 | 25 | fn build_tree() -> CoverTreeWriter> { 26 | let file_name = "../data/ember_complex.yml"; 27 | let path = Path::new(file_name); 28 | if !path.exists() { 29 | panic!("{} does not exist", file_name); 30 | } 31 | 32 | cover_tree_from_labeled_yaml(&path).unwrap() 33 | } 34 | 35 | fn main() { 36 | for _i in 0..1 { 37 | let mut ct = build_tree(); 38 | //ct.cluster().unwrap(); 39 | ct.refresh(); 40 | let ct_reader = ct.reader(); 41 | println!("Tree has {} nodes", ct_reader.len()); 42 | for scale_index in ct_reader.scale_range() { 43 | println!( 44 | "Layer {} has {} nodes in scale {}", 45 | scale_index, 46 | ct_reader.layer(scale_index).len(), 47 | ct_reader.scale(scale_index) 48 | ); 49 | } 50 | } 51 | /* 52 | println!("===== Parameters ====="); 53 | let (sb, co, re) = ct.parameters(); 54 | println!( 55 | "{{\"scale_base\":{},\"cutoff\":{},\"resolution\":{}}}", 56 | sb, co, re 57 | ); 58 | println!("===== Schema ====="); 59 | println!("{}", ct.metadata_schema()); 60 | println!("===== KNN ====="); 61 | let query1 = ct.knn_query(&zeros, 5).unwrap(); 62 | println!("{:?}", query1); 63 | println!("===== Center KNN ====="); 64 | let query1 = ct.center_knn_query(&zeros, 5).unwrap(); 65 | println!("{:?}", query1); 66 | println!("===== Trace ====="); 67 | assert!(query1.len() == 5); 68 | let query2 = ct.insert_trace(&zeros).unwrap(); 69 | let trace_report: String = query2 70 | .iter() 71 | .map(|node| node.report_json()) 72 | .collect::>() 73 | .join(","); 74 | println!("[{}]", trace_report); 75 | println!("===== Saving ====="); 76 | */ 77 | } 78 | -------------------------------------------------------------------------------- /pygoko/tests/two_d_viz.py: -------------------------------------------------------------------------------- 1 | 2 | import pygoko 3 | import numpy as np 4 | from math import pi 5 | import matplotlib as mpl 6 | import matplotlib.pyplot as plt 7 | import matplotlib.lines as mlines 8 | import matplotlib.patches as mpatches 9 | from matplotlib.collections import PatchCollection 10 | 11 | cmap = plt.get_cmap("jet") 12 | norm = mpl.colors.Normalize(vmin=0.0, vmax=1.0) 13 | 14 | 15 | def show2D(tree, data): 16 | 17 | top = tree.top_scale() 18 | bottom = tree.bottom_scale() 19 | print(top, bottom) 20 | for j in range(top, bottom, -1): 21 | patches = [] 22 | lines = [] 23 | centers = [] 24 | layer = tree.layer(j) 25 | width = layer.radius() / 2 26 | _, centers = layer.centers() 27 | for c in centers: 28 | patches.append(mpatches.Circle(c, 2 * width, color="Blue")) 29 | for i in range(j, top): 30 | parent_layer = tree.layer(i + 1) 31 | point_indexes, center_points = parent_layer.centers() 32 | for point_index, c in zip(point_indexes, center_points): 33 | if not parent_layer.is_leaf(point_index): 34 | for child in parent_layer.child_points(point_index): 35 | lines.append( 36 | mlines.Line2D( 37 | [c[0], child[0]], [c[1], child[1]], color="blue" 38 | ) 39 | ) 40 | 41 | for singleton in parent_layer.singleton_points(point_index): 42 | lines.append( 43 | mlines.Line2D( 44 | [c[0], singleton[0]], [c[1], singleton[1]], color="orange" 45 | ) 46 | ) 47 | 48 | fig, ax = plt.subplots() 49 | centers = np.array(centers) 50 | 51 | ax.set_xlim((-1.6, 1.6)) 52 | ax.set_ylim((-1.6, 1.6)) 53 | 54 | collection = PatchCollection(patches, match_original=True, alpha=0.05) 55 | for line in lines: 56 | ax.add_line(line) 57 | ax.add_collection(collection) 58 | ax.scatter(data[:, 0], data[:, 1], color="orange") 59 | ax.scatter(centers[:, 0], centers[:, 1], color="red") 60 | 61 | ax.axis("off") 62 | fig.set_size_inches(10.00, 10.00) 63 | fig.savefig(f"2d_vis_{j:2d}.png", bbox_inches="tight") 64 | plt.close() 65 | 66 | 67 | if __name__ == "__main__": 68 | numPoints = 120 69 | data = pi * (2 * np.random.rand(numPoints, 1) - 0.5) 70 | data = [np.cos(data).reshape(-1, 1), np.cos(data) * np.sin(data).reshape(-1, 1)] 71 | data = np.concatenate(data, axis=1).astype(np.float32) 72 | 73 | tree = pygoko.CoverTree() 74 | tree.set_leaf_cutoff(0) 75 | tree.set_scale_base(2.0) 76 | tree.fit(data) 77 | 78 | show2D(tree, data) 79 | -------------------------------------------------------------------------------- /pygoko/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from setuptools import setup 4 | from setuptools.command.test import test as TestCommand 5 | from setuptools.command.sdist import sdist as SdistCommand 6 | 7 | from setuptools_rust import RustExtension 8 | 9 | 10 | class CargoModifiedSdist(SdistCommand): 11 | """Modifies Cargo.toml to use an absolute rather than a relative path 12 | The current implementation of PEP 517 in pip always does builds in an 13 | isolated temporary directory. This causes problems with the build, because 14 | Cargo.toml necessarily refers to the current version of pyo3 by a relative 15 | path. 16 | Since these sdists are never meant to be used for anything other than 17 | tox / pip installs, at sdist build time, we will modify the Cargo.toml 18 | in the sdist archive to include an *absolute* path to pyo3. 19 | """ 20 | 21 | def make_release_tree(self, base_dir, files): 22 | """Stages the files to be included in archives""" 23 | super().make_release_tree(base_dir, files) 24 | 25 | import toml 26 | 27 | # Cargo.toml is now staged and ready to be modified 28 | cargo_loc = os.path.join(base_dir, "Cargo.toml") 29 | assert os.path.exists(cargo_loc) 30 | 31 | with open(cargo_loc, "r") as f: 32 | cargo_toml = toml.load(f) 33 | 34 | rel_pyo3_path = cargo_toml["dependencies"]["pyo3"]["path"] 35 | base_path = os.path.dirname(__file__) 36 | abs_pyo3_path = os.path.abspath(os.path.join(base_path, rel_pyo3_path)) 37 | 38 | cargo_toml["dependencies"]["pyo3"]["path"] = abs_pyo3_path 39 | 40 | with open(cargo_loc, "w") as f: 41 | toml.dump(cargo_toml, f) 42 | 43 | 44 | class PyTest(TestCommand): 45 | user_options = [] 46 | 47 | def run(self): 48 | self.run_command("test_rust") 49 | 50 | import subprocess 51 | 52 | subprocess.check_call(["pytest", "tests"]) 53 | 54 | 55 | setup_requires = ["setuptools-rust>=0.10.1", "wheel"] 56 | install_requires = [] 57 | tests_require = install_requires + ["pytest", "pytest-benchmark"] 58 | 59 | setup( 60 | name="pygoko", 61 | version="0.4.0", 62 | classifiers=[ 63 | "License :: OSI Approved :: MIT License", 64 | "Development Status :: 3 - Alpha", 65 | "Intended Audience :: Developers", 66 | "Programming Language :: Python", 67 | "Programming Language :: Rust", 68 | "Operating System :: POSIX", 69 | "Operating System :: MacOS :: MacOS X", 70 | ], 71 | packages=["pygoko"], 72 | rust_extensions=[ 73 | RustExtension("pygoko.pygoko", "Cargo.toml", debug=False), 74 | ], 75 | install_requires=install_requires, 76 | tests_require=tests_require, 77 | setup_requires=setup_requires, 78 | include_package_data=True, 79 | zip_safe=False, 80 | cmdclass={"test": PyTest, "sdist": CargoModifiedSdist}, 81 | ) 82 | -------------------------------------------------------------------------------- /serve_goko/src/http/message.rs: -------------------------------------------------------------------------------- 1 | use tokio::sync::{mpsc, oneshot}; 2 | use pin_project::pin_project; 3 | 4 | use http::{Request, Response}; 5 | use hyper::Body; 6 | 7 | use core::task::Context; 8 | use std::future::Future; 9 | use std::pin::Pin; 10 | use std::task::Poll; 11 | 12 | use std::sync::{atomic, Arc, Mutex}; 13 | 14 | use crate::errors::{GokoClientError, InternalServiceError}; 15 | 16 | pub(crate) type HttpResponseSender = oneshot::Sender, GokoClientError>>; 17 | pub(crate) type HttpResponseReciever = oneshot::Receiver, GokoClientError>>; 18 | pub(crate) type HttpRequestSender = mpsc::UnboundedSender; 19 | pub(crate) type HttpRequestReciever = mpsc::UnboundedReceiver; 20 | 21 | #[pin_project] 22 | pub(crate) struct HttpMessage { 23 | pub(crate) request: Option>, 24 | pub(crate) reply: Option, 25 | pub(crate) global_error: Arc>>>, 26 | } 27 | 28 | impl HttpMessage { 29 | pub(crate) fn request(&mut self) -> Option> { 30 | self.request.take() 31 | } 32 | 33 | pub(crate) fn respond(&mut self, response: Result,GokoClientError>) { 34 | match self.reply.take() { 35 | Some(reply) => { 36 | match reply.send(response) { 37 | Ok(_) => (), 38 | Err(_) => { 39 | *self.global_error.lock().unwrap() = Some(Box::new(GokoClientError::Underlying(InternalServiceError::FailedRespSend))); 40 | } 41 | } 42 | } 43 | None => *self.global_error.lock().unwrap() = Some(Box::new(GokoClientError::Underlying(InternalServiceError::DoubleRead))), 44 | } 45 | } 46 | pub(crate) fn error(&mut self, error: impl std::error::Error + Send + 'static) { 47 | *self.global_error.lock().unwrap() = Some(Box::new(error)); 48 | } 49 | } 50 | 51 | #[pin_project] 52 | pub struct ResponseFuture { 53 | #[pin] 54 | pub(crate) response: HttpResponseReciever, 55 | pub(crate) flight_counter: Arc, 56 | pub(crate) error: Option, 57 | } 58 | 59 | impl Future for ResponseFuture { 60 | type Output = Result, GokoClientError>; 61 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 62 | let this = self.project(); 63 | if let Some(err) = this.error.take() { 64 | return core::task::Poll::Ready(Err(err)); 65 | } 66 | else { 67 | let res = this.response.poll(cx).map(|r| { 68 | match r { 69 | Ok(r) => r.map_err(|e| GokoClientError::from(e)), 70 | Err(e) => Err(GokoClientError::from(e)) 71 | } 72 | }); 73 | this.flight_counter.fetch_sub(1, atomic::Ordering::SeqCst); 74 | res 75 | } 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /serve_goko/src/core/tracker_worker.rs: -------------------------------------------------------------------------------- 1 | use pointcloud::PointCloud; 2 | use goko::{CoverTreeReader,CoverTreeWriter}; 3 | use std::sync::{Arc,RwLock}; 4 | use crate::{GokoRequest, GokoResponse}; 5 | use std::ops::Deref; 6 | use goko::errors::GokoError; 7 | use std::collections::HashMap; 8 | 9 | use goko::plugins::discrete::tracker::BayesCategoricalTracker; 10 | 11 | pub(crate) type TrackerResponseSender = oneshot::Sender, GokoClientError>>; 12 | pub(crate) type TrackerResponseReciever = oneshot::Receiver, GokoClientError>>; 13 | pub(crate) type TrackerRequestSender = mpsc::UnboundedSender; 14 | pub(crate) type TrackerRequestReciever = mpsc::UnboundedReceiver; 15 | 16 | pub struct TrackerWorker { 17 | in_flight: Arc, 18 | request_snd: TrackerResponseSender, 19 | pointcloud: PhantomData, 20 | global_error: Arc>>>, 21 | } 22 | 23 | impl TrackerWorker 24 | where 25 | D: PointCloud, 26 | { 27 | pub(crate) fn new

(mut tracker: BayesCategoricalTracker) -> TrackerWorker 28 | where P: Deref + Send + Sync + 'static { 29 | let (request_snd, mut request_rcv): (TrackerRequestSender

, TrackerRequestReciever

) = 30 | mpsc::unbounded_channel(); 31 | tokio::spawn(async move { 32 | while let Some(mut msg) = request_rcv.recv().await { 33 | let response = match goko_request { 34 | Ok(r) => reader.process(r).await.map_err(|e| e.into()), 35 | Err(e) => Err(e), 36 | }; 37 | match response { 38 | Ok(resp) => msg.respond(into_http(resp)), 39 | Err(e) => msg.respond(Err(e)), 40 | }; 41 | } else { 42 | msg.error(GokoHttpError::DoubleRead) 43 | } 44 | } 45 | }); 46 | let global_error = Arc::new(Mutex::new(None)); 47 | let in_flight = Arc::new(atomic::AtomicU32::new(0)); 48 | TrackerWorker { 49 | in_flight, 50 | request_snd, 51 | pointcloud: PhantomData, 52 | global_error, 53 | } 54 | } 55 | 56 | pub(crate) fn message(&self, request: Request) -> ResponseFuture { 57 | let flight_counter = Arc::clone(&self.in_flight); 58 | self.in_flight.fetch_add(1, atomic::Ordering::SeqCst); 59 | let (reply, response): (TrackerResponseSender, TrackerResponseReciever) = oneshot::channel(); 60 | 61 | let msg = Message { 62 | request: Some(request), 63 | reply: Some(reply), 64 | global_error: Arc::clone(&self.global_error), 65 | }; 66 | 67 | let error = self.request_snd.send(msg).err().map(|_e| GokoHttpError::FailedSend); 68 | ResponseFuture { 69 | response, 70 | flight_counter, 71 | error, 72 | } 73 | } 74 | } -------------------------------------------------------------------------------- /goko/src/lib.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to Elasticsearch B.V. under one or more contributor 3 | * license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright 5 | * ownership. Elasticsearch B.V. licenses this file to you under 6 | * the Apache License, Version 2.0 (the "License"); you may 7 | * not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | #![allow(dead_code)] 21 | #![deny(warnings)] 22 | #![warn(missing_docs)] 23 | #![doc(test(attr(allow(unused_variables), deny(warnings))))] 24 | #![feature(binary_heap_into_iter_sorted)] 25 | #![feature(associated_type_defaults)] 26 | 27 | //! # Goko 28 | //! This is an lock-free efficient implementation of a covertree for data science. The traditional 29 | //! application of this is for KNN, but it can be applied and used in lots of other applications. 30 | //! 31 | //! ## Parameter Guide 32 | //! The structure is controlled by 3 parameters, the most important of which 33 | //! is the scale_base. This should be between 1.2 and 2ish. A higher value will create more outliers. 34 | //! Outliers are not loaded into ram at startup, but a high value slows down creation of a tree 35 | //! significantly. Theoretically, this value doesn't matter to the big O time, but I wouldn't go above 2. 36 | //! 37 | //! The cutoff value controls how many points a leaf is allowed to cover. A smaller value gives faster 38 | //! bottom level queries, but at the cost of higher memory useage. Do not expect a value of 100 will give 39 | //! 1/100 the memory useage of a value of 1. It'd be closer to 1/10 or 1/5th. This is because the distribution 40 | //! of the number of children of a node. A high cutoff will increase the compute of the true-knn by a little bit. 41 | //! 42 | //! The resolution is the minimum scale index, this again reduces memory footprint and increases the query time 43 | //! for true KNN. 44 | //! Once a node's resolution dips below this value we stop and covert the remaining coverage into a leaf. 45 | //! This is mainly to stop before floating point errors become an issue. Try to choose it to result in a cutoff of about 46 | //! 2^-9. 47 | //! 48 | //! See the git readme for a description of the algo. 49 | //! 50 | 51 | #[cfg(test)] 52 | #[macro_use] 53 | extern crate smallvec; 54 | 55 | extern crate rand; 56 | 57 | use rayon::prelude::*; 58 | #[cfg(test)] 59 | #[macro_use] 60 | extern crate assert_approx_eq; 61 | 62 | use pointcloud::*; 63 | pub mod errors; 64 | pub use errors::GokoResult; 65 | 66 | pub(crate) mod monomap; 67 | 68 | mod covertree; 69 | pub use covertree::*; 70 | 71 | pub mod query_interface; 72 | 73 | mod tree_file_format; 74 | pub mod utils; 75 | 76 | pub mod plugins; 77 | 78 | /// The data structure explicitly seperates the covertree by layer, and the addressing schema for nodes 79 | /// is a pair for the layer index and the center point index of that node. 80 | pub type NodeAddress = (i32, usize); 81 | /// Like with a node address, the clusters are segmented by layer so we also reference by layer. The ClusterID is not meaningful, it's just a uint. 82 | pub type ClusterAddress = (i32, usize); 83 | -------------------------------------------------------------------------------- /goko/src/plugins/utils.rs: -------------------------------------------------------------------------------- 1 | //! Plugin for labels and metadata 2 | 3 | use super::*; 4 | use crate::covertree::node::CoverNode; 5 | use crate::covertree::CoverTreeReader; 6 | //use pointcloud::*; 7 | use std::sync::Arc; 8 | 9 | /// Contains all points that this node covers, if the coverage is lower than the limit set in the parameters. 10 | #[derive(Debug, Clone)] 11 | pub struct CoverageIndexes { 12 | pis: Arc>, 13 | } 14 | 15 | impl NodePlugin for CoverageIndexes {} 16 | 17 | impl CoverageIndexes { 18 | /// Returns all point indexes that the node covers 19 | pub fn point_indexes(&self) -> &[usize] { 20 | self.pis.as_ref() 21 | } 22 | } 23 | 24 | /// A plugin that helps gather all the indexes that the node covers into an array you can use. 25 | #[derive(Debug, Clone)] 26 | pub struct GokoCoverageIndexes { 27 | /// The actual limit 28 | pub max: usize, 29 | } 30 | 31 | impl GokoCoverageIndexes { 32 | /// Set up the plugin for restricting the number of indexes we collect into any one node 33 | pub fn restricted(max: usize) -> Self { 34 | Self { max } 35 | } 36 | 37 | /// Set up the plugin for no restrictions 38 | pub fn new() -> Self { 39 | Self { max: usize::MAX } 40 | } 41 | } 42 | 43 | impl GokoPlugin for GokoCoverageIndexes { 44 | type NodeComponent = CoverageIndexes; 45 | fn node_component( 46 | parameters: &Self, 47 | my_node: &CoverNode, 48 | my_tree: &CoverTreeReader, 49 | ) -> Option { 50 | if my_node.coverage_count() < parameters.max { 51 | let mut indexes = my_node.singletons().to_vec(); 52 | // If we're a routing node then grab the childen's values 53 | if let Some((nested_scale, child_addresses)) = my_node.children() { 54 | my_tree.get_node_plugin_and::( 55 | (nested_scale, *my_node.center_index()), 56 | |p| { 57 | indexes.extend(p.point_indexes()); 58 | }, 59 | ); 60 | for ca in child_addresses { 61 | my_tree.get_node_plugin_and::(*ca, |p| { 62 | indexes.extend(p.point_indexes()); 63 | }); 64 | } 65 | } else { 66 | indexes.push(*my_node.center_index()); 67 | } 68 | Some(CoverageIndexes { 69 | pis: Arc::new(indexes), 70 | }) 71 | } else { 72 | None 73 | } 74 | } 75 | } 76 | 77 | #[cfg(test)] 78 | pub(crate) mod tests { 79 | use super::*; 80 | use crate::covertree::tests::build_basic_tree; 81 | 82 | #[test] 83 | fn coverage_sanity() { 84 | let mut ct = build_basic_tree(); 85 | ct.add_plugin::(GokoCoverageIndexes::new()); 86 | let ct_reader = ct.reader(); 87 | let mut untested_addresses = vec![ct_reader.root_address()]; 88 | while let Some(addr) = untested_addresses.pop() { 89 | let count = ct_reader 90 | .get_node_plugin_and::(addr, |p| p.point_indexes().len()) 91 | .unwrap(); 92 | ct_reader.get_node_and(addr, |n| { 93 | assert_eq!(n.coverage_count(), count, "Node: {:?}", n) 94 | }); 95 | 96 | ct_reader.get_node_children_and(addr, |covered, children| { 97 | untested_addresses.push(covered); 98 | untested_addresses.extend(children); 99 | }); 100 | } 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /serve_goko/src/parsers/mod.rs: -------------------------------------------------------------------------------- 1 | use http::Request; 2 | use hyper::Body; 3 | 4 | use core::task::Context; 5 | use std::future::Future; 6 | use std::pin::Pin; 7 | use std::task::Poll; 8 | 9 | use crate::errors::*; 10 | use std::fmt::Debug; 11 | use std::marker::PhantomData; 12 | 13 | use serde::Serialize; 14 | use hyper::body::HttpBody; 15 | use pin_project::pin_project; 16 | 17 | mod msgpack_dense; 18 | pub use msgpack_dense::MsgPackDense; 19 | 20 | pub trait PointParser: Send + 'static { 21 | type Point: Serialize + Send + Sync + Debug + 'static; 22 | fn parse(body_buffer: &[u8], scratch_buffer: &mut Vec, request: &Request) -> Result; 23 | } 24 | 25 | #[pin_project] 26 | pub(crate) struct PointBuffer { 27 | body_buffer: Vec, 28 | point_buffer: Vec, 29 | request: Request, 30 | parser: PhantomData

, 31 | } 32 | 33 | impl PointBuffer

{ 34 | pub(crate) fn new() -> Self { 35 | PointBuffer { 36 | body_buffer: Vec::with_capacity(8*1024), 37 | point_buffer: Vec::with_capacity(8*1024), 38 | request: Request::default(), 39 | parser: PhantomData, 40 | } 41 | } 42 | pub(crate) fn switch(&mut self, req: Request) { 43 | self.request = req; 44 | self.body_buffer.clear(); 45 | self.point_buffer.clear(); 46 | } 47 | 48 | pub(crate) fn poll_point(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { 49 | let this = self.project(); 50 | let mut body = this.request.body_mut(); 51 | loop { 52 | let new_bytes = match Pin::new(&mut body).poll_data(cx) { 53 | Poll::Ready(data) => data, 54 | Poll::Pending => return Poll::Pending, 55 | }; 56 | if let Some(new_bytes) = new_bytes { 57 | match new_bytes { 58 | Ok(new_bytes) => { 59 | this.body_buffer.extend_from_slice(&new_bytes); 60 | } 61 | Err(e) => { 62 | this.body_buffer.clear(); 63 | this.point_buffer.clear(); 64 | *this.request = Request::default(); 65 | return Poll::Ready(Err(e.into())) 66 | }, 67 | } 68 | } else { 69 | match Pin::new(&mut body).poll_trailers(cx) { 70 | Poll::Ready(_) => (), 71 | Poll::Pending => return Poll::Pending, 72 | } 73 | } 74 | 75 | if body.is_end_stream() { 76 | let point_res = P::parse(this.body_buffer, this.point_buffer, this.request); 77 | this.body_buffer.clear(); 78 | this.point_buffer.clear(); 79 | *this.request = Request::default(); 80 | return Poll::Ready(point_res) 81 | } 82 | } 83 | } 84 | 85 | pub(crate) fn point(&mut self, req: Request) -> PointFuture<'_, P> 86 | where 87 | Self: Unpin + Sized, 88 | { 89 | self.switch(req); 90 | PointFuture{ 91 | req: self, 92 | } 93 | } 94 | } 95 | 96 | #[pin_project] 97 | /// Future that resolves to the next data chunk from `Body` 98 | pub(crate) struct PointFuture<'a, P: PointParser> { 99 | req: &'a mut PointBuffer

, 100 | } 101 | 102 | impl<'a, P: PointParser> Future for PointFuture<'a, P> { 103 | type Output = Result; 104 | 105 | fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { 106 | Pin::new(&mut *self.req).poll_point(ctx) 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /goko/examples/ember_sequence_track.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to Elasticsearch B.V. under one or more contributor 3 | * license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright 5 | * ownership. Elasticsearch B.V. licenses this file to you under 6 | * the Apache License, Version 2.0 (the "License"); you may 7 | * not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | extern crate protobuf; 21 | extern crate rand; 22 | extern crate yaml_rust; 23 | use std::path::Path; 24 | #[allow(dead_code)] 25 | extern crate goko; 26 | extern crate pointcloud; 27 | use goko::*; 28 | use pointcloud::*; 29 | use pointcloud::{data_sources::*, label_sources::*, loaders::*}; 30 | 31 | use std::sync::Arc; 32 | use std::time; 33 | 34 | use goko::plugins::discrete::prelude::*; 35 | use goko::query_interface::BulkInterface; 36 | 37 | fn build_tree() -> CoverTreeWriter, VecLabels>> { 38 | let file_name = "data/ember_complex.yml"; 39 | let path = Path::new(file_name); 40 | if !path.exists() { 41 | panic!("{} does not exist", file_name); 42 | } 43 | let builder = CoverTreeBuilder::from_yaml(&path); 44 | let point_cloud = vec_labeled_ram_from_yaml("data/ember_complex.yml").unwrap(); 45 | builder.build(Arc::new(point_cloud)).unwrap() 46 | } 47 | 48 | fn build_test_set() -> SimpleLabeledCloud, VecLabels> { 49 | vec_labeled_ram_from_yaml("data/ember_complex_test.yml").unwrap() 50 | } 51 | 52 | fn main() { 53 | let mut ct = build_tree(); 54 | ct.add_plugin::(GokoDirichlet {}); 55 | let test_set = build_test_set(); 56 | //ct.cluster().unwrap(); 57 | ct.refresh(); 58 | let ct_reader = ct.reader(); 59 | println!("Tree has {} nodes", ct_reader.node_count()); 60 | let window_size = 0; 61 | let num_sequence = 8; 62 | 63 | let mut baseline = DirichletBaseline::default(); 64 | baseline.set_sequence_len(test_set.len()); 65 | baseline.set_num_sequences(num_sequence); 66 | println!( 67 | "Gathering baseline with window_size: {}", 68 | window_size 69 | ); 70 | let baseline_start = time::Instant::now(); 71 | let _baseline_data = baseline.train(ct.reader()); 72 | let baseline_elapse = baseline_start.elapsed().as_millis(); 73 | 74 | println!( 75 | "BASELINE: Time elapsed {:?} milliseconds, time per sequence {} milliseconds", 76 | baseline_elapse, 77 | (baseline_elapse as f64) / ((test_set.len() * num_sequence) as f64) 78 | ); 79 | 80 | let mut tracker = 81 | BayesCategoricalTracker::new( window_size, ct.reader()); 82 | let start = time::Instant::now(); 83 | 84 | let points: Vec<&[f32]> = (0..test_set.len()) 85 | .map(|i| test_set.point(i).unwrap()) 86 | .collect(); 87 | let bulk = BulkInterface::new(ct.reader()); 88 | let mut paths = bulk.path(&points); 89 | for path in paths.drain(0..) { 90 | tracker.add_path(path.unwrap()); 91 | } 92 | 93 | let elapse = start.elapsed().as_millis(); 94 | println!( 95 | "Time elapsed {:?} milliseconds, time per sequence {} milliseconds", 96 | elapse, 97 | (elapse as f64) / (test_set.len() as f64) 98 | ); 99 | println!("stats: {:?}", tracker.kl_div_stats()); 100 | } 101 | -------------------------------------------------------------------------------- /goko/examples/ember_drop.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to Elasticsearch B.V. under one or more contributor 3 | * license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright 5 | * ownership. Elasticsearch B.V. licenses this file to you under 6 | * the Apache License, Version 2.0 (the "License"); you may 7 | * not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | extern crate protobuf; 21 | extern crate rand; 22 | extern crate yaml_rust; 23 | use std::path::Path; 24 | #[allow(dead_code)] 25 | extern crate goko; 26 | extern crate pointcloud; 27 | use goko::*; 28 | use pointcloud::*; 29 | use pointcloud::{data_sources::*, label_sources::*, loaders::*}; 30 | 31 | use std::collections::HashMap; 32 | use std::sync::Arc; 33 | use std::time; 34 | 35 | use goko::query_interface::BulkInterface; 36 | 37 | fn build_tree() -> CoverTreeWriter, SmallIntLabels>> { 38 | let file_name = "../data/ember_complex.yml"; 39 | let path = Path::new(file_name); 40 | if !path.exists() { 41 | panic!("{} does not exist", file_name); 42 | } 43 | let builder = CoverTreeBuilder::from_yaml(&path); 44 | let point_cloud = labeled_ram_from_yaml("../data/ember_complex.yml").unwrap(); 45 | builder.build(Arc::new(point_cloud)).unwrap() 46 | } 47 | 48 | fn main() { 49 | let mut ct = build_tree(); 50 | ct.generate_summaries(); 51 | ct.refresh(); 52 | let ct_reader = ct.reader(); 53 | println!("Tree has {} nodes", ct_reader.node_count()); 54 | 55 | let start = time::Instant::now(); 56 | let bulk = BulkInterface::new(ct.reader()); 57 | let point_indexes = ct_reader.point_cloud().reference_indexes(); 58 | let tau = 0.05; 59 | let depths = bulk.known_path_and(&point_indexes, |reader, path| { 60 | if let Ok(path) = path { 61 | let mut homogenity_depth = path.len(); 62 | for (i, (_d, a)) in path.iter().enumerate() { 63 | let summ = reader.get_node_label_summary(*a).unwrap(); 64 | if summ.summary.items.len() == 1 { 65 | homogenity_depth = i; 66 | break; 67 | } 68 | let sum = summ.summary.items.iter().map(|(_, c)| c).sum::() as f32; 69 | let max = *summ.summary.items.iter().map(|(_, c)| c).max().unwrap() as f32; 70 | if 1.0 - max / sum < tau { 71 | homogenity_depth = i; 72 | break; 73 | } 74 | } 75 | (path.len(), homogenity_depth) 76 | } else { 77 | (0, 0) 78 | } 79 | }); 80 | 81 | let mut final_depths = HashMap::new(); 82 | for (f, h) in &depths { 83 | final_depths 84 | .entry(f - h) 85 | .and_modify(|c| *c += 1) 86 | .or_insert(1); 87 | } 88 | let mut keys: Vec = final_depths.keys().cloned().collect(); 89 | keys.sort(); 90 | println!("Final Depths:"); 91 | for k in keys { 92 | println!("{}: {:?}", k, final_depths.get(&k).unwrap()); 93 | } 94 | 95 | let elapse = start.elapsed().as_millis(); 96 | println!( 97 | "Time elapsed {:?} milliseconds, time per sequence {} milliseconds", 98 | elapse, 99 | (elapse as f64) / (depths.len() as f64) 100 | ); 101 | } 102 | -------------------------------------------------------------------------------- /pygoko/src/plugins.rs: -------------------------------------------------------------------------------- 1 | use goko::plugins::discrete::prelude::*; 2 | use goko::*; 3 | use numpy::PyArray1; 4 | use pointcloud::*; 5 | use pyo3::prelude::*; 6 | use pyo3::types::PyDict; 7 | 8 | /* 9 | pub #[derive(Debug)] 10 | struct PyBucketProbs { 11 | probs: BucketProbs 12 | } 13 | 14 | #[pymethods] 15 | impl PyBucketProbs { 16 | pub fn pdfs(&self) -> PyResult>> { 17 | Array1::from_shape_vec((dim,), m) 18 | .unwrap() 19 | .into_pyarray(py) 20 | .to_owned() 21 | } 22 | } 23 | */ 24 | 25 | #[pyclass(unsendable)] 26 | pub struct PyBayesCategoricalTracker { 27 | pub hkl: BayesCategoricalTracker>, 28 | pub tree: CoverTreeReader>, 29 | } 30 | 31 | #[pymethods] 32 | impl PyBayesCategoricalTracker { 33 | pub fn push(&mut self, point: &PyArray1) { 34 | let results = self 35 | .tree 36 | .path(&point.readonly().as_slice().unwrap()) 37 | .unwrap(); 38 | self.hkl.add_path(results); 39 | } 40 | 41 | pub fn print(&self) { 42 | println!("{:#?}", self.hkl); 43 | } 44 | 45 | pub fn probs(&self, node_address: (i32, usize)) -> Option<(Vec<((i32, usize), f64)>, f64)> { 46 | self.hkl.prob_vector(node_address) 47 | } 48 | 49 | pub fn evidence(&self, node_address: (i32, usize)) -> Option<(Vec<((i32, usize), f64)>, f64)> { 50 | self.hkl.evidence_prob_vector(node_address) 51 | } 52 | 53 | pub fn all_kl(&self) -> Vec<(f64, (i32, usize))> { 54 | self.hkl.all_node_kl() 55 | } 56 | 57 | pub fn kl_div(&self) -> f64 { 58 | self.hkl.kl_div() 59 | } 60 | 61 | pub fn stats(&self) -> PyResult { 62 | let stats = self.hkl.kl_div_stats(); 63 | let gil = pyo3::Python::acquire_gil(); 64 | let py = gil.python(); 65 | let dict = PyDict::new(py); 66 | dict.set_item("max", stats.max)?; 67 | dict.set_item("min", stats.min)?; 68 | dict.set_item("nz_count", stats.nz_count)?; 69 | dict.set_item("moment1_nz", stats.moment1_nz)?; 70 | dict.set_item("moment2_nz", stats.moment2_nz)?; 71 | dict.set_item("sequence_len", stats.sequence_len)?; 72 | Ok(dict.into()) 73 | } 74 | } 75 | 76 | #[pyclass(unsendable)] 77 | pub struct PyKLDivergenceBaseline { 78 | pub baseline: KLDivergenceBaseline, 79 | } 80 | 81 | #[pymethods] 82 | impl PyKLDivergenceBaseline { 83 | pub fn stats(&self, i: usize) -> PyResult { 84 | let stats = self.baseline.stats(i); 85 | let gil = pyo3::Python::acquire_gil(); 86 | let dict = PyDict::new(gil.python()); 87 | let max_dict = PyDict::new(gil.python()); 88 | max_dict.set_item("mean", stats.max.0)?; 89 | max_dict.set_item("var", stats.max.1)?; 90 | dict.set_item("max", max_dict)?; 91 | 92 | let min_dict = PyDict::new(gil.python()); 93 | min_dict.set_item("mean", stats.min.0)?; 94 | min_dict.set_item("var", stats.min.1)?; 95 | dict.set_item("min", min_dict)?; 96 | 97 | let nz_count_dict = PyDict::new(gil.python()); 98 | nz_count_dict.set_item("mean", stats.nz_count.0)?; 99 | nz_count_dict.set_item("var", stats.nz_count.1)?; 100 | dict.set_item("nz_count", nz_count_dict)?; 101 | 102 | let moment1_nz_dict = PyDict::new(gil.python()); 103 | moment1_nz_dict.set_item("mean", stats.moment1_nz.0)?; 104 | moment1_nz_dict.set_item("var", stats.moment1_nz.1)?; 105 | dict.set_item("moment1_nz", moment1_nz_dict)?; 106 | 107 | let moment2_nz_dict = PyDict::new(gil.python()); 108 | moment2_nz_dict.set_item("mean", stats.moment2_nz.0)?; 109 | moment2_nz_dict.set_item("var", stats.moment2_nz.1)?; 110 | dict.set_item("moment2_nz", moment2_nz_dict)?; 111 | Ok(dict.into()) 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /pointcloud/src/data_sources/sparse_ram.rs: -------------------------------------------------------------------------------- 1 | use crate::pc_errors::PointCloudResult; 2 | use std::convert::TryInto; 3 | use std::marker::PhantomData; 4 | use crate::pc_errors::ParsingError; 5 | 6 | use crate::base_traits::*; 7 | use crate::metrics::*; 8 | use crate::points::*; 9 | 10 | /// The data stored in ram. 11 | #[derive(Debug)] 12 | pub struct SparseDataRam { 13 | name: String, 14 | values: Vec, 15 | col_index: Vec, 16 | row_index: Vec, 17 | dim: usize, 18 | metric: PhantomData, 19 | } 20 | 21 | impl SparseDataRam 22 | where 23 | CoefField: std::fmt::Debug + 'static, 24 | Index: std::fmt::Debug + 'static, 25 | { 26 | pub fn new( 27 | values: Vec, 28 | col_index: Vec, 29 | row_index: Vec, 30 | dim: usize, 31 | ) -> SparseDataRam { 32 | SparseDataRam:: { 33 | name: String::new(), 34 | values, 35 | col_index, 36 | row_index, 37 | dim, 38 | metric: PhantomData, 39 | } 40 | } 41 | } 42 | 43 | impl PointCloud for SparseDataRam 44 | where 45 | M: Metric>, 46 | { 47 | type PointRef<'a> = SparseRef<'a, f32, u32>; 48 | type Point = RawSparse; 49 | type Metric = L2; 50 | type LabelSummary = (); 51 | type Label = (); 52 | type MetaSummary = (); 53 | type Metadata = (); 54 | 55 | fn metadata(&self, _pn: usize) -> PointCloudResult> { 56 | Ok(None) 57 | } 58 | fn metasummary(&self, pns: &[usize]) -> PointCloudResult> { 59 | Ok(SummaryCounter { 60 | summary: (), 61 | nones: pns.len(), 62 | errors: 0, 63 | }) 64 | } 65 | fn label(&self, _pn: usize) -> PointCloudResult> { 66 | Ok(None) 67 | } 68 | fn label_summary(&self, pns: &[usize]) -> PointCloudResult> { 69 | Ok(SummaryCounter { 70 | summary: (), 71 | nones: pns.len(), 72 | errors: 0, 73 | }) 74 | } 75 | fn name(&self, pi: usize) -> PointCloudResult { 76 | Ok(pi.to_string()) 77 | } 78 | fn index(&self, pn: &str) -> PointCloudResult { 79 | pn.parse::().map_err(|_| ParsingError::RegularParsingError("Unable to parse your str into an usize").into()) 80 | } 81 | fn names(&self) -> Vec { 82 | (0..self.len()).map(|i| i.to_string()).collect() 83 | } 84 | 85 | /// The number of samples this cloud covers 86 | fn len(&self) -> usize { 87 | self.row_index.len() - 1 88 | } 89 | /// If this is empty 90 | fn is_empty(&self) -> bool { 91 | self.row_index.len() > 1 92 | } 93 | /// The dimension of the underlying data 94 | fn dim(&self) -> usize { 95 | self.dim 96 | } 97 | /// Indexes used for access 98 | fn reference_indexes(&self) -> Vec { 99 | (0..self.len()).collect() 100 | } 101 | /// Gets a point from this dataset 102 | fn point<'a, 'b: 'a>(&'b self, pn: usize) -> PointCloudResult> { 103 | let lower_bound = self.row_index[pn].try_into(); 104 | let upper_bound = self.row_index[pn + 1].try_into(); 105 | if let (Ok(lower_bound), Ok(upper_bound)) = (lower_bound, upper_bound) { 106 | let values = &self.values[lower_bound..upper_bound]; 107 | let indexes = &self.col_index[lower_bound..upper_bound]; 108 | Ok(SparseRef::new(self.dim, values, indexes)) 109 | } else { 110 | panic!("Could not covert a usize into a sparse dimension"); 111 | } 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /goko/src/plugins/gaussians/svd_gaussian.rs: -------------------------------------------------------------------------------- 1 | //! # Diagonal Gaussian 2 | //! 3 | //! This computes a coordinate bound multivariate Gaussian. This can be thought of as a rough 4 | //! simulation of the data underling a node. However we can chose the scale from which we 5 | //! simulate the data, down to the individual point, so this can be arbitrarily accurate. 6 | 7 | use super::*; 8 | use crate::covertree::node::CoverNode; 9 | use crate::covertree::CoverTreeReader; 10 | use crate::plugins::utils::*; 11 | 12 | use ndarray::prelude::*; 13 | use ndarray_linalg::svd::*; 14 | 15 | /// Node component, coded in such a way that it can be efficiently, recursively computed. 16 | #[derive(Debug, Clone, Default)] 17 | pub struct SvdGaussian { 18 | /// Mean of this gaussian 19 | pub mean: Array1, 20 | /// Second Moment 21 | pub vt: Array2, 22 | /// The singular values 23 | pub singular_vals: Array1, 24 | } 25 | /* 26 | impl ContinousDistribution for SvdGaussian { 27 | fn ln_pdf(&self, _point: &PointRef) -> Option { 28 | unimplemented!() 29 | } 30 | 31 | fn sample(&self, _rng: &mut R) -> Vec { 32 | unimplemented!() 33 | } 34 | 35 | fn kl_divergence(&self, _other: &SvdGaussian) -> Option { 36 | unimplemented!() 37 | } 38 | } 39 | */ 40 | impl SvdGaussian { 41 | /// Mean: `moment1/count` 42 | pub fn mean(&self) -> Array1 { 43 | self.mean.clone() 44 | } 45 | } 46 | 47 | impl NodePlugin for SvdGaussian {} 48 | 49 | /// Zero sized type that can be passed around. Equivilant to `()` 50 | #[derive(Debug, Clone)] 51 | pub struct GokoSvdGaussian { 52 | max_points: usize, 53 | min_points: usize, 54 | tau: f32, 55 | } 56 | 57 | impl GokoSvdGaussian { 58 | /// Specify the max number of points, and the min that you want to compute the SVD over, and the tau used for dimension calulations 59 | pub fn new(min_points: usize, max_points: usize, tau: f32) -> GokoSvdGaussian { 60 | GokoSvdGaussian { 61 | max_points, 62 | min_points, 63 | tau, 64 | } 65 | } 66 | } 67 | 68 | impl GokoPlugin for GokoSvdGaussian { 69 | type NodeComponent = SvdGaussian; 70 | fn prepare_tree(parameters: &Self, my_tree: &mut CoverTreeWriter) { 71 | my_tree.add_plugin::(GokoCoverageIndexes::restricted( 72 | parameters.max_points, 73 | )); 74 | my_tree.add_plugin::(GokoDiagGaussian::recursive()); 75 | } 76 | fn node_component( 77 | parameters: &Self, 78 | my_node: &CoverNode, 79 | my_tree: &CoverTreeReader, 80 | ) -> Option { 81 | if my_node.coverage_count() > parameters.min_points { 82 | let points = my_node.get_plugin_and::(|p| { 83 | my_tree 84 | .parameters() 85 | .point_cloud 86 | .points_dense_matrix(p.point_indexes()) 87 | .unwrap() 88 | }); 89 | if let Some(mut points) = points { 90 | let mean = my_node 91 | .get_plugin_and::(|p| { 92 | Array1::from_shape_vec((p.dim(),), p.mean()).unwrap() 93 | }) 94 | .unwrap(); 95 | for mut p in points.axis_iter_mut(Axis(0)) { 96 | p -= &mean; 97 | } 98 | 99 | let (_u, singular_vals, vt) = points.svd(false, true).unwrap(); 100 | let vt = vt.unwrap(); 101 | Some(SvdGaussian { 102 | singular_vals, 103 | vt, 104 | mean, 105 | }) 106 | } else { 107 | None 108 | } 109 | } else { 110 | None 111 | } 112 | } 113 | } 114 | /* 115 | #[cfg(test)] 116 | pub(crate) mod tests { 117 | use super::*; 118 | use crate::covertree::tests::build_basic_tree; 119 | } 120 | */ 121 | -------------------------------------------------------------------------------- /examples/graphistry_vis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "language_info": { 4 | "codemirror_mode": { 5 | "name": "ipython", 6 | "version": 3 7 | }, 8 | "file_extension": ".py", 9 | "mimetype": "text/x-python", 10 | "name": "python", 11 | "nbconvert_exporter": "python", 12 | "pygments_lexer": "ipython3", 13 | "version": "3.8.5-final" 14 | }, 15 | "orig_nbformat": 2, 16 | "kernelspec": { 17 | "name": "python3", 18 | "display_name": "Python 3" 19 | } 20 | }, 21 | "nbformat": 4, 22 | "nbformat_minor": 2, 23 | "cells": [ 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import json\n", 31 | "import numpy as np\n", 32 | "import pandas as pd\n", 33 | "from pygoko import CoverTree\n", 34 | "with open(\"graphistry_creds.json\") as creds:\n", 35 | " creds = json.load(creds)\n", 36 | "import graphistry\n", 37 | "graphistry.register(api=3, username=creds[\"username\"], password=creds[\"password\"])\n", 38 | "\n", 39 | "import tensorflow as tf\n", 40 | "mnist = tf.keras.datasets.mnist\n", 41 | "\n", 42 | "(x_train, y_train),(x_test, y_test) = mnist.load_data()\n", 43 | "x_train, x_test = x_train / 255.0, x_test / 255.0\n", 44 | "x_train = x_train.astype(np.float32)\n", 45 | "y_train = y_train.astype(np.int64)\n", 46 | "x_train = x_train.reshape(-1, 28*28)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "tree = CoverTree()\n", 56 | "tree.set_leaf_cutoff(50)\n", 57 | "tree.set_scale_base(1.3)\n", 58 | "tree.set_min_res_index(-20)\n", 59 | "tree.fit(x_train,y_train)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "source = []\n", 69 | "destination = []\n", 70 | "weight = []\n", 71 | "depth = []\n", 72 | "\n", 73 | "node_colors = [0xFF000000, 0xFFFF0000, 0xFFFFFF00, 0x00FF0000, 0x0000FF00, 0xFF00FF00, 0x88000000, 0x88888800, 0x00880000, 0x00008800]\n", 74 | "\n", 75 | "nodes_centers = set()\n", 76 | "\n", 77 | "unvisited_nodes = [tree.root()]\n", 78 | "while len(unvisited_nodes) > 0: \n", 79 | " node = unvisited_nodes.pop()\n", 80 | " nodes_centers.add(node.address()[1])\n", 81 | " for child in node.children():\n", 82 | " source.append(node.address()[1])\n", 83 | " destination.append(child.address()[1])\n", 84 | " depth.append(node.address()[0] * 0xFF000000)\n", 85 | " weight.append(1.3**(-node.address()[0]))\n", 86 | " unvisited_nodes.append(child)\n", 87 | "\n", 88 | "edges = pd.DataFrame({\n", 89 | " 'source': source,\n", 90 | " 'destination': destination,\n", 91 | " 'weight': weight,\n", 92 | " 'depth': depth,\n", 93 | "})\n", 94 | "\n", 95 | "node_id = list(nodes_centers)\n", 96 | "label = [node_colors[y_train[i]] for i in node_id]\n", 97 | "\n", 98 | "nodes = pd.DataFrame({\n", 99 | " \"node_id\" : node_id,\n", 100 | " \"label\" : label\n", 101 | "})" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "URL_PARAMS = {'play': 5000, 'edgeCurvature': 0.1, 'precisionVsSpeed': -3}\n", 111 | "g = graphistry.nodes(nodes).edges(edges).bind(source='source', destination='destination', node=\"node_id\", point_color=\"label\").settings(url_params=URL_PARAMS)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "g.bind(edge_weight='weight').settings(url_params={**URL_PARAMS, 'expansionRatio': 40, 'edgeInfluence': 7}).plot(render=True)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [] 129 | } 130 | ] 131 | } -------------------------------------------------------------------------------- /pointcloud/src/metrics/l1_f32.rs: -------------------------------------------------------------------------------- 1 | //! f32 implementations of the L1 metric. 2 | 3 | use super::L1; 4 | use crate::base_traits::Metric; 5 | use crate::points::*; 6 | use packed_simd::*; 7 | use std::ops::Deref; 8 | 9 | impl Metric<[f32]> for L1 { 10 | fn dist(x: &[f32], y: &[f32]) -> f32 { 11 | l1_dense_f32(x.deref(), y.deref()).sqrt() 12 | } 13 | } 14 | 15 | impl<'a> Metric> for L1 { 16 | fn dist(x: &RawSparse, y: &RawSparse) -> f32 { 17 | l1_sparse_f32_f32(x.indexes(), x.values(), y.indexes(), y.values()).sqrt() 18 | } 19 | } 20 | 21 | impl<'a> Metric> for L1 { 22 | fn dist(x: &RawSparse, y: &RawSparse) -> f32 { 23 | l1_sparse_f32_f32(x.indexes(), x.values(), y.indexes(), y.values()).sqrt() 24 | } 25 | } 26 | 27 | impl<'a> Metric> for L1 { 28 | fn dist(x: &RawSparse, y: &RawSparse) -> f32 { 29 | l1_sparse_f32_f32(x.indexes(), x.values(), y.indexes(), y.values()).sqrt() 30 | } 31 | } 32 | 33 | /// 34 | pub fn l1_dense_f32(mut x: &[f32], mut y: &[f32]) -> f32 { 35 | let mut d_acc_16 = f32x16::splat(0.0); 36 | while y.len() > 16 { 37 | let y_simd = f32x16::from_slice_unaligned(y); 38 | let x_simd = f32x16::from_slice_unaligned(x); 39 | let diff = x_simd - y_simd; 40 | d_acc_16 += diff.abs(); 41 | y = &y[16..]; 42 | x = &x[16..]; 43 | } 44 | let mut d_acc_8 = f32x8::splat(0.0); 45 | if y.len() > 8 { 46 | let y_simd = f32x8::from_slice_unaligned(y); 47 | let x_simd = f32x8::from_slice_unaligned(x); 48 | let diff = x_simd - y_simd; 49 | d_acc_8 += diff.abs(); 50 | y = &y[8..]; 51 | x = &x[8..]; 52 | } 53 | let leftover = y 54 | .iter() 55 | .zip(x) 56 | .map(|(xi, yi)| (xi - yi).abs()) 57 | .fold(0.0, |acc, y| acc + y); 58 | leftover + d_acc_8.sum() + d_acc_16.sum() 59 | } 60 | 61 | /// 62 | #[inline] 63 | pub fn l1_norm_f32(mut x: &[f32]) -> f32 { 64 | let mut d_acc_16 = f32x16::splat(0.0); 65 | while x.len() > 16 { 66 | let x_simd = f32x16::from_slice_unaligned(x); 67 | d_acc_16 += x_simd.abs(); 68 | x = &x[16..]; 69 | } 70 | let mut d_acc_8 = f32x8::splat(0.0); 71 | if x.len() > 8 { 72 | let x_simd = f32x8::from_slice_unaligned(x); 73 | d_acc_8 += x_simd.abs(); 74 | x = &x[8..]; 75 | } 76 | let leftover = x.iter().map(|xi| xi.abs()).fold(0.0, |acc, xi| acc + xi); 77 | leftover + d_acc_8.sum() + d_acc_16.sum() 78 | } 79 | 80 | /// 81 | pub fn l1_sparse_f32_f32(x_ind: &[S], x_val: &[f32], y_ind: &[S], y_val: &[f32]) -> f32 82 | where 83 | S: Ord, 84 | { 85 | if x_val.is_empty() || y_val.is_empty() { 86 | if x_val.is_empty() && y_val.is_empty() { 87 | return 0.0; 88 | } 89 | if !x_val.is_empty() && y_val.is_empty() { 90 | l1_norm_f32(x_val) 91 | } else { 92 | l1_norm_f32(y_val) 93 | } 94 | } else { 95 | let mut total = 0.0; 96 | let (short_iter, mut long_iter) = if x_ind.len() > y_ind.len() { 97 | (y_ind.iter().zip(y_val), x_ind.iter().zip(x_val)) 98 | } else { 99 | (x_ind.iter().zip(x_val), y_ind.iter().zip(y_val)) 100 | }; 101 | 102 | let mut l_tr: Option<(&S, &f32)> = long_iter.next(); 103 | for (si, sv) in short_iter { 104 | while let Some((li, lv)) = l_tr { 105 | if li < si { 106 | total += lv.abs(); 107 | l_tr = long_iter.next(); 108 | } else { 109 | break; 110 | } 111 | } 112 | if let Some((li, lv)) = l_tr { 113 | if li == si { 114 | let val = sv - lv; 115 | total += val.abs(); 116 | l_tr = long_iter.next(); 117 | } else { 118 | total += sv.abs(); 119 | } 120 | } else { 121 | total += sv.abs(); 122 | } 123 | } 124 | while let Some((_li, lv)) = l_tr { 125 | total += lv.abs(); 126 | l_tr = long_iter.next(); 127 | } 128 | total 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /pointcloud/src/metrics/l2_f32.rs: -------------------------------------------------------------------------------- 1 | //! f32 implementations of the L1 metric. 2 | 3 | use super::L2; 4 | use crate::base_traits::Metric; 5 | use crate::points::*; 6 | use packed_simd::*; 7 | use std::ops::Deref; 8 | 9 | impl Metric<[f32]> for L2 { 10 | fn dist(x: &[f32], y: &[f32]) -> f32 { 11 | sq_l2_dense_f32(x.deref(), y.deref()).sqrt() 12 | } 13 | } 14 | 15 | impl<'a> Metric> for L2 { 16 | fn dist(x: &RawSparse, y: &RawSparse) -> f32 { 17 | sq_l2_sparse_f32_f32(x.indexes(), x.values(), y.indexes(), y.values()).sqrt() 18 | } 19 | } 20 | 21 | impl<'a> Metric> for L2 { 22 | fn dist(x: &RawSparse, y: &RawSparse) -> f32 { 23 | sq_l2_sparse_f32_f32(x.indexes(), x.values(), y.indexes(), y.values()).sqrt() 24 | } 25 | } 26 | 27 | impl<'a> Metric> for L2 { 28 | fn dist(x: &RawSparse, y: &RawSparse) -> f32 { 29 | sq_l2_sparse_f32_f32(x.indexes(), x.values(), y.indexes(), y.values()).sqrt() 30 | } 31 | } 32 | 33 | /// basic sparse function 34 | pub fn sq_l2_sparse_f32_f32(x_ind: &[S], x_val: &[f32], y_ind: &[S], y_val: &[f32]) -> f32 35 | where 36 | S: Ord, 37 | { 38 | if x_val.is_empty() || y_val.is_empty() { 39 | if x_val.is_empty() && y_val.is_empty() { 40 | return 0.0; 41 | } 42 | if !x_val.is_empty() && y_val.is_empty() { 43 | sq_l2_norm_f32(x_val) 44 | } else { 45 | sq_l2_norm_f32(y_val) 46 | } 47 | } else { 48 | let mut total = 0.0; 49 | let (short_iter, mut long_iter) = if x_ind.len() > y_ind.len() { 50 | (y_ind.iter().zip(y_val), x_ind.iter().zip(x_val)) 51 | } else { 52 | (x_ind.iter().zip(x_val), y_ind.iter().zip(y_val)) 53 | }; 54 | 55 | let mut l_tr: Option<(&S, &f32)> = long_iter.next(); 56 | for (si, sv) in short_iter { 57 | while let Some((li, lv)) = l_tr { 58 | if li < si { 59 | total += *lv * *lv; 60 | l_tr = long_iter.next(); 61 | } else { 62 | break; 63 | } 64 | } 65 | if let Some((li, lv)) = l_tr { 66 | if li == si { 67 | let val = *sv - *lv; 68 | total += val * val; 69 | l_tr = long_iter.next(); 70 | } else { 71 | total += *sv * *sv; 72 | } 73 | } else { 74 | total += *sv * *sv; 75 | } 76 | } 77 | while let Some((_li, lv)) = l_tr { 78 | total += *lv * *lv; 79 | l_tr = long_iter.next(); 80 | } 81 | total 82 | } 83 | } 84 | 85 | /// 86 | #[inline] 87 | pub fn sq_l2_dense_f32(mut x: &[f32], mut y: &[f32]) -> f32 { 88 | let mut d_acc_16 = f32x16::splat(0.0); 89 | while y.len() > 16 { 90 | let x_simd = f32x16::from_slice_unaligned(x); 91 | let y_simd = f32x16::from_slice_unaligned(y); 92 | let diff = x_simd - y_simd; 93 | d_acc_16 += diff * diff; 94 | y = &y[16..]; 95 | x = &x[16..]; 96 | } 97 | let mut d_acc_8 = f32x8::splat(0.0); 98 | if y.len() > 8 { 99 | let x_simd = f32x8::from_slice_unaligned(x); 100 | let y_simd = f32x8::from_slice_unaligned(y); 101 | let diff = x_simd - y_simd; 102 | d_acc_8 += diff * diff; 103 | y = &y[8..]; 104 | x = &x[8..]; 105 | } 106 | let leftover = y 107 | .iter() 108 | .zip(x) 109 | .map(|(xi, yi)| (xi - yi) * (xi - yi)) 110 | .fold(0.0, |acc, y| acc + y); 111 | leftover + d_acc_8.sum() + d_acc_16.sum() 112 | } 113 | 114 | /// 115 | #[inline] 116 | pub fn sq_l2_norm_f32(mut x: &[f32]) -> f32 { 117 | let mut d_acc_16 = f32x16::splat(0.0); 118 | while x.len() > 16 { 119 | let x_simd = f32x16::from_slice_unaligned(x); 120 | d_acc_16 += x_simd * x_simd; 121 | x = &x[16..]; 122 | } 123 | let mut d_acc_8 = f32x8::splat(0.0); 124 | if x.len() > 8 { 125 | let x_simd = f32x8::from_slice_unaligned(x); 126 | d_acc_8 += x_simd * x_simd; 127 | x = &x[8..]; 128 | } 129 | let leftover = x.iter().map(|xi| xi * xi).fold(0.0, |acc, xi| acc + xi); 130 | leftover + d_acc_8.sum() + d_acc_16.sum() 131 | } 132 | -------------------------------------------------------------------------------- /goko/src/plugins/labels.rs: -------------------------------------------------------------------------------- 1 | //! Plugin for labels and metadata 2 | 3 | use super::*; 4 | use crate::covertree::node::CoverNode; 5 | use crate::covertree::CoverTreeReader; 6 | //use pointcloud::*; 7 | use std::sync::Arc; 8 | 9 | /// Wrapper around the summary found in the point cloud 10 | #[derive(Debug, Default)] 11 | pub struct NodeLabelSummary { 12 | /// The summary object, refenced counted to eliminate duplicates 13 | pub summary: Arc>, 14 | } 15 | 16 | impl Clone for NodeLabelSummary { 17 | fn clone(&self) -> Self { 18 | NodeLabelSummary { 19 | summary: Arc::clone(&self.summary), 20 | } 21 | } 22 | } 23 | 24 | impl NodePlugin for NodeLabelSummary {} 25 | 26 | /// Plug in that allows for summaries of labels to be attached to 27 | #[derive(Debug, Clone, Default)] 28 | pub struct LabelSummaryPlugin {} 29 | 30 | impl GokoPlugin for LabelSummaryPlugin { 31 | type NodeComponent = NodeLabelSummary; 32 | fn node_component( 33 | _parameters: &Self, 34 | my_node: &CoverNode, 35 | my_tree: &CoverTreeReader, 36 | ) -> Option { 37 | let mut bucket = my_tree 38 | .parameters() 39 | .point_cloud 40 | .label_summary(my_node.singletons()) 41 | .unwrap(); 42 | // If we're a routing node then grab the childen's values 43 | if let Some((nested_scale, child_addresses)) = my_node.children() { 44 | my_tree.get_node_plugin_and::( 45 | (nested_scale, *my_node.center_index()), 46 | |p| bucket.combine(p.summary.as_ref()), 47 | ); 48 | 49 | for ca in child_addresses { 50 | my_tree.get_node_plugin_and::(*ca, |p| { 51 | bucket.combine(p.summary.as_ref()) 52 | }); 53 | } 54 | } else { 55 | bucket.add( 56 | my_tree 57 | .parameters() 58 | .point_cloud 59 | .label(*my_node.center_index()), 60 | ); 61 | } 62 | Some(NodeLabelSummary { 63 | summary: Arc::new(bucket), 64 | }) 65 | } 66 | } 67 | 68 | /// Wrapper around the summary found in the point cloud 69 | #[derive(Debug, Default)] 70 | pub struct NodeMetaSummary { 71 | /// The summary object, refenced counted to eliminate duplicates 72 | pub summary: Arc>, 73 | } 74 | 75 | impl Clone for NodeMetaSummary { 76 | fn clone(&self) -> Self { 77 | NodeMetaSummary { 78 | summary: Arc::clone(&self.summary), 79 | } 80 | } 81 | } 82 | 83 | impl NodePlugin for NodeMetaSummary {} 84 | 85 | /// Plug in that allows for summaries of Metas to be attached to 86 | #[derive(Debug, Clone, Default)] 87 | pub struct MetaSummaryPlugin {} 88 | 89 | impl GokoPlugin for MetaSummaryPlugin { 90 | type NodeComponent = NodeMetaSummary; 91 | fn node_component( 92 | _parameters: &Self, 93 | my_node: &CoverNode, 94 | my_tree: &CoverTreeReader, 95 | ) -> Option { 96 | let mut bucket = my_tree 97 | .parameters() 98 | .point_cloud 99 | .metasummary(my_node.singletons()) 100 | .unwrap(); 101 | // If we're a routing node then grab the childen's values 102 | if let Some((nested_scale, child_addresses)) = my_node.children() { 103 | my_tree.get_node_plugin_and::( 104 | (nested_scale, *my_node.center_index()), 105 | |p| bucket.combine(p.summary.as_ref()), 106 | ); 107 | 108 | for ca in child_addresses { 109 | my_tree.get_node_plugin_and::(*ca, |p| { 110 | bucket.combine(p.summary.as_ref()) 111 | }); 112 | } 113 | } else { 114 | bucket.add( 115 | my_tree 116 | .parameters() 117 | .point_cloud 118 | .metadata(*my_node.center_index()), 119 | ); 120 | } 121 | Some(NodeMetaSummary { 122 | summary: Arc::new(bucket), 123 | }) 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /goko/src/covertree/query_tools/query_items.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to Elasticsearch B.V. under one or more contributor 3 | * license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright 5 | * ownership. Elasticsearch B.V. licenses this file to you under 6 | * the Apache License, Version 2.0 (the "License"); you may 7 | * not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | use crate::NodeAddress; 21 | use std::cmp::Ordering::{self, Less}; 22 | use std::f32; 23 | 24 | #[derive(Clone, Copy, Debug)] 25 | pub(crate) struct QueryAddress { 26 | pub(crate) min_dist: f32, 27 | pub(crate) dist_to_center: f32, 28 | pub(crate) address: NodeAddress, 29 | } 30 | 31 | impl PartialEq for QueryAddress { 32 | fn eq(&self, other: &QueryAddress) -> bool { 33 | other.address == self.address 34 | } 35 | } 36 | 37 | impl Eq for QueryAddress {} 38 | 39 | impl Ord for QueryAddress { 40 | fn cmp(&self, other: &QueryAddress) -> Ordering { 41 | self.partial_cmp(&other).unwrap_or(Ordering::Less) 42 | } 43 | } 44 | 45 | impl PartialOrd for QueryAddress { 46 | fn partial_cmp(&self, other: &QueryAddress) -> Option { 47 | // Backwards to make it a max heap. 48 | match other 49 | .min_dist 50 | .partial_cmp(&self.min_dist) 51 | .unwrap_or(Ordering::Equal) 52 | { 53 | Ordering::Greater => Some(Ordering::Greater), 54 | Ordering::Less => Some(Ordering::Less), 55 | Ordering::Equal => match other.address.0.cmp(&self.address.0) { 56 | Ordering::Greater => Some(Ordering::Greater), 57 | Ordering::Less => Some(Ordering::Less), 58 | Ordering::Equal => other.dist_to_center.partial_cmp(&self.dist_to_center), 59 | }, 60 | } 61 | } 62 | } 63 | 64 | #[derive(Clone, Copy, Debug)] 65 | pub(crate) struct QueryAddressRev { 66 | pub(crate) min_dist: f32, 67 | pub(crate) dist_to_center: f32, 68 | pub(crate) address: NodeAddress, 69 | } 70 | 71 | impl PartialEq for QueryAddressRev { 72 | fn eq(&self, other: &QueryAddressRev) -> bool { 73 | other.address == self.address 74 | } 75 | } 76 | 77 | impl Eq for QueryAddressRev {} 78 | 79 | impl Ord for QueryAddressRev { 80 | fn cmp(&self, other: &QueryAddressRev) -> Ordering { 81 | self.partial_cmp(&other).unwrap_or(Ordering::Less) 82 | } 83 | } 84 | 85 | impl PartialOrd for QueryAddressRev { 86 | fn partial_cmp(&self, other: &QueryAddressRev) -> Option { 87 | // Backwards to make it a max heap. 88 | match self 89 | .min_dist 90 | .partial_cmp(&other.min_dist) 91 | .unwrap_or(Ordering::Equal) 92 | { 93 | Ordering::Greater => Some(Ordering::Greater), 94 | Ordering::Less => Some(Ordering::Less), 95 | Ordering::Equal => match self.address.0.cmp(&other.address.0) { 96 | Ordering::Greater => Some(Ordering::Greater), 97 | Ordering::Less => Some(Ordering::Less), 98 | Ordering::Equal => self.dist_to_center.partial_cmp(&other.dist_to_center), 99 | }, 100 | } 101 | } 102 | } 103 | 104 | #[derive(Clone, Copy, Debug)] 105 | pub(crate) struct QuerySingleton { 106 | pub(crate) dist: f32, 107 | pub(crate) index: usize, 108 | } 109 | 110 | impl QuerySingleton { 111 | pub(crate) fn new(index: usize, dist: f32) -> QuerySingleton { 112 | QuerySingleton { dist, index } 113 | } 114 | } 115 | 116 | impl PartialEq for QuerySingleton { 117 | fn eq(&self, other: &QuerySingleton) -> bool { 118 | other.index == self.index 119 | } 120 | } 121 | 122 | impl Eq for QuerySingleton {} 123 | 124 | impl Ord for QuerySingleton { 125 | fn cmp(&self, other: &QuerySingleton) -> Ordering { 126 | self.partial_cmp(&other).unwrap_or(Less) 127 | } 128 | } 129 | 130 | impl PartialOrd for QuerySingleton { 131 | fn partial_cmp(&self, other: &QuerySingleton) -> Option { 132 | self.dist.partial_cmp(&other.dist) 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /serve_goko/src/api/tracker.rs: -------------------------------------------------------------------------------- 1 | use pointcloud::*; 2 | use goko::{NodeAddress, CoverTreeReader}; 3 | use goko::plugins::discrete::tracker::BayesCategoricalTracker; 4 | use crate::core::internal_service::*; 5 | use goko::errors::GokoError; 6 | use std::ops::Deref; 7 | 8 | use serde::{Deserialize, Serialize}; 9 | use std::collections::HashMap; 10 | 11 | use super::{TrackingRequest, TrackingRequestChoice, TrackingResponse}; 12 | 13 | #[derive(Deserialize, Serialize)] 14 | pub struct TrackPointRequest { 15 | pub point: T, 16 | } 17 | 18 | #[derive(Deserialize, Serialize)] 19 | pub struct TrackPathRequest { 20 | pub path: Vec<(f32, NodeAddress)>, 21 | } 22 | 23 | #[derive(Deserialize, Serialize)] 24 | pub struct TrackPathResponse { 25 | pub success: bool, 26 | } 27 | 28 | #[derive(Deserialize, Serialize)] 29 | pub struct AddTrackerRequest { 30 | pub window_size: usize, 31 | } 32 | #[derive(Deserialize, Serialize)] 33 | pub struct AddTrackerResponse { 34 | pub success: bool, 35 | } 36 | 37 | #[derive(Deserialize, Serialize)] 38 | pub struct CurrentStatsRequest { 39 | pub window_size: usize, 40 | } 41 | 42 | #[derive(Deserialize, Serialize)] 43 | pub struct CurrentStatsResponse { 44 | pub kl_div: f64, 45 | pub max: f64, 46 | pub min: f64, 47 | pub nz_count: u64, 48 | pub moment1_nz: f64, 49 | pub moment2_nz: f64, 50 | pub sequence_len: usize, 51 | } 52 | 53 | 54 | pub struct TrackerWorker { 55 | reader: CoverTreeReader, 56 | trackers: HashMap>, 57 | } 58 | 59 | impl TrackerWorker { 60 | pub fn new(reader: CoverTreeReader) -> TrackerWorker { 61 | TrackerWorker { 62 | reader, 63 | trackers: HashMap::new(), 64 | } 65 | } 66 | 67 | pub(crate) fn operator + Send + Sync + 'static>(reader: CoverTreeReader) -> InternalServiceOperator, TrackingResponse> { 68 | let worker = TrackerWorker { 69 | reader, 70 | trackers: HashMap::new(), 71 | }; 72 | InternalServiceOperator::new(worker) 73 | } 74 | } 75 | 76 | impl + Send + Sync> InternalService, TrackingResponse> for TrackerWorker { 77 | fn process(&mut self, request: TrackingRequest) -> Result { 78 | use TrackingRequestChoice::*; 79 | match request.request { 80 | TrackPoint(req) => { 81 | let path = self.reader.path(&req.point)?; 82 | for tracker in self.trackers.values_mut() { 83 | tracker.add_path(path.clone()); 84 | } 85 | 86 | Ok(TrackingResponse::TrackPath(TrackPathResponse { 87 | success: !self.trackers.is_empty(), 88 | })) 89 | } 90 | TrackPath(req) => { 91 | for tracker in self.trackers.values_mut() { 92 | tracker.add_path(req.path.clone()); 93 | } 94 | Ok(TrackingResponse::TrackPath(TrackPathResponse { 95 | success: true, 96 | })) 97 | } 98 | AddTracker(req) => { 99 | if self.trackers.contains_key(&req.window_size) { 100 | Ok(TrackingResponse::AddTracker(AddTrackerResponse { 101 | success: false, 102 | })) 103 | } else { 104 | self.trackers.insert(req.window_size, BayesCategoricalTracker::new(req.window_size, self.reader.clone())); 105 | Ok(TrackingResponse::AddTracker(AddTrackerResponse { 106 | success: true, 107 | })) 108 | } 109 | } 110 | CurrentStats(req) => { 111 | if let Some(tracker) = self.trackers.get(&req.window_size) { 112 | let stats = tracker.kl_div_stats(); 113 | let kl_div = tracker.kl_div(); 114 | Ok(TrackingResponse::CurrentStats(CurrentStatsResponse { 115 | kl_div, 116 | max: stats.max, 117 | min: stats.min, 118 | nz_count: stats.nz_count, 119 | moment1_nz: stats.moment1_nz, 120 | moment2_nz: stats.moment2_nz, 121 | sequence_len: stats.sequence_len, 122 | })) 123 | } else { 124 | Ok(TrackingResponse::Unknown(request.tracker_name.clone(),Some(req.window_size))) 125 | } 126 | } 127 | } 128 | } 129 | } -------------------------------------------------------------------------------- /goko/src/plugins/mod.rs: -------------------------------------------------------------------------------- 1 | //! # Plugin System 2 | //! 3 | //! To implement a plugin you need to write 2 components one which implements `NodePlugin` and another that implements `TreePlugin`. 4 | //! Finally you need to create an object that implements the parent trait that glues the two objects together. 5 | //! 6 | //! The `NodePlugin` is attached to each node. It is created by the `node_component` function in the parent trait when the plugin is 7 | //! attached to the tree. It can access the `TreePlugin` component, and the tree. These are created recursively, so you can access the 8 | //! plugin for the child nodes. 9 | //! 10 | //! None of this is parallelized. We need to move to Tokio to take advantage of the async computation there to || it. 11 | 12 | use crate::covertree::node::CoverNode; 13 | use crate::covertree::CoverTreeReader; 14 | use crate::*; 15 | use std::fmt::Debug; 16 | use type_map::concurrent::TypeMap; 17 | 18 | pub mod discrete; 19 | pub mod gaussians; 20 | pub mod labels; 21 | pub mod utils; 22 | 23 | /// Mockup for the plugin interface attached to the node. These are meant to be functions that Goko uses to maintain the plugin. 24 | pub trait NodePlugin: Send + Sync + Debug {} 25 | 26 | /// Parent trait that make this all work. Ideally this should be included in the `TreePlugin` but rust doesn't like it. 27 | pub trait GokoPlugin: Send + Sync + Debug + Clone + 'static { 28 | /// The node component of this plugin, these are attached to each node recursively when the plug in is attached to the tree. 29 | type NodeComponent: NodePlugin + Clone + 'static; 30 | /// This is called just before we build the tree to prepare it for the upcomming plugin creations. 31 | fn prepare_tree(_parameters: &Self, _my_tree: &mut CoverTreeWriter) {} 32 | /// The function that actually builds the node components. 33 | fn node_component( 34 | parameters: &Self, 35 | my_node: &CoverNode, 36 | my_node: &CoverTreeReader, 37 | ) -> Option; 38 | } 39 | 40 | pub(crate) type NodePluginSet = TypeMap; 41 | pub(crate) type TreePluginSet = TypeMap; 42 | 43 | #[cfg(test)] 44 | pub(crate) mod tests { 45 | use super::*; 46 | use crate::covertree::tests::build_basic_tree; 47 | 48 | #[derive(Debug, Clone)] 49 | struct DumbNode1 { 50 | id: u32, 51 | pi: usize, 52 | cover_count: usize, 53 | } 54 | 55 | impl NodePlugin for DumbNode1 {} 56 | 57 | #[derive(Debug, Clone)] 58 | struct DumbGoko1 { 59 | id: u32, 60 | } 61 | 62 | impl GokoPlugin for DumbGoko1 { 63 | type NodeComponent = DumbNode1; 64 | fn node_component( 65 | parameters: &Self, 66 | my_node: &CoverNode, 67 | my_tree: &CoverTreeReader, 68 | ) -> Option { 69 | println!( 70 | "Building Dumb Plugin for {:?}", 71 | (my_node.scale_index(), my_node.center_index()) 72 | ); 73 | let cover_count = match my_node.children() { 74 | None => my_node.singletons_len(), 75 | Some((nested_scale, child_addresses)) => { 76 | println!( 77 | "trying to get at the nodes at {:?}", 78 | (nested_scale, child_addresses) 79 | ); 80 | let mut cover_count = my_tree 81 | .get_node_plugin_and::( 82 | (nested_scale, *my_node.center_index()), 83 | |p| p.cover_count, 84 | ) 85 | .unwrap(); 86 | for ca in child_addresses { 87 | cover_count += my_tree 88 | .get_node_plugin_and::(*ca, |p| { 89 | p.cover_count 90 | }) 91 | .unwrap(); 92 | } 93 | cover_count 94 | } 95 | }; 96 | Some(DumbNode1 { 97 | id: parameters.id, 98 | pi: *my_node.center_index(), 99 | cover_count, 100 | }) 101 | } 102 | } 103 | 104 | #[test] 105 | fn dumb_plugins() { 106 | let d = DumbGoko1 { id: 1 }; 107 | let mut tree = build_basic_tree(); 108 | tree.add_plugin::(d); 109 | println!("{:?}", tree.reader().len()); 110 | for (si, layer) in tree.reader().layers() { 111 | println!("Scale Index: {:?}", si); 112 | layer.for_each_node(|pi, n| { 113 | println!("Node: {:?}", n); 114 | n.get_plugin_and::(|dp| { 115 | println!("DumbNodes: {:?}", dp); 116 | assert_eq!(*pi, dp.pi); 117 | }); 118 | }); 119 | } 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /serve_goko/src/errors.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | use std::fmt; 3 | use goko::errors::GokoError; 4 | use tokio::sync::oneshot; 5 | 6 | pub enum InternalServiceError { 7 | Other(GokoError), 8 | FailedSend, 9 | FailedRecv, 10 | FailedRespSend, 11 | DoubleRead, 12 | ClientDropped, 13 | } 14 | 15 | impl From for InternalServiceError { 16 | fn from(_e: oneshot::error::RecvError) -> InternalServiceError { 17 | InternalServiceError::FailedRecv 18 | } 19 | } 20 | 21 | impl From for InternalServiceError { 22 | fn from(e: GokoError) -> InternalServiceError { 23 | InternalServiceError::Other(e) 24 | } 25 | } 26 | 27 | impl fmt::Display for InternalServiceError { 28 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 29 | use InternalServiceError::*; 30 | match *self { 31 | FailedSend => f.pad("Send Failed"), 32 | Other(ref se) => fmt::Display::fmt(&se, f), 33 | FailedRecv => f.pad("Recv Failed"), 34 | FailedRespSend => f.pad("Unable to Respond, client hung up."), 35 | DoubleRead => f.pad("Attempted to read a message twice"), 36 | ClientDropped => f.pad("Client Dropped"), 37 | } 38 | } 39 | } 40 | 41 | impl fmt::Debug for InternalServiceError { 42 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 43 | use InternalServiceError::*; 44 | match *self { 45 | FailedSend => f.pad("SendFailed"), 46 | Other(ref se) => write!(f, "Other({:?})", se), 47 | FailedRecv => f.pad("RecvFailed"), 48 | FailedRespSend => f.pad("FailedRespSend"), 49 | DoubleRead => f.pad("DoubleRead"), 50 | ClientDropped => f.pad("ClientDropped"), 51 | } 52 | } 53 | } 54 | 55 | impl Error for InternalServiceError { 56 | fn source(&self) -> Option<&(dyn Error + 'static)> { 57 | use InternalServiceError::*; 58 | match *self { 59 | FailedSend => None, 60 | Other(ref e) => e.source(), 61 | FailedRecv => None, 62 | FailedRespSend => None, 63 | DoubleRead => None, 64 | ClientDropped => None, 65 | } 66 | } 67 | } 68 | 69 | //use serde::{Deserialize, Serialize}; 70 | // 71 | pub enum GokoClientError { 72 | Underlying(InternalServiceError), 73 | MalformedQuery(&'static str), 74 | Http(hyper::Error), 75 | Parse(Box), 76 | MissingBody, 77 | } 78 | 79 | impl GokoClientError { 80 | pub fn parse(err: Box) -> Self { 81 | GokoClientError::Parse(err) 82 | } 83 | } 84 | 85 | impl From for GokoClientError { 86 | fn from(e: GokoError) -> GokoClientError { 87 | GokoClientError::Underlying(InternalServiceError::Other(e)) 88 | } 89 | } 90 | 91 | impl From for GokoClientError { 92 | fn from(_e: oneshot::error::RecvError) -> GokoClientError { 93 | GokoClientError::Underlying(InternalServiceError::FailedRecv) 94 | } 95 | } 96 | 97 | impl From for GokoClientError { 98 | fn from(e: InternalServiceError) -> GokoClientError { 99 | GokoClientError::Underlying(e) 100 | } 101 | } 102 | 103 | impl From for GokoClientError { 104 | fn from(e: hyper::Error) -> GokoClientError { 105 | GokoClientError::Http(e) 106 | } 107 | } 108 | 109 | impl fmt::Display for GokoClientError { 110 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 111 | match *self { 112 | GokoClientError::Underlying(ref se) => fmt::Display::fmt(se, f), 113 | GokoClientError::MalformedQuery(ref se) => fmt::Display::fmt(se, f), 114 | GokoClientError::Http(ref se) => fmt::Display::fmt(se, f), 115 | GokoClientError::Parse(ref se) => fmt::Display::fmt(se, f), 116 | GokoClientError::MissingBody => f.pad("Body Missing"), 117 | } 118 | } 119 | } 120 | 121 | impl fmt::Debug for GokoClientError { 122 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 123 | match *self { 124 | GokoClientError::Underlying(ref se) => write!(f, "Underlying({:?})", se), 125 | GokoClientError::MalformedQuery(ref se) => write!(f, "MalformedQuery({:?})", se), 126 | GokoClientError::Http(ref se) => write!(f, "Http({:?})", se), 127 | GokoClientError::Parse(ref se) => write!(f, "Underlying({:?})", se), 128 | GokoClientError::MissingBody => f.pad("MissingBody"), 129 | } 130 | } 131 | } 132 | 133 | impl Error for GokoClientError { 134 | fn source(&self) -> Option<&(dyn Error + 'static)> { 135 | match *self { 136 | GokoClientError::Underlying(ref se) => Some(se), 137 | GokoClientError::Http(ref se) => Some(se), 138 | GokoClientError::Parse(ref se) => se.source(), 139 | GokoClientError::MalformedQuery(_) => None, 140 | GokoClientError::MissingBody => None, 141 | } 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /serve_goko/src/core/internal_service.rs: -------------------------------------------------------------------------------- 1 | use tokio::sync::{mpsc, oneshot}; 2 | use pin_project::pin_project; 3 | use goko::errors::GokoError; 4 | use core::task::Context; 5 | use std::future::Future; 6 | use std::pin::Pin; 7 | use std::task::Poll; 8 | 9 | use crate::errors::*; 10 | 11 | use std::sync::{atomic, Arc, Mutex}; 12 | 13 | pub(crate) type CoreRequestSender = mpsc::UnboundedSender>; 14 | pub(crate) type CoreRequestReciever = mpsc::UnboundedReceiver>; 15 | pub(crate) type CoreResponseSender = oneshot::Sender>; 16 | pub(crate) type CoreResponseReciever = oneshot::Receiver>; 17 | 18 | 19 | #[pin_project] 20 | pub(crate) struct Message { 21 | pub(crate) request: Option, 22 | pub(crate) reply: Option>, 23 | pub(crate) global_error: Arc>>>, 24 | } 25 | 26 | impl Message { 27 | pub(crate) fn request(&mut self) -> Option { 28 | self.request.take() 29 | } 30 | 31 | pub(crate) fn respond(&mut self, response: Result) { 32 | match self.reply.take() { 33 | Some(reply) => { 34 | match reply.send(response.map_err(|e| InternalServiceError::from(e))) { 35 | Ok(_) => (), 36 | Err(_) => { 37 | *self.global_error.lock().unwrap() = Some(Box::new(InternalServiceError::FailedRespSend)); 38 | } 39 | } 40 | } 41 | None => *self.global_error.lock().unwrap() = Some(Box::new(InternalServiceError::DoubleRead)), 42 | } 43 | } 44 | pub(crate) fn error(&mut self, error: impl std::error::Error + Send + 'static) { 45 | *self.global_error.lock().unwrap() = Some(Box::new(error)); 46 | } 47 | } 48 | 49 | #[pin_project] 50 | pub struct ResponseFuture { 51 | #[pin] 52 | pub(crate) response: oneshot::Receiver>, 53 | pub(crate) flight_counter: Arc, 54 | pub(crate) error: Option, 55 | } 56 | 57 | impl Future for ResponseFuture { 58 | type Output = Result; 59 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 60 | let this = self.project(); 61 | if let Some(err) = this.error.take() { 62 | return core::task::Poll::Ready(Err(err)); 63 | } 64 | else { 65 | let res = this.response.poll(cx).map(|r| { 66 | match r { 67 | Ok(r) => r.map_err(|e| e.into()), 68 | Err(e) => Err(e.into()) 69 | } 70 | }); 71 | this.flight_counter.fetch_sub(1, atomic::Ordering::SeqCst); 72 | res 73 | } 74 | } 75 | } 76 | 77 | pub trait InternalService: Send { 78 | fn process(&mut self, request: T) -> Result; 79 | } 80 | 81 | #[derive(Clone)] 82 | pub(crate) struct InternalServiceOperator { 83 | in_flight: Arc, 84 | request_snd: CoreRequestSender, 85 | global_error: Arc>>>, 86 | } 87 | 88 | impl InternalServiceOperator { 89 | 90 | pub(crate) fn new + 'static>(mut server: P) -> InternalServiceOperator { 91 | let (request_snd, mut request_rcv): (CoreRequestSender, CoreRequestReciever) = 92 | mpsc::unbounded_channel(); 93 | tokio::spawn(async move { 94 | while let Some(mut msg) = request_rcv.recv().await { 95 | if let Some(request) = msg.request() { 96 | let response = server.process(request); 97 | msg.respond(response); 98 | } else { 99 | msg.error(InternalServiceError::DoubleRead) 100 | } 101 | } 102 | }); 103 | let global_error = Arc::new(Mutex::new(None)); 104 | let in_flight = Arc::new(atomic::AtomicU32::new(0)); 105 | InternalServiceOperator { 106 | in_flight, 107 | request_snd, 108 | global_error, 109 | } 110 | } 111 | 112 | 113 | pub(crate) fn message(&self, request: T) -> ResponseFuture { 114 | let flight_counter = Arc::clone(&self.in_flight); 115 | self.in_flight.fetch_add(1, atomic::Ordering::SeqCst); 116 | let (reply, response): (CoreResponseSender, CoreResponseReciever) = oneshot::channel(); 117 | 118 | let msg = Message { 119 | request: Some(request), 120 | reply: Some(reply), 121 | global_error: Arc::clone(&self.global_error), 122 | }; 123 | 124 | let error = self.request_snd.send(msg).err().map(|_e| InternalServiceError::FailedSend); 125 | ResponseFuture { 126 | response, 127 | flight_counter, 128 | error, 129 | } 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /goko/src/errors.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to Elasticsearch B.V. under one or more contributor 3 | * license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright 5 | * ownership. Elasticsearch B.V. licenses this file to you under 6 | * the Apache License, Version 2.0 (the "License"); you may 7 | * not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | //! The errors that can occor when a cover tree is loading, working or saving. 21 | //! Most errors are floated up from `PointCloud` as that's the i/o layer. 22 | 23 | use pointcloud::pc_errors::PointCloudError; 24 | use protobuf::ProtobufError; 25 | use std::error::Error; 26 | use std::fmt; 27 | use std::io; 28 | use std::str; 29 | 30 | /// Helper type for a call that could go wrong. 31 | pub type GokoResult = Result; 32 | 33 | /// Error type for MalwareBrot. Mostly this is a wrapper around `PointCloudError`, as the data i/o where most errors happen. 34 | #[derive(Debug)] 35 | pub enum GokoError { 36 | /// Unable to retrieve some data point (given by index) in a file (slice name) 37 | PointCloudError(PointCloudError), 38 | /// Most common error, the given point name isn't present in the training data 39 | IndexNotInTree(usize), 40 | /// Parsing error when loading a CSV file 41 | ProtobufError(ProtobufError), 42 | /// Parsing error when loading a CSV file 43 | IoError(io::Error), 44 | /// The probability distribution you are trying to sample from is invalid, probably because it was infered from 0 points. 45 | InvalidProbDistro, 46 | /// Inserted a nested node into a node that already had a nested child 47 | DoubleNest, 48 | /// Inserted a node before you changed it from a leaf node into a normal node. Insert the nested child first. 49 | InsertBeforeNest, 50 | } 51 | 52 | impl fmt::Display for GokoError { 53 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 54 | match *self { 55 | GokoError::PointCloudError(ref e) => write!(f, "{}", e), 56 | GokoError::ProtobufError(ref e) => write!(f, "{}", e), 57 | GokoError::IoError(ref e) => write!(f, "{}", e), 58 | GokoError::IndexNotInTree { .. } => { 59 | write!(f, "there was an issue grabbing a name from the known names") 60 | } 61 | GokoError::DoubleNest => write!( 62 | f, 63 | "Inserted a nested node into a node that already had a nested child" 64 | ), 65 | GokoError::InvalidProbDistro => write!( 66 | f, 67 | "The probability distribution you are trying to sample from is invalid, probably because it was infered from 0 points." 68 | ), 69 | GokoError::InsertBeforeNest => write!( 70 | f, 71 | "Inserted a node into a node that does not have a nested child" 72 | ), 73 | } 74 | } 75 | } 76 | 77 | #[allow(deprecated)] 78 | impl Error for GokoError { 79 | fn description(&self) -> &str { 80 | match *self { 81 | GokoError::PointCloudError(ref e) => e.description(), 82 | GokoError::ProtobufError(ref e) => e.description(), 83 | GokoError::IoError(ref e) => e.description(), 84 | GokoError::IndexNotInTree { .. } => { 85 | "there was an issue grabbing a name from the known names" 86 | } 87 | GokoError::DoubleNest => { 88 | "Inserted a nested node into a node that already had a nested child" 89 | } 90 | GokoError::InsertBeforeNest => { 91 | "Inserted a node into a node that does not have a nested child" 92 | } 93 | GokoError::InvalidProbDistro => { 94 | "The probability distribution you are trying to sample from is invalid, probably because it was infered from 0 points." 95 | } 96 | } 97 | } 98 | 99 | fn cause(&self) -> Option<&dyn Error> { 100 | match *self { 101 | GokoError::PointCloudError(ref e) => Some(e), 102 | GokoError::ProtobufError(ref e) => Some(e), 103 | GokoError::IoError(ref e) => Some(e), 104 | GokoError::IndexNotInTree { .. } => None, 105 | GokoError::DoubleNest => None, 106 | GokoError::InsertBeforeNest => None, 107 | GokoError::InvalidProbDistro => None, 108 | } 109 | } 110 | } 111 | 112 | impl From for GokoError { 113 | fn from(err: PointCloudError) -> Self { 114 | GokoError::PointCloudError(err) 115 | } 116 | } 117 | 118 | impl From for GokoError { 119 | fn from(err: ProtobufError) -> Self { 120 | GokoError::ProtobufError(err) 121 | } 122 | } 123 | 124 | impl From for GokoError { 125 | fn from(err: io::Error) -> Self { 126 | GokoError::IoError(err) 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /examples/mnist_knn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "language_info": { 4 | "codemirror_mode": { 5 | "name": "ipython", 6 | "version": 3 7 | }, 8 | "file_extension": ".py", 9 | "mimetype": "text/x-python", 10 | "name": "python", 11 | "nbconvert_exporter": "python", 12 | "pygments_lexer": "ipython3", 13 | "version": "3.8.5-final" 14 | }, 15 | "orig_nbformat": 2, 16 | "kernelspec": { 17 | "name": "python3", 18 | "display_name": "Python 3" 19 | } 20 | }, 21 | "nbformat": 4, 22 | "nbformat_minor": 2, 23 | "cells": [ 24 | { 25 | "source": [ 26 | "# Basic MNIST Example\n", 27 | "\n", 28 | "This basic example shows loading from a YAML file. You can specify all the parameters in the yaml file, but we're going to load the raw data using tensorflow.\n" 29 | ], 30 | "cell_type": "markdown", 31 | "metadata": {} 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "import numpy as np\n", 40 | "import tensorflow as tf\n", 41 | "from pygoko import CoverTree\n", 42 | "mnist = tf.keras.datasets.mnist\n", 43 | "\n", 44 | "(x_train, y_train),(x_test, y_test) = mnist.load_data()\n", 45 | "x_train, x_test = x_train / 255.0, x_test / 255.0\n", 46 | "x_train = x_train.astype(np.float32)\n", 47 | "y_train = y_train.astype(np.int64)\n", 48 | "x_train = x_train.reshape(-1, 28*28)" 49 | ] 50 | }, 51 | { 52 | "source": [ 53 | "Here we build the covertree, with a leaf cutoff and a minimum resolution index to control the size of the tree. \n", 54 | "\n", 55 | "The minimum resolution index is the scale at which the tree stops splitting. This can be used to control the L2 error (we use the standard fast implementation, which is not the most accurate), or to specify a scale at which the KNN doesn't matter to you. \n", 56 | "\n", 57 | "The leaf cutoff controls the size of individual leafs of the tree. If a node covers fewer than this number of points, the splitting stops and the node becomes a leaf. \n", 58 | "\n", 59 | "The scale base controls the down-step of each split. 1.3 is a good default. It is usually close to the fastest at creating the tree but can be hard to reason about. Another popular choice is 2, which means the radius halves at each step. " 60 | ], 61 | "cell_type": "markdown", 62 | "metadata": {} 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "tree = CoverTree()\n", 71 | "tree.set_leaf_cutoff(10)\n", 72 | "tree.set_scale_base(1.3)\n", 73 | "tree.set_min_res_index(-20)\n", 74 | "tree.fit(x_train,y_train)" 75 | ] 76 | }, 77 | { 78 | "source": [ 79 | "Here's the basic KNN for this data structure. " 80 | ], 81 | "cell_type": "markdown", 82 | "metadata": {} 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "point = np.zeros([784], dtype=np.float32)\n", 91 | "tree.knn(point,5)" 92 | ] 93 | }, 94 | { 95 | "source": [ 96 | "The nodes are addressable by specifying the scale index, and the point index (in the originating dataset). This errors out if you supply an address that isn't known tot he tree. (Currently this is rust panicing about you unwrapping an option that is a None). Only use this creation method with known, correct, addresses." 97 | ], 98 | "cell_type": "markdown", 99 | "metadata": {} 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "root = tree.root()\n", 108 | "print(f\"Root address: {root.address()}\")\n", 109 | "for child in root.children():\n", 110 | " child_address = child.address()\n", 111 | " # The following is the same node as the child:\n", 112 | " copy_of_child = tree.node(child_address)\n", 113 | " print(f\" Child address: {copy_of_child.address()}\")" 114 | ] 115 | }, 116 | { 117 | "source": [ 118 | "If a query point were to belong to the dataset that the tree was constructed from, but was never selected as a routing node, then it would end up at a particular leaf node. This leaf node is deterministic (given the pre-built tree). The path for the query point is the addresses of the nodes from the root node to this leaf." 119 | ], 120 | "cell_type": "markdown", 121 | "metadata": {} 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "path = tree.path(point)\n", 130 | "print(path)\n", 131 | "\n", 132 | "print(\"Summary of the labels of points covered by the node at address\")\n", 133 | "for dist, address in path:\n", 134 | " node = tree.node(address)\n", 135 | " label_summary = node.label_summary()\n", 136 | " print(f\"Address: {address} \\t Summary: {label_summary}\")" 137 | ] 138 | }, 139 | { 140 | "source": [ 141 | "We can also query for the path of known points, by index in the original dataset." 142 | ], 143 | "cell_type": "markdown", 144 | "metadata": {} 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "path = tree.known_path(40000)\n", 153 | "\n", 154 | "print(\"Summary of the labels of points covered by the node at address\")\n", 155 | "for dist, address in path:\n", 156 | " node = tree.node(address)\n", 157 | " label_summary = node.label_summary()\n", 158 | " print(f\"Address: {address} \\t Summary: {label_summary}\")\n", 159 | "\n" 160 | ] 161 | } 162 | ] 163 | } -------------------------------------------------------------------------------- /examples/ember_chronological_drift.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | This is an example that tracks chronological drift in the ember dataset. We train on the ember dataset on data before 2018-07, 5 | and then run everything through it. There's a massive increase in the total KL-div after the cutoff, so this does detect a 6 | shift in the dataset. 7 | """ 8 | 9 | import os 10 | import pandas as pd 11 | import ember 12 | import argparse 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | from pygoko import CoverTree 16 | 17 | 18 | def main(): 19 | prog = "ember_drift_calc" 20 | descr = "Train an ember model from a directory with raw feature files" 21 | parser = argparse.ArgumentParser(prog=prog, description=descr) 22 | parser.add_argument("datadir", metavar="DATADIR", type=str, help="Directory with raw features") 23 | args = parser.parse_args() 24 | 25 | training_data, all_data, X_month = sort_ember_dataset(datadir = args.datadir, split_date = "2018-07") 26 | 27 | # Build the tree 28 | tree = CoverTree() 29 | tree.set_leaf_cutoff(50) 30 | tree.set_scale_base(1.5) 31 | tree.set_min_res_index(0) 32 | tree.fit(training_data) 33 | 34 | # Gather a baseline 35 | prior_weight = 1.0 36 | observation_weight = 1.3 37 | # 0 sets the window to be infinite, otherwise the "dataset" you're computing against is only the last `window_size` elements 38 | window_size = 5000 39 | # We don't use this, our sequences are windowed so we only compute the KL Div on (at most) the last window_size elements 40 | sequence_len = 800000 41 | # Actually computes the KL div this often. All other values are linearly interpolated between these sample points. 42 | # It's too slow to calculate each value and this is accurate enough. 43 | sample_rate = 10 44 | # Gets the mean and variance over this number of simulated sequence. 45 | sequence_count = 50 46 | 47 | ''' 48 | We gather a baseline object. When you feed the entire dataset the covertree was created from to itself, 49 | you will get a non-zero KL-Div on any node that is non-trivial. This process will weight the node's posterior Dirichlet distribution, 50 | multiplying the internal weights by (prior_weight + observation_weight). This posterior distribution has a lower variance than the prior and 51 | the expected KL-divergence between the unknown distributions we're modeling is thus non-zero. 52 | 53 | This slowly builds up, but we expect a non-zero KL-div over the nodes as we feed in-distribution data in. This object estimates that, and 54 | allows us to normalize this natural variance away. 55 | ''' 56 | baseline = tree.kl_div_dirichlet_baseline( 57 | prior_weight, 58 | observation_weight, 59 | window_size, 60 | sequence_count, 61 | sample_rate) 62 | goko_divs = {} 63 | 64 | """ 65 | This is the actual object that computes the KL Divergence statistics between the samples we feed in and the new samples. 66 | 67 | Internally, it is an evidence hashmap containing categorical distributions, and a queue of paths. 68 | The sample's path is computed, we then push it onto the queue and update the evidence by incrementing the correct buckets 69 | in the evidence hashmap. If the queue is full, we pop off the oldest path and decrement the correct paths in the queue. 70 | """ 71 | run_tracker = tree.kl_div_dirichlet( 72 | prior_weight, 73 | observation_weight, 74 | window_size) 75 | 76 | total_kl_div = [] 77 | 78 | for i,datum in enumerate(all_data): 79 | run_tracker.push(datum) 80 | if i % 500 == 0: 81 | goko_divs[i] = normalize(baseline,run_tracker.stats()) 82 | total_kl_div.append(goko_divs[i]['moment1_nz']) 83 | 84 | 85 | fig, ax = plt.subplots() 86 | ax.plot(list(range(0,len(all_data),500)),total_kl_div) 87 | ax.set_ylabel('KL Divergence') 88 | ax.set_xlabel('Sample Timestamp') 89 | tick_len = 0 90 | cutoff_len = 0 91 | tick_locations = [] 92 | dates = [d for d in X_month.keys()] 93 | for date in dates: 94 | if date == "2018-07": 95 | cutoff_len = tick_len 96 | tick_len += len(X_month[date]) 97 | tick_locations.append(tick_len) 98 | ax.set_xticks(tick_locations) 99 | ax.set_xticklabels(dates) 100 | ax.axvline(x=cutoff_len, linewidth=4, color='r') 101 | fig.tight_layout() 102 | fig.savefig("drift.png", bbox_inches='tight') 103 | plt.show() 104 | plt.close() 105 | 106 | def normalize(baseline,stats): 107 | """ 108 | Grabs the mean and variance from the baseline and normalizes the stats object passed in by subtracting 109 | the norm and dividing by the standard deviation. 110 | """ 111 | basesline_stats = baseline.stats(stats["sequence_len"]) 112 | normalized = {} 113 | for k in basesline_stats.keys(): 114 | n = (stats[k]-basesline_stats[k]["mean"]) 115 | if basesline_stats[k]["var"] > 0: 116 | n = n/np.sqrt(basesline_stats[k]["var"]) 117 | normalized[k] = n 118 | return normalized 119 | 120 | def sort_ember_dataset(datadir,split_date): 121 | """ 122 | Opens the dataset and creates a training dataset consisting of everything before the split date. 123 | 124 | Returns the training dataset and all data 125 | """ 126 | X, _ = ember.read_vectorized_features(datadir,"train") 127 | metadata = pd.read_csv(os.path.join(datadir, "train_metadata.csv"), index_col=0) 128 | dates = list(set(metadata['appeared'])) 129 | dates.sort() 130 | 131 | X_month = {k:X[metadata['appeared'] == k] for k in dates} 132 | 133 | training_dates = [d for d in dates if d < split_date] 134 | all_dates = [d for d in dates] 135 | 136 | training_data = np.concatenate([X_month[k] for k in training_dates]) 137 | training_data = np.ascontiguousarray(training_data) 138 | 139 | all_data = np.concatenate([X_month[k] for k in all_dates]) 140 | all_data = np.ascontiguousarray(all_data) 141 | 142 | return training_data, all_data, X_month 143 | 144 | 145 | if __name__ == '__main__': 146 | main() 147 | -------------------------------------------------------------------------------- /serve_goko/src/api/mod.rs: -------------------------------------------------------------------------------- 1 | use std::ops::Deref; 2 | use pointcloud::{PointCloud, SummaryCounter, Summary}; 3 | use crate::errors::InternalServiceError; 4 | use crate::core::CoreReader; 5 | 6 | use serde::{Deserialize, Serialize}; 7 | //use std::convert::Infallible; 8 | 9 | mod parameters; 10 | mod path; 11 | mod knn; 12 | mod tracker; 13 | 14 | pub use parameters::*; 15 | pub use path::*; 16 | pub use tracker::*; 17 | pub use knn::*; 18 | 19 | /// A summary for a small number of categories. 20 | #[derive(Deserialize, Serialize)] 21 | pub enum GokoRequest { 22 | /// With the HTTP server, send a `GET` request to `/` for this. 23 | /// 24 | /// Response: [`ParametersResponse`] 25 | Parameters(ParametersRequest), 26 | /// With the HTTP server, send a `GET` request to `/knn?k=5` with a set of features in the body for this query, 27 | /// will return with the response with the nearest 5 routing nbrs. 28 | /// 29 | /// See the chosen body parser for how to encode the body. 30 | /// 31 | /// Response: [`KnnResponse`] 32 | Knn(KnnRequest), 33 | /// With the HTTP server, send a `GET` request to `/routing_knn?k=5` with a set of features in the body for this query, will return with the response with the nearest 5 routing nbrs. 34 | /// 35 | /// See the chosen body parser for how to encode the body. 36 | /// 37 | /// Response: [`KnnResponse`] 38 | RoutingKnn(RoutingKnnRequest), 39 | /// With the HTTP server, send a `GET` request to `/path` with a set of features in the body for this query, will return with the response the path to the node this point belongs to. 40 | /// 41 | /// See the chosen body parser for how to encode the body. 42 | /// 43 | /// Response: [`PathResponse`] 44 | Path(PathRequest), 45 | /// The queries to manipulate the trackers, all under /track/ 46 | /// 47 | /// See : [`TrackingRequest`] 48 | Tracking(TrackingRequest), 49 | /// The catch-all for errors 50 | Unknown(String, u16), 51 | } 52 | #[derive(Deserialize, Serialize)] 53 | pub struct TrackingRequest { 54 | pub tracker_name: Option, 55 | pub request: TrackingRequestChoice, 56 | } 57 | 58 | 59 | #[derive(Deserialize, Serialize)] 60 | pub enum TrackingRequestChoice { 61 | /// Track a point, send a `POST` request to `/track/point?tracker_name=TRACKER_NAME` with a set of features in the body for this query. 62 | /// Omit the `TRACKER_NAME` query to use the default. You 63 | /// 64 | /// See the chosen body parser for how to encode the body. 65 | /// 66 | /// Response: [`TrackPathResponse`] 67 | TrackPoint(TrackPointRequest), 68 | /// Unsupported for HTTP 69 | /// 70 | /// Response: [`TrackPathResponse`] 71 | TrackPath(TrackPathRequest), 72 | /// Add a tracker, send a `POST` request to `/track/add?window_size=WINDOW_SIZE&tracker_name=TRACKER_NAME` with a set of features in the body for this query. 73 | /// Omit the `TRACKER_NAME` query to use the default. 74 | /// 75 | /// Response: [`AddTrackerResponse`] 76 | AddTracker(AddTrackerRequest), 77 | /// Get the status of a tracker, send a `GET` request to `/track/stats?window_size=WINDOW_SIZE&tracker_name=TRACKER_NAME`. 78 | /// Omit the `TRACKER_NAME` query to use the default. 79 | /// 80 | /// Response: [`CurrentStatsResponse`] 81 | CurrentStats(CurrentStatsRequest), 82 | } 83 | 84 | /// The response one gets back from the core server loop. 85 | #[derive(Deserialize, Serialize)] 86 | pub enum GokoResponse { 87 | Parameters(ParametersResponse), 88 | Knn(KnnResponse), 89 | RoutingKnn(RoutingKnnResponse), 90 | Path(PathResponse), 91 | Tracking(TrackingResponse), 92 | Unknown(String, u16), 93 | } 94 | 95 | #[derive(Deserialize, Serialize)] 96 | pub enum TrackingResponse { 97 | TrackPath(TrackPathResponse), 98 | AddTracker(AddTrackerResponse), 99 | CurrentStats(CurrentStatsResponse), 100 | Unknown(Option,Option), 101 | } 102 | 103 | /// Response for KNN type queries, usually in a vec 104 | #[derive(Deserialize, Serialize)] 105 | pub struct NamedDistance { 106 | /// The name of the point we're refering to 107 | pub name: String, 108 | /// Distance to that point 109 | pub distance: f32, 110 | } 111 | 112 | /// Response for queries that include distances to nodes, usually in a vec 113 | #[derive(Deserialize, Serialize)] 114 | pub struct NodeDistance { 115 | /// The name of the center point of the node we're refering to 116 | pub name: String, 117 | /// The level the node is at 118 | pub layer: i32, 119 | /// The distance to the central node 120 | pub distance: f32, 121 | pub label_summary: Option>, 122 | } 123 | 124 | impl CoreReader 125 | where P: Deref + Send + Sync + 'static { 126 | pub async fn process(&mut self, request: GokoRequest

) -> Result,InternalServiceError> { 127 | match request { 128 | GokoRequest::Parameters(p) => p.process(self).map(|p| GokoResponse::Parameters(p)).map_err(|e| e.into()), 129 | GokoRequest::Knn(p) => p.process(self).map(|p| GokoResponse::Knn(p)).map_err(|e| e.into()), 130 | GokoRequest::RoutingKnn(p) => p.process(self).map(|p| GokoResponse::RoutingKnn(p)).map_err(|e| e.into()), 131 | GokoRequest::Path(p) => p.process(self).map(|p| GokoResponse::Path(p)).map_err(|e| e.into()), 132 | GokoRequest::Unknown(response_string, status) => { 133 | Ok(GokoResponse::Unknown(response_string, status)) 134 | }, 135 | GokoRequest::Tracking(p) => { 136 | if let Some(tracker_name) = &p.tracker_name { 137 | if let TrackingRequestChoice::AddTracker(_) = p.request { 138 | self.trackers.write().await.entry(tracker_name.clone()).or_insert_with(|| TrackerWorker::operator(self.tree.clone())); 139 | } 140 | match self.trackers.read().await.get(tracker_name) { 141 | Some(t) => t.message(p).await.map(|r| GokoResponse::Tracking(r)), 142 | None => Ok(GokoResponse::Tracking(TrackingResponse::Unknown(Some(tracker_name.clone()), None))), 143 | } 144 | } else { 145 | self.main_tracker.message(p).await.map(|r| GokoResponse::Tracking(r)) 146 | } 147 | } 148 | } 149 | } 150 | } -------------------------------------------------------------------------------- /goko/src/utils.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to Elasticsearch B.V. under one or more contributor 3 | * license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright 5 | * ownership. Elasticsearch B.V. licenses this file to you under 6 | * the Apache License, Version 2.0 (the "License"); you may 7 | * not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | //! Utility functions for i/o 21 | 22 | use crate::errors::{GokoError, GokoResult}; 23 | use crate::tree_file_format::*; 24 | use protobuf::{CodedInputStream, CodedOutputStream, Message}; 25 | use std::fs::File; 26 | use std::fs::{read_to_string, remove_file, OpenOptions}; 27 | use std::path::Path; 28 | use std::sync::Arc; 29 | use yaml_rust::YamlLoader; 30 | 31 | use crate::builders::CoverTreeBuilder; 32 | 33 | use crate::CoverTreeWriter; 34 | 35 | use pointcloud::loaders::{labeled_ram_from_yaml, ram_from_yaml}; 36 | use pointcloud::*; 37 | 38 | /// Given a yaml file on disk, it builds a covertree. 39 | /// 40 | /// ```yaml 41 | /// --- 42 | /// leaf_cutoff: 5 43 | /// min_res_index: -10 44 | /// scale_base: 1.3 45 | /// data_path: DATAMEMMAPs 46 | /// labels_path: LABELS_CSV 47 | /// count: NUMBER_OF_DATA_POINTS 48 | /// data_dim: 784 49 | /// labels_index: 3 50 | /// ``` 51 | pub fn cover_tree_from_labeled_yaml>( 52 | path: P, 53 | ) -> GokoResult>> { 54 | let config = read_to_string(&path).expect("Unable to read config file"); 55 | 56 | let params_files = YamlLoader::load_from_str(&config).unwrap(); 57 | let params = ¶ms_files[0]; 58 | 59 | let point_cloud = labeled_ram_from_yaml::<_, L2>(&path)?; 60 | if let Some(count) = params["count"].as_i64() { 61 | if count as usize != point_cloud.len() { 62 | panic!( 63 | "We expected {:?} points, but the file has {:?} points at dim {:?}", 64 | count, 65 | point_cloud.len(), 66 | point_cloud.dim() 67 | ); 68 | } 69 | } 70 | 71 | let builder = CoverTreeBuilder::from_yaml(&path); 72 | println!( 73 | "Loaded dataset, building a cover tree with scale base {}, leaf_cutoff {}, min_res_index {}, and use_singletons {}", 74 | &builder.scale_base, &builder.min_res_index, &builder.min_res_index, &builder.use_singletons 75 | ); 76 | Ok(builder.build(Arc::new(point_cloud))?) 77 | } 78 | 79 | /// Given a yaml file on disk, it builds a covertree. 80 | /// 81 | /// ```yaml 82 | /// --- 83 | /// leaf_cutoff: 5 84 | /// min_res_index: -10 85 | /// scale_base: 1.3 86 | /// data_path: DATAMEMMAPs 87 | /// count: NUMBER_OF_DATA_POINTS 88 | /// data_dim: 784 89 | /// ``` 90 | 91 | pub fn cover_tree_from_yaml>( 92 | path: P, 93 | ) -> GokoResult>> { 94 | let config = read_to_string(&path).expect("Unable to read config file"); 95 | 96 | let params_files = YamlLoader::load_from_str(&config).unwrap(); 97 | let params = ¶ms_files[0]; 98 | 99 | let point_cloud = ram_from_yaml::<_, L2>(&path)?; 100 | if let Some(count) = params["count"].as_i64() { 101 | if count as usize != point_cloud.len() { 102 | panic!( 103 | "We expected {:?} points, but the file has {:?} points at dim {:?}", 104 | count, 105 | point_cloud.len(), 106 | point_cloud.dim() 107 | ); 108 | } 109 | } 110 | let builder = CoverTreeBuilder::from_yaml(&path); 111 | println!( 112 | "Loaded dataset, building a cover tree with scale base {}, leaf_cutoff {}, min_res_index {}, and use_singletons {}", 113 | &builder.scale_base, &builder.min_res_index, &builder.min_res_index, &builder.use_singletons 114 | ); 115 | Ok(builder.build(Arc::new(point_cloud))?) 116 | } 117 | 118 | /// Helper function that handles the file I/O and protobuf decoding for you. 119 | pub fn load_tree, D: PointCloud>( 120 | tree_path: P, 121 | point_cloud: Arc, 122 | ) -> GokoResult> { 123 | let tree_path_ref: &Path = tree_path.as_ref(); 124 | println!("\nLoading tree from : {}", tree_path_ref.to_string_lossy()); 125 | 126 | if !tree_path_ref.exists() { 127 | let tree_path_str = match tree_path_ref.to_str() { 128 | Some(expr) => expr, 129 | None => panic!("Unicode error with the tree path"), 130 | }; 131 | panic!("{} does not exist\n", tree_path_str); 132 | } 133 | 134 | let mut cover_proto = CoreProto::new(); 135 | 136 | let mut file = File::open(&tree_path_ref).map_err(GokoError::from)?; 137 | let mut cis = CodedInputStream::new(&mut file); 138 | if let Err(e) = cover_proto.merge_from(&mut cis) { 139 | panic!("Proto buff was unable to read {:#?}", e) 140 | } 141 | 142 | CoverTreeWriter::load(&cover_proto, point_cloud) 143 | } 144 | 145 | /// Helper function that handles the file I/O and protobuf encoding for you. 146 | pub fn save_tree, D: PointCloud>( 147 | tree_path: P, 148 | cover_tree: &CoverTreeWriter, 149 | ) -> GokoResult<()> { 150 | let tree_path_ref: &Path = tree_path.as_ref(); 151 | 152 | println!("Saving tree to : {}", tree_path_ref.to_string_lossy()); 153 | if tree_path_ref.exists() { 154 | let tree_path_str = match tree_path_ref.to_str() { 155 | Some(expr) => expr, 156 | None => panic!("Unicode error with the tree path"), 157 | }; 158 | println!("\t \t {:?} exists, removing", tree_path_str); 159 | remove_file(&tree_path).map_err(GokoError::from)?; 160 | } 161 | 162 | let cover_proto = cover_tree.save(); 163 | 164 | let mut core_file = OpenOptions::new() 165 | .read(true) 166 | .write(true) 167 | .create(true) 168 | .open(&tree_path) 169 | .unwrap(); 170 | 171 | let mut cos = CodedOutputStream::new(&mut core_file); 172 | cover_proto.write_to(&mut cos).map_err(GokoError::from)?; 173 | cos.flush().map_err(GokoError::from)?; 174 | Ok(()) 175 | } 176 | -------------------------------------------------------------------------------- /goko/src/covertree/query_tools/trace_query_heap.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to Elasticsearch B.V. under one or more contributor 3 | * license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright 5 | * ownership. Elasticsearch B.V. licenses this file to you under 6 | * the Apache License, Version 2.0 (the "License"); you may 7 | * not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | //! Tools and data structures for assisting cover tree queries. 21 | 22 | use super::*; 23 | use crate::NodeAddress; 24 | use std::collections::{BinaryHeap, HashMap}; 25 | use std::f32; 26 | 27 | use super::query_items::{QueryAddress, QueryAddressRev}; 28 | 29 | /// This is used to find the closest `k` nodes to the query point, either to get summary statistics out, or 30 | /// to restrict a Gaussian Mixture Model 31 | #[derive(Debug)] 32 | pub struct MultiscaleQueryHeap { 33 | layer_max_heaps: HashMap>, 34 | layer_min_heaps: HashMap>, 35 | k: usize, 36 | scale_base: f32, 37 | } 38 | 39 | impl RoutingQueryHeap for MultiscaleQueryHeap { 40 | /// Shoves data in here. 41 | fn push_nodes( 42 | &mut self, 43 | indexes: &[NodeAddress], 44 | dists: &[f32], 45 | _parent_address: Option, 46 | ) { 47 | for ((si, pi), d) in indexes.iter().zip(dists) { 48 | let emd = (d - self.scale_base.powi(*si)).max(0.0); 49 | 50 | println!("\t Inserting {:?} into max heap", ((si, pi), d)); 51 | let max_heap = self 52 | .layer_max_heaps 53 | .entry(*si) 54 | .or_insert_with(BinaryHeap::new); 55 | max_heap.push(QueryAddressRev { 56 | address: (*si, *pi), 57 | dist_to_center: *d, 58 | min_dist: emd, 59 | }); 60 | while max_heap.len() > self.k + 1 { 61 | max_heap.pop(); 62 | } 63 | 64 | let min_heap = self 65 | .layer_min_heaps 66 | .entry(*si) 67 | .or_insert_with(BinaryHeap::new); 68 | min_heap.push(QueryAddress { 69 | address: (*si, *pi), 70 | dist_to_center: *d, 71 | min_dist: emd, 72 | }); 73 | } 74 | } 75 | } 76 | 77 | impl MultiscaleQueryHeap { 78 | /// Creates a new set of heaps, hashmaps, and parameters designed to do multiscale KNN 79 | pub fn new(k: usize, scale_base: f32) -> MultiscaleQueryHeap { 80 | MultiscaleQueryHeap { 81 | layer_max_heaps: HashMap::new(), 82 | layer_min_heaps: HashMap::new(), 83 | k, 84 | scale_base, 85 | } 86 | } 87 | 88 | /// Gives us the closest unqueried node on a particular layer 89 | pub fn pop_closest_unqueried(&mut self, scale_index: i32) -> Option<(f32, NodeAddress)> { 90 | match self.layer_min_heaps.get_mut(&scale_index) { 91 | Some(heap) => match heap.pop() { 92 | None => None, 93 | Some(qa) => { 94 | let max_dist = self 95 | .furthest_node(scale_index) 96 | .map(|(d, _)| d) 97 | .unwrap_or(0.0); 98 | if max_dist <= qa.min_dist { 99 | Some((qa.min_dist, qa.address)) 100 | } else { 101 | None 102 | } 103 | } 104 | }, 105 | None => None, 106 | } 107 | } 108 | 109 | /// Unpacks this to a digestible format 110 | pub fn unpack(mut self) -> HashMap> { 111 | self.layer_max_heaps 112 | .drain() 113 | .map(|(si, heap)| { 114 | let mut v: Vec<(f32, NodeAddress)> = heap 115 | .into_iter_sorted() 116 | .map(|qa| (qa.dist_to_center, qa.address)) 117 | .collect(); 118 | v.reverse(); 119 | (si, v) 120 | }) 121 | .collect() 122 | } 123 | 124 | /// returns the node on a layer that is the furthest away. This returns None if the heap isn't full (less than K elements) 125 | pub fn furthest_node(&self, scale_index: i32) -> Option<(f32, NodeAddress)> { 126 | self.layer_max_heaps 127 | .get(&scale_index) 128 | .map(|max_heap| { 129 | if max_heap.len() < self.k { 130 | None 131 | } else { 132 | max_heap.peek().map(|x| (x.dist_to_center, x.address)) 133 | } 134 | }) 135 | .flatten() 136 | } 137 | 138 | /// The count at a layer 139 | pub fn count(&self, si: i32) -> usize { 140 | match self.layer_max_heaps.get(&si) { 141 | Some(heap) => heap.len(), 142 | None => 0, 143 | } 144 | } 145 | } 146 | 147 | //Tested in the node file too 148 | #[cfg(test)] 149 | pub(crate) mod tests { 150 | use super::*; 151 | #[test] 152 | fn multiscale_insertion_unpacks_correctly() { 153 | let mut trace_heap = MultiscaleQueryHeap::new(5, 2.0); 154 | let dists = [0.1, 0.2, 0.4, 0.5, 0.1, 0.2, 0.4, 0.5, 0.05]; 155 | let addresses = [ 156 | (0, 0), 157 | (0, 1), 158 | (0, 3), 159 | (0, 4), 160 | (1, 0), 161 | (1, 1), 162 | (1, 3), 163 | (1, 4), 164 | (1, 2), 165 | ]; 166 | trace_heap.push_nodes(&addresses, &dists, None); 167 | println!("{:#?}", trace_heap); 168 | let results = trace_heap.unpack(); 169 | let layer_0 = results.get(&0).unwrap(); 170 | assert_eq!(layer_0[0], (0.1, (0, 0))); 171 | assert_eq!(layer_0[1], (0.2, (0, 1))); 172 | 173 | let layer_1 = results.get(&1).unwrap(); 174 | assert_eq!(layer_1[0], (0.05, (1, 2))); 175 | assert_eq!(layer_1[1], (0.1, (1, 0))); 176 | assert_eq!(layer_1[2], (0.2, (1, 1))); 177 | } 178 | } 179 | -------------------------------------------------------------------------------- /pointcloud/src/summaries/mod.rs: -------------------------------------------------------------------------------- 1 | //! Summaries for some label types 2 | 3 | use hashbrown::HashMap; 4 | use std::default::Default; 5 | use std::iter::Iterator; 6 | 7 | use smallvec::SmallVec; 8 | 9 | use crate::base_traits::*; 10 | use serde::{Deserialize, Serialize}; 11 | use std::fmt::Debug; 12 | 13 | /// A summary for a small number of categories. 14 | #[derive(Clone, Debug, Deserialize, Serialize)] 15 | pub struct CategorySummary { 16 | /// Hashmap that counts how many of each instance of string there is 17 | pub items: SmallVec<[(i64, usize); 4]>, 18 | } 19 | 20 | impl Default for CategorySummary { 21 | fn default() -> Self { 22 | CategorySummary { 23 | items: SmallVec::new(), 24 | } 25 | } 26 | } 27 | 28 | impl Summary for CategorySummary { 29 | type Label = i64; 30 | fn add(&mut self, val: &i64) { 31 | let mut added_to_existing = false; 32 | for (stored_val, totals) in self.items.iter_mut() { 33 | if val == stored_val { 34 | *totals += 1; 35 | added_to_existing = true; 36 | break; 37 | } 38 | } 39 | if !added_to_existing { 40 | self.items.push((*val, 1)); 41 | } 42 | } 43 | 44 | fn combine(&mut self, other: &CategorySummary) { 45 | for (val, count) in other.items.iter() { 46 | let mut added_to_existing = false; 47 | for (stored_val, totals) in self.items.iter_mut() { 48 | if val == stored_val { 49 | *totals += count; 50 | added_to_existing = true; 51 | break; 52 | } 53 | } 54 | if !added_to_existing { 55 | self.items.push((*val, *count)); 56 | } 57 | } 58 | } 59 | 60 | fn count(&self) -> usize { 61 | self.items.iter().map(|(_a, b)| b).sum() 62 | } 63 | } 64 | 65 | /// Summary of vectors 66 | #[derive(Clone, Debug, Serialize, Deserialize, Default)] 67 | pub struct VecSummary { 68 | /// First moment, see 69 | pub moment1: Vec, 70 | /// Second moment, see 71 | pub moment2: Vec, 72 | /// The count of the number of labels included 73 | pub count: usize, 74 | } 75 | 76 | impl Summary for VecSummary { 77 | type Label = [f32]; 78 | 79 | fn add(&mut self, val: &[f32]) { 80 | if !self.moment1.is_empty() { 81 | if self.moment1.len() == val.len() { 82 | self.moment1.iter_mut().zip(val).for_each(|(m, x)| *m += x); 83 | self.moment2 84 | .iter_mut() 85 | .zip(val) 86 | .for_each(|(m, x)| *m += x * x); 87 | self.count += 1; 88 | } else { 89 | panic!( 90 | "Combining a vec of len {:?} and of len {:?}", 91 | self.moment1.len(), 92 | val.len() 93 | ); 94 | } 95 | } else { 96 | self.moment1.extend(val); 97 | self.moment2.extend(val.iter().map(|x| x * x)) 98 | } 99 | } 100 | fn combine(&mut self, other: &VecSummary) { 101 | self.moment1 102 | .iter_mut() 103 | .zip(&other.moment1) 104 | .for_each(|(x, y)| *x += y); 105 | self.moment2 106 | .iter_mut() 107 | .zip(&other.moment2) 108 | .for_each(|(x, y)| *x += y); 109 | self.count += other.count; 110 | } 111 | 112 | fn count(&self) -> usize { 113 | self.count 114 | } 115 | } 116 | 117 | /// Summary of a bunch of underlying floats 118 | #[derive(Clone, Debug, Serialize, Deserialize, Default)] 119 | pub struct FloatSummary { 120 | /// First moment, see 121 | pub moment1: f64, 122 | /// Second moment, see 123 | pub moment2: f64, 124 | /// The count of the number of labels included 125 | pub count: usize, 126 | } 127 | 128 | impl Summary for FloatSummary { 129 | type Label = f64; 130 | 131 | fn add(&mut self, val: &f64) { 132 | self.moment1 += val; 133 | self.moment2 += val * val; 134 | self.count += 1; 135 | } 136 | fn combine(&mut self, other: &FloatSummary) { 137 | self.moment1 += other.moment1; 138 | self.moment2 += other.moment2; 139 | self.count += other.count; 140 | } 141 | 142 | fn count(&self) -> usize { 143 | self.count 144 | } 145 | } 146 | 147 | /// Summary of a bunch of underlying integers, more accurate for int than the float summary 148 | #[derive(Clone, Debug, Serialize, Deserialize, Default)] 149 | pub struct IntSummary { 150 | /// First moment, see 151 | pub moment1: i64, 152 | /// Second moment, see 153 | pub moment2: i64, 154 | /// The count of the number of labels included 155 | pub count: usize, 156 | } 157 | 158 | impl Summary for IntSummary { 159 | type Label = i64; 160 | 161 | fn add(&mut self, val: &i64) { 162 | self.moment1 += val; 163 | self.moment2 += val * val; 164 | self.count += 1; 165 | } 166 | fn combine(&mut self, other: &IntSummary) { 167 | self.moment1 += other.moment1; 168 | self.moment2 += other.moment2; 169 | self.count += other.count; 170 | } 171 | 172 | fn count(&self) -> usize { 173 | self.count 174 | } 175 | } 176 | 177 | /// A summary for a small number of categories. 178 | #[derive(Clone, Debug, Serialize, Deserialize)] 179 | pub struct StringSummary { 180 | /// Hashmap that counts how many of each instance of string there is 181 | pub items: HashMap, 182 | } 183 | 184 | impl Default for StringSummary { 185 | fn default() -> Self { 186 | StringSummary { 187 | items: HashMap::new(), 188 | } 189 | } 190 | } 191 | 192 | impl Summary for StringSummary { 193 | type Label = String; 194 | fn add(&mut self, val: &String) { 195 | *self.items.entry(val.to_string()).or_insert(0) += 1; 196 | } 197 | 198 | fn combine(&mut self, other: &StringSummary) { 199 | for (val, count) in other.items.iter() { 200 | *self.items.entry(val.to_string()).or_insert(0) += count; 201 | } 202 | } 203 | 204 | fn count(&self) -> usize { 205 | self.items.values().sum() 206 | } 207 | } 208 | -------------------------------------------------------------------------------- /pointcloud/src/metrics/l1_misc.rs: -------------------------------------------------------------------------------- 1 | //! Various implementations of the L1 metric for types that can be easily converted to f32. 2 | 3 | use super::L1; 4 | use crate::base_traits::Metric; 5 | use crate::points::*; 6 | use packed_simd::*; 7 | use std::ops::Deref; 8 | 9 | macro_rules! make_l1_distance { 10 | ($base:ident, $simd_16_base:ident, $simd_8_base:ident, $sparse_base:ident, $dist_base:ident, $norm_base:ident) => { 11 | /// 12 | #[inline] 13 | pub fn $dist_base(mut x: &[$base], mut y: &[$base]) -> f32 { 14 | let mut d_acc_16 = f32x16::splat(0.0); 15 | while y.len() > 16 { 16 | let x_simd = $simd_16_base::from_slice_unaligned(x); 17 | let y_simd = $simd_16_base::from_slice_unaligned(y); 18 | let x_simd_f32 = f32x16::from_cast(x_simd); 19 | let y_simd_f32 = f32x16::from_cast(y_simd); 20 | let diff = x_simd_f32 - y_simd_f32; 21 | d_acc_16 += diff.abs(); 22 | y = &y[16..]; 23 | x = &x[16..]; 24 | } 25 | let mut d_acc_8 = f32x8::splat(0.0); 26 | if y.len() > 8 { 27 | let x_simd = $simd_8_base::from_slice_unaligned(x); 28 | let y_simd = $simd_8_base::from_slice_unaligned(y); 29 | let x_simd_f32 = f32x8::from_cast(x_simd); 30 | let y_simd_f32 = f32x8::from_cast(y_simd); 31 | let diff = x_simd_f32 - y_simd_f32; 32 | d_acc_8 += diff.abs(); 33 | y = &y[8..]; 34 | x = &x[8..]; 35 | } 36 | let leftover = y 37 | .iter() 38 | .zip(x) 39 | .map(|(xi, yi)| (*xi as f32 - *yi as f32).abs()) 40 | .fold(0.0, |acc, y| acc + y); 41 | leftover + d_acc_8.sum() + d_acc_16.sum() 42 | } 43 | 44 | /// 45 | #[inline] 46 | pub fn $norm_base(mut x: &[$base]) -> f32 { 47 | let mut d_acc_16 = f32x16::splat(0.0); 48 | while x.len() > 16 { 49 | let x_simd = $simd_16_base::from_slice_unaligned(x); 50 | let x_simd_f32 = f32x16::from_cast(x_simd); 51 | d_acc_16 += x_simd_f32.abs(); 52 | x = &x[16..]; 53 | } 54 | let mut d_acc_8 = f32x8::splat(0.0); 55 | if x.len() > 8 { 56 | let x_simd = $simd_8_base::from_slice_unaligned(x); 57 | let x_simd_f32 = f32x8::from_cast(x_simd); 58 | d_acc_8 += x_simd_f32.abs(); 59 | x = &x[8..]; 60 | } 61 | let leftover = x 62 | .iter() 63 | .map(|xi| (*xi as f32).abs()) 64 | .fold(0.0, |acc, y| acc + y); 65 | leftover + d_acc_8.sum() + d_acc_16.sum() 66 | } 67 | 68 | /// basic sparse function 69 | pub fn $sparse_base(x_ind: &[S], x_val: &[$base], y_ind: &[S], y_val: &[$base]) -> f32 70 | where 71 | S: Ord, 72 | { 73 | if x_val.is_empty() || y_val.is_empty() { 74 | if x_val.is_empty() && y_val.is_empty() { 75 | return 0.0; 76 | } 77 | if !x_val.is_empty() && y_val.is_empty() { 78 | $norm_base(x_val) 79 | } else { 80 | $norm_base(y_val) 81 | } 82 | } else { 83 | let mut total = 0.0; 84 | let (short_iter, mut long_iter) = if x_ind.len() > y_ind.len() { 85 | (y_ind.iter().zip(y_val), x_ind.iter().zip(x_val)) 86 | } else { 87 | (x_ind.iter().zip(x_val), y_ind.iter().zip(y_val)) 88 | }; 89 | 90 | let mut l_tr: Option<(&S, &$base)> = long_iter.next(); 91 | for (si, sv) in short_iter { 92 | while let Some((li, lv)) = l_tr { 93 | if li < si { 94 | total += (*lv as f32).abs(); 95 | l_tr = long_iter.next(); 96 | } else { 97 | break; 98 | } 99 | } 100 | if let Some((li, lv)) = l_tr { 101 | if li == si { 102 | let val = (*sv as f32) - (*lv as f32); 103 | total += val.abs(); 104 | l_tr = long_iter.next(); 105 | } else { 106 | total += (*sv as f32).abs(); 107 | } 108 | } else { 109 | total += (*sv as f32).abs(); 110 | } 111 | } 112 | while let Some((_li, lv)) = l_tr { 113 | total += (*lv as f32).abs(); 114 | l_tr = long_iter.next(); 115 | } 116 | total 117 | } 118 | } 119 | impl Metric<[$base]> for L1 { 120 | fn dist(x: &[$base], y: &[$base]) -> f32 { 121 | $dist_base(x.deref(), y.deref()).sqrt() 122 | } 123 | } 124 | 125 | impl<'a> Metric> for L1 { 126 | fn dist(x: &RawSparse<$base, u32>, y: &RawSparse<$base, u32>) -> f32 { 127 | $sparse_base(x.indexes(), x.values(), y.indexes(), y.values()).sqrt() 128 | } 129 | } 130 | 131 | impl<'a> Metric> for L1 { 132 | fn dist(x: &RawSparse<$base, u16>, y: &RawSparse<$base, u16>) -> f32 { 133 | $sparse_base(x.indexes(), x.values(), y.indexes(), y.values()).sqrt() 134 | } 135 | } 136 | 137 | impl<'a> Metric> for L1 { 138 | fn dist(x: &RawSparse<$base, u8>, y: &RawSparse<$base, u8>) -> f32 { 139 | $sparse_base(x.indexes(), x.values(), y.indexes(), y.values()).sqrt() 140 | } 141 | } 142 | }; 143 | } 144 | 145 | make_l1_distance!(i8, i8x16, i8x8, l1_sparse_i8_f32, l1_dense_i8, l1_norm_i8); 146 | make_l1_distance!(u8, u8x16, u8x8, l1_sparse_u8_f32, l1_dense_u8, l1_norm_u8); 147 | make_l1_distance!( 148 | i16, 149 | i16x16, 150 | i16x8, 151 | l1_sparse_i16_f32, 152 | l1_dense_i16, 153 | l1_norm_i16 154 | ); 155 | make_l1_distance!( 156 | u16, 157 | u16x16, 158 | u16x8, 159 | l1_sparse_u16_f32, 160 | l1_dense_u16, 161 | l1_norm_u16 162 | ); 163 | make_l1_distance!( 164 | i32, 165 | i32x16, 166 | i32x8, 167 | l1_sparse_i32_f32, 168 | l1_dense_i32, 169 | l1_norm_i32 170 | ); 171 | make_l1_distance!( 172 | u32, 173 | u32x16, 174 | u32x8, 175 | l1_sparse_u32_f32, 176 | l1_dense_u32, 177 | l1_norm_u32 178 | ); 179 | -------------------------------------------------------------------------------- /pygoko/src/node.rs: -------------------------------------------------------------------------------- 1 | use pyo3::prelude::*; 2 | 3 | use ndarray::{Array1, Array2}; 4 | use numpy::{IntoPyArray, PyArray1, PyArray2}; 5 | use pyo3::PyIterProtocol; 6 | 7 | use goko::plugins::discrete::prelude::*; 8 | use goko::plugins::gaussians::*; 9 | use goko::*; 10 | use pointcloud::*; 11 | use std::sync::Arc; 12 | 13 | use pyo3::types::PyDict; 14 | 15 | #[pyclass(unsendable)] 16 | pub struct IterLayerNode { 17 | pub parameters: Arc>>, 18 | pub addresses: Vec, 19 | pub tree: CoverTreeReader>, 20 | pub index: usize, 21 | } 22 | 23 | impl std::iter::Iterator for IterLayerNode { 24 | type Item = PyNode; 25 | fn next(&mut self) -> Option { 26 | if self.index < self.addresses.len() { 27 | let index = self.index; 28 | self.index += 1; 29 | Some(PyNode { 30 | parameters: Arc::clone(&self.parameters), 31 | address: self.addresses[index], 32 | tree: self.tree.clone(), 33 | }) 34 | } else { 35 | None 36 | } 37 | } 38 | } 39 | 40 | #[pyproto] 41 | impl PyIterProtocol for IterLayerNode { 42 | fn __iter__(slf: PyRefMut) -> PyResult> { 43 | Ok(slf.into()) 44 | } 45 | fn __next__(mut slf: PyRefMut) -> PyResult> { 46 | Ok(slf.next()) 47 | } 48 | } 49 | 50 | #[pyclass(unsendable)] 51 | pub struct PyNode { 52 | pub parameters: Arc>>, 53 | pub address: NodeAddress, 54 | pub tree: CoverTreeReader>, 55 | } 56 | 57 | #[pymethods] 58 | impl PyNode { 59 | pub fn address(&self) -> (i32, usize) { 60 | self.address 61 | } 62 | 63 | pub fn is_leaf(&self) -> bool { 64 | self.tree 65 | .get_node_and(self.address, |n| n.is_leaf()) 66 | .unwrap() 67 | } 68 | 69 | pub fn coverage_count(&self) -> usize { 70 | self.tree 71 | .get_node_and(self.address, |n| n.coverage_count()) 72 | .unwrap() 73 | } 74 | 75 | pub fn children(&self) -> Vec { 76 | self.children_addresses() 77 | .iter() 78 | .map(|address| PyNode { 79 | parameters: Arc::clone(&self.parameters), 80 | address: *address, 81 | tree: self.tree.clone(), 82 | }) 83 | .collect() 84 | } 85 | 86 | pub fn children_probs(&self) -> Option<(Vec<((i32, usize), f64)>, f64)> { 87 | self.tree 88 | .get_node_plugin_and(self.address, |p: &Dirichlet| p.prob_vector()) 89 | .flatten() 90 | } 91 | 92 | pub fn children_addresses(&self) -> Vec<(i32, usize)> { 93 | self.tree 94 | .get_node_and(self.address, |n| { 95 | n.children().map(|(nested_scale, children)| { 96 | let mut py_nodes: Vec<(i32, usize)> = Vec::from(children); 97 | py_nodes.push((nested_scale, *n.center_index())); 98 | py_nodes 99 | }) 100 | }) 101 | .flatten() 102 | .unwrap_or(vec![]) 103 | } 104 | 105 | pub fn fractal_dim(&self) -> f32 { 106 | self.tree.node_fractal_dim(self.address) 107 | } 108 | 109 | pub fn weighted_fractal_dim(&self) -> f32 { 110 | self.tree.node_weighted_fractal_dim(self.address) 111 | } 112 | 113 | pub fn singletons(&self) -> PyResult>> { 114 | let dim = self.parameters.point_cloud.dim(); 115 | let len = self.coverage_count() as usize; 116 | let mut ret_matrix = Vec::with_capacity(len * dim); 117 | self.tree.get_node_and(self.address, |n| { 118 | n.singletons().iter().for_each(|pi| { 119 | if let Ok(p) = self.parameters.point_cloud.point(*pi) { 120 | ret_matrix.extend(p.dense_iter()); 121 | } 122 | }); 123 | 124 | if n.is_leaf() { 125 | if let Ok(p) = self.parameters.point_cloud.point(*n.center_index()) { 126 | ret_matrix.extend(p.dense_iter()); 127 | } 128 | } 129 | }); 130 | 131 | let ret_matrix = Array2::from_shape_vec((len, dim), ret_matrix).unwrap(); 132 | let gil = pyo3::Python::acquire_gil(); 133 | let py = gil.python(); 134 | Ok(ret_matrix.into_pyarray(py).to_owned()) 135 | } 136 | 137 | pub fn singletons_indexes(&self) -> Vec { 138 | self.tree 139 | .get_node_and(self.address, |n| Vec::from(n.singletons())) 140 | .unwrap_or(vec![]) 141 | } 142 | 143 | pub fn cover_mean(&self) -> PyResult>>> { 144 | let dim = self.parameters.point_cloud.dim(); 145 | let gil = pyo3::Python::acquire_gil(); 146 | let py = gil.python(); 147 | let mean = self 148 | .tree 149 | .get_node_plugin_and::(self.address, |p| p.mean()) 150 | .map(|m| { 151 | Array1::from_shape_vec((dim,), m) 152 | .unwrap() 153 | .into_pyarray(py) 154 | .to_owned() 155 | }); 156 | 157 | Ok(mean) 158 | } 159 | 160 | pub fn cover_diag_var(&self) -> PyResult>> { 161 | let dim = self.parameters.point_cloud.dim(); 162 | let var = self 163 | .tree 164 | .get_node_plugin_and::(self.address, |p| p.var()) 165 | .unwrap(); 166 | let py_mean = Array1::from_shape_vec((dim,), var).unwrap(); 167 | let gil = pyo3::Python::acquire_gil(); 168 | let py = gil.python(); 169 | Ok(py_mean.into_pyarray(py).to_owned()) 170 | } 171 | 172 | /* 173 | pub fn get_singular_values(&self) -> PyResult>>> { 174 | let gil = pyo3::Python::acquire_gil(); 175 | let py = gil.python(); 176 | Ok(self 177 | .tree 178 | .get_node_plugin_and::(self.address, |p| { 179 | p.singular_vals.clone().into_pyarray(py).to_owned() 180 | })) 181 | } 182 | */ 183 | 184 | pub fn label_summary(&self) -> PyResult> { 185 | let gil = pyo3::Python::acquire_gil(); 186 | let py = gil.python(); 187 | let dict = PyDict::new(py); 188 | match self.tree.get_node_label_summary(self.address) { 189 | Some(s) => { 190 | dict.set_item("errors", s.errors)?; 191 | dict.set_item("nones", s.nones)?; 192 | dict.set_item("items", s.summary.items.to_vec())?; 193 | Ok(Some(dict.into())) 194 | } 195 | None => Ok(None), 196 | } 197 | } 198 | } 199 | -------------------------------------------------------------------------------- /pygoko/src/layer.rs: -------------------------------------------------------------------------------- 1 | use pyo3::prelude::*; 2 | 3 | use ndarray::{Array, Array2}; 4 | use numpy::{IntoPyArray, PyArray1, PyArray2}; 5 | use pyo3::PyIterProtocol; 6 | 7 | use goko::layer::*; 8 | use goko::*; 9 | use pointcloud::*; 10 | use std::sync::Arc; 11 | 12 | use crate::node::*; 13 | 14 | #[pyclass(unsendable)] 15 | pub struct IterLayers { 16 | pub parameters: Arc>>, 17 | pub tree: CoverTreeReader>, 18 | pub scale_indexes: Vec, 19 | pub index: usize, 20 | } 21 | 22 | impl std::iter::Iterator for IterLayers { 23 | type Item = PyLayer; 24 | fn next(&mut self) -> Option { 25 | if self.index < self.scale_indexes.len() { 26 | self.index += 1; 27 | Some(PyLayer { 28 | parameters: Arc::clone(&self.parameters), 29 | tree: self.tree.clone(), 30 | scale_index: self.scale_indexes[self.index - 1], 31 | }) 32 | } else { 33 | None 34 | } 35 | } 36 | } 37 | 38 | #[pyproto] 39 | impl PyIterProtocol for IterLayers { 40 | fn __iter__(slf: PyRefMut) -> PyResult> { 41 | Ok(slf.into()) 42 | } 43 | fn __next__(mut slf: PyRefMut) -> PyResult> { 44 | Ok(slf.next()) 45 | } 46 | } 47 | 48 | #[pyclass(unsendable)] 49 | pub struct PyLayer { 50 | pub parameters: Arc>>, 51 | pub tree: CoverTreeReader>, 52 | pub scale_index: i32, 53 | } 54 | 55 | impl PyLayer { 56 | fn layer(&self) -> &CoverLayerReader> { 57 | self.tree.layer(self.scale_index) 58 | } 59 | } 60 | 61 | #[pymethods] 62 | impl PyLayer { 63 | pub fn radius(&self) -> f32 { 64 | self.parameters.scale_base.powi(self.layer().scale_index()) 65 | } 66 | pub fn scale_index(&self) -> i32 { 67 | self.scale_index 68 | } 69 | pub fn len(&self) -> usize { 70 | self.layer().len() 71 | } 72 | pub fn center_indexes(&self) -> Vec { 73 | self.layer().map_nodes(|pi, _n| *pi) 74 | } 75 | pub fn child_addresses(&self, point_index: usize) -> Option> { 76 | self.layer() 77 | .get_node_children_and(point_index, |nested_address, child_addresses| { 78 | let mut v = vec![nested_address]; 79 | v.extend(child_addresses); 80 | v 81 | }) 82 | } 83 | pub fn singleton_indexes(&self, point_index: usize) -> Option> { 84 | self.layer() 85 | .get_node_and(point_index, |n| Vec::from(n.singletons())) 86 | } 87 | 88 | pub fn is_leaf(&self, point_index: usize) -> Option { 89 | self.layer().get_node_and(point_index, |n| n.is_leaf()) 90 | } 91 | 92 | pub fn fractal_dim(&self) -> f32 { 93 | self.tree.layer_fractal_dim(self.scale_index) 94 | } 95 | 96 | pub fn weighted_fractal_dim(&self) -> f32 { 97 | self.tree.layer_weighted_fractal_dim(self.scale_index) 98 | } 99 | 100 | pub fn centers(&self) -> PyResult<(Py>, Py>)> { 101 | let mut centers = 102 | Vec::with_capacity(self.layer().len() * self.parameters.point_cloud.dim()); 103 | let mut centers_indexes = Vec::with_capacity(self.layer().len()); 104 | self.layer().for_each_node(|pi, _n| { 105 | centers_indexes.push(*pi); 106 | centers.extend(self.parameters.point_cloud.point(*pi).unwrap().dense_iter()); 107 | }); 108 | let py_center_indexes = Array::from(centers_indexes); 109 | let py_centers = Array2::from_shape_vec( 110 | (self.layer().len(), self.parameters.point_cloud.dim()), 111 | centers, 112 | ) 113 | .unwrap(); 114 | let gil = pyo3::Python::acquire_gil(); 115 | let py = gil.python(); 116 | Ok(( 117 | py_center_indexes.into_pyarray(py).to_owned(), 118 | py_centers.into_pyarray(py).to_owned(), 119 | )) 120 | } 121 | 122 | pub fn child_points(&self, point_index: usize) -> PyResult>>> { 123 | let dim = self.parameters.point_cloud.dim(); 124 | Ok(self 125 | .layer() 126 | .get_node_children_and(point_index, |nested_address, child_addresses| { 127 | let count = child_addresses.len() + 1; 128 | let mut centers: Vec = Vec::with_capacity(count * dim); 129 | centers.extend( 130 | self.parameters 131 | .point_cloud 132 | .point(nested_address.1) 133 | .unwrap() 134 | .dense_iter(), 135 | ); 136 | for na in child_addresses { 137 | centers.extend( 138 | self.parameters 139 | .point_cloud 140 | .point(na.1) 141 | .unwrap() 142 | .dense_iter(), 143 | ); 144 | } 145 | let py_centers = Array2::from_shape_vec((count, dim), centers).unwrap(); 146 | let gil = pyo3::Python::acquire_gil(); 147 | let py = gil.python(); 148 | py_centers.into_pyarray(py).to_owned() 149 | })) 150 | } 151 | pub fn singleton_points(&self, point_index: usize) -> PyResult>>> { 152 | let dim = self.parameters.point_cloud.dim(); 153 | Ok(self.layer().get_node_and(point_index, |node| { 154 | let singletons = node.singletons(); 155 | let mut centers: Vec = Vec::with_capacity(singletons.len() * dim); 156 | for pi in singletons { 157 | centers.extend(self.parameters.point_cloud.point(*pi).unwrap().dense_iter()); 158 | } 159 | let py_centers = Array2::from_shape_vec((singletons.len(), dim), centers).unwrap(); 160 | let gil = pyo3::Python::acquire_gil(); 161 | let py = gil.python(); 162 | py_centers.into_pyarray(py).to_owned() 163 | })) 164 | } 165 | 166 | pub fn node(&self, center_index: usize) -> PyResult { 167 | Ok(PyNode { 168 | parameters: Arc::clone(&self.parameters), 169 | address: (self.scale_index, center_index), 170 | tree: self.tree.clone(), 171 | }) 172 | } 173 | 174 | pub fn nodes(&self) -> PyResult { 175 | Ok(IterLayerNode { 176 | parameters: Arc::clone(&self.parameters), 177 | addresses: self 178 | .layer() 179 | .node_center_indexes() 180 | .iter() 181 | .map(|pi| (self.scale_index, *pi)) 182 | .collect(), 183 | tree: self.tree.clone(), 184 | index: 0, 185 | }) 186 | } 187 | } 188 | -------------------------------------------------------------------------------- /pointcloud/src/metrics/l2_misc.rs: -------------------------------------------------------------------------------- 1 | //! Various implementations of the L2 metric for types that can be easily converted to f32. 2 | 3 | use super::L2; 4 | use crate::base_traits::Metric; 5 | use crate::points::*; 6 | use packed_simd::*; 7 | use std::ops::Deref; 8 | 9 | macro_rules! make_l2_distance { 10 | ($base:ident, $simd_16_base:ident, $simd_8_base:ident, $sparse_base:ident, $dist_base:ident, $norm_base:ident) => { 11 | /// 12 | #[inline] 13 | pub fn $dist_base(mut x: &[$base], mut y: &[$base]) -> f32 { 14 | let mut d_acc_16 = f32x16::splat(0.0); 15 | while y.len() > 16 { 16 | let x_simd = $simd_16_base::from_slice_unaligned(x); 17 | let y_simd = $simd_16_base::from_slice_unaligned(y); 18 | let x_simd_f32 = f32x16::from_cast(x_simd); 19 | let y_simd_f32 = f32x16::from_cast(y_simd); 20 | let diff = x_simd_f32 - y_simd_f32; 21 | d_acc_16 += diff * diff; 22 | y = &y[16..]; 23 | x = &x[16..]; 24 | } 25 | let mut d_acc_8 = f32x8::splat(0.0); 26 | if y.len() > 8 { 27 | let x_simd = $simd_8_base::from_slice_unaligned(x); 28 | let y_simd = $simd_8_base::from_slice_unaligned(y); 29 | let x_simd_f32 = f32x8::from_cast(x_simd); 30 | let y_simd_f32 = f32x8::from_cast(y_simd); 31 | let diff = x_simd_f32 - y_simd_f32; 32 | d_acc_8 += diff * diff; 33 | y = &y[8..]; 34 | x = &x[8..]; 35 | } 36 | let leftover = y 37 | .iter() 38 | .zip(x) 39 | .map(|(xi, yi)| (*xi as f32 - *yi as f32) * (*xi as f32 - *yi as f32)) 40 | .fold(0.0, |acc, y| acc + y); 41 | leftover + d_acc_8.sum() + d_acc_16.sum() 42 | } 43 | 44 | /// 45 | #[inline] 46 | pub fn $norm_base(mut x: &[$base]) -> f32 { 47 | let mut d_acc_16 = f32x16::splat(0.0); 48 | while x.len() > 16 { 49 | let x_simd = $simd_16_base::from_slice_unaligned(x); 50 | let x_simd_f32 = f32x16::from_cast(x_simd); 51 | d_acc_16 += x_simd_f32 * x_simd_f32; 52 | x = &x[16..]; 53 | } 54 | let mut d_acc_8 = f32x8::splat(0.0); 55 | if x.len() > 8 { 56 | let x_simd = $simd_8_base::from_slice_unaligned(x); 57 | let x_simd_f32 = f32x8::from_cast(x_simd); 58 | d_acc_8 += x_simd_f32 * x_simd_f32; 59 | x = &x[8..]; 60 | } 61 | let leftover = x 62 | .iter() 63 | .map(|xi| (*xi as f32) * (*xi as f32)) 64 | .fold(0.0, |acc, y| acc + y); 65 | leftover + d_acc_8.sum() + d_acc_16.sum() 66 | } 67 | 68 | /// basic sparse function 69 | pub fn $sparse_base(x_ind: &[S], x_val: &[$base], y_ind: &[S], y_val: &[$base]) -> f32 70 | where 71 | S: Ord, 72 | { 73 | if x_val.is_empty() || y_val.is_empty() { 74 | if x_val.is_empty() && y_val.is_empty() { 75 | return 0.0; 76 | } 77 | if !x_val.is_empty() && y_val.is_empty() { 78 | $norm_base(x_val) 79 | } else { 80 | $norm_base(y_val) 81 | } 82 | } else { 83 | let mut total = 0.0; 84 | let (short_iter, mut long_iter) = if x_ind.len() > y_ind.len() { 85 | (y_ind.iter().zip(y_val), x_ind.iter().zip(x_val)) 86 | } else { 87 | (x_ind.iter().zip(x_val), y_ind.iter().zip(y_val)) 88 | }; 89 | 90 | let mut l_tr: Option<(&S, &$base)> = long_iter.next(); 91 | for (si, sv) in short_iter { 92 | while let Some((li, lv)) = l_tr { 93 | if li < si { 94 | total += (*lv as f32) * (*lv as f32); 95 | l_tr = long_iter.next(); 96 | } else { 97 | break; 98 | } 99 | } 100 | if let Some((li, lv)) = l_tr { 101 | if li == si { 102 | let val = (*sv as f32) - (*lv as f32); 103 | total += val * val; 104 | l_tr = long_iter.next(); 105 | } else { 106 | total += (*sv as f32) * (*sv as f32); 107 | } 108 | } else { 109 | total += (*sv as f32) * (*sv as f32); 110 | } 111 | } 112 | while let Some((_li, lv)) = l_tr { 113 | total += (*lv as f32) * (*lv as f32); 114 | l_tr = long_iter.next(); 115 | } 116 | total 117 | } 118 | } 119 | impl Metric<[$base]> for L2 { 120 | fn dist(x: &[$base], y: &[$base]) -> f32 { 121 | $dist_base(x.deref(), y.deref()).sqrt() 122 | } 123 | } 124 | 125 | impl<'a> Metric> for L2 { 126 | fn dist(x: &RawSparse<$base, u32>, y: &RawSparse<$base, u32>) -> f32 { 127 | $sparse_base(x.indexes(), x.values(), y.indexes(), y.values()).sqrt() 128 | } 129 | } 130 | 131 | impl<'a> Metric> for L2 { 132 | fn dist(x: &RawSparse<$base, u16>, y: &RawSparse<$base, u16>) -> f32 { 133 | $sparse_base(x.indexes(), x.values(), y.indexes(), y.values()).sqrt() 134 | } 135 | } 136 | 137 | impl<'a> Metric> for L2 { 138 | fn dist(x: &RawSparse<$base, u8>, y: &RawSparse<$base, u8>) -> f32 { 139 | $sparse_base(x.indexes(), x.values(), y.indexes(), y.values()).sqrt() 140 | } 141 | } 142 | }; 143 | } 144 | 145 | make_l2_distance!( 146 | i8, 147 | i8x16, 148 | i8x8, 149 | sq_l2_sparse_i8_f32, 150 | sq_l2_dense_i8, 151 | sq_l2_norm_i8 152 | ); 153 | make_l2_distance!( 154 | u8, 155 | u8x16, 156 | u8x8, 157 | sq_l2_sparse_u8_f32, 158 | sq_l2_dense_u8, 159 | sq_l2_norm_u8 160 | ); 161 | make_l2_distance!( 162 | i16, 163 | i16x16, 164 | i16x8, 165 | sq_l2_sparse_i16_f32, 166 | sq_l2_dense_i16, 167 | sq_l2_norm_i16 168 | ); 169 | make_l2_distance!( 170 | u16, 171 | u16x16, 172 | u16x8, 173 | sq_l2_sparse_u16_f32, 174 | sq_l2_dense_u16, 175 | sq_l2_norm_u16 176 | ); 177 | make_l2_distance!( 178 | i32, 179 | i32x16, 180 | i32x8, 181 | sq_l2_sparse_i32_f32, 182 | sq_l2_dense_i32, 183 | sq_l2_norm_i32 184 | ); 185 | make_l2_distance!( 186 | u32, 187 | u32x16, 188 | u32x8, 189 | sq_l2_sparse_u32_f32, 190 | sq_l2_dense_u32, 191 | sq_l2_norm_u32 192 | ); 193 | -------------------------------------------------------------------------------- /goko/src/query_interface/mod.rs: -------------------------------------------------------------------------------- 1 | //! Interfacees that simplify bulk queries 2 | 3 | //use crossbeam_channel::unbounded; 4 | use crate::*; 5 | use ndarray::ArrayView2; 6 | use rayon::iter::repeatn; 7 | use std::ops::Deref; 8 | 9 | /// Inteface for bulk queries. Handles cloning the readers for you 10 | pub struct BulkInterface { 11 | reader: CoverTreeReader, 12 | } 13 | 14 | impl BulkInterface { 15 | /// Creates a new one. 16 | pub fn new(reader: CoverTreeReader) -> Self { 17 | BulkInterface { reader } 18 | } 19 | 20 | /// Applies the passed in fn to the passed in indexes and collects the result in a vector. Core function for this struct. 21 | pub fn index_map_with_reader(&self, point_indexes: &[usize], f: F) -> Vec 22 | where 23 | F: Fn(&CoverTreeReader, usize) -> T + Send + Sync, 24 | T: Send + Sync, 25 | { 26 | let indexes_iter = point_indexes.par_chunks(100); 27 | let reader_copies = indexes_iter.len(); 28 | let mut chunked_results: Vec> = indexes_iter 29 | .zip(repeatn(self.reader.clone(), reader_copies)) 30 | .map(|(chunk_indexes, reader)| chunk_indexes.iter().map(|p| f(&reader, *p)).collect()) 31 | .collect(); 32 | chunked_results 33 | .drain(..) 34 | .reduce(|mut a, mut x| { 35 | a.extend(x.drain(..)); 36 | a 37 | }) 38 | .unwrap() 39 | } 40 | 41 | /// Applies the passed in fn to the passed in indexes and collects the result in a vector. Core function for this struct. 42 | pub fn point_map_with_reader + Send + Sync, F, T>( 43 | &self, 44 | points: &[P], 45 | f: F, 46 | ) -> Vec 47 | where 48 | F: Fn(&CoverTreeReader, &P) -> T + Send + Sync, 49 | T: Send + Sync, 50 | { 51 | let point_iter = points.par_chunks(100); 52 | let reader_copies = point_iter.len(); 53 | let mut chunked_results: Vec> = point_iter 54 | .zip(repeatn(self.reader.clone(), reader_copies)) 55 | .map(|(chunk_points, reader)| chunk_points.iter().map(|p| f(&reader, p)).collect()) 56 | .collect(); 57 | chunked_results 58 | .drain(..) 59 | .reduce(|mut a, mut x| { 60 | a.extend(x.drain(..)); 61 | a 62 | }) 63 | .unwrap() 64 | } 65 | 66 | /// Bulk known path 67 | pub fn known_path(&self, point_indexes: &[usize]) -> Vec>> { 68 | self.index_map_with_reader(point_indexes, |reader, i| reader.known_path(i)) 69 | } 70 | 71 | /// Bulk known path 72 | pub fn known_path_and(&self, point_indexes: &[usize], f: F) -> Vec 73 | where 74 | F: Fn(&CoverTreeReader, GokoResult>) -> T + Send + Sync, 75 | T: Send + Sync, 76 | { 77 | self.index_map_with_reader(point_indexes, |reader, i| f(&reader, reader.known_path(i))) 78 | } 79 | 80 | /// Bulk known path 81 | pub fn path + Send + Sync>( 82 | &self, 83 | points: &[P], 84 | ) -> Vec>> { 85 | self.point_map_with_reader(points, |reader, p| reader.path(p)) 86 | } 87 | 88 | /// Bulk knn 89 | pub fn knn + Send + Sync>( 90 | &self, 91 | points: &[P], 92 | k: usize, 93 | ) -> Vec>> { 94 | self.point_map_with_reader(points, |reader, p| reader.knn(p, k)) 95 | } 96 | 97 | /// Bulk routing knn 98 | pub fn routing_knn + Send + Sync>( 99 | &self, 100 | points: &[P], 101 | k: usize, 102 | ) -> Vec>> { 103 | self.point_map_with_reader(points, |reader, p| reader.routing_knn(p, k)) 104 | } 105 | } 106 | 107 | impl> BulkInterface { 108 | /// Applies the passed in fn to the passed in indexes and collects the result in a vector. Core function for this struct. 109 | pub fn array_map_with_reader<'a, F, T>(&self, points: ArrayView2<'a, f32>, f: F) -> Vec 110 | where 111 | F: Fn(&CoverTreeReader, &&[f32]) -> T + Send + Sync, 112 | T: Send + Sync, 113 | { 114 | let indexes: Vec = (0..points.nrows()).collect(); 115 | let point_iter = indexes.par_chunks(100); 116 | let reader_copies = point_iter.len(); 117 | 118 | let mut chunked_results: Vec> = point_iter 119 | .zip(repeatn(self.reader.clone(), reader_copies)) 120 | .map(|(chunk_points, reader)| { 121 | chunk_points 122 | .iter() 123 | .map(|i| f(&reader, &points.row(*i).as_slice().unwrap())) 124 | .collect() 125 | }) 126 | .collect(); 127 | chunked_results 128 | .drain(..) 129 | .reduce(|mut a, mut x| { 130 | a.extend(x.drain(..)); 131 | a 132 | }) 133 | .unwrap() 134 | } 135 | } 136 | 137 | #[cfg(test)] 138 | pub(crate) mod tests { 139 | use super::*; 140 | use std::env; 141 | 142 | use crate::covertree::tests::build_mnist_tree; 143 | 144 | #[test] 145 | fn bulk_path() { 146 | if env::var("TRAVIS_RUST_VERSION").is_err() { 147 | let tree = build_mnist_tree(); 148 | let reader = tree.reader(); 149 | let interface = BulkInterface::new(tree.reader()); 150 | let cloud = reader.point_cloud(); 151 | 152 | let points: Vec<&[f32]> = (0..100).map(|i| cloud.point(i).unwrap()).collect(); 153 | 154 | let path_results = interface.path(&points); 155 | for (i, path) in path_results.iter().enumerate() { 156 | let old_path = reader.path(&cloud.point(i).unwrap()).unwrap(); 157 | for ((d1, a1), (d2, a2)) in (path.as_ref().unwrap()).iter().zip(old_path) { 158 | assert_approx_eq!(*d1, d2); 159 | assert_eq!(*a1, a2); 160 | } 161 | } 162 | } 163 | } 164 | 165 | #[test] 166 | fn bulk_knn() { 167 | if env::var("TRAVIS_RUST_VERSION").is_err() { 168 | let tree = build_mnist_tree(); 169 | let reader = tree.reader(); 170 | let interface = BulkInterface::new(tree.reader()); 171 | let cloud = reader.point_cloud(); 172 | 173 | let points: Vec<&[f32]> = (0..10).map(|i| cloud.point(i).unwrap()).collect(); 174 | 175 | let knn_results = interface.knn(&points, 5); 176 | for (i, knn) in knn_results.iter().enumerate() { 177 | let old_knn = reader.knn(&cloud.point(i).unwrap(), 5).unwrap(); 178 | for ((d1, a1), (d2, a2)) in (knn.as_ref().unwrap()).iter().zip(old_knn) { 179 | assert_approx_eq!(*d1, d2); 180 | assert_eq!(*a1, a2); 181 | } 182 | } 183 | } 184 | } 185 | } 186 | --------------------------------------------------------------------------------