├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── compose.yaml ├── dria_hnsw ├── .dockerignore ├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── Dockerfile ├── Makefile ├── README.md └── src │ ├── db │ ├── conversions.rs │ ├── env.rs │ ├── mod.rs │ ├── redis_client.rs │ └── rocksdb_client.rs │ ├── errors │ ├── errors.rs │ └── mod.rs │ ├── filter │ ├── mod.rs │ └── text_based.rs │ ├── hnsw │ ├── index.rs │ ├── mod.rs │ ├── scalar.rs │ ├── sync_map.rs │ └── utils.rs │ ├── lib.rs │ ├── main.rs │ ├── middlewares │ ├── cache.rs │ └── mod.rs │ ├── models │ ├── mod.rs │ └── request_models.rs │ ├── proto │ ├── hnsw_comm.proto │ ├── index.proto │ ├── index_buffer.rs │ ├── insert.proto │ ├── insert_buffer.rs │ ├── mod.rs │ ├── request.proto │ └── request_buffer.rs │ ├── responses │ ├── mod.rs │ └── responses.rs │ └── worker.rs ├── hollowdb ├── .dockerignore ├── .env.example ├── .gitignore ├── .yarnrc.yml ├── Dockerfile ├── README.md ├── config │ └── .gitignore ├── jest.config.ts ├── package.json ├── src │ ├── clients │ │ ├── hollowdb.ts │ │ └── rocksdb.ts │ ├── configurations │ │ └── index.ts │ ├── controllers │ │ ├── read.ts │ │ ├── values.ts │ │ └── write.ts │ ├── global.d.ts │ ├── index.ts │ ├── schemas │ │ └── index.ts │ ├── server.ts │ ├── types │ │ └── index.ts │ └── utilities │ │ ├── download.ts │ │ └── refresh.ts ├── test │ ├── index.test.ts │ ├── res │ │ ├── contractSource.ts │ │ └── initialState.ts │ └── util │ │ └── index.ts ├── tsconfig.build.json ├── tsconfig.json └── yarn.lock └── hollowdb_wait ├── Dockerfile ├── README.md └── wait.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | wallet.json 3 | dump.rdb 4 | .vscode 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 FirstBatch 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | ######### DRIA HNSW ######### 3 | .PHONY: build-hnsw run-hnsw push-hnsw 4 | 5 | build-hnsw: 6 | docker build ./dria_hnsw -t dria-hnsw 7 | 8 | run-hnsw: 9 | docker run \ 10 | -e CONTRACT_ID=WbcY2a-KfDpk7fsgumUtLC2bu4NQcVzNlXWi13fPMlU \ 11 | -e REDIS_URL=redis://default:redispw@localhost:6379 \ 12 | -p 8080:8080 \ 13 | dria-hnsw 14 | 15 | push-hnsw: 16 | docker buildx build \ 17 | --platform=linux/amd64,linux/arm64,linux/arm ./dria_hnsw \ 18 | -t firstbatch/dria-hnsw:latest \ 19 | --builder=dria-builder --push 20 | 21 | ######### HOLLOWDB ########## 22 | .PHONY: build-hollowdb run-hollowdb push-hollowdb 23 | 24 | build-hollowdb: 25 | docker build ./hollowdb -t dria-hollowdb 26 | 27 | run-hollowdb: 28 | docker run \ 29 | -e CONTRACT_ID=WbcY2a-KfDpk7fsgumUtLC2bu4NQcVzNlXWi13fPMlU \ 30 | -e REDIS_URL=redis://default:redispw@localhost:6379 \ 31 | -p 3030:3030 \ 32 | dria-hollowdb 33 | 34 | push-hollowdb: 35 | docker build ./hollowdb -t firstbatch/dria-hollowdb:latest --push 36 | 37 | ####### HOLLOWDB WAIT ####### 38 | .PHONY: build-hollowdb-wait run-hollowdb-wait push-hollowdb-wait 39 | 40 | build-hollowdb-wait: 41 | docker build ./hollowdb_wait -t hollowdb-wait-for 42 | 43 | run-hollowdb-wait: 44 | docker run hollowdb-wait-for 45 | 46 | push-hollowdb-wait: 47 | docker build ./hollowdb_wait -t firstbatch/dria-hollowdb-wait-for:latest --push 48 | 49 | ########### REDIS ########### 50 | pull-redis: 51 | docker pull redis:alpine 52 | 53 | run-redis: 54 | docker run redis:alpine --name dria-redis -p 6379:6379 55 | 56 | ########## BUILDER ########## 57 | .PHONY: dria-builder 58 | 59 | # see: https://docs.docker.com/build/building/multi-platform/#cross-compilation 60 | dria-builder: 61 | docker buildx create --name dria-builder --bootstrap --use 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | logo 3 |

4 | 5 |

6 |

7 | Dria Docker 8 |

9 |

10 | Dria Docker is an all-in-one environment to use Dria, the collective knowledge for AI. 11 |

12 |

