├── .gitattributes ├── .gitignore ├── Klibisz ElastiK Nearest Neighbors.pdf ├── Klibisz ElastiK Nearest Neighbors.pptx ├── README.md ├── demo ├── README.md ├── architecture.png ├── pipeline │ ├── .gitignore │ ├── README.md │ ├── batch_es_aknn_create.py │ ├── batch_es_aknn_index.py │ ├── ec2_es_setup.sh │ ├── ingest_twitter_images.py │ ├── requirements.txt │ ├── stream_compute_image_features.py │ ├── stream_produce_image_pointers.py │ └── twitter-credentials.template.json ├── screencast.gif └── webapp │ ├── .gitignore │ ├── README.md │ ├── app.py │ └── templates │ └── index.html ├── elasticsearch-aknn ├── .gitignore ├── LICENSE.txt ├── NOTICE.txt ├── README.md ├── benchmark │ ├── .gitignore │ ├── README.md │ ├── aknn.py │ ├── figures.ipynb │ ├── glove_download.sh │ ├── glove_preprocess.py │ └── metrics │ │ ├── .gitignore │ │ ├── fig_corpus_vs_time.png │ │ └── fig_recall_vs_time.png ├── build.gradle ├── settings.gradle ├── src │ ├── main │ │ ├── java │ │ │ └── org │ │ │ │ └── elasticsearch │ │ │ │ └── plugin │ │ │ │ └── aknn │ │ │ │ ├── AknnPlugin.java │ │ │ │ ├── AknnRestAction.java │ │ │ │ └── LshModel.java │ │ └── plugin-metadata │ │ │ └── plugin-security.policy │ └── test │ │ └── java │ │ └── org │ │ └── elasticsearch │ │ └── plugin │ │ └── aknn │ │ ├── AknnSimpleIT.java │ │ └── AknnSimpleTests.java └── testplugin.sh └── scratch ├── README.md ├── elasticsearch-plugin ├── .gitignore ├── commands-ann_processor.txt ├── commands-ann_search.txt ├── elasticsearch-aknn │ └── .idea │ │ └── workspace.xml ├── glove-hashing-in-python │ ├── glove_test.py │ └── lsh_model.py ├── glove_create_ann.py └── glove_index_ann.py ├── elasticsearch-tweets ├── .gitignore ├── es_index_tweets.py ├── get_tweet_texts.py └── readme.md ├── es-lsh-glove ├── .gitignore ├── dummy_lsh.py ├── get_glove.py ├── glove_exact.py ├── glove_lsh_es_index.py ├── glove_lsh_es_query.py └── readme.md ├── es-lsh-images ├── .gitignore ├── get_imagenet_vectors_labels.py ├── get_twitter_vectors.py ├── imagenet_es_lsh.ipynb ├── imagenet_knn_exact.ipynb ├── readme.md └── twitter_knn_exact.ipynb ├── image-search-streaming-pipeline ├── .gitignore ├── pom.xml ├── src │ └── main │ │ ├── java │ │ └── ImageSearchStreamingPipeline │ │ │ └── FeatureExtractor.java │ │ └── resources │ │ └── log4j.properties └── target │ ├── classes │ ├── ImageSearchStreamingPipeline │ │ ├── FeatureExtractor$1.class │ │ └── FeatureExtractor.class │ └── log4j.properties │ ├── image-search-streaming-pipeline-0.1.jar │ ├── maven-archiver │ └── pom.properties │ └── maven-status │ └── maven-compiler-plugin │ └── compile │ └── default-compile │ ├── createdFiles.lst │ └── inputFiles.lst ├── kafka-streaming ├── .gitignore ├── imagenet-pizza.JPEG ├── imagenet-ref.py ├── pyconsumer.py ├── pyproducer.py ├── s3-images.txt └── streams.examples │ ├── .gitignore │ ├── pom.xml │ ├── readme.md │ └── src │ └── main │ ├── java │ └── myapps │ │ ├── ImageInfoConsumer.java │ │ ├── ImagePrediction.java │ │ ├── LineSplit.java │ │ ├── ND4JPlayground.java │ │ ├── Pipe.java │ │ └── Wordcount.java │ └── resources │ └── log4j.properties ├── lsh-experiments ├── .gitignore ├── lsh-complexity.ipynb ├── lsh-explore.ipynb ├── lsh-linear-algebra-1.ipynb ├── lsh-linear-algebra-2.ipynb └── metrics.ipynb ├── mvp-big ├── batch_feature_vectors.py ├── kafka_convnet_consumer.py ├── kafka_image_producer.py ├── kafka_reset.sh ├── kafka_watch.sh └── requirements.txt ├── mvp ├── index.html ├── kafka_glove_elasticsearch_insert.py ├── kafka_glove_feature_vectors.py ├── kafka_glove_lsh_vectors.py ├── kafka_image_elasticsearch_insert.py ├── kafka_image_feature_vectors.py ├── kafka_image_lsh_vectors.py ├── kafka_image_s3_keys.py ├── readme.md └── s3_keys_test.txt └── twitter-images ├── .gitignore └── ingest.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Ignore scratch files for repo language statistics. 2 | **/*.ipynb linguist-detectable=false 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | -------------------------------------------------------------------------------- /Klibisz ElastiK Nearest Neighbors.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexklibisz/elastik-nearest-neighbors/fe69992c133bbb56a81ea6067217820b6f30b6e6/Klibisz ElastiK Nearest Neighbors.pdf -------------------------------------------------------------------------------- /Klibisz ElastiK Nearest Neighbors.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexklibisz/elastik-nearest-neighbors/fe69992c133bbb56a81ea6067217820b6f30b6e6/Klibisz ElastiK Nearest Neighbors.pptx -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ElastiK Nearest Neighbors 2 | 3 | [Insight Data Engineering](https://www.insightdataengineering.com/) Project, Boston, April - June 2018 4 | 5 | ## Updates 6 | 7 | **April 2020** 8 | 9 | DEPRECATED -- please see the improved implemenation: https://github.com/alexklibisz/elastiknn 10 | 11 | **January 2020** 12 | 13 | I've made the improved implementation public: https://github.com/alexklibisz/elastiknn 14 | 15 | **November 2019** 16 | 17 | I am actively working on an improved implementation of this plugin. I hope to open-source and release it late 2019 or early 2020 18 | 19 | *** 20 | 21 | | | 22 | |---| 23 | | *Image similarity search demo running searches on a cluster of 4 Elasticsearch nodes* | 24 | 25 | ## Overview 26 | 27 | I built Elasticsearch-Aknn (EsAknn), an Elasticsearch plugin which implements 28 | approximate K-nearest-neighbors search for dense, floating-point vectors in 29 | Elasticsearch. This allows data engineers to avoid rebuilding an infrastructure 30 | for large-scale KNN and instead leverage Elasticsearch's proven distributed 31 | infrastructure. 32 | 33 | To demonstrate the plugin, I used it to implement an image similarity search 34 | engine for a corpus of 6.7 million Twitter images. I transformed each image 35 | into a 1000-dimensional floating-point feature vector using a convolutional 36 | neural network. I used EsAknn to store the vectors and search for nearest 37 | neighbors on an Elasticsearch cluster. 38 | 39 | The repository is structured: 40 | 41 | - `demo` directory: Twitter Image Similarity search pipeline and web-application 42 | - `elasticsearch-aknn` directory: EsAknn implementation and benchmarks 43 | - `scratch` directory: several smaller projects implemented while prototyping 44 | 45 | ## Demo 46 | 47 | - ~[Web-application](http://elastiknn.klibisz.com/twitter_images/twitter_image/demo)~ (Taken down at the end of the Insight program) 48 | - [Screencast demo on Youtube](https://www.youtube.com/watch?v=HqvbbwmY-0c) 49 | - [Presentation on Google Slides](https://docs.google.com/presentation/d/1AyIyBqzCqKhytZWcQfSEhtBRN-iHUldBQn14MGGKpr8/edit?usp=sharing) 50 | 51 | ## Elasticsearch-Aknn 52 | 53 | ### Usecase 54 | 55 | EsAknn is useful for problems roughly characterized: 56 | 57 | 1. Have a large corpus of feature vectors with dimensionality ~50-1000. 58 | 2. Need to run similarity searches using K-Nearest-Neighbors. 59 | 3. Need to scale horizontally to support many concurrent similarity searches. 60 | 4. Need to support a growing corpus with near-real-time insertions. I.e., 61 | when a new vector is created/ingested, it should be available for searching in 62 | less than 10 minutes. 63 | 64 | ### How does it compare to other approximate-nearest-neighbors libraries? 65 | 66 | Tldr: If you need to quickly run KNN on an extremely large corpus in an offline 67 | job, use one of the libraries from [Ann-Benchmarks](https://github.com/erikbern/ann-benchmarks). 68 | If you need KNN in an online setting with support for horizontally-scalable searching and 69 | indexing new vectors in near-real-time, consider EsAknn (especially if you already 70 | use Elasticsearch). 71 | 72 | There are about a dozen high-quality open-source approximate-nearest-neighbors libraries. 73 | The [Ann-Benchmarks](https://github.com/erikbern/ann-benchmarks) project is a great place to 74 | compare them. Most of them take a large corpus of vectors, build an index, and 75 | expose an interface to run very fast nearest-neighbors search on that fixed corpus. 76 | 77 | Unfortunately they offer very little infrastructure for deploying your 78 | nearest-neighbors search in an online setting. Specifically, you still have to consider: 79 | 80 | 1. Where do you store millions of vectors and the index? 81 | 2. How do you handle many concurrent searches? 82 | 3. How do you handle a growing corpus? See [this issue](https://github.com/erikbern/ann-benchmarks/issues/36) 83 | on the lack of support for adding to an index. 84 | 4. How do you distribute the index and make searches fault tolerant? 85 | 5. Who manages all the infrastructure you've built for such a simple algorithm? 86 | 87 | Elasticsearch already solves the non-trivial infrastrcture problems, and 88 | EsAknn implements approximate nearest-neighbors indexing and search atop 89 | this proven infrastructure. 90 | 91 | EsAknn's LSH implementation is very simplistic in the grand scheme of 92 | approximate-nearest-neighbors approaches, but it maps well to Elasticsearch and still 93 | yields high recall. EsAknn's speed for serial queries is much slower than other 94 | approximate nearest-neighbor libraries, but it's also not designed for serial 95 | querying. Instead it's designed to serve many concurrent searches over a convenient 96 | HTTP endpoint, index new vectors in near-real-time, and scale horizontally 97 | with Elasticsearch. For specific performance numbers, see the performance section 98 | below and the slides linked in the demo section. 99 | 100 | ### API 101 | 102 | #### Create LSH Model 103 | 104 | Given a sample of vectors, create a locality-sensitive-hashing (LSH) model 105 | and store it as an Elasticsearch document. 106 | 107 | ``` 108 | POST :9200/_aknn_create 109 | 110 | { 111 | "_index": "aknn_models", 112 | "_type": "aknn_model", 113 | "_id": "twitter_image_search", 114 | "_source": { 115 | "_aknn_description": "LSH model for Twitter image similarity search", 116 | "_aknn_nb_tables": 64, 117 | "_aknn_nb_bits_per_table": 18, 118 | "_aknn_nb_dimensions": 1000 119 | }, 120 | "_aknn_vector_sample": [ 121 | # Provide a sample of 2 * _aknn_nb_tables * _aknn_nb_bits_per_table vectors 122 | [0.11, 0.22, ...], 123 | [0.22, 0.33, ...], 124 | ... 125 | [0.88, 0.99, ...] 126 | ] 127 | } 128 | ``` 129 | 130 | This returns: 131 | 132 | ``` 133 | { "took": } 134 | ``` 135 | 136 | #### Index New Vectors 137 | 138 | Given a batch of new vectors, hash each vector using a pre-defined LSH model 139 | and store its raw and hashed values in an Elasticsearch document. 140 | 141 | ``` 142 | POST :9200/_aknn_index 143 | 144 | { 145 | "_index": "twitter_images", 146 | "_type": "twitter_image", 147 | "_aknn_uri": "aknn_models/aknn_model/twitter_image_search" 148 | "_aknn_docs": [ 149 | { 150 | "_id": 1, 151 | "_source": { 152 | "_aknn_vector": [0.12, 0.23, ...], 153 | 154 | # Any other fields you want... 155 | } 156 | }, ... 157 | ] 158 | } 159 | ``` 160 | 161 | This returns: 162 | 163 | ``` 164 | { "took": , "size": } 165 | ``` 166 | 167 | #### Similarity Search 168 | 169 | Given a vector in the index, search for and return its nearest neighbors. 170 | 171 | ``` 172 | GET :9200/twitter_images/twitter_image/1/_aknn_search?k1=1000&k2=10 173 | ``` 174 | 175 | This returns: 176 | 177 | ``` 178 | { 179 | "took": , 180 | "timed_out": false, 181 | 182 | "hits": { 183 | "max_score": 0, 184 | "total": , 185 | "hits": [ 186 | { 187 | "_id": "...", 188 | '_index': "twitter_images", 189 | "_score": , 190 | '_source': { 191 | # All of the document fields except for the potentially 192 | # large fields containing the vector and hashes. 193 | } 194 | }, ... 195 | ] 196 | 197 | } 198 | 199 | } 200 | ``` 201 | 202 | ### Implementation 203 | 204 | The key things to know about the implementation are: 205 | 206 | 1. EsAknn runs entirely in an existing Elasticsearch cluster/node. It operates 207 | effectively as a set of HTTP endpoint handlers and talks to Elasticsearch via 208 | the [Java Client API.](https://www.elastic.co/guide/en/elasticsearch/client/java-api/current/client.html) 209 | 2. Searches can run in parallel. New vectors can be indexed on multiple nodes 210 | in parallel using a round-robin strategy. Parallel indexing on a single node has 211 | not been tested extensively. 212 | 3. EsAknn uses [Locality Sensitive Hashing](https://en.wikipedia.org/wiki/Locality-sensitive_hashing) 213 | to convert a floating-point vector into a discrete representation which can be 214 | efficiently indexed and retrieved in Elasticsearch. 215 | 4. EsAknn stores the LSH models and the vectors as standard documents. 216 | 5. EsAknn uses a [Bool Query](https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-bool-query.html) 217 | to find `k1` approximate nearest neighbors based on discrete hashes. It then 218 | computes the exact distance to each of these approximate neighbors and returns 219 | the `k2` closest. For example, you might set `k1 = 1000` and `k2 = 10`. 220 | 6. EsAknn currently only implements euclidean distance, but any distance 221 | function compatible with LSH can be added. 222 | 223 | ### Performance 224 | 225 | #### Speed 226 | 227 | EsAknn's speed is generally characterized: 228 | 229 | 1. Create a new LSH model: < 1 minute. 230 | 2. Index new vectors: hundreds to low thousands per second. 231 | 3. Search for a vector's neighbors: < 500 milliseconds. Search time scales 232 | sub-linearly with the size of the corpus. 233 | 234 | The corpus vs. search time generally follows a sub-linear pattern like this: 235 | 236 | 238 | 239 | Beyond that, speed is a function of: 240 | 241 | 1. The vectors' dimensionality. 242 | 2. The number of tables (a.k.a. hash functions or trees) in the LSH model. 243 | 3. The number of bits in the LSH model's hashes. 244 | 4. The number of approximate neighbors retrieved, `k1`. 245 | 5. The number of exact neighbors returned, `k2`. 246 | 247 | In the image similarity search engine, you can see that searches against an 248 | index of 6.7 million 1000-dimensional vectors rarely exceed 200 milliseconds. 249 | 250 | #### Recall 251 | 252 | Recall is defined as the proportion of true nearest neighbors returned for 253 | a search and can be evaluated at various values of `k2`. For example, if 254 | you know your application needs to retrieve the top ten most similar items, 255 | you should evaluate recall at `k2 = 10`. 256 | 257 | Similar to speed, recall depends on the LSH configuration. Increasing `k1` 258 | is typically the easiest way to increase recall, but the number of tables and 259 | bits also play an important role. Finding a configuration to maximize 260 | recall and minimize search time can be considered a form of hyper-parameter 261 | optimization. 262 | 263 | The figure below demonstrates that it is possible to find a configuration 264 | with high-recall and low search-time at various corpus sizes. The points plotted 265 | represent the "frontier" of recall/search-time. That is, I ran benchmarks on 266 | many configurations and chose the configurations with the lowest median search 267 | time for each median recall across three corpus sizes. 268 | 269 | 271 | 272 | The table below shows the best configuration for each combination of corpus size, 273 | median recall, median search time with a median recall >= 0.5. 274 | 275 | | | Corpus size | Med. recall | Med. search time | k1 | _aknn_nb_tables | _aknn_nb_bits_per_table | 276 | |----|---------------|---------------|--------------------|------|-------------------|---------------------------| 277 | | 0 | 1000000 | 1 | 191 | 500 | 200 | 12 | 278 | | 1 | 1000000 | 0.9 | 100 | 500 | 100 | 14 | 279 | | 2 | 1000000 | 0.8 | 62 | 1000 | 50 | 16 | 280 | | 3 | 1000000 | 0.7 | 49 | 500 | 50 | 16 | 281 | | 4 | 1000000 | 0.6 | 43 | 250 | 50 | 16 | 282 | | 5 | 1000000 | 0.5 | 50 | 250 | 50 | 19 | 283 | | 6 | 100000 | 1 | 26 | 250 | 100 | 12 | 284 | | 7 | 100000 | 0.9 | 21 | 500 | 50 | 14 | 285 | | 8 | 100000 | 0.8 | 14 | 250 | 50 | 18 | 286 | | 9 | 100000 | 0.7 | 11 | 100 | 50 | 14 | 287 | | 10 | 100000 | 0.6 | 11 | 100 | 50 | 19 | 288 | | 11 | 100000 | 0.5 | 14 | 500 | 10 | 8 | 289 | | 12 | 10000 | 1 | 8 | 100 | 100 | 8 | 290 | | 13 | 10000 | 0.9 | 5 | 100 | 50 | 12 | 291 | | 14 | 10000 | 0.8 | 5 | 100 | 50 | 18 | 292 | | 15 | 10000 | 0.7 | 6 | 250 | 10 | 8 | 293 | | 16 | 10000 | 0.6 | 6 | 15 | 100 | 18 | 294 | | 17 | 10000 | 0.5 | 3 | 15 | 50 | 14 | 295 | 296 | ## Image Processing Pipeline 297 | 298 | ### Implementation 299 | 300 | 301 | 302 | The image processing pipeline consists of the following components, shown in 303 | pink and green above: 304 | 305 | 1. Python program ingests images from the Twitter public stream and stores in S3. 306 | 2. Python program publishes batches of references to images stored in S3 to a 307 | Kafka topic. 308 | 3. Python program consumes batches of image references, computes feature 309 | vectors from the images, stores them on S3, publishes references to Kafka. 310 | I use the `conv_pred` layer from 311 | [Keras pre-trained MobileNet](https://keras.io/applications/#mobilenet) 312 | to compute the 1000-dimensional feature vectors. 313 | 4. Python program consumes image features from Kafka/S3 and indexes them in 314 | Elasticsearch via EsAknn. 315 | 316 | ### Performance 317 | 318 | Image feature extraction is the main bottleneck in this pipeline. It's 319 | embarrassingly parallel but still requires thoughtful optimization. In the end 320 | I was able to compute: 321 | 322 | 1. 40 images / node / second on EC2 P2.xlarge (K80 GPU, $0.3/hr spot instance). 323 | 2. 33 images / node / second on EC2 C5.9xlarge (36-core CPU, $0.6/hr spot instance). 324 | 325 | My first-pass plateaued at about 2 images / node / second. I was able to improve 326 | throughput with the following optimizations: 327 | 328 | 1. Produce image references to Kafka instead of full images. This allows 329 | many workers to download the images in parallel from S3. If you send the full 330 | images through Kafka, it quickly becomes a bottleneck. 331 | 2. Workers use thread pools to download images in parallel from S3. 332 | 3. Workers use process pools to crop and resize images for use with Keras. 333 | 4. Workers use the [Lycon library](https://github.com/ethereon/lycon) for fast image resizing. 334 | 3. Workers use Keras/Tensorflow to compute feature vectors on large batches of 335 | images instead of single images. This is a standard deep learning optimization. 336 | 337 | ### Elasticsearch Versions 338 | 339 | - The plugin was developed on Elasticsearch version 6.2.4. 340 | - User [mingruimingrui has a fork fork for version 5.6.6](https://github.com/mingruimingrui/elastik-nearest-neighbors). 341 | - User [mattiasarro has a fork for version 6.3](https://github.com/mattiasarro/elastik-nearest-neighbors) 342 | 343 | ## Helpful Resources 344 | 345 | Here are a handful of resources I found particularly helpful for this project: 346 | 347 | 1. [Locality Sensitive Hashing lectures by Victor Lavrenko](https://www.youtube.com/watch?v=Arni-zkqMBA) 348 | 2. [Elasticsearch Plugin Starter by Alexander Reelsen](https://github.com/spinscale/cookiecutter-elasticsearch-ingest-processor) 349 | 3. [Elasticsearch OpenNLP Plugin by Alexander Reelsen](https://github.com/spinscale/elasticsearch-ingest-opennlp) 350 | 4. Discussions about similarity search in Elasticsearch: [one](https://discuss.elastic.co/t/what-is-the-best-scheme-for-similarity-search-based-on-binary-codes-with-elasticsearch-context-is-image-search-with-deep-nets/42915/6), [two](https://stackoverflow.com/questions/32785803/similar-image-search-by-phash-distance-in-elasticsearch). 351 | -------------------------------------------------------------------------------- /demo/README.md: -------------------------------------------------------------------------------- 1 | Code for the pipeline used to setup and run the Twitter Image Similarity search demo. 2 | 3 | Please see the top-level readme and the readmes in `pipeline` and `webapp` for 4 | more details. -------------------------------------------------------------------------------- /demo/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexklibisz/elastik-nearest-neighbors/fe69992c133bbb56a81ea6067217820b6f30b6e6/demo/architecture.png -------------------------------------------------------------------------------- /demo/pipeline/.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | twitter-credentials.json -------------------------------------------------------------------------------- /demo/pipeline/README.md: -------------------------------------------------------------------------------- 1 | 2 | Workers and scripts for the Twitter Image Similarity search demo. 3 | 4 | ## Overview 5 | 6 | Below is a terse overview of the functionality for each program in the pipeline. 7 | See individual programs for more detail. 8 | 9 | 1. `ingest_twitter_images.py` ingests tweets from Twitter's streaming API and 10 | saves posted images locally and on S3. This ingests between 500K and 700K images 11 | per day. `twitter-credentials.template.json` should be updated with your Twitter 12 | API credentials to run this program. 13 | 2. `stream_produce_image_pointers.py` produces pointers to images to a Kafka topic. 14 | A pointer is simply the S3 bucket and key where the image file is stored. 15 | 3. `stream_compute_image_features.py` consumes images pointers and computes 16 | a floating-point feature vector for each image. It stores the feature vectors 17 | on S3 and publishes a pointer to the features to a Kafka topic. This program was 18 | designed such that many instances can be run in parallel to speed up computation. 19 | As long as they are all in the same Kafka consumer group, each one will get 20 | independent chunks of the processing load. 21 | 4. `batch_es_aknn_create.py` creates an LSH model in Elasticsearch via the Elasticsearch-Aknn plugin. 22 | 5. `batch_es_aknn_index.py` indexes feature vectors in Elasticsearch via the 23 | Elasticsearch-Aknn plugin. 24 | 25 | ## Usage 26 | 27 | Install dependencies: `pip3 install -r requirements.txt` 28 | 29 | All Python programs implement an argparse CLI, so you can run `python .py --help` to see the exact parameters. 30 | 31 | Most of the programs require a Kafka cluster and an Elasticsearch cluster. 32 | Instructions to set them up is beyond the scope of this brief documentation, 33 | however the `ec2_es_setup.sh` script should be helpful for Elasticsearch. -------------------------------------------------------------------------------- /demo/pipeline/batch_es_aknn_create.py: -------------------------------------------------------------------------------- 1 | """ 2 | Create an Elasticsearch-Aknn model from feature documents stored on disk or S3. 3 | """ 4 | 5 | from argparse import ArgumentParser 6 | from io import BytesIO 7 | from numpy import mean, std 8 | from pprint import pformat 9 | from sys import stderr 10 | from time import time 11 | import boto3 12 | import json 13 | import gzip 14 | import os 15 | import random 16 | import requests 17 | 18 | 19 | def iter_docs(source_str): 20 | 21 | if source_str.startswith("s3://"): 22 | bucket = boto3.resource("s3").Bucket(source_str.replace("s3://", '')) 23 | for obj in bucket.objects.all(): 24 | body = obj.get().get('Body') 25 | buff = BytesIO(body.read()) 26 | with gzip.open(buff) as fp: 27 | yield json.loads(fp.read().decode()) 28 | else: 29 | for fobj in os.scandir(source_str): 30 | with gzip.open(fobj.path) as fp: 31 | yield json.loads(fp.read().decode()) 32 | 33 | 34 | if __name__ == "__main__": 35 | 36 | ap = ArgumentParser(description="See script") 37 | ap.add_argument("features_source", 38 | help="Directory or S3 bucket containing image feature docs.") 39 | ap.add_argument("--es_host", default="http://localhost:9200", 40 | help="URL of single elasticsearch server.") 41 | ap.add_argument("--aknn_tables", type=int, default=64) 42 | ap.add_argument("--aknn_bits", type=int, default=18) 43 | ap.add_argument("--aknn_dimensions", type=int, default=1000) 44 | ap.add_argument("-p", type=float, default=0.2, 45 | help="Prob. of accepting a feature document as a sample.") 46 | args = vars(ap.parse_args()) 47 | 48 | # Prepare the Aknn model mapping. 49 | mapping = { 50 | "properties": { 51 | "_aknn_midpoints": { 52 | "type": "half_float", 53 | "index": False 54 | }, 55 | "_aknn_normals": { 56 | "type": "half_float", 57 | "index": False 58 | }, 59 | "_aknn_nb_bits_per_table": { 60 | "type": "short", 61 | "index": False 62 | }, 63 | "_aknn_nb_dimensions": { 64 | "type": "short", 65 | "index": False 66 | }, 67 | "_aknn_nb_tables": { 68 | "type": "short", 69 | "index": False 70 | } 71 | } 72 | } 73 | 74 | # Body for posting new vectors. 75 | body = { 76 | "_index": "aknn_models", 77 | "_type": "aknn_model", 78 | "_id": "twitter_images", 79 | "_source": { 80 | "_aknn_description": "AKNN model for images on the twitter public stream", 81 | "_aknn_nb_dimensions": args["aknn_dimensions"], 82 | "_aknn_nb_tables": args["aknn_tables"], 83 | "_aknn_nb_bits_per_table": args["aknn_bits"] 84 | }, 85 | "_aknn_vector_sample": [ 86 | # Populated below. 87 | ] 88 | } 89 | 90 | # Delete and remake the index. 91 | print("Deleting index %s" % body["_index"]) 92 | index_url = "%s/%s" % (args["es_host"], body["_index"]) 93 | req = requests.delete(index_url) 94 | assert req.status_code == 200, "Failed to delete index: %s" % json.dumps(req.json()) 95 | 96 | print("Creating index %s" % body["_index"]) 97 | req = requests.put(index_url) 98 | assert req.status_code == 200, "Failed to create index: %s" % json.dumps(req.json()) 99 | 100 | # Put the mapping. This can fail if you already have this index/type setup. 101 | print("Creating mapping for index %s" % body["_index"]) 102 | mapping_url = "%s/%s/%s/_mapping" % (args["es_host"], body["_index"], body["_type"]) 103 | req = requests.put(mapping_url, json=mapping) 104 | assert req.status_code == 200, "Failed to create mapping: %s" % json.dumps(req.json()) 105 | 106 | # Create an iterable over the feature documents. 107 | docs = iter_docs(args["features_source"]) 108 | 109 | # Populate the vector sample by randomly sampling vectors from iterable. 110 | nb_samples = 2 * args["aknn_bits"] * args["aknn_tables"] 111 | print("Sampling %d feature vectors from %s" % (nb_samples, args["features_source"])) 112 | while len(body["_aknn_vector_sample"]) < nb_samples: 113 | vec = next(docs)["feature_vector"] 114 | if random.random() <= args["p"]: 115 | body["_aknn_vector_sample"].append(vec) 116 | 117 | print("Sample mean, std = %.3lf, %.3lf" % ( 118 | mean(body["_aknn_vector_sample"]), 119 | std(body["_aknn_vector_sample"]))) 120 | 121 | print("Posting to Elasticsearch") 122 | t0 = time() 123 | res = requests.post("%s/_aknn_create" % args["es_host"], json=body) 124 | if res.status_code == requests.codes.ok: 125 | print("Successfully built model in %d seconds" % (time() - t0)) 126 | print(pformat(res.json())) 127 | else: 128 | print("Failed with error code %d" % res.status_code, file=stderr) 129 | -------------------------------------------------------------------------------- /demo/pipeline/batch_es_aknn_index.py: -------------------------------------------------------------------------------- 1 | """Populate Elasticsearch-aknn documents from feature docs on disk or S3. 2 | """ 3 | 4 | from argparse import ArgumentParser 5 | from concurrent.futures import ThreadPoolExecutor, as_completed 6 | from io import BytesIO 7 | from itertools import cycle 8 | from more_itertools import chunked 9 | from numpy import mean, std 10 | from pprint import pformat 11 | from time import time 12 | import boto3 13 | import json 14 | import gzip 15 | import os 16 | import pdb 17 | import random 18 | import requests 19 | import sys 20 | 21 | 22 | def iter_docs(src, skip=0, alphasort=False): 23 | 24 | if src.startswith("s3://"): 25 | bucket = boto3.resource("s3").Bucket(src.replace("s3://", '')) 26 | for i, obj in enumerate(bucket.objects.all()): 27 | if i < skip: 28 | continue 29 | body = obj.get().get('Body') 30 | buff = BytesIO(body.read()) 31 | with gzip.open(buff) as fp: 32 | yield json.loads(fp.read().decode()) 33 | else: 34 | 35 | if alphasort: 36 | iter_ = sorted(os.scandir(src), key=lambda f: f.path) 37 | else: 38 | iter_ = os.scandir(src) 39 | 40 | for i, fobj in enumerate(iter_): 41 | if i < skip: 42 | continue 43 | with gzip.open(fobj.path) as fp: 44 | yield json.loads(fp.read().decode()) 45 | 46 | 47 | if __name__ == "__main__": 48 | 49 | ap = ArgumentParser(description="See script") 50 | ap.add_argument("features_src", 51 | help="Directory or S3 bucket containing image feature docs.") 52 | ap.add_argument("--es_hosts", default="http://localhost:9200", 53 | help="Comma-separated elasticsearch host URLs.") 54 | ap.add_argument("-b", "--batch_size", type=int, default=1000, 55 | help="Batch size for elasticsearch indexing.") 56 | args = vars(ap.parse_args()) 57 | 58 | # Parse multiple hosts. 59 | es_hosts = args["es_hosts"].split(",") 60 | es_hosts_cycle = cycle(es_hosts) 61 | 62 | # Prepare the document structure. 63 | body = { 64 | "_index": "twitter_images", 65 | "_type": "twitter_image", 66 | "_aknn_uri": "aknn_models/aknn_model/twitter_images", 67 | "_aknn_docs": [ 68 | # Populated below with structure: 69 | # { 70 | # "_id": "...", 71 | # "_source": { 72 | # "any_fields_you_want": "...", 73 | # "_aknn_vector": [0.1, 0.2, ...] 74 | # } 75 | # }, ... 76 | ] 77 | } 78 | 79 | mapping = { 80 | "properties": { 81 | "_aknn_vector": { 82 | "type": "half_float", 83 | "index": False 84 | } 85 | } 86 | } 87 | 88 | # Check if the index exists and get its count. 89 | count_url = "%s/%s/%s/_count" % (next(es_hosts_cycle), body["_index"], body["_type"]) 90 | req = requests.get(count_url) 91 | count = 0 if req.status_code == 404 else req.json()["count"] 92 | print("Found %d existing documents in index" % count) 93 | 94 | # If the index does not exist, create its mapping. 95 | if req.status_code == 404: 96 | print("Creating index %s" % body["_index"]) 97 | index_url = "%s/%s" % (next(es_hosts_cycle), body["_index"]) 98 | req = requests.put(index_url) 99 | assert req.status_code == 200, json.dumps(req.json()) 100 | 101 | print("Creating mapping for type %s" % body["_type"]) 102 | mapping_url = "%s/%s/%s/_mapping" % ( 103 | next(es_hosts_cycle), body["_index"], body["_type"]) 104 | requests.put(mapping_url, json=mapping) 105 | assert req.status_code == 200, json.dumps(req.json()) 106 | 107 | # Create an iterable over the feature documents. 108 | docs = iter_docs(args["features_src"], count, True) 109 | 110 | # Bookkeeping for round-robin indexing. 111 | docs_batch = [] 112 | tpool = ThreadPoolExecutor(max_workers=len(es_hosts)) 113 | nb_round_robin_rem = len(es_hosts) * args["batch_size"] 114 | nb_indexed = 0 115 | T0 = -1 116 | 117 | for doc in docs: 118 | 119 | if T0 < 0: 120 | T0 = time() 121 | 122 | aknn_doc = { 123 | "_id": doc["id"], 124 | "_source": { 125 | "twitter_url": "https://twitter.com/statuses/%s" % doc["id"], 126 | "imagenet_labels": doc["imagenet_labels"], 127 | "s3_url": "https://s3.amazonaws.com/%s/%s" % ( 128 | doc["img_pointer"]["s3_bucket"], doc["img_pointer"]["s3_key"]), 129 | "_aknn_vector": doc["feature_vector"] 130 | } 131 | } 132 | 133 | docs_batch.append(aknn_doc) 134 | nb_round_robin_rem -= 1 135 | if nb_round_robin_rem > 0: 136 | continue 137 | 138 | futures = [] 139 | for h, d in zip(es_hosts, chunked(docs_batch, args["batch_size"])): 140 | body["_aknn_docs"] = d 141 | post_url = "%s/_aknn_index" % h 142 | futures.append(tpool.submit(requests.post, post_url, json=body)) 143 | print("Posting %d docs to host %s" % (len(body["_aknn_docs"]), h)) 144 | 145 | for f, h in zip(as_completed(futures), es_hosts): 146 | res = f.result() 147 | if res.status_code != 200: 148 | print("Error at host: %s" % h, res.json(), file=sys.stderr) 149 | sys.exit(1) 150 | print("Response %d from host %s:" % (res.status_code, h), res.json()) 151 | nb_indexed += res.json()["size"] 152 | 153 | print("Indexed %d docs in %d seconds = %.2lf docs / second" % ( 154 | nb_indexed, time() - T0, nb_indexed / (time() - T0))) 155 | 156 | # Reset bookkeeping. 157 | nb_round_robin_rem = len(es_hosts) * args["batch_size"] 158 | docs_batch = [] 159 | -------------------------------------------------------------------------------- /demo/pipeline/ec2_es_setup.sh: -------------------------------------------------------------------------------- 1 | # !/bin/sh 2 | # Script to setup Elasticsearch on an Ubuntu EC2 instance as part of a cluster. 3 | # Usage: ./ 4 | 5 | set -e 6 | 7 | clustername=$1 8 | esdir="$HOME/ES624" 9 | cnf="$esdir/config/elasticsearch.yml" 10 | 11 | # Increase memory setting for Elasticsearch. 12 | sudo sysctl -w vm.max_map_count=262144 13 | 14 | # Remove, re-download, unzip Elasticsearch binaries. 15 | rm -rf $esdir 16 | wget https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-6.2.4.tar.gz 17 | tar xvf elasticsearch-6.2.4.tar.gz 18 | rm elasticsearch-6.2.4.tar.gz 19 | mv elasticsearch-6.2.4 $esdir 20 | 21 | # Install/update JVM. 22 | sudo apt-get update -y 23 | sudo apt-get install -y default-jre htop 24 | 25 | # Build a simple config file. 26 | echo "" > $cnf 27 | echo "cluster.name: $clustername" >> $cnf 28 | echo "node.name: $(cat /etc/hostname)" >> $cnf 29 | echo "path.data: $HOME/esdata" >> $cnf 30 | echo "path.logs: $HOME/eslogs" >> $cnf 31 | echo "network.host: 0.0.0.0" >> $cnf 32 | echo "action.destructive_requires_name: true" >> $cnf 33 | echo "http.cors.enabled: true" >> $cnf 34 | echo "http.cors.allow-origin: /(null)|(https?:\/\/localhost(:[0-9]+)?)/" >> $cnf 35 | 36 | # Note: to get ec2 discovery working, either assign an IAM role with EC2 permissions 37 | # to the instances running elasticsearch, or set the AWS_ACCESS_KEY_ID and 38 | # AWS_SECRET_ACCESS_KEY environment variables on the instance. 39 | echo "discovery.zen.hosts_provider: ec2" >> $cnf 40 | bash $esdir/bin/elasticsearch-plugin install -b discovery-ec2 41 | 42 | # Print useful information about the configuration. 43 | echo "----" 44 | cat $cnf 45 | echo "----" 46 | which java 47 | java -version 48 | echo "----" 49 | echo "Done" 50 | -------------------------------------------------------------------------------- /demo/pipeline/ingest_twitter_images.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ingest images from Twitter stream to S3. 3 | 4 | Input: listening to the Twitter stream API for statuses containing an image. 5 | Compute: save a local copy of the image and status, upload the image and status 6 | to an S3 bucket. 7 | 8 | Note that this worker uses thread-based concurrency to ingest statuses and images 9 | without blocking the Twitter stream. That is, as soon as a Tweet with an 10 | image is detected, it gets handed off to a separate thread to download the 11 | image, and store the image and status locally and on S3. 12 | 13 | """ 14 | 15 | from argparse import ArgumentParser 16 | from io import BytesIO 17 | from tweepy import OAuthHandler, API, Stream, StreamListener 18 | from threading import Thread 19 | from time import time 20 | import boto3 21 | import json 22 | import gzip 23 | import urllib.request 24 | import os 25 | import pdb 26 | import sys 27 | 28 | 29 | class Listener(StreamListener): 30 | """Extension on the Tweepy StreamListener that implements some simple logic 31 | for ingesting images from Twitter. 32 | 33 | Arguments 34 | s3_bucket: a boto3 Bucket instance for bucket where images, statuses are stored. 35 | images_dir: local directory where images are stored. 36 | statuses_dir: local directory where statuses are stored. 37 | 38 | """ 39 | 40 | def __init__(self, s3_bucket, images_dir, statuses_dir, **kwargs): 41 | self.cnt_all = len(os.listdir(images_dir)) 42 | self.cnt_new = 0 43 | self.t0 = time() 44 | self.s3_bucket = s3_bucket 45 | self.images_dir = images_dir 46 | self.statuses_dir = statuses_dir 47 | super().__init__(kwargs) 48 | 49 | def _ingest_status(self, status): 50 | """Internal function to ingest a single status 51 | 52 | Can be invoked as a thread to prevent blocking the main loop. 53 | 54 | Arguments 55 | status: Tweepy Status instance containing at least one image. 56 | """ 57 | 58 | t0 = time() 59 | 60 | # Download first image to disk and upload it to S3. 61 | # Some statuses have > 1 image, but it"s very rare. 62 | item = status.entities["media"][0] 63 | ext = item["media_url"].split(".")[-1] 64 | image_key = "%d.%s" % (status.id, ext) 65 | local_path = "%s/%s" % (self.images_dir, image_key) 66 | urllib.request.urlretrieve(item["media_url"], local_path) 67 | self.s3_bucket.upload_file(local_path, image_key) 68 | 69 | # Save status to disk as gzipped JSON. 70 | status_key = "%d.json.gz" % status.id 71 | local_path = "%s/%s" % (self.statuses_dir, status_key) 72 | with gzip.open(local_path, "wb") as fp: 73 | fp.write(json.dumps(status._json).encode()) 74 | 75 | self.s3_bucket.upload_file(local_path, status_key) 76 | print("%.3lf %d" % (time() - t0, status.id)) 77 | 78 | def on_status(self, status): 79 | """Implementation of the function invoked for every new status.""" 80 | 81 | # Skip any status not containing images. 82 | if "media" not in status.entities: 83 | return 84 | 85 | # Create and start a thread to ingest the status. 86 | t = Thread(target=self._ingest_status, args=(status,)) 87 | t.start() 88 | 89 | # Book-keeping and logging. 90 | self.cnt_new += 1 91 | self.cnt_all += 1 92 | time_sec = time() - self.t0 93 | time_day = time_sec / (24 * 60 * 60) 94 | print("%d total, %d new, %d per day" % ( 95 | self.cnt_all, self.cnt_new, self.cnt_new / time_day)) 96 | 97 | 98 | if __name__ == "__main__": 99 | 100 | ap = ArgumentParser(description="See script header") 101 | ap.add_argument("--statuses_dir", 102 | help="Local directory where statuses are stored", 103 | default="data/twitter_stream/statuses") 104 | ap.add_argument("--images_dir", 105 | help="Local directory where images are stored", 106 | default="data/twitter_stream/images") 107 | ap.add_argument("--s3_bucket", 108 | default="klibisz-twitter-stream", 109 | help="Name of AWS S3 bucket where images and statuses are stored") 110 | ap.add_argument("--twitter_credentials_path", 111 | default="twitter-credentials.json", 112 | help="Path to JSON file containing Twitter API credentials") 113 | args = vars(ap.parse_args()) 114 | 115 | # Setup Twitter API client. 116 | with open(args["twitter_credentials_path"]) as fp: 117 | twcreds = json.load(fp) 118 | auth = OAuthHandler(twcreds["consumer_key"], twcreds["consumer_secret"]) 119 | auth.set_access_token(twcreds["access_token"], twcreds["token_secret"]) 120 | twitter = API(auth, wait_on_rate_limit=True, wait_on_rate_limit_notify=True, 121 | retry_count=10, retry_delay=1) 122 | 123 | # Setup S3 API client using credentials in $HOME/.aws or env. variables. 124 | s3_bucket = boto3.resource("s3").Bucket(args["s3_bucket"]) 125 | 126 | # Setup and run stream listener. 127 | listener = Listener(s3_bucket, args["images_dir"], args["statuses_dir"]) 128 | stream = Stream(auth=twitter.auth, listener=listener) 129 | stream.sample() 130 | -------------------------------------------------------------------------------- /demo/pipeline/requirements.txt: -------------------------------------------------------------------------------- 1 | boto3 2 | kafka-python 3 | lycon 4 | numpy 5 | pillow 6 | imageio 7 | tqdm 8 | tweepy 9 | -------------------------------------------------------------------------------- /demo/pipeline/stream_compute_image_features.py: -------------------------------------------------------------------------------- 1 | """ 2 | Worker to compute features from images. 3 | 4 | Input: pointers to images in S3, consumed from a Kafka topic. 5 | Compute: download and preprocess images, compute their features (imagenet 6 | labels and 1000-dimensional floating-point feature vectors). 7 | Output: Upload features to S3, publish a pointer to the features to a Kafka topic. 8 | 9 | Note that this worker is heavily optimized for concurrency/parallelism: 10 | 1. Threadpool to download images from S3 in parallel. 11 | 2. MultiProcessing pool to resize and preprocess images for compatibility with Keras. 12 | 3. Compute features in large batches via Keras/Tensorflow. 13 | 4. Many instances of this worker can be run in parallel, as long as they all 14 | use the same Kafka group ID. 15 | 16 | """ 17 | 18 | from argparse import ArgumentParser 19 | from concurrent.futures import ThreadPoolExecutor, wait 20 | from kafka import KafkaConsumer, KafkaProducer 21 | from io import BytesIO 22 | from multiprocessing import Pool, cpu_count 23 | from imageio import imread 24 | from lycon import resize 25 | from pprint import pformat 26 | from sys import stderr 27 | from time import time 28 | import boto3 29 | import gzip 30 | import json 31 | import numpy as np 32 | 33 | from keras.models import Model 34 | from keras.applications import MobileNet 35 | from keras.applications.imagenet_utils import preprocess_input, decode_predictions 36 | 37 | 38 | def S3Pointer(id, s3_bucket, s3_key): 39 | """Function to formalize the structure of an S3 pointer as a dictionary 40 | which can be serialized and passed to Kafka or S3.""" 41 | return dict(id=id, s3_bucket=s3_bucket, s3_key=s3_key) 42 | 43 | 44 | def FeaturesObject(id, img_pointer, imagenet_labels, feature_vector): 45 | """Function to formalize the structure of a feature object as a dictionary 46 | which can be serialized and passed to Kafka or S3.""" 47 | if not isinstance(feature_vector, list): 48 | feature_vector = list(map(float, feature_vector)) 49 | return dict(id=id, img_pointer=img_pointer, 50 | imagenet_labels=imagenet_labels, 51 | feature_vector=feature_vector) 52 | 53 | 54 | class Convnet(object): 55 | """Wrapper around a Keras network; produces image labels and floating-point 56 | feature vectors.""" 57 | 58 | def __init__(self): 59 | self.preprocess_mode = 'tf' 60 | model = MobileNet(weights='imagenet') 61 | self.model = Model( 62 | model.input, [model.output, model.get_layer('conv_preds').output]) 63 | 64 | def get_labels_and_vecs(self, imgs_iter): 65 | """Compute the labels and floating-point feature vectors for a batch of 66 | images. 67 | 68 | Arguments 69 | imgs_iter: an iterable of numpy array images which have been pre-processed 70 | to a size useable with the Keras model. An iterable is used because it 71 | allows the calling code to pass a map(f, data) which this function 72 | executes over, or a regular list of already extracted images. 73 | 74 | Returns 75 | labels: a list of strings, one per image. Each string contains the ten most 76 | probable labels for the image, separated by spaces. 77 | vecs: a numpy array with shape (number of images, feature vector shape). 78 | For example, a batch of 512 images with a feature vector shape (1000,) 79 | would mean vecs has shape (512, 1000). 80 | 81 | """ 82 | 83 | imgs = np.array(imgs_iter) 84 | imgs = preprocess_input(imgs.astype(np.float32), 85 | mode=self.preprocess_mode) 86 | 87 | clsf, vecs = self.model.predict(imgs) 88 | labels = [' '.join([y[1].lower() for y in x]) 89 | for x in decode_predictions(clsf, top=10)] 90 | vecs = np.squeeze(vecs) 91 | 92 | return labels, vecs 93 | 94 | 95 | def _get_img_bytes_from_s3(args): 96 | """Download the raw image bytes from S3. 97 | 98 | It's generally safe and much faster to call this function from a 99 | thread pool of 10 - 20 threads. 100 | 101 | Arguments 102 | args: a tuple containing the bucket (string), key (string), and s3client 103 | (boto3 client). Using a tuple to support calling this method via 104 | parallelized map() function. 105 | 106 | Returns 107 | object body: bytes from the object downloaded from S3. 108 | """ 109 | bucket, key, s3client = args 110 | obj = s3client.get_object(Bucket=bucket, Key=key) 111 | return obj['Body'].read() 112 | 113 | 114 | def _preprocess_img(img_bytes): 115 | """Load and transform image from its raw bytes to a keras-friendly np array. 116 | 117 | If the image cannot be read, it prints an error message and returns an 118 | array of all zeros. 119 | 120 | Arguments 121 | img_bytes: a Bytes object containing the bytes for a single image. 122 | 123 | Returns 124 | img: numpy array with shape (224, 224, 3). 125 | 126 | """ 127 | 128 | # Read from bytes to numpy array. 129 | try: 130 | img = imread(BytesIO(img_bytes)) 131 | assert isinstance(img, np.ndarray) 132 | except (ValueError, AssertionError) as ex: 133 | print("Error reading image, returning zeros:", ex, file=stderr) 134 | return np.zeros((224, 224, 3), dtype=np.uint8) 135 | 136 | # Extremely fast resize using lycon library. 137 | img = resize(img, 224, 224, interpolation=0) 138 | 139 | # Regular image: return. 140 | if img.shape[-1] == 3: 141 | return img 142 | 143 | # Grayscale image: repeat up to 3 channels. 144 | elif len(img.shape) == 2: 145 | return np.repeat(img[:, :, np.newaxis], 3, -1) 146 | 147 | # Other image: repeat first channel 3 times. 148 | return np.repeat(img[:, :, :1], 3, -1) 149 | 150 | 151 | def _str_to_gzipped_bytes(s): 152 | """Convert a single string to compressed Gzipped bytes, useful for uploading 153 | to S3.""" 154 | b = s.encode() 155 | g = gzip.compress(b) 156 | return BytesIO(g) 157 | 158 | 159 | if __name__ == "__main__": 160 | 161 | ap = ArgumentParser(description="See script header") 162 | ap.add_argument("--kafka_sub_topic", 163 | help="Name of topic from which images are consumed", 164 | default="aknn-demo.image-pointers") 165 | ap.add_argument("--kafka_pub_topic", 166 | help="Name of topic to which feature vectors get produced", 167 | default="aknn-demo.feature-pointers") 168 | ap.add_argument("--s3_pub_bucket", 169 | help="Name of bucket to which feature vectors get saved", 170 | default="klibisz-aknn-demo") 171 | ap.add_argument("--kafka_sub_offset", 172 | help="Where to start reading from topic", 173 | default="earliest", choices=["earliest", "latest"]) 174 | ap.add_argument("--kafka_servers", 175 | help="Bootstrap servers for Kafka", 176 | default="ip-172-31-19-114.ec2.internal:9092") 177 | ap.add_argument("--kafka_group", 178 | help="Group ID for Kafka consumer", 179 | default="aknn-demo.compute-image-features") 180 | 181 | args = vars(ap.parse_args()) 182 | print("Parsed command-line arguments:\n%s" % pformat(args)) 183 | 184 | # Kafka consumer/producer setup. 185 | consumer = KafkaConsumer( 186 | args["kafka_sub_topic"], 187 | bootstrap_servers=args["kafka_servers"], 188 | group_id=args["kafka_group"], 189 | auto_offset_reset=args["kafka_sub_offset"], 190 | key_deserializer=lambda k: k.decode(), 191 | value_deserializer=lambda v: json.loads(v.decode()) 192 | ) 193 | producer = KafkaProducer( 194 | bootstrap_servers=args["kafka_servers"], 195 | compression_type='gzip', 196 | key_serializer=str.encode, 197 | value_serializer=str.encode) 198 | 199 | # S3 connection. 200 | s3client = boto3.client('s3') 201 | 202 | # Convolutional network for feature extraction. 203 | convnet = Convnet() 204 | 205 | # Process pool and thread pool for parallelism. 206 | pool = Pool(cpu_count()) 207 | tpex = ThreadPoolExecutor(max_workers=min(cpu_count() * 4, 20)) 208 | 209 | print("Consuming from %s..." % args["kafka_sub_topic"]) 210 | 211 | for msg in consumer: 212 | 213 | print("-" * 80) 214 | print("Received batch %s with %d images" % (msg.key, len(msg.value))) 215 | T0 = time() 216 | 217 | # Download images from S3 into memory using thread parallelism. 218 | t0 = time() 219 | try: 220 | def f(p): return (p['s3_bucket'], p['s3_key'], s3client) 221 | data = map(f, msg.value) 222 | imgs_bytes = list(tpex.map(_get_img_bytes_from_s3, data)) 223 | except Exception as ex: 224 | print("Error downloading images:", ex, file=stderr) 225 | continue 226 | print("Download images from S3: %.2lf seconds" % (time() - t0)) 227 | 228 | # Preprocess the raw bytes using process parallelism. 229 | t0 = time() 230 | try: 231 | imgs_iter = pool.map(_preprocess_img, imgs_bytes) 232 | except Exception as ex: 233 | print("Error preprocessing images:", ex, file=stderr) 234 | continue 235 | print("Preprocess images: %.2lf seconds" % (time() - t0)) 236 | 237 | # Compute image labels and feature vectors. 238 | t0 = time() 239 | try: 240 | labels, vecs = convnet.get_labels_and_vecs(imgs_iter) 241 | except Exception as ex: 242 | print("Error computing features:", ex, file=stderr) 243 | continue 244 | print("Compute features: %.2lf seconds" % (time() - t0)) 245 | print("Vectors shape, mean, std = %s, %.5lf, %.5lf" % ( 246 | vecs.shape, vecs.mean(), vecs.std())) 247 | 248 | t0 = time() 249 | s3_futures = [] 250 | for img_pointer, label, vec in zip(msg.value, labels, vecs): 251 | 252 | # Create features object which will be uploaded to S3. 253 | features_object = FeaturesObject( 254 | id=img_pointer["id"], img_pointer=img_pointer, 255 | imagenet_labels=label, feature_vector=vec) 256 | 257 | # Create S3 Pointer which will be passed along in Kafka. 258 | features_pointer = S3Pointer( 259 | id=img_pointer["id"], s3_bucket=args["s3_pub_bucket"], 260 | s3_key="img-features-%s.json.gz" % img_pointer["id"]) 261 | 262 | # Upload features object to S3 by submitting to the thread pool. 263 | try: 264 | s3_args = dict( 265 | Body=_str_to_gzipped_bytes(json.dumps(features_object)), 266 | Bucket=features_pointer['s3_bucket'], 267 | Key=features_pointer['s3_key']) 268 | s3_futures.append(tpex.submit(s3client.put_object, **s3_args)) 269 | except Exception as ex: 270 | print("Error uploading to S3:", ex, file=stderr) 271 | continue 272 | 273 | # Publish to Kafka. 274 | try: 275 | producer.send(args["kafka_pub_topic"], 276 | key=features_pointer['id'], 277 | value=json.dumps(features_pointer)) 278 | except Exception as ex: 279 | print("Error publishing to Kafka:", ex, file=stderr) 280 | continue 281 | 282 | # Wait for all s3 requests to complete. 283 | try: 284 | wait(s3_futures, timeout=30) 285 | except Exception as ex: 286 | print("Error resolving upload futures:", ex, file=stderr) 287 | continue 288 | print("Upload features: %.2lf seconds", (time() - t0)) 289 | 290 | print("Finished batch %s: %.2lf seconds" % (msg.key, time() - T0)) 291 | 292 | producer.flush() 293 | -------------------------------------------------------------------------------- /demo/pipeline/stream_produce_image_pointers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Produces pointers to images in S3 to a Kafka Topic. 3 | Each pointer contains the information necessary for a downstream 4 | consumer to retrieve the specific image from S3. 5 | """ 6 | 7 | from argparse import ArgumentParser 8 | from kafka import KafkaProducer 9 | from time import time 10 | from tqdm import tqdm 11 | import boto3 12 | import json 13 | import pdb 14 | 15 | 16 | def S3Pointer(id, s3_bucket, s3_key): 17 | return dict(id=id, s3_bucket=s3_bucket, s3_key=s3_key) 18 | 19 | 20 | if __name__ == "__main__": 21 | 22 | ap = ArgumentParser(description="See script") 23 | ap.add_argument("--bucket", help="S3 bucket name", 24 | default="klibisz-twitter-stream") 25 | ap.add_argument("--kafka_pub_topic", 26 | help="Topic to which image events get published", 27 | default="aknn-demo.image-pointers") 28 | ap.add_argument("--kafka_server", 29 | help="Bootstrap server for producer", 30 | default="ip-172-31-19-114.ec2.internal:9092") 31 | ap.add_argument("-b", "--batch_size", type=int, default=1000, 32 | help="Size of batches produced") 33 | 34 | args = vars(ap.parse_args()) 35 | 36 | bucket = boto3.resource("s3").Bucket(args["bucket"]) 37 | producer = KafkaProducer( 38 | bootstrap_servers=args["kafka_server"], 39 | compression_type="gzip", 40 | key_serializer=str.encode, 41 | value_serializer=str.encode) 42 | 43 | t0 = time() 44 | nb_produced = 0 45 | batch = [] 46 | for obj in bucket.objects.all(): 47 | 48 | # TODO: this was a bad design in the ingestion. 49 | # There should be a way to distinguish statues 50 | # from images without ever reading in the statuses. 51 | if obj.key.endswith(".json.gz"): 52 | continue 53 | 54 | batch.append(S3Pointer( 55 | id=obj.key.split('.')[0], 56 | s3_bucket=obj.bucket_name, s3_key=obj.key)) 57 | 58 | if len(batch) < args["batch_size"]: 59 | continue 60 | 61 | key = "batch-%s-%s" % (batch[0]['id'], batch[-1]['id']) 62 | value = json.dumps(batch) 63 | producer.send(args["kafka_pub_topic"], key=key, value=value) 64 | nb_produced += len(batch) 65 | batch = [] 66 | 67 | print("%d produced - %d / second" % ( 68 | nb_produced, nb_produced / (time() - t0))) 69 | -------------------------------------------------------------------------------- /demo/pipeline/twitter-credentials.template.json: -------------------------------------------------------------------------------- 1 | { 2 | "consumer_key": "...", 3 | "consumer_secret": "...", 4 | "access_token": "...", 5 | "token_secret": "..." 6 | } 7 | -------------------------------------------------------------------------------- /demo/screencast.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexklibisz/elastik-nearest-neighbors/fe69992c133bbb56a81ea6067217820b6f30b6e6/demo/screencast.gif -------------------------------------------------------------------------------- /demo/webapp/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /demo/webapp/README.md: -------------------------------------------------------------------------------- 1 | Flask web application for Twitter Image Similarity search. 2 | 3 | The web app consists of a single endpoint which executes a similarity search 4 | against an external Elasticsearch node and serves a web page containing the 5 | results for that search, as well as several random images from the index so 6 | the user can continue browsing. 7 | 8 | ## Usage 9 | 10 | ``` 11 | # Install Flask 12 | pip3 install flask 13 | 14 | # Define one or more comma-separated Elasticsearch host urls 15 | # http://localhost:9200 is the default. 16 | export ESHOSTS="http://localhost:9200" 17 | 18 | # Define and run the app (in development mode). 19 | export FLASK_APP=app.py 20 | python3 -m flask run -p 9999 -h 0.0.0.0 21 | ``` 22 | -------------------------------------------------------------------------------- /demo/webapp/app.py: -------------------------------------------------------------------------------- 1 | """ 2 | Minimal Flask web app to demonstrate Elasticsearch-Aknn functionality 3 | on corpus of Twitter image features. 4 | """ 5 | 6 | from flask import Flask, request, render_template, redirect 7 | from itertools import cycle 8 | from pprint import pprint 9 | import os 10 | import random 11 | import requests 12 | 13 | # Get elasticsearch hosts from environment variable. 14 | ESHOSTS = cycle(["http://localhost:9200"]) 15 | if "ESHOSTS" in os.environ: 16 | ESHOSTS = cycle(os.environ["ESHOSTS"].split(",")) 17 | 18 | # Define a set of images to cycle through for the /demo endpoint. 19 | DEMO_IDS = [ 20 | "988221425063530502", # Mountain scenery 21 | "990013386929917953", # Car 22 | "991780208138055681", # Screenshot 23 | "989646964148133889", # Male actors (DiCaprio) 24 | "988889393158115329", # Male athlete (C. Ronaldo) 25 | "988487255877718017", # Signs 26 | "991004064748978177", # Female selfie 27 | "988237522810503168", # Cartoon character 28 | "989808637773135873", # North/south Korean politicians 29 | "989144784341229568", # Leo Messi 30 | "989655776363921409", # Dog 31 | "991484266415443968", # Some kids playing with a racoon 32 | "989836022384156672", # Mountain scenery 33 | "990578938505146368", # Race cars 34 | "988526279665205248", # Store fronts 35 | "989477367486672896", 36 | "988531509954011139", 37 | "990159780726665216", 38 | "990678809081823232", 39 | "992379356071825410", 40 | "988788327217119233", 41 | "989065251919458304", 42 | "989617448843403264", 43 | "990863324890869760", 44 | "989664366319484928", 45 | "989951906809344001", 46 | "988674636417249281", 47 | "988426706888216576", 48 | "991450758120902656", 49 | "990226717607415808", 50 | "988902080923529217", 51 | "990372146735087616", 52 | "989678396274814976", 53 | "988867339516022784", 54 | "990713839892119552", 55 | "992056122050662400", 56 | "989161016280875008", 57 | "990594050557231104", 58 | "992186954941980673", 59 | "988825283204558848", 60 | "989350699472490497", 61 | "990430615324450816" 62 | ] 63 | 64 | app = Flask(__name__) 65 | 66 | 67 | @app.route("/slides") 68 | def slides(): 69 | return redirect("https://docs.google.com/presentation/d/1AyIyBqzCqKhytZWcQfSEhtBRN-iHUldBQn14MGGKpr8/present", 70 | code=302) 71 | 72 | @app.route("/") 73 | @app.route("/demo") 74 | def demo(): 75 | return redirect("/twitter_images/twitter_image/demo", code=302) 76 | 77 | @app.route("///") 78 | def images(es_index, es_type, es_id): 79 | 80 | 81 | 82 | # Parse elasticsearch ID. If "demo", pick a random demo image ID. 83 | if es_id.lower() == "demo": 84 | es_id = random.choice(DEMO_IDS) 85 | 86 | elif es_id.lower() == "random": 87 | body = { 88 | "_source": ["s3_url"], 89 | "size": 1, 90 | "query": { 91 | "function_score": { 92 | "query": {"match_all": {}}, 93 | "boost": 5, 94 | "random_score": {}, 95 | "boost_mode": "multiply" 96 | } 97 | } 98 | } 99 | req_url = "%s/%s/%s/_search" % (next(ESHOSTS), es_index, es_type) 100 | req = requests.get(req_url, json=body) 101 | es_id = req.json()["hits"]["hits"][0]["_id"] 102 | 103 | # Get number of docs in corpus. 104 | req_url = "%s/%s/%s/_count" % (next(ESHOSTS), es_index, es_type) 105 | req = requests.get(req_url) 106 | count = req.json()["count"] 107 | 108 | # Get the nearest neighbors for the query image, which includes the image. 109 | image_id = request.args.get("image_id") 110 | req_url = "%s/%s/%s/%s/_aknn_search?k1=100&k2=10" % ( 111 | next(ESHOSTS), es_index, es_type, es_id) 112 | req = requests.get(req_url) 113 | hits = req.json()["hits"]["hits"] 114 | took_ms = req.json()["took"] 115 | query_img, neighbor_imgs = hits[0], hits[1:] 116 | 117 | # Render template. 118 | return render_template( 119 | "index.html", 120 | es_index=es_index, 121 | es_type=es_type, 122 | took_ms=took_ms, 123 | count=count, 124 | query_img=query_img, 125 | neighbor_imgs=neighbor_imgs) 126 | -------------------------------------------------------------------------------- /demo/webapp/templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | ElastiK Nearest Neighbors Demo 7 | 8 | 9 | 29 | 30 | 31 | 32 | 33 |
34 | 55 | 56 |
57 |
58 |

Query Image

59 | 60 | View the original Tweet 61 |
62 |
63 | 64 |
65 |
66 |
67 |

Most Visually Similar Nearest Neighbors

68 |

Click an image to view its original tweet.

69 |
70 |
71 | {% for img in neighbor_imgs %} 72 |
73 | 74 | 75 | 76 |
77 | {% endfor %} 78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 | View another demo image 86 |
87 |
88 |
89 |
90 | 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /elasticsearch-aknn/.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .gradle/ 3 | build/ 4 | out/ 5 | *.swp 6 | *-execution-hints.log 7 | *-execution-times.log 8 | 9 | // ignore the downloaded model files in git 10 | src/test/resources/models/ 11 | 12 | // intellij 13 | *.iml 14 | *.ipr 15 | *.iws 16 | 17 | // eclipse 18 | .project 19 | .classpath 20 | eclipse-build 21 | */.project 22 | */.classpath 23 | */eclipse-build 24 | .settings 25 | !/.settings/org.eclipse.core.resources.prefs 26 | !/.settings/org.eclipse.jdt.core.prefs 27 | !/.settings/org.eclipse.jdt.ui.prefs 28 | -------------------------------------------------------------------------------- /elasticsearch-aknn/LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /elasticsearch-aknn/NOTICE.txt: -------------------------------------------------------------------------------- 1 | Elasticsearch-Aknn 2 | 3 | Copyright 2018 Alex Klibisz 4 | 5 | This product includes software developed by The Apache Software 6 | Foundation (http://www.apache.org/). -------------------------------------------------------------------------------- /elasticsearch-aknn/README.md: -------------------------------------------------------------------------------- 1 | # Elasticsearch-Aknn 2 | 3 | Elasticsearch plugin for approximate K-nearest-neighbor querires on floating-point 4 | vectors using locality sensitive hashing. 5 | 6 | The API for the three main endpoints and main points about implementation are 7 | documented at the root of this repository. 8 | 9 | See the `testplugin.sh` script for an outline of building and installing the plugin. 10 | 11 | See the `benchmarks` directory for examples on interacting with the plugin 12 | programmatically via Python and the requests library. 13 | 14 | The long-term plan for this plugin is to extract it to polish it up and move 15 | it to its own repository. I've begun doing this on the [dev branch of the 16 | elasticsearch-aknn repository.](https://github.com/alexklibisz/elasticsearch-aknn/tree/dev) 17 | 18 | ## Planned Improvements 19 | 20 | 1. Implement integration tests. Elasticsearch has some nice integration testing 21 | functionality, but the documentation is very scarce. 22 | 2. Add proper error checking and error responses to the endpoints to prevent 23 | silent/ambiguous errors. For example, Elasticsearch prevents lowercase index 24 | names and fails to index such a document, but the endpoint still returns 200. 25 | 3. Clean up the JSON-Java serialization and deserialization, especially 26 | the conversion of JSON lists of lists to Java `List>` to 27 | Java `Double [][]` to `RealMatrix`. 28 | 4. Enforce an explicit mapping and types for new Aknn LSH models. For example, the LSH 29 | hyperplanes should not be indexed and can likely be stored as `half_float` / Java `float`) 30 | to save space / network latency. 31 | 5. Enforce an explicit mapping and types for `_aknn_vector` and `_aknn_hashes` 32 | entries. For example, `_aknn_vector` should not be indexed and can likley be 33 | stored as a `half_float` / Java `float`. 34 | 6. Determine a proper place for defining/changing plugin configurations. For 35 | example, the name of the vector and hashes items. 36 | 7. Implement alternative distance functions, starting with cosine distance. 37 | -------------------------------------------------------------------------------- /elasticsearch-aknn/benchmark/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | glove*txt 3 | glove*zip 4 | 5 | -------------------------------------------------------------------------------- /elasticsearch-aknn/benchmark/README.md: -------------------------------------------------------------------------------- 1 | # Elasticsearch-Aknn Benchmarks 2 | 3 | This directory contains code for benchmarking an Elasticsearch-Aknn installation 4 | using Glove word vectors. 5 | 6 | The `aknn.py` script should give an idea for how to programatically interact 7 | with the Elasticsearch-Aknn plugin. 8 | 9 | The `figures.ipynb` notebook should give some ideas on performance for various 10 | configurations. 11 | 12 | -------------------------------------------------------------------------------- /elasticsearch-aknn/benchmark/glove_download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -e 3 | wget http://nlp.stanford.edu/data/glove.twitter.27B.zip 4 | unzip glove.twitter.27B.zip 5 | rm glove.twitter.27B.50d.txt glove.twitter.27B.100d.txt glove.twitter.27B.200d.txt 6 | echo "Done" 7 | -------------------------------------------------------------------------------- /elasticsearch-aknn/benchmark/glove_preprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | assert len(sys.argv) > 1, "Usage: path-to-unzipped-glove-vecs.txt" 5 | 6 | for i, line in enumerate(open(sys.argv[1])): 7 | tkns = line.split(" ") 8 | word = tkns[0] 9 | vector = list(map(lambda x: round(float(x), 5), tkns[1:])) 10 | doc = { 11 | "_id": "word_%d" % i, 12 | "_source": { 13 | "word": word, 14 | "_aknn_vector": vector 15 | } 16 | } 17 | print(json.dumps(doc)) 18 | -------------------------------------------------------------------------------- /elasticsearch-aknn/benchmark/metrics/.gitignore: -------------------------------------------------------------------------------- 1 | *.json 2 | -------------------------------------------------------------------------------- /elasticsearch-aknn/benchmark/metrics/fig_corpus_vs_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexklibisz/elastik-nearest-neighbors/fe69992c133bbb56a81ea6067217820b6f30b6e6/elasticsearch-aknn/benchmark/metrics/fig_corpus_vs_time.png -------------------------------------------------------------------------------- /elasticsearch-aknn/benchmark/metrics/fig_recall_vs_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexklibisz/elastik-nearest-neighbors/fe69992c133bbb56a81ea6067217820b6f30b6e6/elasticsearch-aknn/benchmark/metrics/fig_recall_vs_time.png -------------------------------------------------------------------------------- /elasticsearch-aknn/build.gradle: -------------------------------------------------------------------------------- 1 | import org.elasticsearch.gradle.test.RestIntegTestTask 2 | 3 | buildscript { 4 | repositories { 5 | mavenCentral() 6 | jcenter() 7 | } 8 | 9 | dependencies { 10 | classpath "org.elasticsearch.gradle:build-tools:6.2.4" 11 | } 12 | } 13 | 14 | group = 'org.elasticsearch.plugin.aknn' 15 | version = '0.0.1-SNAPSHOT' 16 | 17 | apply plugin: 'java' 18 | apply plugin: 'elasticsearch.esplugin' 19 | apply plugin: 'idea' 20 | 21 | // license of this project 22 | licenseFile = rootProject.file('LICENSE.txt') 23 | 24 | // copyright notices 25 | noticeFile = rootProject.file('NOTICE.txt') 26 | 27 | esplugin { 28 | name 'elasticsearch-aknn' 29 | description 'Elasticsearch plugin for approximate K-nearest-neighbors search' 30 | classname 'org.elasticsearch.plugin.aknn.AknnPlugin' 31 | // license of the plugin, may be different than the above license 32 | licenseFile rootProject.file('LICENSE.txt') 33 | // copyright notices, may be different than the above notice 34 | noticeFile rootProject.file('NOTICE.txt') 35 | } 36 | 37 | // In this section you declare the dependencies for your production and test code 38 | // TODO: See https://youtu.be/7alCuE7cNVQ?t=12m41s for a potential solution to original ND4J problems. 39 | dependencies { 40 | compile 'org.apache.commons:commons-math3:3.6.1' 41 | } 42 | 43 | test { 44 | enabled true 45 | testLoggingConfig.setOutputMode( 46 | com.carrotsearch.gradle.junit4.TestLoggingConfiguration.OutputMode.ALWAYS 47 | ) 48 | } 49 | 50 | integTestRunner { 51 | enabled true 52 | testLoggingConfig.setOutputMode( 53 | com.carrotsearch.gradle.junit4.TestLoggingConfiguration.OutputMode.ALWAYS 54 | ) 55 | } 56 | 57 | // Set to false to not use elasticsearch checkstyle rules. 58 | checkstyleMain.enabled = true 59 | checkstyleTest.enabled = true 60 | 61 | // FIXME dependency license check needs to be enabled 62 | dependencyLicenses.enabled = false 63 | 64 | // FIXME thirdparty audit needs to be enabled 65 | thirdPartyAudit.enabled = false 66 | -------------------------------------------------------------------------------- /elasticsearch-aknn/settings.gradle: -------------------------------------------------------------------------------- 1 | rootProject.name = 'elasticsearch-aknn' 2 | -------------------------------------------------------------------------------- /elasticsearch-aknn/src/main/java/org/elasticsearch/plugin/aknn/AknnPlugin.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright [2018] [Alex Klibisz] 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.elasticsearch.plugin.aknn; 19 | 20 | import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; 21 | import org.elasticsearch.cluster.node.DiscoveryNodes; 22 | import org.elasticsearch.common.settings.ClusterSettings; 23 | import org.elasticsearch.common.settings.IndexScopedSettings; 24 | import org.elasticsearch.common.settings.Setting; 25 | import org.elasticsearch.common.settings.Settings; 26 | import org.elasticsearch.common.settings.SettingsFilter; 27 | import org.elasticsearch.plugins.ActionPlugin; 28 | import org.elasticsearch.plugins.Plugin; 29 | import org.elasticsearch.rest.RestController; 30 | import org.elasticsearch.rest.RestHandler; 31 | 32 | import java.util.Arrays; 33 | import java.util.List; 34 | import java.util.function.Supplier; 35 | 36 | public class AknnPlugin extends Plugin implements ActionPlugin { 37 | 38 | private static final Setting SETTINGS = 39 | new Setting<>("aknn.sample.setting", "foo", (value) -> value, Setting.Property.NodeScope); 40 | 41 | @Override 42 | public List> getSettings() { 43 | return Arrays.asList(SETTINGS); 44 | } 45 | 46 | @Override 47 | public List getRestHandlers(final Settings settings, 48 | final RestController restController, 49 | final ClusterSettings clusterSettings, 50 | final IndexScopedSettings indexScopedSettings, 51 | final SettingsFilter settingsFilter, 52 | final IndexNameExpressionResolver indexNameExpressionResolver, 53 | final Supplier nodesInCluster) { 54 | return Arrays.asList(new AknnRestAction(settings, restController)); 55 | } 56 | } -------------------------------------------------------------------------------- /elasticsearch-aknn/src/main/java/org/elasticsearch/plugin/aknn/AknnRestAction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright [2018] [Alex Klibisz] 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | package org.elasticsearch.plugin.aknn; 18 | 19 | import org.elasticsearch.action.bulk.BulkRequestBuilder; 20 | import org.elasticsearch.action.bulk.BulkResponse; 21 | import org.elasticsearch.action.get.GetResponse; 22 | import org.elasticsearch.action.index.IndexResponse; 23 | import org.elasticsearch.action.search.SearchResponse; 24 | import org.elasticsearch.client.node.NodeClient; 25 | import org.elasticsearch.common.StopWatch; 26 | import org.elasticsearch.common.inject.Inject; 27 | import org.elasticsearch.common.settings.Settings; 28 | import org.elasticsearch.common.xcontent.XContentBuilder; 29 | import org.elasticsearch.common.xcontent.XContentHelper; 30 | import org.elasticsearch.common.xcontent.XContentParser; 31 | import org.elasticsearch.index.query.BoolQueryBuilder; 32 | import org.elasticsearch.index.query.QueryBuilder; 33 | import org.elasticsearch.index.query.QueryBuilders; 34 | import org.elasticsearch.rest.BaseRestHandler; 35 | import org.elasticsearch.rest.BytesRestResponse; 36 | import org.elasticsearch.rest.RestController; 37 | import org.elasticsearch.rest.RestRequest; 38 | import org.elasticsearch.rest.RestStatus; 39 | import org.elasticsearch.search.SearchHit; 40 | 41 | import java.io.IOException; 42 | import java.util.ArrayList; 43 | import java.util.Comparator; 44 | import java.util.HashMap; 45 | import java.util.List; 46 | import java.util.Map; 47 | 48 | import static java.lang.Math.min; 49 | import static org.elasticsearch.rest.RestRequest.Method.GET; 50 | import static org.elasticsearch.rest.RestRequest.Method.POST; 51 | 52 | public class AknnRestAction extends BaseRestHandler { 53 | 54 | public static String NAME = "_aknn"; 55 | private final String NAME_SEARCH = "_aknn_search"; 56 | private final String NAME_INDEX = "_aknn_index"; 57 | private final String NAME_CREATE = "_aknn_create"; 58 | 59 | // TODO: check how parameters should be defined at the plugin level. 60 | private final String HASHES_KEY = "_aknn_hashes"; 61 | private final String VECTOR_KEY = "_aknn_vector"; 62 | private final Integer K1_DEFAULT = 99; 63 | private final Integer K2_DEFAULT = 10; 64 | 65 | // TODO: add an option to the index endpoint handler that empties the cache. 66 | private Map lshModelCache = new HashMap<>(); 67 | 68 | @Inject 69 | public AknnRestAction(Settings settings, RestController controller) { 70 | super(settings); 71 | controller.registerHandler(GET, "/{index}/{type}/{id}/" + NAME_SEARCH, this); 72 | controller.registerHandler(POST, NAME_INDEX, this); 73 | controller.registerHandler(POST, NAME_CREATE, this); 74 | } 75 | 76 | @Override 77 | public String getName() { 78 | return NAME; 79 | } 80 | 81 | @Override 82 | protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { 83 | if (restRequest.path().endsWith(NAME_SEARCH)) 84 | return handleSearchRequest(restRequest, client); 85 | else if (restRequest.path().endsWith(NAME_INDEX)) 86 | return handleIndexRequest(restRequest, client); 87 | else 88 | return handleCreateRequest(restRequest, client); 89 | } 90 | 91 | public static Double euclideanDistance(List A, List B) { 92 | Double squaredDistance = 0.; 93 | for (Integer i = 0; i < A.size(); i++) 94 | squaredDistance += Math.pow(A.get(i) - B.get(i), 2); 95 | return Math.sqrt(squaredDistance); 96 | } 97 | 98 | private RestChannelConsumer handleSearchRequest(RestRequest restRequest, NodeClient client) throws IOException { 99 | 100 | StopWatch stopWatch = new StopWatch("StopWatch to Time Search Request"); 101 | 102 | // Parse request parameters. 103 | stopWatch.start("Parse request parameters"); 104 | final String index = restRequest.param("index"); 105 | final String type = restRequest.param("type"); 106 | final String id = restRequest.param("id"); 107 | final Integer k1 = restRequest.paramAsInt("k1", K1_DEFAULT); 108 | final Integer k2 = restRequest.paramAsInt("k2", K2_DEFAULT); 109 | stopWatch.stop(); 110 | 111 | logger.info("Get query document at {}/{}/{}", index, type, id); 112 | stopWatch.start("Get query document"); 113 | GetResponse queryGetResponse = client.prepareGet(index, type, id).get(); 114 | Map baseSource = queryGetResponse.getSource(); 115 | stopWatch.stop(); 116 | 117 | logger.info("Parse query document hashes"); 118 | stopWatch.start("Parse query document hashes"); 119 | @SuppressWarnings("unchecked") 120 | Map queryHashes = (Map) baseSource.get(HASHES_KEY); 121 | stopWatch.stop(); 122 | 123 | stopWatch.start("Parse query document vector"); 124 | @SuppressWarnings("unchecked") 125 | List queryVector = (List) baseSource.get(VECTOR_KEY); 126 | stopWatch.stop(); 127 | 128 | // Retrieve the documents with most matching hashes. https://stackoverflow.com/questions/10773581 129 | logger.info("Build boolean query from hashes"); 130 | stopWatch.start("Build boolean query from hashes"); 131 | QueryBuilder queryBuilder = QueryBuilders.boolQuery(); 132 | for (Map.Entry entry : queryHashes.entrySet()) { 133 | String termKey = HASHES_KEY + "." + entry.getKey(); 134 | ((BoolQueryBuilder) queryBuilder).should(QueryBuilders.termQuery(termKey, entry.getValue())); 135 | } 136 | stopWatch.stop(); 137 | 138 | logger.info("Execute boolean search"); 139 | stopWatch.start("Execute boolean search"); 140 | SearchResponse approximateSearchResponse = client 141 | .prepareSearch(index) 142 | .setTypes(type) 143 | .setFetchSource("*", HASHES_KEY) 144 | .setQuery(queryBuilder) 145 | .setSize(k1) 146 | .get(); 147 | stopWatch.stop(); 148 | 149 | // Compute exact KNN on the approximate neighbors. 150 | // Recreate the SearchHit structure, but remove the vector and hashes. 151 | logger.info("Compute exact distance and construct search hits"); 152 | stopWatch.start("Compute exact distance and construct search hits"); 153 | List> modifiedSortedHits = new ArrayList<>(); 154 | for (SearchHit hit: approximateSearchResponse.getHits()) { 155 | Map hitSource = hit.getSourceAsMap(); 156 | @SuppressWarnings("unchecked") 157 | List hitVector = (List) hitSource.get(VECTOR_KEY); 158 | hitSource.remove(VECTOR_KEY); 159 | hitSource.remove(HASHES_KEY); 160 | modifiedSortedHits.add(new HashMap() {{ 161 | put("_index", hit.getIndex()); 162 | put("_id", hit.getId()); 163 | put("_type", hit.getType()); 164 | put("_score", euclideanDistance(queryVector, hitVector)); 165 | put("_source", hitSource); 166 | }}); 167 | } 168 | stopWatch.stop(); 169 | 170 | logger.info("Sort search hits by exact distance"); 171 | stopWatch.start("Sort search hits by exact distance"); 172 | modifiedSortedHits.sort(Comparator.comparingDouble(x -> (Double) x.get("_score"))); 173 | stopWatch.stop(); 174 | 175 | logger.info("Timing summary\n {}", stopWatch.prettyPrint()); 176 | 177 | return channel -> { 178 | XContentBuilder builder = channel.newBuilder(); 179 | builder.startObject(); 180 | builder.field("took", stopWatch.totalTime().getMillis()); 181 | builder.field("timed_out", false); 182 | builder.startObject("hits"); 183 | builder.field("max_score", 0); 184 | 185 | // In some cases there will not be enough approximate matches to return *k2* hits. For example, this could 186 | // be the case if the number of bits per table in the LSH model is too high, over-partioning the space. 187 | builder.field("total", min(k2, modifiedSortedHits.size())); 188 | builder.field("hits", modifiedSortedHits.subList(0, min(k2, modifiedSortedHits.size()))); 189 | builder.endObject(); 190 | builder.endObject(); 191 | channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); 192 | }; 193 | } 194 | 195 | private RestChannelConsumer handleCreateRequest(RestRequest restRequest, NodeClient client) throws IOException { 196 | 197 | StopWatch stopWatch = new StopWatch("StopWatch to time create request"); 198 | logger.info("Parse request"); 199 | stopWatch.start("Parse request"); 200 | 201 | XContentParser xContentParser = XContentHelper.createParser( 202 | restRequest.getXContentRegistry(), restRequest.content(), restRequest.getXContentType()); 203 | Map contentMap = xContentParser.mapOrdered(); 204 | @SuppressWarnings("unchecked") 205 | Map sourceMap = (Map) contentMap.get("_source"); 206 | 207 | final String _index = (String) contentMap.get("_index"); 208 | final String _type = (String) contentMap.get("_type"); 209 | final String _id = (String) contentMap.get("_id"); 210 | final String description = (String) sourceMap.get("_aknn_description"); 211 | final Integer nbTables = (Integer) sourceMap.get("_aknn_nb_tables"); 212 | final Integer nbBitsPerTable = (Integer) sourceMap.get("_aknn_nb_bits_per_table"); 213 | final Integer nbDimensions = (Integer) sourceMap.get("_aknn_nb_dimensions"); 214 | @SuppressWarnings("unchecked") 215 | final List> vectorSample = (List>) contentMap.get("_aknn_vector_sample"); 216 | stopWatch.stop(); 217 | 218 | logger.info("Fit LSH model from sample vectors"); 219 | stopWatch.start("Fit LSH model from sample vectors"); 220 | LshModel lshModel = new LshModel(nbTables, nbBitsPerTable, nbDimensions, description); 221 | lshModel.fitFromVectorSample(vectorSample); 222 | stopWatch.stop(); 223 | 224 | logger.info("Serialize LSH model"); 225 | stopWatch.start("Serialize LSH model"); 226 | Map lshSerialized = lshModel.toMap(); 227 | stopWatch.stop(); 228 | 229 | logger.info("Index LSH model"); 230 | stopWatch.start("Index LSH model"); 231 | IndexResponse indexResponse = client.prepareIndex(_index, _type, _id) 232 | .setSource(lshSerialized) 233 | .get(); 234 | stopWatch.stop(); 235 | 236 | logger.info("Timing summary\n {}", stopWatch.prettyPrint()); 237 | 238 | return channel -> { 239 | XContentBuilder builder = channel.newBuilder(); 240 | builder.startObject(); 241 | builder.field("took", stopWatch.totalTime().getMillis()); 242 | builder.endObject(); 243 | channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); 244 | }; 245 | } 246 | 247 | private RestChannelConsumer handleIndexRequest(RestRequest restRequest, NodeClient client) throws IOException { 248 | 249 | StopWatch stopWatch = new StopWatch("StopWatch to time bulk indexing request"); 250 | 251 | logger.info("Parse request parameters"); 252 | stopWatch.start("Parse request parameters"); 253 | XContentParser xContentParser = XContentHelper.createParser( 254 | restRequest.getXContentRegistry(), restRequest.content(), restRequest.getXContentType()); 255 | Map contentMap = xContentParser.mapOrdered(); 256 | final String index = (String) contentMap.get("_index"); 257 | final String type = (String) contentMap.get("_type"); 258 | final String aknnURI = (String) contentMap.get("_aknn_uri"); 259 | @SuppressWarnings("unchecked") 260 | final List> docs = (List>) contentMap.get("_aknn_docs"); 261 | logger.info("Received {} docs for indexing", docs.size()); 262 | stopWatch.stop(); 263 | 264 | // TODO: check if the index exists. If not, create a mapping which does not index continuous values. 265 | // This is rather low priority, as I tried it via Python and it doesn't make much difference. 266 | 267 | // Check if the LshModel has been cached. If not, retrieve the Aknn document and use it to populate the model. 268 | LshModel lshModel; 269 | if (! lshModelCache.containsKey(aknnURI)) { 270 | 271 | // Get the Aknn document. 272 | logger.info("Get Aknn model document from {}", aknnURI); 273 | stopWatch.start("Get Aknn model document"); 274 | String[] annURITokens = aknnURI.split("/"); 275 | GetResponse aknnGetResponse = client.prepareGet(annURITokens[0], annURITokens[1], annURITokens[2]).get(); 276 | stopWatch.stop(); 277 | 278 | // Instantiate LSH from the source map. 279 | logger.info("Parse Aknn model document"); 280 | stopWatch.start("Parse Aknn model document"); 281 | lshModel = LshModel.fromMap(aknnGetResponse.getSourceAsMap()); 282 | stopWatch.stop(); 283 | 284 | // Save for later. 285 | lshModelCache.put(aknnURI, lshModel); 286 | 287 | } else { 288 | logger.info("Get Aknn model document from local cache"); 289 | stopWatch.start("Get Aknn model document from local cache"); 290 | lshModel = lshModelCache.get(aknnURI); 291 | stopWatch.stop(); 292 | } 293 | 294 | // Prepare documents for batch indexing. 295 | logger.info("Hash documents for indexing"); 296 | stopWatch.start("Hash documents for indexing"); 297 | BulkRequestBuilder bulkIndexRequest = client.prepareBulk(); 298 | for (Map doc: docs) { 299 | @SuppressWarnings("unchecked") 300 | Map source = (Map) doc.get("_source"); 301 | @SuppressWarnings("unchecked") 302 | List vector = (List) source.get(VECTOR_KEY); 303 | source.put(HASHES_KEY, lshModel.getVectorHashes(vector)); 304 | bulkIndexRequest.add(client 305 | .prepareIndex(index, type, (String) doc.get("_id")) 306 | .setSource(source)); 307 | } 308 | stopWatch.stop(); 309 | 310 | logger.info("Execute bulk indexing"); 311 | stopWatch.start("Execute bulk indexing"); 312 | BulkResponse bulkIndexResponse = bulkIndexRequest.get(); 313 | stopWatch.stop(); 314 | 315 | logger.info("Timing summary\n {}", stopWatch.prettyPrint()); 316 | 317 | if (bulkIndexResponse.hasFailures()) { 318 | logger.error("Indexing failed with message: {}", bulkIndexResponse.buildFailureMessage()); 319 | return channel -> { 320 | XContentBuilder builder = channel.newBuilder(); 321 | builder.startObject(); 322 | builder.field("took", stopWatch.totalTime().getMillis()); 323 | builder.field("error", bulkIndexResponse.buildFailureMessage()); 324 | builder.endObject(); 325 | channel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, builder)); 326 | }; 327 | } 328 | 329 | logger.info("Indexed {} docs successfully", docs.size()); 330 | return channel -> { 331 | XContentBuilder builder = channel.newBuilder(); 332 | builder.startObject(); 333 | builder.field("size", docs.size()); 334 | builder.field("took", stopWatch.totalTime().getMillis()); 335 | builder.endObject(); 336 | channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); 337 | }; 338 | } 339 | } 340 | -------------------------------------------------------------------------------- /elasticsearch-aknn/src/main/java/org/elasticsearch/plugin/aknn/LshModel.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright [2018] [Alex Klibisz] 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | package org.elasticsearch.plugin.aknn; 18 | 19 | import org.apache.commons.math3.linear.ArrayRealVector; 20 | import org.apache.commons.math3.linear.MatrixUtils; 21 | import org.apache.commons.math3.linear.RealMatrix; 22 | import org.apache.commons.math3.linear.RealVector; 23 | 24 | import java.util.ArrayList; 25 | import java.util.HashMap; 26 | import java.util.List; 27 | import java.util.Map; 28 | import java.util.stream.Collectors; 29 | 30 | public class LshModel { 31 | 32 | private Integer nbTables; 33 | private Integer nbBitsPerTable; 34 | private Integer nbDimensions; 35 | private String description; 36 | private List midpoints; 37 | private List normals; 38 | private List normalsTransposed; 39 | private List thresholds; 40 | 41 | public LshModel(Integer nbTables, Integer nbBitsPerTable, Integer nbDimensions, String description) { 42 | this.nbTables = nbTables; 43 | this.nbBitsPerTable = nbBitsPerTable; 44 | this.nbDimensions = nbDimensions; 45 | this.description = description; 46 | this.midpoints = new ArrayList<>(); 47 | this.normals = new ArrayList<>(); 48 | this.normalsTransposed = new ArrayList<>(); 49 | this.thresholds = new ArrayList<>(); 50 | } 51 | 52 | public void fitFromVectorSample(List> vectorSample) { 53 | 54 | RealMatrix vectorsA, vectorsB, midpoint, normal, vectorSampleMatrix; 55 | vectorSampleMatrix = MatrixUtils.createRealMatrix(vectorSample.size(), this.nbDimensions); 56 | 57 | for (int i = 0; i < vectorSample.size(); i++) 58 | for (int j = 0; j < this.nbDimensions; j++) 59 | vectorSampleMatrix.setEntry(i, j, vectorSample.get(i).get(j)); 60 | 61 | for (int i = 0; i < vectorSampleMatrix.getRowDimension(); i += (nbBitsPerTable * 2)) { 62 | // Select two subsets of nbBitsPerTable vectors. 63 | vectorsA = vectorSampleMatrix.getSubMatrix(i, i + nbBitsPerTable - 1, 0, nbDimensions - 1); 64 | vectorsB = vectorSampleMatrix.getSubMatrix(i + nbBitsPerTable, i + 2 * nbBitsPerTable - 1, 0, nbDimensions - 1); 65 | 66 | // Compute the midpoint between each pair of vectors. 67 | midpoint = vectorsA.add(vectorsB).scalarMultiply(0.5); 68 | midpoints.add(midpoint); 69 | 70 | // Compute the normal vectors for each pair of vectors. 71 | normal = vectorsB.subtract(midpoint); 72 | normals.add(normal); 73 | } 74 | 75 | } 76 | 77 | public Map getVectorHashes(List vector) { 78 | 79 | RealMatrix xDotNT, vectorAsMatrix; 80 | RealVector threshold; 81 | Map hashes = new HashMap<>(); 82 | Long hash; 83 | Integer i, j; 84 | 85 | // Have to convert the vector to a matrix to support multiplication below. 86 | // TODO: if the List vector argument can be changed to an Array double[] or float[], this would be faster. 87 | vectorAsMatrix = MatrixUtils.createRealMatrix(1, nbDimensions); 88 | for (i = 0; i < nbDimensions; i++) 89 | vectorAsMatrix.setEntry(0, i, vector.get(i)); 90 | 91 | // Compute the hash for this vector with respect to each table. 92 | for (i = 0; i < nbTables; i++) { 93 | xDotNT = vectorAsMatrix.multiply(normalsTransposed.get(i)); 94 | threshold = thresholds.get(i); 95 | hash = 0L; 96 | for (j = 0; j < nbBitsPerTable; j++) 97 | if (xDotNT.getEntry(0, j) > threshold.getEntry(j)) 98 | hash += (long) Math.pow(2, j); 99 | hashes.put(i.toString(), hash); 100 | } 101 | 102 | return hashes; 103 | } 104 | 105 | @SuppressWarnings("unchecked") 106 | public static LshModel fromMap(Map serialized) { 107 | 108 | LshModel lshModel = new LshModel( 109 | (Integer) serialized.get("_aknn_nb_tables"), (Integer) serialized.get("_aknn_nb_bits_per_table"), 110 | (Integer) serialized.get("_aknn_nb_dimensions"), (String) serialized.get("_aknn_description")); 111 | 112 | // TODO: figure out how to cast directly to List or double[][][] and use MatrixUtils.createRealMatrix. 113 | List>> midpointsRaw = (List>>) serialized.get("_aknn_midpoints"); 114 | List>> normalsRaw = (List>>) serialized.get("_aknn_normals"); 115 | for (int i = 0; i < lshModel.nbTables; i++) { 116 | RealMatrix midpoint = MatrixUtils.createRealMatrix(lshModel.nbBitsPerTable, lshModel.nbDimensions); 117 | RealMatrix normal = MatrixUtils.createRealMatrix(lshModel.nbBitsPerTable, lshModel.nbDimensions); 118 | for (int j = 0; j < lshModel.nbBitsPerTable; j++) { 119 | for (int k = 0; k < lshModel.nbDimensions; k++) { 120 | midpoint.setEntry(j, k, midpointsRaw.get(i).get(j).get(k)); 121 | normal.setEntry(j, k, normalsRaw.get(i).get(j).get(k)); 122 | } 123 | } 124 | lshModel.midpoints.add(midpoint); 125 | lshModel.normals.add(normal); 126 | lshModel.normalsTransposed.add(normal.transpose()); 127 | } 128 | 129 | for (int i = 0; i < lshModel.nbTables; i++) { 130 | RealMatrix normal = lshModel.normals.get(i); 131 | RealMatrix midpoint = lshModel.midpoints.get(i); 132 | RealVector threshold = new ArrayRealVector(lshModel.nbBitsPerTable); 133 | for (int j = 0; j < lshModel.nbBitsPerTable; j++) 134 | threshold.setEntry(j, normal.getRowVector(j).dotProduct(midpoint.getRowVector(j))); 135 | lshModel.thresholds.add(threshold); 136 | } 137 | 138 | return lshModel; 139 | } 140 | 141 | public Map toMap() { 142 | return new HashMap() {{ 143 | put("_aknn_nb_tables", nbTables); 144 | put("_aknn_nb_bits_per_table", nbBitsPerTable); 145 | put("_aknn_nb_dimensions", nbDimensions); 146 | put("_aknn_description", description); 147 | put("_aknn_midpoints", midpoints.stream().map(realMatrix -> realMatrix.getData()).collect(Collectors.toList())); 148 | put("_aknn_normals", normals.stream().map(normals -> normals.getData()).collect(Collectors.toList())); 149 | }}; 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /elasticsearch-aknn/src/main/plugin-metadata/plugin-security.policy: -------------------------------------------------------------------------------- 1 | grant { 2 | }; 3 | -------------------------------------------------------------------------------- /elasticsearch-aknn/src/test/java/org/elasticsearch/plugin/aknn/AknnSimpleIT.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright [2018] [Alex Klibisz] 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.elasticsearch.plugin.aknn; 19 | 20 | import org.apache.http.util.EntityUtils; 21 | import org.elasticsearch.client.Client; 22 | import org.elasticsearch.client.Response; 23 | import org.elasticsearch.client.RestClient; 24 | import org.elasticsearch.test.ESIntegTestCase; 25 | import org.junit.Before; 26 | 27 | import java.io.IOException; 28 | 29 | public class AknnSimpleIT extends ESIntegTestCase { 30 | 31 | private Client client; 32 | private RestClient restClient; 33 | 34 | @Before 35 | public void setUp() throws Exception { 36 | super.setUp(); 37 | client = client(); 38 | restClient = getRestClient(); 39 | } 40 | 41 | /** 42 | * Test that the plugin was installed correctly by hitting the _cat/plugins endpoint. 43 | * @throws IOException 44 | */ 45 | public void testPluginInstallation() throws IOException { 46 | Response response = restClient.performRequest("GET", "_cat/plugins"); 47 | String body = EntityUtils.toString(response.getEntity()); 48 | logger.info(body); 49 | assertTrue(body.contains("elasticsearch-aknn")); 50 | } 51 | 52 | } 53 | -------------------------------------------------------------------------------- /elasticsearch-aknn/src/test/java/org/elasticsearch/plugin/aknn/AknnSimpleTests.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright [2018] [Alex Klibisz] 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.elasticsearch.plugin.aknn; 19 | 20 | import org.elasticsearch.test.ESTestCase; 21 | 22 | public class AknnSimpleTests extends ESTestCase { 23 | 24 | // Note: tests must start with the word "test"... 25 | 26 | public void testSomethingTrivial() { 27 | assertTrue(true); 28 | } 29 | 30 | } 31 | -------------------------------------------------------------------------------- /elasticsearch-aknn/testplugin.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # Script for quickly recompiling and testing the elasticsearch-aknn plugin. 3 | set -e 4 | 5 | ESBIN="$HOME/Downloads/elasticsearch-6.2.4/bin" 6 | PLUGINPATH="file:build/distributions/elasticsearch-aknn-0.0.1-SNAPSHOT.zip" 7 | 8 | # TODO: fix the code so that skipping these tasks is not necessary. 9 | gradle clean build -x integTestRunner -x test 10 | $ESBIN/elasticsearch-plugin remove elasticsearch-aknn | true 11 | $ESBIN/elasticsearch-plugin install -b $PLUGINPATH 12 | 13 | sudo sysctl -w vm.max_map_count=262144 14 | export ES_HEAP_SIZE=12g 15 | $ESBIN/elasticsearch 16 | 17 | 18 | -------------------------------------------------------------------------------- /scratch/README.md: -------------------------------------------------------------------------------- 1 | This directory contains small, messy prototypes I built while building the larger, 2 | more refined project. 3 | -------------------------------------------------------------------------------- /scratch/elasticsearch-plugin/.gitignore: -------------------------------------------------------------------------------- 1 | glove_artifacts 2 | !glove_artifacts/glove_knn.txt 3 | 4 | __pycache__ 5 | 6 | ingest-awesome 7 | -------------------------------------------------------------------------------- /scratch/elasticsearch-plugin/commands-ann_processor.txt: -------------------------------------------------------------------------------- 1 | DELETE _ingest/pipeline/ann-pipeline 2 | 3 | PUT _ingest/pipeline/ann-pipeline 4 | { 5 | "description": "A pipeline to store vectors for ANN search", 6 | "processors": [ 7 | { 8 | "ann_processor" : { 9 | "vector_field" : "vector", 10 | "hashes_field": "hashes" 11 | } 12 | } 13 | ] 14 | } 15 | 16 | DELETE ann_ingest_test 17 | 18 | PUT /ann_ingest_test/hashed_vector/1?pipeline=ann-pipeline 19 | { 20 | "vector" : [0.1, 1.2, 2.3, 3.4, 4.5, 5.6] 21 | } 22 | 23 | 24 | GET /ann_ingest_test/hashed_vector/1 25 | -------------------------------------------------------------------------------- /scratch/elasticsearch-plugin/commands-ann_search.txt: -------------------------------------------------------------------------------- 1 | DELETE hashed_vectors 2 | 3 | POST hashed_vectors/hashed_vector/0 4 | { 5 | "description": "Vector 0", 6 | "hashes": { 7 | "0": 1, 8 | "1": 7, 9 | "2": 4 10 | } 11 | } 12 | 13 | POST hashed_vectors/hashed_vector/1 14 | { 15 | "description": "Vector 1", 16 | "hashes": { 17 | "0": 1, 18 | "1": 2, 19 | "2": 4 20 | } 21 | } 22 | 23 | POST hashed_vectors/hashed_vector/2 24 | { 25 | "description": "Vector 2", 26 | "hashes": { 27 | "0": 1, 28 | "1": 4, 29 | "2": 7 30 | } 31 | } 32 | 33 | GET hashed_vectors/hashed_vector/_search 34 | 35 | GET hashed_vectors/hashed_vector/0/_search_ann 36 | 37 | GET /hashed_vectors/hashed_vector/_search 38 | { 39 | "query": { 40 | "bool" : { 41 | "should" : [ 42 | { "term" : { "hashes.0" : 1 }}, 43 | { "term" : { "hashes.1" : 2 }}, 44 | { "term" : { "hashes.2" : 4 }} 45 | ], 46 | "minimum_should_match" : 1, 47 | "boost" : 1.0 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /scratch/elasticsearch-plugin/glove-hashing-in-python/glove_test.py: -------------------------------------------------------------------------------- 1 | from elasticsearch import Elasticsearch, helpers 2 | from sklearn.neighbors import NearestNeighbors 3 | from time import time 4 | from tqdm import tqdm 5 | import json 6 | import os 7 | import numpy as np 8 | import pdb 9 | 10 | from lsh_model import LSHModel 11 | 12 | N = 1000000 # Number of vectors. 13 | D = 300 # Dimension of each vector. 14 | L = 32 # Number of LSH models. 15 | H = 16 # Number of buckets in each LSH model. 16 | RAW_GLOVE_PATH = os.path.expanduser("~") + "/Downloads/glove.840B.300d.txt" 17 | GLOVE_VOC_PATH = "glove_artifacts/glove_vocab.txt" 18 | GLOVE_VEC_PATH = "glove_artifacts/glove_vectors.npy" 19 | GLOVE_KNN_PATH = "glove_artifacts/glove_knn.txt" 20 | GLOVE_TEST_WORDS = ["obama", "quantum", "neural", "olympics", "san", "york", 21 | "nuclear", "data", "pink", "monday"] 22 | LSH_HASHES_PATH = "glove_artifacts/glove_lsh_hashes.txt" 23 | 24 | # Convert the raw glove data into a vocab file and a numpy array file. 25 | if not (os.path.exists(GLOVE_VOC_PATH) and os.path.exists(GLOVE_VEC_PATH)): 26 | 27 | words, vecs = [], np.zeros((N, D)) 28 | 29 | with open(RAW_GLOVE_PATH) as fp: 30 | for i, line in tqdm(enumerate(fp), desc="Processing raw Glove data"): 31 | if i == N: 32 | break 33 | tkns = line.split(" ") 34 | words.append(tkns[0]) 35 | vecs[i] = np.array(list(map(float, tkns[1:]))) 36 | 37 | with open(GLOVE_VOC_PATH, "w") as fp: 38 | fp.write("\n".join(words)) 39 | 40 | np.save(GLOVE_VEC_PATH, vecs.astype(np.float32)) 41 | 42 | 43 | # Compute the real nearest neighbors for a set of test words. 44 | if not (os.path.exists(GLOVE_KNN_PATH)): 45 | 46 | with open(GLOVE_VOC_PATH) as fp: 47 | words = list(map(str.strip, fp)) 48 | word2idx = {w: i for i, w in enumerate(words)} 49 | 50 | vecs = np.load(GLOVE_VEC_PATH) 51 | knn = NearestNeighbors(n_neighbors=5, algorithm='brute', metric='euclidean') 52 | knn.fit(vecs) 53 | 54 | test_ii = list(map(word2idx.get, GLOVE_TEST_WORDS)) 55 | nbrs = knn.kneighbors(vecs[test_ii], return_distance=False) 56 | 57 | with open(GLOVE_KNN_PATH, "w") as fp: 58 | for word, nbrs_ in zip(GLOVE_TEST_WORDS, nbrs): 59 | fp.write("%s %s\n" % (word, " ".join([words[i] for i in nbrs_]))) 60 | 61 | # Fit LSH models and compute the hash from each model on each word vector. 62 | if not os.path.exists(LSH_HASHES_PATH): 63 | 64 | with open(GLOVE_VOC_PATH) as fp: 65 | words = list(map(str.strip, fp)) 66 | word2idx = {w: i for i, w in enumerate(words)} 67 | 68 | vecs = np.load(GLOVE_VEC_PATH) 69 | 70 | lsh_models = [LSHModel(seed=i, H=H).fit(vecs) for i in range(L)] 71 | 72 | for i, lsh_model in enumerate(lsh_models): 73 | print("model %d mean hash = %.3lf" % (i, lsh_model.get_hash(vecs).mean())) 74 | 75 | lines = [] 76 | 77 | for word, vec in tqdm(zip(words, vecs), desc="Computing LSH hashes for each vector"): 78 | lines.append(word) 79 | for lsh_model in lsh_models: 80 | hash_arr = lsh_model.get_hash(vec) 81 | hash_str = ''.join(map(str, hash_arr)) 82 | hash_int = int(hash_str, 2) 83 | lines[-1] = "%s %d" % (lines[-1], hash_int) 84 | 85 | with open(LSH_HASHES_PATH, "w") as fp: 86 | fp.write("\n".join(lines)) 87 | 88 | 89 | # Finally, insert the documents to elasticsearch. 90 | 91 | es = Elasticsearch() 92 | actions = [] 93 | 94 | hashes = ",".join(['"%d": {"type": "integer"}' % i for i in range(L)]) 95 | body = json.loads("""{ 96 | "mappings": { 97 | "hashed_vector": { 98 | "properties": { 99 | "description": { 100 | "type": "text", 101 | "index": false 102 | }, 103 | "hashes": { 104 | "properties": { %s } 105 | }, 106 | "vector": { 107 | "type": "float", 108 | "index": false 109 | } 110 | } 111 | } 112 | } 113 | }""" % hashes) 114 | 115 | es.indices.delete(index="glove_hashed_vectors", ignore=[400, 404]) 116 | es.indices.create(index="glove_hashed_vectors", body=body) 117 | 118 | vecs = np.load(GLOVE_VEC_PATH) 119 | 120 | with open(LSH_HASHES_PATH) as fp: 121 | 122 | for i, line in enumerate(map(str.strip, fp)): 123 | tkns = line.split(" ") 124 | word, hashes = tkns[0], tkns[1:] 125 | 126 | actions.append({ 127 | "_index": "glove_hashed_vectors", 128 | "_type": "hashed_vector", 129 | "_id": word[:100], 130 | "_source": { 131 | "description": word, 132 | "hashes": {str(j): int(h) for j, h in enumerate(hashes)}, 133 | "vector": vecs[i].tolist() 134 | } 135 | }) 136 | 137 | if len(actions) == 10000: 138 | helpers.bulk(es, actions) 139 | print("Inserted %d of %d" % (i, N)) 140 | actions = [] 141 | -------------------------------------------------------------------------------- /scratch/elasticsearch-plugin/glove-hashing-in-python/lsh_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LSHModel(object): 5 | 6 | def __init__(self, seed=865, H=16): 7 | self.rng = np.random.RandomState(seed) 8 | self.H = H 9 | self.N = None 10 | self.NdotM = None 11 | 12 | def fit(self, X): 13 | 14 | # Fit by picking *bits* pairs of points and computing the planes 15 | # equidistant between them. 16 | X_sample = self.rng.choice(X.ravel(), size=(2, self.H, X.shape[-1])) 17 | 18 | # Midpoints for each pair of points. 19 | M = (X_sample[0, ...] + X_sample[1, ...]) / 2. 20 | 21 | # Normal vector for each pair of points. 22 | N = X_sample[-1, ...] - M 23 | 24 | # Keep them around for later. 25 | self.N = N 26 | self.NdotM = (N * M).sum(-1) 27 | 28 | return self 29 | 30 | def get_hash(self, X): 31 | XdotN = X.dot(self.N.T) 32 | H = (XdotN >= self.NdotM).astype(np.uint8) 33 | return H 34 | -------------------------------------------------------------------------------- /scratch/elasticsearch-plugin/glove_create_ann.py: -------------------------------------------------------------------------------- 1 | from elasticsearch import Elasticsearch, helpers 2 | import json 3 | import pdb 4 | import requests 5 | import numpy as np 6 | 7 | D = 300 # Dimension of each vector. 8 | L = 32 # Number of LSH models. 9 | H = 16 # Number of buckets in each LSH model. 10 | 11 | ES_URL = "http://localhost:9200/_aknn_create" 12 | GLOVE_VEC_PATH = "glove-hashing-in-python/glove_artifacts/glove_vectors.npy" 13 | 14 | np.random.seed(1) 15 | vecs = np.load(GLOVE_VEC_PATH) 16 | sample_ii = np.random.permutation(len(vecs))[:2 * L * H] 17 | vecs_sample = vecs[sample_ii] 18 | 19 | print("Sampled %d vectors" % len(vecs_sample)) 20 | 21 | vector_sample_csv = "" 22 | for vec in vecs_sample: 23 | vector_sample_csv += ",".join(map(lambda x: "%.8lf" % x, vec)) + "\n" 24 | 25 | data = { 26 | "_index": "aknn_models", 27 | "_type": "aknn_model", 28 | "_id": "glove_840B_300D", 29 | "description": "AKNN model for Glove Common Crawl 840B (Glove.840B.300d.zip)", 30 | 31 | "nb_tables": L, 32 | "nb_bits_per_table": H, 33 | "nb_dimensions": D, 34 | 35 | "vector_sample_csv": vector_sample_csv 36 | } 37 | 38 | print("Posting to Elasticsearch...") 39 | response = requests.post(ES_URL, json=data) 40 | print(response.json()) 41 | -------------------------------------------------------------------------------- /scratch/elasticsearch-plugin/glove_index_ann.py: -------------------------------------------------------------------------------- 1 | from elasticsearch import Elasticsearch, helpers 2 | import json 3 | import pdb 4 | import requests 5 | import numpy as np 6 | import os 7 | import re 8 | 9 | ES_URL = "http://localhost:9200/_aknn_index" 10 | RAW_GLOVE_PATH = os.path.expanduser("~") + "/Downloads/glove.840B.300d.txt" 11 | BATCH_SIZE = 10000 12 | 13 | INDEX = 'glove_word_vectors' 14 | TYPE = 'word_vector' 15 | ANN_URI = 'aknn_models/aknn_model/glove_840B_300D' 16 | 17 | 18 | es = Elasticsearch() 19 | 20 | body = json.loads("""{ 21 | "mappings": { 22 | "%s": { 23 | "properties": { 24 | "description": { 25 | "type": "text", 26 | "index": false 27 | }, 28 | "vector": { 29 | "type": "float", 30 | "index": false 31 | } 32 | } 33 | } 34 | } 35 | }""" % TYPE) 36 | 37 | es.indices.delete(index=INDEX, ignore=[400, 404]) 38 | es.indices.create(index=INDEX, body=body) 39 | 40 | data = dict(_index=INDEX, _type=TYPE, _ann_uri=ANN_URI, docs=[]) 41 | 42 | regex = re.compile('[^a-zA-Z]') 43 | 44 | for line in open(RAW_GLOVE_PATH): 45 | tkns = line.split(" ") 46 | word, vector = tkns[0], list(map(float, tkns[1:])) 47 | word = regex.sub('', word) 48 | if len(word) == 0: 49 | continue 50 | 51 | data['docs'].append(dict(_id=word, _source=dict( 52 | description=word, _aknn_vector=vector))) 53 | 54 | if len(data['docs']) == BATCH_SIZE: 55 | response = requests.post(ES_URL, json=data) 56 | print(response.json()) 57 | data['docs'] = [] 58 | -------------------------------------------------------------------------------- /scratch/elasticsearch-tweets/.gitignore: -------------------------------------------------------------------------------- 1 | tweet_texts.txt 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /scratch/elasticsearch-tweets/es_index_tweets.py: -------------------------------------------------------------------------------- 1 | """Read cleaned twitter statuses from disk and insert them 2 | to local elasticsearch instance.""" 3 | from tqdm import tqdm 4 | from elasticsearch import Elasticsearch, helpers 5 | import pdb 6 | 7 | if __name__ == "__main__": 8 | es = Elasticsearch() 9 | 10 | # # ~45 minutes to insert ~498K documents. 11 | # for line in tqdm(open("tweet_texts.txt")): 12 | # es.index(index="tweets", doc_type="tweet", body={"text": line}) 13 | 14 | # ~40 seconds to bulk insert in batches of 10000 documents. 15 | actions = [] 16 | 17 | for i, line in tqdm(enumerate(open("tweet_texts.txt"))): 18 | 19 | actions.append({ 20 | "_index": "tweets2", 21 | "_type": "tweet", 22 | "_id": i, 23 | "_source": { 24 | "text": line 25 | } 26 | }) 27 | 28 | if len(actions) == 10000: 29 | helpers.bulk(es, actions) 30 | actions = [] 31 | 32 | helpers.bulk(es, actions) 33 | actions = [] 34 | -------------------------------------------------------------------------------- /scratch/elasticsearch-tweets/get_tweet_texts.py: -------------------------------------------------------------------------------- 1 | """Read twitter statuses from disk and print out only their text""" 2 | from glob import glob 3 | import json 4 | 5 | if __name__ == "__main__": 6 | 7 | statuses_dir = '/home/alex/Desktop/statuses' 8 | 9 | for path in glob('%s/*' % statuses_dir): 10 | 11 | # Attempt to read status as JSON file. Doesn't always work. 12 | try: 13 | with open(path) as fp: 14 | status = json.load(fp) 15 | except json.decoder.JSONDecodeError as ex: 16 | continue 17 | 18 | if 'full_text' in status: 19 | text = status['full_text'] 20 | else: 21 | text = status['text'] 22 | 23 | print(text.replace('\n', '')) 24 | -------------------------------------------------------------------------------- /scratch/elasticsearch-tweets/readme.md: -------------------------------------------------------------------------------- 1 | 2 | ## Metrics 3 | 4 | - 45 minutes to insert 498K documents as individual requests. 5 | - 40 seconds to insert 498K documents in batches of 10K documents. 6 | 7 | ## Results 8 | 9 | Simple query in Kibana: 10 | 11 | ``` 12 | GET tweets/_search 13 | { 14 | "query": { 15 | "match": { 16 | "text": "Fake news media" 17 | } 18 | } 19 | } 20 | ``` 21 | 22 | ``` 23 | { 24 | "took": 17, 25 | "timed_out": false, 26 | "_shards": { 27 | "total": 5, 28 | "successful": 5, 29 | "skipped": 0, 30 | "failed": 0 31 | }, 32 | "hits": { 33 | "total": 21021, 34 | "max_score": 16.087315, 35 | "hits": [ 36 | { 37 | "_index": "tweets", 38 | "_type": "tweet", 39 | "_id": "pGD83mIBXudwSMXjAVz3", 40 | "_score": 16.087315, 41 | "_source": { 42 | "text": "Fake news regularly appears in the media: https://t.co/4ty4SJAUHt\n" 43 | } 44 | }, 45 | { 46 | "_index": "tweets", 47 | "_type": "tweet", 48 | "_id": "wWQR32IBXudwSMXjf0-r", 49 | "_score": 15.683937, 50 | "_source": { 51 | "text": "Fake news media covering up ANTIFA terrorist attack https://t.co/TUNv2saNdq\n" 52 | } 53 | }, 54 | { 55 | "_index": "tweets", 56 | "_type": "tweet", 57 | "_id": "_WUZ32IBXudwSMXjRr_c", 58 | "_score": 15.528734, 59 | "_source": { 60 | "text": "https://t.co/A9li9olSUA This is from Thursday. MEDIA LIES. FAKE NEWS\n" 61 | } 62 | }, 63 | { 64 | "_index": "tweets", 65 | "_type": "tweet", 66 | "_id": "AWD93mIBXudwSMXj-bqt", 67 | "_score": 15.437178, 68 | "_source": { 69 | "text": """ 70 | "The fake news media has been spreading fake news unbelievably. But we're making America great again. Right, folks? I think so" #SOTU 71 | 72 | """ 73 | } 74 | }, 75 | { 76 | "_index": "tweets", 77 | "_type": "tweet", 78 | "_id": "02QP32IBXudwSMXj5gPx", 79 | "_score": 15.33066, 80 | "_source": { 81 | "text": "Cernovich reads hater articles from the fake news media https://t.co/gkZvcsuoky\n" 82 | } 83 | }, 84 | { 85 | "_index": "tweets", 86 | "_type": "tweet", 87 | "_id": "lWIH32IBXudwSMXj4Yeh", 88 | "_score": 15.300294, 89 | "_source": { 90 | "text": "Fake News Panic and the Silencing of Dissident Media’ https://t.co/ZBUL6jE9JY\n" 91 | } 92 | }, 93 | { 94 | "_index": "tweets", 95 | "_type": "tweet", 96 | "_id": "BmD63mIBXudwSMXjPgl5", 97 | "_score": 15.279158, 98 | "_source": { 99 | "text": "Cernovich reads hater articles from the fake news media https://t.co/gkZvcsuoky\n" 100 | } 101 | }, 102 | { 103 | "_index": "tweets", 104 | "_type": "tweet", 105 | "_id": "MmED32IBXudwSMXjNrHy", 106 | "_score": 15.279158, 107 | "_source": { 108 | "text": "Fake and failing news media.@Newsweek #FakeNews 👉https://t.co/ey1yzYBK16https://t.co/xa1XgWCO9h\n" 109 | } 110 | }, 111 | { 112 | "_index": "tweets", 113 | "_type": "tweet", 114 | "_id": "omUX32IBXudwSMXjo3Ij", 115 | "_score": 15.14844, 116 | "_source": { 117 | "text": "@seanhannity @nytimes @FBI Speaking of the fake news media... https://t.co/s4DpS9fIyU\n" 118 | } 119 | }, 120 | { 121 | "_index": "tweets", 122 | "_type": "tweet", 123 | "_id": "IGMM32IBXudwSMXjOlXv", 124 | "_score": 14.964441, 125 | "_source": { 126 | "text": "Trump: Mainstream Media Should Compete for Fake News Trophy - https://t.co/JLV7Vn43oE #CNNisFakeNews 🏆\n" 127 | } 128 | } 129 | ] 130 | } 131 | } 132 | ``` -------------------------------------------------------------------------------- /scratch/es-lsh-glove/.gitignore: -------------------------------------------------------------------------------- 1 | glove_vocab.txt 2 | glove_vecs.npy 3 | glove_knn_exact_cosine.txt 4 | glove_knn_exact_cosine.npy 5 | -------------------------------------------------------------------------------- /scratch/es-lsh-glove/dummy_lsh.py: -------------------------------------------------------------------------------- 1 | """Based on: http://www.bogotobogo.com/Algorithms/Locality_Sensitive_Hashing_LSH_using_Cosine_Distance_Similarity.php.. 2 | Using numpy arrays instead of ints with bitwise arithmetic.""" 3 | 4 | import numpy as np 5 | import math 6 | import pdb 7 | 8 | 9 | def get_signature(data, planes): 10 | """ 11 | LSH signature generation using random projection 12 | Returns the signature bits for two data points. 13 | The signature bits of the two points are different 14 | only for the plane that divides the two points. 15 | """ 16 | sig = np.zeros(len(planes), dtype=np.uint8) 17 | for i, p in enumerate(planes): 18 | sig[i] = int(np.dot(data, p) >= 0) 19 | return sig 20 | 21 | 22 | def get_bitcount(xorsig): 23 | return xorsig.sum() 24 | 25 | 26 | def get_xor(sig1, sig2): 27 | return np.bitwise_xor(sig1, sig2) 28 | 29 | 30 | def length(v): 31 | """returns the length of a vector""" 32 | return math.sqrt(np.dot(v, v)) 33 | 34 | 35 | if __name__ == '__main__': 36 | 37 | dim = 50 # dimension of data points (# of features) 38 | bits = 1024 # number of bits (planes) per signature 39 | run = 999 # number of runs 40 | avg = 0 41 | 42 | # reference planes as many as bits (= signature bits) 43 | ref_planes = np.random.randn(bits, dim).astype(np.float16) 44 | 45 | for r in range(run): 46 | 47 | # Generate two data points p1, p2 48 | pt1 = np.random.randn(dim) 49 | pt2 = np.random.randn(dim) 50 | 51 | # signature bits for two data points 52 | sig1 = get_signature(pt1, ref_planes) 53 | sig2 = get_signature(pt2, ref_planes) 54 | 55 | # Calculates exact angle difference 56 | cosine = np.dot(pt1, pt2) / length(pt1) / length(pt2) 57 | exact = 1 - math.acos(cosine) / math.pi 58 | 59 | # Calculates angle difference using LSH based on cosine distance 60 | # It's using signature bits' count 61 | cosine_hash = 1 - get_bitcount(get_xor(sig1, sig2)) / bits 62 | 63 | # Difference between exact and LSH 64 | diff = abs(cosine_hash - exact) / exact 65 | avg += diff 66 | print('exact %.3f, hash %.3f, diff %.3f' % (exact, cosine_hash, diff)) 67 | 68 | print('avg diff = %.3f' % (avg / run)) 69 | -------------------------------------------------------------------------------- /scratch/es-lsh-glove/get_glove.py: -------------------------------------------------------------------------------- 1 | """Reads the Glove text file and creates a single numpy matrix 2 | containing the vectors and a text file containing the words.""" 3 | 4 | import numpy as np 5 | 6 | glove_path = "/home/alex/Downloads/glove.6B.50d.txt" 7 | 8 | fp_glove = open(glove_path) 9 | fp_words = open("glove_vocab.txt", "w") 10 | words = [] 11 | vecs = np.zeros((400000, 50)) 12 | 13 | for i, line in enumerate(fp_glove): 14 | tkns = line.split() 15 | words.append(tkns[0]) 16 | vecs[i, :] = np.array([float(x) for x in tkns[1:]]) 17 | print(i, words[-1]) 18 | 19 | 20 | fp_words.write('\n'.join(words)) 21 | np.save('glove_vecs.npy', vecs.astype(np.float32)) 22 | -------------------------------------------------------------------------------- /scratch/es-lsh-glove/glove_exact.py: -------------------------------------------------------------------------------- 1 | """Compute the exact KNN for the first 1000 Glove words""" 2 | from sklearn.neighbors import NearestNeighbors 3 | from tqdm import tqdm 4 | import numpy as np 5 | import pdb 6 | import sys 7 | 8 | 9 | if __name__ == "__main__": 10 | 11 | K = 100 12 | B = 500 13 | metric = sys.argv[1] if len(sys.argv) > 1 else 'cosine' 14 | 15 | vecs = np.load('glove_vecs.npy') 16 | vocab_w2i = {w.strip(): i for i, w in enumerate(open('glove_vocab.txt'))} 17 | vocab_i2w = {i: w for w, i in vocab_w2i.items()} 18 | 19 | N = np.zeros((len(vecs), K)) 20 | 21 | knn = NearestNeighbors(n_neighbors=K, algorithm='brute', metric=metric) 22 | knn.fit(vecs) 23 | 24 | for i in tqdm(range(0, len(vecs), B)): 25 | query_words = [vocab_i2w[j] for j in range(i, i + B)] 26 | nbrs = knn.kneighbors(vecs[i:i + B], return_distance=False) 27 | for j, (w, ii) in enumerate(zip(query_words, nbrs)): 28 | # print('%s: %s' % (w, ' '.join([vocab_i2w[k] for k in ii]))) 29 | N[i + j, :] = ii 30 | 31 | np.save('glove_neighbors_exact.npy', N) 32 | -------------------------------------------------------------------------------- /scratch/es-lsh-glove/glove_lsh_es_index.py: -------------------------------------------------------------------------------- 1 | """Hash each word's vector and insert it to elastic search.""" 2 | from elasticsearch import Elasticsearch, helpers 3 | from tqdm import tqdm 4 | import numpy as np 5 | import math 6 | import pdb 7 | 8 | 9 | def get_signature(data, planes): 10 | sig = np.zeros(len(planes), dtype=np.uint8) 11 | for i, p in enumerate(planes): 12 | sig[i] = int(np.dot(data, p) >= 0) 13 | return sig 14 | 15 | 16 | def signature_to_text(sig): 17 | d = "" 18 | for i, v in enumerate(sig): 19 | d += "%d_%d " % (i, v) 20 | return d 21 | 22 | 23 | if __name__ == '__main__': 24 | 25 | dim = 50 # Feature vector dimension. 26 | bits = 1024 # number of bits (planes) per signature 27 | 28 | lsh_planes = np.random.randn(bits, dim) 29 | 30 | vecs = np.load('glove_vecs.npy') 31 | vocab_w2i = {w.strip(): i for i, w in enumerate(open('glove_vocab.txt'))} 32 | vocab_i2w = {i: w for w, i in vocab_w2i.items()} 33 | 34 | es = Elasticsearch() 35 | actions = [] 36 | 37 | for i, vec in tqdm(enumerate(vecs)): 38 | 39 | sgtr = get_signature(vec, lsh_planes) 40 | text = signature_to_text(sgtr) 41 | 42 | actions.append({ 43 | "_index": "glove50", 44 | "_type": "word", 45 | "_id": i, 46 | "_source": { 47 | "word": vocab_i2w[i], 48 | "text": text 49 | } 50 | }) 51 | 52 | if len(actions) == 10000: 53 | helpers.bulk(es, actions) 54 | actions = [] 55 | -------------------------------------------------------------------------------- /scratch/es-lsh-glove/glove_lsh_es_query.py: -------------------------------------------------------------------------------- 1 | """Take a word, execute query against ES, show nearest words""" 2 | 3 | from elasticsearch import Elasticsearch, helpers 4 | import pdb 5 | import sys 6 | 7 | if __name__ == '__main__': 8 | 9 | assert len(sys.argv) > 1 10 | index = sys.argv[1] 11 | doc_type = sys.argv[2] 12 | word = sys.argv[3] 13 | 14 | es = Elasticsearch() 15 | res = es.search( 16 | index=index, doc_type=doc_type, 17 | body={"query": {"match": {"word": word}}}) 18 | 19 | text = res['hits']['hits'][0]['_source']['text'] 20 | res = es.search( 21 | index=index, doc_type=doc_type, 22 | body={"query": {"match": {"text": text}}}) 23 | 24 | print(word) 25 | for hit in res['hits']['hits']: 26 | print(' %s' % hit['_source']['word']) 27 | -------------------------------------------------------------------------------- /scratch/es-lsh-glove/readme.md: -------------------------------------------------------------------------------- 1 | 2 | ## Introduction 3 | 4 | 1. Use [GloVe: Global Vectors for Word Representations](https://nlp.stanford.edu/projects/glove/), specifically the 6B-50D vectors with 400K terms. 5 | 2. Run an exact KNN with cosine distance. 6 | 3. Run a very simple LSH on the Glove vectors, insert the hashes in ES as text documents. Try similarity search 7 | 8 | ## Results 9 | 10 | Insert hashed Glove vectors to Elasticsearch and look them up using the hash text. 11 | 12 | For example, if the hash vector is [0 1 0 1 1 1], the document has tokens ["0_0", "1_1", "2_0", "3_1", "4_1", "5_1"]. 13 | 14 | Documents look like this: 15 | 16 | ``` 17 | GET glove50/_search 18 | { 19 | "size": 1, 20 | "query": { 21 | "term": { 22 | "word": "cat" 23 | } 24 | } 25 | } 26 | ``` 27 | 28 | ``` 29 | { 30 | "took": 1, 31 | "timed_out": false, 32 | "_shards": { 33 | "total": 5, 34 | "successful": 5, 35 | "skipped": 0, 36 | "failed": 0 37 | }, 38 | "hits": { 39 | "total": 4, 40 | "max_score": 11.294465, 41 | "hits": [ 42 | { 43 | "_index": "glove50", 44 | "_type": "word", 45 | "_id": "5450", 46 | "_score": 11.294465, 47 | "_source": { 48 | "text": "0_0 1_1 2_0 3_1 4_1 5_0 6_1 7_0 8_0 9_1 10_0 11_0 12_0 13_1 14_1 15_1 16_0 17_1 18_0 19_0 20_0 21_0 22_0 23_0 24_1 25_1 26_0 27_0 28_1 29_1 30_1 31_0 32_0 33_0 34_0 35_1 36_1 37_1 38_1 39_0 40_0 41_0 42_0 43_0 44_1 45_0 46_1 47_0 48_0 49_0 50_1 51_1 52_0 53_0 54_1 55_1 56_1 57_0 58_1 59_0 60_1 61_1 62_0 63_1 64_1 65_0 66_1 67_1 68_1 69_1 70_0 71_0 72_0 73_1 74_1 75_1 76_0 77_1 78_0 79_0 80_0 81_0 82_0 83_0 84_0 85_0 86_1 87_0 88_1 89_0 90_0 91_0 92_0 93_1 94_0 95_1 96_1 97_0 98_1 99_0 100_0 101_1 102_0 103_0 104_0 49 | ... 50 | 1001_1 1002_1 1003_1 1004_1 1005_0 1006_1 1007_0 1008_0 1009_0 1010_0 1011_1 1012_0 1013_1 1014_1 1015_0 1016_1 1017_0 1018_1 1019_1 1020_1 1021_0 1022_1 1023_0 ", 51 | "word": "cat" 52 | } 53 | } 54 | ] 55 | } 56 | } 57 | ``` 58 | 59 | Similarity query results look like this: 60 | 61 | ``` 62 | (insight) alex@ltp:knn-python$ python glove_lsh_es_query.py glove50 word quantum 63 | quantum 64 | quantum 65 | theory 66 | gravity 67 | relativity 68 | evolution 69 | dynamics 70 | particle 71 | computation 72 | mathematical 73 | molecular 74 | ``` 75 | 76 | ``` 77 | (insight) alex@ltp:knn-python$ python glove_lsh_es_query.py glove50 word music 78 | music 79 | music 80 | musical 81 | recording 82 | studio 83 | pop 84 | artists 85 | songs 86 | contemporary 87 | best 88 | well 89 | ``` 90 | 91 | ``` 92 | (insight) alex@ltp:knn-python$ python glove_lsh_es_query.py glove50 word tennis 93 | tennis 94 | tennis 95 | tournament 96 | golf 97 | volleyball 98 | wimbledon 99 | soccer 100 | open 101 | finals 102 | semi 103 | championships 104 | ``` 105 | 106 | ``` 107 | (insight) alex@ltp:knn-python$ python glove_lsh_es_query.py glove50 word fox 108 | fox 109 | fox 110 | show 111 | abc 112 | nbc 113 | shows 114 | cbs 115 | tv 116 | television 117 | cnn 118 | 's 119 | ``` 120 | 121 | -------------------------------------------------------------------------------- /scratch/es-lsh-images/.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | *.npy 3 | *.txt 4 | -------------------------------------------------------------------------------- /scratch/es-lsh-images/get_imagenet_vectors_labels.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from keras.applications import Xception 3 | from keras.applications.imagenet_utils import preprocess_input, decode_predictions 4 | from keras.models import Model 5 | from scipy.misc import imread, imresize, imshow 6 | from tqdm import tqdm 7 | import numpy as np 8 | import pdb 9 | 10 | 11 | def get_img(img_path, dim=224): 12 | 13 | assert dim % 2 == 0 14 | 15 | img = imread(img_path) 16 | 17 | # Handle grayscale. 18 | if len(img.shape) < 3: 19 | tmp = np.zeros(img.shape + (3,)) 20 | tmp[:, :, 0] = img 21 | tmp[:, :, 1] = img 22 | tmp[:, :, 2] = img 23 | img = tmp 24 | 25 | # Resize image. 26 | h0, w0, d = img.shape 27 | h1, w1 = h0, w0 28 | if h0 < w0: 29 | h1 = dim 30 | w1 = int(dim * w0 / h0) 31 | else: 32 | w1 = dim 33 | h1 = int(dim * h0 / w0) 34 | assert abs((h0 / w0) - (h1 / w1)) <= 0.1, "%d %d %d %d" % (h0, w0, h1, w1) 35 | 36 | img = imresize(img, (h1, w1, d)) 37 | 38 | # Crop image at the center. 39 | # Width > height. 40 | if w1 > h1: 41 | c = int(w1 / 2) 42 | o = int(dim / 2) 43 | img = img[:, c - o: c + o, :] 44 | # Height > width. 45 | elif h1 > w1: 46 | c = int(h1 / 2) 47 | o = int(dim / 2) 48 | img = img[c - o: c + o, :, :] 49 | 50 | assert img.shape == (dim, dim, 3), '%s, %s' % (img_path, img.shape) 51 | 52 | return img 53 | 54 | 55 | imgs_dir = "/home/alex/Downloads/ILSVRC/Data/DET/test/" 56 | vecs_path = 'imagenet_vectors.npy' 57 | 58 | img_paths = sorted(glob('%s/*.JPEG' % imgs_dir)) 59 | fp_paths = open('imagenet_paths.txt', 'w') 60 | 61 | dim = 224 62 | batch_size = 500 63 | imgs_batch = np.zeros((batch_size, dim, dim, 3)) 64 | 65 | # Instantiate model and chop off some layers. 66 | vector_layer = "avg_pool" 67 | m1 = Xception() 68 | m2 = Model(inputs=m1.input, outputs=m1.get_layer(vector_layer).output) 69 | 70 | vecs = np.zeros((len(img_paths), m2.output_shape[-1])) 71 | 72 | for i in range(0, len(img_paths), batch_size): 73 | 74 | for j in range(batch_size): 75 | imgs_batch[j] = get_img(img_paths[i + j], dim) 76 | 77 | imgs_batch = preprocess_input(imgs_batch, mode='tf') 78 | prds_batch = m2.predict(imgs_batch) 79 | vecs[i:i + batch_size] = prds_batch 80 | fp_paths.write('\n'.join(img_paths[i:i + batch_size]) + '\n') 81 | 82 | print('%d-%d %.3lf %.3lf %.3lf' % ( 83 | i, i + batch_size, prds_batch.min(), 84 | np.median(prds_batch), prds_batch.max())) 85 | 86 | np.save(vecs_path, vecs) 87 | -------------------------------------------------------------------------------- /scratch/es-lsh-images/get_twitter_vectors.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from keras.applications import Xception 3 | from keras.applications.imagenet_utils import preprocess_input, decode_predictions 4 | from keras.models import Model 5 | from scipy.misc import imread, imresize, imshow 6 | from tqdm import tqdm 7 | import numpy as np 8 | import pdb 9 | 10 | 11 | def get_img(img_path, dim=224): 12 | 13 | assert dim % 2 == 0 14 | 15 | img = imread(img_path) 16 | 17 | # Handle grayscale. 18 | if len(img.shape) < 3: 19 | tmp = np.zeros(img.shape + (3,)) 20 | tmp[:, :, 0] = img 21 | tmp[:, :, 1] = img 22 | tmp[:, :, 2] = img 23 | img = tmp 24 | 25 | # Resize image. 26 | h0, w0, d = img.shape 27 | h1, w1 = h0, w0 28 | if h0 < w0: 29 | h1 = dim 30 | w1 = int(dim * w0 / h0) 31 | else: 32 | w1 = dim 33 | h1 = int(dim * h0 / w0) 34 | assert abs((h0 / w0) - (h1 / w1)) <= 0.1, "%d %d %d %d" % (h0, w0, h1, w1) 35 | 36 | img = imresize(img, (h1, w1, d)) 37 | 38 | # Crop image at the center. 39 | # Width > height. 40 | if w1 > h1: 41 | c = int(w1 / 2) 42 | o = int(dim / 2) 43 | img = img[:, c - o: c + o, :] 44 | # Height > width. 45 | elif h1 > w1: 46 | c = int(h1 / 2) 47 | o = int(dim / 2) 48 | img = img[c - o: c + o, :, :] 49 | 50 | assert img.shape == (dim, dim, 3), '%s, %s' % (img_path, img.shape) 51 | 52 | return img 53 | 54 | 55 | imgs_dir = "/mnt/data/datasets/insight-twitter-images/images/" 56 | vecs_path = 'twitter_vectors.npy' 57 | 58 | img_paths = glob('%s/*.jpg' % imgs_dir)[:20000] 59 | fp_paths = open('twitter_paths.txt', 'w') 60 | 61 | dim = 224 62 | batch_size = 500 63 | imgs_batch = np.zeros((batch_size, dim, dim, 3)) 64 | 65 | # Instantiate model and chop off some layers. 66 | vector_layer = "avg_pool" 67 | m1 = Xception() 68 | m2 = Model(inputs=m1.input, outputs=m1.get_layer(vector_layer).output) 69 | 70 | vecs = np.zeros((len(img_paths), m2.output_shape[-1])) 71 | 72 | for i in range(0, len(img_paths), batch_size): 73 | 74 | for j in range(batch_size): 75 | imgs_batch[j] = get_img(img_paths[i + j], dim) 76 | 77 | imgs_batch = preprocess_input(imgs_batch, mode='tf') 78 | prds_batch = m2.predict(imgs_batch) 79 | vecs[i:i + batch_size] = prds_batch 80 | fp_paths.write('\n'.join(img_paths[i:i + batch_size]) + '\n') 81 | np.save(vecs_path, vecs[:i+batch_size]) 82 | 83 | print('%d-%d %.3lf %.3lf %.3lf' % ( 84 | i, i + batch_size, prds_batch.min(), 85 | np.median(prds_batch), prds_batch.max())) 86 | 87 | np.save(vecs_path, vecs) 88 | -------------------------------------------------------------------------------- /scratch/es-lsh-images/readme.md: -------------------------------------------------------------------------------- 1 | 2 | ## Introduction 3 | 4 | 1. Use Imagenet test set downloaded from Kaggle Imagenet competition. 5 | 2. Use Xception architecture with Imagenet weights to product 2048-dimensional feature vectors. 6 | 3. Run an exact KNN with cosine distance. See results in `imagenet_knn_exact.ipynb`. 7 | 4. Run a very simple LSH on the vectors. Insert the hashes in ES as text documents. Run similarity search. See very promising results in `imagenet_es_lsh.ipynb`. 8 | 9 | ## Results 10 | 11 | See the two notebooks for visual results. -------------------------------------------------------------------------------- /scratch/image-search-streaming-pipeline/.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.iml 3 | *.ipr 4 | *.iws 5 | -------------------------------------------------------------------------------- /scratch/image-search-streaming-pipeline/pom.xml: -------------------------------------------------------------------------------- 1 | 17 | 18 | 21 | 4.0.0 22 | 23 | image-search-streaming-pipeline 24 | image-search-streaming-pipeline 25 | 0.1 26 | jar 27 | 28 | Image Search Streaming Pipeline 29 | 30 | 31 | UTF-8 32 | 1.1.0 33 | 1.7.7 34 | 1.2.17 35 | 1.0.0-alpha 36 | 1.0.0-alpha 37 | 38 | 39 | 40 | 41 | apache.snapshots 42 | Apache Development Snapshot Repository 43 | https://repository.apache.org/content/repositories/snapshots/ 44 | 45 | false 46 | 47 | 48 | true 49 | 50 | 51 | 52 | 53 | 57 | 58 | 59 | 60 | 61 | org.apache.maven.plugins 62 | maven-compiler-plugin 63 | 3.1 64 | 65 | 1.8 66 | 1.8 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | maven-compiler-plugin 75 | 76 | 1.8 77 | 1.8 78 | jdt 79 | 80 | 81 | 82 | org.eclipse.tycho 83 | tycho-compiler-jdt 84 | 0.21.0 85 | 86 | 87 | 88 | 89 | org.eclipse.m2e 90 | lifecycle-mapping 91 | 1.0.0 92 | 93 | 94 | 95 | 96 | 97 | org.apache.maven.plugins 98 | maven-assembly-plugin 99 | [2.4,) 100 | 101 | single 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | org.apache.maven.plugins 111 | maven-compiler-plugin 112 | [3.1,) 113 | 114 | testCompile 115 | compile 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | org.apache.kafka 134 | kafka-streams 135 | ${kafka.version} 136 | 137 | 138 | 139 | 140 | org.apache.kafka 141 | kafka-streams 142 | ${kafka.version} 143 | 144 | 145 | 146 | 147 | com.amazonaws 148 | aws-java-sdk 149 | 1.11.52 150 | 151 | 152 | 153 | 154 | org.deeplearning4j 155 | deeplearning4j-core 156 | ${dl4j.version} 157 | 158 | 159 | 160 | org.deeplearning4j 161 | deeplearning4j-zoo 162 | ${dl4j.version} 163 | 164 | 165 | 166 | org.nd4j 167 | nd4j-native-platform 168 | ${nd4j.version} 169 | 170 | 171 | 172 | 173 | -------------------------------------------------------------------------------- /scratch/image-search-streaming-pipeline/src/main/java/ImageSearchStreamingPipeline/FeatureExtractor.java: -------------------------------------------------------------------------------- 1 | /** 2 | * TODO: rewrite this to following spec 3 | * - Move logic out of mapValues and into another class. This will require some renaming/repackaging. 4 | * - Write binary vectors to output topic. 5 | */ 6 | 7 | package ImageSearchStreamingPipeline; 8 | 9 | import com.amazonaws.regions.Region; 10 | import com.amazonaws.regions.Regions; 11 | import com.amazonaws.services.s3.AmazonS3; 12 | import com.amazonaws.services.s3.AmazonS3Client; 13 | import com.amazonaws.services.s3.model.GetObjectRequest; 14 | import com.amazonaws.services.s3.model.S3Object; 15 | import org.apache.kafka.common.serialization.Serdes; 16 | import org.apache.kafka.streams.KafkaStreams; 17 | import org.apache.kafka.streams.StreamsBuilder; 18 | import org.apache.kafka.streams.StreamsConfig; 19 | import org.apache.kafka.streams.Topology; 20 | import org.apache.kafka.streams.kstream.KStream; 21 | import org.datavec.image.loader.NativeImageLoader; 22 | import org.deeplearning4j.nn.graph.ComputationGraph; 23 | import org.deeplearning4j.nn.transferlearning.TransferLearning; 24 | import org.deeplearning4j.zoo.PretrainedType; 25 | import org.deeplearning4j.zoo.ZooModel; 26 | import org.deeplearning4j.zoo.model.ResNet50; 27 | import org.nd4j.linalg.api.ndarray.INDArray; 28 | 29 | import javax.imageio.ImageIO; 30 | import java.awt.image.BufferedImage; 31 | import java.io.*; 32 | import java.util.Properties; 33 | import java.util.concurrent.CountDownLatch; 34 | 35 | public class FeatureExtractor { 36 | 37 | public static void main(String[] args) throws Exception { 38 | 39 | // Configuration. 40 | final String bootstrapServer = "localhost:9092"; 41 | final String appID = "streams-image-info-consumer"; 42 | final String inputTopic = "streams-plaintext-input"; 43 | final String outputTopic = "streams-image-info-output"; 44 | final String bucketName = "klibisz-twitter-stream"; 45 | final String awsRegion = "us-east-1"; 46 | 47 | // Map specifying the stream execution configuration. 48 | Properties props = new Properties(); 49 | props.put(StreamsConfig.APPLICATION_ID_CONFIG, appID); 50 | props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServer); 51 | 52 | // Specify serialization and deserialization libraries. 53 | props.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); 54 | props.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); 55 | 56 | // S3 client used to download images. 57 | AmazonS3 s3 = new AmazonS3Client(); 58 | Region usEast1 = Region.getRegion(Regions.US_EAST_1); 59 | s3.setRegion(usEast1); 60 | 61 | // Convnet setup. 62 | final NativeImageLoader imageLoader = new NativeImageLoader(224, 224, 3, true); 63 | final ZooModel zooModel = new ResNet50(); 64 | final ComputationGraph fullConvNet = (ComputationGraph) zooModel.initPretrained(PretrainedType.IMAGENET); 65 | final ComputationGraph truncatedConvNet = new TransferLearning.GraphBuilder(fullConvNet) 66 | .removeVertexAndConnections("fc1000").setOutputs("flatten_3").build(); 67 | 68 | // Define computational logic of the streams application as a topology of nodes. 69 | final StreamsBuilder builder = new StreamsBuilder(); 70 | 71 | // Create source stream from specific Kafka topic containing key-value pairs. 72 | KStream imageKeys = builder.stream(inputTopic); 73 | 74 | KStream imageInfos = imageKeys.mapValues((imageKey) -> { 75 | 76 | // Download image from S3 bucket into memory. 77 | System.out.println(String.format("Downloading %s from S3", imageKey)); 78 | S3Object object = s3.getObject(new GetObjectRequest(bucketName, imageKey)); 79 | System.out.println("Content-Type: " + object.getObjectMetadata().getContentType()); 80 | 81 | String imageInfo = "No image information available"; 82 | String vectorInfo = "No vector information available"; 83 | 84 | // Read image into an n-dimensional array. 85 | try { 86 | INDArray image = imageLoader.asMatrix(object.getObjectContent()); 87 | int[] shape = image.shape(); 88 | imageInfo = String.format( 89 | "Key = %s, shape = (%d x %d x %d), mean intensity = %.3f", 90 | imageKey, shape[2], shape[3], shape[1], image.meanNumber()); 91 | System.out.println("Image info: " + imageInfo); 92 | 93 | INDArray featureVector = truncatedConvNet.outputSingle(image); 94 | vectorInfo = String.format("" + 95 | "Shape = %s, min = %.3f, mean = %.3f, max = %.3f", 96 | featureVector.shapeInfoToString(), featureVector.minNumber(), 97 | featureVector.meanNumber(), featureVector.maxNumber()); 98 | System.out.println("Vector info: " + vectorInfo); 99 | 100 | } catch(IOException ex) { 101 | System.out.println("Problem reading image " + ex); 102 | } 103 | 104 | return imageInfo; 105 | }); 106 | 107 | // Write the image information to output topic. 108 | imageInfos.to(outputTopic); 109 | 110 | // Finalize and describe topology. 111 | final Topology topology = builder.build(); 112 | System.out.println(topology.describe()); 113 | 114 | // Define the stream. 115 | final KafkaStreams streams = new KafkaStreams(topology, props); 116 | 117 | // Define shutdown handler with a countdown. 118 | final CountDownLatch latch = new CountDownLatch(3); 119 | 120 | Runtime.getRuntime().addShutdownHook(new Thread("streams-shutdown-hook") { 121 | @Override 122 | public void run() { 123 | streams.close(); 124 | latch.countDown(); 125 | } 126 | }); 127 | 128 | // Start running. 129 | try { 130 | streams.start(); 131 | latch.await(); 132 | } catch (Throwable e) { 133 | System.exit(1); 134 | } 135 | System.exit(0); 136 | } 137 | } -------------------------------------------------------------------------------- /scratch/image-search-streaming-pipeline/src/main/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one or more 2 | # contributor license agreements. See the NOTICE file distributed with 3 | # this work for additional information regarding copyright ownership. 4 | # The ASF licenses this file to You under the Apache License, Version 2.0 5 | # (the "License"); you may not use this file except in compliance with 6 | # the License. You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | log4j.rootLogger=INFO, console 16 | 17 | log4j.appender.console=org.apache.log4j.ConsoleAppender 18 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 19 | log4j.appender.console.layout.ConversionPattern=%d{HH:mm:ss,SSS} %-5p %-60c %x - %m%n -------------------------------------------------------------------------------- /scratch/image-search-streaming-pipeline/target/classes/ImageSearchStreamingPipeline/FeatureExtractor$1.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexklibisz/elastik-nearest-neighbors/fe69992c133bbb56a81ea6067217820b6f30b6e6/scratch/image-search-streaming-pipeline/target/classes/ImageSearchStreamingPipeline/FeatureExtractor$1.class -------------------------------------------------------------------------------- /scratch/image-search-streaming-pipeline/target/classes/ImageSearchStreamingPipeline/FeatureExtractor.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexklibisz/elastik-nearest-neighbors/fe69992c133bbb56a81ea6067217820b6f30b6e6/scratch/image-search-streaming-pipeline/target/classes/ImageSearchStreamingPipeline/FeatureExtractor.class -------------------------------------------------------------------------------- /scratch/image-search-streaming-pipeline/target/classes/log4j.properties: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one or more 2 | # contributor license agreements. See the NOTICE file distributed with 3 | # this work for additional information regarding copyright ownership. 4 | # The ASF licenses this file to You under the Apache License, Version 2.0 5 | # (the "License"); you may not use this file except in compliance with 6 | # the License. You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | log4j.rootLogger=INFO, console 16 | 17 | log4j.appender.console=org.apache.log4j.ConsoleAppender 18 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 19 | log4j.appender.console.layout.ConversionPattern=%d{HH:mm:ss,SSS} %-5p %-60c %x - %m%n -------------------------------------------------------------------------------- /scratch/image-search-streaming-pipeline/target/image-search-streaming-pipeline-0.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexklibisz/elastik-nearest-neighbors/fe69992c133bbb56a81ea6067217820b6f30b6e6/scratch/image-search-streaming-pipeline/target/image-search-streaming-pipeline-0.1.jar -------------------------------------------------------------------------------- /scratch/image-search-streaming-pipeline/target/maven-archiver/pom.properties: -------------------------------------------------------------------------------- 1 | #Generated by Maven 2 | #Wed Apr 25 11:31:45 EDT 2018 3 | version=0.1 4 | groupId=image-search-streaming-pipeline 5 | artifactId=image-search-streaming-pipeline 6 | -------------------------------------------------------------------------------- /scratch/image-search-streaming-pipeline/target/maven-status/maven-compiler-plugin/compile/default-compile/createdFiles.lst: -------------------------------------------------------------------------------- 1 | ImageSearchStreamingPipeline/FeatureExtractor$1.class 2 | ImageSearchStreamingPipeline/FeatureExtractor.class 3 | -------------------------------------------------------------------------------- /scratch/image-search-streaming-pipeline/target/maven-status/maven-compiler-plugin/compile/default-compile/inputFiles.lst: -------------------------------------------------------------------------------- 1 | /home/alex/Documents/dev/approximate-vector-search/pipeline/stream-processing/image-search-streaming-pipeline/src/main/java/ImageSearchStreamingPipeline/FeatureExtractor.java 2 | -------------------------------------------------------------------------------- /scratch/kafka-streaming/.gitignore: -------------------------------------------------------------------------------- 1 | out.JPEG 2 | imagenet*.json 3 | -------------------------------------------------------------------------------- /scratch/kafka-streaming/imagenet-pizza.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexklibisz/elastik-nearest-neighbors/fe69992c133bbb56a81ea6067217820b6f30b6e6/scratch/kafka-streaming/imagenet-pizza.JPEG -------------------------------------------------------------------------------- /scratch/kafka-streaming/imagenet-ref.py: -------------------------------------------------------------------------------- 1 | from keras.models import Model 2 | from keras.applications import ResNet50 3 | from keras.applications.imagenet_utils import preprocess_input, decode_predictions 4 | from scipy.misc import imread, imshow 5 | import numpy as np 6 | 7 | img_path = "./imagenet-pizza.JPEG" 8 | img_path = "./out.JPEG" 9 | 10 | img = imread(img_path).astype(np.float32) 11 | #img = img[:224, :224, :3] 12 | 13 | img_batch = img[np.newaxis,...] 14 | #img_batch = preprocess_input(img_batch, mode='caffe') 15 | 16 | model = ResNet50() 17 | prds = model.predict(img_batch) 18 | print(decode_predictions(prds)) 19 | 20 | vector_layer = "avg_pool" 21 | model2 = Model(inputs=model.input, outputs=model.get_layer(vector_layer).output) 22 | prds = model2.predict(img_batch) 23 | print(prds.shape, prds.min(), prds.mean(), prds.max()) 24 | -------------------------------------------------------------------------------- /scratch/kafka-streaming/pyconsumer.py: -------------------------------------------------------------------------------- 1 | from kafka import KafkaConsumer 2 | 3 | consumer = KafkaConsumer('test') 4 | for msg in consumer: 5 | print (msg) 6 | -------------------------------------------------------------------------------- /scratch/kafka-streaming/pyproducer.py: -------------------------------------------------------------------------------- 1 | # Example for sending very simple messages to the "test" topic. 2 | from kafka import KafkaProducer 3 | from time import time 4 | 5 | producer = KafkaProducer(bootstrap_servers='localhost:9092') 6 | t0 = time() 7 | 8 | for i in range(10): 9 | m = 'message %d' % (t0 + i) 10 | producer.send('test', m.encode()) 11 | producer.flush() 12 | -------------------------------------------------------------------------------- /scratch/kafka-streaming/streams.examples/.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | .idea 3 | -------------------------------------------------------------------------------- /scratch/kafka-streaming/streams.examples/pom.xml: -------------------------------------------------------------------------------- 1 | 17 | 18 | 21 | 4.0.0 22 | 23 | streams.examples 24 | streams.examples 25 | 0.1 26 | jar 27 | 28 | Kafka Streams Quickstart :: Java 29 | 30 | 31 | UTF-8 32 | 1.1.0 33 | 1.7.7 34 | 1.2.17 35 | 1.0.0-alpha 36 | 1.0.0-alpha 37 | 38 | 39 | 40 | 41 | apache.snapshots 42 | Apache Development Snapshot Repository 43 | https://repository.apache.org/content/repositories/snapshots/ 44 | 45 | false 46 | 47 | 48 | true 49 | 50 | 51 | 52 | 53 | 57 | 58 | 59 | 60 | 61 | org.apache.maven.plugins 62 | maven-compiler-plugin 63 | 3.1 64 | 65 | 1.8 66 | 1.8 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | maven-compiler-plugin 75 | 76 | 1.8 77 | 1.8 78 | jdt 79 | 80 | 81 | 82 | org.eclipse.tycho 83 | tycho-compiler-jdt 84 | 0.21.0 85 | 86 | 87 | 88 | 89 | org.eclipse.m2e 90 | lifecycle-mapping 91 | 1.0.0 92 | 93 | 94 | 95 | 96 | 97 | org.apache.maven.plugins 98 | maven-assembly-plugin 99 | [2.4,) 100 | 101 | single 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | org.apache.maven.plugins 111 | maven-compiler-plugin 112 | [3.1,) 113 | 114 | testCompile 115 | compile 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | org.apache.kafka 135 | kafka-streams 136 | ${kafka.version} 137 | 138 | 139 | 140 | 141 | com.amazonaws 142 | aws-java-sdk 143 | 1.11.52 144 | 145 | 146 | 147 | 148 | org.deeplearning4j 149 | deeplearning4j-core 150 | ${dl4j.version} 151 | 152 | 153 | 154 | org.deeplearning4j 155 | deeplearning4j-zoo 156 | ${dl4j.version} 157 | 158 | 159 | 160 | org.nd4j 161 | nd4j-native-platform 162 | ${nd4j.version} 163 | 164 | 165 | 166 | 167 | -------------------------------------------------------------------------------- /scratch/kafka-streaming/streams.examples/readme.md: -------------------------------------------------------------------------------- 1 | Apache Kafka and Kafka streams tutorials based on: 2 | 3 | - [Tutorial on Apache streams website](https://kafka.apache.org/11/documentation/streams/tutorial) 4 | - [AWS Java SDK Sample repo](https://github.com/aws-samples/aws-java-sample) 5 | 6 | The code at `myapps/ImageInfoConsumer.java` reads from a topic where S3 keys for an image are published, and writes some information about the image (shape and mean pixel intensity) to an output topic. 7 | -------------------------------------------------------------------------------- /scratch/kafka-streaming/streams.examples/src/main/java/myapps/ImageInfoConsumer.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Input: Key for of an image stored in S3 bucket. 3 | * Processing: 4 | * - Download image from S3. 5 | * - Load image into an ND-array (without writing to disk). 6 | * - Compute shape. 7 | * - Compute mean pixel intensity. 8 | * Output: String containing (Image key, Image shape, Mean pixel intensity). 9 | */ 10 | 11 | package myapps; 12 | 13 | import com.amazonaws.regions.Region; 14 | import com.amazonaws.regions.Regions; 15 | import com.amazonaws.services.s3.AmazonS3; 16 | import com.amazonaws.services.s3.AmazonS3Client; 17 | import com.amazonaws.services.s3.model.GetObjectRequest; 18 | import com.amazonaws.services.s3.model.S3Object; 19 | import org.apache.kafka.common.serialization.Serdes; 20 | import org.apache.kafka.streams.KafkaStreams; 21 | import org.apache.kafka.streams.StreamsBuilder; 22 | import org.apache.kafka.streams.StreamsConfig; 23 | import org.apache.kafka.streams.Topology; 24 | import org.apache.kafka.streams.kstream.KStream; 25 | import org.datavec.image.loader.NativeImageLoader; 26 | import org.deeplearning4j.nn.graph.ComputationGraph; 27 | import org.deeplearning4j.nn.transferlearning.TransferLearning; 28 | import org.deeplearning4j.zoo.PretrainedType; 29 | import org.deeplearning4j.zoo.ZooModel; 30 | import org.deeplearning4j.zoo.model.ResNet50; 31 | import org.nd4j.linalg.api.ndarray.INDArray; 32 | 33 | import javax.imageio.ImageIO; 34 | import java.awt.image.BufferedImage; 35 | import java.io.*; 36 | import java.util.Properties; 37 | import java.util.concurrent.CountDownLatch; 38 | 39 | public class ImageInfoConsumer { 40 | 41 | public static void main(String[] args) throws Exception { 42 | 43 | // Configuration. 44 | final String bootstrapServer = "localhost:9092"; 45 | final String appID = "streams-image-info-consumer"; 46 | final String inputTopic = "streams-plaintext-input"; 47 | final String outputTopic = "streams-image-info-output"; 48 | final String bucketName = "klibisz-twitter-stream"; 49 | final String awsRegion = "us-east-1"; 50 | 51 | // Map specifying the stream execution configuration. 52 | Properties props = new Properties(); 53 | props.put(StreamsConfig.APPLICATION_ID_CONFIG, appID); 54 | props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServer); 55 | 56 | // Specify serialization and deserialization libraries. 57 | props.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); 58 | props.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); 59 | 60 | // S3 client used to download images. 61 | AmazonS3 s3 = new AmazonS3Client(); 62 | Region usEast1 = Region.getRegion(Regions.US_EAST_1); 63 | s3.setRegion(usEast1); 64 | 65 | // Convnet setup. 66 | final NativeImageLoader imageLoader = new NativeImageLoader(224, 224, 3, true); 67 | final ZooModel zooModel = new ResNet50(); 68 | final ComputationGraph fullConvNet = (ComputationGraph) zooModel.initPretrained(PretrainedType.IMAGENET); 69 | final ComputationGraph truncatedConvNet = new TransferLearning.GraphBuilder(fullConvNet) 70 | .removeVertexAndConnections("fc1000").setOutputs("flatten_3").build(); 71 | 72 | // Define computational logic of the streams application as a topology of nodes. 73 | final StreamsBuilder builder = new StreamsBuilder(); 74 | 75 | // Create source stream from specific Kafka topic containing key-value pairs. 76 | KStream imageKeys = builder.stream(inputTopic); 77 | 78 | KStream imageInfos = imageKeys.mapValues((imageKey) -> { 79 | 80 | // Download image from S3 bucket into memory. 81 | System.out.println(String.format("Downloading %s from S3", imageKey)); 82 | S3Object object = s3.getObject(new GetObjectRequest(bucketName, imageKey)); 83 | System.out.println("Content-Type: " + object.getObjectMetadata().getContentType()); 84 | 85 | String imageInfo = "No image information available"; 86 | String vectorInfo = "No vector information available"; 87 | 88 | // Read image into an n-dimensional array. 89 | try { 90 | INDArray image = imageLoader.asMatrix(object.getObjectContent()); 91 | int[] shape = image.shape(); 92 | imageInfo = String.format( 93 | "Key = %s, shape = (%d x %d x %d), mean intensity = %.3f", 94 | imageKey, shape[2], shape[3], shape[1], image.meanNumber()); 95 | System.out.println("Image info: " + imageInfo); 96 | 97 | INDArray featureVector = truncatedConvNet.outputSingle(image); 98 | vectorInfo = String.format("" + 99 | "Shape = %s, min = %.3f, mean = %.3f, max = %.3f", 100 | featureVector.shapeInfoToString(), featureVector.minNumber(), 101 | featureVector.meanNumber(), featureVector.maxNumber()); 102 | System.out.println("Vector info: " + vectorInfo); 103 | 104 | } catch(IOException ex) { 105 | System.out.println("Problem reading image " + ex); 106 | } 107 | 108 | return imageInfo; 109 | }); 110 | 111 | // Write the image information to output topic. 112 | imageInfos.to(outputTopic); 113 | 114 | // Finalize and describe topology. 115 | final Topology topology = builder.build(); 116 | System.out.println(topology.describe()); 117 | 118 | // Define the stream. 119 | final KafkaStreams streams = new KafkaStreams(topology, props); 120 | 121 | // Define shutdown handler with a countdown. 122 | final CountDownLatch latch = new CountDownLatch(3); 123 | 124 | Runtime.getRuntime().addShutdownHook(new Thread("streams-shutdown-hook") { 125 | @Override 126 | public void run() { 127 | streams.close(); 128 | latch.countDown(); 129 | } 130 | }); 131 | 132 | // Start running. 133 | try { 134 | streams.start(); 135 | latch.await(); 136 | } catch (Throwable e) { 137 | System.exit(1); 138 | } 139 | System.exit(0); 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /scratch/kafka-streaming/streams.examples/src/main/java/myapps/ImagePrediction.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Given an image, use a pre-trained deep convolutional neural network 3 | * to compute its feature vector. 4 | */ 5 | 6 | 7 | package myapps; 8 | 9 | import org.datavec.image.loader.Java2DNativeImageLoader; 10 | import org.datavec.image.loader.NativeImageLoader; 11 | import org.deeplearning4j.nn.graph.ComputationGraph; 12 | import org.deeplearning4j.nn.graph.vertex.GraphVertex; 13 | import org.deeplearning4j.nn.transferlearning.TransferLearning; 14 | import org.deeplearning4j.zoo.model.ResNet50; 15 | import org.deeplearning4j.zoo.*; 16 | import org.nd4j.linalg.api.ndarray.INDArray; 17 | 18 | import javax.imageio.ImageIO; 19 | import java.awt.image.BufferedImage; 20 | import java.io.*; 21 | 22 | public class ImagePrediction { 23 | 24 | public static void main(String[] args) throws Exception { 25 | 26 | // Load pre-trained Convnet. 27 | ZooModel zooModel = new ResNet50(); 28 | ComputationGraph pretrained = (ComputationGraph) zooModel.initPretrained(PretrainedType.IMAGENET); 29 | 30 | // Load input image. 31 | final String imgPath = "../imagenet-pizza.JPEG"; 32 | File inputFile = new File(imgPath); 33 | BufferedImage inputImage = ImageIO.read(inputFile); 34 | 35 | // Convert input image into format used by Convnet.. 36 | NativeImageLoader inputLoader = new NativeImageLoader(224, 224, 3, true); 37 | INDArray imageMatrix = inputLoader.asMatrix(inputImage); 38 | 39 | // Write image back to disk (to see exactly how the previous step pre-processed the image. 40 | Java2DNativeImageLoader outputLoader = new Java2DNativeImageLoader(); 41 | BufferedImage outputImage = outputLoader.asBufferedImage(imageMatrix); 42 | ImageIO.write(outputImage, "jpg", new File("../out.JPEG")); 43 | 44 | // Make prediction and print most probable class, e.g. 963. 45 | INDArray output = pretrained.outputSingle(imageMatrix); 46 | System.out.println(output.argMax()); 47 | 48 | ComputationGraph model = new TransferLearning.GraphBuilder(pretrained) 49 | .removeVertexAndConnections("fc1000") 50 | .setOutputs("flatten_3") 51 | .build(); 52 | 53 | GraphVertex[] vertices = model.getVertices(); 54 | for (int i = 0; i < vertices.length; i++) { 55 | GraphVertex v = vertices[i]; 56 | // System.out.println(String.format("%d %s", v.getVertexIndex(), v.getVertexName())); 57 | } 58 | 59 | output = model.outputSingle(imageMatrix); 60 | System.out.println(output.shapeInfoToString()); 61 | System.out.println(String.format("%.3f %.3f %.3f", 62 | output.minNumber(), output.meanNumber(), output.maxNumber())); 63 | 64 | System.exit(0); 65 | 66 | } 67 | 68 | } 69 | -------------------------------------------------------------------------------- /scratch/kafka-streaming/streams.examples/src/main/java/myapps/LineSplit.java: -------------------------------------------------------------------------------- 1 | package myapps; 2 | 3 | import org.apache.kafka.common.serialization.Serdes; 4 | import org.apache.kafka.streams.KafkaStreams; 5 | import org.apache.kafka.streams.StreamsBuilder; 6 | import org.apache.kafka.streams.StreamsConfig; 7 | import org.apache.kafka.streams.Topology; 8 | import org.apache.kafka.streams.kstream.KStream; 9 | import org.apache.kafka.streams.kstream.ValueMapper; 10 | 11 | import java.util.Arrays; 12 | import java.util.Properties; 13 | import java.util.concurrent.CountDownLatch; 14 | 15 | public class LineSplit { 16 | public static void main(String[] args) throws Exception { 17 | 18 | // Map specifying the stream execution configuration. 19 | Properties props = new Properties(); 20 | props.put(StreamsConfig.APPLICATION_ID_CONFIG, "streams-linesplit"); 21 | props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); 22 | 23 | // Specify serialization and deserialization libraries. 24 | props.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); 25 | props.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); 26 | 27 | // Define computational logic of the streams application as a topology of nodes. 28 | final StreamsBuilder builder = new StreamsBuilder(); 29 | 30 | // Create source stream from specific Kafka topic containing key-value pairs. 31 | KStream source = builder.stream("streams-plaintext-input"); 32 | 33 | // Apply a flatmap that just splits each line into its constituent words. 34 | KStream words = source.flatMapValues(value -> Arrays.asList(value.split("\\W+"))); 35 | 36 | // Write the word stream to another Kafka topic. 37 | words.to("streams-linesplit-output"); 38 | 39 | // Finalize and describe topology. 40 | final Topology topology = builder.build(); 41 | System.out.println(topology.describe()); 42 | 43 | // Define the stream. 44 | final KafkaStreams streams = new KafkaStreams(topology, props); 45 | 46 | // Define shutdown handler with a countdown. 47 | final CountDownLatch latch = new CountDownLatch(3); 48 | 49 | Runtime.getRuntime().addShutdownHook(new Thread("streams-shutdown-hook") { 50 | @Override 51 | public void run() { 52 | streams.close(); 53 | latch.countDown(); 54 | } 55 | }); 56 | 57 | // Start running. 58 | try { 59 | streams.start(); 60 | latch.await(); 61 | } catch (Throwable e) { 62 | System.exit(1); 63 | } 64 | System.exit(0); 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /scratch/kafka-streaming/streams.examples/src/main/java/myapps/ND4JPlayground.java: -------------------------------------------------------------------------------- 1 | import org.datavec.image.loader.NativeImageLoader; 2 | import org.nd4j.linalg.api.ndarray.INDArray; 3 | 4 | import javax.imageio.ImageIO; 5 | import java.awt.image.BufferedImage; 6 | import java.io.BufferedInputStream; 7 | import java.io.FileInputStream; 8 | 9 | 10 | public class ND4JPlayground { 11 | public static void main(String[] args) throws Exception { 12 | 13 | final String imgPath = "/home/alex/tmp/test.jpg"; 14 | 15 | FileInputStream fs = new FileInputStream(imgPath); 16 | // BufferedInputStream bs = new BufferedInputStream(fs); 17 | // BufferedImage bimg = ImageIO.read(bs); 18 | 19 | NativeImageLoader loader = new NativeImageLoader(); 20 | INDArray img = loader.asMatrix(fs); 21 | 22 | System.out.println(img); 23 | 24 | // INDArray img = INDArray() 25 | 26 | // System.out.println(bs); 27 | // System.out.println(bimg); 28 | 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /scratch/kafka-streaming/streams.examples/src/main/java/myapps/Pipe.java: -------------------------------------------------------------------------------- 1 | package myapps; 2 | 3 | import org.apache.kafka.common.serialization.Serdes; 4 | import org.apache.kafka.streams.KafkaStreams; 5 | import org.apache.kafka.streams.StreamsBuilder; 6 | import org.apache.kafka.streams.StreamsConfig; 7 | import org.apache.kafka.streams.Topology; 8 | import org.apache.kafka.streams.kstream.KStream; 9 | 10 | import java.util.Properties; 11 | import java.util.concurrent.CountDownLatch; 12 | 13 | public class Pipe { 14 | public static void main(String[] args) throws Exception { 15 | 16 | // Map specifying the stream execution configuration. 17 | Properties props = new Properties(); 18 | 19 | // Identify this application vs. others talking to Kafka. 20 | props.put(StreamsConfig.APPLICATION_ID_CONFIG, "streams-pipe"); 21 | 22 | // Specify host/port to establish connection to local Kafka instance. 23 | props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); 24 | 25 | // Specify serialization and deserialization libraries. 26 | props.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); 27 | props.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); 28 | 29 | // Define computational logic of the streams application as a topology of nodes. 30 | final StreamsBuilder builder = new StreamsBuilder(); 31 | 32 | // Create source stream from specific Kafka topic containing key-value pairs. 33 | KStream source = builder.stream("streams-plaintext-input"); 34 | 35 | // Just write the source to another Kafka topic. 36 | source.to("streams-pipe-output"); 37 | 38 | // Finalize the topology. 39 | final Topology topology = builder.build(); 40 | 41 | // Print description of the topology. 42 | System.out.println(topology.describe()); 43 | 44 | // Define the stream. 45 | final KafkaStreams streams = new KafkaStreams(topology, props); 46 | 47 | // Define shutdown handler with a countdown. 48 | final CountDownLatch latch = new CountDownLatch(3); 49 | 50 | Runtime.getRuntime().addShutdownHook(new Thread("streams-shutdown-hook") { 51 | @Override 52 | public void run() { 53 | streams.close(); 54 | latch.countDown(); 55 | } 56 | }); 57 | 58 | // Start running. 59 | try { 60 | streams.start(); 61 | latch.await(); 62 | } catch (Throwable e) { 63 | System.exit(1); 64 | } 65 | System.exit(0); 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /scratch/kafka-streaming/streams.examples/src/main/java/myapps/Wordcount.java: -------------------------------------------------------------------------------- 1 | package myapps; 2 | 3 | import org.apache.kafka.common.serialization.Serde; 4 | import org.apache.kafka.common.serialization.Serdes; 5 | import org.apache.kafka.common.utils.Bytes; 6 | import org.apache.kafka.streams.KafkaStreams; 7 | import org.apache.kafka.streams.StreamsBuilder; 8 | import org.apache.kafka.streams.StreamsConfig; 9 | import org.apache.kafka.streams.Topology; 10 | import org.apache.kafka.streams.kstream.*; 11 | import org.apache.kafka.streams.state.KeyValueStore; 12 | 13 | import java.util.Arrays; 14 | import java.util.Properties; 15 | import java.util.concurrent.CountDownLatch; 16 | 17 | public class Wordcount { 18 | public static void main(String[] args) throws Exception { 19 | 20 | // Map specifying the stream execution configuration. 21 | Properties props = new Properties(); 22 | props.put(StreamsConfig.APPLICATION_ID_CONFIG, "streams-wordcount"); 23 | props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); 24 | 25 | // Specify serialization and deserialization libraries. 26 | props.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); 27 | props.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); 28 | 29 | // Define computational logic of the streams application as a topology of nodes. 30 | final StreamsBuilder builder = new StreamsBuilder(); 31 | 32 | // Create source stream from specific Kafka topic containing key-value pairs. 33 | KStream source = builder.stream("streams-plaintext-input"); 34 | 35 | // Apply a flatmap that just splits each line into its constituent words. 36 | KStream words = source.flatMapValues(value -> Arrays.asList(value.toLowerCase().split("\\W+"))); 37 | 38 | // To do counting aggregation, you have to group items by a key to maintain their state. 39 | // In this case, the key is just the value, which is the word from the flatmap above. 40 | // The groupBy is followed by a count operation which stores its state in the "counts-store". 41 | KTable counts = words.groupBy((key, value) -> value).count(); 42 | 43 | // Specify the output location and format. 44 | counts.toStream().to("streams-wordcount-output", Produced.with(Serdes.String(), Serdes.Long())); 45 | 46 | // Create and describe topology. 47 | final Topology topology = builder.build(); 48 | System.out.println(topology.describe()); 49 | 50 | // Define the stream. 51 | final KafkaStreams streams = new KafkaStreams(topology, props); 52 | 53 | // Define shutdown handler with a countdown. 54 | final CountDownLatch latch = new CountDownLatch(3); 55 | 56 | Runtime.getRuntime().addShutdownHook(new Thread("streams-shutdown-hook") { 57 | @Override 58 | public void run() { 59 | streams.close(); 60 | latch.countDown(); 61 | } 62 | }); 63 | 64 | // Start running. 65 | try { 66 | streams.start(); 67 | latch.await(); 68 | } catch (Throwable e) { 69 | System.exit(1); 70 | } 71 | System.exit(0); 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /scratch/kafka-streaming/streams.examples/src/main/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one or more 2 | # contributor license agreements. See the NOTICE file distributed with 3 | # this work for additional information regarding copyright ownership. 4 | # The ASF licenses this file to You under the Apache License, Version 2.0 5 | # (the "License"); you may not use this file except in compliance with 6 | # the License. You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | log4j.rootLogger=INFO, console 16 | 17 | log4j.appender.console=org.apache.log4j.ConsoleAppender 18 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 19 | log4j.appender.console.layout.ConversionPattern=%d{HH:mm:ss,SSS} %-5p %-60c %x - %m%n -------------------------------------------------------------------------------- /scratch/lsh-experiments/.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | -------------------------------------------------------------------------------- /scratch/mvp-big/batch_feature_vectors.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from glob import glob 3 | from io import BytesIO 4 | from keras.models import Model 5 | from keras.applications import MobileNet 6 | from keras.applications.imagenet_utils import preprocess_input, decode_predictions 7 | from scipy.misc import imread, imsave 8 | from skimage.transform import resize 9 | from time import time 10 | from tqdm import tqdm 11 | import base64 12 | import numpy as np 13 | import pdb 14 | 15 | 16 | class ImageVectorizer(object): 17 | 18 | def __init__(self): 19 | net = MobileNet() 20 | self.convnet = Model(net.input, net.get_layer('conv_preds').output) 21 | 22 | def get_feature_vectors(self, img_list): 23 | 24 | t0 = time() 25 | 26 | img_arr = np.zeros((len(img_list), 224, 224, 3)) 27 | 28 | for i, img in enumerate(img_list): 29 | if len(img.shape) == 2: 30 | img = np.repeat(img[:, :, np.newaxis], 3, -1) 31 | img_arr[i] = resize(img[:, :, :3], (224, 224), 32 | mode='reflect', 33 | preserve_range=True) 34 | 35 | # Pre-process batch for keras. 36 | img_arr = preprocess_input(img_arr, mode='caffe') 37 | 38 | print('Preprocessing', time() - t0) 39 | 40 | # Compute, return feature vectors. 41 | t0 = time() 42 | vec_arr = self.convnet.predict(img_arr) 43 | vec_arr = np.squeeze(vec_arr) 44 | print('Computing', time() - t0) 45 | return vec_arr 46 | 47 | 48 | if __name__ == "__main__": 49 | 50 | ap = ArgumentParser(description="Compute feature vectors for a large batch of images") 51 | ap.add_argument('images_dir', help='full path to directory where images are stored') 52 | ap.add_argument('index_path', help='path to index file') 53 | ap.add_argument('features_dir', help='full path to directory where features will be stored') 54 | ap.add_argument('-b', '--batch', help='batch size', default=128, type=int) 55 | args = vars(ap.parse_args()) 56 | 57 | img_vectorizer = ImageVectorizer() 58 | img_list = [] 59 | id_list = [] 60 | 61 | # with open(args['index_path']) as fp: 62 | # lines = fp.read().split('\n') 63 | 64 | for i, twitter_id, ext in map(str.split, open(args['index_path'])): 65 | 66 | # i, twitter_id, ext = line.split(' ') 67 | img_path = "%s/%s.%s" % (args['images_dir'], twitter_id, ext) 68 | 69 | t0 = time() 70 | 71 | try: 72 | img_list.append(imread(img_path)) 73 | id_list.append(twitter_id) 74 | except Exception as ex: 75 | print("Error reading image %s" % img_path, ex) 76 | 77 | if len(img_list) < args['batch']: 78 | continue 79 | 80 | print('Reading', time() - t0) 81 | 82 | try: 83 | vec_arr = img_vectorizer.get_feature_vectors(img_list) 84 | img_list = [] 85 | id_list = [] 86 | print("---") 87 | except Exception as ex: 88 | print("Error computing vectors", ex) 89 | -------------------------------------------------------------------------------- /scratch/mvp-big/kafka_convnet_consumer.py: -------------------------------------------------------------------------------- 1 | from kafka import KafkaConsumer, KafkaProducer 2 | from io import BytesIO 3 | from keras.models import Model 4 | from keras.applications import ResNet50 5 | from keras.applications.imagenet_utils import preprocess_input, decode_predictions 6 | from scipy.misc import imread, imsave 7 | from skimage.transform import resize 8 | from time import time 9 | from tqdm import tqdm 10 | import base64 11 | import numpy as np 12 | import pdb 13 | 14 | KAFKA_SERVERS = [ 15 | "ip-172-31-19-114.ec2.internal:9092", 16 | "ip-172-31-18-192.ec2.internal:9092", 17 | "ip-172-31-20-205.ec2.internal:9092" 18 | ] 19 | KAFKA_SUB_TOPIC = "aknn-demo-twitter-images-base64" 20 | KAFKA_PUB_TOPIC = "aknn-demo-feature-vectors" 21 | KAFKA_GROUP_ID = "aknn-demo-convnet-consumers" 22 | 23 | 24 | class ImageVectorizer(object): 25 | 26 | def __init__(self): 27 | self.resnet50 = ResNet50(include_top=False) 28 | 29 | def get_feature_vector(self, img): 30 | 31 | # Pre-process image for keras. 32 | img = resize(img[:, :, :3], (224, 224), preserve_range=True) 33 | img_batch = preprocess_input(img[np.newaxis, ...], mode='caffe') 34 | 35 | # Compute keras output. 36 | vec_batch = self.resnet50.predict(img_batch) 37 | vec = vec_batch.reshape((vec_batch.shape[-1])) 38 | 39 | # Return numpy feature vector 40 | return vec 41 | 42 | 43 | if __name__ == "__main__": 44 | 45 | consumer = KafkaConsumer( 46 | KAFKA_SUB_TOPIC, 47 | bootstrap_servers=",".join(KAFKA_SERVERS), 48 | group_id=KAFKA_GROUP_ID) 49 | 50 | producer = KafkaProducer(bootstrap_servers=",".join(KAFKA_SERVERS)) 51 | 52 | image_vectorizer = ImageVectorizer() 53 | 54 | pbar = tqdm(consumer) 55 | for msg in pbar: 56 | try: 57 | bod = BytesIO(base64.decodebytes(msg.value)) 58 | img = imread(bod) 59 | vec = image_vectorizer.get_feature_vector(img).astype(np.float16) 60 | producer.send(KAFKA_PUB_TOPIC, key=msg.key, value=vec.tostring()) 61 | pbar.set_description("%s: %.2lf, %.2lf, %.2lf" % (msg.key.decode(), vec.min(), vec.mean(), vec.max())) 62 | except Exception as ex: 63 | print("Exception", msg, ex) 64 | 65 | producer.flush() 66 | -------------------------------------------------------------------------------- /scratch/mvp-big/kafka_image_producer.py: -------------------------------------------------------------------------------- 1 | from kafka import KafkaProducer 2 | from tqdm import tqdm 3 | import base64 4 | import sys 5 | import os 6 | import pdb 7 | import random 8 | 9 | KAFKA_SERVER = "ip-172-31-19-114.ec2.internal:9092" 10 | KAFKA_PUB_TOPIC = "aknn-demo-twitter-images-base64" 11 | 12 | if __name__ == "__main__": 13 | 14 | images_dir = sys.argv[1] 15 | N = int(sys.argv[2]) 16 | 17 | producer = KafkaProducer(bootstrap_servers=KAFKA_SERVER) 18 | image_paths = os.listdir(images_dir) 19 | 20 | pbar = tqdm(random.sample(image_paths, N)) 21 | for image_fname in pbar: 22 | with open("%s/%s" % (images_dir, image_fname), "rb") as fp: 23 | b64 = base64.b64encode(fp.read()) 24 | producer.send(KAFKA_PUB_TOPIC, key=image_fname.encode(), value=b64) 25 | pbar.set_description(image_fname) 26 | 27 | producer.flush() 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /scratch/mvp-big/kafka_reset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | ZK="ip-172-31-82-223.ec2.internal:2181" 4 | T1="aknn-demo-twitter-images-base64" 5 | T2="aknn-demo-feature-vectors" 6 | 7 | kafka-topics.sh --zookeeper $ZK --delete --topic $T1 8 | kafka-topics.sh --zookeeper $ZK --delete --topic $T2 9 | kafka-topics.sh --zookeeper $ZK --create --topic $T1 --replication-factor 2 --partitions 10 10 | kafka-topics.sh --zookeeper $ZK --create --topic $T2 --replication-factor 1 --partitions 10 11 | kafka-topics.sh --zookeeper $ZK --list 12 | 13 | -------------------------------------------------------------------------------- /scratch/mvp-big/kafka_watch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | SERVER=$1 4 | TOPIC=$2 5 | 6 | echo $SERVER 7 | echo $TOPIC 8 | 9 | kafka-console-consumer.sh --bootstrap-server "$SERVER" --topic "$TOPIC" \ 10 | --from-beginning \ 11 | --property key.deserializer=org.apache.kafka.common.serialization.StringDeserializer \ 12 | --property print.key=true \ 13 | --property print.value=false 14 | -------------------------------------------------------------------------------- /scratch/mvp-big/requirements.txt: -------------------------------------------------------------------------------- 1 | kafka-python 2 | keras 3 | tensorflow 4 | numpy 5 | scipy 6 | scikit-image 7 | tqdm 8 | -------------------------------------------------------------------------------- /scratch/mvp/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | Approximate Vector Search Demo 4 | 5 | 16 | 17 | 18 | 19 | 20 | 21 |
22 |

This really primitive web-app picks a random image indexed in ElasticSearch, 23 | queries for its nearest neighbors, and shows them.

24 |

The image outlined in red is the "query" image. The remaining are its nine 25 | nearest neighbors.

26 |

There are about 6000 images total indexed in Elasticsearch, all taken from 27 | the imagenet test set, downloaded from Kaggle.

28 |

Below the images you can see the query response. Note how each image is 29 | represented by the "text" field, which is a string of tokens corresponding to 30 | the document's approximate location in the vector space of images.

31 |
32 |
33 | 34 |
35 | 36 |
37 | 38 |
39 | 40 | 41 | 42 | 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /scratch/mvp/kafka_glove_elasticsearch_insert.py: -------------------------------------------------------------------------------- 1 | # Kafka consumer with following responsibilities: 2 | # - Read LSH vector from topic. 3 | # - Insert LSH vector to Elasticsearch. 4 | from elasticsearch import Elasticsearch, helpers 5 | from confluent_kafka import Consumer, KafkaError, Producer 6 | from time import time 7 | import numpy as np 8 | import pdb 9 | import os 10 | 11 | 12 | def vec_to_text(vec): 13 | tokens = [] 14 | for i, b in enumerate(vec): 15 | tokens.append("%d_%d" % (i, b)) 16 | return " ".join(tokens) 17 | 18 | 19 | if __name__ == "__main__": 20 | 21 | K_SERVER = "localhost:9092" 22 | K_SUB_TOPIC = "glove-hash-vectors" 23 | 24 | settings = { 25 | 'bootstrap.servers': K_SERVER, 26 | 'group.id': 'TODO', 27 | 'client.id': 'client-%d' % time(), 28 | 'enable.auto.commit': True, 29 | 'session.timeout.ms': 6000, 30 | 'default.topic.config': { 31 | 'auto.offset.reset': 'smallest' 32 | } 33 | } 34 | 35 | consumer = Consumer(settings) 36 | consumer.subscribe([K_SUB_TOPIC]) 37 | 38 | es = Elasticsearch() 39 | actions = [] 40 | 41 | while True: 42 | 43 | # TODO: is it really best practice to poll like this? 44 | msg = consumer.poll(0.1) 45 | if msg is None: 46 | continue 47 | 48 | if msg.error(): 49 | print('Error: %s' % msg.error().str()) 50 | continue 51 | 52 | key = msg.key().decode() 53 | hsh = np.fromstring(msg.value(), dtype=np.uint8) 54 | 55 | actions.append({ 56 | "_index": "glove_vectors", 57 | "_type": "word", 58 | "_id": key, 59 | "_source": { 60 | "key": key, 61 | "text": vec_to_text(hsh) 62 | } 63 | }) 64 | 65 | # TODO: actually do a bulk insertion... 66 | helpers.bulk(es, actions) 67 | actions = [] 68 | 69 | print('%s %s %.3lf' % (key, str(hsh.shape), hsh.mean())) 70 | 71 | # TODO: use atexit to make sure it flushes if the script fails. 72 | producer.flush() 73 | -------------------------------------------------------------------------------- /scratch/mvp/kafka_glove_feature_vectors.py: -------------------------------------------------------------------------------- 1 | # Kafka consumer with following responsibilities: 2 | # - Read S3 key from Kafka topic. 3 | # - Compute feature vector via pre-trained convnet. 4 | # - Publish feature vector to topic. 5 | # See below link for description of configuration: 6 | # https://www.confluent.io/blog/introduction-to-apache-kafka-for-python-programmers/ 7 | 8 | from confluent_kafka import Consumer, KafkaError, Producer 9 | from time import time 10 | import boto3 11 | import numpy as np 12 | import pdb 13 | 14 | if __name__ == "__main__": 15 | 16 | K_SERVER = "localhost:9092" 17 | K_PUB_TOPIC = "glove-feature-vectors" 18 | 19 | producer = Producer({"bootstrap.servers": K_SERVER}) 20 | 21 | glove_dir = '/home/alex/dev/approximate-vector-search/scratch/es-lsh-glove' 22 | glove_keys = [l.strip() for l in open('%s/glove_vocab.txt' % glove_dir)] 23 | glove_vecs = np.load('%s/glove_vecs.npy' % glove_dir).astype(np.float32) 24 | 25 | glove_keys = glove_keys[:90000] 26 | glove_vecs = glove_vecs[:90000] 27 | 28 | for i, (key, vec) in enumerate(zip(glove_keys, glove_vecs)): 29 | producer.produce(K_PUB_TOPIC, key=key, value=vec.tostring()) 30 | print(i, key, vec.mean()) 31 | 32 | producer.flush() 33 | -------------------------------------------------------------------------------- /scratch/mvp/kafka_glove_lsh_vectors.py: -------------------------------------------------------------------------------- 1 | # Kafka consumer with following responsibilities: 2 | # - Read feature vector from Kafka topic. 3 | # - Compute vector hash via LSH. 4 | # - Publish vector hash to topic. 5 | 6 | from confluent_kafka import Consumer, KafkaError, Producer 7 | from time import time 8 | import numpy as np 9 | import pdb 10 | import os 11 | 12 | 13 | class SimpleLSH(object): 14 | 15 | def __init__(self, seed=865, bits=1024): 16 | self.bits = bits 17 | self.rng = np.random.RandomState(seed) 18 | self.planes = None 19 | self.M = None 20 | self.N = None 21 | self.NdotM = None 22 | 23 | def save(self, model_path): 24 | np.savez(model_path, self.M, self.N, self.NdotM) 25 | return self 26 | 27 | def load(self, model_path): 28 | arrs = np.load(model_path) 29 | self.M = arrs['arr_0'] 30 | self.N = arrs['arr_1'] 31 | self.NdotM = arrs['arr_2'] 32 | return self 33 | 34 | def fit(self, X): 35 | sample_ii = self.rng.choice(range(len(X)), 2 * self.bits) 36 | X_sample = X[sample_ii].reshape(2, self.bits, X.shape[-1]) 37 | self.M = (X_sample[0, ...] + X_sample[1, ...]) / 2 38 | self.N = X_sample[-1, ...] - self.M 39 | self.NdotM = (self.N * self.M).sum(-1) 40 | return self 41 | 42 | def get_vector_hash(self, X): 43 | XdotN = X.dot(self.N.T) 44 | return (XdotN >= self.NdotM).astype(np.uint8) 45 | 46 | 47 | if __name__ == "__main__": 48 | 49 | K_SERVER = "localhost:9092" 50 | K_SUB_TOPIC = "glove-feature-vectors" 51 | K_PUB_TOPIC = "glove-hash-vectors" 52 | 53 | settings = { 54 | 'bootstrap.servers': K_SERVER, 55 | 'group.id': 'TODO', 56 | 'client.id': 'client-%d' % time(), 57 | 'enable.auto.commit': True, 58 | 'session.timeout.ms': 6000, 59 | 'default.topic.config': { 60 | 'auto.offset.reset': 'smallest' 61 | } 62 | } 63 | 64 | consumer = Consumer(settings) 65 | producer = Producer({"bootstrap.servers": K_SERVER}) 66 | consumer.subscribe([K_SUB_TOPIC]) 67 | 68 | glove_dir = '/home/alex/dev/approximate-vector-search/scratch/es-lsh-glove' 69 | vecs = np.load('%s/glove_vecs.npy' % glove_dir).astype(np.float32) 70 | simple_lsh = SimpleLSH(bits=1024, seed=865) 71 | simple_lsh.fit(vecs) 72 | 73 | while True: 74 | 75 | # TODO: is it really best practice to poll like this? 76 | msg = consumer.poll(0.1) 77 | if msg is None: 78 | continue 79 | 80 | if msg.error(): 81 | print('Error: %s' % msg.error().str()) 82 | continue 83 | 84 | key = msg.key().decode() 85 | vec = np.fromstring(msg.value(), dtype=np.float32) 86 | hsh = simple_lsh.get_vector_hash(vec[np.newaxis, :]) 87 | 88 | print('%s %s %.3lf %s %.3lf' % ( 89 | key, str(vec.shape), vec.mean(), str(hsh.shape), hsh.mean())) 90 | 91 | producer.produce(K_PUB_TOPIC, key=key, value=hsh.tostring()) 92 | 93 | # TODO: use atexit to make sure it flushes if the script fails. 94 | producer.flush() 95 | -------------------------------------------------------------------------------- /scratch/mvp/kafka_image_elasticsearch_insert.py: -------------------------------------------------------------------------------- 1 | # Kafka consumer with following responsibilities: 2 | # - Read LSH vector from topic. 3 | # - Insert LSH vector to Elasticsearch. 4 | from elasticsearch import Elasticsearch, helpers 5 | from confluent_kafka import Consumer, KafkaError, Producer 6 | from time import time 7 | import numpy as np 8 | import pdb 9 | import os 10 | 11 | 12 | def vec_to_text(vec): 13 | tokens = [] 14 | for i, b in enumerate(vec): 15 | tokens.append("%d_%d" % (i, b)) 16 | return " ".join(tokens) 17 | 18 | 19 | if __name__ == "__main__": 20 | 21 | K_SERVER = "localhost:9092" 22 | K_SUB_TOPIC = "image-hash-vectors" 23 | 24 | settings = { 25 | 'bootstrap.servers': K_SERVER, 26 | 'group.id': 'TODO', 27 | 'client.id': 'client-%d' % time(), 28 | 'enable.auto.commit': True, 29 | 'session.timeout.ms': 6000, 30 | 'default.topic.config': { 31 | 'auto.offset.reset': 'smallest' 32 | } 33 | } 34 | 35 | consumer = Consumer(settings) 36 | consumer.subscribe([K_SUB_TOPIC]) 37 | 38 | es = Elasticsearch() 39 | actions = [] 40 | 41 | while True: 42 | 43 | # TODO: is it really best practice to poll like this? 44 | msg = consumer.poll(0.1) 45 | if msg is None: 46 | continue 47 | 48 | if msg.error(): 49 | print('Error: %s' % msg.error().str()) 50 | continue 51 | 52 | image_key = msg.key().decode() 53 | hsh = np.fromstring(msg.value(), dtype=np.uint8) 54 | 55 | actions.append({ 56 | "_index": "imagenet_images", 57 | "_type": "image", 58 | "_id": image_key, 59 | "_source": { 60 | "image_key": image_key, 61 | "text": vec_to_text(hsh) 62 | } 63 | }) 64 | 65 | # TODO: actually do a bulk insertion... 66 | helpers.bulk(es, actions) 67 | actions = [] 68 | 69 | print('%s %s %.3lf' % (image_key, str(hsh.shape), hsh.mean())) 70 | 71 | # TODO: use atexit to make sure it flushes if the script fails. 72 | producer.flush() 73 | -------------------------------------------------------------------------------- /scratch/mvp/kafka_image_feature_vectors.py: -------------------------------------------------------------------------------- 1 | # Kafka consumer with following responsibilities: 2 | # - Read S3 key from Kafka topic. 3 | # - Compute feature vector via pre-trained convnet. 4 | # - Publish feature vector to topic. 5 | # See below link for description of configuration: 6 | # https://www.confluent.io/blog/introduction-to-apache-kafka-for-python-programmers/ 7 | 8 | from confluent_kafka import Consumer, KafkaError, Producer 9 | from io import BytesIO 10 | from keras.models import Model 11 | from keras.applications import ResNet50 12 | from keras.applications.imagenet_utils import preprocess_input, decode_predictions 13 | from scipy.misc import imread 14 | from skimage.transform import resize 15 | from time import time 16 | import boto3 17 | import numpy as np 18 | import pdb 19 | 20 | 21 | class ImageVectorizer(object): 22 | 23 | def __init__(self): 24 | self.resnet50 = ResNet50(include_top=False) 25 | self.s3_client = boto3.client('s3') 26 | 27 | def get_feature_vector(self, s3_key): 28 | 29 | # Download image from S3 into memory. 30 | obj = self.s3_client.get_object(Bucket='klibisz-test', Key=s3_key) 31 | bod = BytesIO(obj['Body'].read()) 32 | img = imread(bod) 33 | 34 | # Pre-process image for keras. 35 | img = resize(img, (224, 224), preserve_range=True) 36 | img_batch = preprocess_input(img[np.newaxis, ...], mode='caffe') 37 | 38 | # Compute keras output. 39 | vec_batch = self.resnet50.predict(img_batch) 40 | vec = vec_batch.reshape((vec_batch.shape[-1])) 41 | 42 | # Return numpy feature vector 43 | return vec 44 | 45 | 46 | if __name__ == "__main__": 47 | 48 | K_SERVER = "localhost:9092" 49 | K_SUB_TOPIC = "image-s3-keys" 50 | K_PUB_TOPIC = "image-feature-vectors" 51 | 52 | settings = { 53 | 'bootstrap.servers': K_SERVER, 54 | 'group.id': 'TODO', 55 | 'client.id': 'client-%d' % time(), 56 | 'enable.auto.commit': True, 57 | 'session.timeout.ms': 6000, 58 | 'default.topic.config': { 59 | 'auto.offset.reset': 'smallest' 60 | } 61 | } 62 | 63 | consumer = Consumer(settings) 64 | producer = Producer({"bootstrap.servers": K_SERVER}) 65 | consumer.subscribe([K_SUB_TOPIC]) 66 | 67 | image_vectorizer = ImageVectorizer() 68 | 69 | while True: 70 | 71 | # TODO: is it really best practice to poll like this? 72 | msg = consumer.poll(0.1) 73 | if msg is None: 74 | continue 75 | 76 | if msg.error(): 77 | print('Error: %s' % msg.error().str()) 78 | continue 79 | 80 | # Compute feature vector. 81 | try: 82 | image_key = msg.value().decode() 83 | vec = image_vectorizer.get_feature_vector(image_key) 84 | print('%s %s %.3lf %.3lf %.3lf' % ( 85 | image_key, str(vec.shape), vec.min(), vec.mean(), vec.max())) 86 | producer.produce(K_PUB_TOPIC, key=image_key, value=vec.tostring()) 87 | except Exception as ex: 88 | print('Exception processing image key %s: %s' % (image_key, ex)) 89 | 90 | # TODO: use atexit to make sure it flushes if the script fails. 91 | producer.flush() 92 | -------------------------------------------------------------------------------- /scratch/mvp/kafka_image_lsh_vectors.py: -------------------------------------------------------------------------------- 1 | # Kafka consumer with following responsibilities: 2 | # - Read feature vector from Kafka topic. 3 | # - Compute vector hash via LSH. 4 | # - Publish vector hash to topic. 5 | 6 | from confluent_kafka import Consumer, KafkaError, Producer 7 | from time import time 8 | import numpy as np 9 | import pdb 10 | import os 11 | 12 | 13 | class SimpleLSH(object): 14 | 15 | def __init__(self, seed=865, bits=1024): 16 | self.bits = bits 17 | self.rng = np.random.RandomState(seed) 18 | self.planes = None 19 | self.M = None 20 | self.N = None 21 | self.NdotM = None 22 | 23 | def fit(self, X): 24 | sample_ii = self.rng.choice(range(len(X)), 2 * self.bits) 25 | X_sample = X[sample_ii].reshape(2, self.bits, X.shape[-1]) 26 | self.M = (X_sample[0, ...] + X_sample[1, ...]) / 2 27 | self.N = X_sample[-1, ...] - self.M 28 | self.NdotM = (self.N * self.M).sum(-1) 29 | return self 30 | 31 | def get_vector_hash(self, X): 32 | XdotN = X.dot(self.N.T) 33 | return (XdotN >= self.NdotM).astype(np.uint8) 34 | 35 | 36 | if __name__ == "__main__": 37 | 38 | K_SERVER = "localhost:9092" 39 | K_SUB_TOPIC = "image-feature-vectors" 40 | K_PUB_TOPIC = "image-hash-vectors" 41 | FIT_VECS_PATH = "../es-lsh-images/imagenet_vectors.npy" 42 | 43 | settings = { 44 | 'bootstrap.servers': K_SERVER, 45 | 'group.id': 'TODO', 46 | 'client.id': 'client-%d' % time(), 47 | 'enable.auto.commit': True, 48 | 'session.timeout.ms': 6000, 49 | 'default.topic.config': { 50 | 'auto.offset.reset': 'smallest' 51 | } 52 | } 53 | 54 | consumer = Consumer(settings) 55 | producer = Producer({"bootstrap.servers": K_SERVER}) 56 | consumer.subscribe([K_SUB_TOPIC]) 57 | 58 | vecs = np.load(FIT_VECS_PATH) 59 | simple_lsh = SimpleLSH(bits=1024, seed=865) 60 | simple_lsh.fit(vecs) 61 | 62 | while True: 63 | 64 | # TODO: is it really best practice to poll like this? 65 | msg = consumer.poll(0.1) 66 | if msg is None: 67 | continue 68 | 69 | if msg.error(): 70 | print('Error: %s' % msg.error().str()) 71 | continue 72 | 73 | image_key = msg.key().decode() 74 | vec = np.fromstring(msg.value(), dtype=np.float32) 75 | hsh = simple_lsh.get_vector_hash(vec[np.newaxis, :]) 76 | 77 | print('%s %s %.3lf %s %.3lf' % ( 78 | image_key, str(vec.shape), vec.mean(), str(hsh.shape), hsh.mean())) 79 | 80 | producer.produce(K_PUB_TOPIC, key=image_key, value=hsh.tostring()) 81 | 82 | # TODO: use atexit to make sure it flushes if the script fails. 83 | producer.flush() 84 | -------------------------------------------------------------------------------- /scratch/mvp/kafka_image_s3_keys.py: -------------------------------------------------------------------------------- 1 | # Reads S3 image keys from a text file and produces them to a Kafka topic. 2 | from confluent_kafka import Producer 3 | from time import sleep 4 | 5 | if __name__ == "__main__": 6 | 7 | S3_KEYS_PATH = 's3_keys_test.txt' 8 | KAFKA_SERVER = 'localhost:9092' 9 | OUTPUT_TOPIC = 'image-s3-keys' 10 | 11 | producer = Producer({"bootstrap.servers": KAFKA_SERVER}) 12 | 13 | for i, s3_key in enumerate(map(str.strip, open(S3_KEYS_PATH))): 14 | producer.produce(OUTPUT_TOPIC, key=s3_key, value=s3_key) 15 | print('%d, key %s' % (i, s3_key)) 16 | sleep(0.0001) 17 | 18 | producer.flush() 19 | -------------------------------------------------------------------------------- /scratch/mvp/readme.md: -------------------------------------------------------------------------------- 1 | # MVP 2 | 3 | [Demo video](https://www.youtube.com/watch?v=qyMeh0R4xCU&feature=youtu.be) -------------------------------------------------------------------------------- /scratch/twitter-images/.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | -------------------------------------------------------------------------------- /scratch/twitter-images/ingest.py: -------------------------------------------------------------------------------- 1 | from tweepy import OAuthHandler, API, Stream, StreamListener 2 | from threading import Thread, active_count 3 | from time import time 4 | import json 5 | import gzip 6 | import urllib.request 7 | import os 8 | import pdb 9 | import sys 10 | 11 | STATUSES_DIR = 'data/statuses' 12 | IMAGES_DIR = 'data/images' 13 | TWCREDS = { 14 | "consumer_key": os.getenv("TWCK"), 15 | "consumer_secret": os.getenv("TWCS"), 16 | "access_token": os.getenv("TWAT"), 17 | "token_secret": os.getenv("TWTS") 18 | } 19 | 20 | def download(status): 21 | 22 | for i, item in enumerate(status.entities['media']): 23 | t0 = time() 24 | ext = item['media_url'].split('.')[-1] 25 | local_path = '%s/%d_%d.%s' % (IMAGES_DIR, status.id, i, ext) 26 | urllib.request.urlretrieve(item['media_url'], local_path) 27 | print('%.2lf %s' % (time() - t0, local_path)) 28 | 29 | with gzip.open('%s/%d.json.gz' % (STATUSES_DIR, status.id), 'wb') as fp: 30 | fp.write(json.dumps(status._json).encode()) 31 | 32 | class MyStreamListener(StreamListener): 33 | 34 | def __init__(self, **kwargs): 35 | self.cnt_tot = len(os.listdir(IMAGES_DIR)) 36 | self.cnt_new = 0 37 | self.t0 = time() 38 | super().__init__(kwargs) 39 | 40 | def on_status(self, status): 41 | 42 | if 'media' not in status.entities: 43 | return 44 | 45 | thr = Thread(target=download, args=(status,)) 46 | thr.start() 47 | 48 | self.cnt_new += 1 49 | self.cnt_tot += len(status.entities['media']) 50 | 51 | time_sec = time() - self.t0 52 | time_day = time_sec / (24 * 60 * 60) 53 | 54 | print('%d total, %d new, %.3lf per day, %d active threads' % ( 55 | self.cnt_tot, self.cnt_new, self.cnt_new / time_day, active_count())) 56 | 57 | if __name__ == "__main__": 58 | 59 | auth = OAuthHandler(TWCREDS['consumer_key'], TWCREDS['consumer_secret']) 60 | auth.set_access_token(TWCREDS['access_token'], TWCREDS['token_secret']) 61 | twitter = API( 62 | auth, wait_on_rate_limit=True, wait_on_rate_limit_notify=True, 63 | retry_count=10, retry_delay=1) 64 | 65 | myStreamListener = MyStreamListener() 66 | myStream = Stream(auth=twitter.auth, listener=myStreamListener) 67 | myStream.sample() 68 | 69 | --------------------------------------------------------------------------------