├── .github
└── workflows
│ ├── ci.yml
│ └── gh-pages.yml
├── .gitignore
├── .vscode
└── settings.json
├── LICENSE.md
├── README.md
├── demo
├── Makefile
├── server
│ └── main.go
├── site
│ ├── favicon.ico
│ ├── index.html
│ └── wasm_exec.js
├── wasm
│ ├── go.mod
│ ├── go.sum
│ ├── main.go
│ └── words.txt
└── words.py
├── dll.go
├── example_test.go
├── go.mod
├── go.sum
├── go.work
├── go.work.sum
├── heap.go
├── search.go
├── search_test.go
├── search_whitebox_test.go
├── trie.go
├── trie_test.go
├── walk.go
└── walk_test.go
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches: [master]
6 | pull_request:
7 | branches: [master]
8 |
9 | jobs:
10 | lint:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - name: Install Go
14 | uses: actions/setup-go@v2
15 | with:
16 | go-version: '1.18'
17 | - name: Install goimports
18 | run: go install golang.org/x/tools/cmd/goimports@latest
19 | - name: Checkout code
20 | uses: actions/checkout@v3
21 | - run: goimports -w .
22 | - run: go mod tidy
23 | - name: Verify no changes from goimports and go mod tidy
24 | run: |
25 | if [ -n "$(git status --porcelain)" ]; then
26 | exit 1
27 | fi
28 | test:
29 | runs-on: ubuntu-latest
30 | steps:
31 | - name: Install Go
32 | uses: actions/setup-go@v2
33 | with:
34 | go-version: '1.18'
35 | - name: Checkout code
36 | uses: actions/checkout@v2
37 | - name: Run tests
38 | run: go test -v ./...
39 | - name: Run benchmarks
40 | run: go test -v -bench=.
--------------------------------------------------------------------------------
/.github/workflows/gh-pages.yml:
--------------------------------------------------------------------------------
1 | name: github pages
2 |
3 | on:
4 | push:
5 | branches: [master]
6 | pull_request:
7 | branches: [master]
8 |
9 | jobs:
10 | deploy:
11 | runs-on: ubuntu-20.04
12 | steps:
13 | - name: Install Go
14 | uses: actions/setup-go@v2
15 | with:
16 | go-version: '^1.18'
17 | - name: Checkout code
18 | uses: actions/checkout@v2
19 | - name: Create wasm module
20 | run: cd demo && make build-wasm
21 | - name: Deploy
22 | uses: peaceiris/actions-gh-pages@v3
23 | if: github.ref == 'refs/heads/master'
24 | with:
25 | github_token: ${{ secrets.GITHUB_TOKEN }}
26 | publish_dir: ./demo/site
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | main.wasm
2 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "gopls": {
3 | "build.env": {
4 | "GOOS": "js",
5 | "GOARCH": "wasm"
6 | }
7 | }
8 | }
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 shivam mamgain
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## trie [](https://godoc.org/github.com/shivamMg/trie)  [](https://opensource.org/licenses/MIT)
2 |
3 | An implementation of the [Trie](https://en.wikipedia.org/wiki/Trie) data structure in Go. It provides more features than the usual Trie prefix-search, and is meant to be used for auto-completion.
4 |
5 | ### Demo
6 |
7 | A WebAssembly demo can be tried at [shivamMg.github.io/trie](https://shivammg.github.io/trie/).
8 |
9 | ### Features
10 |
11 | - Keys are `[]string` instead of `string`, thereby supporting more use cases - e.g. []string{the quick brown fox} can be a key where each word will be a node in the Trie
12 | - Support for Put key and Delete key
13 | - Support for Prefix search - e.g. searching for _nation_ might return _nation_, _national_, _nationalism_, _nationalist_, etc.
14 | - Support for Edit distance search (aka Levenshtein distance) - e.g. searching for _wheat_ might return similar looking words like _wheat_, _cheat_, _heat_, _what_, etc.
15 | - Order of search results is deterministic. It follows insertion order.
16 |
17 | ### Examples
18 |
19 | ```go
20 | tri := trie.New()
21 | // Put keys ([]string) and values (any)
22 | tri.Put([]string{"the"}, 1)
23 | tri.Put([]string{"the", "quick", "brown", "fox"}, 2)
24 | tri.Put([]string{"the", "quick", "sports", "car"}, 3)
25 | tri.Put([]string{"the", "green", "tree"}, 4)
26 | tri.Put([]string{"an", "apple", "tree"}, 5)
27 | tri.Put([]string{"an", "umbrella"}, 6)
28 |
29 | tri.Root().Print()
30 | // Output (full trie with terminals ending with ($)):
31 | // ^
32 | // ├─ the ($)
33 | // │ ├─ quick
34 | // │ │ ├─ brown
35 | // │ │ │ └─ fox ($)
36 | // │ │ └─ sports
37 | // │ │ └─ car ($)
38 | // │ └─ green
39 | // │ └─ tree ($)
40 | // └─ an
41 | // ├─ apple
42 | // │ └─ tree ($)
43 | // └─ umbrella ($)
44 |
45 | results := tri.Search([]string{"the", "quick"})
46 | for _, res := range results.Results {
47 | fmt.Println(res.Key, res.Value)
48 | }
49 | // Output (prefix-based search):
50 | // [the quick brown fox] 2
51 | // [the quick sports car] 3
52 |
53 | key := []string{"the", "tree"}
54 | results = tri.Search(key, trie.WithMaxEditDistance(2), // An edit can be insert, delete, replace
55 | trie.WithEditOps())
56 | for _, res := range results.Results {
57 | fmt.Println(res.Key, res.EditDistance) // EditDistance is number of edits needed to convert to [the tree]
58 | }
59 | // Output (results not more than 2 edits away from [the tree]):
60 | // [the] 1
61 | // [the green tree] 1
62 | // [an apple tree] 2
63 | // [an umbrella] 2
64 |
65 | result := results.Results[2]
66 | fmt.Printf("To convert %v to %v:\n", result.Key, key)
67 | printEditOps(result.EditOps)
68 | // Output (edit operations needed to covert a result to [the tree]):
69 | // To convert [an apple tree] to [the tree]:
70 | // - delete "an"
71 | // - replace "apple" with "the"
72 | // - don't edit "tree"
73 |
74 | results = tri.Search(key, trie.WithMaxEditDistance(2), trie.WithTopKLeastEdited(), trie.WithMaxResults(2))
75 | for _, res := range results.Results {
76 | fmt.Println(res.Key, res.Value, res.EditDistance)
77 | }
78 | // Output (top 2 least edited results):
79 | // [the] 1 1
80 | // [the green tree] 4 1
81 | ```
82 |
83 | ### References
84 |
85 | * https://en.wikipedia.org/wiki/Levenshtein_distance#Iterative_with_full_matrix
86 | * http://stevehanov.ca/blog/?id=114
87 | * https://gist.github.com/jlherren/d97839b1276b9bd7faa930f74711a4b6
--------------------------------------------------------------------------------
/demo/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: build-wasm
2 | build-wasm:
3 | cd wasm && GOOS=js GOARCH=wasm go build -o ../site/main.wasm
4 |
5 | .PHONY: server
6 | server:
7 | cd server && go run .
8 |
--------------------------------------------------------------------------------
/demo/server/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "log"
5 | "net/http"
6 | )
7 |
8 | const (
9 | siteDir = "./../site"
10 | addr = ":8080"
11 | )
12 |
13 | func main() {
14 | log.Println("server will start at", addr)
15 | log.Fatal(http.ListenAndServe(addr, http.FileServer(http.Dir(siteDir))))
16 | }
17 |
--------------------------------------------------------------------------------
/demo/site/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shivamMg/trie/fdf2c274601aaf8ec709ce4c1461b5cca1dd868c/demo/site/favicon.ico
--------------------------------------------------------------------------------
/demo/site/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | trie auto-completion demo
5 |
6 |
7 |
8 |
53 |
54 |
55 |
56 |
57 |
58 |
59 | This is a WebAssembly demo for auto-completion using the trie Go library .
60 | The WASM module contains a Trie populated with English dictionary words. Searching a word retrieves max 10 results.
61 |
62 |
63 | The default search is Prefix search i.e. results will have the same prefix.
64 |
65 |
66 | If "Edit distance search" is enabled, then results will be at most 3
67 | edit distance away and in least-edited-first order.
68 |
69 |
70 |
71 |
72 | Edit distance search (type after toggling)
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
134 |
--------------------------------------------------------------------------------
/demo/site/wasm_exec.js:
--------------------------------------------------------------------------------
1 | // Copyright 2018 The Go Authors. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | "use strict";
6 |
7 | (() => {
8 | const enosys = () => {
9 | const err = new Error("not implemented");
10 | err.code = "ENOSYS";
11 | return err;
12 | };
13 |
14 | if (!globalThis.fs) {
15 | let outputBuf = "";
16 | globalThis.fs = {
17 | constants: { O_WRONLY: -1, O_RDWR: -1, O_CREAT: -1, O_TRUNC: -1, O_APPEND: -1, O_EXCL: -1 }, // unused
18 | writeSync(fd, buf) {
19 | outputBuf += decoder.decode(buf);
20 | const nl = outputBuf.lastIndexOf("\n");
21 | if (nl != -1) {
22 | console.log(outputBuf.substr(0, nl));
23 | outputBuf = outputBuf.substr(nl + 1);
24 | }
25 | return buf.length;
26 | },
27 | write(fd, buf, offset, length, position, callback) {
28 | if (offset !== 0 || length !== buf.length || position !== null) {
29 | callback(enosys());
30 | return;
31 | }
32 | const n = this.writeSync(fd, buf);
33 | callback(null, n);
34 | },
35 | chmod(path, mode, callback) { callback(enosys()); },
36 | chown(path, uid, gid, callback) { callback(enosys()); },
37 | close(fd, callback) { callback(enosys()); },
38 | fchmod(fd, mode, callback) { callback(enosys()); },
39 | fchown(fd, uid, gid, callback) { callback(enosys()); },
40 | fstat(fd, callback) { callback(enosys()); },
41 | fsync(fd, callback) { callback(null); },
42 | ftruncate(fd, length, callback) { callback(enosys()); },
43 | lchown(path, uid, gid, callback) { callback(enosys()); },
44 | link(path, link, callback) { callback(enosys()); },
45 | lstat(path, callback) { callback(enosys()); },
46 | mkdir(path, perm, callback) { callback(enosys()); },
47 | open(path, flags, mode, callback) { callback(enosys()); },
48 | read(fd, buffer, offset, length, position, callback) { callback(enosys()); },
49 | readdir(path, callback) { callback(enosys()); },
50 | readlink(path, callback) { callback(enosys()); },
51 | rename(from, to, callback) { callback(enosys()); },
52 | rmdir(path, callback) { callback(enosys()); },
53 | stat(path, callback) { callback(enosys()); },
54 | symlink(path, link, callback) { callback(enosys()); },
55 | truncate(path, length, callback) { callback(enosys()); },
56 | unlink(path, callback) { callback(enosys()); },
57 | utimes(path, atime, mtime, callback) { callback(enosys()); },
58 | };
59 | }
60 |
61 | if (!globalThis.process) {
62 | globalThis.process = {
63 | getuid() { return -1; },
64 | getgid() { return -1; },
65 | geteuid() { return -1; },
66 | getegid() { return -1; },
67 | getgroups() { throw enosys(); },
68 | pid: -1,
69 | ppid: -1,
70 | umask() { throw enosys(); },
71 | cwd() { throw enosys(); },
72 | chdir() { throw enosys(); },
73 | }
74 | }
75 |
76 | if (!globalThis.crypto) {
77 | throw new Error("globalThis.crypto is not available, polyfill required (crypto.getRandomValues only)");
78 | }
79 |
80 | if (!globalThis.performance) {
81 | throw new Error("globalThis.performance is not available, polyfill required (performance.now only)");
82 | }
83 |
84 | if (!globalThis.TextEncoder) {
85 | throw new Error("globalThis.TextEncoder is not available, polyfill required");
86 | }
87 |
88 | if (!globalThis.TextDecoder) {
89 | throw new Error("globalThis.TextDecoder is not available, polyfill required");
90 | }
91 |
92 | const encoder = new TextEncoder("utf-8");
93 | const decoder = new TextDecoder("utf-8");
94 |
95 | globalThis.Go = class {
96 | constructor() {
97 | this.argv = ["js"];
98 | this.env = {};
99 | this.exit = (code) => {
100 | if (code !== 0) {
101 | console.warn("exit code:", code);
102 | }
103 | };
104 | this._exitPromise = new Promise((resolve) => {
105 | this._resolveExitPromise = resolve;
106 | });
107 | this._pendingEvent = null;
108 | this._scheduledTimeouts = new Map();
109 | this._nextCallbackTimeoutID = 1;
110 |
111 | const setInt64 = (addr, v) => {
112 | this.mem.setUint32(addr + 0, v, true);
113 | this.mem.setUint32(addr + 4, Math.floor(v / 4294967296), true);
114 | }
115 |
116 | const getInt64 = (addr) => {
117 | const low = this.mem.getUint32(addr + 0, true);
118 | const high = this.mem.getInt32(addr + 4, true);
119 | return low + high * 4294967296;
120 | }
121 |
122 | const loadValue = (addr) => {
123 | const f = this.mem.getFloat64(addr, true);
124 | if (f === 0) {
125 | return undefined;
126 | }
127 | if (!isNaN(f)) {
128 | return f;
129 | }
130 |
131 | const id = this.mem.getUint32(addr, true);
132 | return this._values[id];
133 | }
134 |
135 | const storeValue = (addr, v) => {
136 | const nanHead = 0x7FF80000;
137 |
138 | if (typeof v === "number" && v !== 0) {
139 | if (isNaN(v)) {
140 | this.mem.setUint32(addr + 4, nanHead, true);
141 | this.mem.setUint32(addr, 0, true);
142 | return;
143 | }
144 | this.mem.setFloat64(addr, v, true);
145 | return;
146 | }
147 |
148 | if (v === undefined) {
149 | this.mem.setFloat64(addr, 0, true);
150 | return;
151 | }
152 |
153 | let id = this._ids.get(v);
154 | if (id === undefined) {
155 | id = this._idPool.pop();
156 | if (id === undefined) {
157 | id = this._values.length;
158 | }
159 | this._values[id] = v;
160 | this._goRefCounts[id] = 0;
161 | this._ids.set(v, id);
162 | }
163 | this._goRefCounts[id]++;
164 | let typeFlag = 0;
165 | switch (typeof v) {
166 | case "object":
167 | if (v !== null) {
168 | typeFlag = 1;
169 | }
170 | break;
171 | case "string":
172 | typeFlag = 2;
173 | break;
174 | case "symbol":
175 | typeFlag = 3;
176 | break;
177 | case "function":
178 | typeFlag = 4;
179 | break;
180 | }
181 | this.mem.setUint32(addr + 4, nanHead | typeFlag, true);
182 | this.mem.setUint32(addr, id, true);
183 | }
184 |
185 | const loadSlice = (addr) => {
186 | const array = getInt64(addr + 0);
187 | const len = getInt64(addr + 8);
188 | return new Uint8Array(this._inst.exports.mem.buffer, array, len);
189 | }
190 |
191 | const loadSliceOfValues = (addr) => {
192 | const array = getInt64(addr + 0);
193 | const len = getInt64(addr + 8);
194 | const a = new Array(len);
195 | for (let i = 0; i < len; i++) {
196 | a[i] = loadValue(array + i * 8);
197 | }
198 | return a;
199 | }
200 |
201 | const loadString = (addr) => {
202 | const saddr = getInt64(addr + 0);
203 | const len = getInt64(addr + 8);
204 | return decoder.decode(new DataView(this._inst.exports.mem.buffer, saddr, len));
205 | }
206 |
207 | const timeOrigin = Date.now() - performance.now();
208 | this.importObject = {
209 | go: {
210 | // Go's SP does not change as long as no Go code is running. Some operations (e.g. calls, getters and setters)
211 | // may synchronously trigger a Go event handler. This makes Go code get executed in the middle of the imported
212 | // function. A goroutine can switch to a new stack if the current stack is too small (see morestack function).
213 | // This changes the SP, thus we have to update the SP used by the imported function.
214 |
215 | // func wasmExit(code int32)
216 | "runtime.wasmExit": (sp) => {
217 | sp >>>= 0;
218 | const code = this.mem.getInt32(sp + 8, true);
219 | this.exited = true;
220 | delete this._inst;
221 | delete this._values;
222 | delete this._goRefCounts;
223 | delete this._ids;
224 | delete this._idPool;
225 | this.exit(code);
226 | },
227 |
228 | // func wasmWrite(fd uintptr, p unsafe.Pointer, n int32)
229 | "runtime.wasmWrite": (sp) => {
230 | sp >>>= 0;
231 | const fd = getInt64(sp + 8);
232 | const p = getInt64(sp + 16);
233 | const n = this.mem.getInt32(sp + 24, true);
234 | fs.writeSync(fd, new Uint8Array(this._inst.exports.mem.buffer, p, n));
235 | },
236 |
237 | // func resetMemoryDataView()
238 | "runtime.resetMemoryDataView": (sp) => {
239 | sp >>>= 0;
240 | this.mem = new DataView(this._inst.exports.mem.buffer);
241 | },
242 |
243 | // func nanotime1() int64
244 | "runtime.nanotime1": (sp) => {
245 | sp >>>= 0;
246 | setInt64(sp + 8, (timeOrigin + performance.now()) * 1000000);
247 | },
248 |
249 | // func walltime() (sec int64, nsec int32)
250 | "runtime.walltime": (sp) => {
251 | sp >>>= 0;
252 | const msec = (new Date).getTime();
253 | setInt64(sp + 8, msec / 1000);
254 | this.mem.setInt32(sp + 16, (msec % 1000) * 1000000, true);
255 | },
256 |
257 | // func scheduleTimeoutEvent(delay int64) int32
258 | "runtime.scheduleTimeoutEvent": (sp) => {
259 | sp >>>= 0;
260 | const id = this._nextCallbackTimeoutID;
261 | this._nextCallbackTimeoutID++;
262 | this._scheduledTimeouts.set(id, setTimeout(
263 | () => {
264 | this._resume();
265 | while (this._scheduledTimeouts.has(id)) {
266 | // for some reason Go failed to register the timeout event, log and try again
267 | // (temporary workaround for https://github.com/golang/go/issues/28975)
268 | console.warn("scheduleTimeoutEvent: missed timeout event");
269 | this._resume();
270 | }
271 | },
272 | getInt64(sp + 8) + 1, // setTimeout has been seen to fire up to 1 millisecond early
273 | ));
274 | this.mem.setInt32(sp + 16, id, true);
275 | },
276 |
277 | // func clearTimeoutEvent(id int32)
278 | "runtime.clearTimeoutEvent": (sp) => {
279 | sp >>>= 0;
280 | const id = this.mem.getInt32(sp + 8, true);
281 | clearTimeout(this._scheduledTimeouts.get(id));
282 | this._scheduledTimeouts.delete(id);
283 | },
284 |
285 | // func getRandomData(r []byte)
286 | "runtime.getRandomData": (sp) => {
287 | sp >>>= 0;
288 | crypto.getRandomValues(loadSlice(sp + 8));
289 | },
290 |
291 | // func finalizeRef(v ref)
292 | "syscall/js.finalizeRef": (sp) => {
293 | sp >>>= 0;
294 | const id = this.mem.getUint32(sp + 8, true);
295 | this._goRefCounts[id]--;
296 | if (this._goRefCounts[id] === 0) {
297 | const v = this._values[id];
298 | this._values[id] = null;
299 | this._ids.delete(v);
300 | this._idPool.push(id);
301 | }
302 | },
303 |
304 | // func stringVal(value string) ref
305 | "syscall/js.stringVal": (sp) => {
306 | sp >>>= 0;
307 | storeValue(sp + 24, loadString(sp + 8));
308 | },
309 |
310 | // func valueGet(v ref, p string) ref
311 | "syscall/js.valueGet": (sp) => {
312 | sp >>>= 0;
313 | const result = Reflect.get(loadValue(sp + 8), loadString(sp + 16));
314 | sp = this._inst.exports.getsp() >>> 0; // see comment above
315 | storeValue(sp + 32, result);
316 | },
317 |
318 | // func valueSet(v ref, p string, x ref)
319 | "syscall/js.valueSet": (sp) => {
320 | sp >>>= 0;
321 | Reflect.set(loadValue(sp + 8), loadString(sp + 16), loadValue(sp + 32));
322 | },
323 |
324 | // func valueDelete(v ref, p string)
325 | "syscall/js.valueDelete": (sp) => {
326 | sp >>>= 0;
327 | Reflect.deleteProperty(loadValue(sp + 8), loadString(sp + 16));
328 | },
329 |
330 | // func valueIndex(v ref, i int) ref
331 | "syscall/js.valueIndex": (sp) => {
332 | sp >>>= 0;
333 | storeValue(sp + 24, Reflect.get(loadValue(sp + 8), getInt64(sp + 16)));
334 | },
335 |
336 | // valueSetIndex(v ref, i int, x ref)
337 | "syscall/js.valueSetIndex": (sp) => {
338 | sp >>>= 0;
339 | Reflect.set(loadValue(sp + 8), getInt64(sp + 16), loadValue(sp + 24));
340 | },
341 |
342 | // func valueCall(v ref, m string, args []ref) (ref, bool)
343 | "syscall/js.valueCall": (sp) => {
344 | sp >>>= 0;
345 | try {
346 | const v = loadValue(sp + 8);
347 | const m = Reflect.get(v, loadString(sp + 16));
348 | const args = loadSliceOfValues(sp + 32);
349 | const result = Reflect.apply(m, v, args);
350 | sp = this._inst.exports.getsp() >>> 0; // see comment above
351 | storeValue(sp + 56, result);
352 | this.mem.setUint8(sp + 64, 1);
353 | } catch (err) {
354 | sp = this._inst.exports.getsp() >>> 0; // see comment above
355 | storeValue(sp + 56, err);
356 | this.mem.setUint8(sp + 64, 0);
357 | }
358 | },
359 |
360 | // func valueInvoke(v ref, args []ref) (ref, bool)
361 | "syscall/js.valueInvoke": (sp) => {
362 | sp >>>= 0;
363 | try {
364 | const v = loadValue(sp + 8);
365 | const args = loadSliceOfValues(sp + 16);
366 | const result = Reflect.apply(v, undefined, args);
367 | sp = this._inst.exports.getsp() >>> 0; // see comment above
368 | storeValue(sp + 40, result);
369 | this.mem.setUint8(sp + 48, 1);
370 | } catch (err) {
371 | sp = this._inst.exports.getsp() >>> 0; // see comment above
372 | storeValue(sp + 40, err);
373 | this.mem.setUint8(sp + 48, 0);
374 | }
375 | },
376 |
377 | // func valueNew(v ref, args []ref) (ref, bool)
378 | "syscall/js.valueNew": (sp) => {
379 | sp >>>= 0;
380 | try {
381 | const v = loadValue(sp + 8);
382 | const args = loadSliceOfValues(sp + 16);
383 | const result = Reflect.construct(v, args);
384 | sp = this._inst.exports.getsp() >>> 0; // see comment above
385 | storeValue(sp + 40, result);
386 | this.mem.setUint8(sp + 48, 1);
387 | } catch (err) {
388 | sp = this._inst.exports.getsp() >>> 0; // see comment above
389 | storeValue(sp + 40, err);
390 | this.mem.setUint8(sp + 48, 0);
391 | }
392 | },
393 |
394 | // func valueLength(v ref) int
395 | "syscall/js.valueLength": (sp) => {
396 | sp >>>= 0;
397 | setInt64(sp + 16, parseInt(loadValue(sp + 8).length));
398 | },
399 |
400 | // valuePrepareString(v ref) (ref, int)
401 | "syscall/js.valuePrepareString": (sp) => {
402 | sp >>>= 0;
403 | const str = encoder.encode(String(loadValue(sp + 8)));
404 | storeValue(sp + 16, str);
405 | setInt64(sp + 24, str.length);
406 | },
407 |
408 | // valueLoadString(v ref, b []byte)
409 | "syscall/js.valueLoadString": (sp) => {
410 | sp >>>= 0;
411 | const str = loadValue(sp + 8);
412 | loadSlice(sp + 16).set(str);
413 | },
414 |
415 | // func valueInstanceOf(v ref, t ref) bool
416 | "syscall/js.valueInstanceOf": (sp) => {
417 | sp >>>= 0;
418 | this.mem.setUint8(sp + 24, (loadValue(sp + 8) instanceof loadValue(sp + 16)) ? 1 : 0);
419 | },
420 |
421 | // func copyBytesToGo(dst []byte, src ref) (int, bool)
422 | "syscall/js.copyBytesToGo": (sp) => {
423 | sp >>>= 0;
424 | const dst = loadSlice(sp + 8);
425 | const src = loadValue(sp + 32);
426 | if (!(src instanceof Uint8Array || src instanceof Uint8ClampedArray)) {
427 | this.mem.setUint8(sp + 48, 0);
428 | return;
429 | }
430 | const toCopy = src.subarray(0, dst.length);
431 | dst.set(toCopy);
432 | setInt64(sp + 40, toCopy.length);
433 | this.mem.setUint8(sp + 48, 1);
434 | },
435 |
436 | // func copyBytesToJS(dst ref, src []byte) (int, bool)
437 | "syscall/js.copyBytesToJS": (sp) => {
438 | sp >>>= 0;
439 | const dst = loadValue(sp + 8);
440 | const src = loadSlice(sp + 16);
441 | if (!(dst instanceof Uint8Array || dst instanceof Uint8ClampedArray)) {
442 | this.mem.setUint8(sp + 48, 0);
443 | return;
444 | }
445 | const toCopy = src.subarray(0, dst.length);
446 | dst.set(toCopy);
447 | setInt64(sp + 40, toCopy.length);
448 | this.mem.setUint8(sp + 48, 1);
449 | },
450 |
451 | "debug": (value) => {
452 | console.log(value);
453 | },
454 | }
455 | };
456 | }
457 |
458 | async run(instance) {
459 | if (!(instance instanceof WebAssembly.Instance)) {
460 | throw new Error("Go.run: WebAssembly.Instance expected");
461 | }
462 | this._inst = instance;
463 | this.mem = new DataView(this._inst.exports.mem.buffer);
464 | this._values = [ // JS values that Go currently has references to, indexed by reference id
465 | NaN,
466 | 0,
467 | null,
468 | true,
469 | false,
470 | globalThis,
471 | this,
472 | ];
473 | this._goRefCounts = new Array(this._values.length).fill(Infinity); // number of references that Go has to a JS value, indexed by reference id
474 | this._ids = new Map([ // mapping from JS values to reference ids
475 | [0, 1],
476 | [null, 2],
477 | [true, 3],
478 | [false, 4],
479 | [globalThis, 5],
480 | [this, 6],
481 | ]);
482 | this._idPool = []; // unused ids that have been garbage collected
483 | this.exited = false; // whether the Go program has exited
484 |
485 | // Pass command line arguments and environment variables to WebAssembly by writing them to the linear memory.
486 | let offset = 4096;
487 |
488 | const strPtr = (str) => {
489 | const ptr = offset;
490 | const bytes = encoder.encode(str + "\0");
491 | new Uint8Array(this.mem.buffer, offset, bytes.length).set(bytes);
492 | offset += bytes.length;
493 | if (offset % 8 !== 0) {
494 | offset += 8 - (offset % 8);
495 | }
496 | return ptr;
497 | };
498 |
499 | const argc = this.argv.length;
500 |
501 | const argvPtrs = [];
502 | this.argv.forEach((arg) => {
503 | argvPtrs.push(strPtr(arg));
504 | });
505 | argvPtrs.push(0);
506 |
507 | const keys = Object.keys(this.env).sort();
508 | keys.forEach((key) => {
509 | argvPtrs.push(strPtr(`${key}=${this.env[key]}`));
510 | });
511 | argvPtrs.push(0);
512 |
513 | const argv = offset;
514 | argvPtrs.forEach((ptr) => {
515 | this.mem.setUint32(offset, ptr, true);
516 | this.mem.setUint32(offset + 4, 0, true);
517 | offset += 8;
518 | });
519 |
520 | // The linker guarantees global data starts from at least wasmMinDataAddr.
521 | // Keep in sync with cmd/link/internal/ld/data.go:wasmMinDataAddr.
522 | const wasmMinDataAddr = 4096 + 8192;
523 | if (offset >= wasmMinDataAddr) {
524 | throw new Error("total length of command line and environment variables exceeds limit");
525 | }
526 |
527 | this._inst.exports.run(argc, argv);
528 | if (this.exited) {
529 | this._resolveExitPromise();
530 | }
531 | await this._exitPromise;
532 | }
533 |
534 | _resume() {
535 | if (this.exited) {
536 | throw new Error("Go program has already exited");
537 | }
538 | this._inst.exports.resume();
539 | if (this.exited) {
540 | this._resolveExitPromise();
541 | }
542 | }
543 |
544 | _makeFuncWrapper(id) {
545 | const go = this;
546 | return function () {
547 | const event = { id: id, this: this, args: arguments };
548 | go._pendingEvent = event;
549 | go._resume();
550 | return event.result;
551 | };
552 | }
553 | }
554 | })();
555 |
--------------------------------------------------------------------------------
/demo/wasm/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/shivamMg/trie/demo/wasm
2 |
3 | go 1.18
4 |
5 | require github.com/shivamMg/trie v0.0.0-20220605114339-55c1368e363e
6 |
7 | require github.com/shivamMg/ppds v0.0.1 // indirect
8 |
9 | replace github.com/shivamMg/trie => ../../
10 |
--------------------------------------------------------------------------------
/demo/wasm/go.sum:
--------------------------------------------------------------------------------
1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
2 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
3 | github.com/shivamMg/ppds v0.0.1 h1:idK2dpaen652zOO+OmcwmyoPNncBNqfHjF/14eS5JIk=
4 | github.com/shivamMg/ppds v0.0.1/go.mod h1:hb39VqUO6qfkb9zBBQPTIV1vWBtI7yQsG0wr3pN78fM=
5 | github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
6 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
7 |
--------------------------------------------------------------------------------
/demo/wasm/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "bufio"
5 | "bytes"
6 | _ "embed"
7 | "fmt"
8 | "io"
9 | "strings"
10 | "sync"
11 | "syscall/js"
12 |
13 | "github.com/shivamMg/trie"
14 | )
15 |
16 | //go:embed words.txt
17 | var data []byte
18 |
19 | var (
20 | mu sync.Mutex
21 | tri *trie.Trie
22 |
23 | longestWordLen int
24 | )
25 |
26 | func init() {
27 | initTrie()
28 | }
29 |
30 | func initTrie() {
31 | mu.Lock()
32 | defer mu.Unlock()
33 | tri = trie.New()
34 | r := bufio.NewReader(bytes.NewReader(data))
35 | for {
36 | word, err := r.ReadString('\n')
37 | if err == io.EOF {
38 | break
39 | }
40 | if err != nil {
41 | fmt.Println(err)
42 | }
43 | word = strings.TrimRight(word, "\n")
44 | if len(word) > longestWordLen {
45 | longestWordLen = len(word)
46 | }
47 | key := strings.Split(word, "")
48 | tri.Put(key, struct{}{})
49 | }
50 | }
51 |
52 | func getNoEdits(key []string, ops []*trie.EditOp) []interface{} {
53 | uneditedLetters := make([]string, 0)
54 | for _, op := range ops {
55 | if op.Type == trie.EditOpTypeNoEdit {
56 | uneditedLetters = append(uneditedLetters, op.KeyPart)
57 | }
58 | }
59 | noEdits := make([]interface{}, len(key))
60 | j := 0
61 | for i, letter := range key {
62 | unedited := false
63 | if j < len(uneditedLetters) && letter == uneditedLetters[j] {
64 | unedited = true
65 | j += 1
66 | }
67 | noEdits[i] = unedited
68 | }
69 | return noEdits
70 | }
71 |
72 | func getNoEditsForPrefixSearch(wordLen int, keyLen int) []interface{} {
73 | noEdits := make([]interface{}, keyLen)
74 | for i := 0; i < keyLen; i++ {
75 | noEdits[i] = i < wordLen
76 | }
77 | return noEdits
78 | }
79 |
80 | func searchWord(this js.Value, args []js.Value) interface{} {
81 | mu.Lock()
82 | defer mu.Unlock()
83 | word := args[0].String()
84 | if len(word) > longestWordLen {
85 | return map[string]interface{}{
86 | "words": []interface{}{},
87 | "noEdits": []interface{}{},
88 | }
89 | }
90 | approximate := args[1].Bool()
91 | key := strings.Split(word, "")
92 | opts := []func(*trie.SearchOptions){trie.WithMaxResults(10)}
93 | if approximate {
94 | opts = append(opts, trie.WithMaxEditDistance(3), trie.WithEditOps(), trie.WithTopKLeastEdited())
95 | }
96 | results := tri.Search(key, opts...)
97 | n := len(results.Results)
98 | words := make([]interface{}, n)
99 | noEdits := make([]interface{}, n)
100 | for i, res := range results.Results {
101 | words[i] = strings.Join(res.Key, "")
102 | if approximate {
103 | noEdits[i] = getNoEdits(res.Key, res.EditOps)
104 | } else {
105 | noEdits[i] = getNoEditsForPrefixSearch(len(word), len(res.Key))
106 | }
107 | }
108 | return map[string]interface{}{
109 | "words": words,
110 | "noEdits": noEdits,
111 | }
112 | }
113 |
114 | func main() {
115 | c := make(chan struct{}, 0)
116 | js.Global().Set("searchWord", js.FuncOf(searchWord))
117 | <-c
118 | }
119 |
--------------------------------------------------------------------------------
/demo/words.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import string
4 |
5 |
6 | WORDS_FILE = '/usr/share/dict/american-english' # https://en.wikipedia.org/wiki/Words_(Unix)
7 | TARGET_FILE = 'wasm/words.txt'
8 |
9 |
10 | if __name__ == '__main__':
11 | with open(WORDS_FILE) as f:
12 | words = [line.rstrip('\n') for line in f.readlines()]
13 | words = [word for word in words if all([ch in string.ascii_lowercase for ch in word])] # removes words: e.g. Abby (noun), accuser's (apostrophe)
14 | word_set = set(words)
15 | is_singular = lambda word: not word.endswith('s') or word[:-1] not in word_set
16 | words = [word for word in words if is_singular(word)]
17 |
18 | with open(TARGET_FILE, 'w') as f:
19 | for word in words:
20 | f.write(word + '\n')
21 | print(len(words), 'words written to', TARGET_FILE)
--------------------------------------------------------------------------------
/dll.go:
--------------------------------------------------------------------------------
1 | package trie
2 |
3 | // TODO: tests
4 | type doublyLinkedList struct {
5 | head, tail *dllNode
6 | }
7 |
8 | type dllNode struct {
9 | trieNode *Node
10 | next, prev *dllNode
11 | }
12 |
13 | func newDLLNode(trieNode *Node) *dllNode {
14 | return &dllNode{trieNode: trieNode}
15 | }
16 |
17 | func (dll *doublyLinkedList) append(node *dllNode) {
18 | if dll.head == nil {
19 | dll.head = node
20 | dll.tail = node
21 | return
22 | }
23 | dll.tail.next = node
24 | node.prev = dll.tail
25 | dll.tail = node
26 | }
27 |
28 | func (dll *doublyLinkedList) pop(node *dllNode) {
29 | if node == dll.head && node == dll.tail {
30 | dll.head = nil
31 | dll.tail = nil
32 | return
33 | }
34 | if node == dll.head {
35 | dll.head = node.next
36 | dll.head.prev = nil
37 | node.next = nil
38 | return
39 | }
40 | if node == dll.tail {
41 | dll.tail = node.prev
42 | dll.tail.next = nil
43 | node.prev = nil
44 | return
45 | }
46 | prev := node.prev
47 | next := node.next
48 | prev.next = next
49 | next.prev = prev
50 | node.next = nil
51 | node.prev = nil
52 | }
53 |
--------------------------------------------------------------------------------
/example_test.go:
--------------------------------------------------------------------------------
1 | package trie_test
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/shivamMg/trie"
7 | )
8 |
9 | func printEditOps(ops []*trie.EditOp) {
10 | for _, op := range ops {
11 | switch op.Type {
12 | case trie.EditOpTypeNoEdit:
13 | fmt.Printf("- don't edit %q\n", op.KeyPart)
14 | case trie.EditOpTypeInsert:
15 | fmt.Printf("- insert %q\n", op.KeyPart)
16 | case trie.EditOpTypeDelete:
17 | fmt.Printf("- delete %q\n", op.KeyPart)
18 | case trie.EditOpTypeReplace:
19 | fmt.Printf("- replace %q with %q\n", op.KeyPart, op.ReplaceWith)
20 | }
21 | }
22 | }
23 |
24 | func Example() {
25 | tri := trie.New()
26 | // Put keys ([]string) and values (any)
27 | tri.Put([]string{"the"}, 1)
28 | tri.Put([]string{"the", "quick", "brown", "fox"}, 2)
29 | tri.Put([]string{"the", "quick", "sports", "car"}, 3)
30 | tri.Put([]string{"the", "green", "tree"}, 4)
31 | tri.Put([]string{"an", "apple", "tree"}, 5)
32 | tri.Put([]string{"an", "umbrella"}, 6)
33 |
34 | tri.Root().Print()
35 | // Output (full trie with terminals ending with ($)):
36 | // ^
37 | // ├─ the ($)
38 | // │ ├─ quick
39 | // │ │ ├─ brown
40 | // │ │ │ └─ fox ($)
41 | // │ │ └─ sports
42 | // │ │ └─ car ($)
43 | // │ └─ green
44 | // │ └─ tree ($)
45 | // └─ an
46 | // ├─ apple
47 | // │ └─ tree ($)
48 | // └─ umbrella ($)
49 |
50 | results := tri.Search([]string{"the", "quick"})
51 | for _, res := range results.Results {
52 | fmt.Println(res.Key, res.Value)
53 | }
54 | // Output (prefix-based search):
55 | // [the quick brown fox] 2
56 | // [the quick sports car] 3
57 |
58 | key := []string{"the", "tree"}
59 | results = tri.Search(key, trie.WithMaxEditDistance(2), // An edit can be insert, delete, replace
60 | trie.WithEditOps())
61 | for _, res := range results.Results {
62 | fmt.Println(res.Key, res.EditDistance) // EditDistance is number of edits needed to convert to [the tree]
63 | }
64 | // Output (results not more than 2 edits away from [the tree]):
65 | // [the] 1
66 | // [the green tree] 1
67 | // [an apple tree] 2
68 | // [an umbrella] 2
69 |
70 | result := results.Results[2]
71 | fmt.Printf("To convert %v to %v:\n", result.Key, key)
72 | printEditOps(result.EditOps)
73 | // Output (edit operations needed to covert a result to [the tree]):
74 | // To convert [an apple tree] to [the tree]:
75 | // - delete "an"
76 | // - replace "apple" with "the"
77 | // - don't edit "tree"
78 |
79 | results = tri.Search(key, trie.WithMaxEditDistance(2), trie.WithTopKLeastEdited(), trie.WithMaxResults(2))
80 | for _, res := range results.Results {
81 | fmt.Println(res.Key, res.Value, res.EditDistance)
82 | }
83 | // Output (top 2 least edited results):
84 | // [the] 1 1
85 | // [the green tree] 4 1
86 | }
87 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/shivamMg/trie
2 |
3 | go 1.18
4 |
5 | require (
6 | github.com/shivamMg/ppds v0.0.1
7 | github.com/stretchr/testify v1.7.1
8 | )
9 |
10 | require (
11 | github.com/davecgh/go-spew v1.1.0 // indirect
12 | github.com/pmezard/go-difflib v1.0.0 // indirect
13 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
14 | )
15 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
3 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
4 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
5 | github.com/shivamMg/ppds v0.0.1 h1:idK2dpaen652zOO+OmcwmyoPNncBNqfHjF/14eS5JIk=
6 | github.com/shivamMg/ppds v0.0.1/go.mod h1:hb39VqUO6qfkb9zBBQPTIV1vWBtI7yQsG0wr3pN78fM=
7 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
8 | github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
9 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
10 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
11 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
12 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
13 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
14 |
--------------------------------------------------------------------------------
/go.work:
--------------------------------------------------------------------------------
1 | go 1.18
2 |
3 | use ./demo/wasm
4 |
--------------------------------------------------------------------------------
/go.work.sum:
--------------------------------------------------------------------------------
1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
2 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
3 | github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
4 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
5 |
--------------------------------------------------------------------------------
/heap.go:
--------------------------------------------------------------------------------
1 | package trie
2 |
3 | type searchResultMaxHeap []*SearchResult
4 |
5 | func (s searchResultMaxHeap) Len() int {
6 | return len(s)
7 | }
8 |
9 | func (s searchResultMaxHeap) Less(i, j int) bool {
10 | if s[i].EditDistance == s[j].EditDistance {
11 | return s[i].tiebreaker > s[j].tiebreaker
12 | }
13 | return s[i].EditDistance > s[j].EditDistance
14 | }
15 |
16 | func (s searchResultMaxHeap) Swap(i, j int) {
17 | s[i], s[j] = s[j], s[i]
18 | }
19 |
20 | func (s *searchResultMaxHeap) Push(x interface{}) {
21 | *s = append(*s, x.(*SearchResult))
22 | }
23 |
24 | func (s *searchResultMaxHeap) Pop() interface{} {
25 | old := *s
26 | n := len(old)
27 | x := old[n-1]
28 | *s = old[0 : n-1]
29 | return x
30 | }
31 |
--------------------------------------------------------------------------------
/search.go:
--------------------------------------------------------------------------------
1 | package trie
2 |
3 | import (
4 | "container/heap"
5 | "errors"
6 | "math"
7 | )
8 |
9 | type EditOpType int
10 |
11 | const (
12 | EditOpTypeNoEdit EditOpType = iota
13 | EditOpTypeInsert
14 | EditOpTypeDelete
15 | EditOpTypeReplace
16 | )
17 |
18 | // EditOp represents an Edit Operation.
19 | type EditOp struct {
20 | Type EditOpType
21 | // KeyPart:
22 | // - In case of NoEdit, KeyPart is to be retained.
23 | // - In case of Insert, KeyPart is to be inserted in the key.
24 | // - In case of Delete/Replace, KeyPart is the part of the key on which delete/replace is performed.
25 | KeyPart string
26 | // ReplaceWith is set for Type=EditOpTypeReplace
27 | ReplaceWith string
28 | }
29 |
30 | type SearchResults struct {
31 | Results []*SearchResult
32 | heap *searchResultMaxHeap
33 | tiebreakerCount int
34 | }
35 |
36 | type SearchResult struct {
37 | // Key is the key that was Put() into the Trie.
38 | Key []string
39 | // Value is the value that was Put() into the Trie.
40 | Value interface{}
41 | // EditDistance is the number of edits (insert/delete/replace) needed to convert Key into the Search()-ed key.
42 | EditDistance int
43 | // EditOps is the list of edit operations (see EditOpType) needed to convert Key into the Search()-ed key.
44 | EditOps []*EditOp
45 |
46 | tiebreaker int
47 | }
48 |
49 | type SearchOptions struct {
50 | // - WithExactKey
51 | // - WithMaxResults
52 | // - WithMaxEditDistance
53 | // - WithEditOps
54 | // - WithTopKLeastEdited
55 | exactKey bool
56 | maxResults bool
57 | maxResultsCount int
58 | editDistance bool
59 | maxEditDistance int
60 | editOps bool
61 | topKLeastEdited bool
62 | }
63 |
64 | // WithExactKey can be passed to Search(). When passed, Search() returns just the result with
65 | // Key=Search()-ed key. If the key does not exist, result list will be empty.
66 | func WithExactKey() func(*SearchOptions) {
67 | return func(so *SearchOptions) {
68 | so.exactKey = true
69 | }
70 | }
71 |
72 | // WithMaxResults can be passed to Search(). When passed, Search() will return at most maxResults
73 | // number of results.
74 | func WithMaxResults(maxResults int) func(*SearchOptions) {
75 | if maxResults <= 0 {
76 | panic(errors.New("invalid usage: maxResults must be greater than zero"))
77 | }
78 | return func(so *SearchOptions) {
79 | so.maxResults = true
80 | so.maxResultsCount = maxResults
81 | }
82 | }
83 |
84 | // WithMaxEditDistance can be passed to Search(). When passed, Search() changes its default behaviour from
85 | // Prefix search to Edit distance search. It can be used to return "Approximate" results instead of strict
86 | // Prefix search results.
87 | //
88 | // maxDistance is the maximum number of edits allowed on Trie keys to consider them as a SearchResult.
89 | // Higher the maxDistance, more lenient and slower the search becomes.
90 | //
91 | // e.g. If a Trie stores English words, then searching for "wheat" with maxDistance=1 might return similar
92 | // looking words like "wheat", "cheat", "heat", "what", etc. With maxDistance=2 it might also return words like
93 | // "beat", "ahead", etc.
94 | //
95 | // Read about Edit distance: https://en.wikipedia.org/wiki/Edit_distance
96 | func WithMaxEditDistance(maxDistance int) func(*SearchOptions) {
97 | if maxDistance <= 0 {
98 | panic(errors.New("invalid usage: maxDistance must be greater than zero"))
99 | }
100 | return func(so *SearchOptions) {
101 | so.editDistance = true
102 | so.maxEditDistance = maxDistance
103 | }
104 | }
105 |
106 | // WithEditOps can be passed to Search() alongside WithMaxEditDistance(). When passed, Search() also returns EditOps
107 | // for each SearchResult. EditOps can be used to determine the minimum number of edit operations needed to convert
108 | // a result Key into the Search()-ed key.
109 | //
110 | // e.g. Searching for "wheat" in a Trie that stores English words might return "beat". EditOps for this result might be:
111 | // 1. insert "w" 2. replace "b" with "h".
112 | //
113 | // There might be multiple ways to edit a key into another. EditOps represents only one.
114 | //
115 | // Computing EditOps makes Search() slower.
116 | func WithEditOps() func(*SearchOptions) {
117 | return func(so *SearchOptions) {
118 | so.editOps = true
119 | }
120 | }
121 |
122 | // WithTopKLeastEdited can be passed to Search() alongside WithMaxEditDistance() and WithMaxResults(). When passed,
123 | // Search() returns maxResults number of results that have the lowest EditDistances. Results are sorted on EditDistance
124 | // (lowest to highest).
125 | //
126 | // e.g. In a Trie that stores English words searching for "wheat" might return "wheat" (EditDistance=0), "cheat" (EditDistance=1),
127 | // "beat" (EditDistance=2) - in that order.
128 | func WithTopKLeastEdited() func(*SearchOptions) {
129 | return func(so *SearchOptions) {
130 | so.topKLeastEdited = true
131 | }
132 | }
133 |
134 | // Search() takes a key and some options to return results (see SearchResult) from the Trie.
135 | // Without any options, it does a Prefix search i.e. result Keys have the same prefix as key.
136 | // Order of the results is deterministic and will follow the order in which Put() was called for the keys.
137 | // See "With*" functions for options accepted by Search().
138 | func (t *Trie) Search(key []string, options ...func(*SearchOptions)) *SearchResults {
139 | opts := &SearchOptions{}
140 | for _, f := range options {
141 | f(opts)
142 | }
143 | if opts.editOps && !opts.editDistance {
144 | panic(errors.New("invalid usage: WithEditOps() must not be passed without WithMaxEditDistance()"))
145 | }
146 | if opts.topKLeastEdited && !opts.editDistance {
147 | panic(errors.New("invalid usage: WithTopKLeastEdited() must not be passed without WithMaxEditDistance()"))
148 | }
149 | if opts.exactKey && opts.editDistance {
150 | panic(errors.New("invalid usage: WithExactKey() must not be passed with WithMaxEditDistance()"))
151 | }
152 | if opts.exactKey && opts.maxResults {
153 | panic(errors.New("invalid usage: WithExactKey() must not be passed with WithMaxResults()"))
154 | }
155 | if opts.topKLeastEdited && !opts.maxResults {
156 | panic(errors.New("invalid usage: WithTopKLeastEdited() must not be passed without WithMaxResults()"))
157 | }
158 |
159 | if opts.editDistance {
160 | return t.searchWithEditDistance(key, opts)
161 | }
162 | return t.search(key, opts)
163 | }
164 |
165 | func (t *Trie) searchWithEditDistance(key []string, opts *SearchOptions) *SearchResults {
166 | // https://en.wikipedia.org/wiki/Levenshtein_distance#Iterative_with_full_matrix
167 | // http://stevehanov.ca/blog/?id=114
168 | columns := len(key) + 1
169 | newRow := make([]int, columns)
170 | for i := 0; i < columns; i++ {
171 | newRow[i] = i
172 | }
173 | m := len(key)
174 | if m == 0 {
175 | m = 1
176 | }
177 | rows := make([][]int, 1, m)
178 | rows[0] = newRow
179 | results := &SearchResults{}
180 | if opts.topKLeastEdited {
181 | results.heap = &searchResultMaxHeap{}
182 | }
183 |
184 | keyColumn := make([]string, 1, m)
185 | stop := false
186 | // prioritize Node that has the same keyPart as key. this results in better results
187 | // e.g. if key=national, build with Node(keyPart=n) first so that keys like notional, nation, nationally, etc. are prioritized
188 | // same logic is used inside the recursive buildWithEditDistance() method
189 | var prioritizedNode *Node
190 | if len(key) > 0 {
191 | if prioritizedNode = t.root.children[key[0]]; prioritizedNode != nil {
192 | keyColumn[0] = prioritizedNode.keyPart
193 | t.buildWithEditDistance(&stop, results, prioritizedNode, &keyColumn, &rows, key, opts)
194 | }
195 | }
196 | for dllNode := t.root.childrenDLL.head; dllNode != nil; dllNode = dllNode.next {
197 | node := dllNode.trieNode
198 | if node == prioritizedNode {
199 | continue
200 | }
201 | keyColumn[0] = node.keyPart
202 | t.buildWithEditDistance(&stop, results, node, &keyColumn, &rows, key, opts)
203 | }
204 | if opts.topKLeastEdited {
205 | n := results.heap.Len()
206 | results.Results = make([]*SearchResult, n)
207 | for n != 0 {
208 | result := heap.Pop(results.heap).(*SearchResult)
209 | result.tiebreaker = 0
210 | results.Results[n-1] = result
211 | n--
212 | }
213 | results.heap = nil
214 | results.tiebreakerCount = 0
215 | }
216 | return results
217 | }
218 |
219 | func (t *Trie) buildWithEditDistance(stop *bool, results *SearchResults, node *Node, keyColumn *[]string, rows *[][]int, key []string, opts *SearchOptions) {
220 | if *stop {
221 | return
222 | }
223 | prevRow := (*rows)[len(*rows)-1]
224 | columns := len(key) + 1
225 | newRow := make([]int, columns)
226 | newRow[0] = prevRow[0] + 1
227 | for i := 1; i < columns; i++ {
228 | replaceCost := 1
229 | if key[i-1] == (*keyColumn)[len(*keyColumn)-1] {
230 | replaceCost = 0
231 | }
232 | newRow[i] = min(
233 | newRow[i-1]+1, // insertion
234 | prevRow[i]+1, // deletion
235 | prevRow[i-1]+replaceCost, // substitution
236 | )
237 | }
238 | *rows = append(*rows, newRow)
239 |
240 | if newRow[columns-1] <= opts.maxEditDistance && node.isTerminal {
241 | editDistance := newRow[columns-1]
242 | lazyCreate := func() *SearchResult { // optimization for the case where topKLeastEdited=true and the result should not be pushed to heap
243 | resultKey := make([]string, len(*keyColumn))
244 | copy(resultKey, *keyColumn)
245 | result := &SearchResult{Key: resultKey, Value: node.value, EditDistance: editDistance}
246 | if opts.editOps {
247 | result.EditOps = t.getEditOps(rows, keyColumn, key)
248 | }
249 | return result
250 | }
251 | if opts.topKLeastEdited {
252 | results.tiebreakerCount++
253 | if results.heap.Len() < opts.maxResultsCount {
254 | result := lazyCreate()
255 | result.tiebreaker = results.tiebreakerCount
256 | heap.Push(results.heap, result)
257 | } else if (*results.heap)[0].EditDistance > editDistance {
258 | result := lazyCreate()
259 | result.tiebreaker = results.tiebreakerCount
260 | heap.Pop(results.heap)
261 | heap.Push(results.heap, result)
262 | }
263 | } else {
264 | result := lazyCreate()
265 | results.Results = append(results.Results, result)
266 | if opts.maxResults && len(results.Results) == opts.maxResultsCount {
267 | *stop = true
268 | return
269 | }
270 | }
271 | }
272 |
273 | if min(newRow...) <= opts.maxEditDistance {
274 | var prioritizedNode *Node
275 | m := len(*keyColumn)
276 | if m < len(key) {
277 | if prioritizedNode = node.children[key[m]]; prioritizedNode != nil {
278 | *keyColumn = append(*keyColumn, prioritizedNode.keyPart)
279 | t.buildWithEditDistance(stop, results, prioritizedNode, keyColumn, rows, key, opts)
280 | *keyColumn = (*keyColumn)[:len(*keyColumn)-1]
281 | }
282 | }
283 | for dllNode := node.childrenDLL.head; dllNode != nil; dllNode = dllNode.next {
284 | child := dllNode.trieNode
285 | if child == prioritizedNode {
286 | continue
287 | }
288 | *keyColumn = append(*keyColumn, child.keyPart)
289 | t.buildWithEditDistance(stop, results, child, keyColumn, rows, key, opts)
290 | *keyColumn = (*keyColumn)[:len(*keyColumn)-1]
291 | }
292 | }
293 |
294 | *rows = (*rows)[:len(*rows)-1]
295 | }
296 |
297 | func (t *Trie) getEditOps(rows *[][]int, keyColumn *[]string, key []string) []*EditOp {
298 | // https://gist.github.com/jlherren/d97839b1276b9bd7faa930f74711a4b6
299 | ops := make([]*EditOp, 0, len(key))
300 | r, c := len(*rows)-1, len((*rows)[0])-1
301 | for r > 0 || c > 0 {
302 | insertionCost, deletionCost, substitutionCost := math.MaxInt, math.MaxInt, math.MaxInt
303 | if c > 0 {
304 | insertionCost = (*rows)[r][c-1]
305 | }
306 | if r > 0 {
307 | deletionCost = (*rows)[r-1][c]
308 | }
309 | if r > 0 && c > 0 {
310 | substitutionCost = (*rows)[r-1][c-1]
311 | }
312 | minCost := min(insertionCost, deletionCost, substitutionCost)
313 | if minCost == substitutionCost {
314 | if (*rows)[r][c] > (*rows)[r-1][c-1] {
315 | ops = append(ops, &EditOp{Type: EditOpTypeReplace, KeyPart: (*keyColumn)[r-1], ReplaceWith: key[c-1]})
316 | } else {
317 | ops = append(ops, &EditOp{Type: EditOpTypeNoEdit, KeyPart: (*keyColumn)[r-1]})
318 | }
319 | r -= 1
320 | c -= 1
321 | } else if minCost == deletionCost {
322 | ops = append(ops, &EditOp{Type: EditOpTypeDelete, KeyPart: (*keyColumn)[r-1]})
323 | r -= 1
324 | } else if minCost == insertionCost {
325 | ops = append(ops, &EditOp{Type: EditOpTypeInsert, KeyPart: key[c-1]})
326 | c -= 1
327 | }
328 | }
329 | for i, j := 0, len(ops)-1; i < j; i, j = i+1, j-1 {
330 | ops[i], ops[j] = ops[j], ops[i]
331 | }
332 | return ops
333 | }
334 |
335 | func (t *Trie) search(prefixKey []string, opts *SearchOptions) *SearchResults {
336 | results := &SearchResults{}
337 | node := t.root
338 | for _, keyPart := range prefixKey {
339 | child, ok := node.children[keyPart]
340 | if !ok {
341 | return results
342 | }
343 | node = child
344 | }
345 | if opts.exactKey {
346 | if node.isTerminal {
347 | result := &SearchResult{Key: prefixKey, Value: node.value}
348 | results.Results = append(results.Results, result)
349 | }
350 | return results
351 | }
352 | t.build(results, node, &prefixKey, opts)
353 | return results
354 | }
355 |
356 | func (t *Trie) build(results *SearchResults, node *Node, prefixKey *[]string, opts *SearchOptions) (stop bool) {
357 | if node.isTerminal {
358 | key := make([]string, len(*prefixKey))
359 | copy(key, *prefixKey)
360 | result := &SearchResult{Key: key, Value: node.value}
361 | results.Results = append(results.Results, result)
362 | if opts.maxResults && len(results.Results) == opts.maxResultsCount {
363 | return true
364 | }
365 | }
366 |
367 | for dllNode := node.childrenDLL.head; dllNode != nil; dllNode = dllNode.next {
368 | child := dllNode.trieNode
369 | *prefixKey = append(*prefixKey, child.keyPart)
370 | stop := t.build(results, child, prefixKey, opts)
371 | *prefixKey = (*prefixKey)[:len(*prefixKey)-1]
372 | if stop {
373 | return true
374 | }
375 | }
376 | return false
377 | }
378 |
379 | func min(s ...int) int {
380 | m := s[0]
381 | for _, a := range s[1:] {
382 | if a < m {
383 | m = a
384 | }
385 | }
386 | return m
387 | }
388 |
--------------------------------------------------------------------------------
/search_test.go:
--------------------------------------------------------------------------------
1 | package trie_test
2 |
3 | import (
4 | "bufio"
5 | "io"
6 | "os"
7 | "strings"
8 | "testing"
9 |
10 | "github.com/shivamMg/trie"
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | var (
15 | benchmarkResults *trie.SearchResults // https://dave.cheney.net/2013/06/30/how-to-write-benchmarks-in-go
16 | wordsTrie *trie.Trie
17 | )
18 |
19 | func TestTrie_Search(t *testing.T) {
20 | tri := trie.New()
21 | tri.Put([]string{"the"}, 1)
22 | tri.Put([]string{"the", "quick", "brown", "fox"}, 2)
23 | tri.Put([]string{"the", "quick", "swimmer"}, 3)
24 | tri.Put([]string{"the", "green", "tree"}, 4)
25 | tri.Put([]string{"an", "apple", "tree"}, 5)
26 | tri.Put([]string{"an", "umbrella"}, 6)
27 |
28 | testCases := []struct {
29 | name string
30 | inputKey []string
31 | inputOptions []func(*trie.SearchOptions)
32 | expectedResults *trie.SearchResults
33 | }{
34 | {
35 | name: "prefix-one-word",
36 | inputKey: []string{"the"},
37 | expectedResults: &trie.SearchResults{
38 | Results: []*trie.SearchResult{
39 | {Key: []string{"the"}, Value: 1},
40 | {Key: []string{"the", "quick", "brown", "fox"}, Value: 2},
41 | {Key: []string{"the", "quick", "swimmer"}, Value: 3},
42 | {Key: []string{"the", "green", "tree"}, Value: 4},
43 | },
44 | },
45 | },
46 | {
47 | name: "prefix-one-word-with-max-three-results",
48 | inputKey: []string{"the"},
49 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxResults(3)},
50 | expectedResults: &trie.SearchResults{
51 | Results: []*trie.SearchResult{
52 | {Key: []string{"the"}, Value: 1},
53 | {Key: []string{"the", "quick", "brown", "fox"}, Value: 2},
54 | {Key: []string{"the", "quick", "swimmer"}, Value: 3},
55 | },
56 | },
57 | },
58 | {
59 | name: "prefix-multiple-words",
60 | inputKey: []string{"the", "quick"},
61 | expectedResults: &trie.SearchResults{
62 | Results: []*trie.SearchResult{
63 | {Key: []string{"the", "quick", "brown", "fox"}, Value: 2},
64 | {Key: []string{"the", "quick", "swimmer"}, Value: 3},
65 | },
66 | },
67 | },
68 | {
69 | name: "prefix-non-existing",
70 | inputKey: []string{"non-existing"},
71 | expectedResults: &trie.SearchResults{},
72 | },
73 | {
74 | name: "prefix-empty",
75 | inputKey: []string{},
76 | expectedResults: &trie.SearchResults{
77 | Results: []*trie.SearchResult{
78 | {Key: []string{"the"}, Value: 1},
79 | {Key: []string{"the", "quick", "brown", "fox"}, Value: 2},
80 | {Key: []string{"the", "quick", "swimmer"}, Value: 3},
81 | {Key: []string{"the", "green", "tree"}, Value: 4},
82 | {Key: []string{"an", "apple", "tree"}, Value: 5},
83 | {Key: []string{"an", "umbrella"}, Value: 6},
84 | },
85 | },
86 | },
87 | {
88 | name: "prefix-nil",
89 | inputKey: nil,
90 | expectedResults: &trie.SearchResults{
91 | Results: []*trie.SearchResult{
92 | {Key: []string{"the"}, Value: 1},
93 | {Key: []string{"the", "quick", "brown", "fox"}, Value: 2},
94 | {Key: []string{"the", "quick", "swimmer"}, Value: 3},
95 | {Key: []string{"the", "green", "tree"}, Value: 4},
96 | {Key: []string{"an", "apple", "tree"}, Value: 5},
97 | {Key: []string{"an", "umbrella"}, Value: 6},
98 | },
99 | },
100 | },
101 | {
102 | name: "prefix-one-word-with-exact-key",
103 | inputKey: []string{"the"},
104 | inputOptions: []func(*trie.SearchOptions){trie.WithExactKey()},
105 | expectedResults: &trie.SearchResults{
106 | Results: []*trie.SearchResult{
107 | {Key: []string{"the"}, Value: 1},
108 | },
109 | },
110 | },
111 | {
112 | name: "prefix-multiple-words-with-exact-key",
113 | inputKey: []string{"the", "quick", "swimmer"},
114 | inputOptions: []func(*trie.SearchOptions){trie.WithExactKey()},
115 | expectedResults: &trie.SearchResults{
116 | Results: []*trie.SearchResult{
117 | {Key: []string{"the", "quick", "swimmer"}, Value: 3},
118 | },
119 | },
120 | },
121 | {
122 | name: "edit-distance-one-edit",
123 | inputKey: []string{"the", "tree"},
124 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(1)},
125 | expectedResults: &trie.SearchResults{
126 | Results: []*trie.SearchResult{
127 | {Key: []string{"the"}, Value: 1, EditDistance: 1},
128 | {Key: []string{"the", "green", "tree"}, Value: 4, EditDistance: 1},
129 | },
130 | },
131 | },
132 | {
133 | name: "edit-distance-one-edit-with-edit-opts",
134 | inputKey: []string{"the", "tree"},
135 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(1), trie.WithEditOps()},
136 | expectedResults: &trie.SearchResults{
137 | Results: []*trie.SearchResult{
138 | {Key: []string{"the"}, Value: 1, EditDistance: 1, EditOps: []*trie.EditOp{
139 | {Type: trie.EditOpTypeNoEdit, KeyPart: "the"},
140 | {Type: trie.EditOpTypeInsert, KeyPart: "tree"},
141 | }},
142 | {Key: []string{"the", "green", "tree"}, Value: 4, EditDistance: 1, EditOps: []*trie.EditOp{
143 | {Type: trie.EditOpTypeNoEdit, KeyPart: "the"},
144 | {Type: trie.EditOpTypeDelete, KeyPart: "green"},
145 | {Type: trie.EditOpTypeNoEdit, KeyPart: "tree"},
146 | }},
147 | },
148 | },
149 | },
150 | {
151 | name: "edit-distance-two-edits-with-edit-opts",
152 | inputKey: []string{"the", "tree"},
153 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(2), trie.WithEditOps()},
154 | expectedResults: &trie.SearchResults{
155 | Results: []*trie.SearchResult{
156 | {Key: []string{"the"}, Value: 1, EditDistance: 1, EditOps: []*trie.EditOp{
157 | {Type: trie.EditOpTypeNoEdit, KeyPart: "the"},
158 | {Type: trie.EditOpTypeInsert, KeyPart: "tree"},
159 | }},
160 | {Key: []string{"the", "quick", "swimmer"}, Value: 3, EditDistance: 2, EditOps: []*trie.EditOp{
161 | {Type: trie.EditOpTypeNoEdit, KeyPart: "the"},
162 | {Type: trie.EditOpTypeDelete, KeyPart: "quick"},
163 | {Type: trie.EditOpTypeReplace, KeyPart: "swimmer", ReplaceWith: "tree"},
164 | }},
165 | {Key: []string{"the", "green", "tree"}, Value: 4, EditDistance: 1, EditOps: []*trie.EditOp{
166 | {Type: trie.EditOpTypeNoEdit, KeyPart: "the"},
167 | {Type: trie.EditOpTypeDelete, KeyPart: "green"},
168 | {Type: trie.EditOpTypeNoEdit, KeyPart: "tree"},
169 | }},
170 | {Key: []string{"an", "apple", "tree"}, Value: 5, EditDistance: 2, EditOps: []*trie.EditOp{
171 | {Type: trie.EditOpTypeDelete, KeyPart: "an"},
172 | {Type: trie.EditOpTypeReplace, KeyPart: "apple", ReplaceWith: "the"},
173 | {Type: trie.EditOpTypeNoEdit, KeyPart: "tree"},
174 | }},
175 | {Key: []string{"an", "umbrella"}, Value: 6, EditDistance: 2, EditOps: []*trie.EditOp{
176 | {Type: trie.EditOpTypeReplace, KeyPart: "an", ReplaceWith: "the"},
177 | {Type: trie.EditOpTypeReplace, KeyPart: "umbrella", ReplaceWith: "tree"},
178 | }},
179 | },
180 | },
181 | },
182 | {
183 | name: "edit-distance-two-edits-with-edit-opts-with-max-four-results",
184 | inputKey: []string{"the", "tree"},
185 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(2), trie.WithEditOps(), trie.WithMaxResults(4)},
186 | expectedResults: &trie.SearchResults{
187 | Results: []*trie.SearchResult{
188 | {Key: []string{"the"}, Value: 1, EditDistance: 1, EditOps: []*trie.EditOp{
189 | {Type: trie.EditOpTypeNoEdit, KeyPart: "the"},
190 | {Type: trie.EditOpTypeInsert, KeyPart: "tree"},
191 | }},
192 | {Key: []string{"the", "quick", "swimmer"}, Value: 3, EditDistance: 2, EditOps: []*trie.EditOp{
193 | {Type: trie.EditOpTypeNoEdit, KeyPart: "the"},
194 | {Type: trie.EditOpTypeDelete, KeyPart: "quick"},
195 | {Type: trie.EditOpTypeReplace, KeyPart: "swimmer", ReplaceWith: "tree"},
196 | }},
197 | {Key: []string{"the", "green", "tree"}, Value: 4, EditDistance: 1, EditOps: []*trie.EditOp{
198 | {Type: trie.EditOpTypeNoEdit, KeyPart: "the"},
199 | {Type: trie.EditOpTypeDelete, KeyPart: "green"},
200 | {Type: trie.EditOpTypeNoEdit, KeyPart: "tree"},
201 | }},
202 | {Key: []string{"an", "apple", "tree"}, Value: 5, EditDistance: 2, EditOps: []*trie.EditOp{
203 | {Type: trie.EditOpTypeDelete, KeyPart: "an"},
204 | {Type: trie.EditOpTypeReplace, KeyPart: "apple", ReplaceWith: "the"},
205 | {Type: trie.EditOpTypeNoEdit, KeyPart: "tree"},
206 | }},
207 | },
208 | },
209 | },
210 | {
211 | name: "edit-distance-one-edit-with-topk",
212 | inputKey: []string{"the", "tree"},
213 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(1), trie.WithTopKLeastEdited(), trie.WithMaxResults(1)},
214 | expectedResults: &trie.SearchResults{
215 | Results: []*trie.SearchResult{
216 | {Key: []string{"the"}, Value: 1, EditDistance: 1},
217 | },
218 | },
219 | },
220 | {
221 | name: "edit-distance-two-edits-with-edit-opts-with-topk",
222 | inputKey: []string{"the", "tree"},
223 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(2), trie.WithEditOps(), trie.WithTopKLeastEdited(), trie.WithMaxResults(4)},
224 | expectedResults: &trie.SearchResults{
225 | Results: []*trie.SearchResult{
226 | {Key: []string{"the"}, Value: 1, EditDistance: 1, EditOps: []*trie.EditOp{
227 | {Type: trie.EditOpTypeNoEdit, KeyPart: "the"},
228 | {Type: trie.EditOpTypeInsert, KeyPart: "tree"},
229 | }},
230 | {Key: []string{"the", "green", "tree"}, Value: 4, EditDistance: 1, EditOps: []*trie.EditOp{
231 | {Type: trie.EditOpTypeNoEdit, KeyPart: "the"},
232 | {Type: trie.EditOpTypeDelete, KeyPart: "green"},
233 | {Type: trie.EditOpTypeNoEdit, KeyPart: "tree"},
234 | }},
235 | {Key: []string{"the", "quick", "swimmer"}, Value: 3, EditDistance: 2, EditOps: []*trie.EditOp{
236 | {Type: trie.EditOpTypeNoEdit, KeyPart: "the"},
237 | {Type: trie.EditOpTypeDelete, KeyPart: "quick"},
238 | {Type: trie.EditOpTypeReplace, KeyPart: "swimmer", ReplaceWith: "tree"},
239 | }},
240 | {Key: []string{"an", "apple", "tree"}, Value: 5, EditDistance: 2, EditOps: []*trie.EditOp{
241 | {Type: trie.EditOpTypeDelete, KeyPart: "an"},
242 | {Type: trie.EditOpTypeReplace, KeyPart: "apple", ReplaceWith: "the"},
243 | {Type: trie.EditOpTypeNoEdit, KeyPart: "tree"},
244 | }},
245 | },
246 | },
247 | },
248 | {
249 | name: "edit-distance-two-edits-with-two-topk",
250 | inputKey: []string{"the", "tree"},
251 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(2), trie.WithTopKLeastEdited(), trie.WithMaxResults(2)},
252 | expectedResults: &trie.SearchResults{
253 | Results: []*trie.SearchResult{
254 | {Key: []string{"the"}, Value: 1, EditDistance: 1},
255 | {Key: []string{"the", "green", "tree"}, Value: 4, EditDistance: 1},
256 | },
257 | },
258 | },
259 | {
260 | name: "edit-distance-empty",
261 | inputKey: []string{},
262 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(2), trie.WithTopKLeastEdited(), trie.WithMaxResults(5)},
263 | expectedResults: &trie.SearchResults{
264 | Results: []*trie.SearchResult{
265 | {Key: []string{"the"}, Value: 1, EditDistance: 1},
266 | {Key: []string{"an", "umbrella"}, Value: 6, EditDistance: 2},
267 | },
268 | },
269 | },
270 | {
271 | name: "edit-distance-nil",
272 | inputKey: nil,
273 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(2), trie.WithTopKLeastEdited(), trie.WithMaxResults(5)},
274 | expectedResults: &trie.SearchResults{
275 | Results: []*trie.SearchResult{
276 | {Key: []string{"the"}, Value: 1, EditDistance: 1},
277 | {Key: []string{"an", "umbrella"}, Value: 6, EditDistance: 2},
278 | },
279 | },
280 | },
281 | }
282 |
283 | for _, tc := range testCases {
284 | t.Run(tc.name, func(t *testing.T) {
285 | actual := tri.Search(tc.inputKey, tc.inputOptions...)
286 | assert.Len(t, actual.Results, len(tc.expectedResults.Results))
287 | assert.Equal(t, tc.expectedResults, actual)
288 | })
289 | }
290 | }
291 |
292 | func TestTrie_Search_WordsTrie(t *testing.T) {
293 | tri := getWordsTrie()
294 | testCases := []struct {
295 | name string
296 | inputKey []string
297 | inputOptions []func(*trie.SearchOptions)
298 | expectedResults *trie.SearchResults
299 | }{
300 | {
301 | name: "prefix",
302 | inputKey: strings.Split("aband", ""),
303 | expectedResults: &trie.SearchResults{
304 | Results: []*trie.SearchResult{
305 | {Key: strings.Split("abandon", "")},
306 | {Key: strings.Split("abandoned", "")},
307 | {Key: strings.Split("abandoning", "")},
308 | {Key: strings.Split("abandonment", "")},
309 | },
310 | },
311 | },
312 | {
313 | name: "edit-distance",
314 | inputKey: strings.Split("wheat", ""),
315 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(1)},
316 | expectedResults: &trie.SearchResults{
317 | Results: []*trie.SearchResult{
318 | {Key: strings.Split("wheat", ""), EditDistance: 0},
319 | {Key: strings.Split("wheal", ""), EditDistance: 1},
320 | {Key: strings.Split("whet", ""), EditDistance: 1},
321 | {Key: strings.Split("what", ""), EditDistance: 1},
322 | {Key: strings.Split("cheat", ""), EditDistance: 1},
323 | {Key: strings.Split("heat", ""), EditDistance: 1},
324 | },
325 | },
326 | },
327 | {
328 | name: "edit-distance-with-max-results",
329 | inputKey: strings.Split("national", ""),
330 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(3), trie.WithMaxResults(13)},
331 | expectedResults: &trie.SearchResults{
332 | Results: []*trie.SearchResult{
333 | {Key: strings.Split("nation", ""), EditDistance: 2},
334 | {Key: strings.Split("national", ""), EditDistance: 0},
335 | {Key: strings.Split("nationalism", ""), EditDistance: 3},
336 | {Key: strings.Split("nationalist", ""), EditDistance: 3},
337 | {Key: strings.Split("nationality", ""), EditDistance: 3},
338 | {Key: strings.Split("nationalize", ""), EditDistance: 3},
339 | {Key: strings.Split("nationally", ""), EditDistance: 2},
340 | {Key: strings.Split("natal", ""), EditDistance: 3},
341 | {Key: strings.Split("natural", ""), EditDistance: 3},
342 | {Key: strings.Split("nautical", ""), EditDistance: 3},
343 | {Key: strings.Split("notion", ""), EditDistance: 3},
344 | {Key: strings.Split("notional", ""), EditDistance: 1},
345 | {Key: strings.Split("notionally", ""), EditDistance: 3},
346 | },
347 | },
348 | },
349 | {
350 | name: "edit-distance-with-topk",
351 | inputKey: strings.Split("national", ""),
352 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(3), trie.WithMaxResults(13), trie.WithTopKLeastEdited()},
353 | expectedResults: &trie.SearchResults{
354 | Results: []*trie.SearchResult{
355 | {Key: strings.Split("national", ""), EditDistance: 0},
356 | {Key: strings.Split("notional", ""), EditDistance: 1},
357 | {Key: strings.Split("rational", ""), EditDistance: 1},
358 | {Key: strings.Split("nation", ""), EditDistance: 2},
359 | {Key: strings.Split("nationally", ""), EditDistance: 2},
360 | {Key: strings.Split("atonal", ""), EditDistance: 2},
361 | {Key: strings.Split("factional", ""), EditDistance: 2},
362 | {Key: strings.Split("optional", ""), EditDistance: 2},
363 | {Key: strings.Split("rationale", ""), EditDistance: 2},
364 | {Key: strings.Split("nationalism", ""), EditDistance: 3},
365 | {Key: strings.Split("nationalist", ""), EditDistance: 3},
366 | {Key: strings.Split("nationality", ""), EditDistance: 3},
367 | {Key: strings.Split("nationalize", ""), EditDistance: 3},
368 | },
369 | },
370 | },
371 | {
372 | name: "edit-distance-with-topk-stop-after-prioritized",
373 | inputKey: strings.Split("national", ""),
374 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(3), trie.WithMaxResults(2), trie.WithTopKLeastEdited()},
375 | expectedResults: &trie.SearchResults{
376 | Results: []*trie.SearchResult{
377 | {Key: strings.Split("national", ""), EditDistance: 0},
378 | {Key: strings.Split("notional", ""), EditDistance: 1},
379 | },
380 | },
381 | },
382 | }
383 | for _, tc := range testCases {
384 | t.Run(tc.name, func(t *testing.T) {
385 | actual := tri.Search(tc.inputKey, tc.inputOptions...)
386 | assert.Len(t, actual.Results, len(tc.expectedResults.Results))
387 | assert.Equal(t, tc.expectedResults, actual)
388 | })
389 | }
390 | }
391 |
392 | func TestTrie_Search_InvalidUsage_EditDistance_LessThanZeroDistance(t *testing.T) {
393 | assert.PanicsWithError(t, "invalid usage: maxDistance must be greater than zero", func() {
394 | trie.WithMaxEditDistance(0)
395 | })
396 | assert.PanicsWithError(t, "invalid usage: maxDistance must be greater than zero", func() {
397 | trie.WithMaxEditDistance(-1)
398 | })
399 | }
400 |
401 | func TestTrie_Search_InvalidUsage_MaxResults_LessThanZero(t *testing.T) {
402 | assert.PanicsWithError(t, "invalid usage: maxResults must be greater than zero", func() {
403 | trie.WithMaxResults(0)
404 | })
405 | assert.PanicsWithError(t, "invalid usage: maxResults must be greater than zero", func() {
406 | trie.WithMaxResults(-1)
407 | })
408 | }
409 |
410 | func TestTrie_Search_InvalidUsage_EditOpsWithoutMaxEditDistance(t *testing.T) {
411 | tri := trie.New()
412 |
413 | assert.PanicsWithError(t, "invalid usage: WithEditOps() must not be passed without WithMaxEditDistance()", func() {
414 | tri.Search(nil, trie.WithEditOps())
415 | })
416 | }
417 |
418 | func TestTrie_Search_InvalidUsage_TopKWithoutMaxEditDistance(t *testing.T) {
419 | tri := trie.New()
420 |
421 | assert.PanicsWithError(t, "invalid usage: WithTopKLeastEdited() must not be passed without WithMaxEditDistance()", func() {
422 | tri.Search(nil, trie.WithTopKLeastEdited())
423 | })
424 | }
425 |
426 | func TestTrie_Search_InvalidUsage_ExactKeyWithMaxEditDistance(t *testing.T) {
427 | tri := trie.New()
428 |
429 | assert.PanicsWithError(t, "invalid usage: WithExactKey() must not be passed with WithMaxEditDistance()", func() {
430 | tri.Search(nil, trie.WithExactKey(), trie.WithMaxEditDistance(1))
431 | })
432 | }
433 |
434 | func TestTrie_Search_InvalidUsage_ExactKeyWithMaxResults(t *testing.T) {
435 | tri := trie.New()
436 |
437 | assert.PanicsWithError(t, "invalid usage: WithExactKey() must not be passed with WithMaxResults()", func() {
438 | tri.Search(nil, trie.WithExactKey(), trie.WithMaxResults(1))
439 | })
440 | }
441 |
442 | func TestTrie_Search_InvalidUsage_TopKLeastEditedWithoutMaxResults(t *testing.T) {
443 | tri := trie.New()
444 |
445 | assert.PanicsWithError(t, "invalid usage: WithTopKLeastEdited() must not be passed without WithMaxResults()", func() {
446 | tri.Search(nil, trie.WithMaxEditDistance(1), trie.WithTopKLeastEdited())
447 | })
448 | }
449 |
450 | func BenchmarkTrie_Search_WordsTrie(b *testing.B) {
451 | tri := getWordsTrie()
452 | benchmarks := []struct {
453 | name string
454 | inputKey []string
455 | inputOptions []func(*trie.SearchOptions)
456 | }{
457 | {
458 | name: "prefix",
459 | inputKey: strings.Split("ab", ""),
460 | },
461 | {
462 | name: "prefix-with-max-results",
463 | inputKey: strings.Split("ab", ""),
464 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxResults(20)},
465 | },
466 | {
467 | name: "edit-distance",
468 | inputKey: strings.Split("someday", ""),
469 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(5)},
470 | },
471 | {
472 | name: "edit-distance-with-edit-ops",
473 | inputKey: strings.Split("someday", ""),
474 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(5), trie.WithEditOps()},
475 | },
476 | {
477 | name: "edit-distance-with-edit-ops-with-max-results",
478 | inputKey: strings.Split("someday", ""),
479 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(5), trie.WithEditOps(), trie.WithMaxResults(20)},
480 | },
481 | {
482 | name: "edit-distance-with-edit-ops-with-max-results-with-top-k",
483 | inputKey: strings.Split("someday", ""),
484 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(5), trie.WithEditOps(), trie.WithMaxResults(20),
485 | trie.WithTopKLeastEdited()},
486 | },
487 | }
488 | var results *trie.SearchResults
489 | for _, bm := range benchmarks {
490 | b.Run(bm.name, func(b *testing.B) {
491 | for i := 0; i < b.N; i++ {
492 | results = tri.Search(bm.inputKey, bm.inputOptions...)
493 | }
494 | benchmarkResults = results
495 | })
496 | }
497 | }
498 |
499 | func getWordsTrie() *trie.Trie {
500 | if wordsTrie != nil {
501 | return wordsTrie
502 | }
503 | f, err := os.Open("./demo/wasm/words.txt")
504 | if err != nil {
505 | panic(err)
506 | }
507 | tri := trie.New()
508 | r := bufio.NewReader(f)
509 | for {
510 | word, err := r.ReadString('\n')
511 | if err == io.EOF {
512 | break
513 | }
514 | if err != nil {
515 | panic(err)
516 | }
517 | word = strings.TrimRight(word, "\n")
518 | word = strings.TrimRight(word, "\r") // windows
519 | key := strings.Split(word, "")
520 | tri.Put(key, nil)
521 | }
522 | wordsTrie = tri
523 | return tri
524 | }
525 |
--------------------------------------------------------------------------------
/search_whitebox_test.go:
--------------------------------------------------------------------------------
1 | package trie
2 |
3 | import (
4 | "fmt"
5 | "strings"
6 | "testing"
7 |
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestTrie_getEditOps(t *testing.T) {
12 | testCases := []struct {
13 | fromKeyColumn []string
14 | toKey []string
15 | rows [][]int
16 | expectedOps []*EditOp
17 | }{
18 | {
19 | fromKeyColumn: strings.Split("sitting", ""),
20 | toKey: strings.Split("kitten", ""),
21 | rows: [][]int{
22 | {0, 1, 2, 3, 4, 5, 6},
23 | {1, 1, 2, 3, 4, 5, 6},
24 | {2, 2, 1, 2, 3, 4, 5},
25 | {3, 3, 2, 1, 2, 3, 4},
26 | {4, 4, 3, 2, 1, 2, 3},
27 | {5, 5, 4, 3, 2, 2, 3},
28 | {6, 6, 5, 4, 3, 3, 2},
29 | {7, 7, 6, 5, 4, 4, 3},
30 | },
31 | expectedOps: []*EditOp{
32 | {Type: EditOpTypeReplace, KeyPart: "s", ReplaceWith: "k"},
33 | {Type: EditOpTypeNoEdit, KeyPart: "i"},
34 | {Type: EditOpTypeNoEdit, KeyPart: "t"},
35 | {Type: EditOpTypeNoEdit, KeyPart: "t"},
36 | {Type: EditOpTypeReplace, KeyPart: "i", ReplaceWith: "e"},
37 | {Type: EditOpTypeNoEdit, KeyPart: "n"},
38 | {Type: EditOpTypeDelete, KeyPart: "g"},
39 | },
40 | },
41 | {
42 | fromKeyColumn: strings.Split("Sunday", ""),
43 | toKey: strings.Split("Saturday", ""),
44 | rows: [][]int{
45 | {0, 1, 2, 3, 4, 5, 6, 7, 8},
46 | {1, 0, 1, 2, 3, 4, 5, 6, 7},
47 | {2, 1, 1, 2, 2, 3, 4, 5, 6},
48 | {3, 2, 2, 2, 3, 3, 4, 5, 6},
49 | {4, 3, 3, 3, 3, 4, 3, 4, 5},
50 | {5, 4, 3, 4, 4, 4, 4, 3, 4},
51 | {6, 5, 4, 4, 5, 5, 5, 4, 3},
52 | },
53 | expectedOps: []*EditOp{
54 | {Type: EditOpTypeNoEdit, KeyPart: "S"},
55 | {Type: EditOpTypeInsert, KeyPart: "a"},
56 | {Type: EditOpTypeInsert, KeyPart: "t"},
57 | {Type: EditOpTypeNoEdit, KeyPart: "u"},
58 | {Type: EditOpTypeReplace, KeyPart: "n", ReplaceWith: "r"},
59 | {Type: EditOpTypeNoEdit, KeyPart: "d"},
60 | {Type: EditOpTypeNoEdit, KeyPart: "a"},
61 | {Type: EditOpTypeNoEdit, KeyPart: "y"},
62 | },
63 | },
64 | }
65 |
66 | for _, tc := range testCases {
67 | t.Run(fmt.Sprintf("from %v to %v", tc.fromKeyColumn, tc.toKey), func(t *testing.T) {
68 | tri := New()
69 | actual := tri.getEditOps(&tc.rows, &tc.fromKeyColumn, tc.toKey)
70 | assert.Equal(t, tc.expectedOps, actual)
71 | })
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/trie.go:
--------------------------------------------------------------------------------
1 | package trie
2 |
3 | import (
4 | "github.com/shivamMg/ppds/tree"
5 | )
6 |
7 | const (
8 | RootKeyPart = "^"
9 | terminalSuffix = "($)"
10 | )
11 |
12 | // Trie is the trie data structure.
13 | type Trie struct {
14 | root *Node
15 | }
16 |
17 | // Node is a tree node inside Trie.
18 | type Node struct {
19 | keyPart string
20 | isTerminal bool
21 | value interface{}
22 | dllNode *dllNode
23 | children map[string]*Node
24 | childrenDLL *doublyLinkedList
25 | }
26 |
27 | func newNode(keyPart string) *Node {
28 | return &Node{
29 | keyPart: keyPart,
30 | children: make(map[string]*Node),
31 | childrenDLL: &doublyLinkedList{},
32 | }
33 | }
34 |
35 | // KeyPart returns the part (string) of the key ([]string) that this Node represents.
36 | func (n *Node) KeyPart() string {
37 | return n.keyPart
38 | }
39 |
40 | // IsTerminal returns a boolean that tells whether a key ends at this Node.
41 | func (n *Node) IsTerminal() bool {
42 | return n.isTerminal
43 | }
44 |
45 | // Value returns the value stored for the key ending at this Node. If Node is not a terminal, it returns nil.
46 | func (n *Node) Value() interface{} {
47 | return n.value
48 | }
49 |
50 | // SetValue sets the value for the key ending at this Node. If Node is not a terminal, value is not set.
51 | func (n *Node) SetValue(value interface{}) {
52 | if n.isTerminal {
53 | n.value = value
54 | }
55 | }
56 |
57 | // ChildNodes returns the child-nodes of this Node.
58 | func (n *Node) ChildNodes() []*Node {
59 | return n.childNodes()
60 | }
61 |
62 | // Data is used in Print(). Use Value() to get value at this Node.
63 | func (n *Node) Data() interface{} {
64 | data := n.keyPart
65 | if n.isTerminal {
66 | data += " " + terminalSuffix
67 | }
68 | return data
69 | }
70 |
71 | // Children is used in Print(). Use ChildNodes() to get child-nodes of this Node.
72 | func (n *Node) Children() []tree.Node {
73 | children := n.childNodes()
74 | result := make([]tree.Node, len(children))
75 | for i, child := range children {
76 | result[i] = tree.Node(child)
77 | }
78 | return result
79 | }
80 |
81 | // Print prints the tree rooted at this Node. A Trie's root node is printed as RootKeyPart.
82 | // All the terminal nodes are suffixed with ($).
83 | func (n *Node) Print() {
84 | tree.PrintHrn(n)
85 | }
86 |
87 | func (n *Node) Sprint() string {
88 | return tree.SprintHrn(n)
89 | }
90 |
91 | func (n *Node) childNodes() []*Node {
92 | children := make([]*Node, 0, len(n.children))
93 | dllNode := n.childrenDLL.head
94 | for dllNode != nil {
95 | children = append(children, dllNode.trieNode)
96 | dllNode = dllNode.next
97 | }
98 | return children
99 | }
100 |
101 | // New returns a new instance of Trie.
102 | func New() *Trie {
103 | return &Trie{root: newNode(RootKeyPart)}
104 | }
105 |
106 | // Root returns the root node of the Trie.
107 | func (t *Trie) Root() *Node {
108 | return t.root
109 | }
110 |
111 | // Put upserts value for the given key in the Trie. It returns a boolean depending on
112 | // whether the key already existed or not.
113 | func (t *Trie) Put(key []string, value interface{}) (existed bool) {
114 | node := t.root
115 | for i, part := range key {
116 | child, ok := node.children[part]
117 | if !ok {
118 | child = newNode(part)
119 | child.dllNode = newDLLNode(child)
120 | node.children[part] = child
121 | node.childrenDLL.append(child.dllNode)
122 | }
123 | if i == len(key)-1 {
124 | existed = child.isTerminal
125 | child.isTerminal = true
126 | child.value = value
127 | }
128 | node = child
129 | }
130 | return existed
131 | }
132 |
133 | // Delete deletes key-value for the given key in the Trie. It returns (value, true) if the key existed,
134 | // else (nil, false).
135 | func (t *Trie) Delete(key []string) (value interface{}, existed bool) {
136 | node := t.root
137 | parent := make(map[*Node]*Node)
138 | for _, keyPart := range key {
139 | child, ok := node.children[keyPart]
140 | if !ok {
141 | return nil, false
142 | }
143 | parent[child] = node
144 | node = child
145 | }
146 | if !node.isTerminal {
147 | return nil, false
148 | }
149 | node.isTerminal = false
150 | value = node.value
151 | node.value = nil
152 | for node != nil && !node.isTerminal && len(node.children) == 0 {
153 | delete(parent[node].children, node.keyPart)
154 | parent[node].childrenDLL.pop(node.dllNode)
155 | node = parent[node]
156 | }
157 | return value, true
158 | }
159 |
--------------------------------------------------------------------------------
/trie_test.go:
--------------------------------------------------------------------------------
1 | package trie_test
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/shivamMg/ppds/tree"
7 | "github.com/shivamMg/trie"
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestNode_SetValue(t *testing.T) {
12 | tri := trie.New()
13 | tri.Put([]string{"a", "b"}, 1)
14 | tri.Put([]string{"a", "b", "c"}, 2)
15 |
16 | node := tri.Root()
17 | node.SetValue(10)
18 | assert.Equal(t, nil, node.Value())
19 |
20 | node = tri.Root().ChildNodes()[0]
21 | node.SetValue(10)
22 | assert.Equal(t, nil, node.Value())
23 |
24 | node = tri.Root().ChildNodes()[0].ChildNodes()[0]
25 | node.SetValue(10)
26 | assert.Equal(t, 10, node.Value())
27 | }
28 |
29 | func TestTrie_Put(t *testing.T) {
30 | tri := trie.New()
31 | existed := tri.Put([]string{"an", "umbrella"}, 2)
32 | assert.False(t, existed)
33 | existed = tri.Put([]string{"the"}, 1)
34 | assert.False(t, existed)
35 | existed = tri.Put([]string{"the", "swimmer"}, 4)
36 | assert.False(t, existed)
37 | existed = tri.Put([]string{"the", "tree"}, 3)
38 | assert.False(t, existed)
39 |
40 | // validate full tree
41 | expected := `^
42 | ├─ an
43 | │ └─ umbrella ($)
44 | └─ the ($)
45 | ├─ swimmer ($)
46 | └─ tree ($)
47 | `
48 | actual := tree.SprintHrn(tri.Root())
49 | assert.Equal(t, expected, actual)
50 |
51 | // validate attributes for each node
52 | root := tri.Root()
53 | assert.Equal(t, root.KeyPart(), trie.RootKeyPart)
54 | assert.False(t, root.IsTerminal())
55 | assert.Nil(t, root.Value())
56 |
57 | rootChildren := root.ChildNodes()
58 | an, the := rootChildren[0], rootChildren[1]
59 | assert.Equal(t, an.KeyPart(), "an")
60 | assert.False(t, an.IsTerminal())
61 | assert.Nil(t, an.Value())
62 |
63 | assert.Equal(t, the.KeyPart(), "the")
64 | assert.True(t, the.IsTerminal())
65 | assert.Equal(t, the.Value(), 1)
66 |
67 | umbrella := an.ChildNodes()[0]
68 | assert.Equal(t, umbrella.KeyPart(), "umbrella")
69 | assert.True(t, umbrella.IsTerminal())
70 | assert.Equal(t, umbrella.Value(), 2)
71 |
72 | theChildren := the.ChildNodes()
73 | swimmer, tree_ := theChildren[0], theChildren[1]
74 | assert.Equal(t, swimmer.KeyPart(), "swimmer")
75 | assert.True(t, swimmer.IsTerminal())
76 | assert.Equal(t, swimmer.Value(), 4)
77 |
78 | assert.Equal(t, tree_.KeyPart(), "tree")
79 | assert.True(t, tree_.IsTerminal())
80 | assert.Equal(t, tree_.Value(), 3)
81 |
82 | // validate update
83 | existed = tri.Put([]string{"an", "umbrella"}, 5)
84 | assert.True(t, existed)
85 | assert.Equal(t, umbrella.Value(), 5)
86 | }
87 |
88 | func TestTrie_Delete(t *testing.T) {
89 | tri := trie.New()
90 | tri.Put([]string{"an", "apple", "tree"}, 5)
91 | tri.Put([]string{"an", "umbrella"}, 6)
92 | tri.Put([]string{"the"}, 1)
93 | tri.Put([]string{"the", "green", "tree"}, 4)
94 | tri.Put([]string{"the", "quick", "brown", "fox"}, 2)
95 | tri.Put([]string{"the", "quick", "swimmer"}, 3)
96 |
97 | value, existed := tri.Delete([]string{"the", "quick", "brown", "fox"})
98 | assert.True(t, existed)
99 | assert.Equal(t, value, 2)
100 | expected := `^
101 | ├─ an
102 | │ ├─ apple
103 | │ │ └─ tree ($)
104 | │ └─ umbrella ($)
105 | └─ the ($)
106 | ├─ green
107 | │ └─ tree ($)
108 | └─ quick
109 | └─ swimmer ($)
110 | `
111 | assert.Equal(t, expected, tri.Root().Sprint())
112 |
113 | value, existed = tri.Delete([]string{"the", "quick", "swimmer"})
114 | assert.True(t, existed)
115 | assert.Equal(t, value, 3)
116 | expected = `^
117 | ├─ an
118 | │ ├─ apple
119 | │ │ └─ tree ($)
120 | │ └─ umbrella ($)
121 | └─ the ($)
122 | └─ green
123 | └─ tree ($)
124 | `
125 | assert.Equal(t, expected, tri.Root().Sprint())
126 |
127 | value, existed = tri.Delete([]string{"the"})
128 | assert.True(t, existed)
129 | assert.Equal(t, value, 1)
130 | expected = `^
131 | ├─ an
132 | │ ├─ apple
133 | │ │ └─ tree ($)
134 | │ └─ umbrella ($)
135 | └─ the
136 | └─ green
137 | └─ tree ($)
138 | `
139 | assert.Equal(t, expected, tri.Root().Sprint())
140 |
141 | value, existed = tri.Delete([]string{"non", "existing"})
142 | assert.False(t, existed)
143 | assert.Nil(t, value)
144 | expected = `^
145 | ├─ an
146 | │ ├─ apple
147 | │ │ └─ tree ($)
148 | │ └─ umbrella ($)
149 | └─ the
150 | └─ green
151 | └─ tree ($)
152 | `
153 | assert.Equal(t, expected, tri.Root().Sprint())
154 |
155 | value, existed = tri.Delete([]string{"an"})
156 | assert.False(t, existed)
157 | assert.Nil(t, value)
158 | expected = `^
159 | ├─ an
160 | │ ├─ apple
161 | │ │ └─ tree ($)
162 | │ └─ umbrella ($)
163 | └─ the
164 | └─ green
165 | └─ tree ($)
166 | `
167 | assert.Equal(t, expected, tri.Root().Sprint())
168 | }
169 |
--------------------------------------------------------------------------------
/walk.go:
--------------------------------------------------------------------------------
1 | package trie
2 |
3 | type WalkFunc func(key []string, node *Node) error
4 |
5 | // Walk traverses the Trie and calls walker function. If walker function returns an error, Walk early-returns with that error.
6 | // Traversal follows insertion order.
7 | func (t *Trie) Walk(key []string, walker WalkFunc) error {
8 | node := t.root
9 | for _, keyPart := range key {
10 | child, ok := node.children[keyPart]
11 | if !ok {
12 | return nil
13 | }
14 | node = child
15 | }
16 | return t.walk(node, &key, walker)
17 | }
18 |
19 | func (t *Trie) walk(node *Node, prefixKey *[]string, walker WalkFunc) error {
20 | if node.isTerminal {
21 | key := make([]string, len(*prefixKey))
22 | copy(key, *prefixKey)
23 | if err := walker(key, node); err != nil {
24 | return err
25 | }
26 | }
27 |
28 | for dllNode := node.childrenDLL.head; dllNode != nil; dllNode = dllNode.next {
29 | child := dllNode.trieNode
30 | *prefixKey = append(*prefixKey, child.keyPart)
31 | err := t.walk(child, prefixKey, walker)
32 | *prefixKey = (*prefixKey)[:len(*prefixKey)-1]
33 | if err != nil {
34 | return err
35 | }
36 | }
37 | return nil
38 | }
39 |
--------------------------------------------------------------------------------
/walk_test.go:
--------------------------------------------------------------------------------
1 | package trie_test
2 |
3 | import (
4 | "errors"
5 | "testing"
6 |
7 | "github.com/shivamMg/trie"
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestTrie_WalkErr(t *testing.T) {
12 | tri := trie.New()
13 | tri.Put([]string{"d", "a", "l", "i"}, 1)
14 | tri.Put([]string{"d", "a", "l", "i", "b"}, 2)
15 | tri.Put([]string{"d", "a", "l", "i", "b", "e"}, 3)
16 | tri.Put([]string{"d", "a", "l", "i", "b", "e", "r", "t"}, 4)
17 |
18 | var selected []string
19 | walker := func(key []string, node *trie.Node) error {
20 | what := node.Value().(int)
21 | if what == 3 {
22 | selected = key
23 | return errors.New("found")
24 | }
25 | return nil
26 | }
27 |
28 | err := tri.Walk(nil, walker)
29 | assert.EqualError(t, err, "found")
30 | assert.EqualValues(t, []string{"d", "a", "l", "i", "b", "e"}, selected)
31 | }
32 |
33 | func TestTrie_Walk(t *testing.T) {
34 | tri := trie.New()
35 | tri.Put([]string{"d", "a", "l", "i"}, []int{0, 1, 2, 4, 5})
36 | tri.Put([]string{"d", "a", "l", "i", "b"}, []int{1, 2, 4, 5})
37 | tri.Put([]string{"d", "a", "l", "i", "b", "e"}, []int{1, 0, 2, 4, 5, 0})
38 | tri.Put([]string{"d", "a", "l", "i", "b", "e", "r", "t"}, []int{1, 2, 4, 5})
39 | type KVPair struct {
40 | key []string
41 | value []int
42 | }
43 | var selected []KVPair
44 | walker := func(key []string, node *trie.Node) error {
45 | what := node.Value().([]int)
46 | for _, i := range what {
47 | if i == 0 {
48 | selected = append(selected, KVPair{key, what})
49 | break
50 | }
51 | }
52 | return nil
53 | }
54 |
55 | err := tri.Walk(nil, walker)
56 | assert.NoError(t, err)
57 | expected := []KVPair{
58 | {[]string{"d", "a", "l", "i"}, []int{0, 1, 2, 4, 5}},
59 | {[]string{"d", "a", "l", "i", "b", "e"}, []int{1, 0, 2, 4, 5, 0}},
60 | }
61 | assert.EqualValues(t, expected, selected)
62 |
63 | selected = nil
64 | err = tri.Walk([]string{"d", "a", "l", "i", "b"}, walker)
65 | assert.NoError(t, err)
66 | expected = []KVPair{
67 | {[]string{"d", "a", "l", "i", "b", "e"}, []int{1, 0, 2, 4, 5, 0}},
68 | }
69 | assert.EqualValues(t, expected, selected)
70 |
71 | selected = nil
72 | err = tri.Walk([]string{"a", "b", "c"}, walker)
73 | assert.NoError(t, err)
74 | assert.Nil(t, selected)
75 | }
76 |
--------------------------------------------------------------------------------