13 | 14 | ## Setup 15 | 16 | To use Dria Docker, you need: 17 | 18 | - [Docker](https://www.docker.com/) installed in your machine. 19 | - A Dria contract ID 20 | 21 | A Dria contract is the knowledge that is deployed on Arweave; the contract ID can be seen on each knowledge deployed to [Dria](https://dria.co/). For example, consider the Dria knowledge of [The Rust Programming Language](https://dria.co/knowledge/7EZMw0vAAFaKVMNOmu2rFgFCFjRD2C2F0kI_N5Cv6QQ): 22 | 23 | - 24 | 25 | The base64 URL there is our contract ID, and it can also be seen at the top of the page at that link. 26 | 27 | ### Using Dria CLI 28 | 29 | The preferred method of using Dria Docker is via the [Dria CLI](https://github.com/firstbatchxyz/dria-cli/), which is an NPM package. 30 | 31 | ```sh 32 | npm i -g dria-cli 33 | ``` 34 | 35 | You can see available commands with: 36 | 37 | ```sh 38 | dria help 39 | ``` 40 | 41 | See the [docs](https://github.com/firstbatchxyz/dria-cli/?tab=readme-ov-file#usage) of Dria CLI for more. 42 | 43 | ### Using Compose 44 | 45 | Download the Docker compose file: 46 | 47 | ```sh 48 | curl -o compose.yaml -L https://raw.githubusercontent.com/firstbatchxyz/dria-docker/master/compose.yaml 49 | ``` 50 | 51 | You can start a Dria container with the following command, where the contract ID is provided as environment variable. 52 | 53 | ```sh 54 | CONTRACT=contract-id docker compose up 55 | ``` 56 | 57 | ## Usage 58 | 59 | When everything is up, you will have access to both Dria and HollowDB on your local network! 60 | 61 | - Dria HNSW will be live at `localhost:8080`, see endpoints [here](./dria_hnsw/README.md#endpoints). 62 | - HollowDB API will be live at `localhost:3030`, see endpoints [here](./hollowdb/README.md#endpoints). 63 | 64 | These host ports can also be changed within the [compose file](./compose.yaml), if you have them reserved for other applications. 65 | 66 | > [!TIP] 67 | > 68 | > You can also connect to a terminal on the Redis container and use `redis-cli` if you would like to examine the keys. 69 | 70 | ## License 71 | 72 | Dria Docker is licensed under [Apache 2.0](./LICENSE). 73 | -------------------------------------------------------------------------------- /compose.yaml: -------------------------------------------------------------------------------- 1 | version: "3.8" 2 | 3 | services: 4 | ### Dria HNSW Rust code ### 5 | dria-hnsw: 6 | #build: ./dria_hnsw 7 | image: "firstbatch/dria-hnsw" 8 | environment: 9 | - PORT=8080 10 | - ROCKSDB_PATH=/data/${CONTRACT} 11 | - REDIS_URL=redis://default:redispw@redis:6379 12 | - CONTRACT_ID=${CONTRACT} 13 | volumes: 14 | - ${HOME}/.dria/data:/data 15 | ports: 16 | - "8080:8080" 17 | depends_on: 18 | hollowdb-wait-for: 19 | condition: service_completed_successfully 20 | 21 | ### HollowDBs API 'wait-for' script ### 22 | hollowdb-wait-for: 23 | # build: ./hollowdb_wait 24 | image: "firstbatch/dria-hollowdb-wait-for" 25 | environment: 26 | - TARGET=hollowdb:3000 27 | depends_on: 28 | - hollowdb 29 | 30 | ### HollowDB API ### 31 | hollowdb: 32 | #build: ./hollowdb 33 | image: "firstbatch/dria-hollowdb" 34 | ports: 35 | - "3000:3000" 36 | expose: 37 | - "3000" # used by HollowDB wait-for script 38 | volumes: 39 | - ${HOME}/.dria/data:/app/data 40 | environment: 41 | - PORT=3000 42 | - CONTRACT_ID=${CONTRACT} 43 | - ROCKSDB_PATH=/app/data/${CONTRACT} 44 | - REDIS_URL=redis://default:redispw@redis:6379 45 | - USE_BUNDLR=true # true if your contract uses Bundlr 46 | - USE_HTX=true # true if your contract stores values as `hash.txid` 47 | - BUNDLR_FBS=80 # batch size for downloading bundled values from Arweave 48 | depends_on: 49 | - redis 50 | 51 | ### Redis Container ### 52 | redis: 53 | image: "redis:alpine" 54 | expose: 55 | - "6379" 56 | # prettier-ignore 57 | command: [ 58 | 'redis-server', 59 | '--port', '6379', 60 | '--maxmemory', '100mb', 61 | '--maxmemory-policy', 'allkeys-lru', 62 | '--appendonly', 'no', 63 | '--dbfilename', '${CONTRACT}.rdb', 64 | '--dir', '/tmp' 65 | ] 66 | -------------------------------------------------------------------------------- /dria_hnsw/.dockerignore: -------------------------------------------------------------------------------- 1 | target 2 | Dockerfile 3 | manifests 4 | .dockerignore 5 | .git 6 | .gitignore 7 | .github 8 | .DS_Store 9 | README.md 10 | -------------------------------------------------------------------------------- /dria_hnsw/.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | -------------------------------------------------------------------------------- /dria_hnsw/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "dria_hnsw" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | redis = { version = "0.23.0", features = ["cluster"] } 10 | prost = "0.10.0" 11 | prost-build = "0.10.0" 12 | prost-types = "0.10.0" 13 | prost-derive = "0.10.0" 14 | base64 = "0.13.0" 15 | rand = "0.8.5" 16 | serde = { version = "1.0", features = ["derive"] } 17 | serde_json = "1.0" 18 | chrono = "0.4.19" 19 | hashbrown = "0.14.0" 20 | url = "2.4.0" 21 | actix-web = "4.3.1" 22 | reqwest = "0.11.18" 23 | actix-cors = "0.6.4" 24 | tokio = { version = "1", features = ["full"] } 25 | futures-util = "0.3.28" 26 | proc-macro-error = "1.0.4" 27 | derive_more = "0.99.2" 28 | tdigest = "0.2.3" 29 | rayon = "1.7.0" 30 | ahash = "0.8.6" 31 | rocksdb = "0.21.0" 32 | parking_lot = "0.12.1" 33 | crossbeam-channel = "0.5.8" 34 | mini-moka = "0.10.1" 35 | dashmap = "5.5.0" 36 | log = "0.4" 37 | simple_logger = "4.2.0" 38 | simsimd = "3.8.0" 39 | probly-search = "2.0.0" 40 | 41 | [dev-dependencies] 42 | simple-home-dir = "0.3.2" 43 | -------------------------------------------------------------------------------- /dria_hnsw/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM --platform=$BUILDPLATFORM rust:1.71 as builder 2 | 3 | # https://docs.docker.com/engine/reference/builder/#automatic-platform-args-in-the-global-scope 4 | # 5 | # offical rust image supports the following archs: 6 | # amd64 (AMD & Intel 64-bit) 7 | # arm32/v7 (ARMv7 32-bit) 8 | # arm64/v8 (ARMv8 64-bit) 9 | # i386 (Intel 32-bit 8086) 10 | # 11 | # our builds will be for platforms: 12 | # linux/amd64 13 | # linux/arm64/v8 14 | # linux/arm32/v7 15 | # linux/i386 16 | # 17 | # however, for small image size we use distroless, which allow 18 | # linux/amd64 19 | # linux/arm64 20 | # linux/arm 21 | # 22 | # To build an image & push them to Docker hub for this Dockerfile: 23 | # 24 | # docker buildx build --platform=linux/amd64,linux/arm64,linux/arm . -t firstbatch/dria-hnsw:latest --builder=dria-builder --push 25 | 26 | ARG BUILDPLATFORM 27 | ARG TARGETPLATFORM 28 | RUN echo "Build platform: $BUILDPLATFORM" 29 | RUN echo "Target platform: $TARGETPLATFORM" 30 | 31 | # install Cmake 32 | RUN apt-get update 33 | RUN apt-get install -y cmake 34 | 35 | # libclang needed by rocksdb 36 | RUN apt-get install -y clang 37 | 38 | # use nightly Rust 39 | RUN rustup install nightly-2023-07-25 40 | RUN rustup default nightly-2023-07-25 41 | 42 | # build release binary 43 | WORKDIR /usr/src/app 44 | COPY . . 45 | RUN cargo build --release 46 | 47 | # copy release binary to distroless 48 | FROM --platform=$BUILDPLATFORM gcr.io/distroless/cc 49 | COPY --from=builder /usr/src/app/target/release/dria_hnsw / 50 | 51 | EXPOSE 8080 52 | 53 | CMD ["./dria_hnsw"] 54 | -------------------------------------------------------------------------------- /dria_hnsw/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: docs 2 | docs: 3 | cargo doc --open --no-deps 4 | 5 | .PHONY: test 6 | test: 7 | cargo test 8 | 9 | .PHONY: lint 10 | lint: 11 | cargo clippy 12 | 13 | -------------------------------------------------------------------------------- /dria_hnsw/README.md: -------------------------------------------------------------------------------- 1 | # Dria HNSW 2 | 3 | Dria HNSW is an API that allows you to permissionlessly search knowledge uploaded to Dria. It works over values downloaded from Arweave to a Redis cache, and reads these values directly from Redis. 4 | It is written in Rust, and several functions respect the machine architecture for efficiency. 5 | 6 | ## Setup 7 | 8 | To run the server, you need to provide a contract ID along with a RocksDB path: 9 | 10 | ```sh 11 | CONTRACT_ID= ROCKSDB_PATH="/path/to/rocksdb" cargo run 12 | ``` 13 | 14 | Dria HNSW is available as a container: 15 | 16 | ```sh 17 | docker pull firstbatch/dria-hnsw 18 | ``` 19 | 20 | > [!TIP] 21 | > 22 | > The docker image is cross-compiled & built for multiple architectures, so when you pull the image the most efficient code per your architecture will be downloaded! 23 | 24 | To see the available endpoints, refer to [this section](#endpoints) below. 25 | 26 | ## Endpoints 27 | 28 | Dria is an [Actix](https://actix.rs/) server with the following endpoints: 29 | 30 | - [`health`](#health) 31 | - [`fetch`](#fetch) 32 | - [`query`](#query) 33 | - [`insert_vector`](#insert_vector) 34 | 35 | All endpoints return a response in the following format: 36 | 37 | - `success`: a boolean indicating the success of request 38 | - `code`: status code 39 | - `data`: response data 40 | 41 | > [!TIP] 42 | > 43 | > If `success` is false, the error message will be written in `data` as a string. 44 | 45 | ### `HEALTH` 46 | 47 | 48 | ```ts 49 | GET /health 50 | ``` 51 | 52 | **A simple healthcheck to see if the server is up.** 53 | 54 | Response data: 55 | 56 | - A string `"hello world!"`. 57 | 58 | ### `FETCH` 59 | 60 | 61 | ```ts 62 | POST /fetch 63 | ``` 64 | 65 | **Given a list of ids, fetches their corresponding vectors.** 66 | 67 | Request body: 68 | 69 | - `id`: an array of integers 70 | 71 | Response data: 72 | 73 | - An array of metadatas, index `i` corresponding to metadata of vector with ID `id[i]`. 74 | 75 | ### `QUERY` 76 | 77 | 78 | ```ts 79 | POST /query 80 | ``` 81 | 82 | **Given a list of ids, fetches their corresponding vectors.** 83 | 84 | Request body: 85 | 86 | - `vector`: an array of floats corresponding to the embedding vector 87 | - `top_n`: number of results to return 88 | - `query`: (_optional_) the text that belongs to given embedding, yields better results by looking for this text within the results 89 | - `level`: (_optional_) an integer value in range [0, 4] that defines the intensity of search, a larger values takes more time to complete but has higher recall 90 | 91 | Response data: 92 | 93 | - An array of objects with the following keys: 94 | - `id`: id of the returned vector 95 | - `score`: relevance score 96 | - `metadata`: metadata of the vector 97 | 98 | ### `INSERT_VECTOR` 99 | 100 | 101 | ```ts 102 | POST /insert_vector 103 | ``` 104 | 105 | **Insert a new vector to HNSW.** 106 | 107 | Request body: 108 | 109 | - `vector`: an array of floats corresponding to the embedding vector 110 | - `metadata`: (_optional_) a JSON object that represent metadata for this vector 111 | 112 | Response data: 113 | 114 | - A string `"Success"`. 115 | 116 | ## Testing 117 | 118 | We have several tests that you can run with: 119 | 120 | ```sh 121 | cargo test 122 | ``` 123 | 124 | Some tests expect a RocksDB folder present at `$HOME/.dria/data/WbcY2a-KfDpk7fsgumUtLC2bu4NQcVzNlXWi13fPMlU`, which can easily be downloaded with the [Dria CLI](https://github.com/firstbatchxyz/dria-cli/) if you do not have it: 125 | 126 | ```sh 127 | dria pull WbcY2a-KfDpk7fsgumUtLC2bu4NQcVzNlXWi13fPMlU 128 | ``` 129 | 130 | The said knowledge is a rather lightweight knowledge that is useful for testing. 131 | -------------------------------------------------------------------------------- /dria_hnsw/src/db/conversions.rs: -------------------------------------------------------------------------------- 1 | use crate::proto::index_buffer::{LayerNode, Point}; 2 | use crate::proto::insert_buffer::{BatchStr, BatchVec, SingletonStr, SingletonVec}; 3 | use prost::Message; 4 | 5 | pub fn point_to_base64(point: &Point) -> String { 6 | let mut bytes = Vec::new(); 7 | point.encode(&mut bytes).expect("Failed to encode message"); 8 | let enc = base64::encode(&bytes); // Convert bytes to string if needed 9 | enc 10 | } 11 | 12 | pub fn base64_to_point(e_point: &str) -> Point { 13 | let bytes = base64::decode(e_point).unwrap(); 14 | let point = Point::decode(bytes.as_slice()).unwrap(); // Deserialize 15 | point 16 | } 17 | 18 | pub fn node_to_base64(node: &LayerNode) -> String { 19 | let mut bytes = Vec::new(); 20 | node.encode(&mut bytes).expect("Failed to encode message"); 21 | let enc = base64::encode(&bytes); // Convert bytes to string if needed 22 | enc 23 | } 24 | 25 | pub fn base64_to_node(e_node: &str) -> LayerNode { 26 | let bytes = base64::decode(e_node).unwrap(); 27 | let node = LayerNode::decode(bytes.as_slice()).unwrap(); // Deserialize 28 | node 29 | } 30 | 31 | //*************** Batch to Singleton ***************// 32 | 33 | pub fn base64_to_batch_vec(batch: &str) -> BatchVec { 34 | let bytes = base64::decode(batch).unwrap(); 35 | let node = BatchVec::decode(bytes.as_slice()).unwrap(); // Deserialize 36 | node 37 | } 38 | 39 | pub fn base64_to_singleton_vec(singleton: &str) -> SingletonVec { 40 | let bytes = base64::decode(singleton).unwrap(); 41 | let node = SingletonVec::decode(bytes.as_slice()).unwrap(); // Deserialize 42 | node 43 | } 44 | 45 | pub fn base64_to_batch_str(batch: &str) -> BatchStr { 46 | let bytes = base64::decode(batch).unwrap(); 47 | let node = BatchStr::decode(bytes.as_slice()).unwrap(); // Deserialize 48 | node 49 | } 50 | 51 | pub fn base64_to_singleton_str(singleton: &str) -> SingletonStr { 52 | let bytes = base64::decode(singleton).unwrap(); 53 | let node = SingletonStr::decode(bytes.as_slice()).unwrap(); // Deserialize 54 | node 55 | } 56 | 57 | #[cfg(test)] 58 | mod tests { 59 | use super::*; 60 | 61 | #[test] 62 | fn test_point_to_base64() { 63 | let point = Point { 64 | idx: 1, 65 | v: vec![1.0, 2.0, 3.0], 66 | }; 67 | let enc = point_to_base64(&point); 68 | let dec = base64_to_point(&enc); 69 | assert_eq!(point, dec); 70 | } 71 | 72 | #[test] 73 | fn test_node_to_base64() { 74 | let node = LayerNode { 75 | level: 1, 76 | idx: 1, 77 | visible: true, 78 | neighbors: std::collections::HashMap::new(), 79 | }; 80 | let enc = node_to_base64(&node); 81 | let dec = base64_to_node(&enc); 82 | assert_eq!(node, dec); 83 | } 84 | 85 | #[test] 86 | fn test_node_to_base64_from_string() { 87 | let node = LayerNode { 88 | level: 1, 89 | idx: 1, 90 | visible: true, 91 | neighbors: std::collections::HashMap::new(), 92 | }; 93 | let enc = "CAEQARgB".to_string(); 94 | let dec = base64_to_node(&enc); 95 | assert_eq!(node, dec); 96 | } 97 | 98 | #[test] 99 | fn test_batch_to_singleton() { 100 | let singleton = SingletonVec { 101 | v: vec![1.0, 2.0, 3.0], 102 | map: std::collections::HashMap::new(), 103 | }; 104 | let batch = BatchVec { 105 | s: vec![singleton.clone()], 106 | }; 107 | let enc = base64::encode(&batch.encode_to_vec()); 108 | let dec = base64_to_batch_vec(&enc); 109 | assert_eq!(batch, dec); 110 | 111 | let singleton = SingletonStr { 112 | v: "test".to_string(), 113 | map: std::collections::HashMap::new(), 114 | }; 115 | let batch = BatchStr { 116 | s: vec![singleton.clone()], 117 | }; 118 | let enc = base64::encode(&batch.encode_to_vec()); 119 | let dec = base64_to_batch_str(&enc); 120 | assert_eq!(batch, dec); 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /dria_hnsw/src/db/env.rs: -------------------------------------------------------------------------------- 1 | use serde::Deserialize; 2 | use std::env; 3 | 4 | #[derive(Debug, Deserialize)] 5 | pub struct Config { 6 | env: String, 7 | debug: bool, 8 | pk_length: u32, 9 | sk_length: u32, 10 | rate_limit: String, 11 | global_rate_limit: String, 12 | logging_level: String, 13 | pub contract_id: String, 14 | pub redis_url: String, 15 | pub port: String, 16 | pub rocksdb_path: String, 17 | } 18 | 19 | impl Config { 20 | pub fn new() -> Config { 21 | let rocksdb_path = match env::var("ROCKSDB_PATH") { 22 | Ok(val) => val, 23 | Err(_) => "/tmp/rocksdb".to_string(), 24 | }; 25 | 26 | let port = match env::var("PORT") { 27 | Ok(val) => val, 28 | Err(_) => "8080".to_string(), 29 | }; 30 | 31 | let contract_id = match env::var("CONTRACT_ID") { 32 | Ok(val) => val, 33 | Err(_) => { 34 | println!("CONTRACT_ID not found, using default"); 35 | "default".to_string() 36 | } 37 | }; 38 | 39 | Config { 40 | env: "development".to_string(), 41 | debug: true, 42 | pk_length: 0, 43 | sk_length: 0, 44 | rate_limit: "".to_string(), 45 | global_rate_limit: "".to_string(), 46 | logging_level: "DEBUG".to_string(), 47 | contract_id, 48 | redis_url: "redis://127.0.0.1/".to_string(), 49 | port, 50 | rocksdb_path, 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /dria_hnsw/src/db/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod conversions; 2 | pub mod env; 3 | pub mod redis_client; 4 | pub mod rocksdb_client; 5 | -------------------------------------------------------------------------------- /dria_hnsw/src/db/redis_client.rs: -------------------------------------------------------------------------------- 1 | extern crate redis; 2 | use redis::{Client, Commands, Connection}; 3 | 4 | use crate::db::conversions::{base64_to_node, node_to_base64, point_to_base64}; 5 | use crate::db::env::Config; 6 | use crate::errors::errors::DeserializeError; 7 | use crate::proto::index_buffer::{LayerNode, Point}; 8 | use prost::Message; 9 | use serde_json::Value; 10 | 11 | pub struct RedisClient { 12 | client: Client, 13 | connection: Connection, 14 | tag: String, 15 | } 16 | 17 | impl RedisClient { 18 | pub fn new(contract_id: String) -> Result { 19 | let cfg = Config::new(); 20 | let client = 21 | Client::open(cfg.redis_url).map_err(|_| DeserializeError::RedisConnectionError)?; 22 | let connection = client 23 | .get_connection() 24 | .map_err(|_| DeserializeError::RedisConnectionError)?; 25 | 26 | Ok(RedisClient { 27 | client, 28 | connection, 29 | tag: contract_id, 30 | }) 31 | } 32 | 33 | pub fn set(&mut self, key: String, value: String) -> Result<(), DeserializeError> { 34 | let keys_local = format!("{}.value.{}", self.tag, key); 35 | let _: () = self 36 | .connection 37 | .set(&keys_local, &value) 38 | .map_err(|_| DeserializeError::RedisConnectionError)?; 39 | Ok(()) 40 | } 41 | 42 | pub fn get_neighbor( 43 | &mut self, 44 | layer: usize, 45 | idx: usize, 46 | ) -> Result { 47 | let key = format!( 48 | "{}.value.{}:{}", 49 | self.tag, 50 | layer.to_string(), 51 | idx.to_string() 52 | ); 53 | 54 | let node_str = match self 55 | .connection 56 | .get::<_, String>(key) 57 | .map_err(|_| DeserializeError::RedisConnectionError) 58 | { 59 | Ok(node_str) => node_str, 60 | Err(_) => { 61 | return Err(DeserializeError::RedisConnectionError); 62 | } 63 | }; 64 | 65 | //let node_:Value = serde_json::from_str(&node_str).unwrap(); 66 | Ok(base64_to_node(&node_str)) 67 | } 68 | 69 | pub fn get_neighbors( 70 | &mut self, 71 | layer: usize, 72 | indices: Vec, 73 | ) -> Result, DeserializeError> { 74 | let keys = indices 75 | .iter() 76 | .map(|x| format!("{}.value.{}:{}", self.tag, layer.to_string(), x.to_string())) 77 | .collect::>(); 78 | 79 | let values_str: Vec = self 80 | .connection 81 | .mget(keys) 82 | .map_err(|_| DeserializeError::RedisConnectionError)?; 83 | 84 | let neighbors = values_str 85 | .iter() 86 | .map(|s| { 87 | let bytes = base64::decode(s).unwrap(); 88 | let p = LayerNode::decode(bytes.as_slice()).unwrap(); // Deserialize 89 | Ok(p) 90 | }) 91 | .collect::, DeserializeError>>()?; 92 | 93 | Ok(neighbors) 94 | } 95 | 96 | pub fn upsert_neighbor(&mut self, node: LayerNode) -> Result<(), DeserializeError> { 97 | let key = format!("{}.value.{}:{}", self.tag, node.level, node.idx); 98 | 99 | let node_str = node_to_base64(&node); 100 | self.set(key, node_str)?; 101 | 102 | Ok(()) 103 | } 104 | 105 | pub fn upsert_neighbors(&mut self, nodes: Vec) -> Result<(), DeserializeError> { 106 | let mut pairs = Vec::new(); 107 | for node in nodes { 108 | let key = format!("{}.value.{}:{}", self.tag, node.level, node.idx); 109 | let node_str = node_to_base64(&node); 110 | pairs.push((key, node_str)); 111 | } 112 | 113 | let _ = self 114 | .connection 115 | .mset(pairs.as_slice()) 116 | .map_err(|_| DeserializeError::RedisConnectionError)?; 117 | 118 | Ok(()) 119 | } 120 | 121 | pub fn get_points(&mut self, indices: &Vec) -> Result, DeserializeError> { 122 | let keys = indices 123 | .iter() 124 | .map(|x| format!("{}.value.{}", self.tag, x)) 125 | .collect::>(); 126 | 127 | if keys.is_empty() { 128 | return Ok(vec![]); 129 | } 130 | let values_str: Vec = self 131 | .connection 132 | .mget(keys) 133 | .map_err(|_| DeserializeError::RedisConnectionError)?; 134 | 135 | let points = values_str 136 | .into_iter() 137 | .map(|s| { 138 | let bytes = base64::decode(s).unwrap(); 139 | let p = Point::decode(bytes.as_slice()).unwrap(); // Deserialize 140 | Ok(p) 141 | }) 142 | .collect::, DeserializeError>>()?; 143 | Ok(points) 144 | } 145 | 146 | pub fn add_points_batch( 147 | &mut self, 148 | v: &Vec>, 149 | start_idx: usize, 150 | ) -> Result<(), DeserializeError> { 151 | let mut pairs = Vec::new(); 152 | 153 | for (i, p) in v.iter().enumerate() { 154 | let idx = start_idx + i; 155 | let p = Point::new(p.clone(), idx); 156 | let p_str = point_to_base64(&p); 157 | pairs.push((format!("{}.value.{}", self.tag, idx), p_str)); 158 | } 159 | 160 | let _ = self 161 | .connection 162 | .mset(pairs.as_slice()) 163 | .map_err(|_| DeserializeError::RedisConnectionError)?; 164 | Ok(()) 165 | } 166 | 167 | pub fn add_points(&mut self, v: Vec, idx: usize) -> Result<(), DeserializeError> { 168 | let p = Point::new(v, idx); 169 | let p_str = point_to_base64(&p); 170 | 171 | self.set(format!("{}.value.{}", self.tag, idx), p_str)?; 172 | Ok(()) 173 | } 174 | 175 | pub fn set_datasize(&mut self, datasize: usize) -> Result<(), DeserializeError> { 176 | let _: () = self 177 | .connection 178 | .set(format!("{}.value.datasize", self.tag), datasize.to_string()) 179 | .map_err(|_| DeserializeError::RedisConnectionError)?; 180 | Ok(()) 181 | } 182 | 183 | pub fn get_datasize(&mut self) -> Result { 184 | let datasize: String = self 185 | .connection 186 | .get(format!("{}.value.datasize", self.tag)) 187 | .map_err(|_| DeserializeError::RedisConnectionError)?; 188 | let datasize = datasize.parse::().unwrap(); 189 | Ok(datasize) 190 | } 191 | 192 | pub fn set_num_layers(&mut self, num_layers: usize) -> Result<(), DeserializeError> { 193 | let _: () = self 194 | .connection 195 | .set( 196 | format!("{}.value.num_layers", self.tag), 197 | num_layers.to_string(), 198 | ) 199 | .map_err(|_| DeserializeError::RedisConnectionError)?; 200 | Ok(()) 201 | } 202 | 203 | pub fn get_num_layers(&mut self) -> Result { 204 | let num_layers: String = self 205 | .connection 206 | .get(format!("{}.value.num_layers", self.tag)) 207 | .map_err(|_| DeserializeError::RedisConnectionError)?; 208 | let num_layers = num_layers.parse::().unwrap(); 209 | Ok(num_layers) 210 | } 211 | 212 | pub fn set_ep(&mut self, ep: usize) -> Result<(), DeserializeError> { 213 | let _: () = self 214 | .connection 215 | .set(format!("{}.value.ep", self.tag), ep.to_string()) 216 | .map_err(|_| DeserializeError::RedisConnectionError)?; 217 | Ok(()) 218 | } 219 | 220 | pub fn get_ep(&mut self) -> Result { 221 | let ep: String = self 222 | .connection 223 | .get(format!("{}.value.ep", self.tag)) 224 | .map_err(|_| DeserializeError::RedisConnectionError)?; 225 | let ep = ep.parse::().unwrap(); 226 | Ok(ep) 227 | } 228 | 229 | pub fn set_metadata_batch( 230 | &mut self, 231 | metadata: Vec, 232 | idx: usize, 233 | ) -> Result<(), DeserializeError> { 234 | let mut pairs = Vec::new(); 235 | for (i, m) in metadata.iter().enumerate() { 236 | let key = format!("{}.value.m:{}", self.tag, idx + i); 237 | let metadata_str = serde_json::to_string(&m).unwrap(); 238 | pairs.push((key, metadata_str)); 239 | } 240 | let _ = self 241 | .connection 242 | .mset(pairs.as_slice()) 243 | .map_err(|_| DeserializeError::RedisConnectionError)?; 244 | Ok(()) 245 | } 246 | 247 | pub fn set_metadata(&mut self, metadata: Value, idx: usize) -> Result<(), DeserializeError> { 248 | let key = format!("{}.value.m:{}", self.tag, idx); 249 | let metadata_str = serde_json::to_string(&metadata).unwrap(); 250 | self.set(key, metadata_str)?; 251 | Ok(()) 252 | } 253 | 254 | pub fn get_metadata(&mut self, idx: usize) -> Result { 255 | let key = format!("{}.value.m:{}", self.tag, idx); 256 | let metadata_str: String = self 257 | .connection 258 | .get(&key) 259 | .map_err(|_| DeserializeError::RedisConnectionError)?; 260 | let metadata: Value = serde_json::from_str(&metadata_str).unwrap(); 261 | Ok(metadata) 262 | } 263 | 264 | pub fn get_metadatas(&mut self, indices: Vec) -> Result, DeserializeError> { 265 | let keys = indices 266 | .iter() 267 | .map(|x| format!("{}.value.m:{}", self.tag, x)) 268 | .collect::>(); 269 | 270 | let metadata_str: Vec = self 271 | .connection 272 | .mget(&keys) 273 | .map_err(|_| DeserializeError::RedisConnectionError)?; 274 | 275 | let metadata = metadata_str 276 | .into_iter() 277 | .map(|s| { 278 | let m: Value = serde_json::from_str(&s).unwrap(); 279 | Ok(m) 280 | }) 281 | .collect::, DeserializeError>>()?; 282 | 283 | Ok(metadata) 284 | } 285 | } 286 | -------------------------------------------------------------------------------- /dria_hnsw/src/db/rocksdb_client.rs: -------------------------------------------------------------------------------- 1 | use crate::db::conversions::{base64_to_node, base64_to_point, node_to_base64, point_to_base64}; 2 | use crate::db::env::Config; 3 | use crate::errors::errors::DeserializeError; 4 | use crate::proto::index_buffer::{LayerNode, Point}; 5 | use prost::Message; 6 | use rocksdb; 7 | use rocksdb::{DBWithThreadMode, Options, WriteBatch, DB}; 8 | use serde_json::Value; 9 | 10 | #[derive(Debug)] 11 | pub struct RocksdbClient { 12 | tag: String, 13 | client: DB, 14 | } 15 | 16 | impl RocksdbClient { 17 | pub fn new(contract_id: String) -> Result { 18 | let cfg = Config::new(); 19 | // Create a new database options instance. 20 | let mut opts = Options::default(); 21 | opts.create_if_missing(true); // Creates a database if it does not exist. 22 | //let x = DBWithThreadMode::open(&opts, cfg.rocksdb_path); 23 | let db = DB::open(&opts, cfg.rocksdb_path).unwrap(); 24 | 25 | Ok(RocksdbClient { 26 | tag: contract_id, 27 | client: db, 28 | }) 29 | } 30 | 31 | pub fn set(&self, key: String, value: String) -> Result<(), DeserializeError> { 32 | let _: () = self 33 | .client 34 | .put(key.as_bytes(), value.as_bytes()) 35 | .map_err(|_| DeserializeError::RedisConnectionError)?; 36 | Ok(()) 37 | } 38 | 39 | pub fn get_neighbor(&self, layer: usize, idx: usize) -> Result { 40 | let key = format!("{}.value.{}:{}", self.tag, layer, idx); 41 | 42 | let value = self 43 | .client 44 | .get(key.as_bytes()) 45 | .map_err(|_| DeserializeError::RocksDBConnectionError)?; // Handle RocksDB errors appropriately 46 | 47 | let node_str = match value { 48 | Some(value) => String::from_utf8(value).map_err(|_| DeserializeError::InvalidForm)?, // Convert bytes to String and handle UTF-8 error 49 | None => return Err(DeserializeError::MissingKey), // Handle case where key is not found 50 | }; 51 | 52 | Ok(base64_to_node(&node_str)) 53 | } 54 | 55 | pub fn get_neighbors( 56 | &self, 57 | layer: usize, 58 | indices: Vec, 59 | ) -> Result, DeserializeError> { 60 | // Collect keys as a Vec> for multi_get 61 | let keys = indices 62 | .iter() 63 | .map(|&x| format!("{}.value.{}:{}", self.tag, layer, x).into_bytes()) 64 | .collect::>>(); 65 | 66 | // Use multi_get to fetch values for all keys at once 67 | let values = self.client.multi_get(keys); 68 | 69 | let mut neighbors = Vec::new(); 70 | for value_result in values { 71 | // Correctly handle the Result>, E> for each value 72 | match value_result { 73 | Ok(Some(v)) => { 74 | let node_str = 75 | String::from_utf8(v).map_err(|_| DeserializeError::InvalidForm)?; // Convert bytes to String and handle UTF-8 error 76 | let node = base64_to_node(&node_str); // Convert String to LayerNode and handle base64 error 77 | neighbors.push(node); 78 | } 79 | Ok(None) => return Err(DeserializeError::MissingKey), // Handle case where key is not found 80 | Err(_) => return Err(DeserializeError::RocksDBConnectionError), // Handle error in fetching value 81 | } 82 | } 83 | 84 | Ok(neighbors) 85 | } 86 | 87 | pub fn upsert_neighbor(&self, node: LayerNode) -> Result<(), DeserializeError> { 88 | let key = format!("{}.value.{}:{}", self.tag, node.level, node.idx); 89 | 90 | let node_str = node_to_base64(&node); 91 | self.set(key, node_str)?; 92 | 93 | Ok(()) 94 | } 95 | 96 | pub fn upsert_neighbors(&self, nodes: Vec) -> Result<(), DeserializeError> { 97 | let mut batch = WriteBatch::default(); 98 | for node in nodes { 99 | let key = format!("{}.value.{}:{}", self.tag, node.level, node.idx); 100 | let node_str = node_to_base64(&node); 101 | batch.put(key.as_bytes(), node_str.as_bytes()); 102 | } 103 | 104 | let _ = self 105 | .client 106 | .write(batch) 107 | .map_err(|_| DeserializeError::RocksDBConnectionError)?; 108 | 109 | Ok(()) 110 | } 111 | 112 | pub fn get_points(&self, indices: &Vec) -> Result, DeserializeError> { 113 | let keys = indices 114 | .iter() 115 | .map(|x| format!("{}.value.{}", self.tag, x).into_bytes()) 116 | .collect::>>(); 117 | 118 | if keys.is_empty() { 119 | return Ok(vec![]); 120 | } 121 | 122 | // Assuming multi_get directly returns Vec>, E>> 123 | let values = self.client.multi_get(keys); 124 | 125 | let mut points = Vec::new(); 126 | for value_result in values { 127 | match value_result { 128 | Ok(Some(value)) => { 129 | let point_str = 130 | String::from_utf8(value).map_err(|_| DeserializeError::InvalidForm)?; // Handle UTF-8 conversion error 131 | let point = base64_to_point(&point_str); // Handle potential error from base64_to_point 132 | points.push(point); 133 | } 134 | Ok(None) => return Err(DeserializeError::MissingKey), // Key not found 135 | Err(_) => return Err(DeserializeError::RocksDBConnectionError), // Error fetching from RocksDB 136 | } 137 | } 138 | 139 | Ok(points) 140 | } 141 | 142 | pub fn add_points(&self, v: Vec, idx: usize) -> Result<(), DeserializeError> { 143 | let p = Point::new(v, idx); 144 | let p_str = point_to_base64(&p); 145 | let key = format!("{}.value.{}", self.tag, idx).into_bytes(); 146 | self.client 147 | .put(key, p_str.as_bytes()) 148 | .map_err(|_| DeserializeError::RocksDBConnectionError)?; 149 | //self.put_multi_hashtag(&[idx.to_string()], &[json!(p_str)], false)?; 150 | Ok(()) 151 | } 152 | 153 | pub fn add_points_batch( 154 | &self, 155 | v: &Vec>, 156 | start_idx: usize, 157 | ) -> Result<(), DeserializeError> { 158 | let mut batch = WriteBatch::default(); 159 | for (i, p) in v.iter().enumerate() { 160 | let idx = start_idx + i; 161 | let p = Point::new(p.clone(), idx); 162 | let p_str = point_to_base64(&p); 163 | let key = format!("{}.value.{}", self.tag, idx).into_bytes(); 164 | 165 | //keys.push(idx.to_string()); 166 | //values.push(json!(p_str)); 167 | 168 | batch.put(key, p_str.as_bytes()); 169 | } 170 | 171 | self.client 172 | .write(batch) 173 | .map_err(|_| DeserializeError::RocksDBConnectionError)?; 174 | 175 | Ok(()) 176 | } 177 | 178 | pub fn set_datasize(&self, datasize: usize) -> Result<(), DeserializeError> { 179 | //self.put_multi_hashtag(&["datasize".to_string()], &[json!(datasize)], false)?; 180 | self.client 181 | .put( 182 | format!("{}.value.datasize", self.tag).into_bytes(), 183 | datasize.to_string().as_bytes(), 184 | ) 185 | .map_err(|_| DeserializeError::RocksDBConnectionError)?; 186 | Ok(()) 187 | } 188 | 189 | pub fn get_datasize(&self) -> Result { 190 | let datasize_key: String = format!("{}.value.datasize", self.tag); 191 | let value = self 192 | .client 193 | .get(datasize_key.as_bytes()) 194 | .map_err(|_| DeserializeError::RocksDBConnectionError)?; 195 | 196 | let datasize = match value { 197 | Some(value_bytes) => { 198 | let value_str = 199 | String::from_utf8(value_bytes).map_err(|_| DeserializeError::InvalidForm)?; // Handle UTF-8 error gracefully 200 | value_str 201 | .parse::() 202 | .map_err(|_| DeserializeError::InvalidForm)? // Handle parse error gracefully 203 | } 204 | None => return Err(DeserializeError::MissingKey), // Handle case where key is not found 205 | }; 206 | Ok(datasize) 207 | } 208 | 209 | pub fn get_num_layers(&self) -> Result { 210 | let num_layers_key: String = format!("{}.value.num_layers", self.tag); 211 | let value = self 212 | .client 213 | .get(num_layers_key.as_bytes()) 214 | .map_err(|_| DeserializeError::RocksDBConnectionError)?; 215 | 216 | let num_layers = match value { 217 | Some(value_bytes) => { 218 | let value_str = 219 | String::from_utf8(value_bytes).map_err(|_| DeserializeError::InvalidForm)?; // Handle UTF-8 error gracefully 220 | value_str 221 | .parse::() 222 | .map_err(|_| DeserializeError::InvalidForm)? // Handle parse error gracefully 223 | } 224 | None => return Err(DeserializeError::MissingKey), // Handle case where key is not found 225 | }; 226 | Ok(num_layers) 227 | } 228 | 229 | pub fn set_num_layers(&self, num_layers: usize, expire: bool) -> Result<(), DeserializeError> { 230 | self.client 231 | .put( 232 | format!("{}.value.num_layers", self.tag).into_bytes(), 233 | num_layers.to_string().as_bytes(), 234 | ) 235 | .map_err(|_| DeserializeError::RocksDBConnectionError)?; 236 | Ok(()) 237 | } 238 | 239 | pub fn set_ep(&self, ep: usize, expire: bool) -> Result<(), DeserializeError> { 240 | self.client 241 | .put( 242 | format!("{}.value.ep", self.tag).into_bytes(), 243 | ep.to_string().as_bytes(), 244 | ) 245 | .map_err(|_| DeserializeError::RocksDBConnectionError)?; 246 | Ok(()) 247 | } 248 | 249 | pub fn get_ep(&self) -> Result { 250 | let ep_key: String = format!("{}.value.ep", self.tag); 251 | let value = self 252 | .client 253 | .get(ep_key.as_bytes()) 254 | .map_err(|_| DeserializeError::RocksDBConnectionError)?; 255 | 256 | // Attempt to convert the fetched value from bytes to String, then parse it into usize 257 | let ep_usize = match value { 258 | Some(value_bytes) => { 259 | // Convert bytes to String 260 | let value_str = 261 | String::from_utf8(value_bytes).map_err(|_| DeserializeError::InvalidForm)?; // Handle UTF-8 error gracefully 262 | // Parse String to usize 263 | value_str 264 | .parse::() 265 | .map_err(|_| DeserializeError::InvalidForm)? // Handle parse error gracefully 266 | } 267 | None => return Err(DeserializeError::MissingKey), // Handle case where key is not found 268 | }; 269 | Ok(ep_usize) 270 | } 271 | 272 | pub fn set_metadata(&self, metadata: Value, idx: usize) -> Result<(), DeserializeError> { 273 | let key = format!("{}.value.m:{}", self.tag, idx); 274 | let metadata_str = serde_json::to_vec(&metadata).unwrap(); 275 | self.client 276 | .put(key.as_bytes(), metadata_str) 277 | .map_err(|_| DeserializeError::RocksDBConnectionError)?; 278 | Ok(()) 279 | } 280 | 281 | pub fn set_metadata_batch( 282 | &self, 283 | metadata: Vec, 284 | idx: usize, 285 | ) -> Result<(), DeserializeError> { 286 | let mut batch = WriteBatch::default(); 287 | 288 | for (i, m) in metadata.iter().enumerate() { 289 | let key = format!("{}.value.m:{}", self.tag, idx + i); 290 | let metadata_str = serde_json::to_vec(&m).unwrap(); 291 | batch.put(key.as_bytes(), metadata_str); 292 | } 293 | self.client 294 | .write(batch) 295 | .map_err(|_| DeserializeError::RocksDBConnectionError)?; 296 | Ok(()) 297 | } 298 | 299 | pub fn get_metadata(&self, idx: usize) -> Result { 300 | let key = format!("{}.value.m:{}", self.tag, idx); 301 | 302 | let value = self 303 | .client 304 | .get(key.as_bytes()) 305 | .map_err(|_| DeserializeError::RocksDBConnectionError)?; 306 | 307 | let metadata = match value { 308 | Some(value) => serde_json::from_slice(&value).unwrap() 309 | , // Convert bytes to String and handle UTF-8 error 310 | None => return Err(DeserializeError::MissingKey), // Handle case where key is not found 311 | }; 312 | Ok(metadata) 313 | } 314 | 315 | pub fn get_metadatas(&self, indices: Vec) -> Result, DeserializeError> { 316 | let keys = indices 317 | .iter() 318 | .map(|x| format!("{}.value.m:{}", self.tag, x).as_bytes().to_vec()) 319 | .collect::>>(); 320 | 321 | // Assuming multi_get returns Vec>, E>> directly 322 | let values = self.client.multi_get(&keys); 323 | 324 | let mut metadata = Vec::new(); 325 | for value_result in values { 326 | match value_result { 327 | Ok(Some(v)) => { 328 | // Properly handle potential serde_json deserialization errors 329 | match serde_json::from_slice::(&v) { 330 | Ok(meta) => metadata.push(meta), 331 | Err(_) => return Err(DeserializeError::InvalidForm), // Add a DeserializationError variant if not already present 332 | } 333 | } 334 | Ok(None) => return Err(DeserializeError::MissingKey), // Key not found 335 | Err(_) => return Err(DeserializeError::RocksDBConnectionError), // Error fetching from RocksDB 336 | } 337 | } 338 | 339 | Ok(metadata) 340 | } 341 | } 342 | -------------------------------------------------------------------------------- /dria_hnsw/src/errors/errors.rs: -------------------------------------------------------------------------------- 1 | use actix_web::{ 2 | error::ResponseError, 3 | http::{header::ContentType, StatusCode}, 4 | HttpResponse, 5 | }; 6 | use derive_more::{Display, Error}; 7 | use std::fmt; 8 | 9 | #[derive(Debug)] 10 | pub enum DeserializeError { 11 | MissingKey, 12 | InvalidForm, 13 | RocksDBConnectionError, 14 | RedisConnectionError, 15 | DNSResolverError, 16 | ClusterConnectionError, 17 | } 18 | 19 | impl fmt::Display for DeserializeError { 20 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 21 | match *self { 22 | DeserializeError::MissingKey => write!(f, "Key is missing in the response"), 23 | DeserializeError::InvalidForm => write!(f, "Value is not in the expected format"), 24 | DeserializeError::RocksDBConnectionError => write!(f, "Error connecting to RocksDB"), 25 | DeserializeError::RedisConnectionError => write!(f, "Error connecting to Redis"), 26 | DeserializeError::DNSResolverError => write!(f, "Error resolving DNS"), 27 | DeserializeError::ClusterConnectionError => { 28 | write!(f, "Error connecting to Cluster at init") 29 | } 30 | } 31 | } 32 | } 33 | 34 | impl std::error::Error for DeserializeError {} 35 | 36 | #[derive(Debug, Display, Error)] 37 | pub enum MiddlewareError { 38 | #[display(fmt = "internal error")] 39 | InternalError, 40 | 41 | #[display(fmt = "api key not found")] 42 | APIKeyError, 43 | 44 | #[display(fmt = "timeout")] 45 | Timeout, 46 | } 47 | 48 | impl ResponseError for MiddlewareError { 49 | fn error_response(&self) -> HttpResponse { 50 | HttpResponse::build(self.status_code()) 51 | .insert_header(ContentType::html()) 52 | .body(self.to_string()) 53 | } 54 | 55 | fn status_code(&self) -> StatusCode { 56 | match *self { 57 | MiddlewareError::InternalError => StatusCode::INTERNAL_SERVER_ERROR, 58 | MiddlewareError::APIKeyError => StatusCode::UNAUTHORIZED, 59 | MiddlewareError::Timeout => StatusCode::GATEWAY_TIMEOUT, 60 | } 61 | } 62 | } 63 | 64 | #[derive(Debug)] 65 | pub struct ValidationError(pub String); 66 | 67 | impl fmt::Display for ValidationError { 68 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 69 | write!(f, "{}", self.0) 70 | } 71 | } 72 | 73 | impl std::error::Error for ValidationError {} 74 | -------------------------------------------------------------------------------- /dria_hnsw/src/errors/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod errors; 2 | -------------------------------------------------------------------------------- /dria_hnsw/src/filter/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod text_based; -------------------------------------------------------------------------------- /dria_hnsw/src/filter/text_based.rs: -------------------------------------------------------------------------------- 1 | use probly_search::score::ScoreCalculator; 2 | use probly_search::score::{bm25, zero_to_one}; 3 | use probly_search::Index; 4 | use serde_json::{json, Value}; 5 | use std::borrow::Cow; 6 | 7 | pub struct Doc { 8 | pub id: usize, 9 | pub text: String, 10 | } 11 | 12 | fn tokenizer(s: &str) -> Vec> { 13 | s.split(' ').map(Cow::from).collect::>() 14 | } 15 | 16 | fn text_extract(d: &Doc) -> Vec<&str> { 17 | vec![d.text.as_str()] 18 | } 19 | 20 | pub fn create_index_from_docs( 21 | index: &mut Index, 22 | query: &str, 23 | metadata: Vec, 24 | ) -> Vec { 25 | let mut wikis = Vec::new(); 26 | let mut query_results = Vec::new(); 27 | let mut ids = Vec::new(); 28 | let mut iter = 0; 29 | 30 | for value in metadata.iter() { 31 | let mut text = value["metadata"]["text"].as_str(); 32 | if text.is_none() { 33 | // use whole metadata as text 34 | text = value["metadata"].as_str(); 35 | } 36 | 37 | let id_doc = value["id"].as_u64().unwrap() as usize; 38 | // let url = value["metadata"]["url"].as_str(); 39 | 40 | let t = text.unwrap().to_string(); 41 | 42 | let sentences = t.split("."); 43 | for sentence in sentences { 44 | let wiki = Doc { 45 | id: iter, 46 | text: sentence.to_string(), 47 | }; 48 | wikis.push(wiki); 49 | 50 | let mut value_x = value.clone(); 51 | value_x["metadata"]["text"] = json!(sentence.to_string()); 52 | query_results.push(value_x); 53 | ids.push(id_doc); 54 | iter += 1; 55 | } 56 | } 57 | if wikis.len() == 0 { 58 | return metadata; 59 | } 60 | 61 | for wiki in wikis.iter() { 62 | index.add_document(&[text_extract], tokenizer, wiki.id.clone(), &wiki); 63 | } 64 | 65 | let results = index.query(query, &mut zero_to_one::new(), tokenizer, &[1.]); 66 | let mut results_as_wiki = vec![]; 67 | for res in results.iter() { 68 | let val = json!({"id": ids[res.key], "metadata": query_results[res.key].clone(), "score": res.score}); 69 | results_as_wiki.push(val); 70 | } 71 | return results_as_wiki; 72 | } 73 | -------------------------------------------------------------------------------- /dria_hnsw/src/hnsw/index.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_snake_case)] 2 | 3 | extern crate redis; 4 | 5 | use actix_web::web::Data; 6 | use hashbrown::HashSet; 7 | use mini_moka::sync::Cache; 8 | use redis::Commands; 9 | use std::borrow::Borrow; 10 | use std::cmp::Reverse; 11 | use std::collections::HashMap; 12 | use std::sync::atomic::{AtomicIsize, AtomicUsize, Ordering}; 13 | use std::sync::Arc; 14 | 15 | use simsimd::SimSIMD; 16 | 17 | use rand::{thread_rng, Rng, SeedableRng}; 18 | 19 | use crate::proto::index_buffer::{LayerNode, Point}; 20 | use prost::Message; 21 | 22 | use crate::errors::errors::DeserializeError; 23 | use crate::hnsw::utils::{create_max_heap, create_min_heap, IntoHeap, IntoMap, Numeric}; 24 | 25 | use crate::hnsw::scalar::ScalarQuantizer; 26 | use rayon::prelude::*; 27 | use serde_json::{json, Value}; 28 | 29 | use crate::db::rocksdb_client::RocksdbClient; 30 | use crate::hnsw::sync_map::SynchronizedNodes; 31 | 32 | pub const SINGLE_THREADED_HNSW_BUILD_THRESHOLD: usize = 256; 33 | 34 | /* 35 | Redis Scheme 36 | 37 | Points: "0", "1" ... 38 | Graph: graph_level.idx : "2.5320" layer 2, node idx 5320 39 | */ 40 | pub struct HNSW { 41 | pub m: usize, 42 | pub m_max0: usize, 43 | pub rng_seed: u64, 44 | pub ml: f32, 45 | pub ef_construction: usize, 46 | pub ef: usize, 47 | pub db: Data, 48 | quantizer: ScalarQuantizer, 49 | metric: Option, 50 | } 51 | 52 | impl HNSW { 53 | pub fn new( 54 | M: usize, 55 | ef_construction: usize, 56 | ef: usize, 57 | //contract_id: String, 58 | metric: Option, 59 | db: Data, 60 | ) -> HNSW { 61 | let m = M; 62 | let m_max0 = M * 2; 63 | let ml = 1.0 / (M as f32).ln(); 64 | //let db = RocksdbClient::new(contract_id).expect("Error creating RocksdbClient"); 65 | let sq = ScalarQuantizer::new(256, 1000, 256); 66 | 67 | HNSW { 68 | m, 69 | m_max0, 70 | rng_seed: 0, 71 | ml, 72 | ef_construction, 73 | ef, 74 | db, 75 | quantizer: sq, 76 | metric, 77 | } 78 | } 79 | 80 | pub fn set_rng_seed(&mut self, seed: u64) { 81 | self.rng_seed = seed; 82 | } 83 | 84 | pub fn set_ef(&mut self, ef: usize) { 85 | self.ef = ef; 86 | } 87 | 88 | pub fn select_layer(&self) -> usize { 89 | let mut random = thread_rng(); 90 | let rand_float: f32 = random.gen_range(1e-6..1.0); // Avoid very small values 91 | let result = (-1.0 * rand_float.ln() * self.ml) as usize; 92 | 93 | // Optionally clamp to a maximum value if applicable 94 | let max_layer = 1000; // Example maximum layer 95 | std::cmp::min(result, max_layer) 96 | } 97 | 98 | fn distance(&self, x: &[f32], y: &[f32], dist: &Option) -> f32 { 99 | let dist = match dist.as_ref().map(String::as_str) { 100 | Some("sqeuclidean") => SimSIMD::sqeuclidean(x, y), 101 | Some("inner") => SimSIMD::inner(x, y), 102 | Some("cosine") | None => SimSIMD::cosine(x, y), 103 | _ => panic!("Unsupported distance metric"), 104 | }; 105 | if dist.is_none() { 106 | println!("Error in distance"); //make the error propagate 107 | } 108 | dist.unwrap() 109 | } 110 | 111 | fn get_points_w_memory( 112 | &self, 113 | indices: &Vec, 114 | point_map: Cache, 115 | ) -> Vec { 116 | // Initialize points with None to reserve the space and maintain order 117 | let mut points: Vec> = vec![None; indices.len()]; 118 | 119 | // Track missing indices and their positions 120 | let mut missing_indices_with_pos: Vec<(usize, u32)> = Vec::new(); 121 | 122 | for (pos, idx) in indices.iter().enumerate() { 123 | let key = format!("p:{}", idx); 124 | if let Some(point) = point_map.get(&key) { 125 | points[pos] = Some(point.clone()); 126 | } else { 127 | missing_indices_with_pos.push((pos, *idx)); 128 | } 129 | } 130 | 131 | if !missing_indices_with_pos.is_empty() { 132 | let missing_indices: Vec = missing_indices_with_pos 133 | .iter() 134 | .map(|&(_, idx)| idx) 135 | .collect(); 136 | 137 | let fetched_points = self.db.get_points(&missing_indices); 138 | 139 | if fetched_points.is_err() { 140 | println!( 141 | "Error getting points, get points _w _memory {:?}", 142 | &missing_indices 143 | ); 144 | } 145 | 146 | let fetched_points = fetched_points.unwrap(); 147 | 148 | for point in fetched_points { 149 | let key = format!("p:{}", point.idx); 150 | point_map.insert(key, point.clone()); 151 | 152 | if let Some(&(pos, _)) = missing_indices_with_pos 153 | .iter() 154 | .find(|&&(_, idx)| idx == point.idx) 155 | { 156 | points[pos] = Some(point); 157 | } 158 | } 159 | } 160 | points.into_iter().filter_map(|p| p).collect() 161 | } 162 | 163 | fn get_neighbors_w_memory( 164 | &self, 165 | layer: usize, 166 | indices: &Vec, 167 | node_map: Arc, 168 | ) -> Vec { 169 | let mut nodes = Vec::with_capacity(indices.len()); 170 | let mut missing_indices = Vec::new(); 171 | 172 | // First pass: Fill in nodes from node_map or mark as missing 173 | for &idx in indices { 174 | let key = format!("{}:{}", layer, idx); 175 | match node_map.get_or_wait_opt(&key) { 176 | Some(node) => nodes.push(node.clone()), 177 | None => { 178 | missing_indices.push(idx); 179 | nodes.push(LayerNode::new(0, 0)); // Placeholder for missing nodes 180 | } 181 | } 182 | } 183 | if !missing_indices.is_empty() { 184 | let fetched_nodes = self 185 | .db 186 | .get_neighbors(layer, missing_indices) 187 | .expect("Error getting neighbors"); 188 | 189 | for fetched_node in fetched_nodes.iter() { 190 | let index = indices.iter().position(|&i| i == fetched_node.idx).unwrap(); 191 | nodes[index] = fetched_node.clone(); 192 | } 193 | node_map.insert_batch_and_notify(fetched_nodes); 194 | } 195 | 196 | nodes 197 | } 198 | 199 | fn get_neighbor_w_memory( 200 | &self, 201 | layer: usize, 202 | idx: usize, 203 | node_map: Arc, 204 | ) -> LayerNode { 205 | let key = format!("{}:{}", layer, idx); 206 | 207 | let node_option = node_map.get_or_wait_opt(&key); 208 | return if let Some(node) = node_option { 209 | node.clone() 210 | } else { 211 | let node_ = self.db.get_neighbor(layer, idx); 212 | if node_.is_err() { 213 | println!("Sync issue, awaiting notification..."); 214 | let value = node_map.get_or_wait(&key); 215 | return value.clone(); 216 | } 217 | let node_ = node_.unwrap(); 218 | node_map.insert_and_notify(&node_); 219 | node_ 220 | }; 221 | } 222 | 223 | pub fn insert_w_preset( 224 | &self, 225 | idx: usize, 226 | node_map: Arc, 227 | point_map: Cache, 228 | nl: Arc, 229 | epa: Arc, 230 | ) -> Result<(), DeserializeError> { 231 | let mut W = HashMap::new(); 232 | 233 | let mut ep_index = None; 234 | let ep_index_ = epa.load(Ordering::SeqCst); 235 | if ep_index_ != -1 { 236 | ep_index = Some(ep_index_ as u32); 237 | } 238 | 239 | let mut num_layers = nl.load(Ordering::Relaxed); 240 | 241 | let L = if num_layers == 0 { 0 } else { num_layers - 1 }; 242 | let l = self.select_layer(); 243 | 244 | let qs = self.get_points_w_memory(&vec![idx as u32], point_map.clone()); 245 | let q = qs[0].v.clone(); 246 | 247 | if ep_index.is_some() { 248 | let ep_index_ = ep_index.unwrap(); 249 | 250 | let points = self.get_points_w_memory(&vec![ep_index_], point_map.clone()); 251 | let point = points.first().unwrap(); 252 | let dist = self.distance(&q, &point.v, &self.metric); 253 | let mut ep = HashMap::from([(ep_index_, dist)]); 254 | 255 | for i in ((l + 1)..=L).rev() { 256 | W = self.search_layer(&q, ep.clone(), 1, i, node_map.clone(), point_map.clone())?; 257 | 258 | if let Some((_, value)) = W.iter().next() { 259 | if &dist < value { 260 | ep = W; 261 | } 262 | } 263 | } 264 | 265 | for l_c in (0..=std::cmp::min(L, l)).rev() { 266 | W = self.search_layer( 267 | &q, 268 | ep, 269 | self.ef_construction, 270 | l_c, 271 | node_map.clone(), 272 | point_map.clone(), 273 | )?; 274 | 275 | //upsert expire = true by default, populate upserted_keys for replication 276 | node_map.insert_and_notify(&LayerNode::new(l_c, idx)); 277 | 278 | ep = W.clone(); 279 | let neighbors = self.select_neighbors(&q, W, l_c, true); 280 | 281 | let M = if l_c == 0 { self.m_max0 } else { self.m }; 282 | 283 | //read neighbors of all nodes in selected neighbors 284 | let mut indices = neighbors.iter().map(|x| *x.0).collect::>(); 285 | indices.push(idx as u32); 286 | let idx_i = indices.len() - 1; 287 | 288 | let mut nodes = self.get_neighbors_w_memory(l_c, &indices, node_map.clone()); 289 | 290 | for (i, (e_i, dist)) in neighbors.iter().enumerate() { 291 | if i == idx_i { 292 | // We want to skip last layernode, which is idx -> layernode 293 | continue; 294 | } 295 | 296 | nodes[i].neighbors.insert(idx as u32, *dist); 297 | nodes[idx_i].neighbors.insert(*e_i, *dist); 298 | } 299 | 300 | // TODO: remove redundant 301 | for (i, (e_i, dist)) in neighbors.iter().enumerate() { 302 | if i == idx_i { 303 | // We want to skip last layernode, which is idx -> layernode 304 | continue; 305 | } 306 | let eConn = nodes[i].neighbors.clone(); 307 | if eConn.len() > M { 308 | let eNewConn = self.select_neighbors(&q, eConn, l_c, true); 309 | nodes[i].neighbors = eNewConn.clone(); 310 | } 311 | } 312 | 313 | node_map.insert_batch_and_notify(nodes); 314 | } 315 | } 316 | 317 | for i in num_layers..=l { 318 | node_map.insert_and_notify(&LayerNode::new(i, idx)); 319 | let _ = epa.fetch_update(Ordering::SeqCst, Ordering::Relaxed, |x| Some(idx as isize)); 320 | } 321 | 322 | let _ = nl.fetch_update(Ordering::SeqCst, Ordering::Acquire, |v| { 323 | if l + 1 > v { 324 | Some(l + 1) 325 | } else { 326 | None 327 | } 328 | }); 329 | 330 | Ok(()) 331 | } 332 | 333 | fn search_layer( 334 | &self, 335 | q: &Vec, 336 | ep: HashMap, 337 | ef: usize, 338 | l_c: usize, 339 | node_map: Arc, 340 | point_map: Cache, 341 | ) -> Result, DeserializeError> { 342 | let mut v = HashSet::new(); 343 | 344 | for (k, _) in ep.iter() { 345 | v.insert(k.clone()); 346 | } 347 | 348 | let mut C = ep.clone().into_minheap(); 349 | let mut W = ep.into_maxheap(); 350 | 351 | while !C.is_empty() { 352 | let c = C.pop().unwrap().0; 353 | let f_value = W.peek().unwrap().0 .0; 354 | 355 | if c.0 .0 > f_value { 356 | break; 357 | } 358 | 359 | let layernd = self.get_neighbor_w_memory(l_c, c.1 as usize, node_map.clone()); 360 | 361 | let mut pairs: Vec<_> = layernd.neighbors.into_iter().collect(); 362 | //pairs.sort_by(|&(_, a), &(_, b)| a.partial_cmp(&b).unwrap()); 363 | pairs.sort_by(|&(_, a), &(_, b)| { 364 | if a.is_nan() || b.is_nan() { 365 | println!("NaN value detected: a = {}, b = {}", a, b); 366 | std::cmp::Ordering::Greater 367 | } else { 368 | a.partial_cmp(&b).unwrap_or_else(|| { 369 | println!("Unexpected comparison error: a = {}, b = {}", a, b); 370 | std::cmp::Ordering::Greater 371 | }) 372 | } 373 | }); 374 | let sorted_keys: Vec = pairs.into_iter().map(|(k, _)| k).collect(); 375 | 376 | let neighbors: Vec = sorted_keys 377 | .into_iter() 378 | .filter_map(|x| if !v.contains(&x) { Some(x) } else { None }) 379 | .collect(); 380 | 381 | let points = self.get_points_w_memory(&neighbors, point_map.clone()); 382 | 383 | let distances = points 384 | .iter() 385 | .map(|x| self.distance(&q, &x.v, &self.metric)) 386 | .collect::>(); 387 | 388 | for (i, d) in neighbors.iter().zip(distances.iter()) { 389 | v.insert(i.clone()); 390 | if d < &f_value || W.len() < ef { 391 | C.push(Reverse((Numeric(d.clone()), i.clone()))); 392 | W.push((Numeric(d.clone()), i.clone())); 393 | if W.len() > ef { 394 | W.pop(); 395 | } 396 | } 397 | } 398 | } 399 | 400 | if ef == 1 { 401 | if W.len() > 0 { 402 | let W_map = W.into_map(); 403 | let mut W_min = W_map.into_minheap(); 404 | let mut single_map = HashMap::new(); 405 | let min_val = W_min.pop().unwrap().0; 406 | single_map.insert(min_val.1, min_val.0 .0); 407 | return Ok(single_map); 408 | } else { 409 | return Ok(HashMap::new()); 410 | } 411 | } 412 | 413 | Ok(W.into_map()) 414 | } 415 | 416 | fn select_neighbors( 417 | &self, 418 | q: &Vec, 419 | C: HashMap, 420 | l_c: usize, 421 | k_p_c: bool, 422 | ) -> HashMap { 423 | let mut R = create_min_heap(); 424 | let mut W = C.into_minheap(); 425 | 426 | let mut M = 0; 427 | 428 | if l_c > 0 { 429 | M = self.m 430 | } else { 431 | M = self.m_max0; 432 | } 433 | 434 | let mut W_d = create_min_heap(); 435 | while W.len() > 0 && R.len() < M { 436 | let e = W.pop().unwrap().0; 437 | 438 | if R.len() == 0 || e.0 < R.peek().unwrap().0 .0 { 439 | R.push(Reverse(e)); 440 | } else { 441 | W_d.push(Reverse(e)); 442 | } 443 | } 444 | 445 | if k_p_c { 446 | while W_d.len() > 0 && R.len() < M { 447 | R.push(W_d.pop().unwrap()); 448 | } 449 | } 450 | R.into_map() 451 | } 452 | 453 | pub fn knn_search( 454 | &self, 455 | q: &Vec, 456 | K: usize, 457 | node_map: Arc, 458 | point_map: Cache, 459 | ) -> Vec { 460 | let mut W = HashMap::new(); 461 | 462 | let ep_index = self.db.get_ep().expect("") as u32; 463 | let num_layers = self.db.get_num_layers().expect("Error getting num_layers"); 464 | 465 | let points = self.get_points_w_memory(&vec![ep_index], point_map.clone()); 466 | let point = points.first().unwrap(); 467 | let dist = self.distance(&q, &point.v, &self.metric); 468 | let mut ep = HashMap::from([(ep_index, dist)]); 469 | 470 | for l_c in (1..=num_layers - 1).rev() { 471 | W = self 472 | .search_layer(&q, ep, 1, l_c, node_map.clone(), point_map.clone()) 473 | .expect("Error searching layer"); 474 | ep = W; 475 | } 476 | 477 | let ep_ = self 478 | .search_layer(q, ep, self.ef, 0, node_map.clone(), point_map.clone()) 479 | .expect("Error searching layer"); 480 | 481 | let mut heap = ep_.into_minheap(); 482 | let mut sorted_vec = Vec::new(); 483 | while !heap.is_empty() && sorted_vec.len() < K { 484 | let item = heap.pop().unwrap().0; 485 | sorted_vec.push((item.1, 1.0 - item.0 .0)); 486 | } 487 | let indices = sorted_vec.iter().map(|x| x.0).collect::>(); 488 | let metadata = self 489 | .db 490 | .get_metadatas(indices) 491 | .expect("Error getting metadatas"); 492 | 493 | let result = sorted_vec 494 | .iter() 495 | .zip(metadata.iter()) 496 | .map(|(x, y)| json!({"id":x.0, "score":x.1, "metadata":y.clone()})) 497 | .collect::>(); 498 | result 499 | } 500 | } 501 | -------------------------------------------------------------------------------- /dria_hnsw/src/hnsw/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod index; 2 | pub mod scalar; 3 | pub mod sync_map; 4 | pub mod utils; 5 | -------------------------------------------------------------------------------- /dria_hnsw/src/hnsw/scalar.rs: -------------------------------------------------------------------------------- 1 | extern crate core; 2 | extern crate serde_json; 3 | 4 | use serde::{Deserialize, Serialize}; 5 | use tdigest::TDigest; 6 | 7 | #[derive(Debug, Clone)] 8 | pub struct ScalarQuantizer { 9 | levels: usize, 10 | quantiles: Vec, 11 | t_digest: TDigest, 12 | dim: usize, 13 | } 14 | 15 | impl ScalarQuantizer { 16 | pub fn new(levels: usize, size: usize, dim: usize) -> Self { 17 | let quantiles = vec![]; 18 | let t_digest = TDigest::new_with_size(size); 19 | ScalarQuantizer { 20 | levels, 21 | quantiles, 22 | t_digest, 23 | dim, 24 | } 25 | } 26 | 27 | pub fn merge(&mut self, matrix: Vec>) { 28 | let flattened: Vec = matrix.into_iter().flatten().collect(); 29 | self.t_digest = self.t_digest.merge_unsorted(flattened); // Pass a reference to flattened 30 | self.quantiles = vec![]; //reset 31 | for i in 0..self.levels { 32 | self.quantiles.push( 33 | self.t_digest 34 | .estimate_quantile((i as f64) / (self.levels.clone() as f64)), 35 | ); 36 | } 37 | } 38 | 39 | fn __quantize_scalar(&self, scalar: &f64) -> usize { 40 | let ax = &self.quantiles; 41 | ax.into_iter() 42 | .enumerate() 43 | .find(|&(_, q)| scalar < q) 44 | .map_or(256 - 1, |(i, _)| i) 45 | } 46 | 47 | pub fn quantize_vectors(&self, vecs: Vec>) -> Vec> { 48 | let single: Vec = vecs.into_iter().flatten().collect(); 49 | let quantized = self.quantize(single.as_slice()); 50 | quantized 51 | .chunks(self.dim.clone()) 52 | .map(|chunk| chunk.to_vec()) 53 | .collect() 54 | } 55 | 56 | pub fn dequantize(&self, qv: &[usize]) -> Vec { 57 | qv.iter().map(|&val| self.quantiles[val.min(255)]).collect() 58 | } 59 | 60 | pub fn quantize(&self, v: &[f64]) -> Vec { 61 | v.iter() 62 | .map(|value| self.__quantize_scalar(value)) 63 | .collect() 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /dria_hnsw/src/hnsw/sync_map.rs: -------------------------------------------------------------------------------- 1 | use crate::proto::index_buffer::LayerNode; 2 | use crossbeam_channel::{bounded, Receiver, Sender}; 3 | use dashmap::DashMap; 4 | use hashbrown::HashSet; 5 | use std::collections::HashMap; 6 | use std::sync::Arc; 7 | //use tokio::sync::{Mutex, RwLock, RwLockWriteGuard}; 8 | use parking_lot::{Mutex, RwLock}; 9 | use std::time::Duration; 10 | 11 | //static ref RESET_SIZE = 120_000; 12 | 13 | pub struct SynchronizedNodes { 14 | pub map: Arc>, //Cache, //Arc>, 15 | pub lock_map: Arc>>, 16 | wait_map: Mutex, Receiver<()>)>>, 17 | } 18 | 19 | impl SynchronizedNodes { 20 | pub fn new() -> Self { 21 | SynchronizedNodes { 22 | map: Arc::new(DashMap::new()), //Arc::new(DashMap::new()),cache 23 | lock_map: Arc::new(DashMap::new()), 24 | wait_map: Mutex::new(HashMap::new()), 25 | } 26 | } 27 | 28 | pub fn reset(&self) { 29 | if self.map.len() > 120_000 { 30 | self.map.clear(); 31 | } 32 | } 33 | 34 | pub fn insert_and_notify(&self, node: &LayerNode) { 35 | let key = format!("{}:{}", node.level, node.idx); 36 | 37 | { 38 | let node_lock = self 39 | .lock_map 40 | .entry(key.clone()) 41 | .or_insert_with(|| RwLock::new(())); 42 | 43 | let _write_guard = node_lock.write(); 44 | // Insert or update the node in the DashMap 45 | self.map.insert(key.clone(), node.clone()); 46 | } 47 | 48 | // Notify all waiting threads registered for this key 49 | self.notify(&key); 50 | } 51 | 52 | pub fn insert_batch_and_notify(&self, nodes: Vec) { 53 | //let mut wait_map_guard = self.wait_map.lock().unwrap(); 54 | let mut keys_to_notify = HashSet::new(); 55 | 56 | for node in nodes.iter() { 57 | let key = format!("{}:{}", node.level, node.idx); 58 | 59 | { 60 | let node_lock = self 61 | .lock_map 62 | .entry(key.clone()) 63 | .or_insert_with(|| RwLock::new(())); 64 | let _write_guard = node_lock.write(); // Lock for writing 65 | self.map.insert(key.clone(), node.clone()); 66 | } 67 | 68 | keys_to_notify.insert(key); 69 | //drop(_write_guard); 70 | } 71 | 72 | for key in keys_to_notify { 73 | self.notify(&key); 74 | } 75 | } 76 | 77 | pub fn get_or_wait(&self, key: &str) -> LayerNode { 78 | loop { 79 | // Register for notification before checking if a node is being inserted 80 | 81 | if let Some(value) = self.map.get(&key.to_string()) { 82 | return value.value().clone(); 83 | } 84 | 85 | let receiver = self.register_for_notification(key); 86 | 87 | // A secondary check 88 | if let Some(value) = self.map.get(&key.to_string()) { 89 | println!("Second check grabbed key: {}", key); 90 | return value.value().clone(); 91 | } 92 | 93 | //receiver.recv().unwrap(); // Block the thread until notification is received 94 | match receiver.recv_timeout(Duration::from_millis(500)) { 95 | Ok(_) => { /* Handle reception */ } 96 | Err(e) => { 97 | // Handle timeout or other errors 98 | eprintln!("Error or timeout waiting for message: {:?}", e); 99 | continue; // or handle differently 100 | } 101 | } 102 | } 103 | } 104 | 105 | pub fn get_or_wait_opt(&self, key: &str) -> Option { 106 | loop { 107 | if let Some(value) = self.map.get(&key.to_string()) { 108 | return Some(value.value().clone()); 109 | } 110 | 111 | // Check if this key is expected to be updated soon 112 | let receiver = { 113 | let wait_map_guard = self.wait_map.lock(); 114 | if let Some((_sender, receiver)) = wait_map_guard.get(key) { 115 | Some(receiver.clone()) // Clone the receiver 116 | } else { 117 | None // Key is not expected to be updated soon 118 | } 119 | }; 120 | 121 | if let Some(receiver) = receiver { 122 | // Wait for the notification if the key is expected to be updated 123 | match receiver.recv() { 124 | Ok(_) => { 125 | // Handle the received message 126 | } 127 | Err(e) => { 128 | // Handle the error, e.g., log it or perform a fallback action 129 | eprintln!("Error receiving message: {:?}", e); 130 | return None; 131 | } 132 | } 133 | } else { 134 | // If the key is not expected to be updated soon, return None 135 | return None; 136 | } 137 | } 138 | } 139 | 140 | pub fn register_for_notification(&self, key: &str) -> Receiver<()> { 141 | let mut wait_map_guard = self.wait_map.lock(); 142 | if !wait_map_guard.contains_key(key) { 143 | // Create a new sender/receiver pair if it doesn't exist 144 | let (sender, receiver) = bounded(10); 145 | wait_map_guard.insert(key.to_string(), (sender, receiver.clone())); 146 | receiver 147 | } else { 148 | // If it already exists, return the existing receiver 149 | wait_map_guard.get(key).unwrap().1.clone() 150 | } 151 | } 152 | 153 | pub fn notify(&self, key: &str) { 154 | let mut wait_map_guard = self.wait_map.lock(); 155 | if let Some((sender, _)) = wait_map_guard.remove(key) { 156 | drop(wait_map_guard); // Drop the lock before sending to avoid deadlocks 157 | let _ = sender.send(()); // It's safe to ignore the send result 158 | } 159 | } 160 | } 161 | 162 | #[cfg(test)] 163 | mod tests { 164 | // use super::*; 165 | // use crate::db::conversions::{base64_to_node, node_to_base64}; 166 | // use crate::proto::index_buffer::LayerNode; 167 | // use crate::proto::index_buffer::Point; 168 | // use std::sync::Arc; 169 | // use std::thread; 170 | // use std::time::Duration; 171 | 172 | #[test] 173 | #[ignore = "todo"] 174 | fn test_synchronized_nodes() {} 175 | } 176 | -------------------------------------------------------------------------------- /dria_hnsw/src/hnsw/utils.rs: -------------------------------------------------------------------------------- 1 | use std::cmp::{Ordering, Reverse}; 2 | use std::collections::BinaryHeap; //node_metadata 3 | use std::collections::HashMap; 4 | 5 | #[derive(PartialEq, Debug)] 6 | pub struct Numeric(pub f32); 7 | 8 | impl Eq for Numeric {} 9 | 10 | impl PartialOrd for Numeric { 11 | fn partial_cmp(&self, other: &Self) -> Option { 12 | self.0.partial_cmp(&other.0) 13 | } 14 | } 15 | 16 | impl Ord for Numeric { 17 | fn cmp(&self, other: &Self) -> Ordering { 18 | self.0.partial_cmp(&other.0).unwrap_or(Ordering::Equal) 19 | } 20 | } 21 | 22 | pub trait IntoMap { 23 | fn into_map(self) -> HashMap; 24 | } 25 | 26 | impl IntoMap for BinaryHeap<(Numeric, u32)> { 27 | fn into_map(self) -> HashMap { 28 | self.into_iter().map(|(d, i)| (i, d.0)).collect() 29 | } 30 | } 31 | 32 | impl IntoMap for BinaryHeap> { 33 | fn into_map(self) -> HashMap { 34 | self.into_iter().map(|Reverse((d, i))| (i, d.0)).collect() 35 | } 36 | } 37 | 38 | pub trait IntoHeap { 39 | fn into_maxheap(self) -> BinaryHeap<(Numeric, u32)>; 40 | fn into_minheap(self) -> BinaryHeap>; 41 | } 42 | 43 | impl IntoHeap for HashMap { 44 | fn into_maxheap(self) -> BinaryHeap<(Numeric, u32)> { 45 | self.into_iter().map(|(i, d)| (Numeric(d), i)).collect() 46 | } 47 | 48 | fn into_minheap(self) -> BinaryHeap> { 49 | self.into_iter() 50 | .map(|(i, d)| Reverse((Numeric(d), i))) 51 | .collect() 52 | } 53 | } 54 | 55 | impl IntoHeap for Vec<(f32, u32)> { 56 | fn into_maxheap(self) -> BinaryHeap<(Numeric, u32)> { 57 | self.into_iter().map(|(d, i)| (Numeric(d), i)).collect() 58 | } 59 | 60 | fn into_minheap(self) -> BinaryHeap> { 61 | self.into_iter() 62 | .map(|(d, i)| Reverse((Numeric(d), i))) 63 | .collect() 64 | } 65 | } 66 | 67 | pub fn create_min_heap() -> BinaryHeap> { 68 | let q: BinaryHeap> = BinaryHeap::new(); 69 | q 70 | } 71 | 72 | pub fn create_max_heap() -> BinaryHeap<(Numeric, u32)> { 73 | let q: BinaryHeap<(Numeric, u32)> = BinaryHeap::new(); 74 | q 75 | } 76 | -------------------------------------------------------------------------------- /dria_hnsw/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod db; 2 | pub mod errors; 3 | pub mod hnsw; 4 | pub mod middlewares; 5 | pub mod models; 6 | pub mod proto; 7 | pub mod responses; 8 | pub mod worker; 9 | pub mod filter; 10 | -------------------------------------------------------------------------------- /dria_hnsw/src/main.rs: -------------------------------------------------------------------------------- 1 | use actix_cors::Cors; 2 | use actix_web::middleware::Logger; 3 | use actix_web::{web, App, HttpServer}; 4 | use dria_hnsw::db::env::Config; 5 | use dria_hnsw::db::rocksdb_client::RocksdbClient; 6 | use dria_hnsw::middlewares::cache::{NodeCache, PointCache}; 7 | use dria_hnsw::worker::{fetch, get_health_status, insert_vector, query}; 8 | 9 | pub fn config(conf: &mut web::ServiceConfig) { 10 | conf.service(get_health_status); 11 | conf.service(query); 12 | conf.service(fetch); 13 | conf.service(insert_vector); 14 | } 15 | 16 | #[actix_web::main] 17 | async fn main() -> std::io::Result<()> { 18 | let node_cache = web::Data::new(NodeCache::new()); 19 | let point_cache = web::Data::new(PointCache::new()); 20 | let cfg = Config::new(); 21 | 22 | let rocksdb_client = RocksdbClient::new(cfg.contract_id.clone()); 23 | 24 | if rocksdb_client.is_err() { 25 | println!("Rocksdb client failed to initialize"); 26 | return Err(std::io::Error::new( 27 | std::io::ErrorKind::Other, 28 | "Rocksdb client failed to initialize", 29 | )); 30 | } 31 | let rdb = rocksdb_client.unwrap(); 32 | let rocksdb_client = web::Data::new(rdb); 33 | 34 | let factory = move || { 35 | App::new() 36 | .app_data(web::JsonConfig::default().limit(152428800)) 37 | .app_data(node_cache.clone()) 38 | .app_data(rocksdb_client.clone()) 39 | .app_data(point_cache.clone()) 40 | .configure(config) 41 | .wrap(Logger::default()) 42 | .wrap(Cors::permissive()) 43 | }; 44 | 45 | let url = format!("0.0.0.0:{}", cfg.port); 46 | println!("Dria HNSW listening at {}", url); 47 | HttpServer::new(factory).bind(url)?.run().await?; 48 | Ok(()) 49 | } 50 | -------------------------------------------------------------------------------- /dria_hnsw/src/middlewares/cache.rs: -------------------------------------------------------------------------------- 1 | use crate::hnsw::sync_map::SynchronizedNodes; 2 | use crate::proto::index_buffer::Point; 3 | use mini_moka::sync::Cache; 4 | use std::sync::Arc; 5 | use std::time::Duration; 6 | 7 | pub struct NodeCache { 8 | pub caches: Cache>, 9 | } 10 | 11 | /// If a key within cache is not used (i.e. get or insert) for the given duration (seconds), expire that key. 12 | const NODE_CACHE_EXPIRE: u64 = 48 * 60 * 60; // 2 days 13 | /// Maximum capacity of the cache, in number of keys. 14 | const NODE_CACHE_CAPACITY: u64 = 5_000; 15 | 16 | impl NodeCache { 17 | pub fn new() -> Self { 18 | let cache = Cache::builder() 19 | .time_to_idle(Duration::from_secs(NODE_CACHE_EXPIRE)) 20 | .max_capacity(NODE_CACHE_CAPACITY) 21 | .build(); 22 | 23 | NodeCache { caches: cache } 24 | } 25 | 26 | pub fn get_cache(&self, key: String) -> Arc { 27 | let my_cache = self.caches.clone(); 28 | let node_cache = my_cache.get(&key).unwrap_or_else(|| { 29 | let new_cache = Arc::new(SynchronizedNodes::new()); 30 | my_cache.insert(key.to_string(), new_cache.clone()); 31 | new_cache 32 | }); 33 | 34 | // TODO: clone required here? 35 | node_cache.clone() 36 | } 37 | 38 | pub fn add_cache(&self, key: &str, cache: Arc) { 39 | let my_cache = self.caches.clone(); 40 | my_cache.insert(key.to_string(), cache); 41 | } 42 | } 43 | 44 | /// If a key within cache is not used (i.e. get or insert) for the given duration (seconds), expire that key. 45 | const POINT_CACHE_EXPIRE: u64 = 24 * 60 * 60; // 1 days 46 | /// Maximum capacity of the cache, in number of keys. 47 | const POINT_CACHE_CAPACITY: u64 = 5_000; 48 | 49 | pub struct PointCache { 50 | pub caches: Cache>, //Cache>>, 51 | } 52 | 53 | impl PointCache { 54 | pub fn new() -> Self { 55 | let cache = Cache::builder() 56 | .time_to_idle(Duration::from_secs(POINT_CACHE_EXPIRE)) 57 | .max_capacity(POINT_CACHE_CAPACITY) // around 106MB for 1536 dim vectors 58 | .build(); 59 | 60 | PointCache { caches: cache } 61 | } 62 | 63 | pub fn get_cache(&self, key: String) -> Cache { 64 | //let cache = self.caches.entry(key.to_string()).or_insert_with(|| Arc::new(DashMap::new())); 65 | let my_cache = self.caches.clone(); 66 | let point_cache = my_cache.get(&key).unwrap_or_else(|| { 67 | //let new_cache = Arc::new(DashMap::new()); 68 | let new_cache = Cache::builder() 69 | //if a key is not used (get or insert) for 2 hour, expire it 70 | //.time_to_live(Duration::from_secs(1 * 60 * 60)) 71 | .max_capacity(200_000) // around 2060MB for 1536 dim vectors 72 | .build(); 73 | my_cache.insert(key.to_string(), new_cache.clone()); 74 | new_cache 75 | }); 76 | 77 | // TODO: clone required here? 78 | point_cache.clone() 79 | } 80 | 81 | pub fn add_cache(&self, key: &str, cache: Cache) { 82 | self.caches.insert(key.to_string(), cache); 83 | } 84 | } 85 | 86 | #[cfg(test)] 87 | mod tests { 88 | use super::*; 89 | use std::sync::atomic::{AtomicU32, Ordering}; 90 | 91 | #[test] 92 | fn test_cache() { 93 | let cache = Cache::builder() 94 | //if a key is not used (get or insert) for 2 hour, expire it 95 | .time_to_idle(Duration::from_secs(120 * 60)) 96 | .max_capacity(100) 97 | .build(); 98 | 99 | for i in 0..105 { 100 | cache.insert(i.to_string(), i); 101 | } 102 | 103 | let ix = cache.get(&"0".to_string()); 104 | assert_eq!(ix, Some(0)); 105 | } 106 | 107 | #[test] 108 | fn test_weighted_cache() { 109 | let current_weight = AtomicU32::new(1); // Start weights from 1 to avoid assigning a weight of 0 110 | 111 | let cache = Cache::builder() 112 | .weigher(move |_key, _value: &String| -> u32 { 113 | // Use the current weight and increment for the next use 114 | current_weight.fetch_add(1, Ordering::SeqCst) 115 | }) 116 | // Assuming a simple numeric weight limit for demonstration purposes 117 | .max_capacity(100) 118 | .build(); 119 | 120 | // Example inserts - in a real scenario, make sure to manage the size and weights appropriately 121 | for i in 0..105 { 122 | cache.insert(i, format!("Value {}", i)); 123 | } 124 | 125 | let ix = cache.get(&0); 126 | assert_eq!(ix, Some("Value 0".to_string())); 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /dria_hnsw/src/middlewares/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod cache; 2 | -------------------------------------------------------------------------------- /dria_hnsw/src/models/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod request_models; 2 | -------------------------------------------------------------------------------- /dria_hnsw/src/models/request_models.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | use serde_json::Value; 3 | 4 | use crate::errors::errors::ValidationError; 5 | 6 | #[derive(Serialize, Deserialize, Debug)] 7 | pub struct InsertModel { 8 | pub vector: Vec, 9 | pub metadata: Value, 10 | } 11 | 12 | #[derive(Serialize, Deserialize, Debug)] 13 | pub struct InsertBatchModel { 14 | pub data: Vec, 15 | } 16 | 17 | #[derive(Serialize, Deserialize, Debug)] 18 | pub struct FetchModel { 19 | pub id: Vec, // TODO: rename this to `ids` 20 | } 21 | 22 | #[derive(Serialize, Deserialize, Debug)] 23 | pub struct QueryModel { 24 | pub vector: Vec, 25 | pub top_n: usize, 26 | pub query: Option, 27 | pub level: Option, 28 | } 29 | 30 | impl QueryModel { 31 | pub fn new( 32 | vector: Vec, 33 | top_n: usize, 34 | query: Option, 35 | level: Option, 36 | ) -> Result { 37 | Self::validate_top_n(top_n)?; 38 | Self::validate_level(level)?; 39 | 40 | Ok(QueryModel { 41 | vector, 42 | top_n, 43 | query, 44 | level, 45 | }) 46 | } 47 | 48 | fn validate_top_n(top_n: usize) -> Result<(), ValidationError> { 49 | if top_n > 20 { 50 | Err(ValidationError("Top N cannot be more than 20.".to_string())) 51 | } else { 52 | Ok(()) 53 | } 54 | } 55 | 56 | fn validate_level(level: Option) -> Result<(), ValidationError> { 57 | if level.is_some() { 58 | match level.unwrap() { 59 | 1 | 2 | 3 | 4 => Ok(()), 60 | _ => Err(ValidationError( 61 | "Level should be 1, 2, 3, or 4.".to_string(), 62 | )), 63 | } 64 | } else { 65 | Ok(()) 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /dria_hnsw/src/proto/hnsw_comm.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package index_buffer; 4 | 5 | 6 | message LayerNode { 7 | uint32 level = 1; // Using uint64 as a safe alternative for usize 8 | uint32 idx = 2; // Using uint64 as a safe alternative for usize 9 | bool visible = 3; // Whether the node is visible 10 | map neighbors = 4; // Neighbor idx and its distance 11 | } 12 | 13 | 14 | message Point { 15 | uint32 idx = 1; // Using uint64 as a safe alternative for usize 16 | repeated float v = 2; // Vector of floats 17 | } 18 | 19 | message PointQuant { 20 | uint32 idx = 1; // Using uint64 as a safe alternative for usize 21 | repeated uint32 v = 2; // Vector of ints 22 | } -------------------------------------------------------------------------------- /dria_hnsw/src/proto/index.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package index_buffer; 4 | 5 | 6 | message LayerNode { 7 | uint32 level = 1; // Using uint64 as a safe alternative for usize 8 | uint32 idx = 2; // Using uint64 as a safe alternative for usize 9 | bool visible = 3; // Whether the node is visible 10 | map neighbors = 4; // Neighbor idx and its distance 11 | } 12 | 13 | 14 | message Point { 15 | uint32 idx = 1; // Using uint64 as a safe alternative for usize 16 | repeated float v = 2; // Vector of floats 17 | } 18 | 19 | message PointQuant { 20 | uint32 idx = 1; // Using uint64 as a safe alternative for usize 21 | repeated uint32 v = 2; // Vector of ints 22 | } -------------------------------------------------------------------------------- /dria_hnsw/src/proto/index_buffer.rs: -------------------------------------------------------------------------------- 1 | #[derive(Clone, PartialEq, ::prost::Message)] 2 | pub struct LayerNode { 3 | /// Using uint64 as a safe alternative for usize 4 | #[prost(uint32, tag = "1")] 5 | pub level: u32, 6 | /// Using uint64 as a safe alternative for usize 7 | #[prost(uint32, tag = "2")] 8 | pub idx: u32, 9 | /// Whether the node is visible 10 | #[prost(bool, tag = "3")] 11 | pub visible: bool, 12 | /// Neighbor idx and its distance 13 | #[prost(map = "uint32, float", tag = "4")] 14 | pub neighbors: ::std::collections::HashMap, 15 | } 16 | impl LayerNode { 17 | pub fn new(level: usize, idx: usize) -> LayerNode { 18 | LayerNode { 19 | level: level as u32, 20 | idx: idx as u32, 21 | visible: true, 22 | neighbors: ::std::collections::HashMap::new(), 23 | } 24 | } 25 | } 26 | 27 | #[derive(Clone, PartialEq, ::prost::Message)] 28 | pub struct Point { 29 | /// Using uint64 as a safe alternative for usize 30 | #[prost(uint32, tag = "1")] 31 | pub idx: u32, 32 | /// Vector of floats 33 | #[prost(float, repeated, tag = "2")] 34 | pub v: ::prost::alloc::vec::Vec, 35 | } 36 | impl Point { 37 | pub fn new(vec: Vec, idx: usize) -> Point { 38 | Point { 39 | idx: idx as u32, 40 | v: vec, 41 | } 42 | } 43 | } 44 | 45 | #[derive(Clone, PartialEq, ::prost::Message)] 46 | pub struct PointQuant { 47 | /// Using uint64 as a safe alternative for usize 48 | #[prost(uint32, tag = "1")] 49 | pub idx: u32, 50 | /// Vector of ints 51 | #[prost(uint32, repeated, tag = "2")] 52 | pub v: ::prost::alloc::vec::Vec, 53 | } 54 | impl PointQuant { 55 | pub fn new(vec: Vec, idx: usize) -> PointQuant { 56 | PointQuant { 57 | idx: idx as u32, 58 | v: vec, 59 | } 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /dria_hnsw/src/proto/insert.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package insert_buffer; 4 | 5 | message MetadataValue { 6 | oneof value_type { 7 | float float_value = 1; 8 | int64 int_value = 2; 9 | string string_value = 3; 10 | bool bool_value = 4; 11 | } 12 | } 13 | 14 | message SingletonVec { 15 | repeated float v = 1; // Vector of floats 16 | map map = 2; 17 | 18 | } 19 | 20 | message BatchVec { 21 | repeated SingletonVec s = 1; 22 | } 23 | 24 | 25 | 26 | message SingletonStr { 27 | string v = 1; // Vector of strings 28 | map map = 2; 29 | 30 | } 31 | 32 | message BatchStr { 33 | repeated SingletonStr s = 1; 34 | } 35 | -------------------------------------------------------------------------------- /dria_hnsw/src/proto/insert_buffer.rs: -------------------------------------------------------------------------------- 1 | use serde::ser::{Serialize, SerializeStruct, Serializer}; 2 | 3 | #[derive(Clone, PartialEq, ::prost::Message)] 4 | pub struct MetadataValue { 5 | #[prost(oneof = "metadata_value::ValueType", tags = "1, 2, 3, 4")] 6 | pub value_type: ::core::option::Option, 7 | } 8 | /// Nested message and enum types in `MetadataValue`. 9 | pub mod metadata_value { 10 | #[derive(Clone, PartialEq, ::prost::Oneof)] 11 | pub enum ValueType { 12 | #[prost(float, tag = "1")] 13 | FloatValue(f32), 14 | #[prost(int64, tag = "2")] 15 | IntValue(i64), 16 | #[prost(string, tag = "3")] 17 | StringValue(::prost::alloc::string::String), 18 | #[prost(bool, tag = "4")] 19 | BoolValue(bool), 20 | } 21 | } 22 | #[derive(Clone, PartialEq, ::prost::Message)] 23 | pub struct SingletonVec { 24 | /// Vector of floats 25 | #[prost(float, repeated, tag = "1")] 26 | pub v: ::prost::alloc::vec::Vec, 27 | #[prost(map = "string, message", tag = "2")] 28 | pub map: ::std::collections::HashMap<::prost::alloc::string::String, MetadataValue>, 29 | } 30 | #[derive(Clone, PartialEq, ::prost::Message)] 31 | pub struct BatchVec { 32 | #[prost(message, repeated, tag = "1")] 33 | pub s: ::prost::alloc::vec::Vec, 34 | } 35 | #[derive(Clone, PartialEq, ::prost::Message)] 36 | pub struct SingletonStr { 37 | /// Vector of strings 38 | #[prost(string, tag = "1")] 39 | pub v: ::prost::alloc::string::String, 40 | #[prost(map = "string, message", tag = "2")] 41 | pub map: ::std::collections::HashMap<::prost::alloc::string::String, MetadataValue>, 42 | } 43 | #[derive(Clone, PartialEq, ::prost::Message)] 44 | pub struct BatchStr { 45 | #[prost(message, repeated, tag = "1")] 46 | pub s: ::prost::alloc::vec::Vec, 47 | } 48 | 49 | impl Serialize for metadata_value::ValueType { 50 | fn serialize(&self, serializer: S) -> Result 51 | where 52 | S: Serializer, 53 | { 54 | // Serialize the value directly without wrapping it in a structure. 55 | match *self { 56 | metadata_value::ValueType::FloatValue(f) => serializer.serialize_f32(f), 57 | metadata_value::ValueType::IntValue(i) => serializer.serialize_i64(i), 58 | metadata_value::ValueType::StringValue(ref s) => serializer.serialize_str(s), 59 | metadata_value::ValueType::BoolValue(b) => serializer.serialize_bool(b), 60 | } 61 | } 62 | } 63 | 64 | impl Serialize for MetadataValue { 65 | fn serialize(&self, serializer: S) -> Result 66 | where 67 | S: Serializer, 68 | { 69 | // Use the inner value's serialization directly. 70 | if let Some(ref value_type) = self.value_type { 71 | value_type.serialize(serializer) 72 | } else { 73 | // Decide how you want to handle MetadataValue when it's None. 74 | // For example, you might serialize it as null or an empty object. 75 | serializer.serialize_none() 76 | } 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /dria_hnsw/src/proto/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod index_buffer; 2 | pub mod insert_buffer; 3 | -------------------------------------------------------------------------------- /dria_hnsw/src/proto/request.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package request_buffer; 4 | 5 | message Singleton { 6 | repeated float v = 1; // Vector of floats 7 | map metadata = 2; // Neighbor idx and its distance 8 | } 9 | 10 | message Batch { 11 | repeated string b = 1; 12 | } 13 | -------------------------------------------------------------------------------- /dria_hnsw/src/proto/request_buffer.rs: -------------------------------------------------------------------------------- 1 | #[derive(Clone, PartialEq, ::prost::Message)] 2 | pub struct Singleton { 3 | /// Vector of floats 4 | #[prost(float, repeated, tag = "1")] 5 | pub v: ::prost::alloc::vec::Vec, 6 | /// Neighbor idx and its distance 7 | #[prost(map = "string, string", tag = "2")] 8 | pub metadata: 9 | ::std::collections::HashMap<::prost::alloc::string::String, ::prost::alloc::string::String>, 10 | } 11 | #[derive(Clone, PartialEq, ::prost::Message)] 12 | pub struct Batch { 13 | #[prost(string, repeated, tag = "1")] 14 | pub b: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, 15 | } 16 | -------------------------------------------------------------------------------- /dria_hnsw/src/responses/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod responses; 2 | -------------------------------------------------------------------------------- /dria_hnsw/src/responses/responses.rs: -------------------------------------------------------------------------------- 1 | use serde::Serialize; 2 | 3 | #[derive(Serialize)] 4 | pub struct CustomResponse { 5 | pub(crate) success: bool, 6 | pub(crate) data: T, 7 | pub(crate) code: u32, 8 | } 9 | -------------------------------------------------------------------------------- /dria_hnsw/src/worker.rs: -------------------------------------------------------------------------------- 1 | use crate::db::env::Config; 2 | use crate::db::rocksdb_client::RocksdbClient; 3 | use crate::hnsw::index::HNSW; 4 | use crate::hnsw::sync_map::SynchronizedNodes; 5 | use crate::middlewares::cache::{NodeCache, PointCache}; 6 | use crate::models::request_models::{FetchModel, InsertBatchModel, QueryModel}; 7 | use crate::proto::index_buffer::{LayerNode, Point}; 8 | use crate::responses::responses::CustomResponse; 9 | use actix_web::web::{Data, Json}; 10 | use actix_web::{get, post, web, HttpMessage, HttpRequest, HttpResponse}; 11 | use log::error; 12 | use mini_moka::sync::Cache; 13 | use rayon::prelude::*; 14 | use serde::{Deserialize, Serialize}; 15 | use serde_json::{json, Value}; 16 | use std::borrow::Borrow; 17 | use std::sync::atomic::{AtomicIsize, AtomicUsize, Ordering}; 18 | use std::sync::Arc; 19 | use tokio::task; 20 | 21 | use crate::filter::text_based::create_index_from_docs; 22 | use probly_search::Index; 23 | 24 | pub const SINGLE_THREADED_HNSW_BUILD_THRESHOLD: usize = 256; 25 | 26 | #[get("/health")] 27 | pub async fn get_health_status() -> HttpResponse { 28 | let response = CustomResponse { 29 | success: true, 30 | data: "hello world!".to_string(), 31 | code: 200, 32 | }; 33 | HttpResponse::Ok().json(response) 34 | } 35 | 36 | #[post("/query")] 37 | pub async fn query(req: HttpRequest, payload: Json) -> HttpResponse { 38 | let ind: HNSW; 39 | 40 | let cfg = Config::new(); 41 | 42 | let rocksdb_client = req 43 | .app_data::>() 44 | .expect("Error getting rocksdb client"); 45 | 46 | ind = HNSW::new( 47 | 16, 48 | 128, 49 | ef_helper(payload.level), 50 | None, 51 | rocksdb_client.clone(), 52 | ); 53 | let node_cache = req 54 | .app_data::>() 55 | .expect("Error getting node cache"); //Arc = Arc::new(SynchronizedNodes::new()); 56 | let point_cache = req 57 | .app_data::>() 58 | .expect("Error getting point cache"); //Arc> = Arc::new(DashMap::new()); 59 | 60 | let node_map = node_cache.get_cache(cfg.contract_id.clone()); //Arc = Arc::new(SynchronizedNodes::new()); 61 | let point_map = point_cache.get_cache(cfg.contract_id.clone()); 62 | let res = ind.knn_search(&payload.vector, payload.top_n, node_map, point_map); 63 | 64 | if payload.query.is_some() { 65 | let mut index = Index::::new(1); 66 | let results = 67 | create_index_from_docs(&mut index, &payload.query.clone().unwrap(), res.clone()); 68 | let response = CustomResponse { 69 | success: true, 70 | data: json!(results), 71 | code: 200, 72 | }; 73 | return HttpResponse::Ok().json(response); 74 | } 75 | 76 | let response = CustomResponse { 77 | success: true, 78 | data: json!(res), 79 | code: 200, 80 | }; 81 | HttpResponse::Ok().json(response) 82 | } 83 | 84 | #[post("/fetch")] 85 | pub async fn fetch(req: HttpRequest, payload: Json) -> HttpResponse { 86 | let ind: HNSW; 87 | 88 | let rocksdb_client = req 89 | .app_data::>() 90 | .expect("Error getting rocksdb client"); 91 | 92 | ind = HNSW::new(16, 128, 0, None, rocksdb_client.clone()); 93 | 94 | let res = ind.db.get_metadatas(payload.id.clone()); 95 | 96 | if res.is_err() { 97 | let response = CustomResponse { 98 | success: false, 99 | data: "Error fetching metadata".to_string(), 100 | code: 500, 101 | }; 102 | return HttpResponse::InternalServerError().json(response); 103 | } 104 | 105 | let res = res.unwrap(); 106 | 107 | let response = CustomResponse { 108 | success: true, 109 | data: json!(res), 110 | code: 200, 111 | }; 112 | HttpResponse::Ok().json(response) 113 | } 114 | 115 | #[post("/insert_vector")] 116 | pub async fn insert_vector(req: HttpRequest, payload: Json) -> HttpResponse { 117 | let cfg = Config::new(); 118 | let cid = cfg.contract_id.clone(); 119 | 120 | let rocksdb_client = req 121 | .app_data::>() 122 | .expect("Error getting rocksdb client"); 123 | 124 | let rocksdb_client = rocksdb_client.clone(); 125 | 126 | let mut vectors = Vec::new(); 127 | let mut metadata_batch = Vec::new(); 128 | for d in payload.data.iter() { 129 | vectors.push(d.vector.clone()); 130 | metadata_batch.push(d.metadata.clone()); 131 | } 132 | 133 | if vectors.len() > 2500 { 134 | let response = CustomResponse { 135 | success: false, 136 | data: "Batch size should be smaller than 2500.".to_string(), 137 | code: 401, 138 | }; 139 | // TODO: fix status code 140 | return HttpResponse::InternalServerError().json(response); 141 | } 142 | 143 | let node_cache = req 144 | .app_data::>() 145 | .expect("Error getting node cache"); //Arc = Arc::new(SynchronizedNodes::new()); 146 | let point_cache = req 147 | .app_data::>() 148 | .expect("Error getting point cache"); //Arc> = Arc::new(DashMap::new()); 149 | 150 | let node_map = node_cache.get_cache(cid.clone()); //Arc = Arc::new(SynchronizedNodes::new()); 151 | let point_map = point_cache.get_cache(cid.clone()); //Arc> = Arc::new(DashMap::new()); 152 | let cid_clone = cid.clone(); 153 | let result = task::spawn_blocking(move || { 154 | train_worker( 155 | vectors, 156 | metadata_batch, 157 | node_map, 158 | point_map, 159 | rocksdb_client.clone(), 160 | 10_000, 161 | ) 162 | }) 163 | .await; 164 | 165 | let node_map = node_cache.get_cache(cid_clone); //Arc = Arc::new(SynchronizedNodes::new()); 166 | node_map.reset(); 167 | 168 | let (res, code) = result.expect("Error getting result"); 169 | 170 | if code != 200 { 171 | return HttpResponse::InternalServerError().json(CustomResponse { 172 | success: false, 173 | data: res, 174 | code: code as u32, 175 | }); 176 | } 177 | return HttpResponse::Ok().json(CustomResponse { 178 | success: true, 179 | data: "Values are successfully added to index.".to_string(), 180 | code: 200, 181 | }); 182 | } 183 | 184 | fn ef_helper(ef: Option) -> usize { 185 | let level = ef.clone().unwrap_or(1); 186 | 20 + (level * 30) 187 | } 188 | 189 | fn train_worker( 190 | vectors: Vec>, 191 | metadata_batch: Vec, 192 | node_map: Arc, 193 | point_map: Cache, 194 | rocksdb_client: Data, 195 | batch_size: usize, 196 | ) -> (String, u16) { 197 | let ind = HNSW::new(16, 128, ef_helper(Some(1)), None, rocksdb_client.clone()); 198 | 199 | let mut ds = 0; 200 | let nl = ind.db.get_num_layers(); 201 | let num_layers = Arc::new(AtomicUsize::new(0)); 202 | 203 | if nl.is_err() { 204 | error!("{}", nl.err().unwrap()); 205 | ind.db.set_datasize(0).expect("Error setting datasize"); 206 | } else { 207 | let nl_value = nl.expect("").clone(); 208 | ds = ind.db.get_datasize().expect("Error getting datasize"); 209 | 210 | let res = num_layers.fetch_update(Ordering::SeqCst, Ordering::Relaxed, |x| Some(nl_value)); 211 | 212 | if res.is_err() { 213 | error!("{}", res.err().unwrap()); 214 | return ("Error setting num layers, atomic".to_string(), 500); 215 | } 216 | } 217 | 218 | let r1 = ind.db.add_points_batch(&vectors, ds); 219 | let r2 = ind.db.set_metadata_batch(metadata_batch, ds); 220 | let r3 = ind.db.set_datasize(ds + vectors.len()); 221 | 222 | if r1.is_err() || r2.is_err() || r3.is_err() { 223 | error!("Error adding points as batch"); 224 | return ("Error adding points as batch".to_string(), 500); 225 | } 226 | 227 | let epa = Arc::new(AtomicIsize::new(-1)); 228 | 229 | let ep = ind.db.get_ep(); 230 | 231 | if ep.is_ok() { 232 | let ep_value = ep.expect("").clone(); 233 | let res = epa.fetch_update(Ordering::SeqCst, Ordering::Relaxed, |x| { 234 | // Your update logic here 235 | Some(ep_value as isize) 236 | }); 237 | if res.is_err() { 238 | error!("{}", res.err().unwrap()); 239 | return ("Error setting ep, atomic".to_string(), 500); 240 | } 241 | } 242 | let pool = rayon::ThreadPoolBuilder::new() 243 | .thread_name(|idx| format!("hnsw-build-{idx}")) 244 | .num_threads(8) 245 | .build() 246 | .expect("Error building threadpool"); 247 | 248 | if ds < SINGLE_THREADED_HNSW_BUILD_THRESHOLD { 249 | let iter_ind = vectors.len().min(SINGLE_THREADED_HNSW_BUILD_THRESHOLD - ds); 250 | for i in 0..iter_ind { 251 | ind.insert_w_preset( 252 | ds + i, 253 | node_map.clone(), 254 | point_map.clone(), 255 | num_layers.clone(), 256 | epa.clone(), 257 | ) 258 | .expect("Error inserting"); 259 | } 260 | 261 | if SINGLE_THREADED_HNSW_BUILD_THRESHOLD < (vectors.len() + ds) { 262 | pool.install(|| { 263 | (iter_ind..vectors.len()) 264 | .into_par_iter() 265 | .try_for_each(|item| { 266 | ind.insert_w_preset( 267 | ds + item, 268 | node_map.clone(), 269 | point_map.clone(), 270 | num_layers.clone(), 271 | epa.clone(), 272 | ) 273 | }) 274 | }) 275 | .expect("Error inserting"); 276 | } 277 | } else { 278 | pool.install(|| { 279 | (0..vectors.len()).into_par_iter().try_for_each(|item| { 280 | ind.insert_w_preset( 281 | item + ds, 282 | node_map.clone(), 283 | point_map.clone(), 284 | num_layers.clone(), 285 | epa.clone(), 286 | ) 287 | }) 288 | }) 289 | .expect("Error inserting"); 290 | } 291 | 292 | //replicate neighbors 293 | let values: Vec = node_map 294 | .clone() 295 | .map 296 | .iter() 297 | .map(|entry| entry.value().clone()) 298 | .collect(); 299 | 300 | let ep_value = epa.clone().load(Ordering::Relaxed); 301 | let num_layers = num_layers.clone().load(Ordering::Relaxed); 302 | 303 | let batch_size: usize = batch_size; 304 | 305 | for chunk in values.chunks(batch_size) { 306 | let r1 = ind.db.upsert_neighbors(chunk.to_vec()); 307 | if r1.is_err() { 308 | error!("{}", r1.err().unwrap()); 309 | return ("Error writing batch to blockchain".to_string(), 500); 310 | } 311 | } 312 | let r2 = ind.db.set_ep(ep_value as usize, false); 313 | let r3 = ind.db.set_num_layers(num_layers, false); 314 | 315 | if r2.is_err() || r3.is_err() { 316 | error!("Error writing to blockchain"); 317 | return ("Error writing batch to blockchain".to_string(), 500); 318 | } 319 | return ("Values are successfully added to index.".to_string(), 200); 320 | } 321 | 322 | #[cfg(test)] 323 | mod tests { 324 | use super::*; 325 | use actix_web::{http::header::ContentType, test, App}; 326 | use rand::{self, Rng}; 327 | use simple_home_dir::home_dir; 328 | use std::env; 329 | 330 | fn prepare_env() -> () { 331 | const CONTRACT_ID: &str = "WbcY2a-KfDpk7fsgumUtLC2bu4NQcVzNlXWi13fPMlU"; 332 | 333 | env::set_var("CONTRACT_ID", CONTRACT_ID); 334 | let mut rocksdb_path = home_dir().unwrap().to_str().unwrap().to_owned(); 335 | rocksdb_path.push_str("/.dria/data/"); 336 | rocksdb_path.push_str(CONTRACT_ID); 337 | env::set_var("ROCKSDB_PATH", rocksdb_path); 338 | } 339 | 340 | fn prepare_rocksdb() -> Data { 341 | let cfg = Config::new(); 342 | let rocksdb_client = RocksdbClient::new(cfg.contract_id.clone()).unwrap(); 343 | web::Data::new(rocksdb_client) 344 | } 345 | 346 | #[actix_web::test] 347 | async fn test_health() { 348 | let app = test::init_service(App::new().configure(|conf| { 349 | conf.service(get_health_status); 350 | })) 351 | .await; 352 | 353 | let req = test::TestRequest::get() 354 | .uri("/health") 355 | .insert_header(ContentType::plaintext()) 356 | .to_request(); 357 | let resp = test::call_service(&app, req).await; 358 | assert!(resp.status().is_success()); 359 | } 360 | 361 | #[actix_web::test] 362 | async fn test_fetch_and_query() { 363 | prepare_env(); 364 | let rocksdb_client = prepare_rocksdb(); 365 | let node_cache = web::Data::new(NodeCache::new()); 366 | let point_cache = web::Data::new(PointCache::new()); 367 | let app = test::init_service( 368 | App::new() 369 | .app_data(rocksdb_client) 370 | .app_data(node_cache) 371 | .app_data(point_cache) 372 | .configure(|conf| { 373 | conf.service(query).service(fetch); 374 | }), 375 | ) 376 | .await; 377 | 378 | // fetch 379 | let req = test::TestRequest::post() 380 | .uri("/fetch") 381 | .set_json(json!({ "id": [0] })) 382 | .to_request(); 383 | let resp = test::call_service(&app, req).await; 384 | assert!(resp.status().is_success()); 385 | 386 | // query 387 | let query_vector: Vec = (0..768).map(|_| rand::thread_rng().gen()).collect(); 388 | let req = test::TestRequest::post() 389 | .uri("/query") 390 | .set_json(json!({ 391 | "vector": query_vector, 392 | "top_n": 10 393 | })) 394 | .to_request(); 395 | let resp = test::call_service(&app, req).await; 396 | assert!(resp.status().is_success()); 397 | } 398 | } 399 | -------------------------------------------------------------------------------- /hollowdb/.dockerignore: -------------------------------------------------------------------------------- 1 | test 2 | config 3 | .vscode 4 | .github 5 | .env.example 6 | .gitignore 7 | .mocharc.json 8 | node_modules 9 | -------------------------------------------------------------------------------- /hollowdb/.env.example: -------------------------------------------------------------------------------- 1 | # contract to connect to 2 | CONTRACT_ID= 3 | 4 | # path to write rocksdb 5 | ROCKSDB_PATH= 6 | 7 | # Redis URL for the caching layer 8 | REDIS_URL= 9 | 10 | # Bundlr usage 11 | USE_BUNDLR=true 12 | 13 | # Treat values as `hash.txid` 14 | USE_HTX=true 15 | 16 | # Warp's log level 17 | WARP_LOG_LEVEL=info 18 | 19 | # Batch size 20 | BUNDLR_FBS=80 21 | -------------------------------------------------------------------------------- /hollowdb/.gitignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | build 3 | cache 4 | logs 5 | 6 | dump.rdb 7 | 8 | .vscode 9 | .DS_Store 10 | 11 | !src/cache 12 | 13 | .env 14 | tmp.md 15 | 16 | data 17 | 18 | .yarn 19 | -------------------------------------------------------------------------------- /hollowdb/.yarnrc.yml: -------------------------------------------------------------------------------- 1 | nodeLinker: node-modules 2 | -------------------------------------------------------------------------------- /hollowdb/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM node:20 as base 2 | WORKDIR /app 3 | 4 | COPY . . 5 | 6 | RUN yarn set version berry 7 | 8 | # yarn install might take time when a non-readonly volume is attached (see https://github.com/yarnpkg/yarn/issues/7747) 9 | # ~700 seconds on an M2 Macbook Air for example 10 | RUN yarn install 11 | 12 | # Build code 13 | FROM base as builder 14 | WORKDIR /app 15 | RUN yarn build 16 | 17 | # Install prod dependencies 18 | FROM base as dependencies 19 | WORKDIR /app 20 | RUN yarn workspaces focus --production 21 | 22 | # Slim has GLIBC needed by RocksDB, Alpine does not. 23 | FROM node:20-slim 24 | # RUN apk add gcompat 25 | 26 | WORKDIR /app 27 | COPY --from=builder /app/build ./build 28 | COPY --from=dependencies /app/node_modules ./node_modules 29 | 30 | EXPOSE 3000 31 | 32 | CMD ["node", "./build/index.js"] 33 | -------------------------------------------------------------------------------- /hollowdb/README.md: -------------------------------------------------------------------------------- 1 | # HollowDB API 2 | 3 | Each Dria knowledge is a smart-contract on Arweave, which serves the knowledge on permaweb via a key-value interface. To download them, we require HollowDB client, which is a Node package. Here we have a Fastify server that acts as an API for HollowDB, which we primarily use to download & unbundle the values for Dria HNSW to use. 4 | 5 | ## Installation 6 | 7 | Install the packages: 8 | 9 | ```sh 10 | yarn install 11 | ``` 12 | 13 | ## Usage 14 | 15 | To run the server, you need to provide a contract ID along with a RocksDB path: 16 | 17 | ```sh 18 | CONTRACT_ID= ROCKSDB_PATH="path/do/rocksdb" yarn start 19 | ``` 20 | 21 | HollowDB PAI is available as a container: 22 | 23 | ```sh 24 | docker pull firstbatch/dria-hollowdb 25 | ``` 26 | 27 | In both cases, you will need a Redis container running at the URL defined by `REDIS_URL` environment variable. 28 | 29 | ### Configurations 30 | 31 | There are several environment variables to configure the server. You can provide them within the command line, or via `.env` file. An example is given [here](./.env.example). 32 | 33 | - `REDIS_URL=`
You need a Redis server running before you start the server, the URL to the server can be provided with a `REDIS_URL` environment variable. The connection URL defaults to `redis://default:redispw@localhost:6379`. 34 | 35 | - `WARP_LOG_LEVEL=`
By default Warp will log at `info` level, but you can change it via the `WARP_LOG_LEVEL` environment variable. Options are the known levels of `debug`, `error`, `fatal`, `info`, `none`, `silly`, `trace` and `warn`. 36 | 37 | - `USE_BUNDLR=`
You can treat the values as transaction ids if `USE_BUNDLR` environment variable is set to be `true`. When this is the case, `REFRESH` will actually fetch the uploaded data and download it to Redis. 38 | 39 | > [!WARNING] 40 | > 41 | > Uploading to Bundlr via `PUT` or `PUT_MANY` is not yet implemented. 42 | 43 | - `USE_HTX=`
When we have `USE_BUNDLR=true` we treat the stored values as transaction ids; however, HollowDB may have an alternative approach where values are stored as `hash.txid` (due to [this implementation](https://github.com/firstbatchxyz/hollowdb/blob/master/src/contracts/hollowdb-htx.contract.ts)). To comply with this approach, set `USE_HTX=true`. 44 | 45 | - `BUNDLR_FBS=`
When using Bundlr, downloading values from Arweave cannot be done in a huge `Promise.all`, as it causes timeouts. We instead download values in batches, defaulting to 40 values per batch. To override the batch size, you can provide an integer value to this variable. 46 | 47 | ## Endpoints 48 | 49 | HollowDB API exposes the following endpoints: 50 | 51 | - GET [`/state`](#state) 52 | - POST [`/get`](#get) 53 | - POST [`/getRaw`](#getraw) 54 | - POST [`/getMany`](#getmany) 55 | - POST [`/getManyRaw `](#getmanyraw) 56 | - POST [`put`](#put) 57 | - POST [`putMany`](#putmany) 58 | - POST [`update`](#update) 59 | - POST [`remove`](#remove) 60 | - POST [`refresh`](#refresh) 61 | - POST [`clear`](#clear) 62 | 63 | ### `get` 64 | 65 | ```ts 66 | interface { 67 | key: string 68 | } 69 | 70 | // response body 71 | interface { 72 | value: any 73 | } 74 | ``` 75 | 76 | Returns the value at the given key. 77 | 78 | ### `getRaw` 79 | 80 | ```ts 81 | 82 | // response body 83 | interface { 84 | value: any 85 | } 86 | ``` 87 | 88 | Returns the value at the given `key`, directly from the cache layer & without involving Warp or Arweave. 89 | 90 | ### `getMany` 91 | 92 | ```ts 93 | interface { 94 | keys: string[] 95 | } 96 | 97 | // response body 98 | interface { 99 | values: any[] 100 | } 101 | ``` 102 | 103 | Returns the values at the given `keys`. 104 | 105 | ### `getManyRaw` 106 | 107 | ```ts 108 | interface { 109 | keys: string[] 110 | } 111 | 112 | // response body 113 | interface { 114 | values: any[] 115 | } 116 | ``` 117 | 118 | Returns the values at the given `keys`, reads directly from the storage. 119 | 120 | This has the advantage of not being bound to the interaction size limit, however, the user must check that the data is fresh with their own methods. 121 | Furthermore, you must make a call to `REFRESH` before using this endpoint, and subsequent calls to `REFRESH` will update the data with the new on-chain values. 122 | 123 | ### `put` 124 | 125 | ```ts 126 | interface { 127 | key: string, 128 | value: any 129 | } 130 | ``` 131 | 132 | Puts `value` at the given `key`. The key must not exist already, or it must have `null` stored at it. 133 | 134 | ### `putMany` 135 | 136 | ```ts 137 | interface { 138 | keys: string[], 139 | values: any[] 140 | } 141 | ``` 142 | 143 | Updates given `keys` with the provided `values`. No key must exist already in the database. 144 | 145 | ### `update` 146 | 147 | ```ts 148 | interface { 149 | key: string, 150 | value: any, 151 | proof?: object 152 | } 153 | ``` 154 | 155 | Updates a `key` with the provided `value` and an optional `proof`. 156 | 157 | ### `remove` 158 | 159 | ```ts 160 | interface { 161 | key: string, 162 | proof?: object 163 | } 164 | ``` 165 | 166 | Removes the value at `key`, along with an optional `proof`. 167 | 168 | ### `state` 169 | 170 | Syncs & fetches the latest contract state, and returns it. 171 | 172 | ### `refresh` 173 | 174 | Syncs & fetches the latest state and stores the latest sort key for each key in the database. Returns the number of keys refreshed for diagnostic purposes. 175 | 176 | ### `clear` 177 | 178 | ```ts 179 | interface { 180 | keys?: string[] 181 | } 182 | ``` 183 | 184 | Clears the contents for given `keys` with respect to the values written by `REFRESH` endpoint. One might want to refresh some keys again, without flushing the entire database, so that is the purpose of this endpoint. Returns the number of keys cleared for diagnostic purposes. 185 | 186 | > [!TIP] 187 | > 188 | > If no `keys` are given to the `CLEAR` endpoint (i.e. `keys = undefined`) then this will clear **all keys**. 189 | 190 | ## Testing 191 | 192 | We have tests that roll a local Arweave and run tests on them with the micro server in the middle. 193 | 194 | > [!NOTE] 195 | > 196 | > You need a Redis server running in the background for the tests to work. 197 | 198 | To run tests, do: 199 | 200 | ```sh 201 | yarn test 202 | ``` 203 | -------------------------------------------------------------------------------- /hollowdb/config/.gitignore: -------------------------------------------------------------------------------- 1 | # hide yo wallet 2 | *.json 3 | -------------------------------------------------------------------------------- /hollowdb/jest.config.ts: -------------------------------------------------------------------------------- 1 | import type { JestConfigWithTsJest } from "ts-jest"; 2 | 3 | const config: JestConfigWithTsJest = { 4 | // ts-jest defaults 5 | preset: "ts-jest", 6 | testEnvironment: "node", 7 | transform: { 8 | "^.+\\.(ts|js)$": "ts-jest", 9 | }, 10 | // environment setup & teardown scripts 11 | // setup & teardown for spinning up arlocal 12 | // globalSetup: "/tests/environment/setup.ts", 13 | // globalTeardown: "/tests/environment/teardown.ts", 14 | // timeout should be rather large, especially for the workflows 15 | testTimeout: 60000, 16 | // warp & arlocal takes some time to close, so make this 10 secs 17 | openHandlesTimeout: 10000, 18 | // print everything like Mocha 19 | verbose: true, 20 | // tests may hang randomly (not known why yet, it was fixed before) 21 | // that will cause workflow to run all the way, so we might force exit 22 | forceExit: true, 23 | }; 24 | 25 | export default config; 26 | -------------------------------------------------------------------------------- /hollowdb/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "author": "FirstBatch Team ", 3 | "private": true, 4 | "contributors": [ 5 | "Erhan Tezcan " 6 | ], 7 | "version": "0.1.0", 8 | "main": "index.js", 9 | "scripts": { 10 | "build": "tsc -p tsconfig.build.json", 11 | "clean": "rm -rf ./build", 12 | "start": "yarn build && node ./build/index.js", 13 | "test": "jest" 14 | }, 15 | "dependencies": { 16 | "@fastify/type-provider-typebox": "^4.0.0", 17 | "@sinclair/typebox": "^0.32.14", 18 | "axios": "^1.6.7", 19 | "fastify": "^4.26.1", 20 | "hollowdb": "^1.4.1", 21 | "ioredis": "^5.3.2", 22 | "levelup": "^5.1.1", 23 | "pino-pretty": "^10.3.1", 24 | "rocksdb": "^5.2.1", 25 | "warp-contracts": "^1.4.34", 26 | "warp-contracts-redis": "^0.4.2" 27 | }, 28 | "devDependencies": { 29 | "@types/jest": "^29.5.12", 30 | "@types/levelup": "^5.1.5", 31 | "@types/node": "^20.11.17", 32 | "@types/rocksdb": "^3.0.5", 33 | "arlocal": "^1.1.65", 34 | "jest": "^29.7.0", 35 | "lorem-ipsum": "^2.0.8", 36 | "ts-jest": "^29.1.2", 37 | "ts-node": "^10.9.2", 38 | "typescript": "^5.3.3", 39 | "warp-contracts-plugin-deploy": "^1.0.13" 40 | }, 41 | "packageManager": "yarn@4.1.0", 42 | "prettier": { 43 | "printWidth": 120 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /hollowdb/src/clients/hollowdb.ts: -------------------------------------------------------------------------------- 1 | import { WarpFactory, Warp } from "warp-contracts"; 2 | import type { CacheOptions } from "warp-contracts"; 3 | import type { Redis } from "ioredis"; 4 | import { RedisCache, type RedisOptions } from "warp-contracts-redis"; 5 | import type { CacheTypes } from "../types"; 6 | 7 | /** 8 | * Utility to create Warp Redis caches. 9 | * 10 | * @param contractTxId contract transaction id to be used as prefix in the keys 11 | * @param client Redis client to use a self-managed cache 12 | * @returns caches 13 | */ 14 | export function createCaches(contractTxId: string, client: Redis): CacheTypes { 15 | const defaultCacheOptions: CacheOptions = { 16 | inMemory: true, 17 | subLevelSeparator: "|", 18 | dbLocation: "redis.hollowdb", 19 | }; 20 | 21 | // if a client exists, use it; otherwise connect via URL 22 | const redisOptions: RedisOptions = { client }; 23 | 24 | return { 25 | state: new RedisCache( 26 | { 27 | ...defaultCacheOptions, 28 | dbLocation: `${contractTxId}.state`, 29 | }, 30 | redisOptions 31 | ), 32 | contract: new RedisCache( 33 | { 34 | ...defaultCacheOptions, 35 | dbLocation: `${contractTxId}.contract`, 36 | }, 37 | redisOptions 38 | ), 39 | src: new RedisCache( 40 | { 41 | ...defaultCacheOptions, 42 | dbLocation: `${contractTxId}.src`, 43 | }, 44 | redisOptions 45 | ), 46 | kvFactory: (contractTxId: string) => 47 | new RedisCache( 48 | { 49 | ...defaultCacheOptions, 50 | dbLocation: `${contractTxId}.kv`, 51 | }, 52 | redisOptions 53 | ), 54 | }; 55 | } 56 | 57 | /** Creates a Warp instance connected to mainnet. */ 58 | export function makeWarp(caches: CacheTypes): Warp { 59 | return WarpFactory.forMainnet() 60 | .useStateCache(caches.state) 61 | .useContractCache(caches.contract, caches.src) 62 | .useKVStorageFactory(caches.kvFactory); 63 | } 64 | -------------------------------------------------------------------------------- /hollowdb/src/clients/rocksdb.ts: -------------------------------------------------------------------------------- 1 | import Levelup from "levelup"; 2 | import Rocksdb from "rocksdb"; 3 | import { toValueKey } from "../utilities/download"; 4 | import { existsSync, mkdirSync } from "fs"; 5 | 6 | export class RocksdbClient { 7 | db: ReturnType; 8 | contractTxId: string; 9 | 10 | constructor(path: string, contractTxId: string) { 11 | if (!existsSync(path)) { 12 | mkdirSync(path, { recursive: true }); 13 | } 14 | 15 | this.db = Levelup(Rocksdb(path)); 16 | this.contractTxId = contractTxId; 17 | } 18 | 19 | async close() { 20 | if (!this.db.isClosed()) { 21 | await this.db.close(); 22 | } 23 | } 24 | 25 | async open() { 26 | if (!this.db.isOpen()) { 27 | await this.db.open(); 28 | } 29 | } 30 | 31 | async get(key: string) { 32 | const value = await this.db.get(toValueKey(this.contractTxId, key)); 33 | return this.tryParse(value); 34 | } 35 | 36 | async getMany(keys: string[]) { 37 | const values = await this.db.getMany(keys.map((k) => toValueKey(this.contractTxId, k))); 38 | return values.map((v) => this.tryParse(v)); 39 | } 40 | 41 | async set(key: string, value: string) { 42 | await this.db.put(toValueKey(this.contractTxId, key), value); 43 | } 44 | 45 | async setMany(pairs: [string, string][]) { 46 | await this.db.batch( 47 | pairs.map(([key, value]) => ({ 48 | type: "put", 49 | key: toValueKey(this.contractTxId, key), 50 | value: value, 51 | })) 52 | ); 53 | } 54 | 55 | async remove(key: string) { 56 | await this.db.del(toValueKey(this.contractTxId, key)); 57 | } 58 | 59 | async removeMany(keys: string[]) { 60 | await this.db.batch(keys.map((key) => ({ type: "del", key: toValueKey(this.contractTxId, key) }))); 61 | } 62 | 63 | /** 64 | * Given a value, tries to `JSON.parse` it and if parsing fails 65 | * it will return the value as-is. 66 | * 67 | * This is particularly useful when there are some stringified values, 68 | * and some other non-stringified strings together. 69 | * 70 | * @param value a stringified or null value 71 | * @template V type of the expected value 72 | * @returns parsed value or `null` 73 | */ 74 | private tryParse(value: Rocksdb.Bytes | null): V | null { 75 | let result = null; 76 | 77 | if (value) { 78 | try { 79 | result = JSON.parse(value.toString()); 80 | } catch (err) { 81 | // FIXME: return null here? 82 | result = value; 83 | } 84 | } 85 | 86 | return result as V; 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /hollowdb/src/configurations/index.ts: -------------------------------------------------------------------------------- 1 | import type { LogLevel } from "fastify"; 2 | import type { LoggerFactory } from "warp-contracts"; 3 | 4 | export default { 5 | /** Port that is listened by HollowDB. */ 6 | PORT: parseInt(process.env.PORT ?? "3030"), 7 | /** Redis URL to connect to. Defaults to `redis://default:redispw@localhost:6379`. */ 8 | REDIS_URL: process.env.REDIS_URL ?? "redis://default:redispw@localhost:6379", 9 | /** Path to Rocksdb storage. */ 10 | ROCKSDB_PATH: process.env.ROCKSDB_PATH ?? "./data/values", 11 | /** Treat values as Bundlr txIds. */ 12 | USE_BUNDLR: process.env.USE_BUNDLR ? process.env.USE_BUNDLR === "true" : process.env.NODE_ENV !== "test", 13 | /** Use the optimized [`hash.txid`](https://github.com/firstbatchxyz/hollowdb/blob/master/src/contracts/hollowdb-htx.contract.ts) layout for values. */ 14 | USE_HTX: process.env.USE_HTX ? process.env.USE_HTX === "true" : process.env.NODE_ENV !== "test", 15 | /** Log level for underlying Warp. */ 16 | WARP_LOG_LEVEL: (process.env.WARP_LOG_LEVEL ?? "info") as Parameters[0], 17 | /** How many fetches at once should be made to download Bundlr data? FBS stands for "Fetch Batch Size". */ 18 | BUNDLR_FBS: parseInt(process.env.BUNDLR_FBS ?? "40"), 19 | /** Configurations for Bundlr downloads. */ 20 | DOWNLOAD: { 21 | /** Download URL for the bundled data. */ 22 | BASE_URL: "https://arweave.net", 23 | /** Max allowed timeout (milliseconds). */ 24 | TIMEOUT: 50_000, 25 | /** Max attempts to retry on caught errors. */ 26 | MAX_ATTEMPTS: 5, 27 | /** Time to sleep (ms) between each attempt. */ 28 | ATTEMPT_SLEEP: 1000, 29 | }, 30 | /** Logging stuff for the server. */ 31 | LOG: { 32 | LEVEL: "info" satisfies LogLevel as LogLevel, // for some reason, :LogLevel doesnt work well 33 | REDACT: [ 34 | "reqId", 35 | "res.remoteAddress", 36 | "res.remotePort", 37 | "res.hostname", 38 | "req.remoteAddress", 39 | "req.remotePort", 40 | "req.hostname", 41 | ] as string[], 42 | }, 43 | } as const; 44 | -------------------------------------------------------------------------------- /hollowdb/src/controllers/read.ts: -------------------------------------------------------------------------------- 1 | import type { RouteHandler } from "fastify"; 2 | import type { Get, GetMany } from "../schemas"; 3 | import { RocksdbClient } from "../clients/rocksdb"; 4 | 5 | export const get: RouteHandler<{ Body: Get }> = async ({ server, body }) => { 6 | const value = await server.hollowdb.get(body.key); 7 | return { value }; 8 | }; 9 | 10 | export const getRaw: RouteHandler<{ Body: Get }> = async ({ server, body }) => { 11 | const rocksdb = new RocksdbClient(server.rocksdbPath, server.hollowdb.contractTxId); 12 | 13 | await rocksdb.open(); 14 | const value = await rocksdb.get(body.key); 15 | await rocksdb.close(); 16 | 17 | return { value }; 18 | }; 19 | 20 | export const getMany: RouteHandler<{ Body: GetMany }> = async ({ server, body }) => { 21 | const values = await server.hollowdb.getMany(body.keys); 22 | return { values }; 23 | }; 24 | 25 | export const getManyRaw: RouteHandler<{ Body: GetMany }> = async ({ server, body }) => { 26 | const rocksdb = new RocksdbClient(server.rocksdbPath, server.hollowdb.contractTxId); 27 | await rocksdb.open(); 28 | const values = await rocksdb.getMany(body.keys); 29 | await rocksdb.close(); 30 | 31 | return { values }; 32 | }; 33 | 34 | export const state: RouteHandler = async ({ server }) => { 35 | return await server.hollowdb.getState(); 36 | }; 37 | -------------------------------------------------------------------------------- /hollowdb/src/controllers/values.ts: -------------------------------------------------------------------------------- 1 | import { RouteHandler } from "fastify"; 2 | import type { Clear } from "../schemas"; 3 | import type { Redis } from "ioredis"; 4 | import { lastPossibleSortKey } from "warp-contracts"; 5 | import { RocksdbClient } from "../clients/rocksdb"; 6 | import { toSortKeyKey } from "../utilities/download"; 7 | import { refreshKeys } from "../utilities/refresh"; 8 | 9 | export const refresh: RouteHandler = async ({ server }, reply) => { 10 | const numKeysRefreshed = await refreshKeys(server); 11 | return reply.send(numKeysRefreshed); 12 | }; 13 | 14 | export const clear: RouteHandler<{ Body: Clear }> = async ({ server, body }, reply) => { 15 | const kv = server.hollowdb.base.warp.kvStorageFactory(server.hollowdb.contractTxId); 16 | const redis = kv.storage(); 17 | 18 | // get all existing keys (without sortKey) 19 | const keys = body.keys ?? (await kv.keys(lastPossibleSortKey)); 20 | 21 | // delete the sortKey mappings from Redis 22 | await redis.del(...keys.map((key) => toSortKeyKey(server.hollowdb.contractTxId, key))); 23 | 24 | // delete the values from Rocksdb 25 | const rocksdb = new RocksdbClient(server.rocksdbPath, server.hollowdb.contractTxId); 26 | await rocksdb.open(); 27 | await rocksdb.removeMany(keys); 28 | await rocksdb.close(); 29 | 30 | return reply.send(); 31 | }; 32 | -------------------------------------------------------------------------------- /hollowdb/src/controllers/write.ts: -------------------------------------------------------------------------------- 1 | import type { RouteHandler } from "fastify"; 2 | import type { Put, PutMany, Remove, Set, SetMany, Update } from "../schemas"; 3 | 4 | export const put: RouteHandler<{ Body: Put }> = async ({ server, body }, reply) => { 5 | await server.hollowdb.put(body.key, body.value); 6 | return reply.send(); 7 | }; 8 | 9 | export const putMany: RouteHandler<{ Body: PutMany }> = async ({ server, body }, reply) => { 10 | await server.hollowdb.putMany(body.keys, body.values); 11 | return reply.send(); 12 | }; 13 | 14 | export const set: RouteHandler<{ Body: Set }> = async ({ server, body }, reply) => { 15 | await server.hollowdb.set(body.key, body.value); 16 | return reply.send(); 17 | }; 18 | 19 | export const setMany: RouteHandler<{ Body: SetMany }> = async ({ server, body }, reply) => { 20 | await server.hollowdb.setMany(body.keys, body.values); 21 | return reply.send(); 22 | }; 23 | 24 | export const update: RouteHandler<{ Body: Update }> = async ({ server, body }, reply) => { 25 | await server.hollowdb.update(body.key, body.value, body.proof); 26 | return reply.send(); 27 | }; 28 | 29 | export const remove: RouteHandler<{ Body: Remove }> = async ({ server, body }, reply) => { 30 | await server.hollowdb.remove(body.key, body.proof); 31 | return reply.send(); 32 | }; 33 | -------------------------------------------------------------------------------- /hollowdb/src/global.d.ts: -------------------------------------------------------------------------------- 1 | import fastify from "fastify"; 2 | import { SetSDK } from "hollowdb"; 3 | import http from "http"; 4 | 5 | declare module "fastify" { 6 | export interface FastifyInstance< 7 | HttpServer = http.Server, 8 | HttpRequest = http.IncomingMessage, 9 | HttpResponse = http.ServerResponse 10 | > { 11 | /** HollowDB decorator. */ 12 | hollowdb: SetSDK; 13 | /** Contract ID. */ 14 | contractTxId: string; 15 | /** RocksDB Path. */ 16 | rocksdbPath: string; 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /hollowdb/src/index.ts: -------------------------------------------------------------------------------- 1 | import { SetSDK } from "hollowdb"; 2 | import { Redis } from "ioredis"; 3 | import { makeServer } from "./server"; 4 | import configurations from "./configurations"; 5 | import { createCaches, makeWarp } from "./clients/hollowdb"; 6 | 7 | async function main() { 8 | const contractId = process.env.CONTRACT_ID; 9 | if (!contractId) { 10 | throw new Error("Please provide CONTRACT_ID environment variable."); 11 | } 12 | if (Buffer.from(contractId, "base64").toString("hex").length !== 64) { 13 | throw new Error("Invalid CONTRACT_ID."); 14 | } 15 | 16 | // ping redis to make sure connection is there before moving on 17 | const redisClient = new Redis(configurations.REDIS_URL); 18 | await redisClient.ping(); 19 | 20 | // create Redis caches & use them for Warp 21 | const caches = createCaches(contractId, redisClient); 22 | const warp = makeWarp(caches); 23 | 24 | // create a random wallet, which is ok since we only make read operations 25 | // TODO: or, we can use a dummy throw-away wallet every time? 26 | const wallet = await warp.generateWallet(); 27 | 28 | const hollowdb = new SetSDK(wallet.jwk, contractId, warp); 29 | const server = await makeServer(hollowdb, configurations.ROCKSDB_PATH); 30 | const addr = await server.listen({ 31 | port: configurations.PORT, 32 | // host is set to listen on all interfaces to allow Docker internal network to work 33 | // see: https://fastify.dev/docs/latest/Reference/Server/#listentextresolver 34 | host: "::", 35 | listenTextResolver: (address) => { 36 | return `HollowDB is listening at ${address}`; 37 | }, 38 | }); 39 | server.log.info(`Listening at: ${addr}`); 40 | } 41 | 42 | main(); 43 | -------------------------------------------------------------------------------- /hollowdb/src/schemas/index.ts: -------------------------------------------------------------------------------- 1 | import { Static, Type } from "@sinclair/typebox"; 2 | 3 | export const Get = Type.Object({ 4 | key: Type.String(), 5 | }); 6 | export type Get = Static; 7 | 8 | export const GetMany = Type.Object({ 9 | keys: Type.Array(Type.String()), 10 | }); 11 | export type GetMany = Static; 12 | 13 | export const Put = Type.Object({ 14 | key: Type.String(), 15 | value: Type.Any(), 16 | }); 17 | export type Put = Static; 18 | 19 | export const PutMany = Type.Object({ 20 | keys: Type.Array(Type.String()), 21 | values: Type.Array(Type.Any()), 22 | }); 23 | export type PutMany = Static; 24 | 25 | export const Remove = Type.Object({ 26 | key: Type.String(), 27 | proof: Type.Optional(Type.Any()), 28 | }); 29 | export type Remove = Static; 30 | 31 | export const Update = Type.Object({ 32 | key: Type.String(), 33 | value: Type.Any(), 34 | proof: Type.Optional(Type.Any()), 35 | }); 36 | export type Update = Static; 37 | 38 | export const Set = Type.Object({ 39 | key: Type.String(), 40 | value: Type.Any(), 41 | }); 42 | export type Set = Static; 43 | 44 | export const SetMany = Type.Object({ 45 | keys: Type.Array(Type.String()), 46 | values: Type.Array(Type.Any()), 47 | }); 48 | export type SetMany = Static; 49 | 50 | export const Clear = Type.Object({ 51 | keys: Type.Optional(Type.Array(Type.String())), 52 | }); 53 | export type Clear = Static; 54 | -------------------------------------------------------------------------------- /hollowdb/src/server.ts: -------------------------------------------------------------------------------- 1 | import fastify, { LogLevel } from "fastify"; 2 | import type { TypeBoxTypeProvider } from "@fastify/type-provider-typebox"; 3 | import { get, getMany, getManyRaw, getRaw, state } from "./controllers/read"; 4 | import { put, putMany, remove, set, setMany, update } from "./controllers/write"; 5 | import { clear, refresh } from "./controllers/values"; 6 | import { Clear, Get, GetMany, Put, PutMany, Remove, Set, SetMany, Update } from "./schemas"; 7 | import { SetSDK } from "hollowdb"; 8 | import { LoggerFactory } from "warp-contracts"; 9 | import configurations from "./configurations"; 10 | import { refreshKeys } from "./utilities/refresh"; 11 | import { Redis } from "ioredis"; 12 | 13 | export async function makeServer(hollowdb: SetSDK, rocksdbPath: string) { 14 | const server = fastify({ 15 | logger: { 16 | level: configurations.LOG.LEVEL, 17 | transport: { target: "pino-pretty" }, 18 | redact: { 19 | paths: configurations.LOG.REDACT, 20 | remove: true, 21 | }, 22 | }, 23 | }).withTypeProvider(); 24 | LoggerFactory.INST.logLevel(configurations.LOG.LEVEL === "silent" ? "none" : configurations.LOG.LEVEL); 25 | 26 | server.decorate("hollowdb", hollowdb); 27 | server.decorate("contractTxId", hollowdb.contractTxId); 28 | server.decorate("rocksdbPath", rocksdbPath); // TODO: store RocksDB itself here maybe? 29 | 30 | // check redis 31 | await server.hollowdb.base.warp.kvStorageFactory(server.hollowdb.contractTxId).storage().ping(); 32 | 33 | // refresh keys 34 | server.log.info("Waiting for cache to be loaded."); 35 | const numKeysRefreshed = await refreshKeys(server); 36 | 37 | server.log.info(`Server synced & ready! (${numKeysRefreshed} keys refreshed)`); 38 | server.log.info(`> Redis: ${configurations.REDIS_URL}`); 39 | server.log.info(`> RocksDB: ${configurations.ROCKSDB_PATH}`); 40 | server.log.info(`> Download URL: ${configurations.DOWNLOAD.BASE_URL} (timeout ${configurations.DOWNLOAD.TIMEOUT})`); 41 | server.log.info(`> Contract: https://sonar.warp.cc/#/app/contract/${server.contractTxId}`); 42 | 43 | server.get("/state", state); 44 | server.post("/get", { schema: { body: Get } }, get); 45 | server.post("/getRaw", { schema: { body: Get } }, getRaw); 46 | server.post("/getMany", { schema: { body: GetMany } }, getMany); 47 | server.post("/getManyRaw", { schema: { body: GetMany } }, getManyRaw); 48 | server.post("/put", { schema: { body: Put } }, put); 49 | server.post("/putMany", { schema: { body: PutMany } }, putMany); 50 | server.post("/set", { schema: { body: Set } }, set); 51 | server.post("/setMany", { schema: { body: SetMany } }, setMany); 52 | server.post("/update", { schema: { body: Update } }, update); 53 | server.post("/remove", { schema: { body: Remove } }, remove); 54 | server.post("/clear", { schema: { body: Clear } }, clear); 55 | server.post("/refresh", refresh); 56 | 57 | server.addHook("onError", (request, reply, error, done) => { 58 | if (error.message.startsWith("Contract Error")) { 59 | reply.status(400); 60 | } 61 | done(); 62 | }); 63 | 64 | return server; 65 | } 66 | -------------------------------------------------------------------------------- /hollowdb/src/types/index.ts: -------------------------------------------------------------------------------- 1 | import type { SortKeyCacheResult } from "warp-contracts"; 2 | import type { RedisCache } from "warp-contracts-redis"; 3 | 4 | /** Cache types used by `Warp`. */ 5 | export type CacheTypes = { 6 | state: C; 7 | contract: C; 8 | src: C; 9 | kvFactory: (contractTxId: string) => C; 10 | }; 11 | 12 | /** 13 | * A `SortKeyCacheResult` with its respective `key` attached to it. 14 | * 15 | * - `key` is accessed as `.key` 16 | * - `sortKey` is accessed as `.sortKeyCacheResult.sortKey` 17 | * - `value` is accessed as `.sortKeyCacheResult.value` 18 | */ 19 | export type KeyedSortKeyCacheResult = { sortKeyCacheResult: SortKeyCacheResult; key: string }; 20 | -------------------------------------------------------------------------------- /hollowdb/src/utilities/download.ts: -------------------------------------------------------------------------------- 1 | import axios from "axios"; 2 | import configurations from "../configurations"; 3 | import { FastifyBaseLogger } from "fastify"; 4 | import { sleep } from "warp-contracts"; 5 | 6 | /** Downloads an object from Bundlr network w.r.t transaction id. 7 | * 8 | * If USE_HTX is enabled, it means that values are stored as `hash.txId` (as a string), 9 | * so to get the txid we split by `.` and then get the second element. 10 | * 11 | * @param txid transaction ID on Arweave 12 | * @template V type of the value 13 | * @returns unbundled raw value 14 | */ 15 | export async function downloadFromBundlr(txid: string, log: FastifyBaseLogger) { 16 | if (configurations.USE_HTX) { 17 | const split = txid.split("."); 18 | if (split.length !== 2) { 19 | log.warn("Expected value to be a hash.txid, received: " + txid); 20 | } 21 | txid = split[1]; 22 | } 23 | 24 | const url = `${configurations.DOWNLOAD.BASE_URL}/${txid}`; 25 | 26 | // try for a few times 27 | for (let a = 0; a <= configurations.DOWNLOAD.MAX_ATTEMPTS; ++a) { 28 | try { 29 | const response = await axios.get(url, { 30 | timeout: configurations.DOWNLOAD.TIMEOUT, 31 | }); 32 | if (response.status !== 200) { 33 | throw new Error(`Bundlr failed with ${response.status}`); 34 | } 35 | return response.data as V; 36 | } catch (err) { 37 | if (a === configurations.DOWNLOAD.MAX_ATTEMPTS) { 38 | throw err; 39 | } else { 40 | log.warn( 41 | `(tries: ${a + 1}/${configurations.DOWNLOAD.MAX_ATTEMPTS})` + 42 | `\tError downloading ${url}: ${(err as Error).message}` 43 | ); 44 | await sleep(configurations.DOWNLOAD.ATTEMPT_SLEEP); 45 | } 46 | } 47 | } 48 | 49 | throw new Error("All attempts failed."); 50 | } 51 | 52 | /** Returns a pretty string about the current download progress. 53 | * @param cur current value, can be more than `max` 54 | * @param max maximum value 55 | * @param decimals (optional) number of decimals for the percentage (default: 2) 56 | * @returns progress description 57 | */ 58 | export function progressString(cur: number, max: number, decimals: number = 2) { 59 | const val = Math.min(cur, max); 60 | const percentage = (val / max) * 100; 61 | return `[${val} / ${max}] (${percentage.toFixed(decimals)}%)`; 62 | } 63 | 64 | /** 65 | * Map a given key to a value key. 66 | * @param contractTxId contract txID 67 | * @param key key 68 | * @returns value key 69 | */ 70 | export function toValueKey(contractTxId: string, key: string) { 71 | return `${contractTxId}.value.${key}`; 72 | } 73 | 74 | /** 75 | * Map a given key to a sortKey key. 76 | * @param contractTxId contract txID 77 | * @param key key 78 | * @returns sortKey key 79 | */ 80 | export function toSortKeyKey(contractTxId: string, key: string) { 81 | return `${contractTxId}.sortKey.${key}`; 82 | } 83 | -------------------------------------------------------------------------------- /hollowdb/src/utilities/refresh.ts: -------------------------------------------------------------------------------- 1 | import type { FastifyInstance } from "fastify"; 2 | import type { Redis } from "ioredis"; 3 | import { lastPossibleSortKey } from "warp-contracts"; 4 | import type { KeyedSortKeyCacheResult } from "../types"; 5 | import { downloadFromBundlr, progressString, toSortKeyKey } from "./download"; 6 | import { RocksdbClient } from "../clients/rocksdb"; 7 | import configurations from "../configurations"; 8 | 9 | /** 10 | * Refresh keys. 11 | * @param server HollowDB server 12 | * @returns number of refreshed keys 13 | */ 14 | export async function refreshKeys(server: FastifyInstance): Promise { 15 | server.log.info(`\nRefreshing keys (${server.hollowdb.contractTxId})\n`); 16 | await server.hollowdb.base.readState(); // get to the latest state 17 | 18 | const kv = server.hollowdb.base.warp.kvStorageFactory(server.hollowdb.contractTxId); 19 | const redis = kv.storage(); 20 | 21 | // get all keys 22 | const keys = await kv.keys(lastPossibleSortKey); 23 | 24 | // return early if there are no keys 25 | if (keys.length === 0) { 26 | server.log.info("All keys are up-to-date."); 27 | return 0; 28 | } 29 | 30 | // get the last sortKey for each key 31 | const sortKeyCacheResults = await Promise.all(keys.map((key) => kv.getLast(key))); 32 | 33 | // from these values, get the ones that are out-of-date (i.e. stale) 34 | const latestSortKeys: (string | null)[] = sortKeyCacheResults.map((skcr) => (skcr ? skcr.sortKey : null)); 35 | const existingSortKeys: (string | null)[] = await redis.mget( 36 | ...keys.map((key) => toSortKeyKey(server.contractTxId, key)) 37 | ); 38 | const staleResults: KeyedSortKeyCacheResult[] = sortKeyCacheResults 39 | .map((skcr, i) => 40 | // filter out existing sortKeys 41 | // also store the respective `key` with the result 42 | latestSortKeys[i] !== existingSortKeys[i] ? { sortKeyCacheResult: skcr, key: keys[i] } : null 43 | ) 44 | // this filter will filter out both existing null values, and matching sortKeys 45 | .filter((res): res is KeyedSortKeyCacheResult => res !== null); 46 | 47 | // return early if everything is up-to-date 48 | if (staleResults.length === 0) { 49 | return 0; 50 | } 51 | 52 | const rocksdb = new RocksdbClient(server.rocksdbPath, server.contractTxId); 53 | await rocksdb.open(); 54 | 55 | const refreshValues = async (results: KeyedSortKeyCacheResult[], values?: unknown[]) => { 56 | if (values && values.length !== results.length) { 57 | throw new Error("array length mismatch"); 58 | } 59 | 60 | // create [key, value] pairs for with stringified values 61 | const valuePairs = results.map(({ key, sortKeyCacheResult: { cachedValue } }, i) => { 62 | const val = values 63 | ? // override with given value 64 | typeof values[i] === "string" 65 | ? (values[i] as string) 66 | : JSON.stringify(values[i]) 67 | : // use own value 68 | JSON.stringify(cachedValue); 69 | return [key, val] as [string, string]; 70 | }); 71 | 72 | // write values to disk (as they may be too much for the memory) 73 | await rocksdb.setMany(valuePairs); 74 | 75 | // store the `sortKey`s for later refreshes to see if a `value` is stale 76 | const sortKeyPairs = results.map( 77 | ({ key, sortKeyCacheResult: { sortKey } }) => 78 | [toSortKeyKey(server.contractTxId, key), sortKey] as [string, string] 79 | ); 80 | await redis.mset(...sortKeyPairs.flat()); 81 | }; 82 | 83 | // update values in Redis 84 | 85 | const { USE_BUNDLR, BUNDLR_FBS } = configurations; 86 | if (USE_BUNDLR) { 87 | const progress: [number, number] = [0, 0]; 88 | 89 | server.log.info("Starting batched Bundlr downloads:"); 90 | progress[1] = staleResults.length; 91 | for (let b = 0; b < staleResults.length; b += BUNDLR_FBS) { 92 | const batchResults = staleResults.slice(b, b + BUNDLR_FBS); 93 | 94 | progress[0] = Math.min(b + BUNDLR_FBS, staleResults.length); 95 | 96 | const startTime = performance.now(); 97 | const batchValues = await Promise.all( 98 | batchResults.map((result) => 99 | downloadFromBundlr<{ data: any }>(result.sortKeyCacheResult.cachedValue as string, server.log) 100 | ) 101 | ); 102 | const endTime = performance.now(); 103 | server.log.info( 104 | `${progressString(progress[0], progress[1])} values downloaded (${(endTime - startTime).toFixed(2)} ms)` 105 | ); 106 | 107 | await refreshValues( 108 | batchResults, 109 | // our Bundlr service uploads as "{data: payload}" so we parse it here 110 | batchValues.map((val) => val.data) 111 | ); 112 | } 113 | server.log.info("Downloaded & refreshed all stale values."); 114 | } else { 115 | await refreshValues(staleResults); 116 | } 117 | 118 | await rocksdb.close(); 119 | 120 | return staleResults.length; 121 | } 122 | -------------------------------------------------------------------------------- /hollowdb/test/index.test.ts: -------------------------------------------------------------------------------- 1 | import ArLocal from "arlocal"; 2 | import { ArWallet, LoggerFactory, sleep } from "warp-contracts"; 3 | 4 | import { Redis } from "ioredis"; 5 | import { SetSDK } from "hollowdb"; 6 | 7 | import { makeServer } from "../src/server"; 8 | import config from "../src/configurations"; 9 | import { createCaches } from "../src/clients/hollowdb"; 10 | import { deploy, FetchClient, randomKeyValue, makeLocalWarp } from "./util"; 11 | import { Get, GetMany, Put, PutMany, Update } from "../src/schemas"; 12 | import { randomBytes } from "crypto"; 13 | import { rmSync } from "fs"; 14 | 15 | describe("crud operations", () => { 16 | let arlocal: ArLocal; 17 | let redisClient: Redis; 18 | let client: FetchClient; 19 | let url: string; 20 | 21 | const DATA_PATH = "./test/data"; 22 | const ARWEAVE_PORT = 3169; 23 | const VALUE = randomBytes(16).toString("hex"); 24 | const NEW_VALUE = randomBytes(16).toString("hex"); 25 | const KEY = randomBytes(16).toString("hex"); 26 | 27 | beforeAll(async () => { 28 | console.log("Starting..."); 29 | 30 | // create a local Arweave instance 31 | arlocal = new ArLocal(ARWEAVE_PORT, false); 32 | await arlocal.start(); 33 | 34 | // deploy a contract locally and generate a wallet 35 | redisClient = new Redis(config.REDIS_URL, { lazyConnect: false }); 36 | let caches = createCaches("testing-setup", redisClient); 37 | let warp = makeLocalWarp(ARWEAVE_PORT, caches); 38 | const owner: ArWallet = (await warp.generateWallet()).jwk; 39 | const { contractTxId } = await deploy(owner, warp); 40 | 41 | // start the server & connect to the contract 42 | caches = createCaches(contractTxId, redisClient); 43 | warp = makeLocalWarp(ARWEAVE_PORT, caches); 44 | const hollowdb = new SetSDK(owner, contractTxId, warp); 45 | const server = await makeServer(hollowdb, `${DATA_PATH}/${contractTxId}`); 46 | url = await server.listen({ port: config.PORT }); 47 | LoggerFactory.INST.logLevel("none"); 48 | 49 | client = new FetchClient(url); 50 | 51 | // TODO: wait a bit due to state syncing 52 | console.log("waiting a bit for the server to be ready..."); 53 | await sleep(1200); 54 | console.log("done"); 55 | }); 56 | 57 | describe("basic CRUD", () => { 58 | it("should put & get a value", async () => { 59 | const putResponse = await client.post("/put", { key: KEY, value: VALUE }); 60 | expect(putResponse.status).toBe(200); 61 | 62 | const getResponse = await client.post("/get", { key: KEY }); 63 | expect(getResponse.status).toBe(200); 64 | expect(await getResponse.json().then((body) => body.value)).toBe(VALUE); 65 | }); 66 | 67 | it("should NOT put to an existing key", async () => { 68 | const putResponse = await client.post("/put", { key: KEY, value: VALUE }); 69 | expect(putResponse.status).toBe(400); 70 | const body = await putResponse.json(); 71 | expect(body.message).toBe("Contract Error [put]: Key already exists."); 72 | }); 73 | 74 | it("should update & get the new value", async () => { 75 | const updateResponse = await client.post("/update", { key: KEY, value: NEW_VALUE }); 76 | expect(updateResponse.status).toBe(200); 77 | 78 | const getResponse = await client.post("/get", { key: KEY }); 79 | expect(getResponse.status).toBe(200); 80 | expect(await getResponse.json().then((body) => body.value)).toBe(NEW_VALUE); 81 | }); 82 | 83 | it("should remove the new value & get null", async () => { 84 | const removeResponse = await client.post("/remove", { 85 | key: KEY, 86 | }); 87 | expect(removeResponse.status).toBe(200); 88 | 89 | const getResponse = await client.post("/get", { key: KEY }); 90 | expect(getResponse.status).toBe(200); 91 | expect(await getResponse.json().then((body) => body.value)).toBe(null); 92 | }); 93 | }); 94 | 95 | describe("batch gets and puts", () => { 96 | const LENGTH = 10; 97 | const KEY_VALUES = Array.from({ length: LENGTH }, () => randomKeyValue({ numVals: 768 })); 98 | 99 | it("should put many values", async () => { 100 | const putResponse = await client.post("/putMany", { 101 | keys: KEY_VALUES.map((kv) => kv.key), 102 | values: KEY_VALUES.map((kv) => kv.value), 103 | }); 104 | expect(putResponse.status).toBe(200); 105 | }); 106 | 107 | it("should get many values", async () => { 108 | const getManyResponse = await client.post("/getMany", { 109 | keys: KEY_VALUES.map((kv) => kv.key), 110 | }); 111 | expect(getManyResponse.status).toBe(200); 112 | const body = await getManyResponse.json(); 113 | for (let i = 0; i < KEY_VALUES.length; i++) { 114 | const expected = KEY_VALUES[i].value; 115 | const result = body.values[i] as (typeof KEY_VALUES)[0]["value"]; 116 | expect(result.metadata.text).toBe(expected.metadata.text); 117 | expect(result.f).toBe(expected.f); 118 | } 119 | }); 120 | 121 | it("should refresh the cache for raw GET operations", async () => { 122 | const refreshResponse = await client.post("/refresh"); 123 | expect(refreshResponse.status).toBe(200); 124 | }); 125 | 126 | it("should do a raw GET", async () => { 127 | const { key, value } = KEY_VALUES[0]; 128 | 129 | const getRawResponse = await client.post("/getRaw", { key }); 130 | expect(getRawResponse.status).toBe(200); 131 | const body = await getRawResponse.json(); 132 | const result = body.value as typeof value; 133 | expect(result.metadata.text).toBe(value.metadata.text); 134 | expect(result.f).toBe(value.f); 135 | }); 136 | 137 | it("should do a raw GET many", async () => { 138 | const getManyRawResponse = await client.post("/getManyRaw", { 139 | keys: KEY_VALUES.map((kv) => kv.key), 140 | }); 141 | expect(getManyRawResponse.status).toBe(200); 142 | const body = await getManyRawResponse.json(); 143 | for (let i = 0; i < KEY_VALUES.length; i++) { 144 | const expected = KEY_VALUES[i].value; 145 | const result = body.values[i] as (typeof KEY_VALUES)[0]["value"]; 146 | expect(result.metadata.text).toBe(expected.metadata.text); 147 | expect(result.f).toBe(expected.f); 148 | } 149 | }); 150 | 151 | it("should refresh a newly PUT key", async () => { 152 | const { key, value } = randomKeyValue(); 153 | const putResponse = await client.post("/put", { key, value }); 154 | expect(putResponse.status).toBe(200); 155 | 156 | // we expect only 1 new key to be added via REFRESH 157 | const refreshResponse = await client.post("/refresh"); 158 | expect(refreshResponse.status).toBe(200); 159 | expect(await refreshResponse.text()).toBe("1"); 160 | }); 161 | 162 | it("should refresh with 0 keys when no additions are made", async () => { 163 | const refreshResponse = await client.post("/refresh"); 164 | expect(refreshResponse.status).toBe(200); 165 | expect(await refreshResponse.text()).toBe("0"); 166 | }); 167 | 168 | it("should clear all keys", async () => { 169 | // we expect 0 keys as nothing has changed since the last refresh 170 | const clearResponse = await client.post("/clear"); 171 | expect(clearResponse.status).toBe(200); 172 | 173 | const getManyRawResponse = await client.post("/getManyRaw", { 174 | keys: KEY_VALUES.map((kv) => kv.key), 175 | }); 176 | expect(getManyRawResponse.status).toBe(200); 177 | const body = await getManyRawResponse.json(); 178 | 179 | (body.values as (typeof KEY_VALUES)[0]["value"][]).forEach((val) => expect(val).toBe(null)); 180 | }); 181 | }); 182 | 183 | afterAll(async () => { 184 | console.log("waiting a bit before closing..."); 185 | await sleep(1500); 186 | 187 | rmSync(DATA_PATH, { recursive: true }); 188 | await arlocal.stop(); 189 | await redisClient.quit(); 190 | }); 191 | }); 192 | -------------------------------------------------------------------------------- /hollowdb/test/res/contractSource.ts: -------------------------------------------------------------------------------- 1 | // you can use any build under https://github.com/firstbatchxyz/hollowdb/tree/master/src/contracts/build 2 | export default ` 3 | // src/contracts/errors/index.ts 4 | var KeyExistsError = new ContractError("Key already exists."); 5 | var KeyNotExistsError = new ContractError("Key does not exist."); 6 | var CantEvolveError = new ContractError("Evolving is disabled."); 7 | var NoVerificationKeyError = new ContractError("No verification key."); 8 | var UnknownProtocolError = new ContractError("Unknown protocol."); 9 | var NotWhitelistedError = new ContractError("Not whitelisted."); 10 | var InvalidProofError = new ContractError("Invalid proof."); 11 | var ExpectedProofError = new ContractError("Expected a proof."); 12 | var NullValueError = new ContractError("Value cant be null, use remove instead."); 13 | var NotOwnerError = new ContractError("Not contract owner."); 14 | var InvalidFunctionError = new ContractError("Invalid function."); 15 | var ArrayLengthMismatchError = new ContractError("Key and value counts mismatch."); 16 | 17 | // src/contracts/utils/index.ts 18 | var verifyProof = async (proof, psignals, verificationKey) => { 19 | if (!verificationKey) { 20 | throw NoVerificationKeyError; 21 | } 22 | if (verificationKey.protocol !== "groth16" && verificationKey.protocol !== "plonk") { 23 | throw UnknownProtocolError; 24 | } 25 | return await SmartWeave.extensions[verificationKey.protocol].verify(verificationKey, psignals, proof); 26 | }; 27 | var hashToGroup = (value) => { 28 | if (value) { 29 | return BigInt(SmartWeave.extensions.ethers.utils.ripemd160(Buffer.from(JSON.stringify(value)))); 30 | } else { 31 | return BigInt(0); 32 | } 33 | }; 34 | 35 | // src/contracts/modifiers/index.ts 36 | var onlyOwner = (caller, input, state) => { 37 | if (caller !== state.owner) { 38 | throw NotOwnerError; 39 | } 40 | return input; 41 | }; 42 | var onlyNonNullValue = (_, input) => { 43 | if (input.value === null) { 44 | throw NullValueError; 45 | } 46 | return input; 47 | }; 48 | var onlyNonNullValues = (_, input) => { 49 | if (input.values.some((val) => val === null)) { 50 | throw NullValueError; 51 | } 52 | return input; 53 | }; 54 | var onlyWhitelisted = (list) => { 55 | return (caller, input, state) => { 56 | if (!state.isWhitelistRequired[list]) { 57 | return input; 58 | } 59 | if (!state.whitelists[list][caller]) { 60 | throw NotWhitelistedError; 61 | } 62 | return input; 63 | }; 64 | }; 65 | var onlyProofVerified = (proofName, prepareInputs) => { 66 | return async (caller, input, state) => { 67 | if (!state.isProofRequired[proofName]) { 68 | return input; 69 | } 70 | if (!input.proof) { 71 | throw ExpectedProofError; 72 | } 73 | const ok = await verifyProof( 74 | input.proof, 75 | await prepareInputs(caller, input, state), 76 | state.verificationKeys[proofName] 77 | ); 78 | if (!ok) { 79 | throw InvalidProofError; 80 | } 81 | return input; 82 | }; 83 | }; 84 | async function apply(caller, input, state, ...modifiers) { 85 | for (const modifier of modifiers) { 86 | input = await modifier(caller, input, state); 87 | } 88 | return input; 89 | } 90 | 91 | // src/contracts/hollowdb-set.contract.ts 92 | var handle = async (state, action) => { 93 | const { caller, input } = action; 94 | switch (input.function) { 95 | case "get": { 96 | const { key } = await apply(caller, input.value, state); 97 | return { result: await SmartWeave.kv.get(key) }; 98 | } 99 | case "getMany": { 100 | const { keys } = await apply(caller, input.value, state); 101 | const values = await Promise.all(keys.map((key) => SmartWeave.kv.get(key))); 102 | return { result: values }; 103 | } 104 | case "set": { 105 | const { key, value } = await apply(caller, input.value, state, onlyWhitelisted("set"), onlyNonNullValue); 106 | await SmartWeave.kv.put(key, value); 107 | return { state }; 108 | } 109 | case "setMany": { 110 | const { keys, values } = await apply(caller, input.value, state, onlyWhitelisted("set"), onlyNonNullValues); 111 | if (keys.length !== values.length) { 112 | throw new ContractError("Key and value counts mismatch"); 113 | } 114 | await Promise.all(keys.map((key, i) => SmartWeave.kv.put(key, values[i]))); 115 | return { state }; 116 | } 117 | case "getKeys": { 118 | const { options } = await apply(caller, input.value, state); 119 | return { result: await SmartWeave.kv.keys(options) }; 120 | } 121 | case "getKVMap": { 122 | const { options } = await apply(caller, input.value, state); 123 | return { result: await SmartWeave.kv.kvMap(options) }; 124 | } 125 | case "put": { 126 | const { key, value } = await apply(caller, input.value, state, onlyWhitelisted("put"), onlyNonNullValue); 127 | if (await SmartWeave.kv.get(key) !== null) { 128 | throw KeyExistsError; 129 | } 130 | await SmartWeave.kv.put(key, value); 131 | return { state }; 132 | } 133 | case "putMany": { 134 | const { keys, values } = await apply(caller, input.value, state, onlyWhitelisted("put"), onlyNonNullValues); 135 | if (keys.length !== values.length) { 136 | throw new ContractError("Key and value counts mismatch"); 137 | } 138 | if (await Promise.all(keys.map((key) => SmartWeave.kv.get(key))).then((values2) => values2.some((val) => val !== null))) { 139 | throw KeyExistsError; 140 | } 141 | await Promise.all(keys.map((key, i) => SmartWeave.kv.put(key, values[i]))); 142 | return { state }; 143 | } 144 | case "update": { 145 | const { key, value } = await apply( 146 | caller, 147 | input.value, 148 | state, 149 | onlyNonNullValue, 150 | onlyWhitelisted("update"), 151 | onlyProofVerified("auth", async (_, input2) => { 152 | const oldValue = await SmartWeave.kv.get(input2.key); 153 | return [hashToGroup(oldValue), hashToGroup(input2.value), BigInt(input2.key)]; 154 | }) 155 | ); 156 | await SmartWeave.kv.put(key, value); 157 | return { state }; 158 | } 159 | case "remove": { 160 | const { key } = await apply( 161 | caller, 162 | input.value, 163 | state, 164 | onlyWhitelisted("update"), 165 | onlyProofVerified("auth", async (_, input2) => { 166 | const oldValue = await SmartWeave.kv.get(input2.key); 167 | return [hashToGroup(oldValue), BigInt(0), BigInt(input2.key)]; 168 | }) 169 | ); 170 | await SmartWeave.kv.del(key); 171 | return { state }; 172 | } 173 | case "updateOwner": { 174 | const { newOwner } = await apply(caller, input.value, state, onlyOwner); 175 | state.owner = newOwner; 176 | return { state }; 177 | } 178 | case "updateProofRequirement": { 179 | const { name, value } = await apply(caller, input.value, state, onlyOwner); 180 | state.isProofRequired[name] = value; 181 | return { state }; 182 | } 183 | case "updateVerificationKey": { 184 | const { name, verificationKey } = await apply(caller, input.value, state, onlyOwner); 185 | state.verificationKeys[name] = verificationKey; 186 | return { state }; 187 | } 188 | case "updateWhitelistRequirement": { 189 | const { name, value } = await apply(caller, input.value, state, onlyOwner); 190 | state.isWhitelistRequired[name] = value; 191 | return { state }; 192 | } 193 | case "updateWhitelist": { 194 | const { add, remove, name } = await apply(caller, input.value, state, onlyOwner); 195 | add.forEach((user) => { 196 | state.whitelists[name][user] = true; 197 | }); 198 | remove.forEach((user) => { 199 | delete state.whitelists[name][user]; 200 | }); 201 | return { state }; 202 | } 203 | case "evolve": { 204 | const srcTxId = await apply(caller, input.value, state, onlyOwner); 205 | if (!state.canEvolve) { 206 | throw CantEvolveError; 207 | } 208 | state.evolve = srcTxId; 209 | return { state }; 210 | } 211 | default: 212 | input; 213 | throw InvalidFunctionError; 214 | } 215 | }; 216 | ` as string; 217 | -------------------------------------------------------------------------------- /hollowdb/test/res/initialState.ts: -------------------------------------------------------------------------------- 1 | export default { 2 | owner: "", 3 | verificationKeys: { 4 | auth: null, 5 | }, 6 | isProofRequired: { 7 | auth: false, 8 | }, 9 | canEvolve: false, 10 | whitelist: { 11 | put: {}, 12 | update: {}, 13 | set: {}, 14 | }, 15 | isWhitelistRequired: { 16 | put: false, 17 | update: false, 18 | set: false, 19 | }, 20 | } as const; 21 | -------------------------------------------------------------------------------- /hollowdb/test/util/index.ts: -------------------------------------------------------------------------------- 1 | import initialState from "../res/initialState"; 2 | import contractSource from "../res/contractSource"; 3 | import { randomUUID } from "crypto"; 4 | import { loremIpsum } from "lorem-ipsum"; 5 | import { JWKInterface, WarpFactory, Warp } from "warp-contracts"; 6 | import { DeployPlugin } from "warp-contracts-plugin-deploy"; 7 | import { CacheTypes } from "../../src/types"; 8 | /** 9 | * Returns the size of a given data in bytes. 10 | * - To convert to KBs: `size / (1 << 10)` 11 | * - To convert to MBs: `size / (1 << 20)` 12 | * @param data data, such as `JSON.stringify(body)` for a POST request. 13 | * @returns data size in bytes 14 | */ 15 | export function size(data: string) { 16 | return new Blob([data]).size; 17 | } 18 | 19 | /** A tiny API wrapper. */ 20 | export class FetchClient { 21 | constructor(readonly baseUrl: string) {} 22 | 23 | /** 24 | * Generic POST utility for HollowDB micro. Depending on the 25 | * request, call `response.json()` or `response.text()` to parse 26 | * the returned body. 27 | * @param url url 28 | * @param data body 29 | * @returns response object 30 | */ 31 | async post(url: string, data?: Body) { 32 | const body = JSON.stringify(data ?? {}); 33 | return fetch(this.baseUrl + url, { 34 | method: "POST", 35 | headers: { 36 | "Content-Type": "application/json; charset=utf-8", 37 | }, 38 | body, 39 | }); 40 | } 41 | } 42 | 43 | /** 44 | * Creates a local warp instance, also uses the `DeployPlugin`. 45 | * 46 | * WARNING: Do not use `useStateCache` and `useContractCache` together with `forLocal`. 47 | */ 48 | export function makeLocalWarp(port: number, caches: CacheTypes): Warp { 49 | return WarpFactory.forLocal(port).use(new DeployPlugin()).useKVStorageFactory(caches.kvFactory); 50 | } 51 | 52 | /** Returns a random key-value pair related to our internal usage. */ 53 | export function randomKeyValue(options?: { numVals?: number; numChildren?: number }): { 54 | key: string; 55 | value: { 56 | children: number[]; 57 | f: number; 58 | is_leaf: boolean; 59 | n_descendants: number; 60 | metadata: { 61 | text: string; 62 | }; 63 | v: number[]; 64 | }; 65 | } { 66 | const numChildren = options?.numChildren || Math.round(Math.random() * 5); 67 | const numVals = options?.numVals || Math.round(Math.random() * 500 + 100); 68 | 69 | return { 70 | key: randomUUID(), 71 | value: { 72 | children: Array.from({ length: numChildren }, () => Math.round(Math.random() * 50)), 73 | f: Math.round(Math.random() * 100), 74 | is_leaf: Math.random() < 0.5, 75 | metadata: { 76 | text: loremIpsum({ count: 4 }), 77 | }, 78 | n_descendants: Math.round(Math.random() * 50), 79 | v: Array.from({ length: numVals }, () => Math.random() * 2 - 1), 80 | }, 81 | }; 82 | } 83 | 84 | /** 85 | * Deploy a new contract via the provided Warp instance. 86 | * @param owner owner wallet 87 | * @param warp a `Warp` instance 88 | * */ 89 | export async function deploy( 90 | owner: JWKInterface, 91 | warp: Warp 92 | ): Promise<{ contractTxId: string; srcTxId: string | undefined }> { 93 | if (warp.environment !== "local") { 94 | throw new Error("Expected a local Warp environment."); 95 | } 96 | 97 | const { contractTxId, srcTxId } = await warp.deploy( 98 | { 99 | wallet: owner, 100 | initState: JSON.stringify(initialState), 101 | src: contractSource, 102 | evaluationManifest: { 103 | evaluationOptions: { 104 | allowBigInt: true, 105 | useKVStorage: true, 106 | }, 107 | }, 108 | }, 109 | true // disable bundling in test environment 110 | ); 111 | 112 | return { contractTxId, srcTxId }; 113 | } 114 | -------------------------------------------------------------------------------- /hollowdb/tsconfig.build.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "./tsconfig.json", 3 | "compilerOptions": { 4 | "rootDir": "./src" 5 | }, 6 | "exclude": ["node_modules", "jest.config.ts", "test"] 7 | } 8 | -------------------------------------------------------------------------------- /hollowdb/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | /* Visit https://aka.ms/tsconfig to read more about this file */ 4 | 5 | /* Projects */ 6 | // "incremental": true, /* Save .tsbuildinfo files to allow for incremental compilation of projects. */ 7 | // "composite": true, /* Enable constraints that allow a TypeScript project to be used with project references. */ 8 | // "tsBuildInfoFile": "./.tsbuildinfo", /* Specify the path to .tsbuildinfo incremental compilation file. */ 9 | // "disableSourceOfProjectReferenceRedirect": true, /* Disable preferring source files instead of declaration files when referencing composite projects. */ 10 | // "disableSolutionSearching": true, /* Opt a project out of multi-project reference checking when editing. */ 11 | // "disableReferencedProjectLoad": true, /* Reduce the number of projects loaded automatically by TypeScript. */ 12 | 13 | /* Language and Environment */ 14 | "target": "ES2022" /* Set the JavaScript language version for emitted JavaScript and include compatible library declarations. */, 15 | // "lib": [], /* Specify a set of bundled library declaration files that describe the target runtime environment. */ 16 | // "jsx": "preserve", /* Specify what JSX code is generated. */ 17 | // "experimentalDecorators": true, /* Enable experimental support for legacy experimental decorators. */ 18 | // "emitDecoratorMetadata": true, /* Emit design-type metadata for decorated declarations in source files. */ 19 | // "jsxFactory": "", /* Specify the JSX factory function used when targeting React JSX emit, e.g. 'React.createElement' or 'h'. */ 20 | // "jsxFragmentFactory": "", /* Specify the JSX Fragment reference used for fragments when targeting React JSX emit e.g. 'React.Fragment' or 'Fragment'. */ 21 | // "jsxImportSource": "", /* Specify module specifier used to import the JSX factory functions when using 'jsx: react-jsx*'. */ 22 | // "reactNamespace": "", /* Specify the object invoked for 'createElement'. This only applies when targeting 'react' JSX emit. */ 23 | // "noLib": true, /* Disable including any library files, including the default lib.d.ts. */ 24 | // "useDefineForClassFields": true, /* Emit ECMAScript-standard-compliant class fields. */ 25 | // "moduleDetection": "auto", /* Control what method is used to detect module-format JS files. */ 26 | 27 | /* Modules */ 28 | "module": "commonjs" /* Specify what module code is generated. */, 29 | // "rootDir": "./src" /* Specify the root folder within your source files. */, 30 | // "moduleResolution": "node10", /* Specify how TypeScript looks up a file from a given module specifier. */ 31 | // "baseUrl": "./", /* Specify the base directory to resolve non-relative module names. */ 32 | // "paths": {}, /* Specify a set of entries that re-map imports to additional lookup locations. */ 33 | // "rootDirs": [], /* Allow multiple folders to be treated as one when resolving modules. */ 34 | // "typeRoots": [], /* Specify multiple folders that act like './node_modules/@types'. */ 35 | // "types": [], /* Specify type package names to be included without being referenced in a source file. */ 36 | // "allowUmdGlobalAccess": true, /* Allow accessing UMD globals from modules. */ 37 | // "moduleSuffixes": [], /* List of file name suffixes to search when resolving a module. */ 38 | // "allowImportingTsExtensions": true, /* Allow imports to include TypeScript file extensions. Requires '--moduleResolution bundler' and either '--noEmit' or '--emitDeclarationOnly' to be set. */ 39 | // "resolvePackageJsonExports": true, /* Use the package.json 'exports' field when resolving package imports. */ 40 | // "resolvePackageJsonImports": true, /* Use the package.json 'imports' field when resolving imports. */ 41 | // "customConditions": [], /* Conditions to set in addition to the resolver-specific defaults when resolving imports. */ 42 | // "resolveJsonModule": true, /* Enable importing .json files. */ 43 | // "allowArbitraryExtensions": true, /* Enable importing files with any extension, provided a declaration file is present. */ 44 | // "noResolve": true, /* Disallow 'import's, 'require's or ''s from expanding the number of files TypeScript should add to a project. */ 45 | 46 | /* JavaScript Support */ 47 | // "allowJs": true, /* Allow JavaScript files to be a part of your program. Use the 'checkJS' option to get errors from these files. */ 48 | // "checkJs": true, /* Enable error reporting in type-checked JavaScript files. */ 49 | // "maxNodeModuleJsDepth": 1, /* Specify the maximum folder depth used for checking JavaScript files from 'node_modules'. Only applicable with 'allowJs'. */ 50 | 51 | /* Emit */ 52 | // "declaration": true, /* Generate .d.ts files from TypeScript and JavaScript files in your project. */ 53 | // "declarationMap": true, /* Create sourcemaps for d.ts files. */ 54 | // "emitDeclarationOnly": true, /* Only output d.ts files and not JavaScript files. */ 55 | // "sourceMap": true, /* Create source map files for emitted JavaScript files. */ 56 | // "inlineSourceMap": true, /* Include sourcemap files inside the emitted JavaScript. */ 57 | // "outFile": "./", /* Specify a file that bundles all outputs into one JavaScript file. If 'declaration' is true, also designates a file that bundles all .d.ts output. */ 58 | "outDir": "./build" /* Specify an output folder for all emitted files. */, 59 | // "removeComments": true, /* Disable emitting comments. */ 60 | // "noEmit": true, /* Disable emitting files from a compilation. */ 61 | // "importHelpers": true, /* Allow importing helper functions from tslib once per project, instead of including them per-file. */ 62 | // "importsNotUsedAsValues": "remove", /* Specify emit/checking behavior for imports that are only used for types. */ 63 | // "downlevelIteration": true, /* Emit more compliant, but verbose and less performant JavaScript for iteration. */ 64 | // "sourceRoot": "", /* Specify the root path for debuggers to find the reference source code. */ 65 | // "mapRoot": "", /* Specify the location where debugger should locate map files instead of generated locations. */ 66 | // "inlineSources": true, /* Include source code in the sourcemaps inside the emitted JavaScript. */ 67 | // "emitBOM": true, /* Emit a UTF-8 Byte Order Mark (BOM) in the beginning of output files. */ 68 | // "newLine": "crlf", /* Set the newline character for emitting files. */ 69 | // "stripInternal": true, /* Disable emitting declarations that have '@internal' in their JSDoc comments. */ 70 | // "noEmitHelpers": true, /* Disable generating custom helper functions like '__extends' in compiled output. */ 71 | // "noEmitOnError": true, /* Disable emitting files if any type checking errors are reported. */ 72 | // "preserveConstEnums": true, /* Disable erasing 'const enum' declarations in generated code. */ 73 | // "declarationDir": "./", /* Specify the output directory for generated declaration files. */ 74 | // "preserveValueImports": true, /* Preserve unused imported values in the JavaScript output that would otherwise be removed. */ 75 | 76 | /* Interop Constraints */ 77 | // "isolatedModules": true, /* Ensure that each file can be safely transpiled without relying on other imports. */ 78 | // "verbatimModuleSyntax": true, /* Do not transform or elide any imports or exports not marked as type-only, ensuring they are written in the output file's format based on the 'module' setting. */ 79 | // "allowSyntheticDefaultImports": true, /* Allow 'import x from y' when a module doesn't have a default export. */ 80 | "esModuleInterop": true /* Emit additional JavaScript to ease support for importing CommonJS modules. This enables 'allowSyntheticDefaultImports' for type compatibility. */, 81 | // "preserveSymlinks": true, /* Disable resolving symlinks to their realpath. This correlates to the same flag in node. */ 82 | "forceConsistentCasingInFileNames": true /* Ensure that casing is correct in imports. */, 83 | 84 | /* Type Checking */ 85 | "strict": true /* Enable all strict type-checking options. */, 86 | // "noImplicitAny": true, /* Enable error reporting for expressions and declarations with an implied 'any' type. */ 87 | // "strictNullChecks": true, /* When type checking, take into account 'null' and 'undefined'. */ 88 | // "strictFunctionTypes": true, /* When assigning functions, check to ensure parameters and the return values are subtype-compatible. */ 89 | // "strictBindCallApply": true, /* Check that the arguments for 'bind', 'call', and 'apply' methods match the original function. */ 90 | // "strictPropertyInitialization": true, /* Check for class properties that are declared but not set in the constructor. */ 91 | // "noImplicitThis": true, /* Enable error reporting when 'this' is given the type 'any'. */ 92 | // "useUnknownInCatchVariables": true, /* Default catch clause variables as 'unknown' instead of 'any'. */ 93 | // "alwaysStrict": true, /* Ensure 'use strict' is always emitted. */ 94 | // "noUnusedLocals": true, /* Enable error reporting when local variables aren't read. */ 95 | // "noUnusedParameters": true, /* Raise an error when a function parameter isn't read. */ 96 | // "exactOptionalPropertyTypes": true, /* Interpret optional property types as written, rather than adding 'undefined'. */ 97 | // "noImplicitReturns": true, /* Enable error reporting for codepaths that do not explicitly return in a function. */ 98 | // "noFallthroughCasesInSwitch": true, /* Enable error reporting for fallthrough cases in switch statements. */ 99 | // "noUncheckedIndexedAccess": true, /* Add 'undefined' to a type when accessed using an index. */ 100 | // "noImplicitOverride": true, /* Ensure overriding members in derived classes are marked with an override modifier. */ 101 | // "noPropertyAccessFromIndexSignature": true, /* Enforces using indexed accessors for keys declared using an indexed type. */ 102 | // "allowUnusedLabels": true, /* Disable error reporting for unused labels. */ 103 | // "allowUnreachableCode": true, /* Disable error reporting for unreachable code. */ 104 | 105 | /* Completeness */ 106 | // "skipDefaultLibCheck": true, /* Skip type checking .d.ts files that are included with TypeScript. */ 107 | "skipLibCheck": true /* Skip type checking all .d.ts files. */ 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /hollowdb_wait/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM alpine 2 | WORKDIR /app 3 | 4 | RUN apk update 5 | RUN apk add curl 6 | 7 | COPY wait.sh /app/wait.sh 8 | RUN chmod +x wait.sh 9 | 10 | CMD ./wait.sh 11 | -------------------------------------------------------------------------------- /hollowdb_wait/README.md: -------------------------------------------------------------------------------- 1 | # HollowDB API `wait-for` 2 | 3 | This is a simple shell script container that another container can depend "on completion", such that the script will finish when HollowDB container finished downloading & refreshing keys. 4 | 5 | The containers are expected to launch in the following order: 6 | 7 | 1. **Redis**: This is the first container to launch. 8 | 2. **HollowDB**: This starts when Redis is live, and immediately begins to download values from Arweave & store them in memory for Dria to access efficiently. 9 | 3. **Dria HNSW**: Dria waits for the HollowDB API's download the complete via [this wait-for script](./wait.sh), and once that is complete; it launches & starts listening at it's port. 10 | 11 | The script is available on Docker Hub: 12 | 13 | ```sh 14 | docker pull firstbatch/dria-hollowdb-wait-for 15 | ``` 16 | 17 | ## Wait-For Script 18 | 19 | The script makes use of the following cURL command: 20 | 21 | ```sh 22 | curl -f -d '{"route": "STATE"}' $TARGET 23 | ``` 24 | 25 | If the cache is still loading, HollowDB will respond with status `503` and cURL will return a non-zero code, causing the shell script to wait for a while and try again later. The body of the response also contains the percentage of keys loaded, if you are to make the request yourself. 26 | -------------------------------------------------------------------------------- /hollowdb_wait/wait.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Connection parameters 4 | TARGET=${TARGET:-"http://localhost:3030"} 5 | 6 | # How long to sleep (seconds) before each attempt. 7 | SLEEPS=${SLEEPS:-2} 8 | 9 | echo "Polling $TARGET" 10 | 11 | while true; do 12 | curl --fail --silent "$TARGET/state" 13 | 14 | # check if exit code of curl is 0 15 | if [ $? -eq 0 ]; then 16 | echo "" 17 | echo "HollowDB API is ready!" 18 | exit 0 19 | fi 20 | 21 | sleep "$SLEEPS" 22 | done 23 | --------------------------------------------------------------------------------