├── .gitignore
├── LICENSE
├── Makefile
├── README.md
├── compose.yaml
├── dria_hnsw
├── .dockerignore
├── .gitignore
├── Cargo.lock
├── Cargo.toml
├── Dockerfile
├── Makefile
├── README.md
└── src
│ ├── db
│ ├── conversions.rs
│ ├── env.rs
│ ├── mod.rs
│ ├── redis_client.rs
│ └── rocksdb_client.rs
│ ├── errors
│ ├── errors.rs
│ └── mod.rs
│ ├── filter
│ ├── mod.rs
│ └── text_based.rs
│ ├── hnsw
│ ├── index.rs
│ ├── mod.rs
│ ├── scalar.rs
│ ├── sync_map.rs
│ └── utils.rs
│ ├── lib.rs
│ ├── main.rs
│ ├── middlewares
│ ├── cache.rs
│ └── mod.rs
│ ├── models
│ ├── mod.rs
│ └── request_models.rs
│ ├── proto
│ ├── hnsw_comm.proto
│ ├── index.proto
│ ├── index_buffer.rs
│ ├── insert.proto
│ ├── insert_buffer.rs
│ ├── mod.rs
│ ├── request.proto
│ └── request_buffer.rs
│ ├── responses
│ ├── mod.rs
│ └── responses.rs
│ └── worker.rs
├── hollowdb
├── .dockerignore
├── .env.example
├── .gitignore
├── .yarnrc.yml
├── Dockerfile
├── README.md
├── config
│ └── .gitignore
├── jest.config.ts
├── package.json
├── src
│ ├── clients
│ │ ├── hollowdb.ts
│ │ └── rocksdb.ts
│ ├── configurations
│ │ └── index.ts
│ ├── controllers
│ │ ├── read.ts
│ │ ├── values.ts
│ │ └── write.ts
│ ├── global.d.ts
│ ├── index.ts
│ ├── schemas
│ │ └── index.ts
│ ├── server.ts
│ ├── types
│ │ └── index.ts
│ └── utilities
│ │ ├── download.ts
│ │ └── refresh.ts
├── test
│ ├── index.test.ts
│ ├── res
│ │ ├── contractSource.ts
│ │ └── initialState.ts
│ └── util
│ │ └── index.ts
├── tsconfig.build.json
├── tsconfig.json
└── yarn.lock
└── hollowdb_wait
├── Dockerfile
├── README.md
└── wait.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | wallet.json
3 | dump.rdb
4 | .vscode
5 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2023 FirstBatch
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 |
2 | ######### DRIA HNSW #########
3 | .PHONY: build-hnsw run-hnsw push-hnsw
4 |
5 | build-hnsw:
6 | docker build ./dria_hnsw -t dria-hnsw
7 |
8 | run-hnsw:
9 | docker run \
10 | -e CONTRACT_ID=WbcY2a-KfDpk7fsgumUtLC2bu4NQcVzNlXWi13fPMlU \
11 | -e REDIS_URL=redis://default:redispw@localhost:6379 \
12 | -p 8080:8080 \
13 | dria-hnsw
14 |
15 | push-hnsw:
16 | docker buildx build \
17 | --platform=linux/amd64,linux/arm64,linux/arm ./dria_hnsw \
18 | -t firstbatch/dria-hnsw:latest \
19 | --builder=dria-builder --push
20 |
21 | ######### HOLLOWDB ##########
22 | .PHONY: build-hollowdb run-hollowdb push-hollowdb
23 |
24 | build-hollowdb:
25 | docker build ./hollowdb -t dria-hollowdb
26 |
27 | run-hollowdb:
28 | docker run \
29 | -e CONTRACT_ID=WbcY2a-KfDpk7fsgumUtLC2bu4NQcVzNlXWi13fPMlU \
30 | -e REDIS_URL=redis://default:redispw@localhost:6379 \
31 | -p 3030:3030 \
32 | dria-hollowdb
33 |
34 | push-hollowdb:
35 | docker build ./hollowdb -t firstbatch/dria-hollowdb:latest --push
36 |
37 | ####### HOLLOWDB WAIT #######
38 | .PHONY: build-hollowdb-wait run-hollowdb-wait push-hollowdb-wait
39 |
40 | build-hollowdb-wait:
41 | docker build ./hollowdb_wait -t hollowdb-wait-for
42 |
43 | run-hollowdb-wait:
44 | docker run hollowdb-wait-for
45 |
46 | push-hollowdb-wait:
47 | docker build ./hollowdb_wait -t firstbatch/dria-hollowdb-wait-for:latest --push
48 |
49 | ########### REDIS ###########
50 | pull-redis:
51 | docker pull redis:alpine
52 |
53 | run-redis:
54 | docker run redis:alpine --name dria-redis -p 6379:6379
55 |
56 | ########## BUILDER ##########
57 | .PHONY: dria-builder
58 |
59 | # see: https://docs.docker.com/build/building/multi-platform/#cross-compilation
60 | dria-builder:
61 | docker buildx create --name dria-builder --bootstrap --use
62 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | Dria Docker
8 |
9 |
10 | Dria Docker is an all-in-one environment to use Dria , the collective knowledge for AI.
11 |
12 |
13 |
14 | ## Setup
15 |
16 | To use Dria Docker, you need:
17 |
18 | - [Docker](https://www.docker.com/) installed in your machine.
19 | - A Dria contract ID
20 |
21 | A Dria contract is the knowledge that is deployed on Arweave; the contract ID can be seen on each knowledge deployed to [Dria](https://dria.co/). For example, consider the Dria knowledge of [The Rust Programming Language](https://dria.co/knowledge/7EZMw0vAAFaKVMNOmu2rFgFCFjRD2C2F0kI_N5Cv6QQ):
22 |
23 | -
24 |
25 | The base64 URL there is our contract ID, and it can also be seen at the top of the page at that link.
26 |
27 | ### Using Dria CLI
28 |
29 | The preferred method of using Dria Docker is via the [Dria CLI](https://github.com/firstbatchxyz/dria-cli/), which is an NPM package.
30 |
31 | ```sh
32 | npm i -g dria-cli
33 | ```
34 |
35 | You can see available commands with:
36 |
37 | ```sh
38 | dria help
39 | ```
40 |
41 | See the [docs](https://github.com/firstbatchxyz/dria-cli/?tab=readme-ov-file#usage) of Dria CLI for more.
42 |
43 | ### Using Compose
44 |
45 | Download the Docker compose file:
46 |
47 | ```sh
48 | curl -o compose.yaml -L https://raw.githubusercontent.com/firstbatchxyz/dria-docker/master/compose.yaml
49 | ```
50 |
51 | You can start a Dria container with the following command, where the contract ID is provided as environment variable.
52 |
53 | ```sh
54 | CONTRACT=contract-id docker compose up
55 | ```
56 |
57 | ## Usage
58 |
59 | When everything is up, you will have access to both Dria and HollowDB on your local network!
60 |
61 | - Dria HNSW will be live at `localhost:8080`, see endpoints [here](./dria_hnsw/README.md#endpoints).
62 | - HollowDB API will be live at `localhost:3030`, see endpoints [here](./hollowdb/README.md#endpoints).
63 |
64 | These host ports can also be changed within the [compose file](./compose.yaml), if you have them reserved for other applications.
65 |
66 | > [!TIP]
67 | >
68 | > You can also connect to a terminal on the Redis container and use `redis-cli` if you would like to examine the keys.
69 |
70 | ## License
71 |
72 | Dria Docker is licensed under [Apache 2.0](./LICENSE).
73 |
--------------------------------------------------------------------------------
/compose.yaml:
--------------------------------------------------------------------------------
1 | version: "3.8"
2 |
3 | services:
4 | ### Dria HNSW Rust code ###
5 | dria-hnsw:
6 | #build: ./dria_hnsw
7 | image: "firstbatch/dria-hnsw"
8 | environment:
9 | - PORT=8080
10 | - ROCKSDB_PATH=/data/${CONTRACT}
11 | - REDIS_URL=redis://default:redispw@redis:6379
12 | - CONTRACT_ID=${CONTRACT}
13 | volumes:
14 | - ${HOME}/.dria/data:/data
15 | ports:
16 | - "8080:8080"
17 | depends_on:
18 | hollowdb-wait-for:
19 | condition: service_completed_successfully
20 |
21 | ### HollowDBs API 'wait-for' script ###
22 | hollowdb-wait-for:
23 | # build: ./hollowdb_wait
24 | image: "firstbatch/dria-hollowdb-wait-for"
25 | environment:
26 | - TARGET=hollowdb:3000
27 | depends_on:
28 | - hollowdb
29 |
30 | ### HollowDB API ###
31 | hollowdb:
32 | #build: ./hollowdb
33 | image: "firstbatch/dria-hollowdb"
34 | ports:
35 | - "3000:3000"
36 | expose:
37 | - "3000" # used by HollowDB wait-for script
38 | volumes:
39 | - ${HOME}/.dria/data:/app/data
40 | environment:
41 | - PORT=3000
42 | - CONTRACT_ID=${CONTRACT}
43 | - ROCKSDB_PATH=/app/data/${CONTRACT}
44 | - REDIS_URL=redis://default:redispw@redis:6379
45 | - USE_BUNDLR=true # true if your contract uses Bundlr
46 | - USE_HTX=true # true if your contract stores values as `hash.txid`
47 | - BUNDLR_FBS=80 # batch size for downloading bundled values from Arweave
48 | depends_on:
49 | - redis
50 |
51 | ### Redis Container ###
52 | redis:
53 | image: "redis:alpine"
54 | expose:
55 | - "6379"
56 | # prettier-ignore
57 | command: [
58 | 'redis-server',
59 | '--port', '6379',
60 | '--maxmemory', '100mb',
61 | '--maxmemory-policy', 'allkeys-lru',
62 | '--appendonly', 'no',
63 | '--dbfilename', '${CONTRACT}.rdb',
64 | '--dir', '/tmp'
65 | ]
66 |
--------------------------------------------------------------------------------
/dria_hnsw/.dockerignore:
--------------------------------------------------------------------------------
1 | target
2 | Dockerfile
3 | manifests
4 | .dockerignore
5 | .git
6 | .gitignore
7 | .github
8 | .DS_Store
9 | README.md
10 |
--------------------------------------------------------------------------------
/dria_hnsw/.gitignore:
--------------------------------------------------------------------------------
1 | target
2 |
--------------------------------------------------------------------------------
/dria_hnsw/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "dria_hnsw"
3 | version = "0.1.0"
4 | edition = "2021"
5 |
6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
7 |
8 | [dependencies]
9 | redis = { version = "0.23.0", features = ["cluster"] }
10 | prost = "0.10.0"
11 | prost-build = "0.10.0"
12 | prost-types = "0.10.0"
13 | prost-derive = "0.10.0"
14 | base64 = "0.13.0"
15 | rand = "0.8.5"
16 | serde = { version = "1.0", features = ["derive"] }
17 | serde_json = "1.0"
18 | chrono = "0.4.19"
19 | hashbrown = "0.14.0"
20 | url = "2.4.0"
21 | actix-web = "4.3.1"
22 | reqwest = "0.11.18"
23 | actix-cors = "0.6.4"
24 | tokio = { version = "1", features = ["full"] }
25 | futures-util = "0.3.28"
26 | proc-macro-error = "1.0.4"
27 | derive_more = "0.99.2"
28 | tdigest = "0.2.3"
29 | rayon = "1.7.0"
30 | ahash = "0.8.6"
31 | rocksdb = "0.21.0"
32 | parking_lot = "0.12.1"
33 | crossbeam-channel = "0.5.8"
34 | mini-moka = "0.10.1"
35 | dashmap = "5.5.0"
36 | log = "0.4"
37 | simple_logger = "4.2.0"
38 | simsimd = "3.8.0"
39 | probly-search = "2.0.0"
40 |
41 | [dev-dependencies]
42 | simple-home-dir = "0.3.2"
43 |
--------------------------------------------------------------------------------
/dria_hnsw/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM --platform=$BUILDPLATFORM rust:1.71 as builder
2 |
3 | # https://docs.docker.com/engine/reference/builder/#automatic-platform-args-in-the-global-scope
4 | #
5 | # offical rust image supports the following archs:
6 | # amd64 (AMD & Intel 64-bit)
7 | # arm32/v7 (ARMv7 32-bit)
8 | # arm64/v8 (ARMv8 64-bit)
9 | # i386 (Intel 32-bit 8086)
10 | #
11 | # our builds will be for platforms:
12 | # linux/amd64
13 | # linux/arm64/v8
14 | # linux/arm32/v7
15 | # linux/i386
16 | #
17 | # however, for small image size we use distroless, which allow
18 | # linux/amd64
19 | # linux/arm64
20 | # linux/arm
21 | #
22 | # To build an image & push them to Docker hub for this Dockerfile:
23 | #
24 | # docker buildx build --platform=linux/amd64,linux/arm64,linux/arm . -t firstbatch/dria-hnsw:latest --builder=dria-builder --push
25 |
26 | ARG BUILDPLATFORM
27 | ARG TARGETPLATFORM
28 | RUN echo "Build platform: $BUILDPLATFORM"
29 | RUN echo "Target platform: $TARGETPLATFORM"
30 |
31 | # install Cmake
32 | RUN apt-get update
33 | RUN apt-get install -y cmake
34 |
35 | # libclang needed by rocksdb
36 | RUN apt-get install -y clang
37 |
38 | # use nightly Rust
39 | RUN rustup install nightly-2023-07-25
40 | RUN rustup default nightly-2023-07-25
41 |
42 | # build release binary
43 | WORKDIR /usr/src/app
44 | COPY . .
45 | RUN cargo build --release
46 |
47 | # copy release binary to distroless
48 | FROM --platform=$BUILDPLATFORM gcr.io/distroless/cc
49 | COPY --from=builder /usr/src/app/target/release/dria_hnsw /
50 |
51 | EXPOSE 8080
52 |
53 | CMD ["./dria_hnsw"]
54 |
--------------------------------------------------------------------------------
/dria_hnsw/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: docs
2 | docs:
3 | cargo doc --open --no-deps
4 |
5 | .PHONY: test
6 | test:
7 | cargo test
8 |
9 | .PHONY: lint
10 | lint:
11 | cargo clippy
12 |
13 |
--------------------------------------------------------------------------------
/dria_hnsw/README.md:
--------------------------------------------------------------------------------
1 | # Dria HNSW
2 |
3 | Dria HNSW is an API that allows you to permissionlessly search knowledge uploaded to Dria. It works over values downloaded from Arweave to a Redis cache, and reads these values directly from Redis.
4 | It is written in Rust, and several functions respect the machine architecture for efficiency.
5 |
6 | ## Setup
7 |
8 | To run the server, you need to provide a contract ID along with a RocksDB path:
9 |
10 | ```sh
11 | CONTRACT_ID= ROCKSDB_PATH="/path/to/rocksdb" cargo run
12 | ```
13 |
14 | Dria HNSW is available as a container:
15 |
16 | ```sh
17 | docker pull firstbatch/dria-hnsw
18 | ```
19 |
20 | > [!TIP]
21 | >
22 | > The docker image is cross-compiled & built for multiple architectures, so when you pull the image the most efficient code per your architecture will be downloaded!
23 |
24 | To see the available endpoints, refer to [this section](#endpoints) below.
25 |
26 | ## Endpoints
27 |
28 | Dria is an [Actix](https://actix.rs/) server with the following endpoints:
29 |
30 | - [`health`](#health)
31 | - [`fetch`](#fetch)
32 | - [`query`](#query)
33 | - [`insert_vector`](#insert_vector)
34 |
35 | All endpoints return a response in the following format:
36 |
37 | - `success`: a boolean indicating the success of request
38 | - `code`: status code
39 | - `data`: response data
40 |
41 | > [!TIP]
42 | >
43 | > If `success` is false, the error message will be written in `data` as a string.
44 |
45 | ### `HEALTH`
46 |
47 |
48 | ```ts
49 | GET /health
50 | ```
51 |
52 | **A simple healthcheck to see if the server is up.**
53 |
54 | Response data:
55 |
56 | - A string `"hello world!"`.
57 |
58 | ### `FETCH`
59 |
60 |
61 | ```ts
62 | POST /fetch
63 | ```
64 |
65 | **Given a list of ids, fetches their corresponding vectors.**
66 |
67 | Request body:
68 |
69 | - `id`: an array of integers
70 |
71 | Response data:
72 |
73 | - An array of metadatas, index `i` corresponding to metadata of vector with ID `id[i]`.
74 |
75 | ### `QUERY`
76 |
77 |
78 | ```ts
79 | POST /query
80 | ```
81 |
82 | **Given a list of ids, fetches their corresponding vectors.**
83 |
84 | Request body:
85 |
86 | - `vector`: an array of floats corresponding to the embedding vector
87 | - `top_n`: number of results to return
88 | - `query`: (_optional_) the text that belongs to given embedding, yields better results by looking for this text within the results
89 | - `level`: (_optional_) an integer value in range [0, 4] that defines the intensity of search, a larger values takes more time to complete but has higher recall
90 |
91 | Response data:
92 |
93 | - An array of objects with the following keys:
94 | - `id`: id of the returned vector
95 | - `score`: relevance score
96 | - `metadata`: metadata of the vector
97 |
98 | ### `INSERT_VECTOR`
99 |
100 |
101 | ```ts
102 | POST /insert_vector
103 | ```
104 |
105 | **Insert a new vector to HNSW.**
106 |
107 | Request body:
108 |
109 | - `vector`: an array of floats corresponding to the embedding vector
110 | - `metadata`: (_optional_) a JSON object that represent metadata for this vector
111 |
112 | Response data:
113 |
114 | - A string `"Success"`.
115 |
116 | ## Testing
117 |
118 | We have several tests that you can run with:
119 |
120 | ```sh
121 | cargo test
122 | ```
123 |
124 | Some tests expect a RocksDB folder present at `$HOME/.dria/data/WbcY2a-KfDpk7fsgumUtLC2bu4NQcVzNlXWi13fPMlU`, which can easily be downloaded with the [Dria CLI](https://github.com/firstbatchxyz/dria-cli/) if you do not have it:
125 |
126 | ```sh
127 | dria pull WbcY2a-KfDpk7fsgumUtLC2bu4NQcVzNlXWi13fPMlU
128 | ```
129 |
130 | The said knowledge is a rather lightweight knowledge that is useful for testing.
131 |
--------------------------------------------------------------------------------
/dria_hnsw/src/db/conversions.rs:
--------------------------------------------------------------------------------
1 | use crate::proto::index_buffer::{LayerNode, Point};
2 | use crate::proto::insert_buffer::{BatchStr, BatchVec, SingletonStr, SingletonVec};
3 | use prost::Message;
4 |
5 | pub fn point_to_base64(point: &Point) -> String {
6 | let mut bytes = Vec::new();
7 | point.encode(&mut bytes).expect("Failed to encode message");
8 | let enc = base64::encode(&bytes); // Convert bytes to string if needed
9 | enc
10 | }
11 |
12 | pub fn base64_to_point(e_point: &str) -> Point {
13 | let bytes = base64::decode(e_point).unwrap();
14 | let point = Point::decode(bytes.as_slice()).unwrap(); // Deserialize
15 | point
16 | }
17 |
18 | pub fn node_to_base64(node: &LayerNode) -> String {
19 | let mut bytes = Vec::new();
20 | node.encode(&mut bytes).expect("Failed to encode message");
21 | let enc = base64::encode(&bytes); // Convert bytes to string if needed
22 | enc
23 | }
24 |
25 | pub fn base64_to_node(e_node: &str) -> LayerNode {
26 | let bytes = base64::decode(e_node).unwrap();
27 | let node = LayerNode::decode(bytes.as_slice()).unwrap(); // Deserialize
28 | node
29 | }
30 |
31 | //*************** Batch to Singleton ***************//
32 |
33 | pub fn base64_to_batch_vec(batch: &str) -> BatchVec {
34 | let bytes = base64::decode(batch).unwrap();
35 | let node = BatchVec::decode(bytes.as_slice()).unwrap(); // Deserialize
36 | node
37 | }
38 |
39 | pub fn base64_to_singleton_vec(singleton: &str) -> SingletonVec {
40 | let bytes = base64::decode(singleton).unwrap();
41 | let node = SingletonVec::decode(bytes.as_slice()).unwrap(); // Deserialize
42 | node
43 | }
44 |
45 | pub fn base64_to_batch_str(batch: &str) -> BatchStr {
46 | let bytes = base64::decode(batch).unwrap();
47 | let node = BatchStr::decode(bytes.as_slice()).unwrap(); // Deserialize
48 | node
49 | }
50 |
51 | pub fn base64_to_singleton_str(singleton: &str) -> SingletonStr {
52 | let bytes = base64::decode(singleton).unwrap();
53 | let node = SingletonStr::decode(bytes.as_slice()).unwrap(); // Deserialize
54 | node
55 | }
56 |
57 | #[cfg(test)]
58 | mod tests {
59 | use super::*;
60 |
61 | #[test]
62 | fn test_point_to_base64() {
63 | let point = Point {
64 | idx: 1,
65 | v: vec![1.0, 2.0, 3.0],
66 | };
67 | let enc = point_to_base64(&point);
68 | let dec = base64_to_point(&enc);
69 | assert_eq!(point, dec);
70 | }
71 |
72 | #[test]
73 | fn test_node_to_base64() {
74 | let node = LayerNode {
75 | level: 1,
76 | idx: 1,
77 | visible: true,
78 | neighbors: std::collections::HashMap::new(),
79 | };
80 | let enc = node_to_base64(&node);
81 | let dec = base64_to_node(&enc);
82 | assert_eq!(node, dec);
83 | }
84 |
85 | #[test]
86 | fn test_node_to_base64_from_string() {
87 | let node = LayerNode {
88 | level: 1,
89 | idx: 1,
90 | visible: true,
91 | neighbors: std::collections::HashMap::new(),
92 | };
93 | let enc = "CAEQARgB".to_string();
94 | let dec = base64_to_node(&enc);
95 | assert_eq!(node, dec);
96 | }
97 |
98 | #[test]
99 | fn test_batch_to_singleton() {
100 | let singleton = SingletonVec {
101 | v: vec![1.0, 2.0, 3.0],
102 | map: std::collections::HashMap::new(),
103 | };
104 | let batch = BatchVec {
105 | s: vec![singleton.clone()],
106 | };
107 | let enc = base64::encode(&batch.encode_to_vec());
108 | let dec = base64_to_batch_vec(&enc);
109 | assert_eq!(batch, dec);
110 |
111 | let singleton = SingletonStr {
112 | v: "test".to_string(),
113 | map: std::collections::HashMap::new(),
114 | };
115 | let batch = BatchStr {
116 | s: vec![singleton.clone()],
117 | };
118 | let enc = base64::encode(&batch.encode_to_vec());
119 | let dec = base64_to_batch_str(&enc);
120 | assert_eq!(batch, dec);
121 | }
122 | }
123 |
--------------------------------------------------------------------------------
/dria_hnsw/src/db/env.rs:
--------------------------------------------------------------------------------
1 | use serde::Deserialize;
2 | use std::env;
3 |
4 | #[derive(Debug, Deserialize)]
5 | pub struct Config {
6 | env: String,
7 | debug: bool,
8 | pk_length: u32,
9 | sk_length: u32,
10 | rate_limit: String,
11 | global_rate_limit: String,
12 | logging_level: String,
13 | pub contract_id: String,
14 | pub redis_url: String,
15 | pub port: String,
16 | pub rocksdb_path: String,
17 | }
18 |
19 | impl Config {
20 | pub fn new() -> Config {
21 | let rocksdb_path = match env::var("ROCKSDB_PATH") {
22 | Ok(val) => val,
23 | Err(_) => "/tmp/rocksdb".to_string(),
24 | };
25 |
26 | let port = match env::var("PORT") {
27 | Ok(val) => val,
28 | Err(_) => "8080".to_string(),
29 | };
30 |
31 | let contract_id = match env::var("CONTRACT_ID") {
32 | Ok(val) => val,
33 | Err(_) => {
34 | println!("CONTRACT_ID not found, using default");
35 | "default".to_string()
36 | }
37 | };
38 |
39 | Config {
40 | env: "development".to_string(),
41 | debug: true,
42 | pk_length: 0,
43 | sk_length: 0,
44 | rate_limit: "".to_string(),
45 | global_rate_limit: "".to_string(),
46 | logging_level: "DEBUG".to_string(),
47 | contract_id,
48 | redis_url: "redis://127.0.0.1/".to_string(),
49 | port,
50 | rocksdb_path,
51 | }
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/dria_hnsw/src/db/mod.rs:
--------------------------------------------------------------------------------
1 | pub mod conversions;
2 | pub mod env;
3 | pub mod redis_client;
4 | pub mod rocksdb_client;
5 |
--------------------------------------------------------------------------------
/dria_hnsw/src/db/redis_client.rs:
--------------------------------------------------------------------------------
1 | extern crate redis;
2 | use redis::{Client, Commands, Connection};
3 |
4 | use crate::db::conversions::{base64_to_node, node_to_base64, point_to_base64};
5 | use crate::db::env::Config;
6 | use crate::errors::errors::DeserializeError;
7 | use crate::proto::index_buffer::{LayerNode, Point};
8 | use prost::Message;
9 | use serde_json::Value;
10 |
11 | pub struct RedisClient {
12 | client: Client,
13 | connection: Connection,
14 | tag: String,
15 | }
16 |
17 | impl RedisClient {
18 | pub fn new(contract_id: String) -> Result {
19 | let cfg = Config::new();
20 | let client =
21 | Client::open(cfg.redis_url).map_err(|_| DeserializeError::RedisConnectionError)?;
22 | let connection = client
23 | .get_connection()
24 | .map_err(|_| DeserializeError::RedisConnectionError)?;
25 |
26 | Ok(RedisClient {
27 | client,
28 | connection,
29 | tag: contract_id,
30 | })
31 | }
32 |
33 | pub fn set(&mut self, key: String, value: String) -> Result<(), DeserializeError> {
34 | let keys_local = format!("{}.value.{}", self.tag, key);
35 | let _: () = self
36 | .connection
37 | .set(&keys_local, &value)
38 | .map_err(|_| DeserializeError::RedisConnectionError)?;
39 | Ok(())
40 | }
41 |
42 | pub fn get_neighbor(
43 | &mut self,
44 | layer: usize,
45 | idx: usize,
46 | ) -> Result {
47 | let key = format!(
48 | "{}.value.{}:{}",
49 | self.tag,
50 | layer.to_string(),
51 | idx.to_string()
52 | );
53 |
54 | let node_str = match self
55 | .connection
56 | .get::<_, String>(key)
57 | .map_err(|_| DeserializeError::RedisConnectionError)
58 | {
59 | Ok(node_str) => node_str,
60 | Err(_) => {
61 | return Err(DeserializeError::RedisConnectionError);
62 | }
63 | };
64 |
65 | //let node_:Value = serde_json::from_str(&node_str).unwrap();
66 | Ok(base64_to_node(&node_str))
67 | }
68 |
69 | pub fn get_neighbors(
70 | &mut self,
71 | layer: usize,
72 | indices: Vec,
73 | ) -> Result, DeserializeError> {
74 | let keys = indices
75 | .iter()
76 | .map(|x| format!("{}.value.{}:{}", self.tag, layer.to_string(), x.to_string()))
77 | .collect::>();
78 |
79 | let values_str: Vec = self
80 | .connection
81 | .mget(keys)
82 | .map_err(|_| DeserializeError::RedisConnectionError)?;
83 |
84 | let neighbors = values_str
85 | .iter()
86 | .map(|s| {
87 | let bytes = base64::decode(s).unwrap();
88 | let p = LayerNode::decode(bytes.as_slice()).unwrap(); // Deserialize
89 | Ok(p)
90 | })
91 | .collect::, DeserializeError>>()?;
92 |
93 | Ok(neighbors)
94 | }
95 |
96 | pub fn upsert_neighbor(&mut self, node: LayerNode) -> Result<(), DeserializeError> {
97 | let key = format!("{}.value.{}:{}", self.tag, node.level, node.idx);
98 |
99 | let node_str = node_to_base64(&node);
100 | self.set(key, node_str)?;
101 |
102 | Ok(())
103 | }
104 |
105 | pub fn upsert_neighbors(&mut self, nodes: Vec) -> Result<(), DeserializeError> {
106 | let mut pairs = Vec::new();
107 | for node in nodes {
108 | let key = format!("{}.value.{}:{}", self.tag, node.level, node.idx);
109 | let node_str = node_to_base64(&node);
110 | pairs.push((key, node_str));
111 | }
112 |
113 | let _ = self
114 | .connection
115 | .mset(pairs.as_slice())
116 | .map_err(|_| DeserializeError::RedisConnectionError)?;
117 |
118 | Ok(())
119 | }
120 |
121 | pub fn get_points(&mut self, indices: &Vec) -> Result, DeserializeError> {
122 | let keys = indices
123 | .iter()
124 | .map(|x| format!("{}.value.{}", self.tag, x))
125 | .collect::>();
126 |
127 | if keys.is_empty() {
128 | return Ok(vec![]);
129 | }
130 | let values_str: Vec = self
131 | .connection
132 | .mget(keys)
133 | .map_err(|_| DeserializeError::RedisConnectionError)?;
134 |
135 | let points = values_str
136 | .into_iter()
137 | .map(|s| {
138 | let bytes = base64::decode(s).unwrap();
139 | let p = Point::decode(bytes.as_slice()).unwrap(); // Deserialize
140 | Ok(p)
141 | })
142 | .collect::, DeserializeError>>()?;
143 | Ok(points)
144 | }
145 |
146 | pub fn add_points_batch(
147 | &mut self,
148 | v: &Vec>,
149 | start_idx: usize,
150 | ) -> Result<(), DeserializeError> {
151 | let mut pairs = Vec::new();
152 |
153 | for (i, p) in v.iter().enumerate() {
154 | let idx = start_idx + i;
155 | let p = Point::new(p.clone(), idx);
156 | let p_str = point_to_base64(&p);
157 | pairs.push((format!("{}.value.{}", self.tag, idx), p_str));
158 | }
159 |
160 | let _ = self
161 | .connection
162 | .mset(pairs.as_slice())
163 | .map_err(|_| DeserializeError::RedisConnectionError)?;
164 | Ok(())
165 | }
166 |
167 | pub fn add_points(&mut self, v: Vec, idx: usize) -> Result<(), DeserializeError> {
168 | let p = Point::new(v, idx);
169 | let p_str = point_to_base64(&p);
170 |
171 | self.set(format!("{}.value.{}", self.tag, idx), p_str)?;
172 | Ok(())
173 | }
174 |
175 | pub fn set_datasize(&mut self, datasize: usize) -> Result<(), DeserializeError> {
176 | let _: () = self
177 | .connection
178 | .set(format!("{}.value.datasize", self.tag), datasize.to_string())
179 | .map_err(|_| DeserializeError::RedisConnectionError)?;
180 | Ok(())
181 | }
182 |
183 | pub fn get_datasize(&mut self) -> Result {
184 | let datasize: String = self
185 | .connection
186 | .get(format!("{}.value.datasize", self.tag))
187 | .map_err(|_| DeserializeError::RedisConnectionError)?;
188 | let datasize = datasize.parse::().unwrap();
189 | Ok(datasize)
190 | }
191 |
192 | pub fn set_num_layers(&mut self, num_layers: usize) -> Result<(), DeserializeError> {
193 | let _: () = self
194 | .connection
195 | .set(
196 | format!("{}.value.num_layers", self.tag),
197 | num_layers.to_string(),
198 | )
199 | .map_err(|_| DeserializeError::RedisConnectionError)?;
200 | Ok(())
201 | }
202 |
203 | pub fn get_num_layers(&mut self) -> Result {
204 | let num_layers: String = self
205 | .connection
206 | .get(format!("{}.value.num_layers", self.tag))
207 | .map_err(|_| DeserializeError::RedisConnectionError)?;
208 | let num_layers = num_layers.parse::().unwrap();
209 | Ok(num_layers)
210 | }
211 |
212 | pub fn set_ep(&mut self, ep: usize) -> Result<(), DeserializeError> {
213 | let _: () = self
214 | .connection
215 | .set(format!("{}.value.ep", self.tag), ep.to_string())
216 | .map_err(|_| DeserializeError::RedisConnectionError)?;
217 | Ok(())
218 | }
219 |
220 | pub fn get_ep(&mut self) -> Result {
221 | let ep: String = self
222 | .connection
223 | .get(format!("{}.value.ep", self.tag))
224 | .map_err(|_| DeserializeError::RedisConnectionError)?;
225 | let ep = ep.parse::().unwrap();
226 | Ok(ep)
227 | }
228 |
229 | pub fn set_metadata_batch(
230 | &mut self,
231 | metadata: Vec,
232 | idx: usize,
233 | ) -> Result<(), DeserializeError> {
234 | let mut pairs = Vec::new();
235 | for (i, m) in metadata.iter().enumerate() {
236 | let key = format!("{}.value.m:{}", self.tag, idx + i);
237 | let metadata_str = serde_json::to_string(&m).unwrap();
238 | pairs.push((key, metadata_str));
239 | }
240 | let _ = self
241 | .connection
242 | .mset(pairs.as_slice())
243 | .map_err(|_| DeserializeError::RedisConnectionError)?;
244 | Ok(())
245 | }
246 |
247 | pub fn set_metadata(&mut self, metadata: Value, idx: usize) -> Result<(), DeserializeError> {
248 | let key = format!("{}.value.m:{}", self.tag, idx);
249 | let metadata_str = serde_json::to_string(&metadata).unwrap();
250 | self.set(key, metadata_str)?;
251 | Ok(())
252 | }
253 |
254 | pub fn get_metadata(&mut self, idx: usize) -> Result {
255 | let key = format!("{}.value.m:{}", self.tag, idx);
256 | let metadata_str: String = self
257 | .connection
258 | .get(&key)
259 | .map_err(|_| DeserializeError::RedisConnectionError)?;
260 | let metadata: Value = serde_json::from_str(&metadata_str).unwrap();
261 | Ok(metadata)
262 | }
263 |
264 | pub fn get_metadatas(&mut self, indices: Vec) -> Result, DeserializeError> {
265 | let keys = indices
266 | .iter()
267 | .map(|x| format!("{}.value.m:{}", self.tag, x))
268 | .collect::>();
269 |
270 | let metadata_str: Vec = self
271 | .connection
272 | .mget(&keys)
273 | .map_err(|_| DeserializeError::RedisConnectionError)?;
274 |
275 | let metadata = metadata_str
276 | .into_iter()
277 | .map(|s| {
278 | let m: Value = serde_json::from_str(&s).unwrap();
279 | Ok(m)
280 | })
281 | .collect::, DeserializeError>>()?;
282 |
283 | Ok(metadata)
284 | }
285 | }
286 |
--------------------------------------------------------------------------------
/dria_hnsw/src/db/rocksdb_client.rs:
--------------------------------------------------------------------------------
1 | use crate::db::conversions::{base64_to_node, base64_to_point, node_to_base64, point_to_base64};
2 | use crate::db::env::Config;
3 | use crate::errors::errors::DeserializeError;
4 | use crate::proto::index_buffer::{LayerNode, Point};
5 | use prost::Message;
6 | use rocksdb;
7 | use rocksdb::{DBWithThreadMode, Options, WriteBatch, DB};
8 | use serde_json::Value;
9 |
10 | #[derive(Debug)]
11 | pub struct RocksdbClient {
12 | tag: String,
13 | client: DB,
14 | }
15 |
16 | impl RocksdbClient {
17 | pub fn new(contract_id: String) -> Result {
18 | let cfg = Config::new();
19 | // Create a new database options instance.
20 | let mut opts = Options::default();
21 | opts.create_if_missing(true); // Creates a database if it does not exist.
22 | //let x = DBWithThreadMode::open(&opts, cfg.rocksdb_path);
23 | let db = DB::open(&opts, cfg.rocksdb_path).unwrap();
24 |
25 | Ok(RocksdbClient {
26 | tag: contract_id,
27 | client: db,
28 | })
29 | }
30 |
31 | pub fn set(&self, key: String, value: String) -> Result<(), DeserializeError> {
32 | let _: () = self
33 | .client
34 | .put(key.as_bytes(), value.as_bytes())
35 | .map_err(|_| DeserializeError::RedisConnectionError)?;
36 | Ok(())
37 | }
38 |
39 | pub fn get_neighbor(&self, layer: usize, idx: usize) -> Result {
40 | let key = format!("{}.value.{}:{}", self.tag, layer, idx);
41 |
42 | let value = self
43 | .client
44 | .get(key.as_bytes())
45 | .map_err(|_| DeserializeError::RocksDBConnectionError)?; // Handle RocksDB errors appropriately
46 |
47 | let node_str = match value {
48 | Some(value) => String::from_utf8(value).map_err(|_| DeserializeError::InvalidForm)?, // Convert bytes to String and handle UTF-8 error
49 | None => return Err(DeserializeError::MissingKey), // Handle case where key is not found
50 | };
51 |
52 | Ok(base64_to_node(&node_str))
53 | }
54 |
55 | pub fn get_neighbors(
56 | &self,
57 | layer: usize,
58 | indices: Vec,
59 | ) -> Result, DeserializeError> {
60 | // Collect keys as a Vec> for multi_get
61 | let keys = indices
62 | .iter()
63 | .map(|&x| format!("{}.value.{}:{}", self.tag, layer, x).into_bytes())
64 | .collect::>>();
65 |
66 | // Use multi_get to fetch values for all keys at once
67 | let values = self.client.multi_get(keys);
68 |
69 | let mut neighbors = Vec::new();
70 | for value_result in values {
71 | // Correctly handle the Result>, E> for each value
72 | match value_result {
73 | Ok(Some(v)) => {
74 | let node_str =
75 | String::from_utf8(v).map_err(|_| DeserializeError::InvalidForm)?; // Convert bytes to String and handle UTF-8 error
76 | let node = base64_to_node(&node_str); // Convert String to LayerNode and handle base64 error
77 | neighbors.push(node);
78 | }
79 | Ok(None) => return Err(DeserializeError::MissingKey), // Handle case where key is not found
80 | Err(_) => return Err(DeserializeError::RocksDBConnectionError), // Handle error in fetching value
81 | }
82 | }
83 |
84 | Ok(neighbors)
85 | }
86 |
87 | pub fn upsert_neighbor(&self, node: LayerNode) -> Result<(), DeserializeError> {
88 | let key = format!("{}.value.{}:{}", self.tag, node.level, node.idx);
89 |
90 | let node_str = node_to_base64(&node);
91 | self.set(key, node_str)?;
92 |
93 | Ok(())
94 | }
95 |
96 | pub fn upsert_neighbors(&self, nodes: Vec) -> Result<(), DeserializeError> {
97 | let mut batch = WriteBatch::default();
98 | for node in nodes {
99 | let key = format!("{}.value.{}:{}", self.tag, node.level, node.idx);
100 | let node_str = node_to_base64(&node);
101 | batch.put(key.as_bytes(), node_str.as_bytes());
102 | }
103 |
104 | let _ = self
105 | .client
106 | .write(batch)
107 | .map_err(|_| DeserializeError::RocksDBConnectionError)?;
108 |
109 | Ok(())
110 | }
111 |
112 | pub fn get_points(&self, indices: &Vec) -> Result, DeserializeError> {
113 | let keys = indices
114 | .iter()
115 | .map(|x| format!("{}.value.{}", self.tag, x).into_bytes())
116 | .collect::>>();
117 |
118 | if keys.is_empty() {
119 | return Ok(vec![]);
120 | }
121 |
122 | // Assuming multi_get directly returns Vec>, E>>
123 | let values = self.client.multi_get(keys);
124 |
125 | let mut points = Vec::new();
126 | for value_result in values {
127 | match value_result {
128 | Ok(Some(value)) => {
129 | let point_str =
130 | String::from_utf8(value).map_err(|_| DeserializeError::InvalidForm)?; // Handle UTF-8 conversion error
131 | let point = base64_to_point(&point_str); // Handle potential error from base64_to_point
132 | points.push(point);
133 | }
134 | Ok(None) => return Err(DeserializeError::MissingKey), // Key not found
135 | Err(_) => return Err(DeserializeError::RocksDBConnectionError), // Error fetching from RocksDB
136 | }
137 | }
138 |
139 | Ok(points)
140 | }
141 |
142 | pub fn add_points(&self, v: Vec, idx: usize) -> Result<(), DeserializeError> {
143 | let p = Point::new(v, idx);
144 | let p_str = point_to_base64(&p);
145 | let key = format!("{}.value.{}", self.tag, idx).into_bytes();
146 | self.client
147 | .put(key, p_str.as_bytes())
148 | .map_err(|_| DeserializeError::RocksDBConnectionError)?;
149 | //self.put_multi_hashtag(&[idx.to_string()], &[json!(p_str)], false)?;
150 | Ok(())
151 | }
152 |
153 | pub fn add_points_batch(
154 | &self,
155 | v: &Vec>,
156 | start_idx: usize,
157 | ) -> Result<(), DeserializeError> {
158 | let mut batch = WriteBatch::default();
159 | for (i, p) in v.iter().enumerate() {
160 | let idx = start_idx + i;
161 | let p = Point::new(p.clone(), idx);
162 | let p_str = point_to_base64(&p);
163 | let key = format!("{}.value.{}", self.tag, idx).into_bytes();
164 |
165 | //keys.push(idx.to_string());
166 | //values.push(json!(p_str));
167 |
168 | batch.put(key, p_str.as_bytes());
169 | }
170 |
171 | self.client
172 | .write(batch)
173 | .map_err(|_| DeserializeError::RocksDBConnectionError)?;
174 |
175 | Ok(())
176 | }
177 |
178 | pub fn set_datasize(&self, datasize: usize) -> Result<(), DeserializeError> {
179 | //self.put_multi_hashtag(&["datasize".to_string()], &[json!(datasize)], false)?;
180 | self.client
181 | .put(
182 | format!("{}.value.datasize", self.tag).into_bytes(),
183 | datasize.to_string().as_bytes(),
184 | )
185 | .map_err(|_| DeserializeError::RocksDBConnectionError)?;
186 | Ok(())
187 | }
188 |
189 | pub fn get_datasize(&self) -> Result {
190 | let datasize_key: String = format!("{}.value.datasize", self.tag);
191 | let value = self
192 | .client
193 | .get(datasize_key.as_bytes())
194 | .map_err(|_| DeserializeError::RocksDBConnectionError)?;
195 |
196 | let datasize = match value {
197 | Some(value_bytes) => {
198 | let value_str =
199 | String::from_utf8(value_bytes).map_err(|_| DeserializeError::InvalidForm)?; // Handle UTF-8 error gracefully
200 | value_str
201 | .parse::()
202 | .map_err(|_| DeserializeError::InvalidForm)? // Handle parse error gracefully
203 | }
204 | None => return Err(DeserializeError::MissingKey), // Handle case where key is not found
205 | };
206 | Ok(datasize)
207 | }
208 |
209 | pub fn get_num_layers(&self) -> Result {
210 | let num_layers_key: String = format!("{}.value.num_layers", self.tag);
211 | let value = self
212 | .client
213 | .get(num_layers_key.as_bytes())
214 | .map_err(|_| DeserializeError::RocksDBConnectionError)?;
215 |
216 | let num_layers = match value {
217 | Some(value_bytes) => {
218 | let value_str =
219 | String::from_utf8(value_bytes).map_err(|_| DeserializeError::InvalidForm)?; // Handle UTF-8 error gracefully
220 | value_str
221 | .parse::()
222 | .map_err(|_| DeserializeError::InvalidForm)? // Handle parse error gracefully
223 | }
224 | None => return Err(DeserializeError::MissingKey), // Handle case where key is not found
225 | };
226 | Ok(num_layers)
227 | }
228 |
229 | pub fn set_num_layers(&self, num_layers: usize, expire: bool) -> Result<(), DeserializeError> {
230 | self.client
231 | .put(
232 | format!("{}.value.num_layers", self.tag).into_bytes(),
233 | num_layers.to_string().as_bytes(),
234 | )
235 | .map_err(|_| DeserializeError::RocksDBConnectionError)?;
236 | Ok(())
237 | }
238 |
239 | pub fn set_ep(&self, ep: usize, expire: bool) -> Result<(), DeserializeError> {
240 | self.client
241 | .put(
242 | format!("{}.value.ep", self.tag).into_bytes(),
243 | ep.to_string().as_bytes(),
244 | )
245 | .map_err(|_| DeserializeError::RocksDBConnectionError)?;
246 | Ok(())
247 | }
248 |
249 | pub fn get_ep(&self) -> Result {
250 | let ep_key: String = format!("{}.value.ep", self.tag);
251 | let value = self
252 | .client
253 | .get(ep_key.as_bytes())
254 | .map_err(|_| DeserializeError::RocksDBConnectionError)?;
255 |
256 | // Attempt to convert the fetched value from bytes to String, then parse it into usize
257 | let ep_usize = match value {
258 | Some(value_bytes) => {
259 | // Convert bytes to String
260 | let value_str =
261 | String::from_utf8(value_bytes).map_err(|_| DeserializeError::InvalidForm)?; // Handle UTF-8 error gracefully
262 | // Parse String to usize
263 | value_str
264 | .parse::()
265 | .map_err(|_| DeserializeError::InvalidForm)? // Handle parse error gracefully
266 | }
267 | None => return Err(DeserializeError::MissingKey), // Handle case where key is not found
268 | };
269 | Ok(ep_usize)
270 | }
271 |
272 | pub fn set_metadata(&self, metadata: Value, idx: usize) -> Result<(), DeserializeError> {
273 | let key = format!("{}.value.m:{}", self.tag, idx);
274 | let metadata_str = serde_json::to_vec(&metadata).unwrap();
275 | self.client
276 | .put(key.as_bytes(), metadata_str)
277 | .map_err(|_| DeserializeError::RocksDBConnectionError)?;
278 | Ok(())
279 | }
280 |
281 | pub fn set_metadata_batch(
282 | &self,
283 | metadata: Vec,
284 | idx: usize,
285 | ) -> Result<(), DeserializeError> {
286 | let mut batch = WriteBatch::default();
287 |
288 | for (i, m) in metadata.iter().enumerate() {
289 | let key = format!("{}.value.m:{}", self.tag, idx + i);
290 | let metadata_str = serde_json::to_vec(&m).unwrap();
291 | batch.put(key.as_bytes(), metadata_str);
292 | }
293 | self.client
294 | .write(batch)
295 | .map_err(|_| DeserializeError::RocksDBConnectionError)?;
296 | Ok(())
297 | }
298 |
299 | pub fn get_metadata(&self, idx: usize) -> Result {
300 | let key = format!("{}.value.m:{}", self.tag, idx);
301 |
302 | let value = self
303 | .client
304 | .get(key.as_bytes())
305 | .map_err(|_| DeserializeError::RocksDBConnectionError)?;
306 |
307 | let metadata = match value {
308 | Some(value) => serde_json::from_slice(&value).unwrap()
309 | , // Convert bytes to String and handle UTF-8 error
310 | None => return Err(DeserializeError::MissingKey), // Handle case where key is not found
311 | };
312 | Ok(metadata)
313 | }
314 |
315 | pub fn get_metadatas(&self, indices: Vec) -> Result, DeserializeError> {
316 | let keys = indices
317 | .iter()
318 | .map(|x| format!("{}.value.m:{}", self.tag, x).as_bytes().to_vec())
319 | .collect::>>();
320 |
321 | // Assuming multi_get returns Vec>, E>> directly
322 | let values = self.client.multi_get(&keys);
323 |
324 | let mut metadata = Vec::new();
325 | for value_result in values {
326 | match value_result {
327 | Ok(Some(v)) => {
328 | // Properly handle potential serde_json deserialization errors
329 | match serde_json::from_slice::(&v) {
330 | Ok(meta) => metadata.push(meta),
331 | Err(_) => return Err(DeserializeError::InvalidForm), // Add a DeserializationError variant if not already present
332 | }
333 | }
334 | Ok(None) => return Err(DeserializeError::MissingKey), // Key not found
335 | Err(_) => return Err(DeserializeError::RocksDBConnectionError), // Error fetching from RocksDB
336 | }
337 | }
338 |
339 | Ok(metadata)
340 | }
341 | }
342 |
--------------------------------------------------------------------------------
/dria_hnsw/src/errors/errors.rs:
--------------------------------------------------------------------------------
1 | use actix_web::{
2 | error::ResponseError,
3 | http::{header::ContentType, StatusCode},
4 | HttpResponse,
5 | };
6 | use derive_more::{Display, Error};
7 | use std::fmt;
8 |
9 | #[derive(Debug)]
10 | pub enum DeserializeError {
11 | MissingKey,
12 | InvalidForm,
13 | RocksDBConnectionError,
14 | RedisConnectionError,
15 | DNSResolverError,
16 | ClusterConnectionError,
17 | }
18 |
19 | impl fmt::Display for DeserializeError {
20 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
21 | match *self {
22 | DeserializeError::MissingKey => write!(f, "Key is missing in the response"),
23 | DeserializeError::InvalidForm => write!(f, "Value is not in the expected format"),
24 | DeserializeError::RocksDBConnectionError => write!(f, "Error connecting to RocksDB"),
25 | DeserializeError::RedisConnectionError => write!(f, "Error connecting to Redis"),
26 | DeserializeError::DNSResolverError => write!(f, "Error resolving DNS"),
27 | DeserializeError::ClusterConnectionError => {
28 | write!(f, "Error connecting to Cluster at init")
29 | }
30 | }
31 | }
32 | }
33 |
34 | impl std::error::Error for DeserializeError {}
35 |
36 | #[derive(Debug, Display, Error)]
37 | pub enum MiddlewareError {
38 | #[display(fmt = "internal error")]
39 | InternalError,
40 |
41 | #[display(fmt = "api key not found")]
42 | APIKeyError,
43 |
44 | #[display(fmt = "timeout")]
45 | Timeout,
46 | }
47 |
48 | impl ResponseError for MiddlewareError {
49 | fn error_response(&self) -> HttpResponse {
50 | HttpResponse::build(self.status_code())
51 | .insert_header(ContentType::html())
52 | .body(self.to_string())
53 | }
54 |
55 | fn status_code(&self) -> StatusCode {
56 | match *self {
57 | MiddlewareError::InternalError => StatusCode::INTERNAL_SERVER_ERROR,
58 | MiddlewareError::APIKeyError => StatusCode::UNAUTHORIZED,
59 | MiddlewareError::Timeout => StatusCode::GATEWAY_TIMEOUT,
60 | }
61 | }
62 | }
63 |
64 | #[derive(Debug)]
65 | pub struct ValidationError(pub String);
66 |
67 | impl fmt::Display for ValidationError {
68 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
69 | write!(f, "{}", self.0)
70 | }
71 | }
72 |
73 | impl std::error::Error for ValidationError {}
74 |
--------------------------------------------------------------------------------
/dria_hnsw/src/errors/mod.rs:
--------------------------------------------------------------------------------
1 | pub mod errors;
2 |
--------------------------------------------------------------------------------
/dria_hnsw/src/filter/mod.rs:
--------------------------------------------------------------------------------
1 | pub mod text_based;
--------------------------------------------------------------------------------
/dria_hnsw/src/filter/text_based.rs:
--------------------------------------------------------------------------------
1 | use probly_search::score::ScoreCalculator;
2 | use probly_search::score::{bm25, zero_to_one};
3 | use probly_search::Index;
4 | use serde_json::{json, Value};
5 | use std::borrow::Cow;
6 |
7 | pub struct Doc {
8 | pub id: usize,
9 | pub text: String,
10 | }
11 |
12 | fn tokenizer(s: &str) -> Vec> {
13 | s.split(' ').map(Cow::from).collect::>()
14 | }
15 |
16 | fn text_extract(d: &Doc) -> Vec<&str> {
17 | vec![d.text.as_str()]
18 | }
19 |
20 | pub fn create_index_from_docs(
21 | index: &mut Index,
22 | query: &str,
23 | metadata: Vec,
24 | ) -> Vec {
25 | let mut wikis = Vec::new();
26 | let mut query_results = Vec::new();
27 | let mut ids = Vec::new();
28 | let mut iter = 0;
29 |
30 | for value in metadata.iter() {
31 | let mut text = value["metadata"]["text"].as_str();
32 | if text.is_none() {
33 | // use whole metadata as text
34 | text = value["metadata"].as_str();
35 | }
36 |
37 | let id_doc = value["id"].as_u64().unwrap() as usize;
38 | // let url = value["metadata"]["url"].as_str();
39 |
40 | let t = text.unwrap().to_string();
41 |
42 | let sentences = t.split(".");
43 | for sentence in sentences {
44 | let wiki = Doc {
45 | id: iter,
46 | text: sentence.to_string(),
47 | };
48 | wikis.push(wiki);
49 |
50 | let mut value_x = value.clone();
51 | value_x["metadata"]["text"] = json!(sentence.to_string());
52 | query_results.push(value_x);
53 | ids.push(id_doc);
54 | iter += 1;
55 | }
56 | }
57 | if wikis.len() == 0 {
58 | return metadata;
59 | }
60 |
61 | for wiki in wikis.iter() {
62 | index.add_document(&[text_extract], tokenizer, wiki.id.clone(), &wiki);
63 | }
64 |
65 | let results = index.query(query, &mut zero_to_one::new(), tokenizer, &[1.]);
66 | let mut results_as_wiki = vec![];
67 | for res in results.iter() {
68 | let val = json!({"id": ids[res.key], "metadata": query_results[res.key].clone(), "score": res.score});
69 | results_as_wiki.push(val);
70 | }
71 | return results_as_wiki;
72 | }
73 |
--------------------------------------------------------------------------------
/dria_hnsw/src/hnsw/index.rs:
--------------------------------------------------------------------------------
1 | #![allow(non_snake_case)]
2 |
3 | extern crate redis;
4 |
5 | use actix_web::web::Data;
6 | use hashbrown::HashSet;
7 | use mini_moka::sync::Cache;
8 | use redis::Commands;
9 | use std::borrow::Borrow;
10 | use std::cmp::Reverse;
11 | use std::collections::HashMap;
12 | use std::sync::atomic::{AtomicIsize, AtomicUsize, Ordering};
13 | use std::sync::Arc;
14 |
15 | use simsimd::SimSIMD;
16 |
17 | use rand::{thread_rng, Rng, SeedableRng};
18 |
19 | use crate::proto::index_buffer::{LayerNode, Point};
20 | use prost::Message;
21 |
22 | use crate::errors::errors::DeserializeError;
23 | use crate::hnsw::utils::{create_max_heap, create_min_heap, IntoHeap, IntoMap, Numeric};
24 |
25 | use crate::hnsw::scalar::ScalarQuantizer;
26 | use rayon::prelude::*;
27 | use serde_json::{json, Value};
28 |
29 | use crate::db::rocksdb_client::RocksdbClient;
30 | use crate::hnsw::sync_map::SynchronizedNodes;
31 |
32 | pub const SINGLE_THREADED_HNSW_BUILD_THRESHOLD: usize = 256;
33 |
34 | /*
35 | Redis Scheme
36 |
37 | Points: "0", "1" ...
38 | Graph: graph_level.idx : "2.5320" layer 2, node idx 5320
39 | */
40 | pub struct HNSW {
41 | pub m: usize,
42 | pub m_max0: usize,
43 | pub rng_seed: u64,
44 | pub ml: f32,
45 | pub ef_construction: usize,
46 | pub ef: usize,
47 | pub db: Data,
48 | quantizer: ScalarQuantizer,
49 | metric: Option,
50 | }
51 |
52 | impl HNSW {
53 | pub fn new(
54 | M: usize,
55 | ef_construction: usize,
56 | ef: usize,
57 | //contract_id: String,
58 | metric: Option,
59 | db: Data,
60 | ) -> HNSW {
61 | let m = M;
62 | let m_max0 = M * 2;
63 | let ml = 1.0 / (M as f32).ln();
64 | //let db = RocksdbClient::new(contract_id).expect("Error creating RocksdbClient");
65 | let sq = ScalarQuantizer::new(256, 1000, 256);
66 |
67 | HNSW {
68 | m,
69 | m_max0,
70 | rng_seed: 0,
71 | ml,
72 | ef_construction,
73 | ef,
74 | db,
75 | quantizer: sq,
76 | metric,
77 | }
78 | }
79 |
80 | pub fn set_rng_seed(&mut self, seed: u64) {
81 | self.rng_seed = seed;
82 | }
83 |
84 | pub fn set_ef(&mut self, ef: usize) {
85 | self.ef = ef;
86 | }
87 |
88 | pub fn select_layer(&self) -> usize {
89 | let mut random = thread_rng();
90 | let rand_float: f32 = random.gen_range(1e-6..1.0); // Avoid very small values
91 | let result = (-1.0 * rand_float.ln() * self.ml) as usize;
92 |
93 | // Optionally clamp to a maximum value if applicable
94 | let max_layer = 1000; // Example maximum layer
95 | std::cmp::min(result, max_layer)
96 | }
97 |
98 | fn distance(&self, x: &[f32], y: &[f32], dist: &Option) -> f32 {
99 | let dist = match dist.as_ref().map(String::as_str) {
100 | Some("sqeuclidean") => SimSIMD::sqeuclidean(x, y),
101 | Some("inner") => SimSIMD::inner(x, y),
102 | Some("cosine") | None => SimSIMD::cosine(x, y),
103 | _ => panic!("Unsupported distance metric"),
104 | };
105 | if dist.is_none() {
106 | println!("Error in distance"); //make the error propagate
107 | }
108 | dist.unwrap()
109 | }
110 |
111 | fn get_points_w_memory(
112 | &self,
113 | indices: &Vec,
114 | point_map: Cache,
115 | ) -> Vec {
116 | // Initialize points with None to reserve the space and maintain order
117 | let mut points: Vec> = vec![None; indices.len()];
118 |
119 | // Track missing indices and their positions
120 | let mut missing_indices_with_pos: Vec<(usize, u32)> = Vec::new();
121 |
122 | for (pos, idx) in indices.iter().enumerate() {
123 | let key = format!("p:{}", idx);
124 | if let Some(point) = point_map.get(&key) {
125 | points[pos] = Some(point.clone());
126 | } else {
127 | missing_indices_with_pos.push((pos, *idx));
128 | }
129 | }
130 |
131 | if !missing_indices_with_pos.is_empty() {
132 | let missing_indices: Vec = missing_indices_with_pos
133 | .iter()
134 | .map(|&(_, idx)| idx)
135 | .collect();
136 |
137 | let fetched_points = self.db.get_points(&missing_indices);
138 |
139 | if fetched_points.is_err() {
140 | println!(
141 | "Error getting points, get points _w _memory {:?}",
142 | &missing_indices
143 | );
144 | }
145 |
146 | let fetched_points = fetched_points.unwrap();
147 |
148 | for point in fetched_points {
149 | let key = format!("p:{}", point.idx);
150 | point_map.insert(key, point.clone());
151 |
152 | if let Some(&(pos, _)) = missing_indices_with_pos
153 | .iter()
154 | .find(|&&(_, idx)| idx == point.idx)
155 | {
156 | points[pos] = Some(point);
157 | }
158 | }
159 | }
160 | points.into_iter().filter_map(|p| p).collect()
161 | }
162 |
163 | fn get_neighbors_w_memory(
164 | &self,
165 | layer: usize,
166 | indices: &Vec,
167 | node_map: Arc,
168 | ) -> Vec {
169 | let mut nodes = Vec::with_capacity(indices.len());
170 | let mut missing_indices = Vec::new();
171 |
172 | // First pass: Fill in nodes from node_map or mark as missing
173 | for &idx in indices {
174 | let key = format!("{}:{}", layer, idx);
175 | match node_map.get_or_wait_opt(&key) {
176 | Some(node) => nodes.push(node.clone()),
177 | None => {
178 | missing_indices.push(idx);
179 | nodes.push(LayerNode::new(0, 0)); // Placeholder for missing nodes
180 | }
181 | }
182 | }
183 | if !missing_indices.is_empty() {
184 | let fetched_nodes = self
185 | .db
186 | .get_neighbors(layer, missing_indices)
187 | .expect("Error getting neighbors");
188 |
189 | for fetched_node in fetched_nodes.iter() {
190 | let index = indices.iter().position(|&i| i == fetched_node.idx).unwrap();
191 | nodes[index] = fetched_node.clone();
192 | }
193 | node_map.insert_batch_and_notify(fetched_nodes);
194 | }
195 |
196 | nodes
197 | }
198 |
199 | fn get_neighbor_w_memory(
200 | &self,
201 | layer: usize,
202 | idx: usize,
203 | node_map: Arc,
204 | ) -> LayerNode {
205 | let key = format!("{}:{}", layer, idx);
206 |
207 | let node_option = node_map.get_or_wait_opt(&key);
208 | return if let Some(node) = node_option {
209 | node.clone()
210 | } else {
211 | let node_ = self.db.get_neighbor(layer, idx);
212 | if node_.is_err() {
213 | println!("Sync issue, awaiting notification...");
214 | let value = node_map.get_or_wait(&key);
215 | return value.clone();
216 | }
217 | let node_ = node_.unwrap();
218 | node_map.insert_and_notify(&node_);
219 | node_
220 | };
221 | }
222 |
223 | pub fn insert_w_preset(
224 | &self,
225 | idx: usize,
226 | node_map: Arc,
227 | point_map: Cache,
228 | nl: Arc,
229 | epa: Arc,
230 | ) -> Result<(), DeserializeError> {
231 | let mut W = HashMap::new();
232 |
233 | let mut ep_index = None;
234 | let ep_index_ = epa.load(Ordering::SeqCst);
235 | if ep_index_ != -1 {
236 | ep_index = Some(ep_index_ as u32);
237 | }
238 |
239 | let mut num_layers = nl.load(Ordering::Relaxed);
240 |
241 | let L = if num_layers == 0 { 0 } else { num_layers - 1 };
242 | let l = self.select_layer();
243 |
244 | let qs = self.get_points_w_memory(&vec![idx as u32], point_map.clone());
245 | let q = qs[0].v.clone();
246 |
247 | if ep_index.is_some() {
248 | let ep_index_ = ep_index.unwrap();
249 |
250 | let points = self.get_points_w_memory(&vec![ep_index_], point_map.clone());
251 | let point = points.first().unwrap();
252 | let dist = self.distance(&q, &point.v, &self.metric);
253 | let mut ep = HashMap::from([(ep_index_, dist)]);
254 |
255 | for i in ((l + 1)..=L).rev() {
256 | W = self.search_layer(&q, ep.clone(), 1, i, node_map.clone(), point_map.clone())?;
257 |
258 | if let Some((_, value)) = W.iter().next() {
259 | if &dist < value {
260 | ep = W;
261 | }
262 | }
263 | }
264 |
265 | for l_c in (0..=std::cmp::min(L, l)).rev() {
266 | W = self.search_layer(
267 | &q,
268 | ep,
269 | self.ef_construction,
270 | l_c,
271 | node_map.clone(),
272 | point_map.clone(),
273 | )?;
274 |
275 | //upsert expire = true by default, populate upserted_keys for replication
276 | node_map.insert_and_notify(&LayerNode::new(l_c, idx));
277 |
278 | ep = W.clone();
279 | let neighbors = self.select_neighbors(&q, W, l_c, true);
280 |
281 | let M = if l_c == 0 { self.m_max0 } else { self.m };
282 |
283 | //read neighbors of all nodes in selected neighbors
284 | let mut indices = neighbors.iter().map(|x| *x.0).collect::>();
285 | indices.push(idx as u32);
286 | let idx_i = indices.len() - 1;
287 |
288 | let mut nodes = self.get_neighbors_w_memory(l_c, &indices, node_map.clone());
289 |
290 | for (i, (e_i, dist)) in neighbors.iter().enumerate() {
291 | if i == idx_i {
292 | // We want to skip last layernode, which is idx -> layernode
293 | continue;
294 | }
295 |
296 | nodes[i].neighbors.insert(idx as u32, *dist);
297 | nodes[idx_i].neighbors.insert(*e_i, *dist);
298 | }
299 |
300 | // TODO: remove redundant
301 | for (i, (e_i, dist)) in neighbors.iter().enumerate() {
302 | if i == idx_i {
303 | // We want to skip last layernode, which is idx -> layernode
304 | continue;
305 | }
306 | let eConn = nodes[i].neighbors.clone();
307 | if eConn.len() > M {
308 | let eNewConn = self.select_neighbors(&q, eConn, l_c, true);
309 | nodes[i].neighbors = eNewConn.clone();
310 | }
311 | }
312 |
313 | node_map.insert_batch_and_notify(nodes);
314 | }
315 | }
316 |
317 | for i in num_layers..=l {
318 | node_map.insert_and_notify(&LayerNode::new(i, idx));
319 | let _ = epa.fetch_update(Ordering::SeqCst, Ordering::Relaxed, |x| Some(idx as isize));
320 | }
321 |
322 | let _ = nl.fetch_update(Ordering::SeqCst, Ordering::Acquire, |v| {
323 | if l + 1 > v {
324 | Some(l + 1)
325 | } else {
326 | None
327 | }
328 | });
329 |
330 | Ok(())
331 | }
332 |
333 | fn search_layer(
334 | &self,
335 | q: &Vec,
336 | ep: HashMap,
337 | ef: usize,
338 | l_c: usize,
339 | node_map: Arc,
340 | point_map: Cache,
341 | ) -> Result, DeserializeError> {
342 | let mut v = HashSet::new();
343 |
344 | for (k, _) in ep.iter() {
345 | v.insert(k.clone());
346 | }
347 |
348 | let mut C = ep.clone().into_minheap();
349 | let mut W = ep.into_maxheap();
350 |
351 | while !C.is_empty() {
352 | let c = C.pop().unwrap().0;
353 | let f_value = W.peek().unwrap().0 .0;
354 |
355 | if c.0 .0 > f_value {
356 | break;
357 | }
358 |
359 | let layernd = self.get_neighbor_w_memory(l_c, c.1 as usize, node_map.clone());
360 |
361 | let mut pairs: Vec<_> = layernd.neighbors.into_iter().collect();
362 | //pairs.sort_by(|&(_, a), &(_, b)| a.partial_cmp(&b).unwrap());
363 | pairs.sort_by(|&(_, a), &(_, b)| {
364 | if a.is_nan() || b.is_nan() {
365 | println!("NaN value detected: a = {}, b = {}", a, b);
366 | std::cmp::Ordering::Greater
367 | } else {
368 | a.partial_cmp(&b).unwrap_or_else(|| {
369 | println!("Unexpected comparison error: a = {}, b = {}", a, b);
370 | std::cmp::Ordering::Greater
371 | })
372 | }
373 | });
374 | let sorted_keys: Vec = pairs.into_iter().map(|(k, _)| k).collect();
375 |
376 | let neighbors: Vec = sorted_keys
377 | .into_iter()
378 | .filter_map(|x| if !v.contains(&x) { Some(x) } else { None })
379 | .collect();
380 |
381 | let points = self.get_points_w_memory(&neighbors, point_map.clone());
382 |
383 | let distances = points
384 | .iter()
385 | .map(|x| self.distance(&q, &x.v, &self.metric))
386 | .collect::>();
387 |
388 | for (i, d) in neighbors.iter().zip(distances.iter()) {
389 | v.insert(i.clone());
390 | if d < &f_value || W.len() < ef {
391 | C.push(Reverse((Numeric(d.clone()), i.clone())));
392 | W.push((Numeric(d.clone()), i.clone()));
393 | if W.len() > ef {
394 | W.pop();
395 | }
396 | }
397 | }
398 | }
399 |
400 | if ef == 1 {
401 | if W.len() > 0 {
402 | let W_map = W.into_map();
403 | let mut W_min = W_map.into_minheap();
404 | let mut single_map = HashMap::new();
405 | let min_val = W_min.pop().unwrap().0;
406 | single_map.insert(min_val.1, min_val.0 .0);
407 | return Ok(single_map);
408 | } else {
409 | return Ok(HashMap::new());
410 | }
411 | }
412 |
413 | Ok(W.into_map())
414 | }
415 |
416 | fn select_neighbors(
417 | &self,
418 | q: &Vec,
419 | C: HashMap,
420 | l_c: usize,
421 | k_p_c: bool,
422 | ) -> HashMap {
423 | let mut R = create_min_heap();
424 | let mut W = C.into_minheap();
425 |
426 | let mut M = 0;
427 |
428 | if l_c > 0 {
429 | M = self.m
430 | } else {
431 | M = self.m_max0;
432 | }
433 |
434 | let mut W_d = create_min_heap();
435 | while W.len() > 0 && R.len() < M {
436 | let e = W.pop().unwrap().0;
437 |
438 | if R.len() == 0 || e.0 < R.peek().unwrap().0 .0 {
439 | R.push(Reverse(e));
440 | } else {
441 | W_d.push(Reverse(e));
442 | }
443 | }
444 |
445 | if k_p_c {
446 | while W_d.len() > 0 && R.len() < M {
447 | R.push(W_d.pop().unwrap());
448 | }
449 | }
450 | R.into_map()
451 | }
452 |
453 | pub fn knn_search(
454 | &self,
455 | q: &Vec,
456 | K: usize,
457 | node_map: Arc,
458 | point_map: Cache,
459 | ) -> Vec {
460 | let mut W = HashMap::new();
461 |
462 | let ep_index = self.db.get_ep().expect("") as u32;
463 | let num_layers = self.db.get_num_layers().expect("Error getting num_layers");
464 |
465 | let points = self.get_points_w_memory(&vec![ep_index], point_map.clone());
466 | let point = points.first().unwrap();
467 | let dist = self.distance(&q, &point.v, &self.metric);
468 | let mut ep = HashMap::from([(ep_index, dist)]);
469 |
470 | for l_c in (1..=num_layers - 1).rev() {
471 | W = self
472 | .search_layer(&q, ep, 1, l_c, node_map.clone(), point_map.clone())
473 | .expect("Error searching layer");
474 | ep = W;
475 | }
476 |
477 | let ep_ = self
478 | .search_layer(q, ep, self.ef, 0, node_map.clone(), point_map.clone())
479 | .expect("Error searching layer");
480 |
481 | let mut heap = ep_.into_minheap();
482 | let mut sorted_vec = Vec::new();
483 | while !heap.is_empty() && sorted_vec.len() < K {
484 | let item = heap.pop().unwrap().0;
485 | sorted_vec.push((item.1, 1.0 - item.0 .0));
486 | }
487 | let indices = sorted_vec.iter().map(|x| x.0).collect::>();
488 | let metadata = self
489 | .db
490 | .get_metadatas(indices)
491 | .expect("Error getting metadatas");
492 |
493 | let result = sorted_vec
494 | .iter()
495 | .zip(metadata.iter())
496 | .map(|(x, y)| json!({"id":x.0, "score":x.1, "metadata":y.clone()}))
497 | .collect::>();
498 | result
499 | }
500 | }
501 |
--------------------------------------------------------------------------------
/dria_hnsw/src/hnsw/mod.rs:
--------------------------------------------------------------------------------
1 | pub mod index;
2 | pub mod scalar;
3 | pub mod sync_map;
4 | pub mod utils;
5 |
--------------------------------------------------------------------------------
/dria_hnsw/src/hnsw/scalar.rs:
--------------------------------------------------------------------------------
1 | extern crate core;
2 | extern crate serde_json;
3 |
4 | use serde::{Deserialize, Serialize};
5 | use tdigest::TDigest;
6 |
7 | #[derive(Debug, Clone)]
8 | pub struct ScalarQuantizer {
9 | levels: usize,
10 | quantiles: Vec,
11 | t_digest: TDigest,
12 | dim: usize,
13 | }
14 |
15 | impl ScalarQuantizer {
16 | pub fn new(levels: usize, size: usize, dim: usize) -> Self {
17 | let quantiles = vec![];
18 | let t_digest = TDigest::new_with_size(size);
19 | ScalarQuantizer {
20 | levels,
21 | quantiles,
22 | t_digest,
23 | dim,
24 | }
25 | }
26 |
27 | pub fn merge(&mut self, matrix: Vec>) {
28 | let flattened: Vec = matrix.into_iter().flatten().collect();
29 | self.t_digest = self.t_digest.merge_unsorted(flattened); // Pass a reference to flattened
30 | self.quantiles = vec![]; //reset
31 | for i in 0..self.levels {
32 | self.quantiles.push(
33 | self.t_digest
34 | .estimate_quantile((i as f64) / (self.levels.clone() as f64)),
35 | );
36 | }
37 | }
38 |
39 | fn __quantize_scalar(&self, scalar: &f64) -> usize {
40 | let ax = &self.quantiles;
41 | ax.into_iter()
42 | .enumerate()
43 | .find(|&(_, q)| scalar < q)
44 | .map_or(256 - 1, |(i, _)| i)
45 | }
46 |
47 | pub fn quantize_vectors(&self, vecs: Vec>) -> Vec> {
48 | let single: Vec = vecs.into_iter().flatten().collect();
49 | let quantized = self.quantize(single.as_slice());
50 | quantized
51 | .chunks(self.dim.clone())
52 | .map(|chunk| chunk.to_vec())
53 | .collect()
54 | }
55 |
56 | pub fn dequantize(&self, qv: &[usize]) -> Vec {
57 | qv.iter().map(|&val| self.quantiles[val.min(255)]).collect()
58 | }
59 |
60 | pub fn quantize(&self, v: &[f64]) -> Vec {
61 | v.iter()
62 | .map(|value| self.__quantize_scalar(value))
63 | .collect()
64 | }
65 | }
66 |
--------------------------------------------------------------------------------
/dria_hnsw/src/hnsw/sync_map.rs:
--------------------------------------------------------------------------------
1 | use crate::proto::index_buffer::LayerNode;
2 | use crossbeam_channel::{bounded, Receiver, Sender};
3 | use dashmap::DashMap;
4 | use hashbrown::HashSet;
5 | use std::collections::HashMap;
6 | use std::sync::Arc;
7 | //use tokio::sync::{Mutex, RwLock, RwLockWriteGuard};
8 | use parking_lot::{Mutex, RwLock};
9 | use std::time::Duration;
10 |
11 | //static ref RESET_SIZE = 120_000;
12 |
13 | pub struct SynchronizedNodes {
14 | pub map: Arc>, //Cache, //Arc>,
15 | pub lock_map: Arc>>,
16 | wait_map: Mutex, Receiver<()>)>>,
17 | }
18 |
19 | impl SynchronizedNodes {
20 | pub fn new() -> Self {
21 | SynchronizedNodes {
22 | map: Arc::new(DashMap::new()), //Arc::new(DashMap::new()),cache
23 | lock_map: Arc::new(DashMap::new()),
24 | wait_map: Mutex::new(HashMap::new()),
25 | }
26 | }
27 |
28 | pub fn reset(&self) {
29 | if self.map.len() > 120_000 {
30 | self.map.clear();
31 | }
32 | }
33 |
34 | pub fn insert_and_notify(&self, node: &LayerNode) {
35 | let key = format!("{}:{}", node.level, node.idx);
36 |
37 | {
38 | let node_lock = self
39 | .lock_map
40 | .entry(key.clone())
41 | .or_insert_with(|| RwLock::new(()));
42 |
43 | let _write_guard = node_lock.write();
44 | // Insert or update the node in the DashMap
45 | self.map.insert(key.clone(), node.clone());
46 | }
47 |
48 | // Notify all waiting threads registered for this key
49 | self.notify(&key);
50 | }
51 |
52 | pub fn insert_batch_and_notify(&self, nodes: Vec) {
53 | //let mut wait_map_guard = self.wait_map.lock().unwrap();
54 | let mut keys_to_notify = HashSet::new();
55 |
56 | for node in nodes.iter() {
57 | let key = format!("{}:{}", node.level, node.idx);
58 |
59 | {
60 | let node_lock = self
61 | .lock_map
62 | .entry(key.clone())
63 | .or_insert_with(|| RwLock::new(()));
64 | let _write_guard = node_lock.write(); // Lock for writing
65 | self.map.insert(key.clone(), node.clone());
66 | }
67 |
68 | keys_to_notify.insert(key);
69 | //drop(_write_guard);
70 | }
71 |
72 | for key in keys_to_notify {
73 | self.notify(&key);
74 | }
75 | }
76 |
77 | pub fn get_or_wait(&self, key: &str) -> LayerNode {
78 | loop {
79 | // Register for notification before checking if a node is being inserted
80 |
81 | if let Some(value) = self.map.get(&key.to_string()) {
82 | return value.value().clone();
83 | }
84 |
85 | let receiver = self.register_for_notification(key);
86 |
87 | // A secondary check
88 | if let Some(value) = self.map.get(&key.to_string()) {
89 | println!("Second check grabbed key: {}", key);
90 | return value.value().clone();
91 | }
92 |
93 | //receiver.recv().unwrap(); // Block the thread until notification is received
94 | match receiver.recv_timeout(Duration::from_millis(500)) {
95 | Ok(_) => { /* Handle reception */ }
96 | Err(e) => {
97 | // Handle timeout or other errors
98 | eprintln!("Error or timeout waiting for message: {:?}", e);
99 | continue; // or handle differently
100 | }
101 | }
102 | }
103 | }
104 |
105 | pub fn get_or_wait_opt(&self, key: &str) -> Option {
106 | loop {
107 | if let Some(value) = self.map.get(&key.to_string()) {
108 | return Some(value.value().clone());
109 | }
110 |
111 | // Check if this key is expected to be updated soon
112 | let receiver = {
113 | let wait_map_guard = self.wait_map.lock();
114 | if let Some((_sender, receiver)) = wait_map_guard.get(key) {
115 | Some(receiver.clone()) // Clone the receiver
116 | } else {
117 | None // Key is not expected to be updated soon
118 | }
119 | };
120 |
121 | if let Some(receiver) = receiver {
122 | // Wait for the notification if the key is expected to be updated
123 | match receiver.recv() {
124 | Ok(_) => {
125 | // Handle the received message
126 | }
127 | Err(e) => {
128 | // Handle the error, e.g., log it or perform a fallback action
129 | eprintln!("Error receiving message: {:?}", e);
130 | return None;
131 | }
132 | }
133 | } else {
134 | // If the key is not expected to be updated soon, return None
135 | return None;
136 | }
137 | }
138 | }
139 |
140 | pub fn register_for_notification(&self, key: &str) -> Receiver<()> {
141 | let mut wait_map_guard = self.wait_map.lock();
142 | if !wait_map_guard.contains_key(key) {
143 | // Create a new sender/receiver pair if it doesn't exist
144 | let (sender, receiver) = bounded(10);
145 | wait_map_guard.insert(key.to_string(), (sender, receiver.clone()));
146 | receiver
147 | } else {
148 | // If it already exists, return the existing receiver
149 | wait_map_guard.get(key).unwrap().1.clone()
150 | }
151 | }
152 |
153 | pub fn notify(&self, key: &str) {
154 | let mut wait_map_guard = self.wait_map.lock();
155 | if let Some((sender, _)) = wait_map_guard.remove(key) {
156 | drop(wait_map_guard); // Drop the lock before sending to avoid deadlocks
157 | let _ = sender.send(()); // It's safe to ignore the send result
158 | }
159 | }
160 | }
161 |
162 | #[cfg(test)]
163 | mod tests {
164 | // use super::*;
165 | // use crate::db::conversions::{base64_to_node, node_to_base64};
166 | // use crate::proto::index_buffer::LayerNode;
167 | // use crate::proto::index_buffer::Point;
168 | // use std::sync::Arc;
169 | // use std::thread;
170 | // use std::time::Duration;
171 |
172 | #[test]
173 | #[ignore = "todo"]
174 | fn test_synchronized_nodes() {}
175 | }
176 |
--------------------------------------------------------------------------------
/dria_hnsw/src/hnsw/utils.rs:
--------------------------------------------------------------------------------
1 | use std::cmp::{Ordering, Reverse};
2 | use std::collections::BinaryHeap; //node_metadata
3 | use std::collections::HashMap;
4 |
5 | #[derive(PartialEq, Debug)]
6 | pub struct Numeric(pub f32);
7 |
8 | impl Eq for Numeric {}
9 |
10 | impl PartialOrd for Numeric {
11 | fn partial_cmp(&self, other: &Self) -> Option {
12 | self.0.partial_cmp(&other.0)
13 | }
14 | }
15 |
16 | impl Ord for Numeric {
17 | fn cmp(&self, other: &Self) -> Ordering {
18 | self.0.partial_cmp(&other.0).unwrap_or(Ordering::Equal)
19 | }
20 | }
21 |
22 | pub trait IntoMap {
23 | fn into_map(self) -> HashMap;
24 | }
25 |
26 | impl IntoMap for BinaryHeap<(Numeric, u32)> {
27 | fn into_map(self) -> HashMap {
28 | self.into_iter().map(|(d, i)| (i, d.0)).collect()
29 | }
30 | }
31 |
32 | impl IntoMap for BinaryHeap> {
33 | fn into_map(self) -> HashMap {
34 | self.into_iter().map(|Reverse((d, i))| (i, d.0)).collect()
35 | }
36 | }
37 |
38 | pub trait IntoHeap {
39 | fn into_maxheap(self) -> BinaryHeap<(Numeric, u32)>;
40 | fn into_minheap(self) -> BinaryHeap>;
41 | }
42 |
43 | impl IntoHeap for HashMap {
44 | fn into_maxheap(self) -> BinaryHeap<(Numeric, u32)> {
45 | self.into_iter().map(|(i, d)| (Numeric(d), i)).collect()
46 | }
47 |
48 | fn into_minheap(self) -> BinaryHeap> {
49 | self.into_iter()
50 | .map(|(i, d)| Reverse((Numeric(d), i)))
51 | .collect()
52 | }
53 | }
54 |
55 | impl IntoHeap for Vec<(f32, u32)> {
56 | fn into_maxheap(self) -> BinaryHeap<(Numeric, u32)> {
57 | self.into_iter().map(|(d, i)| (Numeric(d), i)).collect()
58 | }
59 |
60 | fn into_minheap(self) -> BinaryHeap> {
61 | self.into_iter()
62 | .map(|(d, i)| Reverse((Numeric(d), i)))
63 | .collect()
64 | }
65 | }
66 |
67 | pub fn create_min_heap() -> BinaryHeap> {
68 | let q: BinaryHeap> = BinaryHeap::new();
69 | q
70 | }
71 |
72 | pub fn create_max_heap() -> BinaryHeap<(Numeric, u32)> {
73 | let q: BinaryHeap<(Numeric, u32)> = BinaryHeap::new();
74 | q
75 | }
76 |
--------------------------------------------------------------------------------
/dria_hnsw/src/lib.rs:
--------------------------------------------------------------------------------
1 | pub mod db;
2 | pub mod errors;
3 | pub mod hnsw;
4 | pub mod middlewares;
5 | pub mod models;
6 | pub mod proto;
7 | pub mod responses;
8 | pub mod worker;
9 | pub mod filter;
10 |
--------------------------------------------------------------------------------
/dria_hnsw/src/main.rs:
--------------------------------------------------------------------------------
1 | use actix_cors::Cors;
2 | use actix_web::middleware::Logger;
3 | use actix_web::{web, App, HttpServer};
4 | use dria_hnsw::db::env::Config;
5 | use dria_hnsw::db::rocksdb_client::RocksdbClient;
6 | use dria_hnsw::middlewares::cache::{NodeCache, PointCache};
7 | use dria_hnsw::worker::{fetch, get_health_status, insert_vector, query};
8 |
9 | pub fn config(conf: &mut web::ServiceConfig) {
10 | conf.service(get_health_status);
11 | conf.service(query);
12 | conf.service(fetch);
13 | conf.service(insert_vector);
14 | }
15 |
16 | #[actix_web::main]
17 | async fn main() -> std::io::Result<()> {
18 | let node_cache = web::Data::new(NodeCache::new());
19 | let point_cache = web::Data::new(PointCache::new());
20 | let cfg = Config::new();
21 |
22 | let rocksdb_client = RocksdbClient::new(cfg.contract_id.clone());
23 |
24 | if rocksdb_client.is_err() {
25 | println!("Rocksdb client failed to initialize");
26 | return Err(std::io::Error::new(
27 | std::io::ErrorKind::Other,
28 | "Rocksdb client failed to initialize",
29 | ));
30 | }
31 | let rdb = rocksdb_client.unwrap();
32 | let rocksdb_client = web::Data::new(rdb);
33 |
34 | let factory = move || {
35 | App::new()
36 | .app_data(web::JsonConfig::default().limit(152428800))
37 | .app_data(node_cache.clone())
38 | .app_data(rocksdb_client.clone())
39 | .app_data(point_cache.clone())
40 | .configure(config)
41 | .wrap(Logger::default())
42 | .wrap(Cors::permissive())
43 | };
44 |
45 | let url = format!("0.0.0.0:{}", cfg.port);
46 | println!("Dria HNSW listening at {}", url);
47 | HttpServer::new(factory).bind(url)?.run().await?;
48 | Ok(())
49 | }
50 |
--------------------------------------------------------------------------------
/dria_hnsw/src/middlewares/cache.rs:
--------------------------------------------------------------------------------
1 | use crate::hnsw::sync_map::SynchronizedNodes;
2 | use crate::proto::index_buffer::Point;
3 | use mini_moka::sync::Cache;
4 | use std::sync::Arc;
5 | use std::time::Duration;
6 |
7 | pub struct NodeCache {
8 | pub caches: Cache>,
9 | }
10 |
11 | /// If a key within cache is not used (i.e. get or insert) for the given duration (seconds), expire that key.
12 | const NODE_CACHE_EXPIRE: u64 = 48 * 60 * 60; // 2 days
13 | /// Maximum capacity of the cache, in number of keys.
14 | const NODE_CACHE_CAPACITY: u64 = 5_000;
15 |
16 | impl NodeCache {
17 | pub fn new() -> Self {
18 | let cache = Cache::builder()
19 | .time_to_idle(Duration::from_secs(NODE_CACHE_EXPIRE))
20 | .max_capacity(NODE_CACHE_CAPACITY)
21 | .build();
22 |
23 | NodeCache { caches: cache }
24 | }
25 |
26 | pub fn get_cache(&self, key: String) -> Arc {
27 | let my_cache = self.caches.clone();
28 | let node_cache = my_cache.get(&key).unwrap_or_else(|| {
29 | let new_cache = Arc::new(SynchronizedNodes::new());
30 | my_cache.insert(key.to_string(), new_cache.clone());
31 | new_cache
32 | });
33 |
34 | // TODO: clone required here?
35 | node_cache.clone()
36 | }
37 |
38 | pub fn add_cache(&self, key: &str, cache: Arc) {
39 | let my_cache = self.caches.clone();
40 | my_cache.insert(key.to_string(), cache);
41 | }
42 | }
43 |
44 | /// If a key within cache is not used (i.e. get or insert) for the given duration (seconds), expire that key.
45 | const POINT_CACHE_EXPIRE: u64 = 24 * 60 * 60; // 1 days
46 | /// Maximum capacity of the cache, in number of keys.
47 | const POINT_CACHE_CAPACITY: u64 = 5_000;
48 |
49 | pub struct PointCache {
50 | pub caches: Cache>, //Cache>>,
51 | }
52 |
53 | impl PointCache {
54 | pub fn new() -> Self {
55 | let cache = Cache::builder()
56 | .time_to_idle(Duration::from_secs(POINT_CACHE_EXPIRE))
57 | .max_capacity(POINT_CACHE_CAPACITY) // around 106MB for 1536 dim vectors
58 | .build();
59 |
60 | PointCache { caches: cache }
61 | }
62 |
63 | pub fn get_cache(&self, key: String) -> Cache {
64 | //let cache = self.caches.entry(key.to_string()).or_insert_with(|| Arc::new(DashMap::new()));
65 | let my_cache = self.caches.clone();
66 | let point_cache = my_cache.get(&key).unwrap_or_else(|| {
67 | //let new_cache = Arc::new(DashMap::new());
68 | let new_cache = Cache::builder()
69 | //if a key is not used (get or insert) for 2 hour, expire it
70 | //.time_to_live(Duration::from_secs(1 * 60 * 60))
71 | .max_capacity(200_000) // around 2060MB for 1536 dim vectors
72 | .build();
73 | my_cache.insert(key.to_string(), new_cache.clone());
74 | new_cache
75 | });
76 |
77 | // TODO: clone required here?
78 | point_cache.clone()
79 | }
80 |
81 | pub fn add_cache(&self, key: &str, cache: Cache) {
82 | self.caches.insert(key.to_string(), cache);
83 | }
84 | }
85 |
86 | #[cfg(test)]
87 | mod tests {
88 | use super::*;
89 | use std::sync::atomic::{AtomicU32, Ordering};
90 |
91 | #[test]
92 | fn test_cache() {
93 | let cache = Cache::builder()
94 | //if a key is not used (get or insert) for 2 hour, expire it
95 | .time_to_idle(Duration::from_secs(120 * 60))
96 | .max_capacity(100)
97 | .build();
98 |
99 | for i in 0..105 {
100 | cache.insert(i.to_string(), i);
101 | }
102 |
103 | let ix = cache.get(&"0".to_string());
104 | assert_eq!(ix, Some(0));
105 | }
106 |
107 | #[test]
108 | fn test_weighted_cache() {
109 | let current_weight = AtomicU32::new(1); // Start weights from 1 to avoid assigning a weight of 0
110 |
111 | let cache = Cache::builder()
112 | .weigher(move |_key, _value: &String| -> u32 {
113 | // Use the current weight and increment for the next use
114 | current_weight.fetch_add(1, Ordering::SeqCst)
115 | })
116 | // Assuming a simple numeric weight limit for demonstration purposes
117 | .max_capacity(100)
118 | .build();
119 |
120 | // Example inserts - in a real scenario, make sure to manage the size and weights appropriately
121 | for i in 0..105 {
122 | cache.insert(i, format!("Value {}", i));
123 | }
124 |
125 | let ix = cache.get(&0);
126 | assert_eq!(ix, Some("Value 0".to_string()));
127 | }
128 | }
129 |
--------------------------------------------------------------------------------
/dria_hnsw/src/middlewares/mod.rs:
--------------------------------------------------------------------------------
1 | pub mod cache;
2 |
--------------------------------------------------------------------------------
/dria_hnsw/src/models/mod.rs:
--------------------------------------------------------------------------------
1 | pub mod request_models;
2 |
--------------------------------------------------------------------------------
/dria_hnsw/src/models/request_models.rs:
--------------------------------------------------------------------------------
1 | use serde::{Deserialize, Serialize};
2 | use serde_json::Value;
3 |
4 | use crate::errors::errors::ValidationError;
5 |
6 | #[derive(Serialize, Deserialize, Debug)]
7 | pub struct InsertModel {
8 | pub vector: Vec,
9 | pub metadata: Value,
10 | }
11 |
12 | #[derive(Serialize, Deserialize, Debug)]
13 | pub struct InsertBatchModel {
14 | pub data: Vec,
15 | }
16 |
17 | #[derive(Serialize, Deserialize, Debug)]
18 | pub struct FetchModel {
19 | pub id: Vec, // TODO: rename this to `ids`
20 | }
21 |
22 | #[derive(Serialize, Deserialize, Debug)]
23 | pub struct QueryModel {
24 | pub vector: Vec,
25 | pub top_n: usize,
26 | pub query: Option,
27 | pub level: Option,
28 | }
29 |
30 | impl QueryModel {
31 | pub fn new(
32 | vector: Vec,
33 | top_n: usize,
34 | query: Option,
35 | level: Option,
36 | ) -> Result {
37 | Self::validate_top_n(top_n)?;
38 | Self::validate_level(level)?;
39 |
40 | Ok(QueryModel {
41 | vector,
42 | top_n,
43 | query,
44 | level,
45 | })
46 | }
47 |
48 | fn validate_top_n(top_n: usize) -> Result<(), ValidationError> {
49 | if top_n > 20 {
50 | Err(ValidationError("Top N cannot be more than 20.".to_string()))
51 | } else {
52 | Ok(())
53 | }
54 | }
55 |
56 | fn validate_level(level: Option) -> Result<(), ValidationError> {
57 | if level.is_some() {
58 | match level.unwrap() {
59 | 1 | 2 | 3 | 4 => Ok(()),
60 | _ => Err(ValidationError(
61 | "Level should be 1, 2, 3, or 4.".to_string(),
62 | )),
63 | }
64 | } else {
65 | Ok(())
66 | }
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/dria_hnsw/src/proto/hnsw_comm.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package index_buffer;
4 |
5 |
6 | message LayerNode {
7 | uint32 level = 1; // Using uint64 as a safe alternative for usize
8 | uint32 idx = 2; // Using uint64 as a safe alternative for usize
9 | bool visible = 3; // Whether the node is visible
10 | map neighbors = 4; // Neighbor idx and its distance
11 | }
12 |
13 |
14 | message Point {
15 | uint32 idx = 1; // Using uint64 as a safe alternative for usize
16 | repeated float v = 2; // Vector of floats
17 | }
18 |
19 | message PointQuant {
20 | uint32 idx = 1; // Using uint64 as a safe alternative for usize
21 | repeated uint32 v = 2; // Vector of ints
22 | }
--------------------------------------------------------------------------------
/dria_hnsw/src/proto/index.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package index_buffer;
4 |
5 |
6 | message LayerNode {
7 | uint32 level = 1; // Using uint64 as a safe alternative for usize
8 | uint32 idx = 2; // Using uint64 as a safe alternative for usize
9 | bool visible = 3; // Whether the node is visible
10 | map neighbors = 4; // Neighbor idx and its distance
11 | }
12 |
13 |
14 | message Point {
15 | uint32 idx = 1; // Using uint64 as a safe alternative for usize
16 | repeated float v = 2; // Vector of floats
17 | }
18 |
19 | message PointQuant {
20 | uint32 idx = 1; // Using uint64 as a safe alternative for usize
21 | repeated uint32 v = 2; // Vector of ints
22 | }
--------------------------------------------------------------------------------
/dria_hnsw/src/proto/index_buffer.rs:
--------------------------------------------------------------------------------
1 | #[derive(Clone, PartialEq, ::prost::Message)]
2 | pub struct LayerNode {
3 | /// Using uint64 as a safe alternative for usize
4 | #[prost(uint32, tag = "1")]
5 | pub level: u32,
6 | /// Using uint64 as a safe alternative for usize
7 | #[prost(uint32, tag = "2")]
8 | pub idx: u32,
9 | /// Whether the node is visible
10 | #[prost(bool, tag = "3")]
11 | pub visible: bool,
12 | /// Neighbor idx and its distance
13 | #[prost(map = "uint32, float", tag = "4")]
14 | pub neighbors: ::std::collections::HashMap,
15 | }
16 | impl LayerNode {
17 | pub fn new(level: usize, idx: usize) -> LayerNode {
18 | LayerNode {
19 | level: level as u32,
20 | idx: idx as u32,
21 | visible: true,
22 | neighbors: ::std::collections::HashMap::new(),
23 | }
24 | }
25 | }
26 |
27 | #[derive(Clone, PartialEq, ::prost::Message)]
28 | pub struct Point {
29 | /// Using uint64 as a safe alternative for usize
30 | #[prost(uint32, tag = "1")]
31 | pub idx: u32,
32 | /// Vector of floats
33 | #[prost(float, repeated, tag = "2")]
34 | pub v: ::prost::alloc::vec::Vec,
35 | }
36 | impl Point {
37 | pub fn new(vec: Vec, idx: usize) -> Point {
38 | Point {
39 | idx: idx as u32,
40 | v: vec,
41 | }
42 | }
43 | }
44 |
45 | #[derive(Clone, PartialEq, ::prost::Message)]
46 | pub struct PointQuant {
47 | /// Using uint64 as a safe alternative for usize
48 | #[prost(uint32, tag = "1")]
49 | pub idx: u32,
50 | /// Vector of ints
51 | #[prost(uint32, repeated, tag = "2")]
52 | pub v: ::prost::alloc::vec::Vec,
53 | }
54 | impl PointQuant {
55 | pub fn new(vec: Vec, idx: usize) -> PointQuant {
56 | PointQuant {
57 | idx: idx as u32,
58 | v: vec,
59 | }
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/dria_hnsw/src/proto/insert.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package insert_buffer;
4 |
5 | message MetadataValue {
6 | oneof value_type {
7 | float float_value = 1;
8 | int64 int_value = 2;
9 | string string_value = 3;
10 | bool bool_value = 4;
11 | }
12 | }
13 |
14 | message SingletonVec {
15 | repeated float v = 1; // Vector of floats
16 | map map = 2;
17 |
18 | }
19 |
20 | message BatchVec {
21 | repeated SingletonVec s = 1;
22 | }
23 |
24 |
25 |
26 | message SingletonStr {
27 | string v = 1; // Vector of strings
28 | map map = 2;
29 |
30 | }
31 |
32 | message BatchStr {
33 | repeated SingletonStr s = 1;
34 | }
35 |
--------------------------------------------------------------------------------
/dria_hnsw/src/proto/insert_buffer.rs:
--------------------------------------------------------------------------------
1 | use serde::ser::{Serialize, SerializeStruct, Serializer};
2 |
3 | #[derive(Clone, PartialEq, ::prost::Message)]
4 | pub struct MetadataValue {
5 | #[prost(oneof = "metadata_value::ValueType", tags = "1, 2, 3, 4")]
6 | pub value_type: ::core::option::Option,
7 | }
8 | /// Nested message and enum types in `MetadataValue`.
9 | pub mod metadata_value {
10 | #[derive(Clone, PartialEq, ::prost::Oneof)]
11 | pub enum ValueType {
12 | #[prost(float, tag = "1")]
13 | FloatValue(f32),
14 | #[prost(int64, tag = "2")]
15 | IntValue(i64),
16 | #[prost(string, tag = "3")]
17 | StringValue(::prost::alloc::string::String),
18 | #[prost(bool, tag = "4")]
19 | BoolValue(bool),
20 | }
21 | }
22 | #[derive(Clone, PartialEq, ::prost::Message)]
23 | pub struct SingletonVec {
24 | /// Vector of floats
25 | #[prost(float, repeated, tag = "1")]
26 | pub v: ::prost::alloc::vec::Vec,
27 | #[prost(map = "string, message", tag = "2")]
28 | pub map: ::std::collections::HashMap<::prost::alloc::string::String, MetadataValue>,
29 | }
30 | #[derive(Clone, PartialEq, ::prost::Message)]
31 | pub struct BatchVec {
32 | #[prost(message, repeated, tag = "1")]
33 | pub s: ::prost::alloc::vec::Vec,
34 | }
35 | #[derive(Clone, PartialEq, ::prost::Message)]
36 | pub struct SingletonStr {
37 | /// Vector of strings
38 | #[prost(string, tag = "1")]
39 | pub v: ::prost::alloc::string::String,
40 | #[prost(map = "string, message", tag = "2")]
41 | pub map: ::std::collections::HashMap<::prost::alloc::string::String, MetadataValue>,
42 | }
43 | #[derive(Clone, PartialEq, ::prost::Message)]
44 | pub struct BatchStr {
45 | #[prost(message, repeated, tag = "1")]
46 | pub s: ::prost::alloc::vec::Vec,
47 | }
48 |
49 | impl Serialize for metadata_value::ValueType {
50 | fn serialize(&self, serializer: S) -> Result
51 | where
52 | S: Serializer,
53 | {
54 | // Serialize the value directly without wrapping it in a structure.
55 | match *self {
56 | metadata_value::ValueType::FloatValue(f) => serializer.serialize_f32(f),
57 | metadata_value::ValueType::IntValue(i) => serializer.serialize_i64(i),
58 | metadata_value::ValueType::StringValue(ref s) => serializer.serialize_str(s),
59 | metadata_value::ValueType::BoolValue(b) => serializer.serialize_bool(b),
60 | }
61 | }
62 | }
63 |
64 | impl Serialize for MetadataValue {
65 | fn serialize(&self, serializer: S) -> Result
66 | where
67 | S: Serializer,
68 | {
69 | // Use the inner value's serialization directly.
70 | if let Some(ref value_type) = self.value_type {
71 | value_type.serialize(serializer)
72 | } else {
73 | // Decide how you want to handle MetadataValue when it's None.
74 | // For example, you might serialize it as null or an empty object.
75 | serializer.serialize_none()
76 | }
77 | }
78 | }
79 |
--------------------------------------------------------------------------------
/dria_hnsw/src/proto/mod.rs:
--------------------------------------------------------------------------------
1 | pub mod index_buffer;
2 | pub mod insert_buffer;
3 |
--------------------------------------------------------------------------------
/dria_hnsw/src/proto/request.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package request_buffer;
4 |
5 | message Singleton {
6 | repeated float v = 1; // Vector of floats
7 | map metadata = 2; // Neighbor idx and its distance
8 | }
9 |
10 | message Batch {
11 | repeated string b = 1;
12 | }
13 |
--------------------------------------------------------------------------------
/dria_hnsw/src/proto/request_buffer.rs:
--------------------------------------------------------------------------------
1 | #[derive(Clone, PartialEq, ::prost::Message)]
2 | pub struct Singleton {
3 | /// Vector of floats
4 | #[prost(float, repeated, tag = "1")]
5 | pub v: ::prost::alloc::vec::Vec,
6 | /// Neighbor idx and its distance
7 | #[prost(map = "string, string", tag = "2")]
8 | pub metadata:
9 | ::std::collections::HashMap<::prost::alloc::string::String, ::prost::alloc::string::String>,
10 | }
11 | #[derive(Clone, PartialEq, ::prost::Message)]
12 | pub struct Batch {
13 | #[prost(string, repeated, tag = "1")]
14 | pub b: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
15 | }
16 |
--------------------------------------------------------------------------------
/dria_hnsw/src/responses/mod.rs:
--------------------------------------------------------------------------------
1 | pub mod responses;
2 |
--------------------------------------------------------------------------------
/dria_hnsw/src/responses/responses.rs:
--------------------------------------------------------------------------------
1 | use serde::Serialize;
2 |
3 | #[derive(Serialize)]
4 | pub struct CustomResponse {
5 | pub(crate) success: bool,
6 | pub(crate) data: T,
7 | pub(crate) code: u32,
8 | }
9 |
--------------------------------------------------------------------------------
/dria_hnsw/src/worker.rs:
--------------------------------------------------------------------------------
1 | use crate::db::env::Config;
2 | use crate::db::rocksdb_client::RocksdbClient;
3 | use crate::hnsw::index::HNSW;
4 | use crate::hnsw::sync_map::SynchronizedNodes;
5 | use crate::middlewares::cache::{NodeCache, PointCache};
6 | use crate::models::request_models::{FetchModel, InsertBatchModel, QueryModel};
7 | use crate::proto::index_buffer::{LayerNode, Point};
8 | use crate::responses::responses::CustomResponse;
9 | use actix_web::web::{Data, Json};
10 | use actix_web::{get, post, web, HttpMessage, HttpRequest, HttpResponse};
11 | use log::error;
12 | use mini_moka::sync::Cache;
13 | use rayon::prelude::*;
14 | use serde::{Deserialize, Serialize};
15 | use serde_json::{json, Value};
16 | use std::borrow::Borrow;
17 | use std::sync::atomic::{AtomicIsize, AtomicUsize, Ordering};
18 | use std::sync::Arc;
19 | use tokio::task;
20 |
21 | use crate::filter::text_based::create_index_from_docs;
22 | use probly_search::Index;
23 |
24 | pub const SINGLE_THREADED_HNSW_BUILD_THRESHOLD: usize = 256;
25 |
26 | #[get("/health")]
27 | pub async fn get_health_status() -> HttpResponse {
28 | let response = CustomResponse {
29 | success: true,
30 | data: "hello world!".to_string(),
31 | code: 200,
32 | };
33 | HttpResponse::Ok().json(response)
34 | }
35 |
36 | #[post("/query")]
37 | pub async fn query(req: HttpRequest, payload: Json) -> HttpResponse {
38 | let ind: HNSW;
39 |
40 | let cfg = Config::new();
41 |
42 | let rocksdb_client = req
43 | .app_data::>()
44 | .expect("Error getting rocksdb client");
45 |
46 | ind = HNSW::new(
47 | 16,
48 | 128,
49 | ef_helper(payload.level),
50 | None,
51 | rocksdb_client.clone(),
52 | );
53 | let node_cache = req
54 | .app_data::>()
55 | .expect("Error getting node cache"); //Arc = Arc::new(SynchronizedNodes::new());
56 | let point_cache = req
57 | .app_data::>()
58 | .expect("Error getting point cache"); //Arc> = Arc::new(DashMap::new());
59 |
60 | let node_map = node_cache.get_cache(cfg.contract_id.clone()); //Arc = Arc::new(SynchronizedNodes::new());
61 | let point_map = point_cache.get_cache(cfg.contract_id.clone());
62 | let res = ind.knn_search(&payload.vector, payload.top_n, node_map, point_map);
63 |
64 | if payload.query.is_some() {
65 | let mut index = Index::::new(1);
66 | let results =
67 | create_index_from_docs(&mut index, &payload.query.clone().unwrap(), res.clone());
68 | let response = CustomResponse {
69 | success: true,
70 | data: json!(results),
71 | code: 200,
72 | };
73 | return HttpResponse::Ok().json(response);
74 | }
75 |
76 | let response = CustomResponse {
77 | success: true,
78 | data: json!(res),
79 | code: 200,
80 | };
81 | HttpResponse::Ok().json(response)
82 | }
83 |
84 | #[post("/fetch")]
85 | pub async fn fetch(req: HttpRequest, payload: Json) -> HttpResponse {
86 | let ind: HNSW;
87 |
88 | let rocksdb_client = req
89 | .app_data::>()
90 | .expect("Error getting rocksdb client");
91 |
92 | ind = HNSW::new(16, 128, 0, None, rocksdb_client.clone());
93 |
94 | let res = ind.db.get_metadatas(payload.id.clone());
95 |
96 | if res.is_err() {
97 | let response = CustomResponse {
98 | success: false,
99 | data: "Error fetching metadata".to_string(),
100 | code: 500,
101 | };
102 | return HttpResponse::InternalServerError().json(response);
103 | }
104 |
105 | let res = res.unwrap();
106 |
107 | let response = CustomResponse {
108 | success: true,
109 | data: json!(res),
110 | code: 200,
111 | };
112 | HttpResponse::Ok().json(response)
113 | }
114 |
115 | #[post("/insert_vector")]
116 | pub async fn insert_vector(req: HttpRequest, payload: Json) -> HttpResponse {
117 | let cfg = Config::new();
118 | let cid = cfg.contract_id.clone();
119 |
120 | let rocksdb_client = req
121 | .app_data::>()
122 | .expect("Error getting rocksdb client");
123 |
124 | let rocksdb_client = rocksdb_client.clone();
125 |
126 | let mut vectors = Vec::new();
127 | let mut metadata_batch = Vec::new();
128 | for d in payload.data.iter() {
129 | vectors.push(d.vector.clone());
130 | metadata_batch.push(d.metadata.clone());
131 | }
132 |
133 | if vectors.len() > 2500 {
134 | let response = CustomResponse {
135 | success: false,
136 | data: "Batch size should be smaller than 2500.".to_string(),
137 | code: 401,
138 | };
139 | // TODO: fix status code
140 | return HttpResponse::InternalServerError().json(response);
141 | }
142 |
143 | let node_cache = req
144 | .app_data::>()
145 | .expect("Error getting node cache"); //Arc = Arc::new(SynchronizedNodes::new());
146 | let point_cache = req
147 | .app_data::>()
148 | .expect("Error getting point cache"); //Arc> = Arc::new(DashMap::new());
149 |
150 | let node_map = node_cache.get_cache(cid.clone()); //Arc = Arc::new(SynchronizedNodes::new());
151 | let point_map = point_cache.get_cache(cid.clone()); //Arc> = Arc::new(DashMap::new());
152 | let cid_clone = cid.clone();
153 | let result = task::spawn_blocking(move || {
154 | train_worker(
155 | vectors,
156 | metadata_batch,
157 | node_map,
158 | point_map,
159 | rocksdb_client.clone(),
160 | 10_000,
161 | )
162 | })
163 | .await;
164 |
165 | let node_map = node_cache.get_cache(cid_clone); //Arc = Arc::new(SynchronizedNodes::new());
166 | node_map.reset();
167 |
168 | let (res, code) = result.expect("Error getting result");
169 |
170 | if code != 200 {
171 | return HttpResponse::InternalServerError().json(CustomResponse {
172 | success: false,
173 | data: res,
174 | code: code as u32,
175 | });
176 | }
177 | return HttpResponse::Ok().json(CustomResponse {
178 | success: true,
179 | data: "Values are successfully added to index.".to_string(),
180 | code: 200,
181 | });
182 | }
183 |
184 | fn ef_helper(ef: Option) -> usize {
185 | let level = ef.clone().unwrap_or(1);
186 | 20 + (level * 30)
187 | }
188 |
189 | fn train_worker(
190 | vectors: Vec>,
191 | metadata_batch: Vec,
192 | node_map: Arc,
193 | point_map: Cache,
194 | rocksdb_client: Data,
195 | batch_size: usize,
196 | ) -> (String, u16) {
197 | let ind = HNSW::new(16, 128, ef_helper(Some(1)), None, rocksdb_client.clone());
198 |
199 | let mut ds = 0;
200 | let nl = ind.db.get_num_layers();
201 | let num_layers = Arc::new(AtomicUsize::new(0));
202 |
203 | if nl.is_err() {
204 | error!("{}", nl.err().unwrap());
205 | ind.db.set_datasize(0).expect("Error setting datasize");
206 | } else {
207 | let nl_value = nl.expect("").clone();
208 | ds = ind.db.get_datasize().expect("Error getting datasize");
209 |
210 | let res = num_layers.fetch_update(Ordering::SeqCst, Ordering::Relaxed, |x| Some(nl_value));
211 |
212 | if res.is_err() {
213 | error!("{}", res.err().unwrap());
214 | return ("Error setting num layers, atomic".to_string(), 500);
215 | }
216 | }
217 |
218 | let r1 = ind.db.add_points_batch(&vectors, ds);
219 | let r2 = ind.db.set_metadata_batch(metadata_batch, ds);
220 | let r3 = ind.db.set_datasize(ds + vectors.len());
221 |
222 | if r1.is_err() || r2.is_err() || r3.is_err() {
223 | error!("Error adding points as batch");
224 | return ("Error adding points as batch".to_string(), 500);
225 | }
226 |
227 | let epa = Arc::new(AtomicIsize::new(-1));
228 |
229 | let ep = ind.db.get_ep();
230 |
231 | if ep.is_ok() {
232 | let ep_value = ep.expect("").clone();
233 | let res = epa.fetch_update(Ordering::SeqCst, Ordering::Relaxed, |x| {
234 | // Your update logic here
235 | Some(ep_value as isize)
236 | });
237 | if res.is_err() {
238 | error!("{}", res.err().unwrap());
239 | return ("Error setting ep, atomic".to_string(), 500);
240 | }
241 | }
242 | let pool = rayon::ThreadPoolBuilder::new()
243 | .thread_name(|idx| format!("hnsw-build-{idx}"))
244 | .num_threads(8)
245 | .build()
246 | .expect("Error building threadpool");
247 |
248 | if ds < SINGLE_THREADED_HNSW_BUILD_THRESHOLD {
249 | let iter_ind = vectors.len().min(SINGLE_THREADED_HNSW_BUILD_THRESHOLD - ds);
250 | for i in 0..iter_ind {
251 | ind.insert_w_preset(
252 | ds + i,
253 | node_map.clone(),
254 | point_map.clone(),
255 | num_layers.clone(),
256 | epa.clone(),
257 | )
258 | .expect("Error inserting");
259 | }
260 |
261 | if SINGLE_THREADED_HNSW_BUILD_THRESHOLD < (vectors.len() + ds) {
262 | pool.install(|| {
263 | (iter_ind..vectors.len())
264 | .into_par_iter()
265 | .try_for_each(|item| {
266 | ind.insert_w_preset(
267 | ds + item,
268 | node_map.clone(),
269 | point_map.clone(),
270 | num_layers.clone(),
271 | epa.clone(),
272 | )
273 | })
274 | })
275 | .expect("Error inserting");
276 | }
277 | } else {
278 | pool.install(|| {
279 | (0..vectors.len()).into_par_iter().try_for_each(|item| {
280 | ind.insert_w_preset(
281 | item + ds,
282 | node_map.clone(),
283 | point_map.clone(),
284 | num_layers.clone(),
285 | epa.clone(),
286 | )
287 | })
288 | })
289 | .expect("Error inserting");
290 | }
291 |
292 | //replicate neighbors
293 | let values: Vec = node_map
294 | .clone()
295 | .map
296 | .iter()
297 | .map(|entry| entry.value().clone())
298 | .collect();
299 |
300 | let ep_value = epa.clone().load(Ordering::Relaxed);
301 | let num_layers = num_layers.clone().load(Ordering::Relaxed);
302 |
303 | let batch_size: usize = batch_size;
304 |
305 | for chunk in values.chunks(batch_size) {
306 | let r1 = ind.db.upsert_neighbors(chunk.to_vec());
307 | if r1.is_err() {
308 | error!("{}", r1.err().unwrap());
309 | return ("Error writing batch to blockchain".to_string(), 500);
310 | }
311 | }
312 | let r2 = ind.db.set_ep(ep_value as usize, false);
313 | let r3 = ind.db.set_num_layers(num_layers, false);
314 |
315 | if r2.is_err() || r3.is_err() {
316 | error!("Error writing to blockchain");
317 | return ("Error writing batch to blockchain".to_string(), 500);
318 | }
319 | return ("Values are successfully added to index.".to_string(), 200);
320 | }
321 |
322 | #[cfg(test)]
323 | mod tests {
324 | use super::*;
325 | use actix_web::{http::header::ContentType, test, App};
326 | use rand::{self, Rng};
327 | use simple_home_dir::home_dir;
328 | use std::env;
329 |
330 | fn prepare_env() -> () {
331 | const CONTRACT_ID: &str = "WbcY2a-KfDpk7fsgumUtLC2bu4NQcVzNlXWi13fPMlU";
332 |
333 | env::set_var("CONTRACT_ID", CONTRACT_ID);
334 | let mut rocksdb_path = home_dir().unwrap().to_str().unwrap().to_owned();
335 | rocksdb_path.push_str("/.dria/data/");
336 | rocksdb_path.push_str(CONTRACT_ID);
337 | env::set_var("ROCKSDB_PATH", rocksdb_path);
338 | }
339 |
340 | fn prepare_rocksdb() -> Data {
341 | let cfg = Config::new();
342 | let rocksdb_client = RocksdbClient::new(cfg.contract_id.clone()).unwrap();
343 | web::Data::new(rocksdb_client)
344 | }
345 |
346 | #[actix_web::test]
347 | async fn test_health() {
348 | let app = test::init_service(App::new().configure(|conf| {
349 | conf.service(get_health_status);
350 | }))
351 | .await;
352 |
353 | let req = test::TestRequest::get()
354 | .uri("/health")
355 | .insert_header(ContentType::plaintext())
356 | .to_request();
357 | let resp = test::call_service(&app, req).await;
358 | assert!(resp.status().is_success());
359 | }
360 |
361 | #[actix_web::test]
362 | async fn test_fetch_and_query() {
363 | prepare_env();
364 | let rocksdb_client = prepare_rocksdb();
365 | let node_cache = web::Data::new(NodeCache::new());
366 | let point_cache = web::Data::new(PointCache::new());
367 | let app = test::init_service(
368 | App::new()
369 | .app_data(rocksdb_client)
370 | .app_data(node_cache)
371 | .app_data(point_cache)
372 | .configure(|conf| {
373 | conf.service(query).service(fetch);
374 | }),
375 | )
376 | .await;
377 |
378 | // fetch
379 | let req = test::TestRequest::post()
380 | .uri("/fetch")
381 | .set_json(json!({ "id": [0] }))
382 | .to_request();
383 | let resp = test::call_service(&app, req).await;
384 | assert!(resp.status().is_success());
385 |
386 | // query
387 | let query_vector: Vec = (0..768).map(|_| rand::thread_rng().gen()).collect();
388 | let req = test::TestRequest::post()
389 | .uri("/query")
390 | .set_json(json!({
391 | "vector": query_vector,
392 | "top_n": 10
393 | }))
394 | .to_request();
395 | let resp = test::call_service(&app, req).await;
396 | assert!(resp.status().is_success());
397 | }
398 | }
399 |
--------------------------------------------------------------------------------
/hollowdb/.dockerignore:
--------------------------------------------------------------------------------
1 | test
2 | config
3 | .vscode
4 | .github
5 | .env.example
6 | .gitignore
7 | .mocharc.json
8 | node_modules
9 |
--------------------------------------------------------------------------------
/hollowdb/.env.example:
--------------------------------------------------------------------------------
1 | # contract to connect to
2 | CONTRACT_ID=
3 |
4 | # path to write rocksdb
5 | ROCKSDB_PATH=
6 |
7 | # Redis URL for the caching layer
8 | REDIS_URL=
9 |
10 | # Bundlr usage
11 | USE_BUNDLR=true
12 |
13 | # Treat values as `hash.txid`
14 | USE_HTX=true
15 |
16 | # Warp's log level
17 | WARP_LOG_LEVEL=info
18 |
19 | # Batch size
20 | BUNDLR_FBS=80
21 |
--------------------------------------------------------------------------------
/hollowdb/.gitignore:
--------------------------------------------------------------------------------
1 | node_modules
2 | build
3 | cache
4 | logs
5 |
6 | dump.rdb
7 |
8 | .vscode
9 | .DS_Store
10 |
11 | !src/cache
12 |
13 | .env
14 | tmp.md
15 |
16 | data
17 |
18 | .yarn
19 |
--------------------------------------------------------------------------------
/hollowdb/.yarnrc.yml:
--------------------------------------------------------------------------------
1 | nodeLinker: node-modules
2 |
--------------------------------------------------------------------------------
/hollowdb/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM node:20 as base
2 | WORKDIR /app
3 |
4 | COPY . .
5 |
6 | RUN yarn set version berry
7 |
8 | # yarn install might take time when a non-readonly volume is attached (see https://github.com/yarnpkg/yarn/issues/7747)
9 | # ~700 seconds on an M2 Macbook Air for example
10 | RUN yarn install
11 |
12 | # Build code
13 | FROM base as builder
14 | WORKDIR /app
15 | RUN yarn build
16 |
17 | # Install prod dependencies
18 | FROM base as dependencies
19 | WORKDIR /app
20 | RUN yarn workspaces focus --production
21 |
22 | # Slim has GLIBC needed by RocksDB, Alpine does not.
23 | FROM node:20-slim
24 | # RUN apk add gcompat
25 |
26 | WORKDIR /app
27 | COPY --from=builder /app/build ./build
28 | COPY --from=dependencies /app/node_modules ./node_modules
29 |
30 | EXPOSE 3000
31 |
32 | CMD ["node", "./build/index.js"]
33 |
--------------------------------------------------------------------------------
/hollowdb/README.md:
--------------------------------------------------------------------------------
1 | # HollowDB API
2 |
3 | Each Dria knowledge is a smart-contract on Arweave, which serves the knowledge on permaweb via a key-value interface. To download them, we require HollowDB client, which is a Node package. Here we have a Fastify server that acts as an API for HollowDB, which we primarily use to download & unbundle the values for Dria HNSW to use.
4 |
5 | ## Installation
6 |
7 | Install the packages:
8 |
9 | ```sh
10 | yarn install
11 | ```
12 |
13 | ## Usage
14 |
15 | To run the server, you need to provide a contract ID along with a RocksDB path:
16 |
17 | ```sh
18 | CONTRACT_ID= ROCKSDB_PATH="path/do/rocksdb" yarn start
19 | ```
20 |
21 | HollowDB PAI is available as a container:
22 |
23 | ```sh
24 | docker pull firstbatch/dria-hollowdb
25 | ```
26 |
27 | In both cases, you will need a Redis container running at the URL defined by `REDIS_URL` environment variable.
28 |
29 | ### Configurations
30 |
31 | There are several environment variables to configure the server. You can provide them within the command line, or via `.env` file. An example is given [here](./.env.example).
32 |
33 | - `REDIS_URL=` You need a Redis server running before you start the server, the URL to the server can be provided with a `REDIS_URL` environment variable. The connection URL defaults to `redis://default:redispw@localhost:6379`.
34 |
35 | - `WARP_LOG_LEVEL=` By default Warp will log at `info` level, but you can change it via the `WARP_LOG_LEVEL` environment variable. Options are the known levels of `debug`, `error`, `fatal`, `info`, `none`, `silly`, `trace` and `warn`.
36 |
37 | - `USE_BUNDLR=` You can treat the values as transaction ids if `USE_BUNDLR` environment variable is set to be `true`. When this is the case, `REFRESH` will actually fetch the uploaded data and download it to Redis.
38 |
39 | > [!WARNING]
40 | >
41 | > Uploading to Bundlr via `PUT` or `PUT_MANY` is not yet implemented.
42 |
43 | - `USE_HTX=` When we have `USE_BUNDLR=true` we treat the stored values as transaction ids; however, HollowDB may have an alternative approach where values are stored as `hash.txid` (due to [this implementation](https://github.com/firstbatchxyz/hollowdb/blob/master/src/contracts/hollowdb-htx.contract.ts)). To comply with this approach, set `USE_HTX=true`.
44 |
45 | - `BUNDLR_FBS=` When using Bundlr, downloading values from Arweave cannot be done in a huge `Promise.all`, as it causes timeouts. We instead download values in batches, defaulting to 40 values per batch. To override the batch size, you can provide an integer value to this variable.
46 |
47 | ## Endpoints
48 |
49 | HollowDB API exposes the following endpoints:
50 |
51 | - GET [`/state`](#state)
52 | - POST [`/get`](#get)
53 | - POST [`/getRaw`](#getraw)
54 | - POST [`/getMany`](#getmany)
55 | - POST [`/getManyRaw `](#getmanyraw)
56 | - POST [`put`](#put)
57 | - POST [`putMany`](#putmany)
58 | - POST [`update`](#update)
59 | - POST [`remove`](#remove)
60 | - POST [`refresh`](#refresh)
61 | - POST [`clear`](#clear)
62 |
63 | ### `get`
64 |
65 | ```ts
66 | interface {
67 | key: string
68 | }
69 |
70 | // response body
71 | interface {
72 | value: any
73 | }
74 | ```
75 |
76 | Returns the value at the given key.
77 |
78 | ### `getRaw`
79 |
80 | ```ts
81 |
82 | // response body
83 | interface {
84 | value: any
85 | }
86 | ```
87 |
88 | Returns the value at the given `key`, directly from the cache layer & without involving Warp or Arweave.
89 |
90 | ### `getMany`
91 |
92 | ```ts
93 | interface {
94 | keys: string[]
95 | }
96 |
97 | // response body
98 | interface {
99 | values: any[]
100 | }
101 | ```
102 |
103 | Returns the values at the given `keys`.
104 |
105 | ### `getManyRaw`
106 |
107 | ```ts
108 | interface {
109 | keys: string[]
110 | }
111 |
112 | // response body
113 | interface {
114 | values: any[]
115 | }
116 | ```
117 |
118 | Returns the values at the given `keys`, reads directly from the storage.
119 |
120 | This has the advantage of not being bound to the interaction size limit, however, the user must check that the data is fresh with their own methods.
121 | Furthermore, you must make a call to `REFRESH` before using this endpoint, and subsequent calls to `REFRESH` will update the data with the new on-chain values.
122 |
123 | ### `put`
124 |
125 | ```ts
126 | interface {
127 | key: string,
128 | value: any
129 | }
130 | ```
131 |
132 | Puts `value` at the given `key`. The key must not exist already, or it must have `null` stored at it.
133 |
134 | ### `putMany`
135 |
136 | ```ts
137 | interface {
138 | keys: string[],
139 | values: any[]
140 | }
141 | ```
142 |
143 | Updates given `keys` with the provided `values`. No key must exist already in the database.
144 |
145 | ### `update`
146 |
147 | ```ts
148 | interface {
149 | key: string,
150 | value: any,
151 | proof?: object
152 | }
153 | ```
154 |
155 | Updates a `key` with the provided `value` and an optional `proof`.
156 |
157 | ### `remove`
158 |
159 | ```ts
160 | interface {
161 | key: string,
162 | proof?: object
163 | }
164 | ```
165 |
166 | Removes the value at `key`, along with an optional `proof`.
167 |
168 | ### `state`
169 |
170 | Syncs & fetches the latest contract state, and returns it.
171 |
172 | ### `refresh`
173 |
174 | Syncs & fetches the latest state and stores the latest sort key for each key in the database. Returns the number of keys refreshed for diagnostic purposes.
175 |
176 | ### `clear`
177 |
178 | ```ts
179 | interface {
180 | keys?: string[]
181 | }
182 | ```
183 |
184 | Clears the contents for given `keys` with respect to the values written by `REFRESH` endpoint. One might want to refresh some keys again, without flushing the entire database, so that is the purpose of this endpoint. Returns the number of keys cleared for diagnostic purposes.
185 |
186 | > [!TIP]
187 | >
188 | > If no `keys` are given to the `CLEAR` endpoint (i.e. `keys = undefined`) then this will clear **all keys**.
189 |
190 | ## Testing
191 |
192 | We have tests that roll a local Arweave and run tests on them with the micro server in the middle.
193 |
194 | > [!NOTE]
195 | >
196 | > You need a Redis server running in the background for the tests to work.
197 |
198 | To run tests, do:
199 |
200 | ```sh
201 | yarn test
202 | ```
203 |
--------------------------------------------------------------------------------
/hollowdb/config/.gitignore:
--------------------------------------------------------------------------------
1 | # hide yo wallet
2 | *.json
3 |
--------------------------------------------------------------------------------
/hollowdb/jest.config.ts:
--------------------------------------------------------------------------------
1 | import type { JestConfigWithTsJest } from "ts-jest";
2 |
3 | const config: JestConfigWithTsJest = {
4 | // ts-jest defaults
5 | preset: "ts-jest",
6 | testEnvironment: "node",
7 | transform: {
8 | "^.+\\.(ts|js)$": "ts-jest",
9 | },
10 | // environment setup & teardown scripts
11 | // setup & teardown for spinning up arlocal
12 | // globalSetup: "/tests/environment/setup.ts",
13 | // globalTeardown: "/tests/environment/teardown.ts",
14 | // timeout should be rather large, especially for the workflows
15 | testTimeout: 60000,
16 | // warp & arlocal takes some time to close, so make this 10 secs
17 | openHandlesTimeout: 10000,
18 | // print everything like Mocha
19 | verbose: true,
20 | // tests may hang randomly (not known why yet, it was fixed before)
21 | // that will cause workflow to run all the way, so we might force exit
22 | forceExit: true,
23 | };
24 |
25 | export default config;
26 |
--------------------------------------------------------------------------------
/hollowdb/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "author": "FirstBatch Team ",
3 | "private": true,
4 | "contributors": [
5 | "Erhan Tezcan "
6 | ],
7 | "version": "0.1.0",
8 | "main": "index.js",
9 | "scripts": {
10 | "build": "tsc -p tsconfig.build.json",
11 | "clean": "rm -rf ./build",
12 | "start": "yarn build && node ./build/index.js",
13 | "test": "jest"
14 | },
15 | "dependencies": {
16 | "@fastify/type-provider-typebox": "^4.0.0",
17 | "@sinclair/typebox": "^0.32.14",
18 | "axios": "^1.6.7",
19 | "fastify": "^4.26.1",
20 | "hollowdb": "^1.4.1",
21 | "ioredis": "^5.3.2",
22 | "levelup": "^5.1.1",
23 | "pino-pretty": "^10.3.1",
24 | "rocksdb": "^5.2.1",
25 | "warp-contracts": "^1.4.34",
26 | "warp-contracts-redis": "^0.4.2"
27 | },
28 | "devDependencies": {
29 | "@types/jest": "^29.5.12",
30 | "@types/levelup": "^5.1.5",
31 | "@types/node": "^20.11.17",
32 | "@types/rocksdb": "^3.0.5",
33 | "arlocal": "^1.1.65",
34 | "jest": "^29.7.0",
35 | "lorem-ipsum": "^2.0.8",
36 | "ts-jest": "^29.1.2",
37 | "ts-node": "^10.9.2",
38 | "typescript": "^5.3.3",
39 | "warp-contracts-plugin-deploy": "^1.0.13"
40 | },
41 | "packageManager": "yarn@4.1.0",
42 | "prettier": {
43 | "printWidth": 120
44 | }
45 | }
46 |
--------------------------------------------------------------------------------
/hollowdb/src/clients/hollowdb.ts:
--------------------------------------------------------------------------------
1 | import { WarpFactory, Warp } from "warp-contracts";
2 | import type { CacheOptions } from "warp-contracts";
3 | import type { Redis } from "ioredis";
4 | import { RedisCache, type RedisOptions } from "warp-contracts-redis";
5 | import type { CacheTypes } from "../types";
6 |
7 | /**
8 | * Utility to create Warp Redis caches.
9 | *
10 | * @param contractTxId contract transaction id to be used as prefix in the keys
11 | * @param client Redis client to use a self-managed cache
12 | * @returns caches
13 | */
14 | export function createCaches(contractTxId: string, client: Redis): CacheTypes {
15 | const defaultCacheOptions: CacheOptions = {
16 | inMemory: true,
17 | subLevelSeparator: "|",
18 | dbLocation: "redis.hollowdb",
19 | };
20 |
21 | // if a client exists, use it; otherwise connect via URL
22 | const redisOptions: RedisOptions = { client };
23 |
24 | return {
25 | state: new RedisCache(
26 | {
27 | ...defaultCacheOptions,
28 | dbLocation: `${contractTxId}.state`,
29 | },
30 | redisOptions
31 | ),
32 | contract: new RedisCache(
33 | {
34 | ...defaultCacheOptions,
35 | dbLocation: `${contractTxId}.contract`,
36 | },
37 | redisOptions
38 | ),
39 | src: new RedisCache(
40 | {
41 | ...defaultCacheOptions,
42 | dbLocation: `${contractTxId}.src`,
43 | },
44 | redisOptions
45 | ),
46 | kvFactory: (contractTxId: string) =>
47 | new RedisCache(
48 | {
49 | ...defaultCacheOptions,
50 | dbLocation: `${contractTxId}.kv`,
51 | },
52 | redisOptions
53 | ),
54 | };
55 | }
56 |
57 | /** Creates a Warp instance connected to mainnet. */
58 | export function makeWarp(caches: CacheTypes): Warp {
59 | return WarpFactory.forMainnet()
60 | .useStateCache(caches.state)
61 | .useContractCache(caches.contract, caches.src)
62 | .useKVStorageFactory(caches.kvFactory);
63 | }
64 |
--------------------------------------------------------------------------------
/hollowdb/src/clients/rocksdb.ts:
--------------------------------------------------------------------------------
1 | import Levelup from "levelup";
2 | import Rocksdb from "rocksdb";
3 | import { toValueKey } from "../utilities/download";
4 | import { existsSync, mkdirSync } from "fs";
5 |
6 | export class RocksdbClient {
7 | db: ReturnType;
8 | contractTxId: string;
9 |
10 | constructor(path: string, contractTxId: string) {
11 | if (!existsSync(path)) {
12 | mkdirSync(path, { recursive: true });
13 | }
14 |
15 | this.db = Levelup(Rocksdb(path));
16 | this.contractTxId = contractTxId;
17 | }
18 |
19 | async close() {
20 | if (!this.db.isClosed()) {
21 | await this.db.close();
22 | }
23 | }
24 |
25 | async open() {
26 | if (!this.db.isOpen()) {
27 | await this.db.open();
28 | }
29 | }
30 |
31 | async get(key: string) {
32 | const value = await this.db.get(toValueKey(this.contractTxId, key));
33 | return this.tryParse(value);
34 | }
35 |
36 | async getMany(keys: string[]) {
37 | const values = await this.db.getMany(keys.map((k) => toValueKey(this.contractTxId, k)));
38 | return values.map((v) => this.tryParse(v));
39 | }
40 |
41 | async set(key: string, value: string) {
42 | await this.db.put(toValueKey(this.contractTxId, key), value);
43 | }
44 |
45 | async setMany(pairs: [string, string][]) {
46 | await this.db.batch(
47 | pairs.map(([key, value]) => ({
48 | type: "put",
49 | key: toValueKey(this.contractTxId, key),
50 | value: value,
51 | }))
52 | );
53 | }
54 |
55 | async remove(key: string) {
56 | await this.db.del(toValueKey(this.contractTxId, key));
57 | }
58 |
59 | async removeMany(keys: string[]) {
60 | await this.db.batch(keys.map((key) => ({ type: "del", key: toValueKey(this.contractTxId, key) })));
61 | }
62 |
63 | /**
64 | * Given a value, tries to `JSON.parse` it and if parsing fails
65 | * it will return the value as-is.
66 | *
67 | * This is particularly useful when there are some stringified values,
68 | * and some other non-stringified strings together.
69 | *
70 | * @param value a stringified or null value
71 | * @template V type of the expected value
72 | * @returns parsed value or `null`
73 | */
74 | private tryParse(value: Rocksdb.Bytes | null): V | null {
75 | let result = null;
76 |
77 | if (value) {
78 | try {
79 | result = JSON.parse(value.toString());
80 | } catch (err) {
81 | // FIXME: return null here?
82 | result = value;
83 | }
84 | }
85 |
86 | return result as V;
87 | }
88 | }
89 |
--------------------------------------------------------------------------------
/hollowdb/src/configurations/index.ts:
--------------------------------------------------------------------------------
1 | import type { LogLevel } from "fastify";
2 | import type { LoggerFactory } from "warp-contracts";
3 |
4 | export default {
5 | /** Port that is listened by HollowDB. */
6 | PORT: parseInt(process.env.PORT ?? "3030"),
7 | /** Redis URL to connect to. Defaults to `redis://default:redispw@localhost:6379`. */
8 | REDIS_URL: process.env.REDIS_URL ?? "redis://default:redispw@localhost:6379",
9 | /** Path to Rocksdb storage. */
10 | ROCKSDB_PATH: process.env.ROCKSDB_PATH ?? "./data/values",
11 | /** Treat values as Bundlr txIds. */
12 | USE_BUNDLR: process.env.USE_BUNDLR ? process.env.USE_BUNDLR === "true" : process.env.NODE_ENV !== "test",
13 | /** Use the optimized [`hash.txid`](https://github.com/firstbatchxyz/hollowdb/blob/master/src/contracts/hollowdb-htx.contract.ts) layout for values. */
14 | USE_HTX: process.env.USE_HTX ? process.env.USE_HTX === "true" : process.env.NODE_ENV !== "test",
15 | /** Log level for underlying Warp. */
16 | WARP_LOG_LEVEL: (process.env.WARP_LOG_LEVEL ?? "info") as Parameters[0],
17 | /** How many fetches at once should be made to download Bundlr data? FBS stands for "Fetch Batch Size". */
18 | BUNDLR_FBS: parseInt(process.env.BUNDLR_FBS ?? "40"),
19 | /** Configurations for Bundlr downloads. */
20 | DOWNLOAD: {
21 | /** Download URL for the bundled data. */
22 | BASE_URL: "https://arweave.net",
23 | /** Max allowed timeout (milliseconds). */
24 | TIMEOUT: 50_000,
25 | /** Max attempts to retry on caught errors. */
26 | MAX_ATTEMPTS: 5,
27 | /** Time to sleep (ms) between each attempt. */
28 | ATTEMPT_SLEEP: 1000,
29 | },
30 | /** Logging stuff for the server. */
31 | LOG: {
32 | LEVEL: "info" satisfies LogLevel as LogLevel, // for some reason, :LogLevel doesnt work well
33 | REDACT: [
34 | "reqId",
35 | "res.remoteAddress",
36 | "res.remotePort",
37 | "res.hostname",
38 | "req.remoteAddress",
39 | "req.remotePort",
40 | "req.hostname",
41 | ] as string[],
42 | },
43 | } as const;
44 |
--------------------------------------------------------------------------------
/hollowdb/src/controllers/read.ts:
--------------------------------------------------------------------------------
1 | import type { RouteHandler } from "fastify";
2 | import type { Get, GetMany } from "../schemas";
3 | import { RocksdbClient } from "../clients/rocksdb";
4 |
5 | export const get: RouteHandler<{ Body: Get }> = async ({ server, body }) => {
6 | const value = await server.hollowdb.get(body.key);
7 | return { value };
8 | };
9 |
10 | export const getRaw: RouteHandler<{ Body: Get }> = async ({ server, body }) => {
11 | const rocksdb = new RocksdbClient(server.rocksdbPath, server.hollowdb.contractTxId);
12 |
13 | await rocksdb.open();
14 | const value = await rocksdb.get(body.key);
15 | await rocksdb.close();
16 |
17 | return { value };
18 | };
19 |
20 | export const getMany: RouteHandler<{ Body: GetMany }> = async ({ server, body }) => {
21 | const values = await server.hollowdb.getMany(body.keys);
22 | return { values };
23 | };
24 |
25 | export const getManyRaw: RouteHandler<{ Body: GetMany }> = async ({ server, body }) => {
26 | const rocksdb = new RocksdbClient(server.rocksdbPath, server.hollowdb.contractTxId);
27 | await rocksdb.open();
28 | const values = await rocksdb.getMany(body.keys);
29 | await rocksdb.close();
30 |
31 | return { values };
32 | };
33 |
34 | export const state: RouteHandler = async ({ server }) => {
35 | return await server.hollowdb.getState();
36 | };
37 |
--------------------------------------------------------------------------------
/hollowdb/src/controllers/values.ts:
--------------------------------------------------------------------------------
1 | import { RouteHandler } from "fastify";
2 | import type { Clear } from "../schemas";
3 | import type { Redis } from "ioredis";
4 | import { lastPossibleSortKey } from "warp-contracts";
5 | import { RocksdbClient } from "../clients/rocksdb";
6 | import { toSortKeyKey } from "../utilities/download";
7 | import { refreshKeys } from "../utilities/refresh";
8 |
9 | export const refresh: RouteHandler = async ({ server }, reply) => {
10 | const numKeysRefreshed = await refreshKeys(server);
11 | return reply.send(numKeysRefreshed);
12 | };
13 |
14 | export const clear: RouteHandler<{ Body: Clear }> = async ({ server, body }, reply) => {
15 | const kv = server.hollowdb.base.warp.kvStorageFactory(server.hollowdb.contractTxId);
16 | const redis = kv.storage();
17 |
18 | // get all existing keys (without sortKey)
19 | const keys = body.keys ?? (await kv.keys(lastPossibleSortKey));
20 |
21 | // delete the sortKey mappings from Redis
22 | await redis.del(...keys.map((key) => toSortKeyKey(server.hollowdb.contractTxId, key)));
23 |
24 | // delete the values from Rocksdb
25 | const rocksdb = new RocksdbClient(server.rocksdbPath, server.hollowdb.contractTxId);
26 | await rocksdb.open();
27 | await rocksdb.removeMany(keys);
28 | await rocksdb.close();
29 |
30 | return reply.send();
31 | };
32 |
--------------------------------------------------------------------------------
/hollowdb/src/controllers/write.ts:
--------------------------------------------------------------------------------
1 | import type { RouteHandler } from "fastify";
2 | import type { Put, PutMany, Remove, Set, SetMany, Update } from "../schemas";
3 |
4 | export const put: RouteHandler<{ Body: Put }> = async ({ server, body }, reply) => {
5 | await server.hollowdb.put(body.key, body.value);
6 | return reply.send();
7 | };
8 |
9 | export const putMany: RouteHandler<{ Body: PutMany }> = async ({ server, body }, reply) => {
10 | await server.hollowdb.putMany(body.keys, body.values);
11 | return reply.send();
12 | };
13 |
14 | export const set: RouteHandler<{ Body: Set }> = async ({ server, body }, reply) => {
15 | await server.hollowdb.set(body.key, body.value);
16 | return reply.send();
17 | };
18 |
19 | export const setMany: RouteHandler<{ Body: SetMany }> = async ({ server, body }, reply) => {
20 | await server.hollowdb.setMany(body.keys, body.values);
21 | return reply.send();
22 | };
23 |
24 | export const update: RouteHandler<{ Body: Update }> = async ({ server, body }, reply) => {
25 | await server.hollowdb.update(body.key, body.value, body.proof);
26 | return reply.send();
27 | };
28 |
29 | export const remove: RouteHandler<{ Body: Remove }> = async ({ server, body }, reply) => {
30 | await server.hollowdb.remove(body.key, body.proof);
31 | return reply.send();
32 | };
33 |
--------------------------------------------------------------------------------
/hollowdb/src/global.d.ts:
--------------------------------------------------------------------------------
1 | import fastify from "fastify";
2 | import { SetSDK } from "hollowdb";
3 | import http from "http";
4 |
5 | declare module "fastify" {
6 | export interface FastifyInstance<
7 | HttpServer = http.Server,
8 | HttpRequest = http.IncomingMessage,
9 | HttpResponse = http.ServerResponse
10 | > {
11 | /** HollowDB decorator. */
12 | hollowdb: SetSDK;
13 | /** Contract ID. */
14 | contractTxId: string;
15 | /** RocksDB Path. */
16 | rocksdbPath: string;
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/hollowdb/src/index.ts:
--------------------------------------------------------------------------------
1 | import { SetSDK } from "hollowdb";
2 | import { Redis } from "ioredis";
3 | import { makeServer } from "./server";
4 | import configurations from "./configurations";
5 | import { createCaches, makeWarp } from "./clients/hollowdb";
6 |
7 | async function main() {
8 | const contractId = process.env.CONTRACT_ID;
9 | if (!contractId) {
10 | throw new Error("Please provide CONTRACT_ID environment variable.");
11 | }
12 | if (Buffer.from(contractId, "base64").toString("hex").length !== 64) {
13 | throw new Error("Invalid CONTRACT_ID.");
14 | }
15 |
16 | // ping redis to make sure connection is there before moving on
17 | const redisClient = new Redis(configurations.REDIS_URL);
18 | await redisClient.ping();
19 |
20 | // create Redis caches & use them for Warp
21 | const caches = createCaches(contractId, redisClient);
22 | const warp = makeWarp(caches);
23 |
24 | // create a random wallet, which is ok since we only make read operations
25 | // TODO: or, we can use a dummy throw-away wallet every time?
26 | const wallet = await warp.generateWallet();
27 |
28 | const hollowdb = new SetSDK(wallet.jwk, contractId, warp);
29 | const server = await makeServer(hollowdb, configurations.ROCKSDB_PATH);
30 | const addr = await server.listen({
31 | port: configurations.PORT,
32 | // host is set to listen on all interfaces to allow Docker internal network to work
33 | // see: https://fastify.dev/docs/latest/Reference/Server/#listentextresolver
34 | host: "::",
35 | listenTextResolver: (address) => {
36 | return `HollowDB is listening at ${address}`;
37 | },
38 | });
39 | server.log.info(`Listening at: ${addr}`);
40 | }
41 |
42 | main();
43 |
--------------------------------------------------------------------------------
/hollowdb/src/schemas/index.ts:
--------------------------------------------------------------------------------
1 | import { Static, Type } from "@sinclair/typebox";
2 |
3 | export const Get = Type.Object({
4 | key: Type.String(),
5 | });
6 | export type Get = Static;
7 |
8 | export const GetMany = Type.Object({
9 | keys: Type.Array(Type.String()),
10 | });
11 | export type GetMany = Static;
12 |
13 | export const Put = Type.Object({
14 | key: Type.String(),
15 | value: Type.Any(),
16 | });
17 | export type Put = Static;
18 |
19 | export const PutMany = Type.Object({
20 | keys: Type.Array(Type.String()),
21 | values: Type.Array(Type.Any()),
22 | });
23 | export type PutMany = Static;
24 |
25 | export const Remove = Type.Object({
26 | key: Type.String(),
27 | proof: Type.Optional(Type.Any()),
28 | });
29 | export type Remove = Static;
30 |
31 | export const Update = Type.Object({
32 | key: Type.String(),
33 | value: Type.Any(),
34 | proof: Type.Optional(Type.Any()),
35 | });
36 | export type Update = Static;
37 |
38 | export const Set = Type.Object({
39 | key: Type.String(),
40 | value: Type.Any(),
41 | });
42 | export type Set = Static;
43 |
44 | export const SetMany = Type.Object({
45 | keys: Type.Array(Type.String()),
46 | values: Type.Array(Type.Any()),
47 | });
48 | export type SetMany = Static;
49 |
50 | export const Clear = Type.Object({
51 | keys: Type.Optional(Type.Array(Type.String())),
52 | });
53 | export type Clear = Static;
54 |
--------------------------------------------------------------------------------
/hollowdb/src/server.ts:
--------------------------------------------------------------------------------
1 | import fastify, { LogLevel } from "fastify";
2 | import type { TypeBoxTypeProvider } from "@fastify/type-provider-typebox";
3 | import { get, getMany, getManyRaw, getRaw, state } from "./controllers/read";
4 | import { put, putMany, remove, set, setMany, update } from "./controllers/write";
5 | import { clear, refresh } from "./controllers/values";
6 | import { Clear, Get, GetMany, Put, PutMany, Remove, Set, SetMany, Update } from "./schemas";
7 | import { SetSDK } from "hollowdb";
8 | import { LoggerFactory } from "warp-contracts";
9 | import configurations from "./configurations";
10 | import { refreshKeys } from "./utilities/refresh";
11 | import { Redis } from "ioredis";
12 |
13 | export async function makeServer(hollowdb: SetSDK, rocksdbPath: string) {
14 | const server = fastify({
15 | logger: {
16 | level: configurations.LOG.LEVEL,
17 | transport: { target: "pino-pretty" },
18 | redact: {
19 | paths: configurations.LOG.REDACT,
20 | remove: true,
21 | },
22 | },
23 | }).withTypeProvider();
24 | LoggerFactory.INST.logLevel(configurations.LOG.LEVEL === "silent" ? "none" : configurations.LOG.LEVEL);
25 |
26 | server.decorate("hollowdb", hollowdb);
27 | server.decorate("contractTxId", hollowdb.contractTxId);
28 | server.decorate("rocksdbPath", rocksdbPath); // TODO: store RocksDB itself here maybe?
29 |
30 | // check redis
31 | await server.hollowdb.base.warp.kvStorageFactory(server.hollowdb.contractTxId).storage().ping();
32 |
33 | // refresh keys
34 | server.log.info("Waiting for cache to be loaded.");
35 | const numKeysRefreshed = await refreshKeys(server);
36 |
37 | server.log.info(`Server synced & ready! (${numKeysRefreshed} keys refreshed)`);
38 | server.log.info(`> Redis: ${configurations.REDIS_URL}`);
39 | server.log.info(`> RocksDB: ${configurations.ROCKSDB_PATH}`);
40 | server.log.info(`> Download URL: ${configurations.DOWNLOAD.BASE_URL} (timeout ${configurations.DOWNLOAD.TIMEOUT})`);
41 | server.log.info(`> Contract: https://sonar.warp.cc/#/app/contract/${server.contractTxId}`);
42 |
43 | server.get("/state", state);
44 | server.post("/get", { schema: { body: Get } }, get);
45 | server.post("/getRaw", { schema: { body: Get } }, getRaw);
46 | server.post("/getMany", { schema: { body: GetMany } }, getMany);
47 | server.post("/getManyRaw", { schema: { body: GetMany } }, getManyRaw);
48 | server.post("/put", { schema: { body: Put } }, put);
49 | server.post("/putMany", { schema: { body: PutMany } }, putMany);
50 | server.post("/set", { schema: { body: Set } }, set);
51 | server.post("/setMany", { schema: { body: SetMany } }, setMany);
52 | server.post("/update", { schema: { body: Update } }, update);
53 | server.post("/remove", { schema: { body: Remove } }, remove);
54 | server.post("/clear", { schema: { body: Clear } }, clear);
55 | server.post("/refresh", refresh);
56 |
57 | server.addHook("onError", (request, reply, error, done) => {
58 | if (error.message.startsWith("Contract Error")) {
59 | reply.status(400);
60 | }
61 | done();
62 | });
63 |
64 | return server;
65 | }
66 |
--------------------------------------------------------------------------------
/hollowdb/src/types/index.ts:
--------------------------------------------------------------------------------
1 | import type { SortKeyCacheResult } from "warp-contracts";
2 | import type { RedisCache } from "warp-contracts-redis";
3 |
4 | /** Cache types used by `Warp`. */
5 | export type CacheTypes = {
6 | state: C;
7 | contract: C;
8 | src: C;
9 | kvFactory: (contractTxId: string) => C;
10 | };
11 |
12 | /**
13 | * A `SortKeyCacheResult` with its respective `key` attached to it.
14 | *
15 | * - `key` is accessed as `.key`
16 | * - `sortKey` is accessed as `.sortKeyCacheResult.sortKey`
17 | * - `value` is accessed as `.sortKeyCacheResult.value`
18 | */
19 | export type KeyedSortKeyCacheResult = { sortKeyCacheResult: SortKeyCacheResult; key: string };
20 |
--------------------------------------------------------------------------------
/hollowdb/src/utilities/download.ts:
--------------------------------------------------------------------------------
1 | import axios from "axios";
2 | import configurations from "../configurations";
3 | import { FastifyBaseLogger } from "fastify";
4 | import { sleep } from "warp-contracts";
5 |
6 | /** Downloads an object from Bundlr network w.r.t transaction id.
7 | *
8 | * If USE_HTX is enabled, it means that values are stored as `hash.txId` (as a string),
9 | * so to get the txid we split by `.` and then get the second element.
10 | *
11 | * @param txid transaction ID on Arweave
12 | * @template V type of the value
13 | * @returns unbundled raw value
14 | */
15 | export async function downloadFromBundlr(txid: string, log: FastifyBaseLogger) {
16 | if (configurations.USE_HTX) {
17 | const split = txid.split(".");
18 | if (split.length !== 2) {
19 | log.warn("Expected value to be a hash.txid, received: " + txid);
20 | }
21 | txid = split[1];
22 | }
23 |
24 | const url = `${configurations.DOWNLOAD.BASE_URL}/${txid}`;
25 |
26 | // try for a few times
27 | for (let a = 0; a <= configurations.DOWNLOAD.MAX_ATTEMPTS; ++a) {
28 | try {
29 | const response = await axios.get(url, {
30 | timeout: configurations.DOWNLOAD.TIMEOUT,
31 | });
32 | if (response.status !== 200) {
33 | throw new Error(`Bundlr failed with ${response.status}`);
34 | }
35 | return response.data as V;
36 | } catch (err) {
37 | if (a === configurations.DOWNLOAD.MAX_ATTEMPTS) {
38 | throw err;
39 | } else {
40 | log.warn(
41 | `(tries: ${a + 1}/${configurations.DOWNLOAD.MAX_ATTEMPTS})` +
42 | `\tError downloading ${url}: ${(err as Error).message}`
43 | );
44 | await sleep(configurations.DOWNLOAD.ATTEMPT_SLEEP);
45 | }
46 | }
47 | }
48 |
49 | throw new Error("All attempts failed.");
50 | }
51 |
52 | /** Returns a pretty string about the current download progress.
53 | * @param cur current value, can be more than `max`
54 | * @param max maximum value
55 | * @param decimals (optional) number of decimals for the percentage (default: 2)
56 | * @returns progress description
57 | */
58 | export function progressString(cur: number, max: number, decimals: number = 2) {
59 | const val = Math.min(cur, max);
60 | const percentage = (val / max) * 100;
61 | return `[${val} / ${max}] (${percentage.toFixed(decimals)}%)`;
62 | }
63 |
64 | /**
65 | * Map a given key to a value key.
66 | * @param contractTxId contract txID
67 | * @param key key
68 | * @returns value key
69 | */
70 | export function toValueKey(contractTxId: string, key: string) {
71 | return `${contractTxId}.value.${key}`;
72 | }
73 |
74 | /**
75 | * Map a given key to a sortKey key.
76 | * @param contractTxId contract txID
77 | * @param key key
78 | * @returns sortKey key
79 | */
80 | export function toSortKeyKey(contractTxId: string, key: string) {
81 | return `${contractTxId}.sortKey.${key}`;
82 | }
83 |
--------------------------------------------------------------------------------
/hollowdb/src/utilities/refresh.ts:
--------------------------------------------------------------------------------
1 | import type { FastifyInstance } from "fastify";
2 | import type { Redis } from "ioredis";
3 | import { lastPossibleSortKey } from "warp-contracts";
4 | import type { KeyedSortKeyCacheResult } from "../types";
5 | import { downloadFromBundlr, progressString, toSortKeyKey } from "./download";
6 | import { RocksdbClient } from "../clients/rocksdb";
7 | import configurations from "../configurations";
8 |
9 | /**
10 | * Refresh keys.
11 | * @param server HollowDB server
12 | * @returns number of refreshed keys
13 | */
14 | export async function refreshKeys(server: FastifyInstance): Promise {
15 | server.log.info(`\nRefreshing keys (${server.hollowdb.contractTxId})\n`);
16 | await server.hollowdb.base.readState(); // get to the latest state
17 |
18 | const kv = server.hollowdb.base.warp.kvStorageFactory(server.hollowdb.contractTxId);
19 | const redis = kv.storage();
20 |
21 | // get all keys
22 | const keys = await kv.keys(lastPossibleSortKey);
23 |
24 | // return early if there are no keys
25 | if (keys.length === 0) {
26 | server.log.info("All keys are up-to-date.");
27 | return 0;
28 | }
29 |
30 | // get the last sortKey for each key
31 | const sortKeyCacheResults = await Promise.all(keys.map((key) => kv.getLast(key)));
32 |
33 | // from these values, get the ones that are out-of-date (i.e. stale)
34 | const latestSortKeys: (string | null)[] = sortKeyCacheResults.map((skcr) => (skcr ? skcr.sortKey : null));
35 | const existingSortKeys: (string | null)[] = await redis.mget(
36 | ...keys.map((key) => toSortKeyKey(server.contractTxId, key))
37 | );
38 | const staleResults: KeyedSortKeyCacheResult[] = sortKeyCacheResults
39 | .map((skcr, i) =>
40 | // filter out existing sortKeys
41 | // also store the respective `key` with the result
42 | latestSortKeys[i] !== existingSortKeys[i] ? { sortKeyCacheResult: skcr, key: keys[i] } : null
43 | )
44 | // this filter will filter out both existing null values, and matching sortKeys
45 | .filter((res): res is KeyedSortKeyCacheResult => res !== null);
46 |
47 | // return early if everything is up-to-date
48 | if (staleResults.length === 0) {
49 | return 0;
50 | }
51 |
52 | const rocksdb = new RocksdbClient(server.rocksdbPath, server.contractTxId);
53 | await rocksdb.open();
54 |
55 | const refreshValues = async (results: KeyedSortKeyCacheResult[], values?: unknown[]) => {
56 | if (values && values.length !== results.length) {
57 | throw new Error("array length mismatch");
58 | }
59 |
60 | // create [key, value] pairs for with stringified values
61 | const valuePairs = results.map(({ key, sortKeyCacheResult: { cachedValue } }, i) => {
62 | const val = values
63 | ? // override with given value
64 | typeof values[i] === "string"
65 | ? (values[i] as string)
66 | : JSON.stringify(values[i])
67 | : // use own value
68 | JSON.stringify(cachedValue);
69 | return [key, val] as [string, string];
70 | });
71 |
72 | // write values to disk (as they may be too much for the memory)
73 | await rocksdb.setMany(valuePairs);
74 |
75 | // store the `sortKey`s for later refreshes to see if a `value` is stale
76 | const sortKeyPairs = results.map(
77 | ({ key, sortKeyCacheResult: { sortKey } }) =>
78 | [toSortKeyKey(server.contractTxId, key), sortKey] as [string, string]
79 | );
80 | await redis.mset(...sortKeyPairs.flat());
81 | };
82 |
83 | // update values in Redis
84 |
85 | const { USE_BUNDLR, BUNDLR_FBS } = configurations;
86 | if (USE_BUNDLR) {
87 | const progress: [number, number] = [0, 0];
88 |
89 | server.log.info("Starting batched Bundlr downloads:");
90 | progress[1] = staleResults.length;
91 | for (let b = 0; b < staleResults.length; b += BUNDLR_FBS) {
92 | const batchResults = staleResults.slice(b, b + BUNDLR_FBS);
93 |
94 | progress[0] = Math.min(b + BUNDLR_FBS, staleResults.length);
95 |
96 | const startTime = performance.now();
97 | const batchValues = await Promise.all(
98 | batchResults.map((result) =>
99 | downloadFromBundlr<{ data: any }>(result.sortKeyCacheResult.cachedValue as string, server.log)
100 | )
101 | );
102 | const endTime = performance.now();
103 | server.log.info(
104 | `${progressString(progress[0], progress[1])} values downloaded (${(endTime - startTime).toFixed(2)} ms)`
105 | );
106 |
107 | await refreshValues(
108 | batchResults,
109 | // our Bundlr service uploads as "{data: payload}" so we parse it here
110 | batchValues.map((val) => val.data)
111 | );
112 | }
113 | server.log.info("Downloaded & refreshed all stale values.");
114 | } else {
115 | await refreshValues(staleResults);
116 | }
117 |
118 | await rocksdb.close();
119 |
120 | return staleResults.length;
121 | }
122 |
--------------------------------------------------------------------------------
/hollowdb/test/index.test.ts:
--------------------------------------------------------------------------------
1 | import ArLocal from "arlocal";
2 | import { ArWallet, LoggerFactory, sleep } from "warp-contracts";
3 |
4 | import { Redis } from "ioredis";
5 | import { SetSDK } from "hollowdb";
6 |
7 | import { makeServer } from "../src/server";
8 | import config from "../src/configurations";
9 | import { createCaches } from "../src/clients/hollowdb";
10 | import { deploy, FetchClient, randomKeyValue, makeLocalWarp } from "./util";
11 | import { Get, GetMany, Put, PutMany, Update } from "../src/schemas";
12 | import { randomBytes } from "crypto";
13 | import { rmSync } from "fs";
14 |
15 | describe("crud operations", () => {
16 | let arlocal: ArLocal;
17 | let redisClient: Redis;
18 | let client: FetchClient;
19 | let url: string;
20 |
21 | const DATA_PATH = "./test/data";
22 | const ARWEAVE_PORT = 3169;
23 | const VALUE = randomBytes(16).toString("hex");
24 | const NEW_VALUE = randomBytes(16).toString("hex");
25 | const KEY = randomBytes(16).toString("hex");
26 |
27 | beforeAll(async () => {
28 | console.log("Starting...");
29 |
30 | // create a local Arweave instance
31 | arlocal = new ArLocal(ARWEAVE_PORT, false);
32 | await arlocal.start();
33 |
34 | // deploy a contract locally and generate a wallet
35 | redisClient = new Redis(config.REDIS_URL, { lazyConnect: false });
36 | let caches = createCaches("testing-setup", redisClient);
37 | let warp = makeLocalWarp(ARWEAVE_PORT, caches);
38 | const owner: ArWallet = (await warp.generateWallet()).jwk;
39 | const { contractTxId } = await deploy(owner, warp);
40 |
41 | // start the server & connect to the contract
42 | caches = createCaches(contractTxId, redisClient);
43 | warp = makeLocalWarp(ARWEAVE_PORT, caches);
44 | const hollowdb = new SetSDK(owner, contractTxId, warp);
45 | const server = await makeServer(hollowdb, `${DATA_PATH}/${contractTxId}`);
46 | url = await server.listen({ port: config.PORT });
47 | LoggerFactory.INST.logLevel("none");
48 |
49 | client = new FetchClient(url);
50 |
51 | // TODO: wait a bit due to state syncing
52 | console.log("waiting a bit for the server to be ready...");
53 | await sleep(1200);
54 | console.log("done");
55 | });
56 |
57 | describe("basic CRUD", () => {
58 | it("should put & get a value", async () => {
59 | const putResponse = await client.post("/put", { key: KEY, value: VALUE });
60 | expect(putResponse.status).toBe(200);
61 |
62 | const getResponse = await client.post("/get", { key: KEY });
63 | expect(getResponse.status).toBe(200);
64 | expect(await getResponse.json().then((body) => body.value)).toBe(VALUE);
65 | });
66 |
67 | it("should NOT put to an existing key", async () => {
68 | const putResponse = await client.post("/put", { key: KEY, value: VALUE });
69 | expect(putResponse.status).toBe(400);
70 | const body = await putResponse.json();
71 | expect(body.message).toBe("Contract Error [put]: Key already exists.");
72 | });
73 |
74 | it("should update & get the new value", async () => {
75 | const updateResponse = await client.post("/update", { key: KEY, value: NEW_VALUE });
76 | expect(updateResponse.status).toBe(200);
77 |
78 | const getResponse = await client.post("/get", { key: KEY });
79 | expect(getResponse.status).toBe(200);
80 | expect(await getResponse.json().then((body) => body.value)).toBe(NEW_VALUE);
81 | });
82 |
83 | it("should remove the new value & get null", async () => {
84 | const removeResponse = await client.post("/remove", {
85 | key: KEY,
86 | });
87 | expect(removeResponse.status).toBe(200);
88 |
89 | const getResponse = await client.post("/get", { key: KEY });
90 | expect(getResponse.status).toBe(200);
91 | expect(await getResponse.json().then((body) => body.value)).toBe(null);
92 | });
93 | });
94 |
95 | describe("batch gets and puts", () => {
96 | const LENGTH = 10;
97 | const KEY_VALUES = Array.from({ length: LENGTH }, () => randomKeyValue({ numVals: 768 }));
98 |
99 | it("should put many values", async () => {
100 | const putResponse = await client.post("/putMany", {
101 | keys: KEY_VALUES.map((kv) => kv.key),
102 | values: KEY_VALUES.map((kv) => kv.value),
103 | });
104 | expect(putResponse.status).toBe(200);
105 | });
106 |
107 | it("should get many values", async () => {
108 | const getManyResponse = await client.post("/getMany", {
109 | keys: KEY_VALUES.map((kv) => kv.key),
110 | });
111 | expect(getManyResponse.status).toBe(200);
112 | const body = await getManyResponse.json();
113 | for (let i = 0; i < KEY_VALUES.length; i++) {
114 | const expected = KEY_VALUES[i].value;
115 | const result = body.values[i] as (typeof KEY_VALUES)[0]["value"];
116 | expect(result.metadata.text).toBe(expected.metadata.text);
117 | expect(result.f).toBe(expected.f);
118 | }
119 | });
120 |
121 | it("should refresh the cache for raw GET operations", async () => {
122 | const refreshResponse = await client.post("/refresh");
123 | expect(refreshResponse.status).toBe(200);
124 | });
125 |
126 | it("should do a raw GET", async () => {
127 | const { key, value } = KEY_VALUES[0];
128 |
129 | const getRawResponse = await client.post("/getRaw", { key });
130 | expect(getRawResponse.status).toBe(200);
131 | const body = await getRawResponse.json();
132 | const result = body.value as typeof value;
133 | expect(result.metadata.text).toBe(value.metadata.text);
134 | expect(result.f).toBe(value.f);
135 | });
136 |
137 | it("should do a raw GET many", async () => {
138 | const getManyRawResponse = await client.post("/getManyRaw", {
139 | keys: KEY_VALUES.map((kv) => kv.key),
140 | });
141 | expect(getManyRawResponse.status).toBe(200);
142 | const body = await getManyRawResponse.json();
143 | for (let i = 0; i < KEY_VALUES.length; i++) {
144 | const expected = KEY_VALUES[i].value;
145 | const result = body.values[i] as (typeof KEY_VALUES)[0]["value"];
146 | expect(result.metadata.text).toBe(expected.metadata.text);
147 | expect(result.f).toBe(expected.f);
148 | }
149 | });
150 |
151 | it("should refresh a newly PUT key", async () => {
152 | const { key, value } = randomKeyValue();
153 | const putResponse = await client.post("/put", { key, value });
154 | expect(putResponse.status).toBe(200);
155 |
156 | // we expect only 1 new key to be added via REFRESH
157 | const refreshResponse = await client.post("/refresh");
158 | expect(refreshResponse.status).toBe(200);
159 | expect(await refreshResponse.text()).toBe("1");
160 | });
161 |
162 | it("should refresh with 0 keys when no additions are made", async () => {
163 | const refreshResponse = await client.post("/refresh");
164 | expect(refreshResponse.status).toBe(200);
165 | expect(await refreshResponse.text()).toBe("0");
166 | });
167 |
168 | it("should clear all keys", async () => {
169 | // we expect 0 keys as nothing has changed since the last refresh
170 | const clearResponse = await client.post("/clear");
171 | expect(clearResponse.status).toBe(200);
172 |
173 | const getManyRawResponse = await client.post("/getManyRaw", {
174 | keys: KEY_VALUES.map((kv) => kv.key),
175 | });
176 | expect(getManyRawResponse.status).toBe(200);
177 | const body = await getManyRawResponse.json();
178 |
179 | (body.values as (typeof KEY_VALUES)[0]["value"][]).forEach((val) => expect(val).toBe(null));
180 | });
181 | });
182 |
183 | afterAll(async () => {
184 | console.log("waiting a bit before closing...");
185 | await sleep(1500);
186 |
187 | rmSync(DATA_PATH, { recursive: true });
188 | await arlocal.stop();
189 | await redisClient.quit();
190 | });
191 | });
192 |
--------------------------------------------------------------------------------
/hollowdb/test/res/contractSource.ts:
--------------------------------------------------------------------------------
1 | // you can use any build under https://github.com/firstbatchxyz/hollowdb/tree/master/src/contracts/build
2 | export default `
3 | // src/contracts/errors/index.ts
4 | var KeyExistsError = new ContractError("Key already exists.");
5 | var KeyNotExistsError = new ContractError("Key does not exist.");
6 | var CantEvolveError = new ContractError("Evolving is disabled.");
7 | var NoVerificationKeyError = new ContractError("No verification key.");
8 | var UnknownProtocolError = new ContractError("Unknown protocol.");
9 | var NotWhitelistedError = new ContractError("Not whitelisted.");
10 | var InvalidProofError = new ContractError("Invalid proof.");
11 | var ExpectedProofError = new ContractError("Expected a proof.");
12 | var NullValueError = new ContractError("Value cant be null, use remove instead.");
13 | var NotOwnerError = new ContractError("Not contract owner.");
14 | var InvalidFunctionError = new ContractError("Invalid function.");
15 | var ArrayLengthMismatchError = new ContractError("Key and value counts mismatch.");
16 |
17 | // src/contracts/utils/index.ts
18 | var verifyProof = async (proof, psignals, verificationKey) => {
19 | if (!verificationKey) {
20 | throw NoVerificationKeyError;
21 | }
22 | if (verificationKey.protocol !== "groth16" && verificationKey.protocol !== "plonk") {
23 | throw UnknownProtocolError;
24 | }
25 | return await SmartWeave.extensions[verificationKey.protocol].verify(verificationKey, psignals, proof);
26 | };
27 | var hashToGroup = (value) => {
28 | if (value) {
29 | return BigInt(SmartWeave.extensions.ethers.utils.ripemd160(Buffer.from(JSON.stringify(value))));
30 | } else {
31 | return BigInt(0);
32 | }
33 | };
34 |
35 | // src/contracts/modifiers/index.ts
36 | var onlyOwner = (caller, input, state) => {
37 | if (caller !== state.owner) {
38 | throw NotOwnerError;
39 | }
40 | return input;
41 | };
42 | var onlyNonNullValue = (_, input) => {
43 | if (input.value === null) {
44 | throw NullValueError;
45 | }
46 | return input;
47 | };
48 | var onlyNonNullValues = (_, input) => {
49 | if (input.values.some((val) => val === null)) {
50 | throw NullValueError;
51 | }
52 | return input;
53 | };
54 | var onlyWhitelisted = (list) => {
55 | return (caller, input, state) => {
56 | if (!state.isWhitelistRequired[list]) {
57 | return input;
58 | }
59 | if (!state.whitelists[list][caller]) {
60 | throw NotWhitelistedError;
61 | }
62 | return input;
63 | };
64 | };
65 | var onlyProofVerified = (proofName, prepareInputs) => {
66 | return async (caller, input, state) => {
67 | if (!state.isProofRequired[proofName]) {
68 | return input;
69 | }
70 | if (!input.proof) {
71 | throw ExpectedProofError;
72 | }
73 | const ok = await verifyProof(
74 | input.proof,
75 | await prepareInputs(caller, input, state),
76 | state.verificationKeys[proofName]
77 | );
78 | if (!ok) {
79 | throw InvalidProofError;
80 | }
81 | return input;
82 | };
83 | };
84 | async function apply(caller, input, state, ...modifiers) {
85 | for (const modifier of modifiers) {
86 | input = await modifier(caller, input, state);
87 | }
88 | return input;
89 | }
90 |
91 | // src/contracts/hollowdb-set.contract.ts
92 | var handle = async (state, action) => {
93 | const { caller, input } = action;
94 | switch (input.function) {
95 | case "get": {
96 | const { key } = await apply(caller, input.value, state);
97 | return { result: await SmartWeave.kv.get(key) };
98 | }
99 | case "getMany": {
100 | const { keys } = await apply(caller, input.value, state);
101 | const values = await Promise.all(keys.map((key) => SmartWeave.kv.get(key)));
102 | return { result: values };
103 | }
104 | case "set": {
105 | const { key, value } = await apply(caller, input.value, state, onlyWhitelisted("set"), onlyNonNullValue);
106 | await SmartWeave.kv.put(key, value);
107 | return { state };
108 | }
109 | case "setMany": {
110 | const { keys, values } = await apply(caller, input.value, state, onlyWhitelisted("set"), onlyNonNullValues);
111 | if (keys.length !== values.length) {
112 | throw new ContractError("Key and value counts mismatch");
113 | }
114 | await Promise.all(keys.map((key, i) => SmartWeave.kv.put(key, values[i])));
115 | return { state };
116 | }
117 | case "getKeys": {
118 | const { options } = await apply(caller, input.value, state);
119 | return { result: await SmartWeave.kv.keys(options) };
120 | }
121 | case "getKVMap": {
122 | const { options } = await apply(caller, input.value, state);
123 | return { result: await SmartWeave.kv.kvMap(options) };
124 | }
125 | case "put": {
126 | const { key, value } = await apply(caller, input.value, state, onlyWhitelisted("put"), onlyNonNullValue);
127 | if (await SmartWeave.kv.get(key) !== null) {
128 | throw KeyExistsError;
129 | }
130 | await SmartWeave.kv.put(key, value);
131 | return { state };
132 | }
133 | case "putMany": {
134 | const { keys, values } = await apply(caller, input.value, state, onlyWhitelisted("put"), onlyNonNullValues);
135 | if (keys.length !== values.length) {
136 | throw new ContractError("Key and value counts mismatch");
137 | }
138 | if (await Promise.all(keys.map((key) => SmartWeave.kv.get(key))).then((values2) => values2.some((val) => val !== null))) {
139 | throw KeyExistsError;
140 | }
141 | await Promise.all(keys.map((key, i) => SmartWeave.kv.put(key, values[i])));
142 | return { state };
143 | }
144 | case "update": {
145 | const { key, value } = await apply(
146 | caller,
147 | input.value,
148 | state,
149 | onlyNonNullValue,
150 | onlyWhitelisted("update"),
151 | onlyProofVerified("auth", async (_, input2) => {
152 | const oldValue = await SmartWeave.kv.get(input2.key);
153 | return [hashToGroup(oldValue), hashToGroup(input2.value), BigInt(input2.key)];
154 | })
155 | );
156 | await SmartWeave.kv.put(key, value);
157 | return { state };
158 | }
159 | case "remove": {
160 | const { key } = await apply(
161 | caller,
162 | input.value,
163 | state,
164 | onlyWhitelisted("update"),
165 | onlyProofVerified("auth", async (_, input2) => {
166 | const oldValue = await SmartWeave.kv.get(input2.key);
167 | return [hashToGroup(oldValue), BigInt(0), BigInt(input2.key)];
168 | })
169 | );
170 | await SmartWeave.kv.del(key);
171 | return { state };
172 | }
173 | case "updateOwner": {
174 | const { newOwner } = await apply(caller, input.value, state, onlyOwner);
175 | state.owner = newOwner;
176 | return { state };
177 | }
178 | case "updateProofRequirement": {
179 | const { name, value } = await apply(caller, input.value, state, onlyOwner);
180 | state.isProofRequired[name] = value;
181 | return { state };
182 | }
183 | case "updateVerificationKey": {
184 | const { name, verificationKey } = await apply(caller, input.value, state, onlyOwner);
185 | state.verificationKeys[name] = verificationKey;
186 | return { state };
187 | }
188 | case "updateWhitelistRequirement": {
189 | const { name, value } = await apply(caller, input.value, state, onlyOwner);
190 | state.isWhitelistRequired[name] = value;
191 | return { state };
192 | }
193 | case "updateWhitelist": {
194 | const { add, remove, name } = await apply(caller, input.value, state, onlyOwner);
195 | add.forEach((user) => {
196 | state.whitelists[name][user] = true;
197 | });
198 | remove.forEach((user) => {
199 | delete state.whitelists[name][user];
200 | });
201 | return { state };
202 | }
203 | case "evolve": {
204 | const srcTxId = await apply(caller, input.value, state, onlyOwner);
205 | if (!state.canEvolve) {
206 | throw CantEvolveError;
207 | }
208 | state.evolve = srcTxId;
209 | return { state };
210 | }
211 | default:
212 | input;
213 | throw InvalidFunctionError;
214 | }
215 | };
216 | ` as string;
217 |
--------------------------------------------------------------------------------
/hollowdb/test/res/initialState.ts:
--------------------------------------------------------------------------------
1 | export default {
2 | owner: "",
3 | verificationKeys: {
4 | auth: null,
5 | },
6 | isProofRequired: {
7 | auth: false,
8 | },
9 | canEvolve: false,
10 | whitelist: {
11 | put: {},
12 | update: {},
13 | set: {},
14 | },
15 | isWhitelistRequired: {
16 | put: false,
17 | update: false,
18 | set: false,
19 | },
20 | } as const;
21 |
--------------------------------------------------------------------------------
/hollowdb/test/util/index.ts:
--------------------------------------------------------------------------------
1 | import initialState from "../res/initialState";
2 | import contractSource from "../res/contractSource";
3 | import { randomUUID } from "crypto";
4 | import { loremIpsum } from "lorem-ipsum";
5 | import { JWKInterface, WarpFactory, Warp } from "warp-contracts";
6 | import { DeployPlugin } from "warp-contracts-plugin-deploy";
7 | import { CacheTypes } from "../../src/types";
8 | /**
9 | * Returns the size of a given data in bytes.
10 | * - To convert to KBs: `size / (1 << 10)`
11 | * - To convert to MBs: `size / (1 << 20)`
12 | * @param data data, such as `JSON.stringify(body)` for a POST request.
13 | * @returns data size in bytes
14 | */
15 | export function size(data: string) {
16 | return new Blob([data]).size;
17 | }
18 |
19 | /** A tiny API wrapper. */
20 | export class FetchClient {
21 | constructor(readonly baseUrl: string) {}
22 |
23 | /**
24 | * Generic POST utility for HollowDB micro. Depending on the
25 | * request, call `response.json()` or `response.text()` to parse
26 | * the returned body.
27 | * @param url url
28 | * @param data body
29 | * @returns response object
30 | */
31 | async post(url: string, data?: Body) {
32 | const body = JSON.stringify(data ?? {});
33 | return fetch(this.baseUrl + url, {
34 | method: "POST",
35 | headers: {
36 | "Content-Type": "application/json; charset=utf-8",
37 | },
38 | body,
39 | });
40 | }
41 | }
42 |
43 | /**
44 | * Creates a local warp instance, also uses the `DeployPlugin`.
45 | *
46 | * WARNING: Do not use `useStateCache` and `useContractCache` together with `forLocal`.
47 | */
48 | export function makeLocalWarp(port: number, caches: CacheTypes): Warp {
49 | return WarpFactory.forLocal(port).use(new DeployPlugin()).useKVStorageFactory(caches.kvFactory);
50 | }
51 |
52 | /** Returns a random key-value pair related to our internal usage. */
53 | export function randomKeyValue(options?: { numVals?: number; numChildren?: number }): {
54 | key: string;
55 | value: {
56 | children: number[];
57 | f: number;
58 | is_leaf: boolean;
59 | n_descendants: number;
60 | metadata: {
61 | text: string;
62 | };
63 | v: number[];
64 | };
65 | } {
66 | const numChildren = options?.numChildren || Math.round(Math.random() * 5);
67 | const numVals = options?.numVals || Math.round(Math.random() * 500 + 100);
68 |
69 | return {
70 | key: randomUUID(),
71 | value: {
72 | children: Array.from({ length: numChildren }, () => Math.round(Math.random() * 50)),
73 | f: Math.round(Math.random() * 100),
74 | is_leaf: Math.random() < 0.5,
75 | metadata: {
76 | text: loremIpsum({ count: 4 }),
77 | },
78 | n_descendants: Math.round(Math.random() * 50),
79 | v: Array.from({ length: numVals }, () => Math.random() * 2 - 1),
80 | },
81 | };
82 | }
83 |
84 | /**
85 | * Deploy a new contract via the provided Warp instance.
86 | * @param owner owner wallet
87 | * @param warp a `Warp` instance
88 | * */
89 | export async function deploy(
90 | owner: JWKInterface,
91 | warp: Warp
92 | ): Promise<{ contractTxId: string; srcTxId: string | undefined }> {
93 | if (warp.environment !== "local") {
94 | throw new Error("Expected a local Warp environment.");
95 | }
96 |
97 | const { contractTxId, srcTxId } = await warp.deploy(
98 | {
99 | wallet: owner,
100 | initState: JSON.stringify(initialState),
101 | src: contractSource,
102 | evaluationManifest: {
103 | evaluationOptions: {
104 | allowBigInt: true,
105 | useKVStorage: true,
106 | },
107 | },
108 | },
109 | true // disable bundling in test environment
110 | );
111 |
112 | return { contractTxId, srcTxId };
113 | }
114 |
--------------------------------------------------------------------------------
/hollowdb/tsconfig.build.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": "./tsconfig.json",
3 | "compilerOptions": {
4 | "rootDir": "./src"
5 | },
6 | "exclude": ["node_modules", "jest.config.ts", "test"]
7 | }
8 |
--------------------------------------------------------------------------------
/hollowdb/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | /* Visit https://aka.ms/tsconfig to read more about this file */
4 |
5 | /* Projects */
6 | // "incremental": true, /* Save .tsbuildinfo files to allow for incremental compilation of projects. */
7 | // "composite": true, /* Enable constraints that allow a TypeScript project to be used with project references. */
8 | // "tsBuildInfoFile": "./.tsbuildinfo", /* Specify the path to .tsbuildinfo incremental compilation file. */
9 | // "disableSourceOfProjectReferenceRedirect": true, /* Disable preferring source files instead of declaration files when referencing composite projects. */
10 | // "disableSolutionSearching": true, /* Opt a project out of multi-project reference checking when editing. */
11 | // "disableReferencedProjectLoad": true, /* Reduce the number of projects loaded automatically by TypeScript. */
12 |
13 | /* Language and Environment */
14 | "target": "ES2022" /* Set the JavaScript language version for emitted JavaScript and include compatible library declarations. */,
15 | // "lib": [], /* Specify a set of bundled library declaration files that describe the target runtime environment. */
16 | // "jsx": "preserve", /* Specify what JSX code is generated. */
17 | // "experimentalDecorators": true, /* Enable experimental support for legacy experimental decorators. */
18 | // "emitDecoratorMetadata": true, /* Emit design-type metadata for decorated declarations in source files. */
19 | // "jsxFactory": "", /* Specify the JSX factory function used when targeting React JSX emit, e.g. 'React.createElement' or 'h'. */
20 | // "jsxFragmentFactory": "", /* Specify the JSX Fragment reference used for fragments when targeting React JSX emit e.g. 'React.Fragment' or 'Fragment'. */
21 | // "jsxImportSource": "", /* Specify module specifier used to import the JSX factory functions when using 'jsx: react-jsx*'. */
22 | // "reactNamespace": "", /* Specify the object invoked for 'createElement'. This only applies when targeting 'react' JSX emit. */
23 | // "noLib": true, /* Disable including any library files, including the default lib.d.ts. */
24 | // "useDefineForClassFields": true, /* Emit ECMAScript-standard-compliant class fields. */
25 | // "moduleDetection": "auto", /* Control what method is used to detect module-format JS files. */
26 |
27 | /* Modules */
28 | "module": "commonjs" /* Specify what module code is generated. */,
29 | // "rootDir": "./src" /* Specify the root folder within your source files. */,
30 | // "moduleResolution": "node10", /* Specify how TypeScript looks up a file from a given module specifier. */
31 | // "baseUrl": "./", /* Specify the base directory to resolve non-relative module names. */
32 | // "paths": {}, /* Specify a set of entries that re-map imports to additional lookup locations. */
33 | // "rootDirs": [], /* Allow multiple folders to be treated as one when resolving modules. */
34 | // "typeRoots": [], /* Specify multiple folders that act like './node_modules/@types'. */
35 | // "types": [], /* Specify type package names to be included without being referenced in a source file. */
36 | // "allowUmdGlobalAccess": true, /* Allow accessing UMD globals from modules. */
37 | // "moduleSuffixes": [], /* List of file name suffixes to search when resolving a module. */
38 | // "allowImportingTsExtensions": true, /* Allow imports to include TypeScript file extensions. Requires '--moduleResolution bundler' and either '--noEmit' or '--emitDeclarationOnly' to be set. */
39 | // "resolvePackageJsonExports": true, /* Use the package.json 'exports' field when resolving package imports. */
40 | // "resolvePackageJsonImports": true, /* Use the package.json 'imports' field when resolving imports. */
41 | // "customConditions": [], /* Conditions to set in addition to the resolver-specific defaults when resolving imports. */
42 | // "resolveJsonModule": true, /* Enable importing .json files. */
43 | // "allowArbitraryExtensions": true, /* Enable importing files with any extension, provided a declaration file is present. */
44 | // "noResolve": true, /* Disallow 'import's, 'require's or ''s from expanding the number of files TypeScript should add to a project. */
45 |
46 | /* JavaScript Support */
47 | // "allowJs": true, /* Allow JavaScript files to be a part of your program. Use the 'checkJS' option to get errors from these files. */
48 | // "checkJs": true, /* Enable error reporting in type-checked JavaScript files. */
49 | // "maxNodeModuleJsDepth": 1, /* Specify the maximum folder depth used for checking JavaScript files from 'node_modules'. Only applicable with 'allowJs'. */
50 |
51 | /* Emit */
52 | // "declaration": true, /* Generate .d.ts files from TypeScript and JavaScript files in your project. */
53 | // "declarationMap": true, /* Create sourcemaps for d.ts files. */
54 | // "emitDeclarationOnly": true, /* Only output d.ts files and not JavaScript files. */
55 | // "sourceMap": true, /* Create source map files for emitted JavaScript files. */
56 | // "inlineSourceMap": true, /* Include sourcemap files inside the emitted JavaScript. */
57 | // "outFile": "./", /* Specify a file that bundles all outputs into one JavaScript file. If 'declaration' is true, also designates a file that bundles all .d.ts output. */
58 | "outDir": "./build" /* Specify an output folder for all emitted files. */,
59 | // "removeComments": true, /* Disable emitting comments. */
60 | // "noEmit": true, /* Disable emitting files from a compilation. */
61 | // "importHelpers": true, /* Allow importing helper functions from tslib once per project, instead of including them per-file. */
62 | // "importsNotUsedAsValues": "remove", /* Specify emit/checking behavior for imports that are only used for types. */
63 | // "downlevelIteration": true, /* Emit more compliant, but verbose and less performant JavaScript for iteration. */
64 | // "sourceRoot": "", /* Specify the root path for debuggers to find the reference source code. */
65 | // "mapRoot": "", /* Specify the location where debugger should locate map files instead of generated locations. */
66 | // "inlineSources": true, /* Include source code in the sourcemaps inside the emitted JavaScript. */
67 | // "emitBOM": true, /* Emit a UTF-8 Byte Order Mark (BOM) in the beginning of output files. */
68 | // "newLine": "crlf", /* Set the newline character for emitting files. */
69 | // "stripInternal": true, /* Disable emitting declarations that have '@internal' in their JSDoc comments. */
70 | // "noEmitHelpers": true, /* Disable generating custom helper functions like '__extends' in compiled output. */
71 | // "noEmitOnError": true, /* Disable emitting files if any type checking errors are reported. */
72 | // "preserveConstEnums": true, /* Disable erasing 'const enum' declarations in generated code. */
73 | // "declarationDir": "./", /* Specify the output directory for generated declaration files. */
74 | // "preserveValueImports": true, /* Preserve unused imported values in the JavaScript output that would otherwise be removed. */
75 |
76 | /* Interop Constraints */
77 | // "isolatedModules": true, /* Ensure that each file can be safely transpiled without relying on other imports. */
78 | // "verbatimModuleSyntax": true, /* Do not transform or elide any imports or exports not marked as type-only, ensuring they are written in the output file's format based on the 'module' setting. */
79 | // "allowSyntheticDefaultImports": true, /* Allow 'import x from y' when a module doesn't have a default export. */
80 | "esModuleInterop": true /* Emit additional JavaScript to ease support for importing CommonJS modules. This enables 'allowSyntheticDefaultImports' for type compatibility. */,
81 | // "preserveSymlinks": true, /* Disable resolving symlinks to their realpath. This correlates to the same flag in node. */
82 | "forceConsistentCasingInFileNames": true /* Ensure that casing is correct in imports. */,
83 |
84 | /* Type Checking */
85 | "strict": true /* Enable all strict type-checking options. */,
86 | // "noImplicitAny": true, /* Enable error reporting for expressions and declarations with an implied 'any' type. */
87 | // "strictNullChecks": true, /* When type checking, take into account 'null' and 'undefined'. */
88 | // "strictFunctionTypes": true, /* When assigning functions, check to ensure parameters and the return values are subtype-compatible. */
89 | // "strictBindCallApply": true, /* Check that the arguments for 'bind', 'call', and 'apply' methods match the original function. */
90 | // "strictPropertyInitialization": true, /* Check for class properties that are declared but not set in the constructor. */
91 | // "noImplicitThis": true, /* Enable error reporting when 'this' is given the type 'any'. */
92 | // "useUnknownInCatchVariables": true, /* Default catch clause variables as 'unknown' instead of 'any'. */
93 | // "alwaysStrict": true, /* Ensure 'use strict' is always emitted. */
94 | // "noUnusedLocals": true, /* Enable error reporting when local variables aren't read. */
95 | // "noUnusedParameters": true, /* Raise an error when a function parameter isn't read. */
96 | // "exactOptionalPropertyTypes": true, /* Interpret optional property types as written, rather than adding 'undefined'. */
97 | // "noImplicitReturns": true, /* Enable error reporting for codepaths that do not explicitly return in a function. */
98 | // "noFallthroughCasesInSwitch": true, /* Enable error reporting for fallthrough cases in switch statements. */
99 | // "noUncheckedIndexedAccess": true, /* Add 'undefined' to a type when accessed using an index. */
100 | // "noImplicitOverride": true, /* Ensure overriding members in derived classes are marked with an override modifier. */
101 | // "noPropertyAccessFromIndexSignature": true, /* Enforces using indexed accessors for keys declared using an indexed type. */
102 | // "allowUnusedLabels": true, /* Disable error reporting for unused labels. */
103 | // "allowUnreachableCode": true, /* Disable error reporting for unreachable code. */
104 |
105 | /* Completeness */
106 | // "skipDefaultLibCheck": true, /* Skip type checking .d.ts files that are included with TypeScript. */
107 | "skipLibCheck": true /* Skip type checking all .d.ts files. */
108 | }
109 | }
110 |
--------------------------------------------------------------------------------
/hollowdb_wait/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM alpine
2 | WORKDIR /app
3 |
4 | RUN apk update
5 | RUN apk add curl
6 |
7 | COPY wait.sh /app/wait.sh
8 | RUN chmod +x wait.sh
9 |
10 | CMD ./wait.sh
11 |
--------------------------------------------------------------------------------
/hollowdb_wait/README.md:
--------------------------------------------------------------------------------
1 | # HollowDB API `wait-for`
2 |
3 | This is a simple shell script container that another container can depend "on completion", such that the script will finish when HollowDB container finished downloading & refreshing keys.
4 |
5 | The containers are expected to launch in the following order:
6 |
7 | 1. **Redis**: This is the first container to launch.
8 | 2. **HollowDB**: This starts when Redis is live, and immediately begins to download values from Arweave & store them in memory for Dria to access efficiently.
9 | 3. **Dria HNSW**: Dria waits for the HollowDB API's download the complete via [this wait-for script](./wait.sh), and once that is complete; it launches & starts listening at it's port.
10 |
11 | The script is available on Docker Hub:
12 |
13 | ```sh
14 | docker pull firstbatch/dria-hollowdb-wait-for
15 | ```
16 |
17 | ## Wait-For Script
18 |
19 | The script makes use of the following cURL command:
20 |
21 | ```sh
22 | curl -f -d '{"route": "STATE"}' $TARGET
23 | ```
24 |
25 | If the cache is still loading, HollowDB will respond with status `503` and cURL will return a non-zero code, causing the shell script to wait for a while and try again later. The body of the response also contains the percentage of keys loaded, if you are to make the request yourself.
26 |
--------------------------------------------------------------------------------
/hollowdb_wait/wait.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # Connection parameters
4 | TARGET=${TARGET:-"http://localhost:3030"}
5 |
6 | # How long to sleep (seconds) before each attempt.
7 | SLEEPS=${SLEEPS:-2}
8 |
9 | echo "Polling $TARGET"
10 |
11 | while true; do
12 | curl --fail --silent "$TARGET/state"
13 |
14 | # check if exit code of curl is 0
15 | if [ $? -eq 0 ]; then
16 | echo ""
17 | echo "HollowDB API is ready!"
18 | exit 0
19 | fi
20 |
21 | sleep "$SLEEPS"
22 | done
23 |
--------------------------------------------------------------------------------