├── .github └── workflows │ ├── ci.yml │ └── gh-pages.yml ├── .gitignore ├── .vscode └── settings.json ├── LICENSE.md ├── README.md ├── demo ├── Makefile ├── server │ └── main.go ├── site │ ├── favicon.ico │ ├── index.html │ └── wasm_exec.js ├── wasm │ ├── go.mod │ ├── go.sum │ ├── main.go │ └── words.txt └── words.py ├── dll.go ├── example_test.go ├── go.mod ├── go.sum ├── go.work ├── go.work.sum ├── heap.go ├── search.go ├── search_test.go ├── search_whitebox_test.go ├── trie.go ├── trie_test.go ├── walk.go └── walk_test.go /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Install Go 14 | uses: actions/setup-go@v2 15 | with: 16 | go-version: '1.18' 17 | - name: Install goimports 18 | run: go install golang.org/x/tools/cmd/goimports@latest 19 | - name: Checkout code 20 | uses: actions/checkout@v3 21 | - run: goimports -w . 22 | - run: go mod tidy 23 | - name: Verify no changes from goimports and go mod tidy 24 | run: | 25 | if [ -n "$(git status --porcelain)" ]; then 26 | exit 1 27 | fi 28 | test: 29 | runs-on: ubuntu-latest 30 | steps: 31 | - name: Install Go 32 | uses: actions/setup-go@v2 33 | with: 34 | go-version: '1.18' 35 | - name: Checkout code 36 | uses: actions/checkout@v2 37 | - name: Run tests 38 | run: go test -v ./... 39 | - name: Run benchmarks 40 | run: go test -v -bench=. -------------------------------------------------------------------------------- /.github/workflows/gh-pages.yml: -------------------------------------------------------------------------------- 1 | name: github pages 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | jobs: 10 | deploy: 11 | runs-on: ubuntu-20.04 12 | steps: 13 | - name: Install Go 14 | uses: actions/setup-go@v2 15 | with: 16 | go-version: '^1.18' 17 | - name: Checkout code 18 | uses: actions/checkout@v2 19 | - name: Create wasm module 20 | run: cd demo && make build-wasm 21 | - name: Deploy 22 | uses: peaceiris/actions-gh-pages@v3 23 | if: github.ref == 'refs/heads/master' 24 | with: 25 | github_token: ${{ secrets.GITHUB_TOKEN }} 26 | publish_dir: ./demo/site -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | main.wasm 2 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "gopls": { 3 | "build.env": { 4 | "GOOS": "js", 5 | "GOARCH": "wasm" 6 | } 7 | } 8 | } -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 shivam mamgain 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## trie [![godoc](https://godoc.org/github.com/shivammg/trie?status.svg)](https://godoc.org/github.com/shivamMg/trie) ![Build](https://github.com/shivamMg/trie/actions/workflows/ci.yml/badge.svg?branch=master) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 2 | 3 | An implementation of the [Trie](https://en.wikipedia.org/wiki/Trie) data structure in Go. It provides more features than the usual Trie prefix-search, and is meant to be used for auto-completion. 4 | 5 | ### Demo 6 | 7 | A WebAssembly demo can be tried at [shivamMg.github.io/trie](https://shivammg.github.io/trie/). 8 | 9 | ### Features 10 | 11 | - Keys are `[]string` instead of `string`, thereby supporting more use cases - e.g. []string{the quick brown fox} can be a key where each word will be a node in the Trie 12 | - Support for Put key and Delete key 13 | - Support for Prefix search - e.g. searching for _nation_ might return _nation_, _national_, _nationalism_, _nationalist_, etc. 14 | - Support for Edit distance search (aka Levenshtein distance) - e.g. searching for _wheat_ might return similar looking words like _wheat_, _cheat_, _heat_, _what_, etc. 15 | - Order of search results is deterministic. It follows insertion order. 16 | 17 | ### Examples 18 | 19 | ```go 20 | tri := trie.New() 21 | // Put keys ([]string) and values (any) 22 | tri.Put([]string{"the"}, 1) 23 | tri.Put([]string{"the", "quick", "brown", "fox"}, 2) 24 | tri.Put([]string{"the", "quick", "sports", "car"}, 3) 25 | tri.Put([]string{"the", "green", "tree"}, 4) 26 | tri.Put([]string{"an", "apple", "tree"}, 5) 27 | tri.Put([]string{"an", "umbrella"}, 6) 28 | 29 | tri.Root().Print() 30 | // Output (full trie with terminals ending with ($)): 31 | // ^ 32 | // ├─ the ($) 33 | // │ ├─ quick 34 | // │ │ ├─ brown 35 | // │ │ │ └─ fox ($) 36 | // │ │ └─ sports 37 | // │ │ └─ car ($) 38 | // │ └─ green 39 | // │ └─ tree ($) 40 | // └─ an 41 | // ├─ apple 42 | // │ └─ tree ($) 43 | // └─ umbrella ($) 44 | 45 | results := tri.Search([]string{"the", "quick"}) 46 | for _, res := range results.Results { 47 | fmt.Println(res.Key, res.Value) 48 | } 49 | // Output (prefix-based search): 50 | // [the quick brown fox] 2 51 | // [the quick sports car] 3 52 | 53 | key := []string{"the", "tree"} 54 | results = tri.Search(key, trie.WithMaxEditDistance(2), // An edit can be insert, delete, replace 55 | trie.WithEditOps()) 56 | for _, res := range results.Results { 57 | fmt.Println(res.Key, res.EditDistance) // EditDistance is number of edits needed to convert to [the tree] 58 | } 59 | // Output (results not more than 2 edits away from [the tree]): 60 | // [the] 1 61 | // [the green tree] 1 62 | // [an apple tree] 2 63 | // [an umbrella] 2 64 | 65 | result := results.Results[2] 66 | fmt.Printf("To convert %v to %v:\n", result.Key, key) 67 | printEditOps(result.EditOps) 68 | // Output (edit operations needed to covert a result to [the tree]): 69 | // To convert [an apple tree] to [the tree]: 70 | // - delete "an" 71 | // - replace "apple" with "the" 72 | // - don't edit "tree" 73 | 74 | results = tri.Search(key, trie.WithMaxEditDistance(2), trie.WithTopKLeastEdited(), trie.WithMaxResults(2)) 75 | for _, res := range results.Results { 76 | fmt.Println(res.Key, res.Value, res.EditDistance) 77 | } 78 | // Output (top 2 least edited results): 79 | // [the] 1 1 80 | // [the green tree] 4 1 81 | ``` 82 | 83 | ### References 84 | 85 | * https://en.wikipedia.org/wiki/Levenshtein_distance#Iterative_with_full_matrix 86 | * http://stevehanov.ca/blog/?id=114 87 | * https://gist.github.com/jlherren/d97839b1276b9bd7faa930f74711a4b6 -------------------------------------------------------------------------------- /demo/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: build-wasm 2 | build-wasm: 3 | cd wasm && GOOS=js GOARCH=wasm go build -o ../site/main.wasm 4 | 5 | .PHONY: server 6 | server: 7 | cd server && go run . 8 | -------------------------------------------------------------------------------- /demo/server/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | "net/http" 6 | ) 7 | 8 | const ( 9 | siteDir = "./../site" 10 | addr = ":8080" 11 | ) 12 | 13 | func main() { 14 | log.Println("server will start at", addr) 15 | log.Fatal(http.ListenAndServe(addr, http.FileServer(http.Dir(siteDir)))) 16 | } 17 | -------------------------------------------------------------------------------- /demo/site/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shivamMg/trie/fdf2c274601aaf8ec709ce4c1461b5cca1dd868c/demo/site/favicon.ico -------------------------------------------------------------------------------- /demo/site/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | trie auto-completion demo 5 | 6 | 7 | 8 | 53 | 54 | 55 |
56 |
57 |
58 |

59 | This is a WebAssembly demo for auto-completion using the trie Go library. 60 | The WASM module contains a Trie populated with English dictionary words. Searching a word retrieves max 10 results. 61 |

62 |

63 | The default search is Prefix search i.e. results will have the same prefix. 64 |

65 |

66 | If "Edit distance search" is enabled, then results will be at most 3 67 | edit distance away and in least-edited-first order. 68 |

