├── .devops └── build-nuget.yml ├── .gitignore ├── LICENSE ├── README.md ├── Src ├── HNSW.Net.Demo │ ├── App.config │ ├── HNSW.Net.Demo.csproj │ ├── MetricsEventListener.cs │ └── Program.cs ├── HNSW.Net.TestFiltering │ ├── HNSW.Net.TestFiltering.csproj │ └── Program.cs ├── HNSW.Net.Tests │ ├── BinaryHeapTests.cs │ ├── HNSW.Net.Tests.csproj │ ├── RewindableRandomNumberGenerator.cs │ ├── SmallWorldTests.cs │ └── vectors.txt ├── HNSW.Net.sln └── HNSW.Net │ ├── Algorithms.Algorithm3.cs │ ├── Algorithms.Algorithm4.cs │ ├── Algorithms.cs │ ├── BinaryHeap.cs │ ├── CosineDistance.cs │ ├── DefaultRandomGenerator.cs │ ├── DistanceCache.cs │ ├── DistanceUtils.cs │ ├── EventSources.cs │ ├── FastRandom.cs │ ├── Graph.Core.cs │ ├── Graph.Searcher.cs │ ├── Graph.Utils.cs │ ├── Graph.cs │ ├── GraphChangedException.cs │ ├── HNSW.Net.csproj │ ├── IProgressReporter.cs │ ├── IProvideRandomValues.cs │ ├── MessagePackCompat │ ├── FloatBits.cs │ ├── MessagePackBinary.cs │ ├── README.md │ └── StringEncoding.cs │ ├── Node.cs │ ├── Properties │ └── AssemblyInfo.cs │ ├── ReverseComparer.cs │ ├── ScopeLatencyTracker.cs │ ├── SmallWorld.cs │ ├── ThreadSafeFastRandom.cs │ ├── TravelingCosts.cs │ └── VectorUtils.cs └── src └── HNSW.Net └── NeighbourSelectionHeuristic.cs /.devops/build-nuget.yml: -------------------------------------------------------------------------------- 1 | variables: 2 | project: './Src/HNSW.Net/HNSW.Net.csproj' 3 | buildConfiguration: 'Release' 4 | targetVersion: yy.M.$(build.buildId) 5 | 6 | trigger: 7 | - master 8 | 9 | pool: 10 | vmImage: 'windows-latest' 11 | 12 | steps: 13 | - task: NuGetToolInstaller@0 14 | 15 | 16 | - task: PowerShell@2 17 | displayName: 'Create CalVer Version' 18 | inputs: 19 | targetType: 'inline' 20 | script: | 21 | $dottedDate = (Get-Date).ToString("yy.M") 22 | $buildID = $($env:BUILD_BUILDID) 23 | $newTargetVersion = "$dottedDate.$buildID" 24 | Write-Host "##vso[task.setvariable variable=targetVersion;]$newTargetVersion" 25 | Write-Host "Updated targetVersion to '$newTargetVersion'" 26 | 27 | 28 | - task: UseDotNet@2 29 | displayName: 'Use .NET 8.0 SDK' 30 | inputs: 31 | packageType: sdk 32 | version: 8.x 33 | includePreviewVersions: false 34 | installationPath: $(Agent.ToolsDirectory)\dotnet 35 | 36 | - task: DotNetCoreCLI@2 37 | inputs: 38 | command: 'restore' 39 | projects: '$(project)' 40 | displayName: 'restore nuget' 41 | 42 | - task: DotNetCoreCLI@2 43 | inputs: 44 | command: 'build' 45 | projects: '$(project)' 46 | arguments: '-c $(buildConfiguration) /p:Version=$(targetVersion) /p:LangVersion=latest' 47 | 48 | - task: DotNetCoreCLI@2 49 | inputs: 50 | command: 'pack' 51 | packagesToPack: '$(project)' 52 | versioningScheme: 'off' 53 | configuration: '$(buildConfiguration)' 54 | buildProperties: 'Version="$(targetVersion)";LangVersion="latest"' 55 | nobuild: true 56 | 57 | - task: NuGetCommand@2 58 | inputs: 59 | command: 'push' 60 | packagesToPush: '**/*.nupkg' 61 | nuGetFeedType: 'external' 62 | publishFeedCredentials: 'nuget-curiosity-org' 63 | displayName: 'push nuget' 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.suo 8 | *.user 9 | *.userosscache 10 | *.sln.docstates 11 | 12 | # User-specific files (MonoDevelop/Xamarin Studio) 13 | *.userprefs 14 | 15 | # Build results 16 | [Dd]ebug/ 17 | [Dd]ebugPublic/ 18 | [Rr]elease/ 19 | [Rr]eleases/ 20 | x64/ 21 | x86/ 22 | bld/ 23 | [Bb]in/ 24 | [Oo]bj/ 25 | [Ll]og/ 26 | 27 | # Visual Studio 2015/2017 cache/options directory 28 | .vs/ 29 | # Uncomment if you have tasks that create the project's static files in wwwroot 30 | #wwwroot/ 31 | 32 | # Visual Studio 2017 auto generated files 33 | Generated\ Files/ 34 | 35 | # MSTest test Results 36 | [Tt]est[Rr]esult*/ 37 | [Bb]uild[Ll]og.* 38 | 39 | # NUNIT 40 | *.VisualState.xml 41 | TestResult.xml 42 | 43 | # Build Results of an ATL Project 44 | [Dd]ebugPS/ 45 | [Rr]eleasePS/ 46 | dlldata.c 47 | 48 | # Benchmark Results 49 | BenchmarkDotNet.Artifacts/ 50 | 51 | # .NET Core 52 | project.lock.json 53 | project.fragment.lock.json 54 | artifacts/ 55 | **/Properties/launchSettings.json 56 | 57 | # StyleCop 58 | StyleCopReport.xml 59 | 60 | # Files built by Visual Studio 61 | *_i.c 62 | *_p.c 63 | *_i.h 64 | *.ilk 65 | *.meta 66 | *.obj 67 | *.iobj 68 | *.pch 69 | *.pdb 70 | *.ipdb 71 | *.pgc 72 | *.pgd 73 | *.rsp 74 | *.sbr 75 | *.tlb 76 | *.tli 77 | *.tlh 78 | *.tmp 79 | *.tmp_proj 80 | *.log 81 | *.vspscc 82 | *.vssscc 83 | .builds 84 | *.pidb 85 | *.svclog 86 | *.scc 87 | 88 | # Chutzpah Test files 89 | _Chutzpah* 90 | 91 | # Visual C++ cache files 92 | ipch/ 93 | *.aps 94 | *.ncb 95 | *.opendb 96 | *.opensdf 97 | *.sdf 98 | *.cachefile 99 | *.VC.db 100 | *.VC.VC.opendb 101 | 102 | # Visual Studio profiler 103 | *.psess 104 | *.vsp 105 | *.vspx 106 | *.sap 107 | 108 | # Visual Studio Trace Files 109 | *.e2e 110 | 111 | # TFS 2012 Local Workspace 112 | $tf/ 113 | 114 | # Guidance Automation Toolkit 115 | *.gpState 116 | 117 | # ReSharper is a .NET coding add-in 118 | _ReSharper*/ 119 | *.[Rr]e[Ss]harper 120 | *.DotSettings.user 121 | 122 | # JustCode is a .NET coding add-in 123 | .JustCode 124 | 125 | # TeamCity is a build add-in 126 | _TeamCity* 127 | 128 | # DotCover is a Code Coverage Tool 129 | *.dotCover 130 | 131 | # AxoCover is a Code Coverage Tool 132 | .axoCover/* 133 | !.axoCover/settings.json 134 | 135 | # Visual Studio code coverage results 136 | *.coverage 137 | *.coveragexml 138 | 139 | # NCrunch 140 | _NCrunch_* 141 | .*crunch*.local.xml 142 | nCrunchTemp_* 143 | 144 | # MightyMoose 145 | *.mm.* 146 | AutoTest.Net/ 147 | 148 | # Web workbench (sass) 149 | .sass-cache/ 150 | 151 | # Installshield output folder 152 | [Ee]xpress/ 153 | 154 | # DocProject is a documentation generator add-in 155 | DocProject/buildhelp/ 156 | DocProject/Help/*.HxT 157 | DocProject/Help/*.HxC 158 | DocProject/Help/*.hhc 159 | DocProject/Help/*.hhk 160 | DocProject/Help/*.hhp 161 | DocProject/Help/Html2 162 | DocProject/Help/html 163 | 164 | # Click-Once directory 165 | publish/ 166 | 167 | # Publish Web Output 168 | *.[Pp]ublish.xml 169 | *.azurePubxml 170 | # Note: Comment the next line if you want to checkin your web deploy settings, 171 | # but database connection strings (with potential passwords) will be unencrypted 172 | *.pubxml 173 | *.publishproj 174 | 175 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 176 | # checkin your Azure Web App publish settings, but sensitive information contained 177 | # in these scripts will be unencrypted 178 | PublishScripts/ 179 | 180 | # NuGet Packages 181 | *.nupkg 182 | # The packages folder can be ignored because of Package Restore 183 | **/[Pp]ackages/* 184 | # except build/, which is used as an MSBuild target. 185 | !**/[Pp]ackages/build/ 186 | # Uncomment if necessary however generally it will be regenerated when needed 187 | #!**/[Pp]ackages/repositories.config 188 | # NuGet v3's project.json files produces more ignorable files 189 | *.nuget.props 190 | *.nuget.targets 191 | 192 | # Microsoft Azure Build Output 193 | csx/ 194 | *.build.csdef 195 | 196 | # Microsoft Azure Emulator 197 | ecf/ 198 | rcf/ 199 | 200 | # Windows Store app package directories and files 201 | AppPackages/ 202 | BundleArtifacts/ 203 | Package.StoreAssociation.xml 204 | _pkginfo.txt 205 | *.appx 206 | 207 | # Visual Studio cache files 208 | # files ending in .cache can be ignored 209 | *.[Cc]ache 210 | # but keep track of directories ending in .cache 211 | !*.[Cc]ache/ 212 | 213 | # Others 214 | ClientBin/ 215 | ~$* 216 | *~ 217 | *.dbmdl 218 | *.dbproj.schemaview 219 | *.jfm 220 | *.pfx 221 | *.publishsettings 222 | orleans.codegen.cs 223 | 224 | # Including strong name files can present a security risk 225 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 226 | #*.snk 227 | 228 | # Since there are multiple workflows, uncomment next line to ignore bower_components 229 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 230 | #bower_components/ 231 | 232 | # RIA/Silverlight projects 233 | Generated_Code/ 234 | 235 | # Backup & report files from converting an old project file 236 | # to a newer Visual Studio version. Backup files are not needed, 237 | # because we have git ;-) 238 | _UpgradeReport_Files/ 239 | Backup*/ 240 | UpgradeLog*.XML 241 | UpgradeLog*.htm 242 | ServiceFabricBackup/ 243 | *.rptproj.bak 244 | 245 | # SQL Server files 246 | *.mdf 247 | *.ldf 248 | *.ndf 249 | 250 | # Business Intelligence projects 251 | *.rdl.data 252 | *.bim.layout 253 | *.bim_*.settings 254 | *.rptproj.rsuser 255 | 256 | # Microsoft Fakes 257 | FakesAssemblies/ 258 | 259 | # GhostDoc plugin setting file 260 | *.GhostDoc.xml 261 | 262 | # Node.js Tools for Visual Studio 263 | .ntvs_analysis.dat 264 | node_modules/ 265 | 266 | # Visual Studio 6 build log 267 | *.plg 268 | 269 | # Visual Studio 6 workspace options file 270 | *.opt 271 | 272 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 273 | *.vbw 274 | 275 | # Visual Studio LightSwitch build output 276 | **/*.HTMLClient/GeneratedArtifacts 277 | **/*.DesktopClient/GeneratedArtifacts 278 | **/*.DesktopClient/ModelManifest.xml 279 | **/*.Server/GeneratedArtifacts 280 | **/*.Server/ModelManifest.xml 281 | _Pvt_Extensions 282 | 283 | # Paket dependency manager 284 | .paket/paket.exe 285 | paket-files/ 286 | 287 | # FAKE - F# Make 288 | .fake/ 289 | 290 | # JetBrains Rider 291 | .idea/ 292 | *.sln.iml 293 | 294 | # CodeRush 295 | .cr/ 296 | 297 | # Python Tools for Visual Studio (PTVS) 298 | __pycache__/ 299 | *.pyc 300 | 301 | # Cake - Uncomment if you are using it 302 | # tools/** 303 | # !tools/packages.config 304 | 305 | # Tabs Studio 306 | *.tss 307 | 308 | # Telerik's JustMock configuration file 309 | *.jmconfig 310 | 311 | # BizTalk build output 312 | *.btp.cs 313 | *.btm.cs 314 | *.odx.cs 315 | *.xsd.cs 316 | 317 | # OpenCover UI analysis results 318 | OpenCover/ 319 | 320 | # Azure Stream Analytics local run output 321 | ASALocalRun/ 322 | 323 | # MSBuild Binary and Structured Log 324 | *.binlog 325 | 326 | # NVidia Nsight GPU debugger configuration file 327 | *.nvuser 328 | 329 | # MFractors (Xamarin productivity tool) working folder 330 | .mfractor/ 331 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://dev.azure.com/curiosity-ai/mosaik/_apis/build/status/hnsw-sharp?branchName=master)](https://dev.azure.com/curiosity-ai/mosaik/_build/latest?definitionId=7&branchName=master) 2 | 3 | 4 | 5 | 6 | # HNSW.Net 7 | .Net library for fast approximate nearest neighbours search. 8 | 9 | Exact _k_ nearest neighbours search algorithms tend to perform poorly in high-dimensional spaces. To overcome curse of dimensionality the ANN algorithms come in place. This library implements one of such algorithms described in the ["Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs"](https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf) article. It provides simple API for building nearest neighbours graphs, (de)serializing them and running k-NN search queries. 10 | 11 | ## Usage 12 | Check out the following code snippets once you've added the library reference to your project. 13 | ##### How to build a graph? 14 | ```c# 15 | var parameters = new SmallWorld.Parameters() 16 | { 17 | M = 15, 18 | LevelLambda = 1 / Math.Log(15), 19 | }; 20 | 21 | float[] vectors = GetFloatVectors(); 22 | var graph = new SmallWorld(CosineDistance.NonOptimized); 23 | graph.BuildGraph(vectors, new Random(42), parameters); 24 | ``` 25 | ##### How to run k-NN search? 26 | ```c# 27 | SmallWorld graph = GetGraph(); 28 | 29 | float[] query = Enumerable.Repeat(1f, 100).ToArray(); 30 | var best20 = graph.KNNSearch(query, 20); 31 | var best1 = best20.OrderBy(r => r.Distance).First(); 32 | ``` 33 | ##### How to (de)serialize the graph? 34 | ```c# 35 | SmallWorld graph = GetGraph(); 36 | byte[] buffer = graph.SerializeGraph(); // buffer stores information about parameters and graph edges 37 | 38 | // distance function must be the same as the one which was used for building the original graph 39 | var copy = new SmallWorld(CosineDistance.NonOptimized); 40 | copy.DeserializeGraph(vectors, buffer); // the original vectors to attach to the "copy" vertices 41 | ``` 42 | ##### Distance functions 43 | The only one distance function supplied by the library is the cosine distance. But there are 4 versions to address universality/performance tradeoff. 44 | ```c# 45 | CosineDistance.NonOptimized // most generic version works for all cases 46 | CosineDistance.ForUnits // gives correct result only when arguments are "unit" vectors 47 | CosineDistance.SIMD // uses SIMD instructions to optimize calculations 48 | CosineDistance.SIMDForUnits // uses SIMD and requires arguments to be "units" 49 | ``` 50 | But the API allows to inject any custom distance function tailored specifically for your needs. 51 | 52 | ## Contributing 53 | Your contributions and suggestions are very welcome! 54 | Please note that this project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 55 | 56 | The contributions to this project are [released](https://help.github.com/articles/github-terms-of-service/#6-contributions-under-repository-license) to the public under the [project's open source license](LICENSE). Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.microsoft.com. 57 | 58 | ### How to contribute 59 | If you've found a bug or have a feature request then please open an issue with detailed description. 60 | We will be glad to see your pull requests as well. 61 | 62 | 1. Prepare workspace. 63 | ``` 64 | git clone https://github.com/Microsoft/HNSW.Net.git 65 | cd HNSW.Net 66 | git checkout -b [username]/[feature] 67 | ``` 68 | 2. Update the library and add tests if needed. 69 | 3. Build and test the changes. 70 | ``` 71 | cd Src 72 | dotnet build 73 | dotnet test 74 | ``` 75 | 4. Send the pull request from `[username]/[feature]` to `master` branch. 76 | 5. Get approve and merge the changes. 77 | 78 | When you submit a pull request, a CLA-bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repositories using our CLA. 79 | 80 | ### Releasing 81 | The library is distributed as a bundle of sources. 82 | We are working on enabling CI and creating Nuget package for the project. 83 | 84 | -------------------------------------------------------------------------------- /Src/HNSW.Net.Demo/App.config: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /Src/HNSW.Net.Demo/HNSW.Net.Demo.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | net8.0 5 | Exe 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /Src/HNSW.Net.Demo/MetricsEventListener.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net.Demo; 7 | 8 | using System; 9 | using System.Collections.Generic; 10 | using System.Diagnostics.Tracing; 11 | using System.Linq; 12 | 13 | public class MetricsEventListener : EventListener 14 | { 15 | private readonly EventSource eventSource; 16 | 17 | public MetricsEventListener(EventSource eventSource) 18 | { 19 | this.eventSource = eventSource; 20 | EnableEvents(this.eventSource, EventLevel.LogAlways, EventKeywords.All, new Dictionary { { "EventCounterIntervalSec", "1" } }); 21 | } 22 | 23 | public override void Dispose() 24 | { 25 | DisableEvents(eventSource); 26 | base.Dispose(); 27 | } 28 | 29 | protected override void OnEventWritten(EventWrittenEventArgs eventData) 30 | { 31 | var counterData = eventData.Payload?.FirstOrDefault() as IDictionary; 32 | if (counterData?.Count == 0) 33 | { 34 | return; 35 | } 36 | 37 | Console.WriteLine($"[{counterData["Name"]:n1}]: Avg={counterData["Mean"]:n1}; SD={counterData["StandardDeviation"]:n1}; Count={counterData["Count"]}"); 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /Src/HNSW.Net.Demo/Program.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net.Demo 7 | { 8 | using System; 9 | using System.Collections.Generic; 10 | using System.Diagnostics; 11 | using System.IO; 12 | using System.Linq; 13 | using System.Numerics; 14 | using System.Runtime.CompilerServices; 15 | using System.Runtime.Intrinsics; 16 | using System.Runtime.Serialization.Formatters.Binary; 17 | using System.Threading; 18 | using System.Threading.Tasks; 19 | 20 | using Parameters = SmallWorld.Parameters; 21 | 22 | public static partial class Program 23 | { 24 | private const int SampleSize = 500; 25 | private const int SampleIncrSize = 100; 26 | private const int TestSize = 10 * SampleSize; 27 | private const int Dimensionality = 128; 28 | private const string VectorsPathSuffix = "vectors.hnsw"; 29 | private const string GraphPathSuffix = "graph.hnsw"; 30 | 31 | public static async Task Main() 32 | { 33 | await MultithreadAddAndReadAsync(); 34 | BuildAndSave("random"); 35 | LoadAndSearch("random"); 36 | } 37 | 38 | private static async Task MultithreadAddAndReadAsync() 39 | { 40 | var world = new SmallWorld(CosineDistance.SIMDForUnits, DefaultRandomGenerator.Instance, new Parameters() { EnableDistanceCacheForConstruction = true, InitialDistanceCacheSize = SampleSize, NeighbourHeuristic = NeighbourSelectionHeuristic.SelectHeuristic, KeepPrunedConnections = true, ExpandBestSelection = true}, threadSafe : false); 41 | 42 | var cts = new CancellationTokenSource(); 43 | 44 | var taskAdd = Task.Run(() => 45 | { 46 | while (!cts.IsCancellationRequested) 47 | { 48 | Console.Write($"Generating {SampleSize} sample vectors... "); 49 | var clock = Stopwatch.StartNew(); 50 | var sampleVectors = RandomVectors(Dimensionality, SampleSize); 51 | Console.WriteLine($"Done in {clock.ElapsedMilliseconds} ms."); 52 | 53 | Console.WriteLine("Building HNSW graph... "); 54 | 55 | using (var listener = new MetricsEventListener(EventSources.GraphBuildEventSource.Instance)) 56 | { 57 | clock = Stopwatch.StartNew(); 58 | for (int i = 0; i < (SampleSize / SampleIncrSize); i++) 59 | { 60 | world.AddItems(sampleVectors.Skip(i * SampleIncrSize).Take(SampleIncrSize).ToArray()); 61 | Console.WriteLine($"\nAt {i + 1} of {SampleSize / SampleIncrSize} Elapsed: {clock.ElapsedMilliseconds} ms.\n"); 62 | } 63 | Console.WriteLine($"Done in {clock.ElapsedMilliseconds} ms."); 64 | } 65 | } 66 | }); 67 | 68 | var taskSearch = Task.Run(async () => 69 | { 70 | while (!cts.IsCancellationRequested) 71 | { 72 | try 73 | { 74 | var searchVectors = RandomVectors(Dimensionality, SampleSize); 75 | Console.WriteLine("Running search agains the graph... "); 76 | using (var listener = new MetricsEventListener(EventSources.GraphSearchEventSource.Instance)) 77 | { 78 | var clock = Stopwatch.StartNew(); 79 | await Parallel.ForEachAsync(searchVectors, (vector, ct) => 80 | { 81 | world.KNNSearch(vector, 10); 82 | Console.Write('.'); 83 | return default; 84 | }); 85 | Console.WriteLine($"Done in {clock.ElapsedMilliseconds} ms."); 86 | } 87 | } 88 | catch (Exception E) 89 | { 90 | throw; 91 | } 92 | } 93 | }); 94 | 95 | 96 | cts.CancelAfter(TimeSpan.FromMinutes(15)); 97 | 98 | await Task.WhenAll(taskAdd, taskSearch); 99 | } 100 | 101 | private static void BuildAndSave(string pathPrefix) 102 | { 103 | var world = new SmallWorld(CosineDistance.SIMDForUnits, DefaultRandomGenerator.Instance, new Parameters() { EnableDistanceCacheForConstruction = true, InitialDistanceCacheSize = SampleSize, NeighbourHeuristic = NeighbourSelectionHeuristic.SelectHeuristic, KeepPrunedConnections = true, ExpandBestSelection = true}); 104 | 105 | Console.Write($"Generating {SampleSize} sample vectors... "); 106 | var clock = Stopwatch.StartNew(); 107 | var sampleVectors = RandomVectors(Dimensionality, SampleSize); 108 | Console.WriteLine($"Done in {clock.ElapsedMilliseconds} ms."); 109 | 110 | Console.WriteLine("Building HNSW graph... "); 111 | using (var listener = new MetricsEventListener(EventSources.GraphBuildEventSource.Instance)) 112 | { 113 | clock = Stopwatch.StartNew(); 114 | for(int i = 0; i < (SampleSize / SampleIncrSize); i++) 115 | { 116 | world.AddItems(sampleVectors.Skip(i * SampleIncrSize).Take(SampleIncrSize).ToArray()); 117 | Console.WriteLine($"\nAt {i+1} of {SampleSize / SampleIncrSize} Elapsed: {clock.ElapsedMilliseconds} ms.\n"); 118 | } 119 | Console.WriteLine($"Done in {clock.ElapsedMilliseconds} ms."); 120 | } 121 | 122 | Console.Write($"Saving HNSW graph to '${Path.Combine(Directory.GetCurrentDirectory(), pathPrefix)}'... "); 123 | clock = Stopwatch.StartNew(); 124 | BinaryFormatter formatter = new BinaryFormatter(); 125 | MemoryStream sampleVectorsStream = new MemoryStream(); 126 | #pragma warning disable SYSLIB0011 // Type or member is obsolete 127 | formatter.Serialize(sampleVectorsStream, sampleVectors); 128 | #pragma warning restore SYSLIB0011 // Type or member is obsolete 129 | File.WriteAllBytes($"{pathPrefix}.{VectorsPathSuffix}", sampleVectorsStream.ToArray()); 130 | 131 | 132 | using (var f = File.Open($"{pathPrefix}.{GraphPathSuffix}", FileMode.Create)) 133 | { 134 | world.SerializeGraph(f); 135 | } 136 | 137 | Console.WriteLine($"Done in {clock.ElapsedMilliseconds} ms."); 138 | } 139 | 140 | private static void LoadAndSearch(string pathPrefix) 141 | { 142 | Stopwatch clock; 143 | 144 | Console.Write("Loading HNSW graph... "); 145 | clock = Stopwatch.StartNew(); 146 | BinaryFormatter formatter = new BinaryFormatter(); 147 | #pragma warning disable SYSLIB0011 // Type or member is obsolete 148 | var sampleVectors = (List)formatter.Deserialize(new MemoryStream(File.ReadAllBytes($"{pathPrefix}.{VectorsPathSuffix}"))); 149 | #pragma warning restore SYSLIB0011 // Type or member is obsolete 150 | SmallWorld world; 151 | using (var f = File.OpenRead($"{pathPrefix}.{GraphPathSuffix}")) 152 | { 153 | (world, _) = SmallWorld.DeserializeGraph(sampleVectors, CosineDistance.SIMDForUnits, DefaultRandomGenerator.Instance, f); 154 | } 155 | Console.WriteLine($"Done in {clock.ElapsedMilliseconds} ms."); 156 | 157 | Console.Write($"Generating {TestSize} test vectos... "); 158 | clock = Stopwatch.StartNew(); 159 | var vectors = RandomVectors(Dimensionality, TestSize); 160 | Console.WriteLine($"Done in {clock.ElapsedMilliseconds} ms."); 161 | 162 | Console.WriteLine("Running search agains the graph... "); 163 | using (var listener = new MetricsEventListener(EventSources.GraphSearchEventSource.Instance)) 164 | { 165 | clock = Stopwatch.StartNew(); 166 | Parallel.ForEach(vectors, (vector) => 167 | { 168 | world.KNNSearch(vector, 10); 169 | }); 170 | Console.WriteLine($"Done in {clock.ElapsedMilliseconds} ms."); 171 | } 172 | } 173 | 174 | private static List RandomVectors(int vectorSize, int vectorsCount) 175 | { 176 | var vectors = new List(); 177 | 178 | for (int i = 0; i < vectorsCount; i++) 179 | { 180 | var vector = new float[vectorSize]; 181 | DefaultRandomGenerator.Instance.NextFloats(vector); 182 | VectorUtils.NormalizeSIMD(vector); 183 | vectors.Add(vector); 184 | } 185 | 186 | return vectors; 187 | } 188 | } 189 | } 190 | -------------------------------------------------------------------------------- /Src/HNSW.Net.TestFiltering/HNSW.Net.TestFiltering.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | Exe 5 | net8.0 6 | enable 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /Src/HNSW.Net.TestFiltering/Program.cs: -------------------------------------------------------------------------------- 1 | using HNSW.Net; 2 | using MessagePack; 3 | using System.Diagnostics; 4 | 5 | int SampleSize = 100_000; 6 | int SampleIncrSize = 100; 7 | int Dimensions = 128; 8 | 9 | var world = new SmallWorld(VectorID.Distance, DefaultRandomGenerator.Instance, new () { EnableDistanceCacheForConstruction = true, InitialDistanceCacheSize = SampleSize, NeighbourHeuristic = NeighbourSelectionHeuristic.SelectHeuristic, KeepPrunedConnections = true, ExpandBestSelection = true }, threadSafe: false); 10 | 11 | Console.WriteLine($"Creating {SampleSize:n0} vectors"); 12 | List vectors; 13 | var fileName = Path.Combine(Directory.GetCurrentDirectory(), $"data-{Dimensions}-{SampleSize}.bin"); 14 | if (File.Exists(fileName)) 15 | { 16 | Console.WriteLine("Loading HNSW graph... "); 17 | using (var f = File.OpenRead(fileName.Replace(".bin", ".vec"))) 18 | { 19 | vectors = MessagePackSerializer.Deserialize>(f); 20 | } 21 | 22 | using (var f = File.OpenRead(fileName)) 23 | { 24 | (world, _) = SmallWorld.DeserializeGraph(vectors, VectorID.Distance, DefaultRandomGenerator.Instance, f, false); 25 | } 26 | } 27 | else 28 | { 29 | vectors = RandomVectors(Dimensions, SampleSize); 30 | Console.WriteLine("Building HNSW graph... "); 31 | //using (var listener = new HNSW.Net.Demo.MetricsEventListener(EventSources.GraphBuildEventSource.Instance)) 32 | { 33 | var clock = Stopwatch.StartNew(); 34 | for (int i = 0; i < (SampleSize / SampleIncrSize); i++) 35 | { 36 | clock.Restart(); 37 | world.AddItems(vectors.Skip(i * SampleIncrSize).Take(SampleIncrSize).ToArray()); 38 | Console.WriteLine($"Indexing vectors, at {(i + 1) * SampleIncrSize:n0} of {SampleSize:n0}, at {(double)clock.ElapsedMilliseconds / SampleIncrSize:n1}ms/vector"); 39 | } 40 | Console.WriteLine("Done building HNSW graph"); 41 | } 42 | 43 | using (var f = File.OpenWrite(fileName.Replace(".bin", ".vec"))) 44 | { 45 | MessagePackSerializer.Serialize(f, vectors); 46 | } 47 | 48 | using (var f = File.OpenWrite(fileName)) 49 | { 50 | world.SerializeGraph(f); 51 | } 52 | } 53 | 54 | var times = new TimeSpan[11]; 55 | int sample = 100; 56 | 57 | for (int repeat = 0; repeat < 2; repeat++) 58 | { 59 | Console.WriteLine($"----------------------------\nRun: {repeat}\n----------------------------"); 60 | for (int p = 0; p <= 10; p++) 61 | { 62 | Console.Write($"Testing {p * 10}% out... "); 63 | var sw = Stopwatch.StartNew(); 64 | foreach (var v in vectors.Take(sample)) 65 | { 66 | using (var cts = new CancellationTokenSource()) 67 | { 68 | cts.CancelAfter(TimeSpan.FromMilliseconds(1)); 69 | var results = world.KNNSearch(v, 50, v => v.ID % 100 < (100 - p * 10), cts.Token); 70 | } 71 | } 72 | sw.Stop(); 73 | times[p] = sw.Elapsed; 74 | Console.WriteLine($"{sw.Elapsed.TotalMilliseconds / sample:n2}ms / call"); 75 | } 76 | 77 | Console.WriteLine(); 78 | } 79 | Console.WriteLine($"----------------------------\nResults\n----------------------------"); 80 | Console.WriteLine(string.Join("\n", times.Select((t, i) => $"Exclude {i*10}%\t{t.TotalMilliseconds / sample:n2}ms / call"))); 81 | 82 | 83 | List RandomVectors(int vectorSize, int vectorsCount) 84 | { 85 | var vectors = new List(); 86 | 87 | for (int i = 0; i < vectorsCount; i++) 88 | { 89 | var vector = new float[vectorSize]; 90 | DefaultRandomGenerator.Instance.NextFloats(vector); 91 | VectorUtils.NormalizeSIMD(vector); 92 | vectors.Add(new VectorID(vector, i)); 93 | } 94 | 95 | return vectors; 96 | } 97 | 98 | [MessagePackObject] 99 | public struct VectorID 100 | { 101 | [Key(0)] public float[] Vector; 102 | [Key(1)] public int ID; 103 | 104 | public VectorID(float[] vector, int iD) 105 | { 106 | Vector = vector; 107 | ID = iD; 108 | } 109 | 110 | internal static float Distance(VectorID a, VectorID b) 111 | { 112 | return CosineDistance.SIMDForUnits(a.Vector, b.Vector); 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /Src/HNSW.Net.Tests/BinaryHeapTests.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net.Tests 7 | { 8 | using System.Linq; 9 | using Microsoft.VisualStudio.TestTools.UnitTesting; 10 | 11 | /// 12 | /// Tests for 13 | /// 14 | [TestClass] 15 | public class BinaryHeapTests 16 | { 17 | /// 18 | /// Tests heap construction. 19 | /// 20 | [TestMethod] 21 | public void HeapifyTest() 22 | { 23 | // Basic tests 24 | { 25 | var heap = new BinaryHeap(Enumerable.Empty().ToList()); 26 | Assert.IsFalse(heap.Buffer.Any()); 27 | 28 | heap = new BinaryHeap(Enumerable.Range(1, 1).ToList()); 29 | Assert.AreEqual(1, heap.Buffer.Count); 30 | Assert.AreEqual(1, heap.Buffer.First()); 31 | 32 | heap = new BinaryHeap(Enumerable.Range(1, 2).ToList()); 33 | Assert.AreEqual(2, heap.Buffer.Count); 34 | Assert.AreEqual(2, heap.Buffer.First()); 35 | } 36 | 37 | // Heapify produces correct heap. 38 | { 39 | const string input = "Hello, World!"; 40 | var heap = new BinaryHeap(input.ToList()); 41 | AssertMaxHeap(heap); 42 | } 43 | } 44 | 45 | /// 46 | /// Tests and 47 | /// 48 | [TestMethod] 49 | public void PushPopTest() 50 | { 51 | var heap = new BinaryHeap(Enumerable.Empty().ToList()); 52 | for (int i = 0; i < 10; ++i) 53 | { 54 | heap.Push(i); 55 | } 56 | 57 | AssertMaxHeap(heap); 58 | 59 | int top = heap.Buffer.First(); 60 | while (heap.Buffer.Any()) 61 | { 62 | Assert.AreEqual(top, heap.Pop()); 63 | top = heap.Buffer.FirstOrDefault(); 64 | } 65 | } 66 | 67 | private void AssertMaxHeap(BinaryHeap heap) 68 | { 69 | for (int p = 0; p < heap.Buffer.Count; ++p) 70 | { 71 | int l = (2 * p) + 1; 72 | int r = l + 1; 73 | 74 | var parent = heap.Buffer[p]; 75 | if (l < heap.Buffer.Count) 76 | { 77 | var left = heap.Buffer[l]; 78 | Assert.IsTrue(heap.Comparer.Compare(parent, left) >= 0); 79 | } 80 | 81 | if (r < heap.Buffer.Count) 82 | { 83 | var right = heap.Buffer[r]; 84 | Assert.IsTrue(heap.Comparer.Compare(parent, right) >= 0); 85 | } 86 | } 87 | } 88 | } 89 | } -------------------------------------------------------------------------------- /Src/HNSW.Net.Tests/HNSW.Net.Tests.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | net8.0 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | PreserveNewest 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /Src/HNSW.Net.Tests/RewindableRandomNumberGenerator.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net.Tests 7 | { 8 | using System; 9 | using System.Collections.Generic; 10 | using System.Linq; 11 | 12 | public class RewindableRandomNumberGenerator : IProvideRandomValues 13 | { 14 | public bool IsThreadSafe => false; 15 | 16 | private FastRandom _random = new FastRandom(42); 17 | private List _calls = new List(); 18 | 19 | public int Next(int minValue, int maxValue) 20 | { 21 | _calls.Add( () => _random.Next(minValue, maxValue) ); 22 | return _random.Next(minValue, maxValue); 23 | } 24 | 25 | public float NextFloat() 26 | { 27 | _calls.Add(() => _random.NextFloat()); 28 | return _random.NextFloat(); 29 | } 30 | 31 | public void NextFloats(Span buffer) 32 | { 33 | var len = buffer.Length; 34 | _calls.Add(() => 35 | { 36 | var b = new float[len]; 37 | _random.NextFloats(b); 38 | }); 39 | _random.NextFloats(buffer); 40 | } 41 | 42 | public int GetState() => _calls.Count; 43 | 44 | public void RewindTo(int state) 45 | { 46 | _random = new FastRandom(42); 47 | var toInvoke = _calls.Take(state).ToArray(); 48 | _calls.Clear(); 49 | foreach (var a in toInvoke) 50 | { 51 | a(); 52 | } 53 | } 54 | } 55 | } -------------------------------------------------------------------------------- /Src/HNSW.Net.Tests/SmallWorldTests.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net.Tests 7 | { 8 | using System; 9 | using System.Collections.Generic; 10 | using System.Globalization; 11 | using System.IO; 12 | using System.Linq; 13 | using Microsoft.VisualStudio.TestTools.UnitTesting; 14 | 15 | /// 16 | /// Tests for 17 | /// 18 | [TestClass] 19 | public class SmallWorldTests 20 | { 21 | // Set floating point error to 5.96 * 10^-7 22 | // For cosine distance error can be bigger in theory but for test data it's not the case. 23 | private const float FloatError = 0.000000596f; 24 | 25 | private IReadOnlyList vectors; 26 | 27 | /// 28 | /// Initializes test resources. 29 | /// 30 | [TestInitialize] 31 | public void TestInitialize() 32 | { 33 | var data = File.ReadAllLines(@"vectors.txt"); 34 | vectors = data.Select(r => Array.ConvertAll(r.Split('\t'), x => float.Parse(x, CultureInfo.CurrentCulture))).ToList(); 35 | } 36 | 37 | /// 38 | /// Basic test for knn search - this test might fail sometimes, as the construction of the graph does not guarantee an exact answer 39 | /// 40 | [TestMethod] 41 | public void KNNSearchTest() 42 | { 43 | var parameters = new SmallWorld.Parameters(); 44 | var graph = new SmallWorld(CosineDistance.NonOptimized, DefaultRandomGenerator.Instance, parameters); 45 | graph.AddItems(vectors); 46 | 47 | int bestWrong = 0; 48 | float maxError = float.MinValue; 49 | 50 | for (int i = 0; i < vectors.Count; ++i) 51 | { 52 | var result = graph.KNNSearch(vectors[i], 20); 53 | var best = result.OrderBy(r => r.Distance).First(); 54 | Assert.AreEqual(20, result.Count); 55 | if (best.Id != i) 56 | { 57 | bestWrong++; 58 | } 59 | maxError = Math.Max(maxError, best.Distance); 60 | } 61 | Assert.AreEqual(0, bestWrong); 62 | Assert.AreEqual(0, maxError, FloatError); 63 | } 64 | 65 | /// 66 | /// Basic test for knn search - this test might fail sometimes, as the construction of the graph does not guarantee an exact answer 67 | /// 68 | [TestMethod] 69 | public void KNNSearchWithFilterTest() 70 | { 71 | var parameters = new SmallWorld.Parameters(); 72 | var graph = new SmallWorld(CosineDistance.NonOptimized, DefaultRandomGenerator.Instance, parameters); 73 | graph.AddItems(vectors); 74 | 75 | for (int i = 0; i < vectors.Count; ++i) 76 | { 77 | var result = graph.KNNSearch(vectors[i], 20, filterItem: v => false); 78 | Assert.AreEqual(0, result.Count); 79 | } 80 | } 81 | 82 | /// 83 | /// Basic test for knn search - this test might fail sometimes, as the construction of the graph does not guarantee an exact answer 84 | /// 85 | [DataTestMethod] 86 | [DataRow(false,false)] 87 | [DataRow(false,true)] 88 | [DataRow(true, false)] 89 | [DataRow(true, true)] 90 | public void KNNSearchTestAlgorithm4(bool expandBestSelection, bool keepPrunedConnections ) 91 | { 92 | var parameters = new SmallWorld.Parameters() { NeighbourHeuristic = NeighbourSelectionHeuristic.SelectHeuristic, ExpandBestSelection = expandBestSelection, KeepPrunedConnections = keepPrunedConnections }; 93 | var graph = new SmallWorld(CosineDistance.NonOptimized, DefaultRandomGenerator.Instance, parameters); 94 | graph.AddItems(vectors); 95 | 96 | int bestWrong = 0; 97 | float maxError = float.MinValue; 98 | 99 | for (int i = 0; i < vectors.Count; ++i) 100 | { 101 | var result = graph.KNNSearch(vectors[i], 20); 102 | var best = result.OrderBy(r => r.Distance).First(); 103 | Assert.AreEqual(20, result.Count); 104 | if (best.Id != i) 105 | { 106 | bestWrong++; 107 | } 108 | maxError = Math.Max(maxError, best.Distance); 109 | } 110 | Assert.AreEqual(0, bestWrong); 111 | Assert.AreEqual(0, maxError, FloatError); 112 | } 113 | 114 | /// 115 | /// Serialization deserialization tests. 116 | /// 117 | [TestMethod] 118 | public void SerializeDeserializeTest() 119 | { 120 | byte[] buffer; 121 | string original; 122 | 123 | // restrict scope of original graph 124 | var stream = new MemoryStream(); 125 | { 126 | var parameters = new SmallWorld.Parameters() 127 | { 128 | M = 15, 129 | LevelLambda = 1 / Math.Log(15), 130 | }; 131 | 132 | var graph = new SmallWorld(CosineDistance.NonOptimized, DefaultRandomGenerator.Instance, parameters); 133 | graph.AddItems(vectors); 134 | 135 | graph.SerializeGraph(stream); 136 | original = graph.Print(); 137 | } 138 | stream.Position = 0; 139 | 140 | var copy = SmallWorld.DeserializeGraph(vectors, CosineDistance.NonOptimized, DefaultRandomGenerator.Instance, stream); 141 | 142 | Assert.AreEqual(original, copy.Graph.Print()); 143 | } 144 | 145 | /// 146 | /// Serialization deserialization tests. 147 | /// 148 | [TestMethod] 149 | public void SerializeDeserializeWithRemainingItemsTest() 150 | { 151 | byte[] buffer; 152 | string original; 153 | int rng_state; 154 | var rng = new RewindableRandomNumberGenerator(); 155 | 156 | // restrict scope of original graph 157 | var stream = new MemoryStream(); 158 | { 159 | var parameters = new SmallWorld.Parameters() 160 | { 161 | M = 15, 162 | LevelLambda = 1 / Math.Log(15), 163 | }; 164 | int itemsToLeaveBehindOnSerialization = (int)(vectors.Count / 2); 165 | 166 | var graph = new SmallWorld(CosineDistance.NonOptimized, rng, parameters); 167 | graph.AddItems(vectors.Take(itemsToLeaveBehindOnSerialization).ToArray()); 168 | 169 | graph.SerializeGraph(stream); 170 | rng_state = rng.GetState(); 171 | graph.AddItems(vectors.Skip(itemsToLeaveBehindOnSerialization).ToArray()); 172 | rng.RewindTo(rng_state); 173 | 174 | original = graph.Print(); 175 | } 176 | stream.Position = 0; 177 | 178 | var copy = SmallWorld.DeserializeGraph(vectors, CosineDistance.NonOptimized, rng, stream); 179 | 180 | copy.Graph.AddItems(copy.ItemsNotInGraph); 181 | 182 | Assert.AreEqual(original, copy.Graph.Print()); 183 | } 184 | } 185 | } -------------------------------------------------------------------------------- /Src/HNSW.Net.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio Version 17 4 | VisualStudioVersion = 17.12.35209.166 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HNSW.Net", "HNSW.Net\HNSW.Net.csproj", "{1A7D27E7-4F89-4197-BA86-2A7E80E702DE}" 7 | EndProject 8 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HNSW.Net.Tests", "HNSW.Net.Tests\HNSW.Net.Tests.csproj", "{F1C038AB-3654-45F3-BE98-378D4EDC2762}" 9 | EndProject 10 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HNSW.Net.Demo", "HNSW.Net.Demo\HNSW.Net.Demo.csproj", "{6F62C4BE-5D1C-4302-9813-E707EB902221}" 11 | EndProject 12 | Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = ".azuredevops", ".azuredevops", "{07C754D2-86A3-46F1-B78D-9C81F8974996}" 13 | ProjectSection(SolutionItems) = preProject 14 | ..\.devops\build-nuget.yml = ..\.devops\build-nuget.yml 15 | EndProjectSection 16 | EndProject 17 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "HNSW.Net.TestFiltering", "HNSW.Net.TestFiltering\HNSW.Net.TestFiltering.csproj", "{3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}" 18 | EndProject 19 | Global 20 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 21 | Debug|Any CPU = Debug|Any CPU 22 | Debug|x64 = Debug|x64 23 | Debug|x86 = Debug|x86 24 | Release|Any CPU = Release|Any CPU 25 | Release|x64 = Release|x64 26 | Release|x86 = Release|x86 27 | EndGlobalSection 28 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 29 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 30 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Debug|Any CPU.Build.0 = Debug|Any CPU 31 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Debug|x64.ActiveCfg = Debug|Any CPU 32 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Debug|x64.Build.0 = Debug|Any CPU 33 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Debug|x86.ActiveCfg = Debug|Any CPU 34 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Debug|x86.Build.0 = Debug|Any CPU 35 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Release|Any CPU.ActiveCfg = Release|Any CPU 36 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Release|Any CPU.Build.0 = Release|Any CPU 37 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Release|x64.ActiveCfg = Release|Any CPU 38 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Release|x64.Build.0 = Release|Any CPU 39 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Release|x86.ActiveCfg = Release|Any CPU 40 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Release|x86.Build.0 = Release|Any CPU 41 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 42 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Debug|Any CPU.Build.0 = Debug|Any CPU 43 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Debug|x64.ActiveCfg = Debug|Any CPU 44 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Debug|x64.Build.0 = Debug|Any CPU 45 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Debug|x86.ActiveCfg = Debug|Any CPU 46 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Debug|x86.Build.0 = Debug|Any CPU 47 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Release|Any CPU.ActiveCfg = Release|Any CPU 48 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Release|Any CPU.Build.0 = Release|Any CPU 49 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Release|x64.ActiveCfg = Release|Any CPU 50 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Release|x64.Build.0 = Release|Any CPU 51 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Release|x86.ActiveCfg = Release|Any CPU 52 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Release|x86.Build.0 = Release|Any CPU 53 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 54 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Debug|Any CPU.Build.0 = Debug|Any CPU 55 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Debug|x64.ActiveCfg = Debug|Any CPU 56 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Debug|x64.Build.0 = Debug|Any CPU 57 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Debug|x86.ActiveCfg = Debug|Any CPU 58 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Debug|x86.Build.0 = Debug|Any CPU 59 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Release|Any CPU.ActiveCfg = Release|Any CPU 60 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Release|Any CPU.Build.0 = Release|Any CPU 61 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Release|x64.ActiveCfg = Release|Any CPU 62 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Release|x64.Build.0 = Release|Any CPU 63 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Release|x86.ActiveCfg = Release|Any CPU 64 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Release|x86.Build.0 = Release|Any CPU 65 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 66 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Debug|Any CPU.Build.0 = Debug|Any CPU 67 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Debug|x64.ActiveCfg = Debug|Any CPU 68 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Debug|x64.Build.0 = Debug|Any CPU 69 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Debug|x86.ActiveCfg = Debug|Any CPU 70 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Debug|x86.Build.0 = Debug|Any CPU 71 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Release|Any CPU.ActiveCfg = Release|Any CPU 72 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Release|Any CPU.Build.0 = Release|Any CPU 73 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Release|x64.ActiveCfg = Release|Any CPU 74 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Release|x64.Build.0 = Release|Any CPU 75 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Release|x86.ActiveCfg = Release|Any CPU 76 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Release|x86.Build.0 = Release|Any CPU 77 | EndGlobalSection 78 | GlobalSection(SolutionProperties) = preSolution 79 | HideSolutionNode = FALSE 80 | EndGlobalSection 81 | GlobalSection(ExtensibilityGlobals) = postSolution 82 | SolutionGuid = {0B720E3E-7DD4-4ED8-BD60-49E166055CBD} 83 | EndGlobalSection 84 | EndGlobal 85 | -------------------------------------------------------------------------------- /Src/HNSW.Net/Algorithms.Algorithm3.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net 7 | { 8 | using System; 9 | using System.Collections.Generic; 10 | 11 | internal partial class Algorithms 12 | { 13 | /// 14 | /// The implementation of the SELECT-NEIGHBORS-SIMPLE(q, C, M) algorithm. 15 | /// Article: Section 4. Algorithm 3. 16 | /// 17 | /// The typeof the items in the small world. 18 | /// The type of the distance in the small world. 19 | internal class Algorithm3 : Algorithm where TDistance : struct, IComparable 20 | { 21 | public Algorithm3(Graph.Core graphCore) : base(graphCore) 22 | { 23 | } 24 | 25 | /// 26 | internal override List SelectBestForConnecting(List candidatesIds, TravelingCosts travelingCosts, int layer) 27 | { 28 | /* 29 | * q ← this 30 | * return M nearest elements from C to q 31 | */ 32 | 33 | // !NO COPY! in-place selection 34 | var bestN = GetM(layer); 35 | var candidatesHeap = new BinaryHeap(candidatesIds, travelingCosts); 36 | while (candidatesHeap.Buffer.Count > bestN) 37 | { 38 | candidatesHeap.Pop(); 39 | } 40 | 41 | return candidatesHeap.Buffer; 42 | } 43 | } 44 | } 45 | } -------------------------------------------------------------------------------- /Src/HNSW.Net/Algorithms.Algorithm4.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net 7 | { 8 | using System; 9 | using System.Collections.Generic; 10 | using System.Linq; 11 | 12 | internal partial class Algorithms 13 | { 14 | /// 15 | /// The implementation of the SELECT-NEIGHBORS-HEURISTIC(q, C, M, lc, extendCandidates, keepPrunedConnections) algorithm. 16 | /// Article: Section 4. Algorithm 4. 17 | /// 18 | /// The typeof the items in the small world. 19 | /// The type of the distance in the small world. 20 | internal sealed class Algorithm4 : Algorithm where TDistance : struct, IComparable 21 | { 22 | public Algorithm4(Graph.Core graphCore) : base(graphCore) 23 | { 24 | } 25 | 26 | /// 27 | internal override List SelectBestForConnecting(List candidatesIds, TravelingCosts travelingCosts, int layer) 28 | { 29 | /* 30 | * q ← this 31 | * R ← ∅ // result 32 | * W ← C // working queue for the candidates 33 | * if expandCandidates // expand candidates 34 | * for each e ∈ C 35 | * for each eadj ∈ neighbourhood(e) at layer lc 36 | * if eadj ∉ W 37 | * W ← W ⋃ eadj 38 | * 39 | * Wd ← ∅ // queue for the discarded candidates 40 | * while │W│ gt 0 and │R│ lt M 41 | * e ← extract nearest element from W to q 42 | * if e is closer to q compared to any element from R 43 | * R ← R ⋃ e 44 | * else 45 | * Wd ← Wd ⋃ e 46 | * 47 | * if keepPrunedConnections // add some of the discarded connections from Wd 48 | * while │Wd│ gt 0 and │R│ lt M 49 | * R ← R ⋃ extract nearest element from Wd to q 50 | * 51 | * return R 52 | */ 53 | 54 | IComparer fartherIsOnTop = travelingCosts; 55 | IComparer closerIsOnTop = fartherIsOnTop.Reverse(); 56 | 57 | var layerM = GetM(layer); 58 | 59 | var resultHeap = new BinaryHeap(new List(layerM + 1), fartherIsOnTop); 60 | var candidatesHeap = new BinaryHeap(candidatesIds, closerIsOnTop); 61 | 62 | // expand candidates option is enabled 63 | if (GraphCore.Parameters.ExpandBestSelection) 64 | { 65 | var visited = new HashSet(candidatesHeap.Buffer); 66 | var toAdd = new HashSet(); 67 | foreach (var candidateId in candidatesHeap.Buffer) 68 | { 69 | var candidateNeighborsIDs = GraphCore.Nodes[candidateId][layer]; 70 | foreach (var candidateNeighbourId in candidateNeighborsIDs) 71 | { 72 | if (!visited.Contains(candidateNeighbourId)) 73 | { 74 | toAdd.Add(candidateNeighbourId); 75 | visited.Add(candidateNeighbourId); 76 | } 77 | } 78 | } 79 | foreach(var id in toAdd) 80 | { 81 | candidatesHeap.Push(id); 82 | } 83 | } 84 | 85 | // main stage of moving candidates to result 86 | var discardedHeap = new BinaryHeap(new List(candidatesHeap.Buffer.Count), closerIsOnTop); 87 | while (candidatesHeap.Buffer.Any() && resultHeap.Buffer.Count < layerM) 88 | { 89 | var candidateId = candidatesHeap.Pop(); 90 | var farestResultId = resultHeap.Buffer.FirstOrDefault(); 91 | 92 | if (!resultHeap.Buffer.Any() || DistanceUtils.LowerThan(travelingCosts.From(candidateId), travelingCosts.From(farestResultId))) 93 | { 94 | resultHeap.Push(candidateId); 95 | } 96 | else if (GraphCore.Parameters.KeepPrunedConnections) 97 | { 98 | discardedHeap.Push(candidateId); 99 | } 100 | } 101 | 102 | // keep pruned option is enabled 103 | if (GraphCore.Parameters.KeepPrunedConnections) 104 | { 105 | while (discardedHeap.Buffer.Any() && resultHeap.Buffer.Count < layerM) 106 | { 107 | resultHeap.Push(discardedHeap.Pop()); 108 | } 109 | } 110 | 111 | return resultHeap.Buffer; 112 | } 113 | } 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /Src/HNSW.Net/Algorithms.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net 7 | { 8 | using System; 9 | using System.Collections.Generic; 10 | 11 | internal partial class Algorithms 12 | { 13 | /// 14 | /// The abstract class representing algorithm to control node capacity. 15 | /// 16 | /// The typeof the items in the small world. 17 | /// The type of the distance in the small world. 18 | internal abstract class Algorithm where TDistance : struct, IComparable 19 | { 20 | protected readonly Graph.Core GraphCore; 21 | 22 | protected readonly Func NodeDistance; 23 | 24 | public Algorithm(Graph.Core graphCore) 25 | { 26 | GraphCore = graphCore; 27 | NodeDistance = graphCore.GetDistance; 28 | } 29 | 30 | /// 31 | /// Creates a new instance of the struct. Controls the exact type of connection lists. 32 | /// 33 | /// The identifier of the node. 34 | /// The max layer where the node is presented. 35 | /// The new instance. 36 | internal virtual Node NewNode(int nodeId, int maxLayer) 37 | { 38 | var connections = new List>(maxLayer + 1); 39 | for (int layer = 0; layer <= maxLayer; ++layer) 40 | { 41 | // M + 1 neighbours to not realloc in AddConnection when the level is full 42 | int layerM = GetM(layer) + 1; 43 | connections.Add(new List(layerM)); 44 | } 45 | 46 | return new Node 47 | { 48 | Id = nodeId, 49 | Connections = connections 50 | }; 51 | } 52 | 53 | /// 54 | /// The algorithm which selects best neighbours from the candidates for the given node. 55 | /// 56 | /// The identifiers of candidates to neighbourhood. 57 | /// Traveling costs to compare candidates. 58 | /// The layer of the neighbourhood. 59 | /// Best nodes selected from the candidates. 60 | internal abstract List SelectBestForConnecting(List candidatesIds, TravelingCosts travelingCosts, int layer); 61 | 62 | /// 63 | /// Get maximum allowed connections for the given level. 64 | /// 65 | /// 66 | /// Article: Section 4.1: 67 | /// "Selection of the Mmax0 (the maximum number of connections that an element can have in the zero layer) also 68 | /// has a strong influence on the search performance, especially in case of high quality(high recall) search. 69 | /// Simulations show that setting Mmax0 to M(this corresponds to kNN graphs on each layer if the neighbors 70 | /// selection heuristic is not used) leads to a very strong performance penalty at high recall. 71 | /// Simulations also suggest that 2∙M is a good choice for Mmax0; 72 | /// setting the parameter higher leads to performance degradation and excessive memory usage." 73 | /// 74 | /// The level of the layer. 75 | /// The maximum number of connections. 76 | internal int GetM(int layer) 77 | { 78 | return layer == 0 ? 2 * GraphCore.Parameters.M : GraphCore.Parameters.M; 79 | } 80 | 81 | /// 82 | /// Tries to connect the node with the new neighbour. 83 | /// 84 | /// The node to add neighbour to. 85 | /// The new neighbour. 86 | /// The layer to add neighbour to. 87 | internal void Connect(Node node, Node neighbour, int layer) 88 | { 89 | var nodeLayer = node[layer]; 90 | nodeLayer.Add(neighbour.Id); 91 | if (nodeLayer.Count > GetM(layer)) 92 | { 93 | var travelingCosts = new TravelingCosts(NodeDistance, node.Id); 94 | node[layer] = SelectBestForConnecting(nodeLayer, travelingCosts, layer); 95 | } 96 | } 97 | } 98 | 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /Src/HNSW.Net/BinaryHeap.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net 7 | { 8 | using System; 9 | using System.Collections.Generic; 10 | using System.Diagnostics.CodeAnalysis; 11 | 12 | /// 13 | /// Binary heap wrapper around the It's a max-heap implementation i.e. the maximum element is always on top. But the order of elements can be customized by providing instance. 14 | /// 15 | /// The type of the items in the source list. 16 | [SuppressMessage("Performance", "CA1815:Override equals and operator equals on value types", Justification = "By design")] 17 | internal struct BinaryHeap 18 | { 19 | internal IComparer Comparer; 20 | internal List Buffer; 21 | internal bool Any => Buffer.Count > 0; 22 | internal BinaryHeap(List buffer) : this(buffer, Comparer.Default) { } 23 | internal BinaryHeap(List buffer, IComparer comparer) 24 | { 25 | Buffer = buffer ?? throw new ArgumentNullException(nameof(buffer)); 26 | Comparer = comparer; 27 | for (int i = 1; i < Buffer.Count; ++i) { SiftUp(i); } 28 | } 29 | 30 | internal void Push(T item) 31 | { 32 | Buffer.Add(item); 33 | SiftUp(Buffer.Count - 1); 34 | } 35 | 36 | internal T Pop() 37 | { 38 | if (Buffer.Count > 0) 39 | { 40 | var result = Buffer[0]; 41 | 42 | Buffer[0] = Buffer[Buffer.Count - 1]; 43 | Buffer.RemoveAt(Buffer.Count - 1); 44 | SiftDown(0); 45 | 46 | return result; 47 | } 48 | 49 | throw new InvalidOperationException("Heap is empty"); 50 | } 51 | 52 | /// 53 | /// Restores the heap property starting from i'th position down to the bottom given that the downstream items fulfill the rule. 54 | /// 55 | /// The position of item where heap property is violated. 56 | private void SiftDown(int i) 57 | { 58 | while (i < Buffer.Count) 59 | { 60 | int l = (i << 1) + 1; 61 | int r = l + 1; 62 | if (l >= Buffer.Count) 63 | { 64 | break; 65 | } 66 | 67 | int m = r < Buffer.Count && Comparer.Compare(Buffer[l], Buffer[r]) < 0 ? r : l; 68 | if (Comparer.Compare(Buffer[m], Buffer[i]) <= 0) 69 | { 70 | break; 71 | } 72 | 73 | Swap(i, m); 74 | i = m; 75 | } 76 | } 77 | 78 | /// 79 | /// Restores the heap property starting from i'th position up to the head given that the upstream items fulfill the rule. 80 | /// 81 | /// The position of item where heap property is violated. 82 | private void SiftUp(int i) 83 | { 84 | while (i > 0) 85 | { 86 | int p = (i - 1) >> 1; 87 | if (Comparer.Compare(Buffer[i], Buffer[p]) <= 0) 88 | { 89 | break; 90 | } 91 | 92 | Swap(i, p); 93 | i = p; 94 | } 95 | } 96 | 97 | private void Swap(int i, int j) 98 | { 99 | var temp = Buffer[i]; 100 | Buffer[i] = Buffer[j]; 101 | Buffer[j] = temp; 102 | } 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /Src/HNSW.Net/CosineDistance.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net 7 | { 8 | using System; 9 | using System.Collections.Generic; 10 | using System.Numerics; 11 | using System.Runtime.CompilerServices; 12 | 13 | /// 14 | /// Calculates cosine similarity. 15 | /// 16 | /// 17 | /// Intuition behind selecting float as a carrier. 18 | /// 19 | /// 1. In practice we work with vectors of dimensionality 100 and each component has value in range [-1; 1] 20 | /// There certainly is a possibility of underflow. 21 | /// But we assume that such cases are rare and we can rely on such underflow losses. 22 | /// 23 | /// 2. According to the article http://www.ti3.tuhh.de/paper/rump/JeaRu13.pdf 24 | /// the floating point rounding error is less then 100 * 2^-24 * sqrt(100) * sqrt(100) < 0.0005960 25 | /// We deem such precision is satisfactory for out needs. 26 | /// 27 | public static class CosineDistance 28 | { 29 | /// 30 | /// Calculates cosine distance without making any optimizations. 31 | /// 32 | /// Left vector. 33 | /// Right vector. 34 | /// Cosine distance between u and v. 35 | public static float NonOptimized(float[] u, float[] v) 36 | { 37 | if (u.Length != v.Length) 38 | { 39 | throw new ArgumentException("Vectors have non-matching dimensions"); 40 | } 41 | 42 | float dot = 0.0f; 43 | float nru = 0.0f; 44 | float nrv = 0.0f; 45 | for (int i = 0; i < u.Length; ++i) 46 | { 47 | dot += u[i] * v[i]; 48 | nru += u[i] * u[i]; 49 | nrv += v[i] * v[i]; 50 | } 51 | 52 | var similarity = dot / (float)(Math.Sqrt(nru) * Math.Sqrt(nrv)); 53 | return 1 - similarity; 54 | } 55 | 56 | /// 57 | /// Calculates cosine distance with assumption that u and v are unit vectors. 58 | /// 59 | /// Left vector. 60 | /// Right vector. 61 | /// Cosine distance between u and v. 62 | public static float ForUnits(float[] u, float[] v) 63 | { 64 | if (u.Length!= v.Length) 65 | { 66 | throw new ArgumentException("Vectors have non-matching dimensions"); 67 | } 68 | 69 | float dot = 0; 70 | for (int i = 0; i < u.Length; ++i) 71 | { 72 | dot += u[i] * v[i]; 73 | } 74 | 75 | return 1 - dot; 76 | } 77 | 78 | /// 79 | /// Calculates cosine distance optimized using SIMD instructions. 80 | /// 81 | /// Left vector. 82 | /// Right vector. 83 | /// Cosine distance between u and v. 84 | public static float SIMD(float[] u, float[] v) 85 | { 86 | if (!Vector.IsHardwareAccelerated) 87 | { 88 | throw new NotSupportedException($"SIMD version of {nameof(CosineDistance)} is not supported"); 89 | } 90 | 91 | if (u.Length != v.Length) 92 | { 93 | throw new ArgumentException("Vectors have non-matching dimensions"); 94 | } 95 | 96 | float dot = 0; 97 | var norm = default(Vector2); 98 | int step = Vector.Count; 99 | 100 | int i, to = u.Length - step; 101 | for (i = 0; i <= to; i += step) 102 | { 103 | var ui = new Vector(u, i); 104 | var vi = new Vector(v, i); 105 | dot += Vector.Dot(ui, vi); 106 | norm.X += Vector.Dot(ui, ui); 107 | norm.Y += Vector.Dot(vi, vi); 108 | } 109 | 110 | for (; i < u.Length; ++i) 111 | { 112 | dot += u[i] * v[i]; 113 | norm.X += u[i] * u[i]; 114 | norm.Y += v[i] * v[i]; 115 | } 116 | 117 | norm = Vector2.SquareRoot(norm); 118 | float n = (norm.X * norm.Y); 119 | 120 | if (n == 0) 121 | { 122 | return 1f; 123 | } 124 | 125 | var similarity = dot / n; 126 | return 1f - similarity; 127 | } 128 | 129 | /// 130 | /// Calculates cosine distance with assumption that u and v are unit vectors using SIMD instructions. 131 | /// 132 | /// Left vector. 133 | /// Right vector. 134 | /// Cosine distance between u and v. 135 | public static float SIMDForUnits(float[] u, float[] v) 136 | { 137 | return 1f - DotProduct(ref u, ref v); 138 | } 139 | 140 | private static readonly int _vs1 = Vector.Count; 141 | private static readonly int _vs2 = 2 * Vector.Count; 142 | private static readonly int _vs3 = 3 * Vector.Count; 143 | private static readonly int _vs4 = 4 * Vector.Count; 144 | 145 | [MethodImpl(MethodImplOptions.AggressiveInlining)] 146 | private static float DotProduct(ref float[] lhs, ref float[] rhs) 147 | { 148 | float result = 0f; 149 | 150 | var count = lhs.Length; 151 | var offset = 0; 152 | 153 | while (count >= _vs4) 154 | { 155 | result += Vector.Dot(new Vector(lhs, offset), new Vector(rhs, offset)); 156 | result += Vector.Dot(new Vector(lhs, offset + _vs1), new Vector(rhs, offset + _vs1)); 157 | result += Vector.Dot(new Vector(lhs, offset + _vs2), new Vector(rhs, offset + _vs2)); 158 | result += Vector.Dot(new Vector(lhs, offset + _vs3), new Vector(rhs, offset + _vs3)); 159 | if (count == _vs4) return result; 160 | count -= _vs4; 161 | offset += _vs4; 162 | } 163 | 164 | if (count >= _vs2) 165 | { 166 | result += Vector.Dot(new Vector(lhs, offset), new Vector(rhs, offset)); 167 | result += Vector.Dot(new Vector(lhs, offset + _vs1), new Vector(rhs, offset + _vs1)); 168 | if (count == _vs2) return result; 169 | count -= _vs2; 170 | offset += _vs2; 171 | } 172 | if (count >= _vs1) 173 | { 174 | result += Vector.Dot(new Vector(lhs, offset), new Vector(rhs, offset)); 175 | if (count == _vs1) return result; 176 | count -= _vs1; 177 | offset += _vs1; 178 | } 179 | if (count > 0) 180 | { 181 | while (count > 0) 182 | { 183 | result += lhs[offset] * rhs[offset]; 184 | offset++; count--; 185 | } 186 | } 187 | return result; 188 | } 189 | } 190 | } 191 | -------------------------------------------------------------------------------- /Src/HNSW.Net/DefaultRandomGenerator.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Runtime.CompilerServices; 4 | using System.Text; 5 | 6 | namespace HNSW.Net 7 | { 8 | public sealed class DefaultRandomGenerator : IProvideRandomValues 9 | { 10 | /// 11 | /// This is the default configuration (it supports the optimization process to be executed on multiple threads) 12 | /// 13 | public static DefaultRandomGenerator Instance { get; } = new DefaultRandomGenerator(allowParallel: true); 14 | 15 | /// 16 | /// This uses the same random number generator but forces the optimization process to run on a single thread (which may be desirable if multiple requests may be processed concurrently 17 | /// or if it is otherwise not desirable to let a single request access all of the CPUs) 18 | /// 19 | public static DefaultRandomGenerator DisableThreading { get; } = new DefaultRandomGenerator(allowParallel: false); 20 | 21 | private DefaultRandomGenerator(bool allowParallel) => IsThreadSafe = allowParallel; 22 | 23 | public bool IsThreadSafe { get; } 24 | 25 | [MethodImpl(MethodImplOptions.AggressiveInlining)] 26 | public int Next(int minValue, int maxValue) => ThreadSafeFastRandom.Next(minValue, maxValue); 27 | 28 | [MethodImpl(MethodImplOptions.AggressiveInlining)] 29 | public float NextFloat() => ThreadSafeFastRandom.NextFloat(); 30 | 31 | [MethodImpl(MethodImplOptions.AggressiveInlining)] 32 | public void NextFloats(Span buffer) => ThreadSafeFastRandom.NextFloats(buffer); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /Src/HNSW.Net/DistanceCache.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | using System; 7 | using System.Runtime.CompilerServices; 8 | 9 | namespace HNSW.Net 10 | { 11 | public static class DistanceCacheLimits 12 | { 13 | /// 14 | /// https://referencesource.microsoft.com/#mscorlib/system/array.cs,2d2b551eabe74985,references 15 | /// We use powers of 2 for efficient modulo 16 | /// 2^28 = 268435456 17 | /// 2^29 = 536870912 18 | /// 2^30 = 1073741824 19 | /// 20 | public static int MaxArrayLength 21 | { 22 | get { return _maxArrayLength; } 23 | set { _maxArrayLength = NextPowerOf2((uint)value); } 24 | } 25 | 26 | private static int NextPowerOf2(uint x) 27 | { 28 | #if NET7_0_OR_GREATER 29 | var v = System.Numerics.BitOperations.RoundUpToPowerOf2(x); 30 | if (v > 0x10000000) return 0x10000000; 31 | return (int)v; 32 | #else 33 | return (int)Math.Pow(2, Math.Ceiling(Math.Log(x) / Math.Log(2))); 34 | #endif 35 | } 36 | 37 | private static int _maxArrayLength = 268_435_456; // 0x10000000; 38 | } 39 | 40 | internal class DistanceCache where TDistance : struct 41 | { 42 | private TDistance[] _values; 43 | 44 | private long[] _keys; 45 | 46 | internal int HitCount; 47 | 48 | internal DistanceCache() 49 | { 50 | } 51 | 52 | internal void Resize(int pointsCount, bool overwrite) 53 | { 54 | if (pointsCount <= 0) { pointsCount = 1024; } 55 | 56 | long capacity = ((long)pointsCount * (pointsCount + 1)) >> 1; 57 | 58 | capacity = capacity < DistanceCacheLimits.MaxArrayLength ? capacity : DistanceCacheLimits.MaxArrayLength; 59 | 60 | if (_keys is null || capacity > _keys.Length || overwrite) 61 | { 62 | int i0 = 0; 63 | if (_keys is null || overwrite) 64 | { 65 | _keys = new long[(int)capacity]; 66 | _values = new TDistance[(int)capacity]; 67 | } 68 | else 69 | { 70 | i0 = _keys.Length; 71 | Array.Resize(ref _keys, (int)capacity); 72 | Array.Resize(ref _values, (int)capacity); 73 | } 74 | 75 | // TODO: may be there is a better way to warm up cache and force OS to allocate pages 76 | _keys.AsSpan().Slice(i0).Fill(-1); 77 | _values.AsSpan().Slice(i0).Fill(default); 78 | } 79 | } 80 | 81 | [MethodImpl(MethodImplOptions.AggressiveInlining)] 82 | internal TDistance GetOrCacheValue(int fromId, int toId, Func getter) 83 | { 84 | long key = MakeKey(fromId, toId); 85 | int hash = (int)(key & (_keys.Length - 1)); 86 | 87 | if (_keys[hash] == key) 88 | { 89 | HitCount++; 90 | return _values[hash]; 91 | } 92 | else 93 | { 94 | var d = getter(fromId, toId); 95 | _keys[hash] = key; 96 | _values[hash] = d; 97 | return d; 98 | } 99 | } 100 | 101 | [MethodImpl(MethodImplOptions.AggressiveInlining)] 102 | private void SetValue(int fromId, int toId, TDistance distance) 103 | { 104 | long key = MakeKey(fromId, toId); 105 | int hash = (int)(key & (_keys.Length - 1)); 106 | _keys[hash] = key; 107 | _values[hash] = distance; 108 | } 109 | 110 | /// 111 | /// Builds key for the pair of points. 112 | /// MakeKey(fromId, toId) == MakeKey(toId, fromId) 113 | /// 114 | /// The from point identifier. 115 | /// The to point identifier. 116 | /// Key of the pair. 117 | private static long MakeKey(int fromId, int toId) 118 | { 119 | return fromId > toId ? (((long)fromId * (fromId + 1)) >> 1) + toId : (((long)toId * (toId + 1)) >> 1) + fromId; 120 | } 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /Src/HNSW.Net/DistanceUtils.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net 7 | { 8 | using System; 9 | 10 | public static class DistanceUtils 11 | { 12 | public static bool LowerThan(TDistance x, TDistance y) where TDistance : IComparable 13 | { 14 | return x.CompareTo(y) < 0; 15 | } 16 | 17 | public static bool GreaterThan(TDistance x, TDistance y) where TDistance : IComparable 18 | { 19 | return x.CompareTo(y) > 0; 20 | } 21 | 22 | public static bool IsEqual(TDistance x, TDistance y) where TDistance : IComparable 23 | { 24 | return x.CompareTo(y) == 0; 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /Src/HNSW.Net/EventSources.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net 7 | { 8 | using System; 9 | using System.Diagnostics.CodeAnalysis; 10 | using System.Diagnostics.Tracing; 11 | using System.Runtime.InteropServices; 12 | 13 | public static class EventSources 14 | { 15 | /// 16 | /// Writes specific metric if the source is enabled. 17 | /// 18 | /// The event source to check. 19 | /// The counter to write metric. 20 | /// The value to write. 21 | internal static void WriteMetricIfEnabled(EventSource source, EventCounter counter, float value) 22 | { 23 | if (source.IsEnabled()) 24 | { 25 | counter.WriteMetric(value); 26 | } 27 | } 28 | 29 | /// 30 | /// Source of events occuring at graph construction phase. 31 | /// 32 | [SuppressMessage("Design", "CA1034:Nested types should not be visible", Justification = "By Design")] 33 | [EventSource(Name = "HNSW.Net.Graph.Build")] 34 | [ComVisible(false)] 35 | public class GraphBuildEventSource : EventSource 36 | { 37 | /// 38 | /// The singleton instance of the source. 39 | /// 40 | public static readonly GraphBuildEventSource Instance = new GraphBuildEventSource(); 41 | 42 | /// 43 | /// Initializes a new instance of the class. 44 | /// 45 | private GraphBuildEventSource() : base(EventSourceSettings.EtwSelfDescribingEventFormat) 46 | { 47 | var coreGetDistanceCacheHitRate = new EventCounter("GetDistance.CacheHitRate", this); 48 | CoreGetDistanceCacheHitRateReporter = (float value) => WriteMetricIfEnabled(this, coreGetDistanceCacheHitRate, value); 49 | 50 | var graphInsertNodeLatency = new EventCounter("InsertNode.Latency", this); 51 | GraphInsertNodeLatencyReporter = (float value) => WriteMetricIfEnabled(this, graphInsertNodeLatency, value); 52 | } 53 | 54 | /// 55 | /// Gets the delegate to report the hit rate of the distance cache. 56 | /// 57 | /// 58 | internal Action CoreGetDistanceCacheHitRateReporter { get; } 59 | 60 | /// 61 | /// Gets the delegate to report the node insertion latency. 62 | /// 63 | /// 64 | internal Action GraphInsertNodeLatencyReporter { get; } 65 | } 66 | 67 | /// 68 | /// Source of events occuring at graph construction phase. 69 | /// 70 | [SuppressMessage("Design", "CA1034:Nested types should not be visible", Justification = "By Design")] 71 | [EventSource(Name = "HNSW.Net.Graph.Search")] 72 | [ComVisible(false)] 73 | public class GraphSearchEventSource : EventSource 74 | { 75 | /// 76 | /// The singleton instance of the source. 77 | /// 78 | public static readonly GraphSearchEventSource Instance = new GraphSearchEventSource(); 79 | 80 | /// 81 | /// Initializes a new instance of the class. 82 | /// 83 | private GraphSearchEventSource() : base(EventSourceSettings.EtwSelfDescribingEventFormat) 84 | { 85 | var graphKNearestLatency = new EventCounter("KNearest.Latency", this); 86 | GraphKNearestLatencyReporter = (float value) => WriteMetricIfEnabled(this, graphKNearestLatency, value); 87 | 88 | var graphKNearestVisitedNodes = new EventCounter("KNearest.VisitedNodes", this); 89 | GraphKNearestVisitedNodesReporter = (float value) => WriteMetricIfEnabled(this, graphKNearestVisitedNodes, value); 90 | } 91 | 92 | /// 93 | /// Gets the delegate to report latency. 94 | /// 95 | internal Action GraphKNearestLatencyReporter { get; } 96 | 97 | /// 98 | /// Gets the counter to report the number of expanded nodes at runtime. 99 | /// 100 | /// 101 | internal Action GraphKNearestVisitedNodesReporter { get; } 102 | } 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /Src/HNSW.Net/FastRandom.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | 3 | namespace HNSW.Net 4 | { 5 | /// 6 | /// A fast random number generator for .NET, from https://www.codeproject.com/Articles/9187/A-fast-equivalent-for-System-Random 7 | /// Colin Green, January 2005 8 | /// 9 | /// September 4th 2005 10 | /// Added NextBytesUnsafe() - commented out by default. 11 | /// Fixed bug in Reinitialise() - y,z and w variables were not being reset. 12 | /// 13 | /// Key points: 14 | /// 1) Based on a simple and fast xor-shift pseudo random number generator (RNG) specified in: 15 | /// Marsaglia, George. (2003). Xorshift RNGs. 16 | /// http://www.jstatsoft.org/v08/i14/xorshift.pdf 17 | /// 18 | /// This particular implementation of xorshift has a period of 2^128-1. See the above paper to see 19 | /// how this can be easily extened if you need a longer period. At the time of writing I could find no 20 | /// information on the period of System.Random for comparison. 21 | /// 22 | /// 2) Faster than System.Random. Up to 8x faster, depending on which methods are called. 23 | /// 24 | /// 3) Direct replacement for System.Random. This class implements all of the methods that System.Random 25 | /// does plus some additional methods. The like named methods are functionally equivalent. 26 | /// 27 | /// 4) Allows fast re-initialisation with a seed, unlike System.Random which accepts a seed at construction 28 | /// time which then executes a relatively expensive initialisation routine. This provides a vast speed improvement 29 | /// if you need to reset the pseudo-random number sequence many times, e.g. if you want to re-generate the same 30 | /// sequence many times. An alternative might be to cache random numbers in an array, but that approach is limited 31 | /// by memory capacity and the fact that you may also want a large number of different sequences cached. Each sequence 32 | /// can each be represented by a single seed value (int) when using FastRandom. 33 | /// 34 | /// Notes. 35 | /// A further performance improvement can be obtained by declaring local variables as static, thus avoiding 36 | /// re-allocation of variables on each call. However care should be taken if multiple instances of 37 | /// FastRandom are in use or if being used in a multi-threaded environment. 38 | /// 39 | /// 40 | internal class FastRandom 41 | { 42 | // The +1 ensures NextDouble doesn't generate 1.0 43 | const float FLOAT_UNIT_INT = 1.0f / ((float)int.MaxValue + 1.0f); 44 | 45 | const double REAL_UNIT_INT = 1.0 / ((double)int.MaxValue + 1.0); 46 | const double REAL_UNIT_UINT = 1.0 / ((double)uint.MaxValue + 1.0); 47 | const uint Y = 842502087, Z = 3579807591, W = 273326509; 48 | 49 | uint x, y, z, w; 50 | 51 | /// 52 | /// Initialises a new instance using time dependent seed. 53 | /// 54 | public FastRandom() 55 | { 56 | // Initialise using the system tick count. 57 | Reinitialise(Environment.TickCount); 58 | } 59 | 60 | /// 61 | /// Initialises a new instance using an int value as seed. 62 | /// This constructor signature is provided to maintain compatibility with 63 | /// System.Random 64 | /// 65 | public FastRandom(int seed) 66 | { 67 | Reinitialise(seed); 68 | } 69 | 70 | /// 71 | /// Reinitialises using an int value as a seed. 72 | /// 73 | public void Reinitialise(int seed) 74 | { 75 | // The only stipulation stated for the xorshift RNG is that at least one of 76 | // the seeds x,y,z,w is non-zero. We fulfill that requirement by only allowing 77 | // resetting of the x seed 78 | x = (uint)seed; 79 | y = Y; 80 | z = Z; 81 | w = W; 82 | } 83 | 84 | /// 85 | /// Generates a random int over the range 0 to int.MaxValue-1. 86 | /// MaxValue is not generated in order to remain functionally equivalent to System.Random.Next(). 87 | /// This does slightly eat into some of the performance gain over System.Random, but not much. 88 | /// For better performance see: 89 | /// 90 | /// Call NextInt() for an int over the range 0 to int.MaxValue. 91 | /// 92 | /// Call NextUInt() and cast the result to an int to generate an int over the full Int32 value range 93 | /// including negative values. 94 | /// 95 | public int Next() 96 | { 97 | uint t = (x ^ (x << 11)); 98 | x = y; y = z; z = w; 99 | w = (w ^ (w >> 19)) ^ (t ^ (t >> 8)); 100 | 101 | // Handle the special case where the value int.MaxValue is generated. This is outside of 102 | // the range of permitted values, so we therefore call Next() to try again. 103 | uint rtn = w & 0x7FFFFFFF; 104 | if (rtn == 0x7FFFFFFF) 105 | return Next(); 106 | return (int)rtn; 107 | } 108 | 109 | /// 110 | /// Generates a random int over the range 0 to upperBound-1, and not including upperBound. 111 | /// 112 | public int Next(int upperBound) 113 | { 114 | if (upperBound < 0) 115 | throw new ArgumentOutOfRangeException("upperBound", upperBound, "upperBound must be >=0"); 116 | 117 | uint t = (x ^ (x << 11)); 118 | x = y; y = z; z = w; 119 | 120 | // The explicit int cast before the first multiplication gives better performance. 121 | // See comments in NextDouble. 122 | return (int)((REAL_UNIT_INT * (int)(0x7FFFFFFF & (w = (w ^ (w >> 19)) ^ (t ^ (t >> 8))))) * upperBound); 123 | } 124 | 125 | /// 126 | /// Generates a random int over the range lowerBound to upperBound-1, and not including upperBound. 127 | /// upperBound must be >= lowerBound. lowerBound may be negative. 128 | /// 129 | public int Next(int lowerBound, int upperBound) 130 | { 131 | if (lowerBound > upperBound) 132 | throw new ArgumentOutOfRangeException("upperBound", upperBound, "upperBound must be >=lowerBound"); 133 | 134 | uint t = (x ^ (x << 11)); 135 | x = y; y = z; z = w; 136 | 137 | // The explicit int cast before the first multiplication gives better performance. 138 | // See comments in NextDouble. 139 | int range = upperBound - lowerBound; 140 | if (range < 0) 141 | { // If range is <0 then an overflow has occured and must resort to using long integer arithmetic instead (slower). 142 | // We also must use all 32 bits of precision, instead of the normal 31, which again is slower. 143 | return lowerBound + (int)((REAL_UNIT_UINT * (double)(w = (w ^ (w >> 19)) ^ (t ^ (t >> 8)))) * (double)((long)upperBound - (long)lowerBound)); 144 | } 145 | 146 | // 31 bits of precision will suffice if range<=int.MaxValue. This allows us to cast to an int and gain 147 | // a little more performance. 148 | return lowerBound + (int)((REAL_UNIT_INT * (double)(int)(0x7FFFFFFF & (w = (w ^ (w >> 19)) ^ (t ^ (t >> 8))))) * (double)range); 149 | } 150 | 151 | /// 152 | /// Generates a random double. Values returned are from 0.0 up to but not including 1.0. 153 | /// 154 | public double NextDouble() 155 | { 156 | uint t = (x ^ (x << 11)); 157 | x = y; y = z; z = w; 158 | 159 | // Here we can gain a 2x speed improvement by generating a value that can be cast to 160 | // an int instead of the more easily available uint. If we then explicitly cast to an 161 | // int the compiler will then cast the int to a double to perform the multiplication, 162 | // this final cast is a lot faster than casting from a uint to a double. The extra cast 163 | // to an int is very fast (the allocated bits remain the same) and so the overall effect 164 | // of the extra cast is a significant performance improvement. 165 | // 166 | // Also note that the loss of one bit of precision is equivalent to what occurs within 167 | // System.Random. 168 | return (REAL_UNIT_INT * (int)(0x7FFFFFFF & (w = (w ^ (w >> 19)) ^ (t ^ (t >> 8))))); 169 | } 170 | 171 | /// 172 | /// Generates a random double. Values returned are from 0.0 up to but not including 1.0. 173 | /// 174 | public float NextFloat() 175 | { 176 | uint x = this.x, y = this.y, z = this.z, w = this.w; 177 | uint t = (x ^ (x << 11)); 178 | x = y; y = z; z = w; 179 | w = (w ^ (w >> 19)) ^ (t ^ (t >> 8)); 180 | var value = FLOAT_UNIT_INT * (int)(0x7FFFFFFF & w); 181 | this.x = x; this.y = y; this.z = z; this.w = w; 182 | return value; 183 | } 184 | 185 | /// 186 | /// Fills the provided byte array with random floats. 187 | /// 188 | public void NextFloats(Span buffer) 189 | { 190 | uint x = this.x, y = this.y, z = this.z, w = this.w; 191 | int i = 0; 192 | uint t; 193 | for (int bound = buffer.Length; i < bound;) 194 | { 195 | t = (x ^ (x << 11)); 196 | x = y; y = z; z = w; 197 | w = (w ^ (w >> 19)) ^ (t ^ (t >> 8)); 198 | 199 | buffer[i++] = FLOAT_UNIT_INT * (int)(0x7FFFFFFF & w); 200 | } 201 | 202 | this.x = x; this.y = y; this.z = z; this.w = w; 203 | } 204 | 205 | 206 | /// 207 | /// Fills the provided byte array with random bytes. 208 | /// This method is functionally equivalent to System.Random.NextBytes(). 209 | /// 210 | public void NextBytes(byte[] buffer) 211 | { 212 | // Fill up the bulk of the buffer in chunks of 4 bytes at a time. 213 | uint x = this.x, y = this.y, z = this.z, w = this.w; 214 | int i = 0; 215 | uint t; 216 | for (int bound = buffer.Length - 3; i < bound;) 217 | { 218 | // Generate 4 bytes. 219 | // Increased performance is achieved by generating 4 random bytes per loop. 220 | // Also note that no mask needs to be applied to zero out the higher order bytes before 221 | // casting because the cast ignores thos bytes. Thanks to Stefan Troschütz for pointing this out. 222 | t = (x ^ (x << 11)); 223 | x = y; y = z; z = w; 224 | w = (w ^ (w >> 19)) ^ (t ^ (t >> 8)); 225 | 226 | buffer[i++] = (byte)w; 227 | buffer[i++] = (byte)(w >> 8); 228 | buffer[i++] = (byte)(w >> 16); 229 | buffer[i++] = (byte)(w >> 24); 230 | } 231 | 232 | // Fill up any remaining bytes in the buffer. 233 | if (i < buffer.Length) 234 | { 235 | // Generate 4 bytes. 236 | t = (x ^ (x << 11)); 237 | x = y; y = z; z = w; 238 | w = (w ^ (w >> 19)) ^ (t ^ (t >> 8)); 239 | 240 | buffer[i++] = (byte)w; 241 | if (i < buffer.Length) 242 | { 243 | buffer[i++] = (byte)(w >> 8); 244 | if (i < buffer.Length) 245 | { 246 | buffer[i++] = (byte)(w >> 16); 247 | if (i < buffer.Length) 248 | { 249 | buffer[i] = (byte)(w >> 24); 250 | } 251 | } 252 | } 253 | } 254 | this.x = x; this.y = y; this.z = z; this.w = w; 255 | } 256 | 257 | /// 258 | /// Fills the provided byte array with random bytes. 259 | /// This method is functionally equivalent to System.Random.NextBytes(). 260 | /// 261 | public void NextBytes(Span buffer) 262 | { 263 | // Fill up the bulk of the buffer in chunks of 4 bytes at a time. 264 | uint x = this.x, y = this.y, z = this.z, w = this.w; 265 | int i = 0; 266 | uint t; 267 | for (int bound = buffer.Length - 3; i < bound;) 268 | { 269 | // Generate 4 bytes. 270 | // Increased performance is achieved by generating 4 random bytes per loop. 271 | // Also note that no mask needs to be applied to zero out the higher order bytes before 272 | // casting because the cast ignores thos bytes. Thanks to Stefan Troschütz for pointing this out. 273 | t = (x ^ (x << 11)); 274 | x = y; y = z; z = w; 275 | w = (w ^ (w >> 19)) ^ (t ^ (t >> 8)); 276 | 277 | buffer[i++] = (byte)w; 278 | buffer[i++] = (byte)(w >> 8); 279 | buffer[i++] = (byte)(w >> 16); 280 | buffer[i++] = (byte)(w >> 24); 281 | } 282 | 283 | // Fill up any remaining bytes in the buffer. 284 | if (i < buffer.Length) 285 | { 286 | // Generate 4 bytes. 287 | t = (x ^ (x << 11)); 288 | x = y; y = z; z = w; 289 | w = (w ^ (w >> 19)) ^ (t ^ (t >> 8)); 290 | 291 | buffer[i++] = (byte)w; 292 | if (i < buffer.Length) 293 | { 294 | buffer[i++] = (byte)(w >> 8); 295 | if (i < buffer.Length) 296 | { 297 | buffer[i++] = (byte)(w >> 16); 298 | if (i < buffer.Length) 299 | { 300 | buffer[i] = (byte)(w >> 24); 301 | } 302 | } 303 | } 304 | } 305 | this.x = x; this.y = y; this.z = z; this.w = w; 306 | } 307 | 308 | /// 309 | /// Generates a uint. Values returned are over the full range of a uint, 310 | /// uint.MinValue to uint.MaxValue, inclusive. 311 | /// 312 | /// This is the fastest method for generating a single random number because the underlying 313 | /// random number generator algorithm generates 32 random bits that can be cast directly to 314 | /// a uint. 315 | /// 316 | public uint NextUInt() 317 | { 318 | uint t = (x ^ (x << 11)); 319 | x = y; y = z; z = w; 320 | return (w = (w ^ (w >> 19)) ^ (t ^ (t >> 8))); 321 | } 322 | 323 | /// 324 | /// Generates a random int over the range 0 to int.MaxValue, inclusive. 325 | /// This method differs from Next() only in that the range is 0 to int.MaxValue 326 | /// and not 0 to int.MaxValue-1. 327 | /// 328 | /// The slight difference in range means this method is slightly faster than Next() 329 | /// but is not functionally equivalent to System.Random.Next(). 330 | /// 331 | public int NextInt() 332 | { 333 | uint t = (x ^ (x << 11)); 334 | x = y; y = z; z = w; 335 | return (int)(0x7FFFFFFF & (w = (w ^ (w >> 19)) ^ (t ^ (t >> 8)))); 336 | } 337 | 338 | 339 | // Buffer 32 bits in bitBuffer, return 1 at a time, keep track of how many have been returned 340 | // with bitBufferIdx. 341 | uint bitBuffer; 342 | uint bitMask = 1; 343 | 344 | /// 345 | /// Generates a single random bit. 346 | /// This method's performance is improved by generating 32 bits in one operation and storing them 347 | /// ready for future calls. 348 | /// 349 | public bool NextBool() 350 | { 351 | if (bitMask == 1) 352 | { 353 | // Generate 32 more bits. 354 | uint t = (x ^ (x << 11)); 355 | x = y; y = z; z = w; 356 | bitBuffer = w = (w ^ (w >> 19)) ^ (t ^ (t >> 8)); 357 | 358 | // Reset the bitMask that tells us which bit to read next. 359 | bitMask = 0x80000000; 360 | return (bitBuffer & bitMask) == 0; 361 | } 362 | 363 | return (bitBuffer & (bitMask >>= 1)) == 0; 364 | } 365 | } 366 | } -------------------------------------------------------------------------------- /Src/HNSW.Net/Graph.Core.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net 7 | { 8 | using System; 9 | using System.Collections.Generic; 10 | using System.IO; 11 | using System.Linq; 12 | using System.Runtime.CompilerServices; 13 | using System.Threading; 14 | using MessagePack; 15 | 16 | using static HNSW.Net.EventSources; 17 | 18 | internal partial class Graph 19 | { 20 | internal class Core 21 | { 22 | private readonly Func Distance; 23 | 24 | private readonly DistanceCache DistanceCache; 25 | 26 | private long DistanceCalculationsCount; 27 | 28 | internal List Nodes { get; private set; } 29 | 30 | internal List Items { get; private set; } 31 | 32 | internal Algorithms.Algorithm Algorithm { get; private set; } 33 | 34 | internal SmallWorld.Parameters Parameters { get; private set; } 35 | 36 | internal float DistanceCacheHitRate => (float)(DistanceCache?.HitCount ?? 0) / DistanceCalculationsCount; 37 | 38 | internal Core(Func distance, SmallWorld.Parameters parameters) 39 | { 40 | Distance = distance; 41 | Parameters = parameters; 42 | 43 | var initialSize = Math.Max(1024, parameters.InitialItemsSize); 44 | 45 | Nodes = new List(initialSize); 46 | Items = new List(initialSize); 47 | 48 | switch (Parameters.NeighbourHeuristic) 49 | { 50 | case NeighbourSelectionHeuristic.SelectSimple: 51 | { 52 | Algorithm = new Algorithms.Algorithm3(this); 53 | break; 54 | } 55 | case NeighbourSelectionHeuristic.SelectHeuristic: 56 | { 57 | Algorithm = new Algorithms.Algorithm4(this); 58 | break; 59 | } 60 | } 61 | 62 | if (Parameters.EnableDistanceCacheForConstruction) 63 | { 64 | DistanceCache = new DistanceCache(); 65 | DistanceCache.Resize(parameters.InitialDistanceCacheSize, false); 66 | } 67 | 68 | DistanceCalculationsCount = 0; 69 | } 70 | 71 | internal IReadOnlyList AddItems(IReadOnlyList items, IProvideRandomValues generator) 72 | { 73 | int newCount = items.Count; 74 | 75 | var newIDs = new List(); 76 | Items.AddRange(items); 77 | DistanceCache?.Resize(newCount, false); 78 | 79 | int id0 = Nodes.Count; 80 | 81 | for (int id = 0; id < newCount; ++id) 82 | { 83 | Nodes.Add(Algorithm.NewNode(id0 + id, RandomLayer(generator, Parameters.LevelLambda))); 84 | newIDs.Add(id0 + id); 85 | } 86 | return newIDs; 87 | } 88 | 89 | internal void ResizeDistanceCache(int newSize) 90 | { 91 | DistanceCache?.Resize(newSize, true); 92 | } 93 | 94 | internal void Serialize(Stream stream) 95 | { 96 | MessagePackSerializer.Serialize(stream, Nodes); 97 | } 98 | 99 | internal TItem[] Deserialize(IReadOnlyList items, Stream stream) 100 | { 101 | // readStrict: true -> removed, as not available anymore on MessagePack 2.0 - also probably not necessary anymore 102 | // see https://github.com/neuecc/MessagePack-CSharp/pull/663 103 | Nodes = MessagePackSerializer.Deserialize>(stream); 104 | var remainingItems = items.Skip(Nodes.Count).ToArray(); 105 | Items.AddRange(items.Take(Nodes.Count)); 106 | return remainingItems; 107 | } 108 | 109 | [MethodImpl(MethodImplOptions.AggressiveInlining)] 110 | internal TDistance GetDistance(int fromId, int toId) 111 | { 112 | DistanceCalculationsCount++; 113 | if (DistanceCache is object) 114 | { 115 | return DistanceCache.GetOrCacheValue(fromId, toId, GetDistanceSkipCache); 116 | } 117 | else 118 | { 119 | return Distance(Items[fromId], Items[toId]); 120 | } 121 | } 122 | 123 | [MethodImpl(MethodImplOptions.AggressiveInlining)] 124 | private TDistance GetDistanceSkipCache(int fromId, int toId) 125 | { 126 | return Distance(Items[fromId], Items[toId]); 127 | } 128 | 129 | private static int RandomLayer(IProvideRandomValues generator, double lambda) 130 | { 131 | var r = -Math.Log(generator.NextFloat()) * lambda; 132 | return (int)r; 133 | } 134 | } 135 | } 136 | } 137 | -------------------------------------------------------------------------------- /Src/HNSW.Net/Graph.Searcher.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net 7 | { 8 | using System; 9 | using System.Collections.Generic; 10 | using System.Threading; 11 | 12 | /// 13 | /// The implementation of knn search. 14 | /// 15 | internal partial class Graph 16 | { 17 | /// 18 | /// The graph searcher. 19 | /// 20 | internal struct Searcher 21 | { 22 | private readonly Core Core; 23 | private readonly List ExpansionBuffer; 24 | private readonly VisitedBitSet VisitedSet; 25 | 26 | /// 27 | /// Initializes a new instance of the struct. 28 | /// 29 | /// The core of the graph. 30 | internal Searcher(Core core) 31 | { 32 | Core = core; 33 | ExpansionBuffer = new List(); 34 | VisitedSet = new VisitedBitSet(core.Nodes.Count); 35 | } 36 | 37 | /// 38 | /// The implementaiton of SEARCH-LAYER(q, ep, ef, lc) algorithm. 39 | /// Article: Section 4. Algorithm 2. 40 | /// 41 | /// The identifier of the entry point for the search. 42 | /// The traveling costs for the search target. 43 | /// The list of identifiers of the nearest neighbours at the level. 44 | /// The layer to perform search at. 45 | /// The number of the nearest neighbours to get from the layer. 46 | /// The version of the graph, will retry the search if the version changed 47 | /// The number of expanded nodes during the run. 48 | internal int RunKnnAtLayer(int entryPointId, TravelingCosts targetCosts, List resultList, int layer, int k, ref long version, long versionAtStart, Func keepResult, CancellationToken cancellationToken = default) 49 | { 50 | /* 51 | * v ← ep // set of visited elements 52 | * C ← ep // set of candidates 53 | * W ← ep // dynamic list of found nearest neighbors 54 | * while │C│ > 0 55 | * c ← extract nearest element from C to q 56 | * f ← get furthest element from W to q 57 | * if distance(c, q) > distance(f, q) 58 | * break // all elements in W are evaluated 59 | * for each e ∈ neighbourhood(c) at layer lc // update C and W 60 | * if e ∉ v 61 | * v ← v ⋃ e 62 | * f ← get furthest element from W to q 63 | * if distance(e, q) < distance(f, q) or │W│ < ef 64 | * C ← C ⋃ e 65 | * W ← W ⋃ e 66 | * if │W│ > ef 67 | * remove furthest element from W to q 68 | * return W 69 | */ 70 | 71 | // prepare tools 72 | IComparer fartherIsOnTop = targetCosts; 73 | IComparer closerIsOnTop = fartherIsOnTop.Reverse(); 74 | 75 | // prepare collections 76 | // TODO: Optimize by providing buffers 77 | var resultHeap = new BinaryHeap(resultList, fartherIsOnTop); 78 | var expansionHeap = new BinaryHeap(ExpansionBuffer, closerIsOnTop); 79 | 80 | if (keepResult(entryPointId)) 81 | { 82 | resultHeap.Push(entryPointId); 83 | } 84 | 85 | expansionHeap.Push(entryPointId); 86 | VisitedSet.Add(entryPointId); 87 | 88 | try 89 | { 90 | // run bfs 91 | int visitedNodesCount = 1; 92 | while (expansionHeap.Buffer.Count > 0) 93 | { 94 | if (cancellationToken.IsCancellationRequested) 95 | { 96 | return visitedNodesCount; 97 | } 98 | 99 | GraphChangedException.ThrowIfChanged(ref version, versionAtStart); 100 | 101 | // get next candidate to check and expand 102 | var toExpandId = expansionHeap.Pop(); 103 | var farthestResultId = resultHeap.Buffer.Count > 0 ? resultHeap.Buffer[0] : -1; 104 | if (farthestResultId > 0 && DistanceUtils.GreaterThan(targetCosts.From(toExpandId), targetCosts.From(farthestResultId))) 105 | { 106 | // the closest candidate is farther than farthest result 107 | break; 108 | } 109 | 110 | // expand candidate 111 | var neighboursIds = Core.Nodes[toExpandId][layer]; 112 | 113 | for (int i = 0; i < neighboursIds.Count; ++i) 114 | { 115 | if (cancellationToken.IsCancellationRequested) 116 | { 117 | return visitedNodesCount; 118 | } 119 | 120 | int neighbourId = neighboursIds[i]; 121 | 122 | if (!VisitedSet.Contains(neighbourId)) 123 | { 124 | // enqueue perspective neighbours to expansion list 125 | farthestResultId = resultHeap.Buffer.Count > 0 ? resultHeap.Buffer[0] : -1; 126 | if (resultHeap.Buffer.Count < k || (farthestResultId >= 0 && DistanceUtils.LowerThan(targetCosts.From(neighbourId), targetCosts.From(farthestResultId)))) 127 | { 128 | expansionHeap.Push(neighbourId); 129 | 130 | if (keepResult(neighbourId)) 131 | { 132 | resultHeap.Push(neighbourId); 133 | } 134 | 135 | if (resultHeap.Buffer.Count > k) 136 | { 137 | resultHeap.Pop(); 138 | } 139 | } 140 | 141 | // update visited list 142 | ++visitedNodesCount; 143 | VisitedSet.Add(neighbourId); 144 | } 145 | } 146 | } 147 | 148 | ExpansionBuffer.Clear(); 149 | VisitedSet.Clear(); 150 | 151 | 152 | return visitedNodesCount; 153 | } 154 | catch (Exception ex) 155 | { 156 | //Throws if the collection changed, otherwise propagates the original exception 157 | GraphChangedException.ThrowIfChanged(ref version, versionAtStart); 158 | throw; 159 | } 160 | } 161 | } 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /Src/HNSW.Net/Graph.Utils.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net 7 | { 8 | using System; 9 | using System.Collections.Generic; 10 | using System.Linq; 11 | 12 | internal partial class Graph 13 | { 14 | /// 15 | /// Runs breadth first search. 16 | /// 17 | /// The graph core. 18 | /// The entry point. 19 | /// The layer of the graph where to run BFS. 20 | /// The action to perform on each node. 21 | internal static void BFS(Core core, Node entryPoint, int layer, Action visitAction) 22 | { 23 | var visitedIds = new HashSet(); 24 | var expansionQueue = new Queue(new[] { entryPoint.Id }); 25 | 26 | while (expansionQueue.Any()) 27 | { 28 | var currentNode = core.Nodes[expansionQueue.Dequeue()]; 29 | if (!visitedIds.Contains(currentNode.Id)) 30 | { 31 | visitAction(currentNode); 32 | visitedIds.Add(currentNode.Id); 33 | foreach (var neighbourId in currentNode[layer]) 34 | { 35 | expansionQueue.Enqueue(neighbourId); 36 | } 37 | } 38 | } 39 | } 40 | 41 | internal class VisitedBitSet 42 | { 43 | private int[] Buffer; 44 | 45 | internal VisitedBitSet(int nodesCount) 46 | { 47 | Buffer = new int[(nodesCount >> 5) + 1]; 48 | } 49 | 50 | internal bool Contains(int nodeId) 51 | { 52 | int carrier = Buffer[nodeId >> 5]; 53 | return ((1 << (nodeId & 31)) & carrier) != 0; 54 | } 55 | 56 | internal void Add(int nodeId) 57 | { 58 | int mask = 1 << (nodeId & 31); 59 | Buffer[nodeId >> 5] |= mask; 60 | } 61 | 62 | internal void Clear() 63 | { 64 | Array.Clear(Buffer, 0, Buffer.Length); 65 | } 66 | } 67 | } 68 | } -------------------------------------------------------------------------------- /Src/HNSW.Net/Graph.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net 7 | { 8 | using System; 9 | using System.Collections.Generic; 10 | using System.IO; 11 | using System.Linq; 12 | using MessagePack; 13 | using System.Text; 14 | 15 | using static HNSW.Net.EventSources; 16 | using System.Threading; 17 | 18 | /// 19 | /// The implementation of a hierarchical small world graph. 20 | /// 21 | /// The type of items to connect into small world. 22 | /// The type of distance between items (expects any numeric type: float, double, decimal, int, ...). 23 | internal partial class Graph where TDistance : struct, IComparable 24 | { 25 | private readonly Func Distance; 26 | 27 | internal Core GraphCore; 28 | 29 | private Node? EntryPoint; 30 | 31 | private long _version; 32 | 33 | /// 34 | /// Initializes a new instance of the class. 35 | /// 36 | /// The distance function. 37 | /// The parameters of the world. 38 | internal Graph(Func distance, SmallWorld.Parameters parameters) 39 | { 40 | Distance = distance; 41 | Parameters = parameters; 42 | } 43 | 44 | internal SmallWorld.Parameters Parameters { get; } 45 | 46 | /// 47 | /// Creates graph from the given items. 48 | /// Contains implementation of INSERT(hnsw, q, M, Mmax, efConstruction, mL) algorithm. 49 | /// Article: Section 4. Algorithm 1. 50 | /// 51 | /// The items to insert. 52 | /// The random number generator to distribute nodes across layers. 53 | /// Interface to report progress 54 | internal IReadOnlyList AddItems(IReadOnlyList items, IProvideRandomValues generator, IProgressReporter progressReporter) 55 | { 56 | if (items is null || !items.Any()) { return Array.Empty(); } 57 | 58 | GraphCore = GraphCore ?? new Core(Distance, Parameters); 59 | 60 | int startIndex = GraphCore.Items.Count; 61 | 62 | var newIDs = GraphCore.AddItems(items, generator); 63 | 64 | var entryPoint = EntryPoint ?? GraphCore.Nodes[0]; 65 | 66 | var searcher = new Searcher(GraphCore); 67 | Func nodeDistance = GraphCore.GetDistance; 68 | var neighboursIdsBuffer = new List(GraphCore.Algorithm.GetM(0) + 1); 69 | 70 | for (int nodeId = startIndex; nodeId < GraphCore.Nodes.Count; ++nodeId) 71 | { 72 | var versionNow = Interlocked.Increment(ref _version); 73 | 74 | using (new ScopeLatencyTracker(GraphBuildEventSource.Instance?.GraphInsertNodeLatencyReporter)) 75 | { 76 | /* 77 | * W ← ∅ // list for the currently found nearest elements 78 | * ep ← get enter point for hnsw 79 | * L ← level of ep // top layer for hnsw 80 | * l ← ⌊-ln(unif(0..1))∙mL⌋ // new element’s level 81 | * for lc ← L … l+1 82 | * W ← SEARCH-LAYER(q, ep, ef=1, lc) 83 | * ep ← get the nearest element from W to q 84 | * for lc ← min(L, l) … 0 85 | * W ← SEARCH-LAYER(q, ep, efConstruction, lc) 86 | * neighbors ← SELECT-NEIGHBORS(q, W, M, lc) // alg. 3 or alg. 4 87 | * for each e ∈ neighbors // shrink connections if needed 88 | * eConn ← neighbourhood(e) at layer lc 89 | * if │eConn│ > Mmax // shrink connections of e if lc = 0 then Mmax = Mmax0 90 | * eNewConn ← SELECT-NEIGHBORS(e, eConn, Mmax, lc) // alg. 3 or alg. 4 91 | * set neighbourhood(e) at layer lc to eNewConn 92 | * ep ← W 93 | * if l > L 94 | * set enter point for hnsw to q 95 | */ 96 | 97 | // zoom in and find the best peer on the same level as newNode 98 | var bestPeer = entryPoint; 99 | var currentNode = GraphCore.Nodes[nodeId]; 100 | var currentNodeTravelingCosts = new TravelingCosts(nodeDistance, nodeId); 101 | for (int layer = bestPeer.MaxLayer; layer > currentNode.MaxLayer; --layer) 102 | { 103 | searcher.RunKnnAtLayer(bestPeer.Id, currentNodeTravelingCosts, neighboursIdsBuffer, layer, 1, ref _version, versionNow, _ => true); 104 | bestPeer = GraphCore.Nodes[neighboursIdsBuffer[0]]; 105 | neighboursIdsBuffer.Clear(); 106 | } 107 | 108 | // connecting new node to the small world 109 | for (int layer = Math.Min(currentNode.MaxLayer, entryPoint.MaxLayer); layer >= 0; --layer) 110 | { 111 | searcher.RunKnnAtLayer(bestPeer.Id, currentNodeTravelingCosts, neighboursIdsBuffer, layer, Parameters.ConstructionPruning, ref _version, versionNow, _ => true); 112 | var bestNeighboursIds = GraphCore.Algorithm.SelectBestForConnecting(neighboursIdsBuffer, currentNodeTravelingCosts, layer); 113 | 114 | for (int i = 0; i < bestNeighboursIds.Count; ++i) 115 | { 116 | int newNeighbourId = bestNeighboursIds[i]; 117 | versionNow = Interlocked.Increment(ref _version); 118 | GraphCore.Algorithm.Connect(currentNode, GraphCore.Nodes[newNeighbourId], layer); 119 | 120 | versionNow = Interlocked.Increment(ref _version); 121 | GraphCore.Algorithm.Connect(GraphCore.Nodes[newNeighbourId], currentNode, layer); 122 | 123 | // if distance from newNode to newNeighbour is better than to bestPeer => update bestPeer 124 | if (DistanceUtils.LowerThan(currentNodeTravelingCosts.From(newNeighbourId), currentNodeTravelingCosts.From(bestPeer.Id))) 125 | { 126 | bestPeer = GraphCore.Nodes[newNeighbourId]; 127 | } 128 | } 129 | 130 | neighboursIdsBuffer.Clear(); 131 | } 132 | 133 | // zoom out to the highest level 134 | if (currentNode.MaxLayer > entryPoint.MaxLayer) 135 | { 136 | entryPoint = currentNode; 137 | } 138 | 139 | // report distance cache hit rate 140 | GraphBuildEventSource.Instance?.CoreGetDistanceCacheHitRateReporter?.Invoke(GraphCore.DistanceCacheHitRate); 141 | } 142 | progressReporter?.Progress(nodeId - startIndex, GraphCore.Nodes.Count - startIndex); 143 | } 144 | 145 | // construction is done 146 | EntryPoint = entryPoint; 147 | 148 | return newIDs; 149 | } 150 | 151 | /// 152 | /// Get k nearest items for a given one. 153 | /// Contains implementation of K-NN-SEARCH(hnsw, q, K, ef) algorithm. 154 | /// Article: Section 4. Algorithm 5. 155 | /// 156 | /// The given node to get the nearest neighbourhood for. 157 | /// The size of the neighbourhood. 158 | /// Filter results by ID that should be kept (return true to keep, false to exclude from results) 159 | /// Cancellation Token for stopping the search when filtering is active 160 | /// The list of the nearest neighbours. 161 | internal IList.KNNSearchResult> KNearest(TItem destination, int k, Func filterItem = null, CancellationToken cancellationToken = default) 162 | { 163 | if (EntryPoint is null) return null; 164 | 165 | Func keepResultInner = _ => true; 166 | 167 | if (filterItem is object) 168 | { 169 | var keepResults = new Dictionary(); 170 | keepResultInner = (id) => 171 | { 172 | if (keepResults.TryGetValue(id, out var v)) return v; 173 | v = filterItem(GraphCore.Items[id]); 174 | keepResults[id] = v; 175 | return v; 176 | }; 177 | } 178 | 179 | int retries = 1_024; 180 | 181 | // TODO: hack we know that destination id is -1. 182 | TDistance RuntimeDistance(int x, int y) 183 | { 184 | int nodeId = x >= 0 ? x : y; 185 | return Distance(destination, GraphCore.Items[nodeId]); 186 | } 187 | 188 | while (true) 189 | { 190 | var versionNow = Interlocked.Read(ref _version); 191 | 192 | try 193 | { 194 | using (new ScopeLatencyTracker(GraphSearchEventSource.Instance?.GraphKNearestLatencyReporter)) 195 | { 196 | var bestPeer = EntryPoint.Value; 197 | var searcher = new Searcher(GraphCore); 198 | var destinationTravelingCosts = new TravelingCosts(RuntimeDistance, -1); 199 | var resultIds = new List(k + 1); 200 | 201 | int visitedNodesCount = 0; 202 | 203 | for (int layer = EntryPoint.Value.MaxLayer; layer > 0; --layer) 204 | { 205 | visitedNodesCount += searcher.RunKnnAtLayer(bestPeer.Id, destinationTravelingCosts, resultIds, layer, 1, ref _version, versionNow, keepResultInner, cancellationToken); 206 | 207 | if (cancellationToken.IsCancellationRequested) 208 | { 209 | //Return best so far - TODO: Investigate if this assumption is correct 210 | return resultIds.Select(id => new SmallWorld.KNNSearchResult(id, GraphCore.Items[id], RuntimeDistance(id, -1))).ToList(); 211 | } 212 | 213 | if (resultIds.Count > 0) 214 | { 215 | bestPeer = GraphCore.Nodes[resultIds[0]]; 216 | } 217 | 218 | resultIds.Clear(); 219 | } 220 | 221 | visitedNodesCount += searcher.RunKnnAtLayer(bestPeer.Id, destinationTravelingCosts, resultIds, 0, k, ref _version, versionNow, keepResultInner, cancellationToken); 222 | 223 | GraphSearchEventSource.Instance?.GraphKNearestVisitedNodesReporter?.Invoke(visitedNodesCount); 224 | 225 | return resultIds.Select(id => new SmallWorld.KNNSearchResult(id, GraphCore.Items[id], RuntimeDistance(id, -1))).ToList(); 226 | } 227 | } 228 | catch (GraphChangedException) 229 | { 230 | if(retries > 0) 231 | { 232 | retries--; 233 | continue; 234 | } 235 | throw; 236 | } 237 | catch(Exception) 238 | { 239 | if (versionNow != Interlocked.Read(ref _version)) 240 | { 241 | if (retries > 0) 242 | { 243 | retries--; 244 | continue; 245 | } 246 | } 247 | throw; 248 | } 249 | } 250 | } 251 | 252 | /// 253 | /// Serializes core of the graph. 254 | /// 255 | /// Bytes representing edges. 256 | internal void Serialize(Stream stream) 257 | { 258 | GraphCore.Serialize(stream); 259 | MessagePackSerializer.Serialize(stream, EntryPoint); 260 | } 261 | 262 | /// 263 | /// Deserilaizes graph edges and assigns nodes to the items. 264 | /// 265 | /// The underlying items. 266 | /// The serialized edges. 267 | internal TItem[] Deserialize(IReadOnlyList items, Stream stream) 268 | { 269 | // readStrict: true -> removed, as not available anymore on MessagePack 2.0 - also probably not necessary anymore 270 | // see https://github.com/neuecc/MessagePack-CSharp/pull/663 271 | 272 | var core = new Core(Distance, Parameters); 273 | var remainingItems = core.Deserialize(items, stream); 274 | EntryPoint = MessagePackSerializer.Deserialize(stream); 275 | GraphCore = core; 276 | return remainingItems; 277 | } 278 | 279 | /// 280 | /// Prints edges of the graph. 281 | /// 282 | /// String representation of the graph's edges. 283 | internal string Print() 284 | { 285 | var buffer = new StringBuilder(); 286 | for (int layer = EntryPoint.Value.MaxLayer; layer >= 0; --layer) 287 | { 288 | buffer.AppendLine($"[LEVEL {layer}]"); 289 | BFS(GraphCore, EntryPoint.Value, layer, (node) => 290 | { 291 | var neighbours = string.Join(", ", node[layer]); 292 | buffer.AppendLine($"({node.Id}) -> {{{neighbours}}}"); 293 | }); 294 | 295 | buffer.AppendLine(); 296 | } 297 | 298 | return buffer.ToString(); 299 | } 300 | } 301 | } 302 | -------------------------------------------------------------------------------- /Src/HNSW.Net/GraphChangedException.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | using System; 7 | using System.Runtime.Serialization; 8 | using System.Threading; 9 | 10 | namespace HNSW.Net 11 | { 12 | [Serializable] 13 | internal class GraphChangedException : Exception 14 | { 15 | public GraphChangedException() 16 | { 17 | } 18 | 19 | public GraphChangedException(string message) : base(message) 20 | { 21 | } 22 | 23 | public GraphChangedException(string message, Exception innerException) : base(message, innerException) 24 | { 25 | } 26 | 27 | protected GraphChangedException(SerializationInfo info, StreamingContext context) : base(info, context) 28 | { 29 | } 30 | 31 | internal static void ThrowIfChanged(ref long version, long versionAtStart) 32 | { 33 | if (Interlocked.Read(ref version) != versionAtStart) 34 | { 35 | throw new GraphChangedException(); 36 | } 37 | } 38 | } 39 | } -------------------------------------------------------------------------------- /Src/HNSW.Net/HNSW.Net.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | netstandard2.0;netstandard2.1;net7.0;net8.0 6 | AnyCPU 7 | Curiosity GmbH 8 | HNSW 9 | Curiosity GmbH and original HSNW.Net authors 10 | HNSW 11 | C# library for fast approximate nearest neighbours search using Hierarchical Navigable Small World graphs. Fork from original at https://github.com/microsoft/HNSW.Net, adds support for incremental build and MessagePack serialization. 12 | MIT 13 | (c) Copyright 2019 Curiosity GmbH, Copyright (c) Microsoft Corporation 14 | true 15 | https://github.com/curiosity-ai/hnsw.net 16 | https://github.com/curiosity-ai/hnsw.net 17 | kNN, ANN, approximate nearest neighbor 18 | false 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /Src/HNSW.Net/IProgressReporter.cs: -------------------------------------------------------------------------------- 1 | namespace HNSW.Net 2 | { 3 | public interface IProgressReporter 4 | { 5 | void Progress(int current, int total); 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /Src/HNSW.Net/IProvideRandomValues.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Runtime.CompilerServices; 3 | 4 | namespace HNSW.Net 5 | { 6 | public interface IProvideRandomValues 7 | { 8 | bool IsThreadSafe { get; } 9 | 10 | /// 11 | /// Generates a random float. Values returned are from 0.0 up to but not including 1.0. 12 | /// 13 | [MethodImpl(MethodImplOptions.AggressiveInlining)] 14 | float NextFloat(); 15 | 16 | /// 17 | /// Fills the elements of a specified array of bytes with random numbers. 18 | /// 19 | /// An array of bytes to contain random numbers. 20 | [MethodImpl(MethodImplOptions.AggressiveInlining)] 21 | void NextFloats(Span buffer); 22 | 23 | /// 24 | /// Returns a random integer that is within a specified range. 25 | /// 26 | /// The inclusive lower bound of the random number returned. 27 | /// The exclusive upper bound of the random number returned. maxValue must be greater than or equal to minValue. 28 | /// A 32-bit signed integer greater than or equal to minValue and less than maxValue; that is, the range of return values includes minValue but not maxValue. If minValue 29 | // equals maxValue, minValue is returned. 30 | [MethodImpl(MethodImplOptions.AggressiveInlining)] 31 | int Next(int minValue, int maxValue); 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /Src/HNSW.Net/MessagePackCompat/FloatBits.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Runtime.InteropServices; 3 | 4 | namespace MessagePackCompat 5 | { 6 | // safe accessor of Single/Double's underlying byte. 7 | // This code is borrowed from MsgPack-Cli https://github.com/msgpack/msgpack-cli 8 | 9 | [StructLayout(LayoutKind.Explicit)] 10 | internal struct Float32Bits 11 | { 12 | [FieldOffset(0)] 13 | public readonly float Value; 14 | 15 | [FieldOffset(0)] 16 | public readonly Byte Byte0; 17 | 18 | [FieldOffset(1)] 19 | public readonly Byte Byte1; 20 | 21 | [FieldOffset(2)] 22 | public readonly Byte Byte2; 23 | 24 | [FieldOffset(3)] 25 | public readonly Byte Byte3; 26 | 27 | public Float32Bits(float value) 28 | { 29 | this = default(Float32Bits); 30 | this.Value = value; 31 | } 32 | 33 | public Float32Bits(byte[] bigEndianBytes, int offset) 34 | { 35 | this = default(Float32Bits); 36 | 37 | if (BitConverter.IsLittleEndian) 38 | { 39 | this.Byte0 = bigEndianBytes[offset + 3]; 40 | this.Byte1 = bigEndianBytes[offset + 2]; 41 | this.Byte2 = bigEndianBytes[offset + 1]; 42 | this.Byte3 = bigEndianBytes[offset]; 43 | } 44 | else 45 | { 46 | this.Byte0 = bigEndianBytes[offset]; 47 | this.Byte1 = bigEndianBytes[offset + 1]; 48 | this.Byte2 = bigEndianBytes[offset + 2]; 49 | this.Byte3 = bigEndianBytes[offset + 3]; 50 | } 51 | } 52 | } 53 | 54 | [StructLayout(LayoutKind.Explicit)] 55 | internal struct Float64Bits 56 | { 57 | [FieldOffset(0)] 58 | public readonly double Value; 59 | 60 | [FieldOffset(0)] 61 | public readonly Byte Byte0; 62 | 63 | [FieldOffset(1)] 64 | public readonly Byte Byte1; 65 | 66 | [FieldOffset(2)] 67 | public readonly Byte Byte2; 68 | 69 | [FieldOffset(3)] 70 | public readonly Byte Byte3; 71 | 72 | [FieldOffset(4)] 73 | public readonly Byte Byte4; 74 | 75 | [FieldOffset(5)] 76 | public readonly Byte Byte5; 77 | 78 | [FieldOffset(6)] 79 | public readonly Byte Byte6; 80 | 81 | [FieldOffset(7)] 82 | public readonly Byte Byte7; 83 | 84 | public Float64Bits(double value) 85 | { 86 | this = default(Float64Bits); 87 | this.Value = value; 88 | } 89 | 90 | public Float64Bits(byte[] bigEndianBytes, int offset) 91 | { 92 | this = default(Float64Bits); 93 | 94 | if (BitConverter.IsLittleEndian) 95 | { 96 | this.Byte0 = bigEndianBytes[offset + 7]; 97 | this.Byte1 = bigEndianBytes[offset + 6]; 98 | this.Byte2 = bigEndianBytes[offset + 5]; 99 | this.Byte3 = bigEndianBytes[offset + 4]; 100 | this.Byte4 = bigEndianBytes[offset + 3]; 101 | this.Byte5 = bigEndianBytes[offset + 2]; 102 | this.Byte6 = bigEndianBytes[offset + 1]; 103 | this.Byte7 = bigEndianBytes[offset]; 104 | } 105 | else 106 | { 107 | this.Byte0 = bigEndianBytes[offset]; 108 | this.Byte1 = bigEndianBytes[offset + 1]; 109 | this.Byte2 = bigEndianBytes[offset + 2]; 110 | this.Byte3 = bigEndianBytes[offset + 3]; 111 | this.Byte4 = bigEndianBytes[offset + 4]; 112 | this.Byte5 = bigEndianBytes[offset + 5]; 113 | this.Byte6 = bigEndianBytes[offset + 6]; 114 | this.Byte7 = bigEndianBytes[offset + 7]; 115 | } 116 | } 117 | } 118 | } -------------------------------------------------------------------------------- /Src/HNSW.Net/MessagePackCompat/README.md: -------------------------------------------------------------------------------- 1 | These files are included here because they've been removed from the new 2.0 MessagePack API - but we depend on them. 2 | Once we refactor the code to use the new Writer/Reader, we can remove this 3 | See: https://github.com/neuecc/MessagePack-CSharp/blob/master/doc/migration.md 4 | 5 | 6 | 7 | 8 | MessagePack for C# 9 | 10 | MIT License 11 | 12 | Copyright (c) 2017 Yoshifumi Kawai and contributors 13 | 14 | Permission is hereby granted, free of charge, to any person obtaining a copy 15 | of this software and associated documentation files (the "Software"), to deal 16 | in the Software without restriction, including without limitation the rights 17 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 18 | copies of the Software, and to permit persons to whom the Software is 19 | furnished to do so, subject to the following conditions: 20 | 21 | The above copyright notice and this permission notice shall be included in all 22 | copies or substantial portions of the Software. 23 | 24 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 25 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 26 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 27 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 28 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 29 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 30 | SOFTWARE. 31 | 32 | --- 33 | 34 | lz4net 35 | 36 | Copyright (c) 2013-2017, Milosz Krajewski 37 | 38 | All rights reserved. 39 | 40 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 41 | 42 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 43 | 44 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 45 | 46 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /Src/HNSW.Net/MessagePackCompat/StringEncoding.cs: -------------------------------------------------------------------------------- 1 | using System.Text; 2 | 3 | namespace MessagePackCompat 4 | { 5 | internal static class StringEncoding 6 | { 7 | public static readonly Encoding UTF8 = new UTF8Encoding(false); 8 | } 9 | } -------------------------------------------------------------------------------- /Src/HNSW.Net/Node.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net 7 | { 8 | using MessagePack; 9 | using System.Collections.Generic; 10 | 11 | /// 12 | /// The implementation of the node in hnsw graph. 13 | /// 14 | [MessagePackObject] 15 | public struct Node 16 | { 17 | [Key(0)] 18 | public List> Connections; 19 | 20 | [Key(1)] public int Id; 21 | 22 | /// 23 | /// Gets the max layer where the node is presented. 24 | /// 25 | [IgnoreMember] 26 | public int MaxLayer 27 | { 28 | get 29 | { 30 | return Connections.Count - 1; 31 | } 32 | } 33 | 34 | /// 35 | /// Gets connections ids of the node at the given layer 36 | /// 37 | /// The layer to get connections at. 38 | /// The connections of the node at the given layer. 39 | public List this[int layer] 40 | { 41 | get 42 | { 43 | return Connections[layer]; 44 | } 45 | set 46 | { 47 | Connections[layer] = value; 48 | } 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /Src/HNSW.Net/Properties/AssemblyInfo.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | using System.Runtime.CompilerServices; 7 | 8 | [assembly: InternalsVisibleTo("HNSW.Net.Tests")] -------------------------------------------------------------------------------- /Src/HNSW.Net/ReverseComparer.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net 7 | { 8 | using System.Collections.Generic; 9 | using System.Diagnostics.CodeAnalysis; 10 | 11 | /// 12 | /// Reverses the order of the nested comparer. 13 | /// 14 | /// The types of items to comapre. 15 | public class ReverseComparer : IComparer 16 | { 17 | private readonly IComparer Comparer; 18 | 19 | /// 20 | /// Initializes a new instance of the class. 21 | /// 22 | /// The comparer to invert. 23 | public ReverseComparer(IComparer comparer) 24 | { 25 | Comparer = comparer; 26 | } 27 | 28 | /// 29 | public int Compare(T x, T y) 30 | { 31 | return Comparer.Compare(y, x); 32 | } 33 | } 34 | 35 | /// 36 | /// Extension methods to shortcut usage. 37 | /// 38 | [SuppressMessage("StyleCop.CSharp.MaintainabilityRules", "SA1402:File may only contain a single type", Justification = "By Design")] 39 | [SuppressMessage("StyleCop.CSharp.OrderingRules", "SA1204:Static elements must appear before instance elements", Justification = "By Design")] 40 | public static class ReverseComparerExtensions 41 | { 42 | /// 43 | /// Creates new wrapper for the given comparer. 44 | /// 45 | /// The types of items to comapre. 46 | /// The source comparer. 47 | /// The inverted to source comparer. 48 | public static ReverseComparer Reverse(this IComparer comparer) 49 | { 50 | return new ReverseComparer(comparer); 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /Src/HNSW.Net/ScopeLatencyTracker.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net 7 | { 8 | using System; 9 | using System.Diagnostics; 10 | 11 | /// 12 | /// Latency tracker for using scope. 13 | /// TODO: make it ref struct in C# 8.0 14 | /// 15 | internal struct ScopeLatencyTracker : IDisposable 16 | { 17 | private long StartTimestamp; 18 | private Action LatencyCallback; 19 | 20 | /// 21 | /// Initializes a new instance of the struct. 22 | /// 23 | /// The latency reporting callback to associate with the scope. 24 | internal ScopeLatencyTracker(Action callback) 25 | { 26 | StartTimestamp = callback != null ? Stopwatch.GetTimestamp() : 0; 27 | LatencyCallback = callback; 28 | } 29 | 30 | /// 31 | /// Reports the time ellsapsed between the tracker creation and this call. 32 | /// 33 | public void Dispose() 34 | { 35 | const long ticksPerMicroSecond = TimeSpan.TicksPerMillisecond / 1000; 36 | if (LatencyCallback != null) 37 | { 38 | long ellapsedMuS = (Stopwatch.GetTimestamp() - StartTimestamp) / ticksPerMicroSecond; 39 | LatencyCallback(ellapsedMuS); 40 | } 41 | } 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /Src/HNSW.Net/SmallWorld.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net 7 | { 8 | using System; 9 | using System.Collections.Generic; 10 | using System.Diagnostics; 11 | using System.Diagnostics.CodeAnalysis; 12 | using System.IO; 13 | using System.Linq; 14 | using System.Threading; 15 | using MessagePack; 16 | using MessagePackCompat; 17 | 18 | /// 19 | /// The Hierarchical Navigable Small World Graphs. https://arxiv.org/abs/1603.09320 20 | /// 21 | /// The type of items to connect into small world. 22 | /// The type of distance between items (expect any numeric type: float, double, decimal, int, ...). 23 | public partial class SmallWorld where TDistance : struct, IComparable 24 | { 25 | private const string SERIALIZATION_HEADER = "HNSW"; 26 | private readonly Func Distance; 27 | 28 | private Graph Graph; 29 | private IProvideRandomValues Generator; 30 | 31 | private ReaderWriterLockSlim _rwLock; 32 | 33 | /// 34 | /// Gets the list of items currently held by the SmallWorld graph. 35 | /// The list is not protected by any locks, and should only be used when it is known the graph won't change 36 | /// 37 | public IReadOnlyList UnsafeItems => Graph?.GraphCore?.Items; 38 | 39 | /// 40 | /// Gets a copy of the list of items currently held by the SmallWorld graph. 41 | /// This call is protected by a read-lock and is safe to be called from multiple threads. 42 | /// 43 | public IReadOnlyList Items 44 | { 45 | get 46 | { 47 | if (_rwLock is object) 48 | { 49 | _rwLock.EnterReadLock(); 50 | try 51 | { 52 | return Graph.GraphCore.Items.ToList(); 53 | } 54 | finally 55 | { 56 | _rwLock.ExitReadLock(); 57 | } 58 | } 59 | else 60 | { 61 | return Graph?.GraphCore?.Items; 62 | } 63 | } 64 | } 65 | 66 | 67 | /// 68 | /// Initializes a new instance of the class. 69 | /// 70 | /// The distance function to use in the small world. 71 | /// The random number generator for building graph. 72 | /// Parameters of the algorithm. 73 | public SmallWorld(Func distance, IProvideRandomValues generator, Parameters parameters, bool threadSafe = true) 74 | { 75 | Distance = distance; 76 | Graph = new Graph(Distance, parameters); 77 | Generator = generator; 78 | _rwLock = threadSafe ? new ReaderWriterLockSlim() : null; 79 | } 80 | 81 | /// 82 | /// Builds hnsw graph from the items. 83 | /// 84 | /// The items to connect into the graph. 85 | 86 | public IReadOnlyList AddItems(IReadOnlyList items, IProgressReporter progressReporter = null) 87 | { 88 | _rwLock?.EnterWriteLock(); 89 | try 90 | { 91 | return Graph.AddItems(items, Generator, progressReporter); 92 | } 93 | finally 94 | { 95 | _rwLock?.ExitWriteLock(); 96 | } 97 | } 98 | 99 | /// 100 | /// Run knn search for a given item. 101 | /// 102 | /// The item to search nearest neighbours. 103 | /// The number of nearest neighbours. 104 | /// Filter results by ID that should be kept (return true to keep, false to exclude from results) 105 | /// Cancellation Token for stopping the search when filtering is active 106 | /// The list of found nearest neighbours. 107 | public IList KNNSearch(TItem item, int k, Func filterItem = null, CancellationToken cancellationToken = default) 108 | { 109 | _rwLock?.EnterReadLock(); 110 | try 111 | { 112 | return Graph.KNearest(item, k, filterItem, cancellationToken); 113 | } 114 | finally 115 | { 116 | _rwLock?.ExitReadLock(); 117 | } 118 | } 119 | 120 | /// 121 | /// Get the item with the index 122 | /// 123 | /// The index of the item 124 | public TItem GetItem(int index) 125 | { 126 | _rwLock?.EnterReadLock(); 127 | try 128 | { 129 | return Items[index]; 130 | } 131 | finally 132 | { 133 | _rwLock?.ExitReadLock(); 134 | } 135 | } 136 | 137 | /// 138 | /// Serializes the graph WITHOUT linked items. 139 | /// 140 | /// Bytes representing the graph. 141 | public void SerializeGraph(Stream stream) 142 | { 143 | if (Graph == null) 144 | { 145 | throw new InvalidOperationException("The graph does not exist"); 146 | } 147 | _rwLock?.EnterReadLock(); 148 | try 149 | { 150 | MessagePackBinary.WriteString(stream, SERIALIZATION_HEADER); 151 | MessagePackSerializer.Serialize(stream, Graph.Parameters); 152 | Graph.Serialize(stream); 153 | } 154 | finally 155 | { 156 | _rwLock?.ExitReadLock(); 157 | } 158 | } 159 | 160 | /// 161 | /// Deserializes the graph from byte array. 162 | /// 163 | /// The items to assign to the graph's verticies. 164 | /// The serialized parameters and edges. 165 | public static (SmallWorld Graph, TItem[] ItemsNotInGraph) DeserializeGraph(IReadOnlyList items, Func distance, IProvideRandomValues generator, Stream stream, bool threadSafe = true) 166 | { 167 | var p0 = stream.Position; 168 | string hnswHeader; 169 | try 170 | { 171 | hnswHeader = MessagePackBinary.ReadString(stream); 172 | } 173 | catch(Exception E) 174 | { 175 | if(stream.CanSeek) { stream.Position = p0; } //Resets the stream to original position 176 | throw new InvalidDataException($"Invalid header found in stream, data is corrupted or invalid", E); 177 | } 178 | 179 | if (hnswHeader != SERIALIZATION_HEADER) 180 | { 181 | if (stream.CanSeek) { stream.Position = p0; } //Resets the stream to original position 182 | throw new InvalidDataException($"Invalid header found in stream, data is corrupted or invalid"); 183 | } 184 | 185 | // readStrict: true -> removed, as not available anymore on MessagePack 2.0 - also probably not necessary anymore 186 | // see https://github.com/neuecc/MessagePack-CSharp/pull/663 187 | 188 | var parameters = MessagePackSerializer.Deserialize(stream); 189 | 190 | //Overwrite previous InitialDistanceCacheSize parameter, so we don't waste time/memory allocating a distance cache for an already existing graph 191 | parameters.InitialDistanceCacheSize = 0; 192 | 193 | var world = new SmallWorld(distance, generator, parameters, threadSafe: threadSafe); 194 | var remainingItems = world.Graph.Deserialize(items, stream); 195 | return (world, remainingItems); 196 | } 197 | 198 | /// 199 | /// Prints edges of the graph. Mostly for debug and test purposes. 200 | /// 201 | /// String representation of the graph's edges. 202 | public string Print() 203 | { 204 | return Graph.Print(); 205 | } 206 | 207 | /// 208 | /// Frees the memory used by the Distance Cache 209 | /// 210 | public void ResizeDistanceCache(int newSize) 211 | { 212 | Graph.GraphCore.ResizeDistanceCache(newSize); 213 | } 214 | 215 | [MessagePackObject(keyAsPropertyName:true)] 216 | public class Parameters 217 | { 218 | public Parameters() 219 | { 220 | M = 10; 221 | LevelLambda = 1 / Math.Log(M); 222 | NeighbourHeuristic = NeighbourSelectionHeuristic.SelectSimple; 223 | ConstructionPruning = 200; 224 | ExpandBestSelection = false; 225 | KeepPrunedConnections = false; 226 | EnableDistanceCacheForConstruction = true; 227 | InitialDistanceCacheSize = 1024 * 1024; 228 | InitialItemsSize = 1024; 229 | } 230 | 231 | /// 232 | /// Gets or sets the parameter which defines the maximum number of neighbors in the zero and above-zero layers. 233 | /// The maximum number of neighbors for the zero layer is 2 * M. 234 | /// The maximum number of neighbors for higher layers is M. 235 | /// 236 | public int M { get; set; } 237 | 238 | /// 239 | /// Gets or sets the max level decay parameter. https://en.wikipedia.org/wiki/Exponential_distribution See 'mL' parameter in the HNSW article. 240 | /// 241 | public double LevelLambda { get; set; } 242 | 243 | /// 244 | /// Gets or sets parameter which specifies the type of heuristic to use for best neighbours selection. 245 | /// 246 | public NeighbourSelectionHeuristic NeighbourHeuristic { get; set; } 247 | 248 | /// 249 | /// Gets or sets the number of candidates to consider as neighbours for a given node at the graph construction phase. See 'efConstruction' parameter in the article. 250 | /// 251 | public int ConstructionPruning { get; set; } 252 | 253 | /// 254 | /// Gets or sets a value indicating whether to expand candidates if is used. See 'extendCandidates' parameter in the article. 255 | /// 256 | public bool ExpandBestSelection { get; set; } 257 | 258 | /// 259 | /// Gets or sets a value indicating whether to keep pruned candidates if is used. See 'keepPrunedConnections' parameter in the article. 260 | /// 261 | public bool KeepPrunedConnections { get; set; } 262 | 263 | /// 264 | /// Gets or sets a value indicating whether to cache calculated distances at graph construction time. 265 | /// 266 | public bool EnableDistanceCacheForConstruction { get; set; } 267 | 268 | /// 269 | /// Gets or sets a the initial distance cache size. 270 | /// Note: This value is reset to 0 on deserialization to avoid allocating the distance cache for pre-built graphs. 271 | /// 272 | public int InitialDistanceCacheSize { get; set; } 273 | 274 | /// 275 | /// Gets or sets a the initial size of the Items list 276 | /// 277 | public int InitialItemsSize { get; set; } 278 | } 279 | 280 | public class KNNSearchResult 281 | { 282 | internal KNNSearchResult(int id, TItem item, TDistance distance) 283 | { 284 | Id = id; 285 | Item = item; 286 | Distance = distance; 287 | } 288 | 289 | public int Id { get; } 290 | 291 | public TItem Item { get; } 292 | 293 | public TDistance Distance { get; } 294 | 295 | public override string ToString() 296 | { 297 | return $"I:{Id} Dist:{Distance:n2} [{Item}]"; 298 | } 299 | } 300 | } 301 | } 302 | -------------------------------------------------------------------------------- /Src/HNSW.Net/ThreadSafeFastRandom.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Runtime.CompilerServices; 3 | 4 | namespace HNSW.Net 5 | { 6 | internal static class ThreadSafeFastRandom 7 | { 8 | private static readonly Random _global = new Random(); 9 | 10 | [ThreadStatic] 11 | private static FastRandom _local; 12 | 13 | private static int GetGlobalSeed() 14 | { 15 | int seed; 16 | lock (_global) 17 | { 18 | seed = _global.Next(); 19 | } 20 | return seed; 21 | } 22 | 23 | /// 24 | /// Returns a non-negative random integer. 25 | /// 26 | /// A 32-bit signed integer that is greater than or equal to 0 and less than System.Int32.MaxValue. 27 | [MethodImpl(MethodImplOptions.AggressiveInlining)] 28 | public static int Next() 29 | { 30 | var inst = _local; 31 | if (inst == null) 32 | { 33 | int seed; 34 | seed = GetGlobalSeed(); 35 | _local = inst = new FastRandom(seed); 36 | } 37 | return inst.Next(); 38 | } 39 | 40 | /// 41 | /// Returns a non-negative random integer that is less than the specified maximum. 42 | /// 43 | /// The exclusive upper bound of the random number to be generated. maxValue must be greater than or equal to 0. 44 | /// A 32-bit signed integer that is greater than or equal to 0, and less than maxValue; that is, the range of return values ordinarily includes 0 but not maxValue. However, 45 | // if maxValue equals 0, maxValue is returned. 46 | [MethodImpl(MethodImplOptions.AggressiveInlining)] 47 | public static int Next(int maxValue) 48 | { 49 | var inst = _local; 50 | if (inst == null) 51 | { 52 | int seed; 53 | seed = GetGlobalSeed(); 54 | _local = inst = new FastRandom(seed); 55 | } 56 | int ans; 57 | do 58 | { 59 | ans = inst.Next(maxValue); 60 | } while (ans == maxValue); 61 | 62 | return ans; 63 | } 64 | 65 | /// 66 | /// Returns a random integer that is within a specified range. 67 | /// 68 | /// The inclusive lower bound of the random number returned. 69 | /// The exclusive upper bound of the random number returned. maxValue must be greater than or equal to minValue. 70 | /// A 32-bit signed integer greater than or equal to minValue and less than maxValue; that is, the range of return values includes minValue but not maxValue. If minValue 71 | // equals maxValue, minValue is returned. 72 | [MethodImpl(MethodImplOptions.AggressiveInlining)] 73 | public static int Next(int minValue, int maxValue) 74 | { 75 | var inst = _local; 76 | if (inst == null) 77 | { 78 | int seed; 79 | seed = GetGlobalSeed(); 80 | _local = inst = new FastRandom(seed); 81 | } 82 | return inst.Next(minValue, maxValue); 83 | } 84 | 85 | /// 86 | /// Generates a random float. Values returned are from 0.0 up to but not including 1.0. 87 | /// 88 | [MethodImpl(MethodImplOptions.AggressiveInlining)] 89 | public static float NextFloat() 90 | { 91 | var inst = _local; 92 | if (inst == null) 93 | { 94 | int seed; 95 | seed = GetGlobalSeed(); 96 | _local = inst = new FastRandom(seed); 97 | } 98 | return inst.NextFloat(); 99 | } 100 | 101 | /// 102 | /// Fills the elements of a specified array of bytes with random numbers. 103 | /// 104 | /// An array of bytes to contain random numbers. 105 | [MethodImpl(MethodImplOptions.AggressiveInlining)] 106 | public static void NextFloats(Span buffer) 107 | { 108 | var inst = _local; 109 | if (inst == null) 110 | { 111 | int seed; 112 | seed = GetGlobalSeed(); 113 | _local = inst = new FastRandom(seed); 114 | } 115 | inst.NextFloats(buffer); 116 | } 117 | } 118 | } -------------------------------------------------------------------------------- /Src/HNSW.Net/TravelingCosts.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net 7 | { 8 | using System; 9 | using System.Collections.Generic; 10 | 11 | /// 12 | /// Implementation of distance calculation from an arbitrary point to the given destination. 13 | /// 14 | /// Type of the points. 15 | /// Type of the distance. 16 | public class TravelingCosts : IComparer 17 | { 18 | private static readonly Comparer DistanceComparer = Comparer.Default; 19 | 20 | private readonly Func Distance; 21 | 22 | public TravelingCosts(Func distance, TItem destination) 23 | { 24 | Distance = distance; 25 | Destination = destination; 26 | } 27 | 28 | public TItem Destination { get; } 29 | 30 | public TDistance From(TItem departure) 31 | { 32 | return Distance(departure, Destination); 33 | } 34 | 35 | /// 36 | /// Compares 2 points by the distance from the destination. 37 | /// 38 | /// Left point. 39 | /// Right point. 40 | /// 41 | /// -1 if x is closer to the destination than y; 42 | /// 0 if x and y are equally far from the destination; 43 | /// 1 if x is farther from the destination than y. 44 | /// 45 | public int Compare(TItem x, TItem y) 46 | { 47 | var fromX = From(x); 48 | var fromY = From(y); 49 | return DistanceComparer.Compare(fromX, fromY); 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /Src/HNSW.Net/VectorUtils.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net 7 | { 8 | using System; 9 | using System.Collections.Generic; 10 | using System.Numerics; 11 | 12 | public static class VectorUtils 13 | { 14 | public static float Magnitude(IList vector) 15 | { 16 | float magnitude = 0.0f; 17 | for (int i = 0; i < vector.Count; ++i) 18 | { 19 | magnitude += vector[i] * vector[i]; 20 | } 21 | 22 | return (float)Math.Sqrt(magnitude); 23 | } 24 | 25 | public static void Normalize(IList vector) 26 | { 27 | float normFactor = 1 / Magnitude(vector); 28 | for (int i = 0; i < vector.Count; ++i) 29 | { 30 | vector[i] *= normFactor; 31 | } 32 | } 33 | 34 | public static float MagnitudeSIMD(float[] vector) 35 | { 36 | if (!Vector.IsHardwareAccelerated) 37 | { 38 | throw new NotSupportedException($"{nameof(VectorUtils.NormalizeSIMD)} is not supported"); 39 | } 40 | 41 | float magnitude = 0.0f; 42 | int step = Vector.Count; 43 | 44 | int i, to = vector.Length - step; 45 | for (i = 0; i <= to; i += Vector.Count) 46 | { 47 | var vi = new Vector(vector, i); 48 | magnitude += Vector.Dot(vi, vi); 49 | } 50 | 51 | for (; i < vector.Length; ++i) 52 | { 53 | magnitude += vector[i] * vector[i]; 54 | } 55 | 56 | return (float)Math.Sqrt(magnitude); 57 | } 58 | 59 | public static void NormalizeSIMD(float[] vector) 60 | { 61 | if (!Vector.IsHardwareAccelerated) 62 | { 63 | throw new NotSupportedException($"{nameof(VectorUtils.NormalizeSIMD)} is not supported"); 64 | } 65 | 66 | float normFactor = 1f / MagnitudeSIMD(vector); 67 | int step = Vector.Count; 68 | 69 | int i, to = vector.Length - step; 70 | for (i = 0; i <= to; i += step) 71 | { 72 | var vi = new Vector(vector, i); 73 | vi = Vector.Multiply(normFactor, vi); 74 | vi.CopyTo(vector, i); 75 | } 76 | 77 | for (; i < vector.Length; ++i) 78 | { 79 | vector[i] *= normFactor; 80 | } 81 | } 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /src/HNSW.Net/NeighbourSelectionHeuristic.cs: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) Microsoft Corporation. All rights reserved. 3 | // Licensed under the MIT License. 4 | // 5 | 6 | namespace HNSW.Net 7 | { 8 | /// 9 | /// Type of heuristic to select best neighbours for a node. 10 | /// 11 | public enum NeighbourSelectionHeuristic 12 | { 13 | /// 14 | /// Marker for the Algorithm 3 (SELECT-NEIGHBORS-SIMPLE) from the article. Implemented in 15 | /// 16 | SelectSimple, 17 | 18 | /// 19 | /// Marker for the Algorithm 4 (SELECT-NEIGHBORS-HEURISTIC) from the article. Implemented in 20 | /// 21 | SelectHeuristic 22 | } 23 | } 24 | --------------------------------------------------------------------------------