69 |
70 |
71 | 72 | 73 |
74 |
75 | 76 |
77 |
78 |
79 |
80 |
81 | 82 | 134 | -------------------------------------------------------------------------------- /demo/site/wasm_exec.js: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | "use strict"; 6 | 7 | (() => { 8 | const enosys = () => { 9 | const err = new Error("not implemented"); 10 | err.code = "ENOSYS"; 11 | return err; 12 | }; 13 | 14 | if (!globalThis.fs) { 15 | let outputBuf = ""; 16 | globalThis.fs = { 17 | constants: { O_WRONLY: -1, O_RDWR: -1, O_CREAT: -1, O_TRUNC: -1, O_APPEND: -1, O_EXCL: -1 }, // unused 18 | writeSync(fd, buf) { 19 | outputBuf += decoder.decode(buf); 20 | const nl = outputBuf.lastIndexOf("\n"); 21 | if (nl != -1) { 22 | console.log(outputBuf.substr(0, nl)); 23 | outputBuf = outputBuf.substr(nl + 1); 24 | } 25 | return buf.length; 26 | }, 27 | write(fd, buf, offset, length, position, callback) { 28 | if (offset !== 0 || length !== buf.length || position !== null) { 29 | callback(enosys()); 30 | return; 31 | } 32 | const n = this.writeSync(fd, buf); 33 | callback(null, n); 34 | }, 35 | chmod(path, mode, callback) { callback(enosys()); }, 36 | chown(path, uid, gid, callback) { callback(enosys()); }, 37 | close(fd, callback) { callback(enosys()); }, 38 | fchmod(fd, mode, callback) { callback(enosys()); }, 39 | fchown(fd, uid, gid, callback) { callback(enosys()); }, 40 | fstat(fd, callback) { callback(enosys()); }, 41 | fsync(fd, callback) { callback(null); }, 42 | ftruncate(fd, length, callback) { callback(enosys()); }, 43 | lchown(path, uid, gid, callback) { callback(enosys()); }, 44 | link(path, link, callback) { callback(enosys()); }, 45 | lstat(path, callback) { callback(enosys()); }, 46 | mkdir(path, perm, callback) { callback(enosys()); }, 47 | open(path, flags, mode, callback) { callback(enosys()); }, 48 | read(fd, buffer, offset, length, position, callback) { callback(enosys()); }, 49 | readdir(path, callback) { callback(enosys()); }, 50 | readlink(path, callback) { callback(enosys()); }, 51 | rename(from, to, callback) { callback(enosys()); }, 52 | rmdir(path, callback) { callback(enosys()); }, 53 | stat(path, callback) { callback(enosys()); }, 54 | symlink(path, link, callback) { callback(enosys()); }, 55 | truncate(path, length, callback) { callback(enosys()); }, 56 | unlink(path, callback) { callback(enosys()); }, 57 | utimes(path, atime, mtime, callback) { callback(enosys()); }, 58 | }; 59 | } 60 | 61 | if (!globalThis.process) { 62 | globalThis.process = { 63 | getuid() { return -1; }, 64 | getgid() { return -1; }, 65 | geteuid() { return -1; }, 66 | getegid() { return -1; }, 67 | getgroups() { throw enosys(); }, 68 | pid: -1, 69 | ppid: -1, 70 | umask() { throw enosys(); }, 71 | cwd() { throw enosys(); }, 72 | chdir() { throw enosys(); }, 73 | } 74 | } 75 | 76 | if (!globalThis.crypto) { 77 | throw new Error("globalThis.crypto is not available, polyfill required (crypto.getRandomValues only)"); 78 | } 79 | 80 | if (!globalThis.performance) { 81 | throw new Error("globalThis.performance is not available, polyfill required (performance.now only)"); 82 | } 83 | 84 | if (!globalThis.TextEncoder) { 85 | throw new Error("globalThis.TextEncoder is not available, polyfill required"); 86 | } 87 | 88 | if (!globalThis.TextDecoder) { 89 | throw new Error("globalThis.TextDecoder is not available, polyfill required"); 90 | } 91 | 92 | const encoder = new TextEncoder("utf-8"); 93 | const decoder = new TextDecoder("utf-8"); 94 | 95 | globalThis.Go = class { 96 | constructor() { 97 | this.argv = ["js"]; 98 | this.env = {}; 99 | this.exit = (code) => { 100 | if (code !== 0) { 101 | console.warn("exit code:", code); 102 | } 103 | }; 104 | this._exitPromise = new Promise((resolve) => { 105 | this._resolveExitPromise = resolve; 106 | }); 107 | this._pendingEvent = null; 108 | this._scheduledTimeouts = new Map(); 109 | this._nextCallbackTimeoutID = 1; 110 | 111 | const setInt64 = (addr, v) => { 112 | this.mem.setUint32(addr + 0, v, true); 113 | this.mem.setUint32(addr + 4, Math.floor(v / 4294967296), true); 114 | } 115 | 116 | const getInt64 = (addr) => { 117 | const low = this.mem.getUint32(addr + 0, true); 118 | const high = this.mem.getInt32(addr + 4, true); 119 | return low + high * 4294967296; 120 | } 121 | 122 | const loadValue = (addr) => { 123 | const f = this.mem.getFloat64(addr, true); 124 | if (f === 0) { 125 | return undefined; 126 | } 127 | if (!isNaN(f)) { 128 | return f; 129 | } 130 | 131 | const id = this.mem.getUint32(addr, true); 132 | return this._values[id]; 133 | } 134 | 135 | const storeValue = (addr, v) => { 136 | const nanHead = 0x7FF80000; 137 | 138 | if (typeof v === "number" && v !== 0) { 139 | if (isNaN(v)) { 140 | this.mem.setUint32(addr + 4, nanHead, true); 141 | this.mem.setUint32(addr, 0, true); 142 | return; 143 | } 144 | this.mem.setFloat64(addr, v, true); 145 | return; 146 | } 147 | 148 | if (v === undefined) { 149 | this.mem.setFloat64(addr, 0, true); 150 | return; 151 | } 152 | 153 | let id = this._ids.get(v); 154 | if (id === undefined) { 155 | id = this._idPool.pop(); 156 | if (id === undefined) { 157 | id = this._values.length; 158 | } 159 | this._values[id] = v; 160 | this._goRefCounts[id] = 0; 161 | this._ids.set(v, id); 162 | } 163 | this._goRefCounts[id]++; 164 | let typeFlag = 0; 165 | switch (typeof v) { 166 | case "object": 167 | if (v !== null) { 168 | typeFlag = 1; 169 | } 170 | break; 171 | case "string": 172 | typeFlag = 2; 173 | break; 174 | case "symbol": 175 | typeFlag = 3; 176 | break; 177 | case "function": 178 | typeFlag = 4; 179 | break; 180 | } 181 | this.mem.setUint32(addr + 4, nanHead | typeFlag, true); 182 | this.mem.setUint32(addr, id, true); 183 | } 184 | 185 | const loadSlice = (addr) => { 186 | const array = getInt64(addr + 0); 187 | const len = getInt64(addr + 8); 188 | return new Uint8Array(this._inst.exports.mem.buffer, array, len); 189 | } 190 | 191 | const loadSliceOfValues = (addr) => { 192 | const array = getInt64(addr + 0); 193 | const len = getInt64(addr + 8); 194 | const a = new Array(len); 195 | for (let i = 0; i < len; i++) { 196 | a[i] = loadValue(array + i * 8); 197 | } 198 | return a; 199 | } 200 | 201 | const loadString = (addr) => { 202 | const saddr = getInt64(addr + 0); 203 | const len = getInt64(addr + 8); 204 | return decoder.decode(new DataView(this._inst.exports.mem.buffer, saddr, len)); 205 | } 206 | 207 | const timeOrigin = Date.now() - performance.now(); 208 | this.importObject = { 209 | go: { 210 | // Go's SP does not change as long as no Go code is running. Some operations (e.g. calls, getters and setters) 211 | // may synchronously trigger a Go event handler. This makes Go code get executed in the middle of the imported 212 | // function. A goroutine can switch to a new stack if the current stack is too small (see morestack function). 213 | // This changes the SP, thus we have to update the SP used by the imported function. 214 | 215 | // func wasmExit(code int32) 216 | "runtime.wasmExit": (sp) => { 217 | sp >>>= 0; 218 | const code = this.mem.getInt32(sp + 8, true); 219 | this.exited = true; 220 | delete this._inst; 221 | delete this._values; 222 | delete this._goRefCounts; 223 | delete this._ids; 224 | delete this._idPool; 225 | this.exit(code); 226 | }, 227 | 228 | // func wasmWrite(fd uintptr, p unsafe.Pointer, n int32) 229 | "runtime.wasmWrite": (sp) => { 230 | sp >>>= 0; 231 | const fd = getInt64(sp + 8); 232 | const p = getInt64(sp + 16); 233 | const n = this.mem.getInt32(sp + 24, true); 234 | fs.writeSync(fd, new Uint8Array(this._inst.exports.mem.buffer, p, n)); 235 | }, 236 | 237 | // func resetMemoryDataView() 238 | "runtime.resetMemoryDataView": (sp) => { 239 | sp >>>= 0; 240 | this.mem = new DataView(this._inst.exports.mem.buffer); 241 | }, 242 | 243 | // func nanotime1() int64 244 | "runtime.nanotime1": (sp) => { 245 | sp >>>= 0; 246 | setInt64(sp + 8, (timeOrigin + performance.now()) * 1000000); 247 | }, 248 | 249 | // func walltime() (sec int64, nsec int32) 250 | "runtime.walltime": (sp) => { 251 | sp >>>= 0; 252 | const msec = (new Date).getTime(); 253 | setInt64(sp + 8, msec / 1000); 254 | this.mem.setInt32(sp + 16, (msec % 1000) * 1000000, true); 255 | }, 256 | 257 | // func scheduleTimeoutEvent(delay int64) int32 258 | "runtime.scheduleTimeoutEvent": (sp) => { 259 | sp >>>= 0; 260 | const id = this._nextCallbackTimeoutID; 261 | this._nextCallbackTimeoutID++; 262 | this._scheduledTimeouts.set(id, setTimeout( 263 | () => { 264 | this._resume(); 265 | while (this._scheduledTimeouts.has(id)) { 266 | // for some reason Go failed to register the timeout event, log and try again 267 | // (temporary workaround for https://github.com/golang/go/issues/28975) 268 | console.warn("scheduleTimeoutEvent: missed timeout event"); 269 | this._resume(); 270 | } 271 | }, 272 | getInt64(sp + 8) + 1, // setTimeout has been seen to fire up to 1 millisecond early 273 | )); 274 | this.mem.setInt32(sp + 16, id, true); 275 | }, 276 | 277 | // func clearTimeoutEvent(id int32) 278 | "runtime.clearTimeoutEvent": (sp) => { 279 | sp >>>= 0; 280 | const id = this.mem.getInt32(sp + 8, true); 281 | clearTimeout(this._scheduledTimeouts.get(id)); 282 | this._scheduledTimeouts.delete(id); 283 | }, 284 | 285 | // func getRandomData(r []byte) 286 | "runtime.getRandomData": (sp) => { 287 | sp >>>= 0; 288 | crypto.getRandomValues(loadSlice(sp + 8)); 289 | }, 290 | 291 | // func finalizeRef(v ref) 292 | "syscall/js.finalizeRef": (sp) => { 293 | sp >>>= 0; 294 | const id = this.mem.getUint32(sp + 8, true); 295 | this._goRefCounts[id]--; 296 | if (this._goRefCounts[id] === 0) { 297 | const v = this._values[id]; 298 | this._values[id] = null; 299 | this._ids.delete(v); 300 | this._idPool.push(id); 301 | } 302 | }, 303 | 304 | // func stringVal(value string) ref 305 | "syscall/js.stringVal": (sp) => { 306 | sp >>>= 0; 307 | storeValue(sp + 24, loadString(sp + 8)); 308 | }, 309 | 310 | // func valueGet(v ref, p string) ref 311 | "syscall/js.valueGet": (sp) => { 312 | sp >>>= 0; 313 | const result = Reflect.get(loadValue(sp + 8), loadString(sp + 16)); 314 | sp = this._inst.exports.getsp() >>> 0; // see comment above 315 | storeValue(sp + 32, result); 316 | }, 317 | 318 | // func valueSet(v ref, p string, x ref) 319 | "syscall/js.valueSet": (sp) => { 320 | sp >>>= 0; 321 | Reflect.set(loadValue(sp + 8), loadString(sp + 16), loadValue(sp + 32)); 322 | }, 323 | 324 | // func valueDelete(v ref, p string) 325 | "syscall/js.valueDelete": (sp) => { 326 | sp >>>= 0; 327 | Reflect.deleteProperty(loadValue(sp + 8), loadString(sp + 16)); 328 | }, 329 | 330 | // func valueIndex(v ref, i int) ref 331 | "syscall/js.valueIndex": (sp) => { 332 | sp >>>= 0; 333 | storeValue(sp + 24, Reflect.get(loadValue(sp + 8), getInt64(sp + 16))); 334 | }, 335 | 336 | // valueSetIndex(v ref, i int, x ref) 337 | "syscall/js.valueSetIndex": (sp) => { 338 | sp >>>= 0; 339 | Reflect.set(loadValue(sp + 8), getInt64(sp + 16), loadValue(sp + 24)); 340 | }, 341 | 342 | // func valueCall(v ref, m string, args []ref) (ref, bool) 343 | "syscall/js.valueCall": (sp) => { 344 | sp >>>= 0; 345 | try { 346 | const v = loadValue(sp + 8); 347 | const m = Reflect.get(v, loadString(sp + 16)); 348 | const args = loadSliceOfValues(sp + 32); 349 | const result = Reflect.apply(m, v, args); 350 | sp = this._inst.exports.getsp() >>> 0; // see comment above 351 | storeValue(sp + 56, result); 352 | this.mem.setUint8(sp + 64, 1); 353 | } catch (err) { 354 | sp = this._inst.exports.getsp() >>> 0; // see comment above 355 | storeValue(sp + 56, err); 356 | this.mem.setUint8(sp + 64, 0); 357 | } 358 | }, 359 | 360 | // func valueInvoke(v ref, args []ref) (ref, bool) 361 | "syscall/js.valueInvoke": (sp) => { 362 | sp >>>= 0; 363 | try { 364 | const v = loadValue(sp + 8); 365 | const args = loadSliceOfValues(sp + 16); 366 | const result = Reflect.apply(v, undefined, args); 367 | sp = this._inst.exports.getsp() >>> 0; // see comment above 368 | storeValue(sp + 40, result); 369 | this.mem.setUint8(sp + 48, 1); 370 | } catch (err) { 371 | sp = this._inst.exports.getsp() >>> 0; // see comment above 372 | storeValue(sp + 40, err); 373 | this.mem.setUint8(sp + 48, 0); 374 | } 375 | }, 376 | 377 | // func valueNew(v ref, args []ref) (ref, bool) 378 | "syscall/js.valueNew": (sp) => { 379 | sp >>>= 0; 380 | try { 381 | const v = loadValue(sp + 8); 382 | const args = loadSliceOfValues(sp + 16); 383 | const result = Reflect.construct(v, args); 384 | sp = this._inst.exports.getsp() >>> 0; // see comment above 385 | storeValue(sp + 40, result); 386 | this.mem.setUint8(sp + 48, 1); 387 | } catch (err) { 388 | sp = this._inst.exports.getsp() >>> 0; // see comment above 389 | storeValue(sp + 40, err); 390 | this.mem.setUint8(sp + 48, 0); 391 | } 392 | }, 393 | 394 | // func valueLength(v ref) int 395 | "syscall/js.valueLength": (sp) => { 396 | sp >>>= 0; 397 | setInt64(sp + 16, parseInt(loadValue(sp + 8).length)); 398 | }, 399 | 400 | // valuePrepareString(v ref) (ref, int) 401 | "syscall/js.valuePrepareString": (sp) => { 402 | sp >>>= 0; 403 | const str = encoder.encode(String(loadValue(sp + 8))); 404 | storeValue(sp + 16, str); 405 | setInt64(sp + 24, str.length); 406 | }, 407 | 408 | // valueLoadString(v ref, b []byte) 409 | "syscall/js.valueLoadString": (sp) => { 410 | sp >>>= 0; 411 | const str = loadValue(sp + 8); 412 | loadSlice(sp + 16).set(str); 413 | }, 414 | 415 | // func valueInstanceOf(v ref, t ref) bool 416 | "syscall/js.valueInstanceOf": (sp) => { 417 | sp >>>= 0; 418 | this.mem.setUint8(sp + 24, (loadValue(sp + 8) instanceof loadValue(sp + 16)) ? 1 : 0); 419 | }, 420 | 421 | // func copyBytesToGo(dst []byte, src ref) (int, bool) 422 | "syscall/js.copyBytesToGo": (sp) => { 423 | sp >>>= 0; 424 | const dst = loadSlice(sp + 8); 425 | const src = loadValue(sp + 32); 426 | if (!(src instanceof Uint8Array || src instanceof Uint8ClampedArray)) { 427 | this.mem.setUint8(sp + 48, 0); 428 | return; 429 | } 430 | const toCopy = src.subarray(0, dst.length); 431 | dst.set(toCopy); 432 | setInt64(sp + 40, toCopy.length); 433 | this.mem.setUint8(sp + 48, 1); 434 | }, 435 | 436 | // func copyBytesToJS(dst ref, src []byte) (int, bool) 437 | "syscall/js.copyBytesToJS": (sp) => { 438 | sp >>>= 0; 439 | const dst = loadValue(sp + 8); 440 | const src = loadSlice(sp + 16); 441 | if (!(dst instanceof Uint8Array || dst instanceof Uint8ClampedArray)) { 442 | this.mem.setUint8(sp + 48, 0); 443 | return; 444 | } 445 | const toCopy = src.subarray(0, dst.length); 446 | dst.set(toCopy); 447 | setInt64(sp + 40, toCopy.length); 448 | this.mem.setUint8(sp + 48, 1); 449 | }, 450 | 451 | "debug": (value) => { 452 | console.log(value); 453 | }, 454 | } 455 | }; 456 | } 457 | 458 | async run(instance) { 459 | if (!(instance instanceof WebAssembly.Instance)) { 460 | throw new Error("Go.run: WebAssembly.Instance expected"); 461 | } 462 | this._inst = instance; 463 | this.mem = new DataView(this._inst.exports.mem.buffer); 464 | this._values = [ // JS values that Go currently has references to, indexed by reference id 465 | NaN, 466 | 0, 467 | null, 468 | true, 469 | false, 470 | globalThis, 471 | this, 472 | ]; 473 | this._goRefCounts = new Array(this._values.length).fill(Infinity); // number of references that Go has to a JS value, indexed by reference id 474 | this._ids = new Map([ // mapping from JS values to reference ids 475 | [0, 1], 476 | [null, 2], 477 | [true, 3], 478 | [false, 4], 479 | [globalThis, 5], 480 | [this, 6], 481 | ]); 482 | this._idPool = []; // unused ids that have been garbage collected 483 | this.exited = false; // whether the Go program has exited 484 | 485 | // Pass command line arguments and environment variables to WebAssembly by writing them to the linear memory. 486 | let offset = 4096; 487 | 488 | const strPtr = (str) => { 489 | const ptr = offset; 490 | const bytes = encoder.encode(str + "\0"); 491 | new Uint8Array(this.mem.buffer, offset, bytes.length).set(bytes); 492 | offset += bytes.length; 493 | if (offset % 8 !== 0) { 494 | offset += 8 - (offset % 8); 495 | } 496 | return ptr; 497 | }; 498 | 499 | const argc = this.argv.length; 500 | 501 | const argvPtrs = []; 502 | this.argv.forEach((arg) => { 503 | argvPtrs.push(strPtr(arg)); 504 | }); 505 | argvPtrs.push(0); 506 | 507 | const keys = Object.keys(this.env).sort(); 508 | keys.forEach((key) => { 509 | argvPtrs.push(strPtr(`${key}=${this.env[key]}`)); 510 | }); 511 | argvPtrs.push(0); 512 | 513 | const argv = offset; 514 | argvPtrs.forEach((ptr) => { 515 | this.mem.setUint32(offset, ptr, true); 516 | this.mem.setUint32(offset + 4, 0, true); 517 | offset += 8; 518 | }); 519 | 520 | // The linker guarantees global data starts from at least wasmMinDataAddr. 521 | // Keep in sync with cmd/link/internal/ld/data.go:wasmMinDataAddr. 522 | const wasmMinDataAddr = 4096 + 8192; 523 | if (offset >= wasmMinDataAddr) { 524 | throw new Error("total length of command line and environment variables exceeds limit"); 525 | } 526 | 527 | this._inst.exports.run(argc, argv); 528 | if (this.exited) { 529 | this._resolveExitPromise(); 530 | } 531 | await this._exitPromise; 532 | } 533 | 534 | _resume() { 535 | if (this.exited) { 536 | throw new Error("Go program has already exited"); 537 | } 538 | this._inst.exports.resume(); 539 | if (this.exited) { 540 | this._resolveExitPromise(); 541 | } 542 | } 543 | 544 | _makeFuncWrapper(id) { 545 | const go = this; 546 | return function () { 547 | const event = { id: id, this: this, args: arguments }; 548 | go._pendingEvent = event; 549 | go._resume(); 550 | return event.result; 551 | }; 552 | } 553 | } 554 | })(); 555 | -------------------------------------------------------------------------------- /demo/wasm/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/shivamMg/trie/demo/wasm 2 | 3 | go 1.18 4 | 5 | require github.com/shivamMg/trie v0.0.0-20220605114339-55c1368e363e 6 | 7 | require github.com/shivamMg/ppds v0.0.1 // indirect 8 | 9 | replace github.com/shivamMg/trie => ../../ 10 | -------------------------------------------------------------------------------- /demo/wasm/go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 2 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 3 | github.com/shivamMg/ppds v0.0.1 h1:idK2dpaen652zOO+OmcwmyoPNncBNqfHjF/14eS5JIk= 4 | github.com/shivamMg/ppds v0.0.1/go.mod h1:hb39VqUO6qfkb9zBBQPTIV1vWBtI7yQsG0wr3pN78fM= 5 | github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= 6 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 7 | -------------------------------------------------------------------------------- /demo/wasm/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | _ "embed" 7 | "fmt" 8 | "io" 9 | "strings" 10 | "sync" 11 | "syscall/js" 12 | 13 | "github.com/shivamMg/trie" 14 | ) 15 | 16 | //go:embed words.txt 17 | var data []byte 18 | 19 | var ( 20 | mu sync.Mutex 21 | tri *trie.Trie 22 | 23 | longestWordLen int 24 | ) 25 | 26 | func init() { 27 | initTrie() 28 | } 29 | 30 | func initTrie() { 31 | mu.Lock() 32 | defer mu.Unlock() 33 | tri = trie.New() 34 | r := bufio.NewReader(bytes.NewReader(data)) 35 | for { 36 | word, err := r.ReadString('\n') 37 | if err == io.EOF { 38 | break 39 | } 40 | if err != nil { 41 | fmt.Println(err) 42 | } 43 | word = strings.TrimRight(word, "\n") 44 | if len(word) > longestWordLen { 45 | longestWordLen = len(word) 46 | } 47 | key := strings.Split(word, "") 48 | tri.Put(key, struct{}{}) 49 | } 50 | } 51 | 52 | func getNoEdits(key []string, ops []*trie.EditOp) []interface{} { 53 | uneditedLetters := make([]string, 0) 54 | for _, op := range ops { 55 | if op.Type == trie.EditOpTypeNoEdit { 56 | uneditedLetters = append(uneditedLetters, op.KeyPart) 57 | } 58 | } 59 | noEdits := make([]interface{}, len(key)) 60 | j := 0 61 | for i, letter := range key { 62 | unedited := false 63 | if j < len(uneditedLetters) && letter == uneditedLetters[j] { 64 | unedited = true 65 | j += 1 66 | } 67 | noEdits[i] = unedited 68 | } 69 | return noEdits 70 | } 71 | 72 | func getNoEditsForPrefixSearch(wordLen int, keyLen int) []interface{} { 73 | noEdits := make([]interface{}, keyLen) 74 | for i := 0; i < keyLen; i++ { 75 | noEdits[i] = i < wordLen 76 | } 77 | return noEdits 78 | } 79 | 80 | func searchWord(this js.Value, args []js.Value) interface{} { 81 | mu.Lock() 82 | defer mu.Unlock() 83 | word := args[0].String() 84 | if len(word) > longestWordLen { 85 | return map[string]interface{}{ 86 | "words": []interface{}{}, 87 | "noEdits": []interface{}{}, 88 | } 89 | } 90 | approximate := args[1].Bool() 91 | key := strings.Split(word, "") 92 | opts := []func(*trie.SearchOptions){trie.WithMaxResults(10)} 93 | if approximate { 94 | opts = append(opts, trie.WithMaxEditDistance(3), trie.WithEditOps(), trie.WithTopKLeastEdited()) 95 | } 96 | results := tri.Search(key, opts...) 97 | n := len(results.Results) 98 | words := make([]interface{}, n) 99 | noEdits := make([]interface{}, n) 100 | for i, res := range results.Results { 101 | words[i] = strings.Join(res.Key, "") 102 | if approximate { 103 | noEdits[i] = getNoEdits(res.Key, res.EditOps) 104 | } else { 105 | noEdits[i] = getNoEditsForPrefixSearch(len(word), len(res.Key)) 106 | } 107 | } 108 | return map[string]interface{}{ 109 | "words": words, 110 | "noEdits": noEdits, 111 | } 112 | } 113 | 114 | func main() { 115 | c := make(chan struct{}, 0) 116 | js.Global().Set("searchWord", js.FuncOf(searchWord)) 117 | <-c 118 | } 119 | -------------------------------------------------------------------------------- /demo/words.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import string 4 | 5 | 6 | WORDS_FILE = '/usr/share/dict/american-english' # https://en.wikipedia.org/wiki/Words_(Unix) 7 | TARGET_FILE = 'wasm/words.txt' 8 | 9 | 10 | if __name__ == '__main__': 11 | with open(WORDS_FILE) as f: 12 | words = [line.rstrip('\n') for line in f.readlines()] 13 | words = [word for word in words if all([ch in string.ascii_lowercase for ch in word])] # removes words: e.g. Abby (noun), accuser's (apostrophe) 14 | word_set = set(words) 15 | is_singular = lambda word: not word.endswith('s') or word[:-1] not in word_set 16 | words = [word for word in words if is_singular(word)] 17 | 18 | with open(TARGET_FILE, 'w') as f: 19 | for word in words: 20 | f.write(word + '\n') 21 | print(len(words), 'words written to', TARGET_FILE) -------------------------------------------------------------------------------- /dll.go: -------------------------------------------------------------------------------- 1 | package trie 2 | 3 | // TODO: tests 4 | type doublyLinkedList struct { 5 | head, tail *dllNode 6 | } 7 | 8 | type dllNode struct { 9 | trieNode *Node 10 | next, prev *dllNode 11 | } 12 | 13 | func newDLLNode(trieNode *Node) *dllNode { 14 | return &dllNode{trieNode: trieNode} 15 | } 16 | 17 | func (dll *doublyLinkedList) append(node *dllNode) { 18 | if dll.head == nil { 19 | dll.head = node 20 | dll.tail = node 21 | return 22 | } 23 | dll.tail.next = node 24 | node.prev = dll.tail 25 | dll.tail = node 26 | } 27 | 28 | func (dll *doublyLinkedList) pop(node *dllNode) { 29 | if node == dll.head && node == dll.tail { 30 | dll.head = nil 31 | dll.tail = nil 32 | return 33 | } 34 | if node == dll.head { 35 | dll.head = node.next 36 | dll.head.prev = nil 37 | node.next = nil 38 | return 39 | } 40 | if node == dll.tail { 41 | dll.tail = node.prev 42 | dll.tail.next = nil 43 | node.prev = nil 44 | return 45 | } 46 | prev := node.prev 47 | next := node.next 48 | prev.next = next 49 | next.prev = prev 50 | node.next = nil 51 | node.prev = nil 52 | } 53 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | package trie_test 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/shivamMg/trie" 7 | ) 8 | 9 | func printEditOps(ops []*trie.EditOp) { 10 | for _, op := range ops { 11 | switch op.Type { 12 | case trie.EditOpTypeNoEdit: 13 | fmt.Printf("- don't edit %q\n", op.KeyPart) 14 | case trie.EditOpTypeInsert: 15 | fmt.Printf("- insert %q\n", op.KeyPart) 16 | case trie.EditOpTypeDelete: 17 | fmt.Printf("- delete %q\n", op.KeyPart) 18 | case trie.EditOpTypeReplace: 19 | fmt.Printf("- replace %q with %q\n", op.KeyPart, op.ReplaceWith) 20 | } 21 | } 22 | } 23 | 24 | func Example() { 25 | tri := trie.New() 26 | // Put keys ([]string) and values (any) 27 | tri.Put([]string{"the"}, 1) 28 | tri.Put([]string{"the", "quick", "brown", "fox"}, 2) 29 | tri.Put([]string{"the", "quick", "sports", "car"}, 3) 30 | tri.Put([]string{"the", "green", "tree"}, 4) 31 | tri.Put([]string{"an", "apple", "tree"}, 5) 32 | tri.Put([]string{"an", "umbrella"}, 6) 33 | 34 | tri.Root().Print() 35 | // Output (full trie with terminals ending with ($)): 36 | // ^ 37 | // ├─ the ($) 38 | // │ ├─ quick 39 | // │ │ ├─ brown 40 | // │ │ │ └─ fox ($) 41 | // │ │ └─ sports 42 | // │ │ └─ car ($) 43 | // │ └─ green 44 | // │ └─ tree ($) 45 | // └─ an 46 | // ├─ apple 47 | // │ └─ tree ($) 48 | // └─ umbrella ($) 49 | 50 | results := tri.Search([]string{"the", "quick"}) 51 | for _, res := range results.Results { 52 | fmt.Println(res.Key, res.Value) 53 | } 54 | // Output (prefix-based search): 55 | // [the quick brown fox] 2 56 | // [the quick sports car] 3 57 | 58 | key := []string{"the", "tree"} 59 | results = tri.Search(key, trie.WithMaxEditDistance(2), // An edit can be insert, delete, replace 60 | trie.WithEditOps()) 61 | for _, res := range results.Results { 62 | fmt.Println(res.Key, res.EditDistance) // EditDistance is number of edits needed to convert to [the tree] 63 | } 64 | // Output (results not more than 2 edits away from [the tree]): 65 | // [the] 1 66 | // [the green tree] 1 67 | // [an apple tree] 2 68 | // [an umbrella] 2 69 | 70 | result := results.Results[2] 71 | fmt.Printf("To convert %v to %v:\n", result.Key, key) 72 | printEditOps(result.EditOps) 73 | // Output (edit operations needed to covert a result to [the tree]): 74 | // To convert [an apple tree] to [the tree]: 75 | // - delete "an" 76 | // - replace "apple" with "the" 77 | // - don't edit "tree" 78 | 79 | results = tri.Search(key, trie.WithMaxEditDistance(2), trie.WithTopKLeastEdited(), trie.WithMaxResults(2)) 80 | for _, res := range results.Results { 81 | fmt.Println(res.Key, res.Value, res.EditDistance) 82 | } 83 | // Output (top 2 least edited results): 84 | // [the] 1 1 85 | // [the green tree] 4 1 86 | } 87 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/shivamMg/trie 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/shivamMg/ppds v0.0.1 7 | github.com/stretchr/testify v1.7.1 8 | ) 9 | 10 | require ( 11 | github.com/davecgh/go-spew v1.1.0 // indirect 12 | github.com/pmezard/go-difflib v1.0.0 // indirect 13 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect 14 | ) 15 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 4 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 5 | github.com/shivamMg/ppds v0.0.1 h1:idK2dpaen652zOO+OmcwmyoPNncBNqfHjF/14eS5JIk= 6 | github.com/shivamMg/ppds v0.0.1/go.mod h1:hb39VqUO6qfkb9zBBQPTIV1vWBtI7yQsG0wr3pN78fM= 7 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 8 | github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= 9 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 10 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 11 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 12 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 13 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 14 | -------------------------------------------------------------------------------- /go.work: -------------------------------------------------------------------------------- 1 | go 1.18 2 | 3 | use ./demo/wasm 4 | -------------------------------------------------------------------------------- /go.work.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 2 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 3 | github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= 4 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 5 | -------------------------------------------------------------------------------- /heap.go: -------------------------------------------------------------------------------- 1 | package trie 2 | 3 | type searchResultMaxHeap []*SearchResult 4 | 5 | func (s searchResultMaxHeap) Len() int { 6 | return len(s) 7 | } 8 | 9 | func (s searchResultMaxHeap) Less(i, j int) bool { 10 | if s[i].EditDistance == s[j].EditDistance { 11 | return s[i].tiebreaker > s[j].tiebreaker 12 | } 13 | return s[i].EditDistance > s[j].EditDistance 14 | } 15 | 16 | func (s searchResultMaxHeap) Swap(i, j int) { 17 | s[i], s[j] = s[j], s[i] 18 | } 19 | 20 | func (s *searchResultMaxHeap) Push(x interface{}) { 21 | *s = append(*s, x.(*SearchResult)) 22 | } 23 | 24 | func (s *searchResultMaxHeap) Pop() interface{} { 25 | old := *s 26 | n := len(old) 27 | x := old[n-1] 28 | *s = old[0 : n-1] 29 | return x 30 | } 31 | -------------------------------------------------------------------------------- /search.go: -------------------------------------------------------------------------------- 1 | package trie 2 | 3 | import ( 4 | "container/heap" 5 | "errors" 6 | "math" 7 | ) 8 | 9 | type EditOpType int 10 | 11 | const ( 12 | EditOpTypeNoEdit EditOpType = iota 13 | EditOpTypeInsert 14 | EditOpTypeDelete 15 | EditOpTypeReplace 16 | ) 17 | 18 | // EditOp represents an Edit Operation. 19 | type EditOp struct { 20 | Type EditOpType 21 | // KeyPart: 22 | // - In case of NoEdit, KeyPart is to be retained. 23 | // - In case of Insert, KeyPart is to be inserted in the key. 24 | // - In case of Delete/Replace, KeyPart is the part of the key on which delete/replace is performed. 25 | KeyPart string 26 | // ReplaceWith is set for Type=EditOpTypeReplace 27 | ReplaceWith string 28 | } 29 | 30 | type SearchResults struct { 31 | Results []*SearchResult 32 | heap *searchResultMaxHeap 33 | tiebreakerCount int 34 | } 35 | 36 | type SearchResult struct { 37 | // Key is the key that was Put() into the Trie. 38 | Key []string 39 | // Value is the value that was Put() into the Trie. 40 | Value interface{} 41 | // EditDistance is the number of edits (insert/delete/replace) needed to convert Key into the Search()-ed key. 42 | EditDistance int 43 | // EditOps is the list of edit operations (see EditOpType) needed to convert Key into the Search()-ed key. 44 | EditOps []*EditOp 45 | 46 | tiebreaker int 47 | } 48 | 49 | type SearchOptions struct { 50 | // - WithExactKey 51 | // - WithMaxResults 52 | // - WithMaxEditDistance 53 | // - WithEditOps 54 | // - WithTopKLeastEdited 55 | exactKey bool 56 | maxResults bool 57 | maxResultsCount int 58 | editDistance bool 59 | maxEditDistance int 60 | editOps bool 61 | topKLeastEdited bool 62 | } 63 | 64 | // WithExactKey can be passed to Search(). When passed, Search() returns just the result with 65 | // Key=Search()-ed key. If the key does not exist, result list will be empty. 66 | func WithExactKey() func(*SearchOptions) { 67 | return func(so *SearchOptions) { 68 | so.exactKey = true 69 | } 70 | } 71 | 72 | // WithMaxResults can be passed to Search(). When passed, Search() will return at most maxResults 73 | // number of results. 74 | func WithMaxResults(maxResults int) func(*SearchOptions) { 75 | if maxResults <= 0 { 76 | panic(errors.New("invalid usage: maxResults must be greater than zero")) 77 | } 78 | return func(so *SearchOptions) { 79 | so.maxResults = true 80 | so.maxResultsCount = maxResults 81 | } 82 | } 83 | 84 | // WithMaxEditDistance can be passed to Search(). When passed, Search() changes its default behaviour from 85 | // Prefix search to Edit distance search. It can be used to return "Approximate" results instead of strict 86 | // Prefix search results. 87 | // 88 | // maxDistance is the maximum number of edits allowed on Trie keys to consider them as a SearchResult. 89 | // Higher the maxDistance, more lenient and slower the search becomes. 90 | // 91 | // e.g. If a Trie stores English words, then searching for "wheat" with maxDistance=1 might return similar 92 | // looking words like "wheat", "cheat", "heat", "what", etc. With maxDistance=2 it might also return words like 93 | // "beat", "ahead", etc. 94 | // 95 | // Read about Edit distance: https://en.wikipedia.org/wiki/Edit_distance 96 | func WithMaxEditDistance(maxDistance int) func(*SearchOptions) { 97 | if maxDistance <= 0 { 98 | panic(errors.New("invalid usage: maxDistance must be greater than zero")) 99 | } 100 | return func(so *SearchOptions) { 101 | so.editDistance = true 102 | so.maxEditDistance = maxDistance 103 | } 104 | } 105 | 106 | // WithEditOps can be passed to Search() alongside WithMaxEditDistance(). When passed, Search() also returns EditOps 107 | // for each SearchResult. EditOps can be used to determine the minimum number of edit operations needed to convert 108 | // a result Key into the Search()-ed key. 109 | // 110 | // e.g. Searching for "wheat" in a Trie that stores English words might return "beat". EditOps for this result might be: 111 | // 1. insert "w" 2. replace "b" with "h". 112 | // 113 | // There might be multiple ways to edit a key into another. EditOps represents only one. 114 | // 115 | // Computing EditOps makes Search() slower. 116 | func WithEditOps() func(*SearchOptions) { 117 | return func(so *SearchOptions) { 118 | so.editOps = true 119 | } 120 | } 121 | 122 | // WithTopKLeastEdited can be passed to Search() alongside WithMaxEditDistance() and WithMaxResults(). When passed, 123 | // Search() returns maxResults number of results that have the lowest EditDistances. Results are sorted on EditDistance 124 | // (lowest to highest). 125 | // 126 | // e.g. In a Trie that stores English words searching for "wheat" might return "wheat" (EditDistance=0), "cheat" (EditDistance=1), 127 | // "beat" (EditDistance=2) - in that order. 128 | func WithTopKLeastEdited() func(*SearchOptions) { 129 | return func(so *SearchOptions) { 130 | so.topKLeastEdited = true 131 | } 132 | } 133 | 134 | // Search() takes a key and some options to return results (see SearchResult) from the Trie. 135 | // Without any options, it does a Prefix search i.e. result Keys have the same prefix as key. 136 | // Order of the results is deterministic and will follow the order in which Put() was called for the keys. 137 | // See "With*" functions for options accepted by Search(). 138 | func (t *Trie) Search(key []string, options ...func(*SearchOptions)) *SearchResults { 139 | opts := &SearchOptions{} 140 | for _, f := range options { 141 | f(opts) 142 | } 143 | if opts.editOps && !opts.editDistance { 144 | panic(errors.New("invalid usage: WithEditOps() must not be passed without WithMaxEditDistance()")) 145 | } 146 | if opts.topKLeastEdited && !opts.editDistance { 147 | panic(errors.New("invalid usage: WithTopKLeastEdited() must not be passed without WithMaxEditDistance()")) 148 | } 149 | if opts.exactKey && opts.editDistance { 150 | panic(errors.New("invalid usage: WithExactKey() must not be passed with WithMaxEditDistance()")) 151 | } 152 | if opts.exactKey && opts.maxResults { 153 | panic(errors.New("invalid usage: WithExactKey() must not be passed with WithMaxResults()")) 154 | } 155 | if opts.topKLeastEdited && !opts.maxResults { 156 | panic(errors.New("invalid usage: WithTopKLeastEdited() must not be passed without WithMaxResults()")) 157 | } 158 | 159 | if opts.editDistance { 160 | return t.searchWithEditDistance(key, opts) 161 | } 162 | return t.search(key, opts) 163 | } 164 | 165 | func (t *Trie) searchWithEditDistance(key []string, opts *SearchOptions) *SearchResults { 166 | // https://en.wikipedia.org/wiki/Levenshtein_distance#Iterative_with_full_matrix 167 | // http://stevehanov.ca/blog/?id=114 168 | columns := len(key) + 1 169 | newRow := make([]int, columns) 170 | for i := 0; i < columns; i++ { 171 | newRow[i] = i 172 | } 173 | m := len(key) 174 | if m == 0 { 175 | m = 1 176 | } 177 | rows := make([][]int, 1, m) 178 | rows[0] = newRow 179 | results := &SearchResults{} 180 | if opts.topKLeastEdited { 181 | results.heap = &searchResultMaxHeap{} 182 | } 183 | 184 | keyColumn := make([]string, 1, m) 185 | stop := false 186 | // prioritize Node that has the same keyPart as key. this results in better results 187 | // e.g. if key=national, build with Node(keyPart=n) first so that keys like notional, nation, nationally, etc. are prioritized 188 | // same logic is used inside the recursive buildWithEditDistance() method 189 | var prioritizedNode *Node 190 | if len(key) > 0 { 191 | if prioritizedNode = t.root.children[key[0]]; prioritizedNode != nil { 192 | keyColumn[0] = prioritizedNode.keyPart 193 | t.buildWithEditDistance(&stop, results, prioritizedNode, &keyColumn, &rows, key, opts) 194 | } 195 | } 196 | for dllNode := t.root.childrenDLL.head; dllNode != nil; dllNode = dllNode.next { 197 | node := dllNode.trieNode 198 | if node == prioritizedNode { 199 | continue 200 | } 201 | keyColumn[0] = node.keyPart 202 | t.buildWithEditDistance(&stop, results, node, &keyColumn, &rows, key, opts) 203 | } 204 | if opts.topKLeastEdited { 205 | n := results.heap.Len() 206 | results.Results = make([]*SearchResult, n) 207 | for n != 0 { 208 | result := heap.Pop(results.heap).(*SearchResult) 209 | result.tiebreaker = 0 210 | results.Results[n-1] = result 211 | n-- 212 | } 213 | results.heap = nil 214 | results.tiebreakerCount = 0 215 | } 216 | return results 217 | } 218 | 219 | func (t *Trie) buildWithEditDistance(stop *bool, results *SearchResults, node *Node, keyColumn *[]string, rows *[][]int, key []string, opts *SearchOptions) { 220 | if *stop { 221 | return 222 | } 223 | prevRow := (*rows)[len(*rows)-1] 224 | columns := len(key) + 1 225 | newRow := make([]int, columns) 226 | newRow[0] = prevRow[0] + 1 227 | for i := 1; i < columns; i++ { 228 | replaceCost := 1 229 | if key[i-1] == (*keyColumn)[len(*keyColumn)-1] { 230 | replaceCost = 0 231 | } 232 | newRow[i] = min( 233 | newRow[i-1]+1, // insertion 234 | prevRow[i]+1, // deletion 235 | prevRow[i-1]+replaceCost, // substitution 236 | ) 237 | } 238 | *rows = append(*rows, newRow) 239 | 240 | if newRow[columns-1] <= opts.maxEditDistance && node.isTerminal { 241 | editDistance := newRow[columns-1] 242 | lazyCreate := func() *SearchResult { // optimization for the case where topKLeastEdited=true and the result should not be pushed to heap 243 | resultKey := make([]string, len(*keyColumn)) 244 | copy(resultKey, *keyColumn) 245 | result := &SearchResult{Key: resultKey, Value: node.value, EditDistance: editDistance} 246 | if opts.editOps { 247 | result.EditOps = t.getEditOps(rows, keyColumn, key) 248 | } 249 | return result 250 | } 251 | if opts.topKLeastEdited { 252 | results.tiebreakerCount++ 253 | if results.heap.Len() < opts.maxResultsCount { 254 | result := lazyCreate() 255 | result.tiebreaker = results.tiebreakerCount 256 | heap.Push(results.heap, result) 257 | } else if (*results.heap)[0].EditDistance > editDistance { 258 | result := lazyCreate() 259 | result.tiebreaker = results.tiebreakerCount 260 | heap.Pop(results.heap) 261 | heap.Push(results.heap, result) 262 | } 263 | } else { 264 | result := lazyCreate() 265 | results.Results = append(results.Results, result) 266 | if opts.maxResults && len(results.Results) == opts.maxResultsCount { 267 | *stop = true 268 | return 269 | } 270 | } 271 | } 272 | 273 | if min(newRow...) <= opts.maxEditDistance { 274 | var prioritizedNode *Node 275 | m := len(*keyColumn) 276 | if m < len(key) { 277 | if prioritizedNode = node.children[key[m]]; prioritizedNode != nil { 278 | *keyColumn = append(*keyColumn, prioritizedNode.keyPart) 279 | t.buildWithEditDistance(stop, results, prioritizedNode, keyColumn, rows, key, opts) 280 | *keyColumn = (*keyColumn)[:len(*keyColumn)-1] 281 | } 282 | } 283 | for dllNode := node.childrenDLL.head; dllNode != nil; dllNode = dllNode.next { 284 | child := dllNode.trieNode 285 | if child == prioritizedNode { 286 | continue 287 | } 288 | *keyColumn = append(*keyColumn, child.keyPart) 289 | t.buildWithEditDistance(stop, results, child, keyColumn, rows, key, opts) 290 | *keyColumn = (*keyColumn)[:len(*keyColumn)-1] 291 | } 292 | } 293 | 294 | *rows = (*rows)[:len(*rows)-1] 295 | } 296 | 297 | func (t *Trie) getEditOps(rows *[][]int, keyColumn *[]string, key []string) []*EditOp { 298 | // https://gist.github.com/jlherren/d97839b1276b9bd7faa930f74711a4b6 299 | ops := make([]*EditOp, 0, len(key)) 300 | r, c := len(*rows)-1, len((*rows)[0])-1 301 | for r > 0 || c > 0 { 302 | insertionCost, deletionCost, substitutionCost := math.MaxInt, math.MaxInt, math.MaxInt 303 | if c > 0 { 304 | insertionCost = (*rows)[r][c-1] 305 | } 306 | if r > 0 { 307 | deletionCost = (*rows)[r-1][c] 308 | } 309 | if r > 0 && c > 0 { 310 | substitutionCost = (*rows)[r-1][c-1] 311 | } 312 | minCost := min(insertionCost, deletionCost, substitutionCost) 313 | if minCost == substitutionCost { 314 | if (*rows)[r][c] > (*rows)[r-1][c-1] { 315 | ops = append(ops, &EditOp{Type: EditOpTypeReplace, KeyPart: (*keyColumn)[r-1], ReplaceWith: key[c-1]}) 316 | } else { 317 | ops = append(ops, &EditOp{Type: EditOpTypeNoEdit, KeyPart: (*keyColumn)[r-1]}) 318 | } 319 | r -= 1 320 | c -= 1 321 | } else if minCost == deletionCost { 322 | ops = append(ops, &EditOp{Type: EditOpTypeDelete, KeyPart: (*keyColumn)[r-1]}) 323 | r -= 1 324 | } else if minCost == insertionCost { 325 | ops = append(ops, &EditOp{Type: EditOpTypeInsert, KeyPart: key[c-1]}) 326 | c -= 1 327 | } 328 | } 329 | for i, j := 0, len(ops)-1; i < j; i, j = i+1, j-1 { 330 | ops[i], ops[j] = ops[j], ops[i] 331 | } 332 | return ops 333 | } 334 | 335 | func (t *Trie) search(prefixKey []string, opts *SearchOptions) *SearchResults { 336 | results := &SearchResults{} 337 | node := t.root 338 | for _, keyPart := range prefixKey { 339 | child, ok := node.children[keyPart] 340 | if !ok { 341 | return results 342 | } 343 | node = child 344 | } 345 | if opts.exactKey { 346 | if node.isTerminal { 347 | result := &SearchResult{Key: prefixKey, Value: node.value} 348 | results.Results = append(results.Results, result) 349 | } 350 | return results 351 | } 352 | t.build(results, node, &prefixKey, opts) 353 | return results 354 | } 355 | 356 | func (t *Trie) build(results *SearchResults, node *Node, prefixKey *[]string, opts *SearchOptions) (stop bool) { 357 | if node.isTerminal { 358 | key := make([]string, len(*prefixKey)) 359 | copy(key, *prefixKey) 360 | result := &SearchResult{Key: key, Value: node.value} 361 | results.Results = append(results.Results, result) 362 | if opts.maxResults && len(results.Results) == opts.maxResultsCount { 363 | return true 364 | } 365 | } 366 | 367 | for dllNode := node.childrenDLL.head; dllNode != nil; dllNode = dllNode.next { 368 | child := dllNode.trieNode 369 | *prefixKey = append(*prefixKey, child.keyPart) 370 | stop := t.build(results, child, prefixKey, opts) 371 | *prefixKey = (*prefixKey)[:len(*prefixKey)-1] 372 | if stop { 373 | return true 374 | } 375 | } 376 | return false 377 | } 378 | 379 | func min(s ...int) int { 380 | m := s[0] 381 | for _, a := range s[1:] { 382 | if a < m { 383 | m = a 384 | } 385 | } 386 | return m 387 | } 388 | -------------------------------------------------------------------------------- /search_test.go: -------------------------------------------------------------------------------- 1 | package trie_test 2 | 3 | import ( 4 | "bufio" 5 | "io" 6 | "os" 7 | "strings" 8 | "testing" 9 | 10 | "github.com/shivamMg/trie" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | var ( 15 | benchmarkResults *trie.SearchResults // https://dave.cheney.net/2013/06/30/how-to-write-benchmarks-in-go 16 | wordsTrie *trie.Trie 17 | ) 18 | 19 | func TestTrie_Search(t *testing.T) { 20 | tri := trie.New() 21 | tri.Put([]string{"the"}, 1) 22 | tri.Put([]string{"the", "quick", "brown", "fox"}, 2) 23 | tri.Put([]string{"the", "quick", "swimmer"}, 3) 24 | tri.Put([]string{"the", "green", "tree"}, 4) 25 | tri.Put([]string{"an", "apple", "tree"}, 5) 26 | tri.Put([]string{"an", "umbrella"}, 6) 27 | 28 | testCases := []struct { 29 | name string 30 | inputKey []string 31 | inputOptions []func(*trie.SearchOptions) 32 | expectedResults *trie.SearchResults 33 | }{ 34 | { 35 | name: "prefix-one-word", 36 | inputKey: []string{"the"}, 37 | expectedResults: &trie.SearchResults{ 38 | Results: []*trie.SearchResult{ 39 | {Key: []string{"the"}, Value: 1}, 40 | {Key: []string{"the", "quick", "brown", "fox"}, Value: 2}, 41 | {Key: []string{"the", "quick", "swimmer"}, Value: 3}, 42 | {Key: []string{"the", "green", "tree"}, Value: 4}, 43 | }, 44 | }, 45 | }, 46 | { 47 | name: "prefix-one-word-with-max-three-results", 48 | inputKey: []string{"the"}, 49 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxResults(3)}, 50 | expectedResults: &trie.SearchResults{ 51 | Results: []*trie.SearchResult{ 52 | {Key: []string{"the"}, Value: 1}, 53 | {Key: []string{"the", "quick", "brown", "fox"}, Value: 2}, 54 | {Key: []string{"the", "quick", "swimmer"}, Value: 3}, 55 | }, 56 | }, 57 | }, 58 | { 59 | name: "prefix-multiple-words", 60 | inputKey: []string{"the", "quick"}, 61 | expectedResults: &trie.SearchResults{ 62 | Results: []*trie.SearchResult{ 63 | {Key: []string{"the", "quick", "brown", "fox"}, Value: 2}, 64 | {Key: []string{"the", "quick", "swimmer"}, Value: 3}, 65 | }, 66 | }, 67 | }, 68 | { 69 | name: "prefix-non-existing", 70 | inputKey: []string{"non-existing"}, 71 | expectedResults: &trie.SearchResults{}, 72 | }, 73 | { 74 | name: "prefix-empty", 75 | inputKey: []string{}, 76 | expectedResults: &trie.SearchResults{ 77 | Results: []*trie.SearchResult{ 78 | {Key: []string{"the"}, Value: 1}, 79 | {Key: []string{"the", "quick", "brown", "fox"}, Value: 2}, 80 | {Key: []string{"the", "quick", "swimmer"}, Value: 3}, 81 | {Key: []string{"the", "green", "tree"}, Value: 4}, 82 | {Key: []string{"an", "apple", "tree"}, Value: 5}, 83 | {Key: []string{"an", "umbrella"}, Value: 6}, 84 | }, 85 | }, 86 | }, 87 | { 88 | name: "prefix-nil", 89 | inputKey: nil, 90 | expectedResults: &trie.SearchResults{ 91 | Results: []*trie.SearchResult{ 92 | {Key: []string{"the"}, Value: 1}, 93 | {Key: []string{"the", "quick", "brown", "fox"}, Value: 2}, 94 | {Key: []string{"the", "quick", "swimmer"}, Value: 3}, 95 | {Key: []string{"the", "green", "tree"}, Value: 4}, 96 | {Key: []string{"an", "apple", "tree"}, Value: 5}, 97 | {Key: []string{"an", "umbrella"}, Value: 6}, 98 | }, 99 | }, 100 | }, 101 | { 102 | name: "prefix-one-word-with-exact-key", 103 | inputKey: []string{"the"}, 104 | inputOptions: []func(*trie.SearchOptions){trie.WithExactKey()}, 105 | expectedResults: &trie.SearchResults{ 106 | Results: []*trie.SearchResult{ 107 | {Key: []string{"the"}, Value: 1}, 108 | }, 109 | }, 110 | }, 111 | { 112 | name: "prefix-multiple-words-with-exact-key", 113 | inputKey: []string{"the", "quick", "swimmer"}, 114 | inputOptions: []func(*trie.SearchOptions){trie.WithExactKey()}, 115 | expectedResults: &trie.SearchResults{ 116 | Results: []*trie.SearchResult{ 117 | {Key: []string{"the", "quick", "swimmer"}, Value: 3}, 118 | }, 119 | }, 120 | }, 121 | { 122 | name: "edit-distance-one-edit", 123 | inputKey: []string{"the", "tree"}, 124 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(1)}, 125 | expectedResults: &trie.SearchResults{ 126 | Results: []*trie.SearchResult{ 127 | {Key: []string{"the"}, Value: 1, EditDistance: 1}, 128 | {Key: []string{"the", "green", "tree"}, Value: 4, EditDistance: 1}, 129 | }, 130 | }, 131 | }, 132 | { 133 | name: "edit-distance-one-edit-with-edit-opts", 134 | inputKey: []string{"the", "tree"}, 135 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(1), trie.WithEditOps()}, 136 | expectedResults: &trie.SearchResults{ 137 | Results: []*trie.SearchResult{ 138 | {Key: []string{"the"}, Value: 1, EditDistance: 1, EditOps: []*trie.EditOp{ 139 | {Type: trie.EditOpTypeNoEdit, KeyPart: "the"}, 140 | {Type: trie.EditOpTypeInsert, KeyPart: "tree"}, 141 | }}, 142 | {Key: []string{"the", "green", "tree"}, Value: 4, EditDistance: 1, EditOps: []*trie.EditOp{ 143 | {Type: trie.EditOpTypeNoEdit, KeyPart: "the"}, 144 | {Type: trie.EditOpTypeDelete, KeyPart: "green"}, 145 | {Type: trie.EditOpTypeNoEdit, KeyPart: "tree"}, 146 | }}, 147 | }, 148 | }, 149 | }, 150 | { 151 | name: "edit-distance-two-edits-with-edit-opts", 152 | inputKey: []string{"the", "tree"}, 153 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(2), trie.WithEditOps()}, 154 | expectedResults: &trie.SearchResults{ 155 | Results: []*trie.SearchResult{ 156 | {Key: []string{"the"}, Value: 1, EditDistance: 1, EditOps: []*trie.EditOp{ 157 | {Type: trie.EditOpTypeNoEdit, KeyPart: "the"}, 158 | {Type: trie.EditOpTypeInsert, KeyPart: "tree"}, 159 | }}, 160 | {Key: []string{"the", "quick", "swimmer"}, Value: 3, EditDistance: 2, EditOps: []*trie.EditOp{ 161 | {Type: trie.EditOpTypeNoEdit, KeyPart: "the"}, 162 | {Type: trie.EditOpTypeDelete, KeyPart: "quick"}, 163 | {Type: trie.EditOpTypeReplace, KeyPart: "swimmer", ReplaceWith: "tree"}, 164 | }}, 165 | {Key: []string{"the", "green", "tree"}, Value: 4, EditDistance: 1, EditOps: []*trie.EditOp{ 166 | {Type: trie.EditOpTypeNoEdit, KeyPart: "the"}, 167 | {Type: trie.EditOpTypeDelete, KeyPart: "green"}, 168 | {Type: trie.EditOpTypeNoEdit, KeyPart: "tree"}, 169 | }}, 170 | {Key: []string{"an", "apple", "tree"}, Value: 5, EditDistance: 2, EditOps: []*trie.EditOp{ 171 | {Type: trie.EditOpTypeDelete, KeyPart: "an"}, 172 | {Type: trie.EditOpTypeReplace, KeyPart: "apple", ReplaceWith: "the"}, 173 | {Type: trie.EditOpTypeNoEdit, KeyPart: "tree"}, 174 | }}, 175 | {Key: []string{"an", "umbrella"}, Value: 6, EditDistance: 2, EditOps: []*trie.EditOp{ 176 | {Type: trie.EditOpTypeReplace, KeyPart: "an", ReplaceWith: "the"}, 177 | {Type: trie.EditOpTypeReplace, KeyPart: "umbrella", ReplaceWith: "tree"}, 178 | }}, 179 | }, 180 | }, 181 | }, 182 | { 183 | name: "edit-distance-two-edits-with-edit-opts-with-max-four-results", 184 | inputKey: []string{"the", "tree"}, 185 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(2), trie.WithEditOps(), trie.WithMaxResults(4)}, 186 | expectedResults: &trie.SearchResults{ 187 | Results: []*trie.SearchResult{ 188 | {Key: []string{"the"}, Value: 1, EditDistance: 1, EditOps: []*trie.EditOp{ 189 | {Type: trie.EditOpTypeNoEdit, KeyPart: "the"}, 190 | {Type: trie.EditOpTypeInsert, KeyPart: "tree"}, 191 | }}, 192 | {Key: []string{"the", "quick", "swimmer"}, Value: 3, EditDistance: 2, EditOps: []*trie.EditOp{ 193 | {Type: trie.EditOpTypeNoEdit, KeyPart: "the"}, 194 | {Type: trie.EditOpTypeDelete, KeyPart: "quick"}, 195 | {Type: trie.EditOpTypeReplace, KeyPart: "swimmer", ReplaceWith: "tree"}, 196 | }}, 197 | {Key: []string{"the", "green", "tree"}, Value: 4, EditDistance: 1, EditOps: []*trie.EditOp{ 198 | {Type: trie.EditOpTypeNoEdit, KeyPart: "the"}, 199 | {Type: trie.EditOpTypeDelete, KeyPart: "green"}, 200 | {Type: trie.EditOpTypeNoEdit, KeyPart: "tree"}, 201 | }}, 202 | {Key: []string{"an", "apple", "tree"}, Value: 5, EditDistance: 2, EditOps: []*trie.EditOp{ 203 | {Type: trie.EditOpTypeDelete, KeyPart: "an"}, 204 | {Type: trie.EditOpTypeReplace, KeyPart: "apple", ReplaceWith: "the"}, 205 | {Type: trie.EditOpTypeNoEdit, KeyPart: "tree"}, 206 | }}, 207 | }, 208 | }, 209 | }, 210 | { 211 | name: "edit-distance-one-edit-with-topk", 212 | inputKey: []string{"the", "tree"}, 213 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(1), trie.WithTopKLeastEdited(), trie.WithMaxResults(1)}, 214 | expectedResults: &trie.SearchResults{ 215 | Results: []*trie.SearchResult{ 216 | {Key: []string{"the"}, Value: 1, EditDistance: 1}, 217 | }, 218 | }, 219 | }, 220 | { 221 | name: "edit-distance-two-edits-with-edit-opts-with-topk", 222 | inputKey: []string{"the", "tree"}, 223 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(2), trie.WithEditOps(), trie.WithTopKLeastEdited(), trie.WithMaxResults(4)}, 224 | expectedResults: &trie.SearchResults{ 225 | Results: []*trie.SearchResult{ 226 | {Key: []string{"the"}, Value: 1, EditDistance: 1, EditOps: []*trie.EditOp{ 227 | {Type: trie.EditOpTypeNoEdit, KeyPart: "the"}, 228 | {Type: trie.EditOpTypeInsert, KeyPart: "tree"}, 229 | }}, 230 | {Key: []string{"the", "green", "tree"}, Value: 4, EditDistance: 1, EditOps: []*trie.EditOp{ 231 | {Type: trie.EditOpTypeNoEdit, KeyPart: "the"}, 232 | {Type: trie.EditOpTypeDelete, KeyPart: "green"}, 233 | {Type: trie.EditOpTypeNoEdit, KeyPart: "tree"}, 234 | }}, 235 | {Key: []string{"the", "quick", "swimmer"}, Value: 3, EditDistance: 2, EditOps: []*trie.EditOp{ 236 | {Type: trie.EditOpTypeNoEdit, KeyPart: "the"}, 237 | {Type: trie.EditOpTypeDelete, KeyPart: "quick"}, 238 | {Type: trie.EditOpTypeReplace, KeyPart: "swimmer", ReplaceWith: "tree"}, 239 | }}, 240 | {Key: []string{"an", "apple", "tree"}, Value: 5, EditDistance: 2, EditOps: []*trie.EditOp{ 241 | {Type: trie.EditOpTypeDelete, KeyPart: "an"}, 242 | {Type: trie.EditOpTypeReplace, KeyPart: "apple", ReplaceWith: "the"}, 243 | {Type: trie.EditOpTypeNoEdit, KeyPart: "tree"}, 244 | }}, 245 | }, 246 | }, 247 | }, 248 | { 249 | name: "edit-distance-two-edits-with-two-topk", 250 | inputKey: []string{"the", "tree"}, 251 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(2), trie.WithTopKLeastEdited(), trie.WithMaxResults(2)}, 252 | expectedResults: &trie.SearchResults{ 253 | Results: []*trie.SearchResult{ 254 | {Key: []string{"the"}, Value: 1, EditDistance: 1}, 255 | {Key: []string{"the", "green", "tree"}, Value: 4, EditDistance: 1}, 256 | }, 257 | }, 258 | }, 259 | { 260 | name: "edit-distance-empty", 261 | inputKey: []string{}, 262 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(2), trie.WithTopKLeastEdited(), trie.WithMaxResults(5)}, 263 | expectedResults: &trie.SearchResults{ 264 | Results: []*trie.SearchResult{ 265 | {Key: []string{"the"}, Value: 1, EditDistance: 1}, 266 | {Key: []string{"an", "umbrella"}, Value: 6, EditDistance: 2}, 267 | }, 268 | }, 269 | }, 270 | { 271 | name: "edit-distance-nil", 272 | inputKey: nil, 273 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(2), trie.WithTopKLeastEdited(), trie.WithMaxResults(5)}, 274 | expectedResults: &trie.SearchResults{ 275 | Results: []*trie.SearchResult{ 276 | {Key: []string{"the"}, Value: 1, EditDistance: 1}, 277 | {Key: []string{"an", "umbrella"}, Value: 6, EditDistance: 2}, 278 | }, 279 | }, 280 | }, 281 | } 282 | 283 | for _, tc := range testCases { 284 | t.Run(tc.name, func(t *testing.T) { 285 | actual := tri.Search(tc.inputKey, tc.inputOptions...) 286 | assert.Len(t, actual.Results, len(tc.expectedResults.Results)) 287 | assert.Equal(t, tc.expectedResults, actual) 288 | }) 289 | } 290 | } 291 | 292 | func TestTrie_Search_WordsTrie(t *testing.T) { 293 | tri := getWordsTrie() 294 | testCases := []struct { 295 | name string 296 | inputKey []string 297 | inputOptions []func(*trie.SearchOptions) 298 | expectedResults *trie.SearchResults 299 | }{ 300 | { 301 | name: "prefix", 302 | inputKey: strings.Split("aband", ""), 303 | expectedResults: &trie.SearchResults{ 304 | Results: []*trie.SearchResult{ 305 | {Key: strings.Split("abandon", "")}, 306 | {Key: strings.Split("abandoned", "")}, 307 | {Key: strings.Split("abandoning", "")}, 308 | {Key: strings.Split("abandonment", "")}, 309 | }, 310 | }, 311 | }, 312 | { 313 | name: "edit-distance", 314 | inputKey: strings.Split("wheat", ""), 315 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(1)}, 316 | expectedResults: &trie.SearchResults{ 317 | Results: []*trie.SearchResult{ 318 | {Key: strings.Split("wheat", ""), EditDistance: 0}, 319 | {Key: strings.Split("wheal", ""), EditDistance: 1}, 320 | {Key: strings.Split("whet", ""), EditDistance: 1}, 321 | {Key: strings.Split("what", ""), EditDistance: 1}, 322 | {Key: strings.Split("cheat", ""), EditDistance: 1}, 323 | {Key: strings.Split("heat", ""), EditDistance: 1}, 324 | }, 325 | }, 326 | }, 327 | { 328 | name: "edit-distance-with-max-results", 329 | inputKey: strings.Split("national", ""), 330 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(3), trie.WithMaxResults(13)}, 331 | expectedResults: &trie.SearchResults{ 332 | Results: []*trie.SearchResult{ 333 | {Key: strings.Split("nation", ""), EditDistance: 2}, 334 | {Key: strings.Split("national", ""), EditDistance: 0}, 335 | {Key: strings.Split("nationalism", ""), EditDistance: 3}, 336 | {Key: strings.Split("nationalist", ""), EditDistance: 3}, 337 | {Key: strings.Split("nationality", ""), EditDistance: 3}, 338 | {Key: strings.Split("nationalize", ""), EditDistance: 3}, 339 | {Key: strings.Split("nationally", ""), EditDistance: 2}, 340 | {Key: strings.Split("natal", ""), EditDistance: 3}, 341 | {Key: strings.Split("natural", ""), EditDistance: 3}, 342 | {Key: strings.Split("nautical", ""), EditDistance: 3}, 343 | {Key: strings.Split("notion", ""), EditDistance: 3}, 344 | {Key: strings.Split("notional", ""), EditDistance: 1}, 345 | {Key: strings.Split("notionally", ""), EditDistance: 3}, 346 | }, 347 | }, 348 | }, 349 | { 350 | name: "edit-distance-with-topk", 351 | inputKey: strings.Split("national", ""), 352 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(3), trie.WithMaxResults(13), trie.WithTopKLeastEdited()}, 353 | expectedResults: &trie.SearchResults{ 354 | Results: []*trie.SearchResult{ 355 | {Key: strings.Split("national", ""), EditDistance: 0}, 356 | {Key: strings.Split("notional", ""), EditDistance: 1}, 357 | {Key: strings.Split("rational", ""), EditDistance: 1}, 358 | {Key: strings.Split("nation", ""), EditDistance: 2}, 359 | {Key: strings.Split("nationally", ""), EditDistance: 2}, 360 | {Key: strings.Split("atonal", ""), EditDistance: 2}, 361 | {Key: strings.Split("factional", ""), EditDistance: 2}, 362 | {Key: strings.Split("optional", ""), EditDistance: 2}, 363 | {Key: strings.Split("rationale", ""), EditDistance: 2}, 364 | {Key: strings.Split("nationalism", ""), EditDistance: 3}, 365 | {Key: strings.Split("nationalist", ""), EditDistance: 3}, 366 | {Key: strings.Split("nationality", ""), EditDistance: 3}, 367 | {Key: strings.Split("nationalize", ""), EditDistance: 3}, 368 | }, 369 | }, 370 | }, 371 | { 372 | name: "edit-distance-with-topk-stop-after-prioritized", 373 | inputKey: strings.Split("national", ""), 374 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(3), trie.WithMaxResults(2), trie.WithTopKLeastEdited()}, 375 | expectedResults: &trie.SearchResults{ 376 | Results: []*trie.SearchResult{ 377 | {Key: strings.Split("national", ""), EditDistance: 0}, 378 | {Key: strings.Split("notional", ""), EditDistance: 1}, 379 | }, 380 | }, 381 | }, 382 | } 383 | for _, tc := range testCases { 384 | t.Run(tc.name, func(t *testing.T) { 385 | actual := tri.Search(tc.inputKey, tc.inputOptions...) 386 | assert.Len(t, actual.Results, len(tc.expectedResults.Results)) 387 | assert.Equal(t, tc.expectedResults, actual) 388 | }) 389 | } 390 | } 391 | 392 | func TestTrie_Search_InvalidUsage_EditDistance_LessThanZeroDistance(t *testing.T) { 393 | assert.PanicsWithError(t, "invalid usage: maxDistance must be greater than zero", func() { 394 | trie.WithMaxEditDistance(0) 395 | }) 396 | assert.PanicsWithError(t, "invalid usage: maxDistance must be greater than zero", func() { 397 | trie.WithMaxEditDistance(-1) 398 | }) 399 | } 400 | 401 | func TestTrie_Search_InvalidUsage_MaxResults_LessThanZero(t *testing.T) { 402 | assert.PanicsWithError(t, "invalid usage: maxResults must be greater than zero", func() { 403 | trie.WithMaxResults(0) 404 | }) 405 | assert.PanicsWithError(t, "invalid usage: maxResults must be greater than zero", func() { 406 | trie.WithMaxResults(-1) 407 | }) 408 | } 409 | 410 | func TestTrie_Search_InvalidUsage_EditOpsWithoutMaxEditDistance(t *testing.T) { 411 | tri := trie.New() 412 | 413 | assert.PanicsWithError(t, "invalid usage: WithEditOps() must not be passed without WithMaxEditDistance()", func() { 414 | tri.Search(nil, trie.WithEditOps()) 415 | }) 416 | } 417 | 418 | func TestTrie_Search_InvalidUsage_TopKWithoutMaxEditDistance(t *testing.T) { 419 | tri := trie.New() 420 | 421 | assert.PanicsWithError(t, "invalid usage: WithTopKLeastEdited() must not be passed without WithMaxEditDistance()", func() { 422 | tri.Search(nil, trie.WithTopKLeastEdited()) 423 | }) 424 | } 425 | 426 | func TestTrie_Search_InvalidUsage_ExactKeyWithMaxEditDistance(t *testing.T) { 427 | tri := trie.New() 428 | 429 | assert.PanicsWithError(t, "invalid usage: WithExactKey() must not be passed with WithMaxEditDistance()", func() { 430 | tri.Search(nil, trie.WithExactKey(), trie.WithMaxEditDistance(1)) 431 | }) 432 | } 433 | 434 | func TestTrie_Search_InvalidUsage_ExactKeyWithMaxResults(t *testing.T) { 435 | tri := trie.New() 436 | 437 | assert.PanicsWithError(t, "invalid usage: WithExactKey() must not be passed with WithMaxResults()", func() { 438 | tri.Search(nil, trie.WithExactKey(), trie.WithMaxResults(1)) 439 | }) 440 | } 441 | 442 | func TestTrie_Search_InvalidUsage_TopKLeastEditedWithoutMaxResults(t *testing.T) { 443 | tri := trie.New() 444 | 445 | assert.PanicsWithError(t, "invalid usage: WithTopKLeastEdited() must not be passed without WithMaxResults()", func() { 446 | tri.Search(nil, trie.WithMaxEditDistance(1), trie.WithTopKLeastEdited()) 447 | }) 448 | } 449 | 450 | func BenchmarkTrie_Search_WordsTrie(b *testing.B) { 451 | tri := getWordsTrie() 452 | benchmarks := []struct { 453 | name string 454 | inputKey []string 455 | inputOptions []func(*trie.SearchOptions) 456 | }{ 457 | { 458 | name: "prefix", 459 | inputKey: strings.Split("ab", ""), 460 | }, 461 | { 462 | name: "prefix-with-max-results", 463 | inputKey: strings.Split("ab", ""), 464 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxResults(20)}, 465 | }, 466 | { 467 | name: "edit-distance", 468 | inputKey: strings.Split("someday", ""), 469 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(5)}, 470 | }, 471 | { 472 | name: "edit-distance-with-edit-ops", 473 | inputKey: strings.Split("someday", ""), 474 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(5), trie.WithEditOps()}, 475 | }, 476 | { 477 | name: "edit-distance-with-edit-ops-with-max-results", 478 | inputKey: strings.Split("someday", ""), 479 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(5), trie.WithEditOps(), trie.WithMaxResults(20)}, 480 | }, 481 | { 482 | name: "edit-distance-with-edit-ops-with-max-results-with-top-k", 483 | inputKey: strings.Split("someday", ""), 484 | inputOptions: []func(*trie.SearchOptions){trie.WithMaxEditDistance(5), trie.WithEditOps(), trie.WithMaxResults(20), 485 | trie.WithTopKLeastEdited()}, 486 | }, 487 | } 488 | var results *trie.SearchResults 489 | for _, bm := range benchmarks { 490 | b.Run(bm.name, func(b *testing.B) { 491 | for i := 0; i < b.N; i++ { 492 | results = tri.Search(bm.inputKey, bm.inputOptions...) 493 | } 494 | benchmarkResults = results 495 | }) 496 | } 497 | } 498 | 499 | func getWordsTrie() *trie.Trie { 500 | if wordsTrie != nil { 501 | return wordsTrie 502 | } 503 | f, err := os.Open("./demo/wasm/words.txt") 504 | if err != nil { 505 | panic(err) 506 | } 507 | tri := trie.New() 508 | r := bufio.NewReader(f) 509 | for { 510 | word, err := r.ReadString('\n') 511 | if err == io.EOF { 512 | break 513 | } 514 | if err != nil { 515 | panic(err) 516 | } 517 | word = strings.TrimRight(word, "\n") 518 | word = strings.TrimRight(word, "\r") // windows 519 | key := strings.Split(word, "") 520 | tri.Put(key, nil) 521 | } 522 | wordsTrie = tri 523 | return tri 524 | } 525 | -------------------------------------------------------------------------------- /search_whitebox_test.go: -------------------------------------------------------------------------------- 1 | package trie 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestTrie_getEditOps(t *testing.T) { 12 | testCases := []struct { 13 | fromKeyColumn []string 14 | toKey []string 15 | rows [][]int 16 | expectedOps []*EditOp 17 | }{ 18 | { 19 | fromKeyColumn: strings.Split("sitting", ""), 20 | toKey: strings.Split("kitten", ""), 21 | rows: [][]int{ 22 | {0, 1, 2, 3, 4, 5, 6}, 23 | {1, 1, 2, 3, 4, 5, 6}, 24 | {2, 2, 1, 2, 3, 4, 5}, 25 | {3, 3, 2, 1, 2, 3, 4}, 26 | {4, 4, 3, 2, 1, 2, 3}, 27 | {5, 5, 4, 3, 2, 2, 3}, 28 | {6, 6, 5, 4, 3, 3, 2}, 29 | {7, 7, 6, 5, 4, 4, 3}, 30 | }, 31 | expectedOps: []*EditOp{ 32 | {Type: EditOpTypeReplace, KeyPart: "s", ReplaceWith: "k"}, 33 | {Type: EditOpTypeNoEdit, KeyPart: "i"}, 34 | {Type: EditOpTypeNoEdit, KeyPart: "t"}, 35 | {Type: EditOpTypeNoEdit, KeyPart: "t"}, 36 | {Type: EditOpTypeReplace, KeyPart: "i", ReplaceWith: "e"}, 37 | {Type: EditOpTypeNoEdit, KeyPart: "n"}, 38 | {Type: EditOpTypeDelete, KeyPart: "g"}, 39 | }, 40 | }, 41 | { 42 | fromKeyColumn: strings.Split("Sunday", ""), 43 | toKey: strings.Split("Saturday", ""), 44 | rows: [][]int{ 45 | {0, 1, 2, 3, 4, 5, 6, 7, 8}, 46 | {1, 0, 1, 2, 3, 4, 5, 6, 7}, 47 | {2, 1, 1, 2, 2, 3, 4, 5, 6}, 48 | {3, 2, 2, 2, 3, 3, 4, 5, 6}, 49 | {4, 3, 3, 3, 3, 4, 3, 4, 5}, 50 | {5, 4, 3, 4, 4, 4, 4, 3, 4}, 51 | {6, 5, 4, 4, 5, 5, 5, 4, 3}, 52 | }, 53 | expectedOps: []*EditOp{ 54 | {Type: EditOpTypeNoEdit, KeyPart: "S"}, 55 | {Type: EditOpTypeInsert, KeyPart: "a"}, 56 | {Type: EditOpTypeInsert, KeyPart: "t"}, 57 | {Type: EditOpTypeNoEdit, KeyPart: "u"}, 58 | {Type: EditOpTypeReplace, KeyPart: "n", ReplaceWith: "r"}, 59 | {Type: EditOpTypeNoEdit, KeyPart: "d"}, 60 | {Type: EditOpTypeNoEdit, KeyPart: "a"}, 61 | {Type: EditOpTypeNoEdit, KeyPart: "y"}, 62 | }, 63 | }, 64 | } 65 | 66 | for _, tc := range testCases { 67 | t.Run(fmt.Sprintf("from %v to %v", tc.fromKeyColumn, tc.toKey), func(t *testing.T) { 68 | tri := New() 69 | actual := tri.getEditOps(&tc.rows, &tc.fromKeyColumn, tc.toKey) 70 | assert.Equal(t, tc.expectedOps, actual) 71 | }) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /trie.go: -------------------------------------------------------------------------------- 1 | package trie 2 | 3 | import ( 4 | "github.com/shivamMg/ppds/tree" 5 | ) 6 | 7 | const ( 8 | RootKeyPart = "^" 9 | terminalSuffix = "($)" 10 | ) 11 | 12 | // Trie is the trie data structure. 13 | type Trie struct { 14 | root *Node 15 | } 16 | 17 | // Node is a tree node inside Trie. 18 | type Node struct { 19 | keyPart string 20 | isTerminal bool 21 | value interface{} 22 | dllNode *dllNode 23 | children map[string]*Node 24 | childrenDLL *doublyLinkedList 25 | } 26 | 27 | func newNode(keyPart string) *Node { 28 | return &Node{ 29 | keyPart: keyPart, 30 | children: make(map[string]*Node), 31 | childrenDLL: &doublyLinkedList{}, 32 | } 33 | } 34 | 35 | // KeyPart returns the part (string) of the key ([]string) that this Node represents. 36 | func (n *Node) KeyPart() string { 37 | return n.keyPart 38 | } 39 | 40 | // IsTerminal returns a boolean that tells whether a key ends at this Node. 41 | func (n *Node) IsTerminal() bool { 42 | return n.isTerminal 43 | } 44 | 45 | // Value returns the value stored for the key ending at this Node. If Node is not a terminal, it returns nil. 46 | func (n *Node) Value() interface{} { 47 | return n.value 48 | } 49 | 50 | // SetValue sets the value for the key ending at this Node. If Node is not a terminal, value is not set. 51 | func (n *Node) SetValue(value interface{}) { 52 | if n.isTerminal { 53 | n.value = value 54 | } 55 | } 56 | 57 | // ChildNodes returns the child-nodes of this Node. 58 | func (n *Node) ChildNodes() []*Node { 59 | return n.childNodes() 60 | } 61 | 62 | // Data is used in Print(). Use Value() to get value at this Node. 63 | func (n *Node) Data() interface{} { 64 | data := n.keyPart 65 | if n.isTerminal { 66 | data += " " + terminalSuffix 67 | } 68 | return data 69 | } 70 | 71 | // Children is used in Print(). Use ChildNodes() to get child-nodes of this Node. 72 | func (n *Node) Children() []tree.Node { 73 | children := n.childNodes() 74 | result := make([]tree.Node, len(children)) 75 | for i, child := range children { 76 | result[i] = tree.Node(child) 77 | } 78 | return result 79 | } 80 | 81 | // Print prints the tree rooted at this Node. A Trie's root node is printed as RootKeyPart. 82 | // All the terminal nodes are suffixed with ($). 83 | func (n *Node) Print() { 84 | tree.PrintHrn(n) 85 | } 86 | 87 | func (n *Node) Sprint() string { 88 | return tree.SprintHrn(n) 89 | } 90 | 91 | func (n *Node) childNodes() []*Node { 92 | children := make([]*Node, 0, len(n.children)) 93 | dllNode := n.childrenDLL.head 94 | for dllNode != nil { 95 | children = append(children, dllNode.trieNode) 96 | dllNode = dllNode.next 97 | } 98 | return children 99 | } 100 | 101 | // New returns a new instance of Trie. 102 | func New() *Trie { 103 | return &Trie{root: newNode(RootKeyPart)} 104 | } 105 | 106 | // Root returns the root node of the Trie. 107 | func (t *Trie) Root() *Node { 108 | return t.root 109 | } 110 | 111 | // Put upserts value for the given key in the Trie. It returns a boolean depending on 112 | // whether the key already existed or not. 113 | func (t *Trie) Put(key []string, value interface{}) (existed bool) { 114 | node := t.root 115 | for i, part := range key { 116 | child, ok := node.children[part] 117 | if !ok { 118 | child = newNode(part) 119 | child.dllNode = newDLLNode(child) 120 | node.children[part] = child 121 | node.childrenDLL.append(child.dllNode) 122 | } 123 | if i == len(key)-1 { 124 | existed = child.isTerminal 125 | child.isTerminal = true 126 | child.value = value 127 | } 128 | node = child 129 | } 130 | return existed 131 | } 132 | 133 | // Delete deletes key-value for the given key in the Trie. It returns (value, true) if the key existed, 134 | // else (nil, false). 135 | func (t *Trie) Delete(key []string) (value interface{}, existed bool) { 136 | node := t.root 137 | parent := make(map[*Node]*Node) 138 | for _, keyPart := range key { 139 | child, ok := node.children[keyPart] 140 | if !ok { 141 | return nil, false 142 | } 143 | parent[child] = node 144 | node = child 145 | } 146 | if !node.isTerminal { 147 | return nil, false 148 | } 149 | node.isTerminal = false 150 | value = node.value 151 | node.value = nil 152 | for node != nil && !node.isTerminal && len(node.children) == 0 { 153 | delete(parent[node].children, node.keyPart) 154 | parent[node].childrenDLL.pop(node.dllNode) 155 | node = parent[node] 156 | } 157 | return value, true 158 | } 159 | -------------------------------------------------------------------------------- /trie_test.go: -------------------------------------------------------------------------------- 1 | package trie_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/shivamMg/ppds/tree" 7 | "github.com/shivamMg/trie" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestNode_SetValue(t *testing.T) { 12 | tri := trie.New() 13 | tri.Put([]string{"a", "b"}, 1) 14 | tri.Put([]string{"a", "b", "c"}, 2) 15 | 16 | node := tri.Root() 17 | node.SetValue(10) 18 | assert.Equal(t, nil, node.Value()) 19 | 20 | node = tri.Root().ChildNodes()[0] 21 | node.SetValue(10) 22 | assert.Equal(t, nil, node.Value()) 23 | 24 | node = tri.Root().ChildNodes()[0].ChildNodes()[0] 25 | node.SetValue(10) 26 | assert.Equal(t, 10, node.Value()) 27 | } 28 | 29 | func TestTrie_Put(t *testing.T) { 30 | tri := trie.New() 31 | existed := tri.Put([]string{"an", "umbrella"}, 2) 32 | assert.False(t, existed) 33 | existed = tri.Put([]string{"the"}, 1) 34 | assert.False(t, existed) 35 | existed = tri.Put([]string{"the", "swimmer"}, 4) 36 | assert.False(t, existed) 37 | existed = tri.Put([]string{"the", "tree"}, 3) 38 | assert.False(t, existed) 39 | 40 | // validate full tree 41 | expected := `^ 42 | ├─ an 43 | │ └─ umbrella ($) 44 | └─ the ($) 45 | ├─ swimmer ($) 46 | └─ tree ($) 47 | ` 48 | actual := tree.SprintHrn(tri.Root()) 49 | assert.Equal(t, expected, actual) 50 | 51 | // validate attributes for each node 52 | root := tri.Root() 53 | assert.Equal(t, root.KeyPart(), trie.RootKeyPart) 54 | assert.False(t, root.IsTerminal()) 55 | assert.Nil(t, root.Value()) 56 | 57 | rootChildren := root.ChildNodes() 58 | an, the := rootChildren[0], rootChildren[1] 59 | assert.Equal(t, an.KeyPart(), "an") 60 | assert.False(t, an.IsTerminal()) 61 | assert.Nil(t, an.Value()) 62 | 63 | assert.Equal(t, the.KeyPart(), "the") 64 | assert.True(t, the.IsTerminal()) 65 | assert.Equal(t, the.Value(), 1) 66 | 67 | umbrella := an.ChildNodes()[0] 68 | assert.Equal(t, umbrella.KeyPart(), "umbrella") 69 | assert.True(t, umbrella.IsTerminal()) 70 | assert.Equal(t, umbrella.Value(), 2) 71 | 72 | theChildren := the.ChildNodes() 73 | swimmer, tree_ := theChildren[0], theChildren[1] 74 | assert.Equal(t, swimmer.KeyPart(), "swimmer") 75 | assert.True(t, swimmer.IsTerminal()) 76 | assert.Equal(t, swimmer.Value(), 4) 77 | 78 | assert.Equal(t, tree_.KeyPart(), "tree") 79 | assert.True(t, tree_.IsTerminal()) 80 | assert.Equal(t, tree_.Value(), 3) 81 | 82 | // validate update 83 | existed = tri.Put([]string{"an", "umbrella"}, 5) 84 | assert.True(t, existed) 85 | assert.Equal(t, umbrella.Value(), 5) 86 | } 87 | 88 | func TestTrie_Delete(t *testing.T) { 89 | tri := trie.New() 90 | tri.Put([]string{"an", "apple", "tree"}, 5) 91 | tri.Put([]string{"an", "umbrella"}, 6) 92 | tri.Put([]string{"the"}, 1) 93 | tri.Put([]string{"the", "green", "tree"}, 4) 94 | tri.Put([]string{"the", "quick", "brown", "fox"}, 2) 95 | tri.Put([]string{"the", "quick", "swimmer"}, 3) 96 | 97 | value, existed := tri.Delete([]string{"the", "quick", "brown", "fox"}) 98 | assert.True(t, existed) 99 | assert.Equal(t, value, 2) 100 | expected := `^ 101 | ├─ an 102 | │ ├─ apple 103 | │ │ └─ tree ($) 104 | │ └─ umbrella ($) 105 | └─ the ($) 106 | ├─ green 107 | │ └─ tree ($) 108 | └─ quick 109 | └─ swimmer ($) 110 | ` 111 | assert.Equal(t, expected, tri.Root().Sprint()) 112 | 113 | value, existed = tri.Delete([]string{"the", "quick", "swimmer"}) 114 | assert.True(t, existed) 115 | assert.Equal(t, value, 3) 116 | expected = `^ 117 | ├─ an 118 | │ ├─ apple 119 | │ │ └─ tree ($) 120 | │ └─ umbrella ($) 121 | └─ the ($) 122 | └─ green 123 | └─ tree ($) 124 | ` 125 | assert.Equal(t, expected, tri.Root().Sprint()) 126 | 127 | value, existed = tri.Delete([]string{"the"}) 128 | assert.True(t, existed) 129 | assert.Equal(t, value, 1) 130 | expected = `^ 131 | ├─ an 132 | │ ├─ apple 133 | │ │ └─ tree ($) 134 | │ └─ umbrella ($) 135 | └─ the 136 | └─ green 137 | └─ tree ($) 138 | ` 139 | assert.Equal(t, expected, tri.Root().Sprint()) 140 | 141 | value, existed = tri.Delete([]string{"non", "existing"}) 142 | assert.False(t, existed) 143 | assert.Nil(t, value) 144 | expected = `^ 145 | ├─ an 146 | │ ├─ apple 147 | │ │ └─ tree ($) 148 | │ └─ umbrella ($) 149 | └─ the 150 | └─ green 151 | └─ tree ($) 152 | ` 153 | assert.Equal(t, expected, tri.Root().Sprint()) 154 | 155 | value, existed = tri.Delete([]string{"an"}) 156 | assert.False(t, existed) 157 | assert.Nil(t, value) 158 | expected = `^ 159 | ├─ an 160 | │ ├─ apple 161 | │ │ └─ tree ($) 162 | │ └─ umbrella ($) 163 | └─ the 164 | └─ green 165 | └─ tree ($) 166 | ` 167 | assert.Equal(t, expected, tri.Root().Sprint()) 168 | } 169 | -------------------------------------------------------------------------------- /walk.go: -------------------------------------------------------------------------------- 1 | package trie 2 | 3 | type WalkFunc func(key []string, node *Node) error 4 | 5 | // Walk traverses the Trie and calls walker function. If walker function returns an error, Walk early-returns with that error. 6 | // Traversal follows insertion order. 7 | func (t *Trie) Walk(key []string, walker WalkFunc) error { 8 | node := t.root 9 | for _, keyPart := range key { 10 | child, ok := node.children[keyPart] 11 | if !ok { 12 | return nil 13 | } 14 | node = child 15 | } 16 | return t.walk(node, &key, walker) 17 | } 18 | 19 | func (t *Trie) walk(node *Node, prefixKey *[]string, walker WalkFunc) error { 20 | if node.isTerminal { 21 | key := make([]string, len(*prefixKey)) 22 | copy(key, *prefixKey) 23 | if err := walker(key, node); err != nil { 24 | return err 25 | } 26 | } 27 | 28 | for dllNode := node.childrenDLL.head; dllNode != nil; dllNode = dllNode.next { 29 | child := dllNode.trieNode 30 | *prefixKey = append(*prefixKey, child.keyPart) 31 | err := t.walk(child, prefixKey, walker) 32 | *prefixKey = (*prefixKey)[:len(*prefixKey)-1] 33 | if err != nil { 34 | return err 35 | } 36 | } 37 | return nil 38 | } 39 | -------------------------------------------------------------------------------- /walk_test.go: -------------------------------------------------------------------------------- 1 | package trie_test 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/shivamMg/trie" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestTrie_WalkErr(t *testing.T) { 12 | tri := trie.New() 13 | tri.Put([]string{"d", "a", "l", "i"}, 1) 14 | tri.Put([]string{"d", "a", "l", "i", "b"}, 2) 15 | tri.Put([]string{"d", "a", "l", "i", "b", "e"}, 3) 16 | tri.Put([]string{"d", "a", "l", "i", "b", "e", "r", "t"}, 4) 17 | 18 | var selected []string 19 | walker := func(key []string, node *trie.Node) error { 20 | what := node.Value().(int) 21 | if what == 3 { 22 | selected = key 23 | return errors.New("found") 24 | } 25 | return nil 26 | } 27 | 28 | err := tri.Walk(nil, walker) 29 | assert.EqualError(t, err, "found") 30 | assert.EqualValues(t, []string{"d", "a", "l", "i", "b", "e"}, selected) 31 | } 32 | 33 | func TestTrie_Walk(t *testing.T) { 34 | tri := trie.New() 35 | tri.Put([]string{"d", "a", "l", "i"}, []int{0, 1, 2, 4, 5}) 36 | tri.Put([]string{"d", "a", "l", "i", "b"}, []int{1, 2, 4, 5}) 37 | tri.Put([]string{"d", "a", "l", "i", "b", "e"}, []int{1, 0, 2, 4, 5, 0}) 38 | tri.Put([]string{"d", "a", "l", "i", "b", "e", "r", "t"}, []int{1, 2, 4, 5}) 39 | type KVPair struct { 40 | key []string 41 | value []int 42 | } 43 | var selected []KVPair 44 | walker := func(key []string, node *trie.Node) error { 45 | what := node.Value().([]int) 46 | for _, i := range what { 47 | if i == 0 { 48 | selected = append(selected, KVPair{key, what}) 49 | break 50 | } 51 | } 52 | return nil 53 | } 54 | 55 | err := tri.Walk(nil, walker) 56 | assert.NoError(t, err) 57 | expected := []KVPair{ 58 | {[]string{"d", "a", "l", "i"}, []int{0, 1, 2, 4, 5}}, 59 | {[]string{"d", "a", "l", "i", "b", "e"}, []int{1, 0, 2, 4, 5, 0}}, 60 | } 61 | assert.EqualValues(t, expected, selected) 62 | 63 | selected = nil 64 | err = tri.Walk([]string{"d", "a", "l", "i", "b"}, walker) 65 | assert.NoError(t, err) 66 | expected = []KVPair{ 67 | {[]string{"d", "a", "l", "i", "b", "e"}, []int{1, 0, 2, 4, 5, 0}}, 68 | } 69 | assert.EqualValues(t, expected, selected) 70 | 71 | selected = nil 72 | err = tri.Walk([]string{"a", "b", "c"}, walker) 73 | assert.NoError(t, err) 74 | assert.Nil(t, selected) 75 | } 76 | --------------------------------------------------------------------------------