├── .devops
└── build-nuget.yml
├── .gitignore
├── LICENSE
├── README.md
├── Src
├── HNSW.Net.Demo
│ ├── App.config
│ ├── HNSW.Net.Demo.csproj
│ ├── MetricsEventListener.cs
│ └── Program.cs
├── HNSW.Net.TestFiltering
│ ├── HNSW.Net.TestFiltering.csproj
│ └── Program.cs
├── HNSW.Net.Tests
│ ├── BinaryHeapTests.cs
│ ├── HNSW.Net.Tests.csproj
│ ├── RewindableRandomNumberGenerator.cs
│ ├── SmallWorldTests.cs
│ └── vectors.txt
├── HNSW.Net.sln
└── HNSW.Net
│ ├── Algorithms.Algorithm3.cs
│ ├── Algorithms.Algorithm4.cs
│ ├── Algorithms.cs
│ ├── BinaryHeap.cs
│ ├── CosineDistance.cs
│ ├── DefaultRandomGenerator.cs
│ ├── DistanceCache.cs
│ ├── DistanceUtils.cs
│ ├── EventSources.cs
│ ├── FastRandom.cs
│ ├── Graph.Core.cs
│ ├── Graph.Searcher.cs
│ ├── Graph.Utils.cs
│ ├── Graph.cs
│ ├── GraphChangedException.cs
│ ├── HNSW.Net.csproj
│ ├── IProgressReporter.cs
│ ├── IProvideRandomValues.cs
│ ├── MessagePackCompat
│ ├── FloatBits.cs
│ ├── MessagePackBinary.cs
│ ├── README.md
│ └── StringEncoding.cs
│ ├── Node.cs
│ ├── Properties
│ └── AssemblyInfo.cs
│ ├── ReverseComparer.cs
│ ├── ScopeLatencyTracker.cs
│ ├── SmallWorld.cs
│ ├── ThreadSafeFastRandom.cs
│ ├── TravelingCosts.cs
│ └── VectorUtils.cs
└── src
└── HNSW.Net
└── NeighbourSelectionHeuristic.cs
/.devops/build-nuget.yml:
--------------------------------------------------------------------------------
1 | variables:
2 | project: './Src/HNSW.Net/HNSW.Net.csproj'
3 | buildConfiguration: 'Release'
4 | targetVersion: yy.M.$(build.buildId)
5 |
6 | trigger:
7 | - master
8 |
9 | pool:
10 | vmImage: 'windows-latest'
11 |
12 | steps:
13 | - task: NuGetToolInstaller@0
14 |
15 |
16 | - task: PowerShell@2
17 | displayName: 'Create CalVer Version'
18 | inputs:
19 | targetType: 'inline'
20 | script: |
21 | $dottedDate = (Get-Date).ToString("yy.M")
22 | $buildID = $($env:BUILD_BUILDID)
23 | $newTargetVersion = "$dottedDate.$buildID"
24 | Write-Host "##vso[task.setvariable variable=targetVersion;]$newTargetVersion"
25 | Write-Host "Updated targetVersion to '$newTargetVersion'"
26 |
27 |
28 | - task: UseDotNet@2
29 | displayName: 'Use .NET 8.0 SDK'
30 | inputs:
31 | packageType: sdk
32 | version: 8.x
33 | includePreviewVersions: false
34 | installationPath: $(Agent.ToolsDirectory)\dotnet
35 |
36 | - task: DotNetCoreCLI@2
37 | inputs:
38 | command: 'restore'
39 | projects: '$(project)'
40 | displayName: 'restore nuget'
41 |
42 | - task: DotNetCoreCLI@2
43 | inputs:
44 | command: 'build'
45 | projects: '$(project)'
46 | arguments: '-c $(buildConfiguration) /p:Version=$(targetVersion) /p:LangVersion=latest'
47 |
48 | - task: DotNetCoreCLI@2
49 | inputs:
50 | command: 'pack'
51 | packagesToPack: '$(project)'
52 | versioningScheme: 'off'
53 | configuration: '$(buildConfiguration)'
54 | buildProperties: 'Version="$(targetVersion)";LangVersion="latest"'
55 | nobuild: true
56 |
57 | - task: NuGetCommand@2
58 | inputs:
59 | command: 'push'
60 | packagesToPush: '**/*.nupkg'
61 | nuGetFeedType: 'external'
62 | publishFeedCredentials: 'nuget-curiosity-org'
63 | displayName: 'push nuget'
64 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ## Ignore Visual Studio temporary files, build results, and
2 | ## files generated by popular Visual Studio add-ons.
3 | ##
4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
5 |
6 | # User-specific files
7 | *.suo
8 | *.user
9 | *.userosscache
10 | *.sln.docstates
11 |
12 | # User-specific files (MonoDevelop/Xamarin Studio)
13 | *.userprefs
14 |
15 | # Build results
16 | [Dd]ebug/
17 | [Dd]ebugPublic/
18 | [Rr]elease/
19 | [Rr]eleases/
20 | x64/
21 | x86/
22 | bld/
23 | [Bb]in/
24 | [Oo]bj/
25 | [Ll]og/
26 |
27 | # Visual Studio 2015/2017 cache/options directory
28 | .vs/
29 | # Uncomment if you have tasks that create the project's static files in wwwroot
30 | #wwwroot/
31 |
32 | # Visual Studio 2017 auto generated files
33 | Generated\ Files/
34 |
35 | # MSTest test Results
36 | [Tt]est[Rr]esult*/
37 | [Bb]uild[Ll]og.*
38 |
39 | # NUNIT
40 | *.VisualState.xml
41 | TestResult.xml
42 |
43 | # Build Results of an ATL Project
44 | [Dd]ebugPS/
45 | [Rr]eleasePS/
46 | dlldata.c
47 |
48 | # Benchmark Results
49 | BenchmarkDotNet.Artifacts/
50 |
51 | # .NET Core
52 | project.lock.json
53 | project.fragment.lock.json
54 | artifacts/
55 | **/Properties/launchSettings.json
56 |
57 | # StyleCop
58 | StyleCopReport.xml
59 |
60 | # Files built by Visual Studio
61 | *_i.c
62 | *_p.c
63 | *_i.h
64 | *.ilk
65 | *.meta
66 | *.obj
67 | *.iobj
68 | *.pch
69 | *.pdb
70 | *.ipdb
71 | *.pgc
72 | *.pgd
73 | *.rsp
74 | *.sbr
75 | *.tlb
76 | *.tli
77 | *.tlh
78 | *.tmp
79 | *.tmp_proj
80 | *.log
81 | *.vspscc
82 | *.vssscc
83 | .builds
84 | *.pidb
85 | *.svclog
86 | *.scc
87 |
88 | # Chutzpah Test files
89 | _Chutzpah*
90 |
91 | # Visual C++ cache files
92 | ipch/
93 | *.aps
94 | *.ncb
95 | *.opendb
96 | *.opensdf
97 | *.sdf
98 | *.cachefile
99 | *.VC.db
100 | *.VC.VC.opendb
101 |
102 | # Visual Studio profiler
103 | *.psess
104 | *.vsp
105 | *.vspx
106 | *.sap
107 |
108 | # Visual Studio Trace Files
109 | *.e2e
110 |
111 | # TFS 2012 Local Workspace
112 | $tf/
113 |
114 | # Guidance Automation Toolkit
115 | *.gpState
116 |
117 | # ReSharper is a .NET coding add-in
118 | _ReSharper*/
119 | *.[Rr]e[Ss]harper
120 | *.DotSettings.user
121 |
122 | # JustCode is a .NET coding add-in
123 | .JustCode
124 |
125 | # TeamCity is a build add-in
126 | _TeamCity*
127 |
128 | # DotCover is a Code Coverage Tool
129 | *.dotCover
130 |
131 | # AxoCover is a Code Coverage Tool
132 | .axoCover/*
133 | !.axoCover/settings.json
134 |
135 | # Visual Studio code coverage results
136 | *.coverage
137 | *.coveragexml
138 |
139 | # NCrunch
140 | _NCrunch_*
141 | .*crunch*.local.xml
142 | nCrunchTemp_*
143 |
144 | # MightyMoose
145 | *.mm.*
146 | AutoTest.Net/
147 |
148 | # Web workbench (sass)
149 | .sass-cache/
150 |
151 | # Installshield output folder
152 | [Ee]xpress/
153 |
154 | # DocProject is a documentation generator add-in
155 | DocProject/buildhelp/
156 | DocProject/Help/*.HxT
157 | DocProject/Help/*.HxC
158 | DocProject/Help/*.hhc
159 | DocProject/Help/*.hhk
160 | DocProject/Help/*.hhp
161 | DocProject/Help/Html2
162 | DocProject/Help/html
163 |
164 | # Click-Once directory
165 | publish/
166 |
167 | # Publish Web Output
168 | *.[Pp]ublish.xml
169 | *.azurePubxml
170 | # Note: Comment the next line if you want to checkin your web deploy settings,
171 | # but database connection strings (with potential passwords) will be unencrypted
172 | *.pubxml
173 | *.publishproj
174 |
175 | # Microsoft Azure Web App publish settings. Comment the next line if you want to
176 | # checkin your Azure Web App publish settings, but sensitive information contained
177 | # in these scripts will be unencrypted
178 | PublishScripts/
179 |
180 | # NuGet Packages
181 | *.nupkg
182 | # The packages folder can be ignored because of Package Restore
183 | **/[Pp]ackages/*
184 | # except build/, which is used as an MSBuild target.
185 | !**/[Pp]ackages/build/
186 | # Uncomment if necessary however generally it will be regenerated when needed
187 | #!**/[Pp]ackages/repositories.config
188 | # NuGet v3's project.json files produces more ignorable files
189 | *.nuget.props
190 | *.nuget.targets
191 |
192 | # Microsoft Azure Build Output
193 | csx/
194 | *.build.csdef
195 |
196 | # Microsoft Azure Emulator
197 | ecf/
198 | rcf/
199 |
200 | # Windows Store app package directories and files
201 | AppPackages/
202 | BundleArtifacts/
203 | Package.StoreAssociation.xml
204 | _pkginfo.txt
205 | *.appx
206 |
207 | # Visual Studio cache files
208 | # files ending in .cache can be ignored
209 | *.[Cc]ache
210 | # but keep track of directories ending in .cache
211 | !*.[Cc]ache/
212 |
213 | # Others
214 | ClientBin/
215 | ~$*
216 | *~
217 | *.dbmdl
218 | *.dbproj.schemaview
219 | *.jfm
220 | *.pfx
221 | *.publishsettings
222 | orleans.codegen.cs
223 |
224 | # Including strong name files can present a security risk
225 | # (https://github.com/github/gitignore/pull/2483#issue-259490424)
226 | #*.snk
227 |
228 | # Since there are multiple workflows, uncomment next line to ignore bower_components
229 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
230 | #bower_components/
231 |
232 | # RIA/Silverlight projects
233 | Generated_Code/
234 |
235 | # Backup & report files from converting an old project file
236 | # to a newer Visual Studio version. Backup files are not needed,
237 | # because we have git ;-)
238 | _UpgradeReport_Files/
239 | Backup*/
240 | UpgradeLog*.XML
241 | UpgradeLog*.htm
242 | ServiceFabricBackup/
243 | *.rptproj.bak
244 |
245 | # SQL Server files
246 | *.mdf
247 | *.ldf
248 | *.ndf
249 |
250 | # Business Intelligence projects
251 | *.rdl.data
252 | *.bim.layout
253 | *.bim_*.settings
254 | *.rptproj.rsuser
255 |
256 | # Microsoft Fakes
257 | FakesAssemblies/
258 |
259 | # GhostDoc plugin setting file
260 | *.GhostDoc.xml
261 |
262 | # Node.js Tools for Visual Studio
263 | .ntvs_analysis.dat
264 | node_modules/
265 |
266 | # Visual Studio 6 build log
267 | *.plg
268 |
269 | # Visual Studio 6 workspace options file
270 | *.opt
271 |
272 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
273 | *.vbw
274 |
275 | # Visual Studio LightSwitch build output
276 | **/*.HTMLClient/GeneratedArtifacts
277 | **/*.DesktopClient/GeneratedArtifacts
278 | **/*.DesktopClient/ModelManifest.xml
279 | **/*.Server/GeneratedArtifacts
280 | **/*.Server/ModelManifest.xml
281 | _Pvt_Extensions
282 |
283 | # Paket dependency manager
284 | .paket/paket.exe
285 | paket-files/
286 |
287 | # FAKE - F# Make
288 | .fake/
289 |
290 | # JetBrains Rider
291 | .idea/
292 | *.sln.iml
293 |
294 | # CodeRush
295 | .cr/
296 |
297 | # Python Tools for Visual Studio (PTVS)
298 | __pycache__/
299 | *.pyc
300 |
301 | # Cake - Uncomment if you are using it
302 | # tools/**
303 | # !tools/packages.config
304 |
305 | # Tabs Studio
306 | *.tss
307 |
308 | # Telerik's JustMock configuration file
309 | *.jmconfig
310 |
311 | # BizTalk build output
312 | *.btp.cs
313 | *.btm.cs
314 | *.odx.cs
315 | *.xsd.cs
316 |
317 | # OpenCover UI analysis results
318 | OpenCover/
319 |
320 | # Azure Stream Analytics local run output
321 | ASALocalRun/
322 |
323 | # MSBuild Binary and Structured Log
324 | *.binlog
325 |
326 | # NVidia Nsight GPU debugger configuration file
327 | *.nvuser
328 |
329 | # MFractors (Xamarin productivity tool) working folder
330 | .mfractor/
331 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation. All rights reserved.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://dev.azure.com/curiosity-ai/mosaik/_build/latest?definitionId=7&branchName=master)
2 |
3 |
4 |
5 |
6 | # HNSW.Net
7 | .Net library for fast approximate nearest neighbours search.
8 |
9 | Exact _k_ nearest neighbours search algorithms tend to perform poorly in high-dimensional spaces. To overcome curse of dimensionality the ANN algorithms come in place. This library implements one of such algorithms described in the ["Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs"](https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf) article. It provides simple API for building nearest neighbours graphs, (de)serializing them and running k-NN search queries.
10 |
11 | ## Usage
12 | Check out the following code snippets once you've added the library reference to your project.
13 | ##### How to build a graph?
14 | ```c#
15 | var parameters = new SmallWorld.Parameters()
16 | {
17 | M = 15,
18 | LevelLambda = 1 / Math.Log(15),
19 | };
20 |
21 | float[] vectors = GetFloatVectors();
22 | var graph = new SmallWorld(CosineDistance.NonOptimized);
23 | graph.BuildGraph(vectors, new Random(42), parameters);
24 | ```
25 | ##### How to run k-NN search?
26 | ```c#
27 | SmallWorld graph = GetGraph();
28 |
29 | float[] query = Enumerable.Repeat(1f, 100).ToArray();
30 | var best20 = graph.KNNSearch(query, 20);
31 | var best1 = best20.OrderBy(r => r.Distance).First();
32 | ```
33 | ##### How to (de)serialize the graph?
34 | ```c#
35 | SmallWorld graph = GetGraph();
36 | byte[] buffer = graph.SerializeGraph(); // buffer stores information about parameters and graph edges
37 |
38 | // distance function must be the same as the one which was used for building the original graph
39 | var copy = new SmallWorld(CosineDistance.NonOptimized);
40 | copy.DeserializeGraph(vectors, buffer); // the original vectors to attach to the "copy" vertices
41 | ```
42 | ##### Distance functions
43 | The only one distance function supplied by the library is the cosine distance. But there are 4 versions to address universality/performance tradeoff.
44 | ```c#
45 | CosineDistance.NonOptimized // most generic version works for all cases
46 | CosineDistance.ForUnits // gives correct result only when arguments are "unit" vectors
47 | CosineDistance.SIMD // uses SIMD instructions to optimize calculations
48 | CosineDistance.SIMDForUnits // uses SIMD and requires arguments to be "units"
49 | ```
50 | But the API allows to inject any custom distance function tailored specifically for your needs.
51 |
52 | ## Contributing
53 | Your contributions and suggestions are very welcome!
54 | Please note that this project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
55 |
56 | The contributions to this project are [released](https://help.github.com/articles/github-terms-of-service/#6-contributions-under-repository-license) to the public under the [project's open source license](LICENSE). Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.microsoft.com.
57 |
58 | ### How to contribute
59 | If you've found a bug or have a feature request then please open an issue with detailed description.
60 | We will be glad to see your pull requests as well.
61 |
62 | 1. Prepare workspace.
63 | ```
64 | git clone https://github.com/Microsoft/HNSW.Net.git
65 | cd HNSW.Net
66 | git checkout -b [username]/[feature]
67 | ```
68 | 2. Update the library and add tests if needed.
69 | 3. Build and test the changes.
70 | ```
71 | cd Src
72 | dotnet build
73 | dotnet test
74 | ```
75 | 4. Send the pull request from `[username]/[feature]` to `master` branch.
76 | 5. Get approve and merge the changes.
77 |
78 | When you submit a pull request, a CLA-bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repositories using our CLA.
79 |
80 | ### Releasing
81 | The library is distributed as a bundle of sources.
82 | We are working on enabling CI and creating Nuget package for the project.
83 |
84 |
--------------------------------------------------------------------------------
/Src/HNSW.Net.Demo/App.config:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/Src/HNSW.Net.Demo/HNSW.Net.Demo.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | net8.0
5 | Exe
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/Src/HNSW.Net.Demo/MetricsEventListener.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net.Demo;
7 |
8 | using System;
9 | using System.Collections.Generic;
10 | using System.Diagnostics.Tracing;
11 | using System.Linq;
12 |
13 | public class MetricsEventListener : EventListener
14 | {
15 | private readonly EventSource eventSource;
16 |
17 | public MetricsEventListener(EventSource eventSource)
18 | {
19 | this.eventSource = eventSource;
20 | EnableEvents(this.eventSource, EventLevel.LogAlways, EventKeywords.All, new Dictionary { { "EventCounterIntervalSec", "1" } });
21 | }
22 |
23 | public override void Dispose()
24 | {
25 | DisableEvents(eventSource);
26 | base.Dispose();
27 | }
28 |
29 | protected override void OnEventWritten(EventWrittenEventArgs eventData)
30 | {
31 | var counterData = eventData.Payload?.FirstOrDefault() as IDictionary;
32 | if (counterData?.Count == 0)
33 | {
34 | return;
35 | }
36 |
37 | Console.WriteLine($"[{counterData["Name"]:n1}]: Avg={counterData["Mean"]:n1}; SD={counterData["StandardDeviation"]:n1}; Count={counterData["Count"]}");
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/Src/HNSW.Net.Demo/Program.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net.Demo
7 | {
8 | using System;
9 | using System.Collections.Generic;
10 | using System.Diagnostics;
11 | using System.IO;
12 | using System.Linq;
13 | using System.Numerics;
14 | using System.Runtime.CompilerServices;
15 | using System.Runtime.Intrinsics;
16 | using System.Runtime.Serialization.Formatters.Binary;
17 | using System.Threading;
18 | using System.Threading.Tasks;
19 |
20 | using Parameters = SmallWorld.Parameters;
21 |
22 | public static partial class Program
23 | {
24 | private const int SampleSize = 500;
25 | private const int SampleIncrSize = 100;
26 | private const int TestSize = 10 * SampleSize;
27 | private const int Dimensionality = 128;
28 | private const string VectorsPathSuffix = "vectors.hnsw";
29 | private const string GraphPathSuffix = "graph.hnsw";
30 |
31 | public static async Task Main()
32 | {
33 | await MultithreadAddAndReadAsync();
34 | BuildAndSave("random");
35 | LoadAndSearch("random");
36 | }
37 |
38 | private static async Task MultithreadAddAndReadAsync()
39 | {
40 | var world = new SmallWorld(CosineDistance.SIMDForUnits, DefaultRandomGenerator.Instance, new Parameters() { EnableDistanceCacheForConstruction = true, InitialDistanceCacheSize = SampleSize, NeighbourHeuristic = NeighbourSelectionHeuristic.SelectHeuristic, KeepPrunedConnections = true, ExpandBestSelection = true}, threadSafe : false);
41 |
42 | var cts = new CancellationTokenSource();
43 |
44 | var taskAdd = Task.Run(() =>
45 | {
46 | while (!cts.IsCancellationRequested)
47 | {
48 | Console.Write($"Generating {SampleSize} sample vectors... ");
49 | var clock = Stopwatch.StartNew();
50 | var sampleVectors = RandomVectors(Dimensionality, SampleSize);
51 | Console.WriteLine($"Done in {clock.ElapsedMilliseconds} ms.");
52 |
53 | Console.WriteLine("Building HNSW graph... ");
54 |
55 | using (var listener = new MetricsEventListener(EventSources.GraphBuildEventSource.Instance))
56 | {
57 | clock = Stopwatch.StartNew();
58 | for (int i = 0; i < (SampleSize / SampleIncrSize); i++)
59 | {
60 | world.AddItems(sampleVectors.Skip(i * SampleIncrSize).Take(SampleIncrSize).ToArray());
61 | Console.WriteLine($"\nAt {i + 1} of {SampleSize / SampleIncrSize} Elapsed: {clock.ElapsedMilliseconds} ms.\n");
62 | }
63 | Console.WriteLine($"Done in {clock.ElapsedMilliseconds} ms.");
64 | }
65 | }
66 | });
67 |
68 | var taskSearch = Task.Run(async () =>
69 | {
70 | while (!cts.IsCancellationRequested)
71 | {
72 | try
73 | {
74 | var searchVectors = RandomVectors(Dimensionality, SampleSize);
75 | Console.WriteLine("Running search agains the graph... ");
76 | using (var listener = new MetricsEventListener(EventSources.GraphSearchEventSource.Instance))
77 | {
78 | var clock = Stopwatch.StartNew();
79 | await Parallel.ForEachAsync(searchVectors, (vector, ct) =>
80 | {
81 | world.KNNSearch(vector, 10);
82 | Console.Write('.');
83 | return default;
84 | });
85 | Console.WriteLine($"Done in {clock.ElapsedMilliseconds} ms.");
86 | }
87 | }
88 | catch (Exception E)
89 | {
90 | throw;
91 | }
92 | }
93 | });
94 |
95 |
96 | cts.CancelAfter(TimeSpan.FromMinutes(15));
97 |
98 | await Task.WhenAll(taskAdd, taskSearch);
99 | }
100 |
101 | private static void BuildAndSave(string pathPrefix)
102 | {
103 | var world = new SmallWorld(CosineDistance.SIMDForUnits, DefaultRandomGenerator.Instance, new Parameters() { EnableDistanceCacheForConstruction = true, InitialDistanceCacheSize = SampleSize, NeighbourHeuristic = NeighbourSelectionHeuristic.SelectHeuristic, KeepPrunedConnections = true, ExpandBestSelection = true});
104 |
105 | Console.Write($"Generating {SampleSize} sample vectors... ");
106 | var clock = Stopwatch.StartNew();
107 | var sampleVectors = RandomVectors(Dimensionality, SampleSize);
108 | Console.WriteLine($"Done in {clock.ElapsedMilliseconds} ms.");
109 |
110 | Console.WriteLine("Building HNSW graph... ");
111 | using (var listener = new MetricsEventListener(EventSources.GraphBuildEventSource.Instance))
112 | {
113 | clock = Stopwatch.StartNew();
114 | for(int i = 0; i < (SampleSize / SampleIncrSize); i++)
115 | {
116 | world.AddItems(sampleVectors.Skip(i * SampleIncrSize).Take(SampleIncrSize).ToArray());
117 | Console.WriteLine($"\nAt {i+1} of {SampleSize / SampleIncrSize} Elapsed: {clock.ElapsedMilliseconds} ms.\n");
118 | }
119 | Console.WriteLine($"Done in {clock.ElapsedMilliseconds} ms.");
120 | }
121 |
122 | Console.Write($"Saving HNSW graph to '${Path.Combine(Directory.GetCurrentDirectory(), pathPrefix)}'... ");
123 | clock = Stopwatch.StartNew();
124 | BinaryFormatter formatter = new BinaryFormatter();
125 | MemoryStream sampleVectorsStream = new MemoryStream();
126 | #pragma warning disable SYSLIB0011 // Type or member is obsolete
127 | formatter.Serialize(sampleVectorsStream, sampleVectors);
128 | #pragma warning restore SYSLIB0011 // Type or member is obsolete
129 | File.WriteAllBytes($"{pathPrefix}.{VectorsPathSuffix}", sampleVectorsStream.ToArray());
130 |
131 |
132 | using (var f = File.Open($"{pathPrefix}.{GraphPathSuffix}", FileMode.Create))
133 | {
134 | world.SerializeGraph(f);
135 | }
136 |
137 | Console.WriteLine($"Done in {clock.ElapsedMilliseconds} ms.");
138 | }
139 |
140 | private static void LoadAndSearch(string pathPrefix)
141 | {
142 | Stopwatch clock;
143 |
144 | Console.Write("Loading HNSW graph... ");
145 | clock = Stopwatch.StartNew();
146 | BinaryFormatter formatter = new BinaryFormatter();
147 | #pragma warning disable SYSLIB0011 // Type or member is obsolete
148 | var sampleVectors = (List)formatter.Deserialize(new MemoryStream(File.ReadAllBytes($"{pathPrefix}.{VectorsPathSuffix}")));
149 | #pragma warning restore SYSLIB0011 // Type or member is obsolete
150 | SmallWorld world;
151 | using (var f = File.OpenRead($"{pathPrefix}.{GraphPathSuffix}"))
152 | {
153 | (world, _) = SmallWorld.DeserializeGraph(sampleVectors, CosineDistance.SIMDForUnits, DefaultRandomGenerator.Instance, f);
154 | }
155 | Console.WriteLine($"Done in {clock.ElapsedMilliseconds} ms.");
156 |
157 | Console.Write($"Generating {TestSize} test vectos... ");
158 | clock = Stopwatch.StartNew();
159 | var vectors = RandomVectors(Dimensionality, TestSize);
160 | Console.WriteLine($"Done in {clock.ElapsedMilliseconds} ms.");
161 |
162 | Console.WriteLine("Running search agains the graph... ");
163 | using (var listener = new MetricsEventListener(EventSources.GraphSearchEventSource.Instance))
164 | {
165 | clock = Stopwatch.StartNew();
166 | Parallel.ForEach(vectors, (vector) =>
167 | {
168 | world.KNNSearch(vector, 10);
169 | });
170 | Console.WriteLine($"Done in {clock.ElapsedMilliseconds} ms.");
171 | }
172 | }
173 |
174 | private static List RandomVectors(int vectorSize, int vectorsCount)
175 | {
176 | var vectors = new List();
177 |
178 | for (int i = 0; i < vectorsCount; i++)
179 | {
180 | var vector = new float[vectorSize];
181 | DefaultRandomGenerator.Instance.NextFloats(vector);
182 | VectorUtils.NormalizeSIMD(vector);
183 | vectors.Add(vector);
184 | }
185 |
186 | return vectors;
187 | }
188 | }
189 | }
190 |
--------------------------------------------------------------------------------
/Src/HNSW.Net.TestFiltering/HNSW.Net.TestFiltering.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Exe
5 | net8.0
6 | enable
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/Src/HNSW.Net.TestFiltering/Program.cs:
--------------------------------------------------------------------------------
1 | using HNSW.Net;
2 | using MessagePack;
3 | using System.Diagnostics;
4 |
5 | int SampleSize = 100_000;
6 | int SampleIncrSize = 100;
7 | int Dimensions = 128;
8 |
9 | var world = new SmallWorld(VectorID.Distance, DefaultRandomGenerator.Instance, new () { EnableDistanceCacheForConstruction = true, InitialDistanceCacheSize = SampleSize, NeighbourHeuristic = NeighbourSelectionHeuristic.SelectHeuristic, KeepPrunedConnections = true, ExpandBestSelection = true }, threadSafe: false);
10 |
11 | Console.WriteLine($"Creating {SampleSize:n0} vectors");
12 | List vectors;
13 | var fileName = Path.Combine(Directory.GetCurrentDirectory(), $"data-{Dimensions}-{SampleSize}.bin");
14 | if (File.Exists(fileName))
15 | {
16 | Console.WriteLine("Loading HNSW graph... ");
17 | using (var f = File.OpenRead(fileName.Replace(".bin", ".vec")))
18 | {
19 | vectors = MessagePackSerializer.Deserialize>(f);
20 | }
21 |
22 | using (var f = File.OpenRead(fileName))
23 | {
24 | (world, _) = SmallWorld.DeserializeGraph(vectors, VectorID.Distance, DefaultRandomGenerator.Instance, f, false);
25 | }
26 | }
27 | else
28 | {
29 | vectors = RandomVectors(Dimensions, SampleSize);
30 | Console.WriteLine("Building HNSW graph... ");
31 | //using (var listener = new HNSW.Net.Demo.MetricsEventListener(EventSources.GraphBuildEventSource.Instance))
32 | {
33 | var clock = Stopwatch.StartNew();
34 | for (int i = 0; i < (SampleSize / SampleIncrSize); i++)
35 | {
36 | clock.Restart();
37 | world.AddItems(vectors.Skip(i * SampleIncrSize).Take(SampleIncrSize).ToArray());
38 | Console.WriteLine($"Indexing vectors, at {(i + 1) * SampleIncrSize:n0} of {SampleSize:n0}, at {(double)clock.ElapsedMilliseconds / SampleIncrSize:n1}ms/vector");
39 | }
40 | Console.WriteLine("Done building HNSW graph");
41 | }
42 |
43 | using (var f = File.OpenWrite(fileName.Replace(".bin", ".vec")))
44 | {
45 | MessagePackSerializer.Serialize(f, vectors);
46 | }
47 |
48 | using (var f = File.OpenWrite(fileName))
49 | {
50 | world.SerializeGraph(f);
51 | }
52 | }
53 |
54 | var times = new TimeSpan[11];
55 | int sample = 100;
56 |
57 | for (int repeat = 0; repeat < 2; repeat++)
58 | {
59 | Console.WriteLine($"----------------------------\nRun: {repeat}\n----------------------------");
60 | for (int p = 0; p <= 10; p++)
61 | {
62 | Console.Write($"Testing {p * 10}% out... ");
63 | var sw = Stopwatch.StartNew();
64 | foreach (var v in vectors.Take(sample))
65 | {
66 | using (var cts = new CancellationTokenSource())
67 | {
68 | cts.CancelAfter(TimeSpan.FromMilliseconds(1));
69 | var results = world.KNNSearch(v, 50, v => v.ID % 100 < (100 - p * 10), cts.Token);
70 | }
71 | }
72 | sw.Stop();
73 | times[p] = sw.Elapsed;
74 | Console.WriteLine($"{sw.Elapsed.TotalMilliseconds / sample:n2}ms / call");
75 | }
76 |
77 | Console.WriteLine();
78 | }
79 | Console.WriteLine($"----------------------------\nResults\n----------------------------");
80 | Console.WriteLine(string.Join("\n", times.Select((t, i) => $"Exclude {i*10}%\t{t.TotalMilliseconds / sample:n2}ms / call")));
81 |
82 |
83 | List RandomVectors(int vectorSize, int vectorsCount)
84 | {
85 | var vectors = new List();
86 |
87 | for (int i = 0; i < vectorsCount; i++)
88 | {
89 | var vector = new float[vectorSize];
90 | DefaultRandomGenerator.Instance.NextFloats(vector);
91 | VectorUtils.NormalizeSIMD(vector);
92 | vectors.Add(new VectorID(vector, i));
93 | }
94 |
95 | return vectors;
96 | }
97 |
98 | [MessagePackObject]
99 | public struct VectorID
100 | {
101 | [Key(0)] public float[] Vector;
102 | [Key(1)] public int ID;
103 |
104 | public VectorID(float[] vector, int iD)
105 | {
106 | Vector = vector;
107 | ID = iD;
108 | }
109 |
110 | internal static float Distance(VectorID a, VectorID b)
111 | {
112 | return CosineDistance.SIMDForUnits(a.Vector, b.Vector);
113 | }
114 | }
115 |
--------------------------------------------------------------------------------
/Src/HNSW.Net.Tests/BinaryHeapTests.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net.Tests
7 | {
8 | using System.Linq;
9 | using Microsoft.VisualStudio.TestTools.UnitTesting;
10 |
11 | ///
12 | /// Tests for
13 | ///
14 | [TestClass]
15 | public class BinaryHeapTests
16 | {
17 | ///
18 | /// Tests heap construction.
19 | ///
20 | [TestMethod]
21 | public void HeapifyTest()
22 | {
23 | // Basic tests
24 | {
25 | var heap = new BinaryHeap(Enumerable.Empty().ToList());
26 | Assert.IsFalse(heap.Buffer.Any());
27 |
28 | heap = new BinaryHeap(Enumerable.Range(1, 1).ToList());
29 | Assert.AreEqual(1, heap.Buffer.Count);
30 | Assert.AreEqual(1, heap.Buffer.First());
31 |
32 | heap = new BinaryHeap(Enumerable.Range(1, 2).ToList());
33 | Assert.AreEqual(2, heap.Buffer.Count);
34 | Assert.AreEqual(2, heap.Buffer.First());
35 | }
36 |
37 | // Heapify produces correct heap.
38 | {
39 | const string input = "Hello, World!";
40 | var heap = new BinaryHeap(input.ToList());
41 | AssertMaxHeap(heap);
42 | }
43 | }
44 |
45 | ///
46 | /// Tests and
47 | ///
48 | [TestMethod]
49 | public void PushPopTest()
50 | {
51 | var heap = new BinaryHeap(Enumerable.Empty().ToList());
52 | for (int i = 0; i < 10; ++i)
53 | {
54 | heap.Push(i);
55 | }
56 |
57 | AssertMaxHeap(heap);
58 |
59 | int top = heap.Buffer.First();
60 | while (heap.Buffer.Any())
61 | {
62 | Assert.AreEqual(top, heap.Pop());
63 | top = heap.Buffer.FirstOrDefault();
64 | }
65 | }
66 |
67 | private void AssertMaxHeap(BinaryHeap heap)
68 | {
69 | for (int p = 0; p < heap.Buffer.Count; ++p)
70 | {
71 | int l = (2 * p) + 1;
72 | int r = l + 1;
73 |
74 | var parent = heap.Buffer[p];
75 | if (l < heap.Buffer.Count)
76 | {
77 | var left = heap.Buffer[l];
78 | Assert.IsTrue(heap.Comparer.Compare(parent, left) >= 0);
79 | }
80 |
81 | if (r < heap.Buffer.Count)
82 | {
83 | var right = heap.Buffer[r];
84 | Assert.IsTrue(heap.Comparer.Compare(parent, right) >= 0);
85 | }
86 | }
87 | }
88 | }
89 | }
--------------------------------------------------------------------------------
/Src/HNSW.Net.Tests/HNSW.Net.Tests.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | net8.0
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | PreserveNewest
20 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/Src/HNSW.Net.Tests/RewindableRandomNumberGenerator.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net.Tests
7 | {
8 | using System;
9 | using System.Collections.Generic;
10 | using System.Linq;
11 |
12 | public class RewindableRandomNumberGenerator : IProvideRandomValues
13 | {
14 | public bool IsThreadSafe => false;
15 |
16 | private FastRandom _random = new FastRandom(42);
17 | private List _calls = new List();
18 |
19 | public int Next(int minValue, int maxValue)
20 | {
21 | _calls.Add( () => _random.Next(minValue, maxValue) );
22 | return _random.Next(minValue, maxValue);
23 | }
24 |
25 | public float NextFloat()
26 | {
27 | _calls.Add(() => _random.NextFloat());
28 | return _random.NextFloat();
29 | }
30 |
31 | public void NextFloats(Span buffer)
32 | {
33 | var len = buffer.Length;
34 | _calls.Add(() =>
35 | {
36 | var b = new float[len];
37 | _random.NextFloats(b);
38 | });
39 | _random.NextFloats(buffer);
40 | }
41 |
42 | public int GetState() => _calls.Count;
43 |
44 | public void RewindTo(int state)
45 | {
46 | _random = new FastRandom(42);
47 | var toInvoke = _calls.Take(state).ToArray();
48 | _calls.Clear();
49 | foreach (var a in toInvoke)
50 | {
51 | a();
52 | }
53 | }
54 | }
55 | }
--------------------------------------------------------------------------------
/Src/HNSW.Net.Tests/SmallWorldTests.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net.Tests
7 | {
8 | using System;
9 | using System.Collections.Generic;
10 | using System.Globalization;
11 | using System.IO;
12 | using System.Linq;
13 | using Microsoft.VisualStudio.TestTools.UnitTesting;
14 |
15 | ///
16 | /// Tests for
17 | ///
18 | [TestClass]
19 | public class SmallWorldTests
20 | {
21 | // Set floating point error to 5.96 * 10^-7
22 | // For cosine distance error can be bigger in theory but for test data it's not the case.
23 | private const float FloatError = 0.000000596f;
24 |
25 | private IReadOnlyList vectors;
26 |
27 | ///
28 | /// Initializes test resources.
29 | ///
30 | [TestInitialize]
31 | public void TestInitialize()
32 | {
33 | var data = File.ReadAllLines(@"vectors.txt");
34 | vectors = data.Select(r => Array.ConvertAll(r.Split('\t'), x => float.Parse(x, CultureInfo.CurrentCulture))).ToList();
35 | }
36 |
37 | ///
38 | /// Basic test for knn search - this test might fail sometimes, as the construction of the graph does not guarantee an exact answer
39 | ///
40 | [TestMethod]
41 | public void KNNSearchTest()
42 | {
43 | var parameters = new SmallWorld.Parameters();
44 | var graph = new SmallWorld(CosineDistance.NonOptimized, DefaultRandomGenerator.Instance, parameters);
45 | graph.AddItems(vectors);
46 |
47 | int bestWrong = 0;
48 | float maxError = float.MinValue;
49 |
50 | for (int i = 0; i < vectors.Count; ++i)
51 | {
52 | var result = graph.KNNSearch(vectors[i], 20);
53 | var best = result.OrderBy(r => r.Distance).First();
54 | Assert.AreEqual(20, result.Count);
55 | if (best.Id != i)
56 | {
57 | bestWrong++;
58 | }
59 | maxError = Math.Max(maxError, best.Distance);
60 | }
61 | Assert.AreEqual(0, bestWrong);
62 | Assert.AreEqual(0, maxError, FloatError);
63 | }
64 |
65 | ///
66 | /// Basic test for knn search - this test might fail sometimes, as the construction of the graph does not guarantee an exact answer
67 | ///
68 | [TestMethod]
69 | public void KNNSearchWithFilterTest()
70 | {
71 | var parameters = new SmallWorld.Parameters();
72 | var graph = new SmallWorld(CosineDistance.NonOptimized, DefaultRandomGenerator.Instance, parameters);
73 | graph.AddItems(vectors);
74 |
75 | for (int i = 0; i < vectors.Count; ++i)
76 | {
77 | var result = graph.KNNSearch(vectors[i], 20, filterItem: v => false);
78 | Assert.AreEqual(0, result.Count);
79 | }
80 | }
81 |
82 | ///
83 | /// Basic test for knn search - this test might fail sometimes, as the construction of the graph does not guarantee an exact answer
84 | ///
85 | [DataTestMethod]
86 | [DataRow(false,false)]
87 | [DataRow(false,true)]
88 | [DataRow(true, false)]
89 | [DataRow(true, true)]
90 | public void KNNSearchTestAlgorithm4(bool expandBestSelection, bool keepPrunedConnections )
91 | {
92 | var parameters = new SmallWorld.Parameters() { NeighbourHeuristic = NeighbourSelectionHeuristic.SelectHeuristic, ExpandBestSelection = expandBestSelection, KeepPrunedConnections = keepPrunedConnections };
93 | var graph = new SmallWorld(CosineDistance.NonOptimized, DefaultRandomGenerator.Instance, parameters);
94 | graph.AddItems(vectors);
95 |
96 | int bestWrong = 0;
97 | float maxError = float.MinValue;
98 |
99 | for (int i = 0; i < vectors.Count; ++i)
100 | {
101 | var result = graph.KNNSearch(vectors[i], 20);
102 | var best = result.OrderBy(r => r.Distance).First();
103 | Assert.AreEqual(20, result.Count);
104 | if (best.Id != i)
105 | {
106 | bestWrong++;
107 | }
108 | maxError = Math.Max(maxError, best.Distance);
109 | }
110 | Assert.AreEqual(0, bestWrong);
111 | Assert.AreEqual(0, maxError, FloatError);
112 | }
113 |
114 | ///
115 | /// Serialization deserialization tests.
116 | ///
117 | [TestMethod]
118 | public void SerializeDeserializeTest()
119 | {
120 | byte[] buffer;
121 | string original;
122 |
123 | // restrict scope of original graph
124 | var stream = new MemoryStream();
125 | {
126 | var parameters = new SmallWorld.Parameters()
127 | {
128 | M = 15,
129 | LevelLambda = 1 / Math.Log(15),
130 | };
131 |
132 | var graph = new SmallWorld(CosineDistance.NonOptimized, DefaultRandomGenerator.Instance, parameters);
133 | graph.AddItems(vectors);
134 |
135 | graph.SerializeGraph(stream);
136 | original = graph.Print();
137 | }
138 | stream.Position = 0;
139 |
140 | var copy = SmallWorld.DeserializeGraph(vectors, CosineDistance.NonOptimized, DefaultRandomGenerator.Instance, stream);
141 |
142 | Assert.AreEqual(original, copy.Graph.Print());
143 | }
144 |
145 | ///
146 | /// Serialization deserialization tests.
147 | ///
148 | [TestMethod]
149 | public void SerializeDeserializeWithRemainingItemsTest()
150 | {
151 | byte[] buffer;
152 | string original;
153 | int rng_state;
154 | var rng = new RewindableRandomNumberGenerator();
155 |
156 | // restrict scope of original graph
157 | var stream = new MemoryStream();
158 | {
159 | var parameters = new SmallWorld.Parameters()
160 | {
161 | M = 15,
162 | LevelLambda = 1 / Math.Log(15),
163 | };
164 | int itemsToLeaveBehindOnSerialization = (int)(vectors.Count / 2);
165 |
166 | var graph = new SmallWorld(CosineDistance.NonOptimized, rng, parameters);
167 | graph.AddItems(vectors.Take(itemsToLeaveBehindOnSerialization).ToArray());
168 |
169 | graph.SerializeGraph(stream);
170 | rng_state = rng.GetState();
171 | graph.AddItems(vectors.Skip(itemsToLeaveBehindOnSerialization).ToArray());
172 | rng.RewindTo(rng_state);
173 |
174 | original = graph.Print();
175 | }
176 | stream.Position = 0;
177 |
178 | var copy = SmallWorld.DeserializeGraph(vectors, CosineDistance.NonOptimized, rng, stream);
179 |
180 | copy.Graph.AddItems(copy.ItemsNotInGraph);
181 |
182 | Assert.AreEqual(original, copy.Graph.Print());
183 | }
184 | }
185 | }
--------------------------------------------------------------------------------
/Src/HNSW.Net.sln:
--------------------------------------------------------------------------------
1 |
2 | Microsoft Visual Studio Solution File, Format Version 12.00
3 | # Visual Studio Version 17
4 | VisualStudioVersion = 17.12.35209.166
5 | MinimumVisualStudioVersion = 10.0.40219.1
6 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HNSW.Net", "HNSW.Net\HNSW.Net.csproj", "{1A7D27E7-4F89-4197-BA86-2A7E80E702DE}"
7 | EndProject
8 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HNSW.Net.Tests", "HNSW.Net.Tests\HNSW.Net.Tests.csproj", "{F1C038AB-3654-45F3-BE98-378D4EDC2762}"
9 | EndProject
10 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HNSW.Net.Demo", "HNSW.Net.Demo\HNSW.Net.Demo.csproj", "{6F62C4BE-5D1C-4302-9813-E707EB902221}"
11 | EndProject
12 | Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = ".azuredevops", ".azuredevops", "{07C754D2-86A3-46F1-B78D-9C81F8974996}"
13 | ProjectSection(SolutionItems) = preProject
14 | ..\.devops\build-nuget.yml = ..\.devops\build-nuget.yml
15 | EndProjectSection
16 | EndProject
17 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "HNSW.Net.TestFiltering", "HNSW.Net.TestFiltering\HNSW.Net.TestFiltering.csproj", "{3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}"
18 | EndProject
19 | Global
20 | GlobalSection(SolutionConfigurationPlatforms) = preSolution
21 | Debug|Any CPU = Debug|Any CPU
22 | Debug|x64 = Debug|x64
23 | Debug|x86 = Debug|x86
24 | Release|Any CPU = Release|Any CPU
25 | Release|x64 = Release|x64
26 | Release|x86 = Release|x86
27 | EndGlobalSection
28 | GlobalSection(ProjectConfigurationPlatforms) = postSolution
29 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
30 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Debug|Any CPU.Build.0 = Debug|Any CPU
31 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Debug|x64.ActiveCfg = Debug|Any CPU
32 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Debug|x64.Build.0 = Debug|Any CPU
33 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Debug|x86.ActiveCfg = Debug|Any CPU
34 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Debug|x86.Build.0 = Debug|Any CPU
35 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Release|Any CPU.ActiveCfg = Release|Any CPU
36 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Release|Any CPU.Build.0 = Release|Any CPU
37 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Release|x64.ActiveCfg = Release|Any CPU
38 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Release|x64.Build.0 = Release|Any CPU
39 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Release|x86.ActiveCfg = Release|Any CPU
40 | {1A7D27E7-4F89-4197-BA86-2A7E80E702DE}.Release|x86.Build.0 = Release|Any CPU
41 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
42 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Debug|Any CPU.Build.0 = Debug|Any CPU
43 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Debug|x64.ActiveCfg = Debug|Any CPU
44 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Debug|x64.Build.0 = Debug|Any CPU
45 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Debug|x86.ActiveCfg = Debug|Any CPU
46 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Debug|x86.Build.0 = Debug|Any CPU
47 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Release|Any CPU.ActiveCfg = Release|Any CPU
48 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Release|Any CPU.Build.0 = Release|Any CPU
49 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Release|x64.ActiveCfg = Release|Any CPU
50 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Release|x64.Build.0 = Release|Any CPU
51 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Release|x86.ActiveCfg = Release|Any CPU
52 | {F1C038AB-3654-45F3-BE98-378D4EDC2762}.Release|x86.Build.0 = Release|Any CPU
53 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
54 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Debug|Any CPU.Build.0 = Debug|Any CPU
55 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Debug|x64.ActiveCfg = Debug|Any CPU
56 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Debug|x64.Build.0 = Debug|Any CPU
57 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Debug|x86.ActiveCfg = Debug|Any CPU
58 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Debug|x86.Build.0 = Debug|Any CPU
59 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Release|Any CPU.ActiveCfg = Release|Any CPU
60 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Release|Any CPU.Build.0 = Release|Any CPU
61 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Release|x64.ActiveCfg = Release|Any CPU
62 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Release|x64.Build.0 = Release|Any CPU
63 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Release|x86.ActiveCfg = Release|Any CPU
64 | {6F62C4BE-5D1C-4302-9813-E707EB902221}.Release|x86.Build.0 = Release|Any CPU
65 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
66 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Debug|Any CPU.Build.0 = Debug|Any CPU
67 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Debug|x64.ActiveCfg = Debug|Any CPU
68 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Debug|x64.Build.0 = Debug|Any CPU
69 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Debug|x86.ActiveCfg = Debug|Any CPU
70 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Debug|x86.Build.0 = Debug|Any CPU
71 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Release|Any CPU.ActiveCfg = Release|Any CPU
72 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Release|Any CPU.Build.0 = Release|Any CPU
73 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Release|x64.ActiveCfg = Release|Any CPU
74 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Release|x64.Build.0 = Release|Any CPU
75 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Release|x86.ActiveCfg = Release|Any CPU
76 | {3D3E2F64-B88B-4E53-B9DD-27870B6D98AA}.Release|x86.Build.0 = Release|Any CPU
77 | EndGlobalSection
78 | GlobalSection(SolutionProperties) = preSolution
79 | HideSolutionNode = FALSE
80 | EndGlobalSection
81 | GlobalSection(ExtensibilityGlobals) = postSolution
82 | SolutionGuid = {0B720E3E-7DD4-4ED8-BD60-49E166055CBD}
83 | EndGlobalSection
84 | EndGlobal
85 |
--------------------------------------------------------------------------------
/Src/HNSW.Net/Algorithms.Algorithm3.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net
7 | {
8 | using System;
9 | using System.Collections.Generic;
10 |
11 | internal partial class Algorithms
12 | {
13 | ///
14 | /// The implementation of the SELECT-NEIGHBORS-SIMPLE(q, C, M) algorithm.
15 | /// Article: Section 4. Algorithm 3.
16 | ///
17 | /// The typeof the items in the small world.
18 | /// The type of the distance in the small world.
19 | internal class Algorithm3 : Algorithm where TDistance : struct, IComparable
20 | {
21 | public Algorithm3(Graph.Core graphCore) : base(graphCore)
22 | {
23 | }
24 |
25 | ///
26 | internal override List SelectBestForConnecting(List candidatesIds, TravelingCosts travelingCosts, int layer)
27 | {
28 | /*
29 | * q ← this
30 | * return M nearest elements from C to q
31 | */
32 |
33 | // !NO COPY! in-place selection
34 | var bestN = GetM(layer);
35 | var candidatesHeap = new BinaryHeap(candidatesIds, travelingCosts);
36 | while (candidatesHeap.Buffer.Count > bestN)
37 | {
38 | candidatesHeap.Pop();
39 | }
40 |
41 | return candidatesHeap.Buffer;
42 | }
43 | }
44 | }
45 | }
--------------------------------------------------------------------------------
/Src/HNSW.Net/Algorithms.Algorithm4.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net
7 | {
8 | using System;
9 | using System.Collections.Generic;
10 | using System.Linq;
11 |
12 | internal partial class Algorithms
13 | {
14 | ///
15 | /// The implementation of the SELECT-NEIGHBORS-HEURISTIC(q, C, M, lc, extendCandidates, keepPrunedConnections) algorithm.
16 | /// Article: Section 4. Algorithm 4.
17 | ///
18 | /// The typeof the items in the small world.
19 | /// The type of the distance in the small world.
20 | internal sealed class Algorithm4 : Algorithm where TDistance : struct, IComparable
21 | {
22 | public Algorithm4(Graph.Core graphCore) : base(graphCore)
23 | {
24 | }
25 |
26 | ///
27 | internal override List SelectBestForConnecting(List candidatesIds, TravelingCosts travelingCosts, int layer)
28 | {
29 | /*
30 | * q ← this
31 | * R ← ∅ // result
32 | * W ← C // working queue for the candidates
33 | * if expandCandidates // expand candidates
34 | * for each e ∈ C
35 | * for each eadj ∈ neighbourhood(e) at layer lc
36 | * if eadj ∉ W
37 | * W ← W ⋃ eadj
38 | *
39 | * Wd ← ∅ // queue for the discarded candidates
40 | * while │W│ gt 0 and │R│ lt M
41 | * e ← extract nearest element from W to q
42 | * if e is closer to q compared to any element from R
43 | * R ← R ⋃ e
44 | * else
45 | * Wd ← Wd ⋃ e
46 | *
47 | * if keepPrunedConnections // add some of the discarded connections from Wd
48 | * while │Wd│ gt 0 and │R│ lt M
49 | * R ← R ⋃ extract nearest element from Wd to q
50 | *
51 | * return R
52 | */
53 |
54 | IComparer fartherIsOnTop = travelingCosts;
55 | IComparer closerIsOnTop = fartherIsOnTop.Reverse();
56 |
57 | var layerM = GetM(layer);
58 |
59 | var resultHeap = new BinaryHeap(new List(layerM + 1), fartherIsOnTop);
60 | var candidatesHeap = new BinaryHeap(candidatesIds, closerIsOnTop);
61 |
62 | // expand candidates option is enabled
63 | if (GraphCore.Parameters.ExpandBestSelection)
64 | {
65 | var visited = new HashSet(candidatesHeap.Buffer);
66 | var toAdd = new HashSet();
67 | foreach (var candidateId in candidatesHeap.Buffer)
68 | {
69 | var candidateNeighborsIDs = GraphCore.Nodes[candidateId][layer];
70 | foreach (var candidateNeighbourId in candidateNeighborsIDs)
71 | {
72 | if (!visited.Contains(candidateNeighbourId))
73 | {
74 | toAdd.Add(candidateNeighbourId);
75 | visited.Add(candidateNeighbourId);
76 | }
77 | }
78 | }
79 | foreach(var id in toAdd)
80 | {
81 | candidatesHeap.Push(id);
82 | }
83 | }
84 |
85 | // main stage of moving candidates to result
86 | var discardedHeap = new BinaryHeap(new List(candidatesHeap.Buffer.Count), closerIsOnTop);
87 | while (candidatesHeap.Buffer.Any() && resultHeap.Buffer.Count < layerM)
88 | {
89 | var candidateId = candidatesHeap.Pop();
90 | var farestResultId = resultHeap.Buffer.FirstOrDefault();
91 |
92 | if (!resultHeap.Buffer.Any() || DistanceUtils.LowerThan(travelingCosts.From(candidateId), travelingCosts.From(farestResultId)))
93 | {
94 | resultHeap.Push(candidateId);
95 | }
96 | else if (GraphCore.Parameters.KeepPrunedConnections)
97 | {
98 | discardedHeap.Push(candidateId);
99 | }
100 | }
101 |
102 | // keep pruned option is enabled
103 | if (GraphCore.Parameters.KeepPrunedConnections)
104 | {
105 | while (discardedHeap.Buffer.Any() && resultHeap.Buffer.Count < layerM)
106 | {
107 | resultHeap.Push(discardedHeap.Pop());
108 | }
109 | }
110 |
111 | return resultHeap.Buffer;
112 | }
113 | }
114 | }
115 | }
116 |
--------------------------------------------------------------------------------
/Src/HNSW.Net/Algorithms.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net
7 | {
8 | using System;
9 | using System.Collections.Generic;
10 |
11 | internal partial class Algorithms
12 | {
13 | ///
14 | /// The abstract class representing algorithm to control node capacity.
15 | ///
16 | /// The typeof the items in the small world.
17 | /// The type of the distance in the small world.
18 | internal abstract class Algorithm where TDistance : struct, IComparable
19 | {
20 | protected readonly Graph.Core GraphCore;
21 |
22 | protected readonly Func NodeDistance;
23 |
24 | public Algorithm(Graph.Core graphCore)
25 | {
26 | GraphCore = graphCore;
27 | NodeDistance = graphCore.GetDistance;
28 | }
29 |
30 | ///
31 | /// Creates a new instance of the struct. Controls the exact type of connection lists.
32 | ///
33 | /// The identifier of the node.
34 | /// The max layer where the node is presented.
35 | /// The new instance.
36 | internal virtual Node NewNode(int nodeId, int maxLayer)
37 | {
38 | var connections = new List>(maxLayer + 1);
39 | for (int layer = 0; layer <= maxLayer; ++layer)
40 | {
41 | // M + 1 neighbours to not realloc in AddConnection when the level is full
42 | int layerM = GetM(layer) + 1;
43 | connections.Add(new List(layerM));
44 | }
45 |
46 | return new Node
47 | {
48 | Id = nodeId,
49 | Connections = connections
50 | };
51 | }
52 |
53 | ///
54 | /// The algorithm which selects best neighbours from the candidates for the given node.
55 | ///
56 | /// The identifiers of candidates to neighbourhood.
57 | /// Traveling costs to compare candidates.
58 | /// The layer of the neighbourhood.
59 | /// Best nodes selected from the candidates.
60 | internal abstract List SelectBestForConnecting(List candidatesIds, TravelingCosts travelingCosts, int layer);
61 |
62 | ///
63 | /// Get maximum allowed connections for the given level.
64 | ///
65 | ///
66 | /// Article: Section 4.1:
67 | /// "Selection of the Mmax0 (the maximum number of connections that an element can have in the zero layer) also
68 | /// has a strong influence on the search performance, especially in case of high quality(high recall) search.
69 | /// Simulations show that setting Mmax0 to M(this corresponds to kNN graphs on each layer if the neighbors
70 | /// selection heuristic is not used) leads to a very strong performance penalty at high recall.
71 | /// Simulations also suggest that 2∙M is a good choice for Mmax0;
72 | /// setting the parameter higher leads to performance degradation and excessive memory usage."
73 | ///
74 | /// The level of the layer.
75 | /// The maximum number of connections.
76 | internal int GetM(int layer)
77 | {
78 | return layer == 0 ? 2 * GraphCore.Parameters.M : GraphCore.Parameters.M;
79 | }
80 |
81 | ///
82 | /// Tries to connect the node with the new neighbour.
83 | ///
84 | /// The node to add neighbour to.
85 | /// The new neighbour.
86 | /// The layer to add neighbour to.
87 | internal void Connect(Node node, Node neighbour, int layer)
88 | {
89 | var nodeLayer = node[layer];
90 | nodeLayer.Add(neighbour.Id);
91 | if (nodeLayer.Count > GetM(layer))
92 | {
93 | var travelingCosts = new TravelingCosts(NodeDistance, node.Id);
94 | node[layer] = SelectBestForConnecting(nodeLayer, travelingCosts, layer);
95 | }
96 | }
97 | }
98 |
99 | }
100 | }
101 |
--------------------------------------------------------------------------------
/Src/HNSW.Net/BinaryHeap.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net
7 | {
8 | using System;
9 | using System.Collections.Generic;
10 | using System.Diagnostics.CodeAnalysis;
11 |
12 | ///
13 | /// Binary heap wrapper around the It's a max-heap implementation i.e. the maximum element is always on top. But the order of elements can be customized by providing instance.
14 | ///
15 | /// The type of the items in the source list.
16 | [SuppressMessage("Performance", "CA1815:Override equals and operator equals on value types", Justification = "By design")]
17 | internal struct BinaryHeap
18 | {
19 | internal IComparer Comparer;
20 | internal List Buffer;
21 | internal bool Any => Buffer.Count > 0;
22 | internal BinaryHeap(List buffer) : this(buffer, Comparer.Default) { }
23 | internal BinaryHeap(List buffer, IComparer comparer)
24 | {
25 | Buffer = buffer ?? throw new ArgumentNullException(nameof(buffer));
26 | Comparer = comparer;
27 | for (int i = 1; i < Buffer.Count; ++i) { SiftUp(i); }
28 | }
29 |
30 | internal void Push(T item)
31 | {
32 | Buffer.Add(item);
33 | SiftUp(Buffer.Count - 1);
34 | }
35 |
36 | internal T Pop()
37 | {
38 | if (Buffer.Count > 0)
39 | {
40 | var result = Buffer[0];
41 |
42 | Buffer[0] = Buffer[Buffer.Count - 1];
43 | Buffer.RemoveAt(Buffer.Count - 1);
44 | SiftDown(0);
45 |
46 | return result;
47 | }
48 |
49 | throw new InvalidOperationException("Heap is empty");
50 | }
51 |
52 | ///
53 | /// Restores the heap property starting from i'th position down to the bottom given that the downstream items fulfill the rule.
54 | ///
55 | /// The position of item where heap property is violated.
56 | private void SiftDown(int i)
57 | {
58 | while (i < Buffer.Count)
59 | {
60 | int l = (i << 1) + 1;
61 | int r = l + 1;
62 | if (l >= Buffer.Count)
63 | {
64 | break;
65 | }
66 |
67 | int m = r < Buffer.Count && Comparer.Compare(Buffer[l], Buffer[r]) < 0 ? r : l;
68 | if (Comparer.Compare(Buffer[m], Buffer[i]) <= 0)
69 | {
70 | break;
71 | }
72 |
73 | Swap(i, m);
74 | i = m;
75 | }
76 | }
77 |
78 | ///
79 | /// Restores the heap property starting from i'th position up to the head given that the upstream items fulfill the rule.
80 | ///
81 | /// The position of item where heap property is violated.
82 | private void SiftUp(int i)
83 | {
84 | while (i > 0)
85 | {
86 | int p = (i - 1) >> 1;
87 | if (Comparer.Compare(Buffer[i], Buffer[p]) <= 0)
88 | {
89 | break;
90 | }
91 |
92 | Swap(i, p);
93 | i = p;
94 | }
95 | }
96 |
97 | private void Swap(int i, int j)
98 | {
99 | var temp = Buffer[i];
100 | Buffer[i] = Buffer[j];
101 | Buffer[j] = temp;
102 | }
103 | }
104 | }
105 |
--------------------------------------------------------------------------------
/Src/HNSW.Net/CosineDistance.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net
7 | {
8 | using System;
9 | using System.Collections.Generic;
10 | using System.Numerics;
11 | using System.Runtime.CompilerServices;
12 |
13 | ///
14 | /// Calculates cosine similarity.
15 | ///
16 | ///
17 | /// Intuition behind selecting float as a carrier.
18 | ///
19 | /// 1. In practice we work with vectors of dimensionality 100 and each component has value in range [-1; 1]
20 | /// There certainly is a possibility of underflow.
21 | /// But we assume that such cases are rare and we can rely on such underflow losses.
22 | ///
23 | /// 2. According to the article http://www.ti3.tuhh.de/paper/rump/JeaRu13.pdf
24 | /// the floating point rounding error is less then 100 * 2^-24 * sqrt(100) * sqrt(100) < 0.0005960
25 | /// We deem such precision is satisfactory for out needs.
26 | ///
27 | public static class CosineDistance
28 | {
29 | ///
30 | /// Calculates cosine distance without making any optimizations.
31 | ///
32 | /// Left vector.
33 | /// Right vector.
34 | /// Cosine distance between u and v.
35 | public static float NonOptimized(float[] u, float[] v)
36 | {
37 | if (u.Length != v.Length)
38 | {
39 | throw new ArgumentException("Vectors have non-matching dimensions");
40 | }
41 |
42 | float dot = 0.0f;
43 | float nru = 0.0f;
44 | float nrv = 0.0f;
45 | for (int i = 0; i < u.Length; ++i)
46 | {
47 | dot += u[i] * v[i];
48 | nru += u[i] * u[i];
49 | nrv += v[i] * v[i];
50 | }
51 |
52 | var similarity = dot / (float)(Math.Sqrt(nru) * Math.Sqrt(nrv));
53 | return 1 - similarity;
54 | }
55 |
56 | ///
57 | /// Calculates cosine distance with assumption that u and v are unit vectors.
58 | ///
59 | /// Left vector.
60 | /// Right vector.
61 | /// Cosine distance between u and v.
62 | public static float ForUnits(float[] u, float[] v)
63 | {
64 | if (u.Length!= v.Length)
65 | {
66 | throw new ArgumentException("Vectors have non-matching dimensions");
67 | }
68 |
69 | float dot = 0;
70 | for (int i = 0; i < u.Length; ++i)
71 | {
72 | dot += u[i] * v[i];
73 | }
74 |
75 | return 1 - dot;
76 | }
77 |
78 | ///
79 | /// Calculates cosine distance optimized using SIMD instructions.
80 | ///
81 | /// Left vector.
82 | /// Right vector.
83 | /// Cosine distance between u and v.
84 | public static float SIMD(float[] u, float[] v)
85 | {
86 | if (!Vector.IsHardwareAccelerated)
87 | {
88 | throw new NotSupportedException($"SIMD version of {nameof(CosineDistance)} is not supported");
89 | }
90 |
91 | if (u.Length != v.Length)
92 | {
93 | throw new ArgumentException("Vectors have non-matching dimensions");
94 | }
95 |
96 | float dot = 0;
97 | var norm = default(Vector2);
98 | int step = Vector.Count;
99 |
100 | int i, to = u.Length - step;
101 | for (i = 0; i <= to; i += step)
102 | {
103 | var ui = new Vector(u, i);
104 | var vi = new Vector(v, i);
105 | dot += Vector.Dot(ui, vi);
106 | norm.X += Vector.Dot(ui, ui);
107 | norm.Y += Vector.Dot(vi, vi);
108 | }
109 |
110 | for (; i < u.Length; ++i)
111 | {
112 | dot += u[i] * v[i];
113 | norm.X += u[i] * u[i];
114 | norm.Y += v[i] * v[i];
115 | }
116 |
117 | norm = Vector2.SquareRoot(norm);
118 | float n = (norm.X * norm.Y);
119 |
120 | if (n == 0)
121 | {
122 | return 1f;
123 | }
124 |
125 | var similarity = dot / n;
126 | return 1f - similarity;
127 | }
128 |
129 | ///
130 | /// Calculates cosine distance with assumption that u and v are unit vectors using SIMD instructions.
131 | ///
132 | /// Left vector.
133 | /// Right vector.
134 | /// Cosine distance between u and v.
135 | public static float SIMDForUnits(float[] u, float[] v)
136 | {
137 | return 1f - DotProduct(ref u, ref v);
138 | }
139 |
140 | private static readonly int _vs1 = Vector.Count;
141 | private static readonly int _vs2 = 2 * Vector.Count;
142 | private static readonly int _vs3 = 3 * Vector.Count;
143 | private static readonly int _vs4 = 4 * Vector.Count;
144 |
145 | [MethodImpl(MethodImplOptions.AggressiveInlining)]
146 | private static float DotProduct(ref float[] lhs, ref float[] rhs)
147 | {
148 | float result = 0f;
149 |
150 | var count = lhs.Length;
151 | var offset = 0;
152 |
153 | while (count >= _vs4)
154 | {
155 | result += Vector.Dot(new Vector(lhs, offset), new Vector(rhs, offset));
156 | result += Vector.Dot(new Vector(lhs, offset + _vs1), new Vector(rhs, offset + _vs1));
157 | result += Vector.Dot(new Vector(lhs, offset + _vs2), new Vector(rhs, offset + _vs2));
158 | result += Vector.Dot(new Vector(lhs, offset + _vs3), new Vector(rhs, offset + _vs3));
159 | if (count == _vs4) return result;
160 | count -= _vs4;
161 | offset += _vs4;
162 | }
163 |
164 | if (count >= _vs2)
165 | {
166 | result += Vector.Dot(new Vector(lhs, offset), new Vector(rhs, offset));
167 | result += Vector.Dot(new Vector(lhs, offset + _vs1), new Vector(rhs, offset + _vs1));
168 | if (count == _vs2) return result;
169 | count -= _vs2;
170 | offset += _vs2;
171 | }
172 | if (count >= _vs1)
173 | {
174 | result += Vector.Dot(new Vector(lhs, offset), new Vector(rhs, offset));
175 | if (count == _vs1) return result;
176 | count -= _vs1;
177 | offset += _vs1;
178 | }
179 | if (count > 0)
180 | {
181 | while (count > 0)
182 | {
183 | result += lhs[offset] * rhs[offset];
184 | offset++; count--;
185 | }
186 | }
187 | return result;
188 | }
189 | }
190 | }
191 |
--------------------------------------------------------------------------------
/Src/HNSW.Net/DefaultRandomGenerator.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 | using System.Runtime.CompilerServices;
4 | using System.Text;
5 |
6 | namespace HNSW.Net
7 | {
8 | public sealed class DefaultRandomGenerator : IProvideRandomValues
9 | {
10 | ///
11 | /// This is the default configuration (it supports the optimization process to be executed on multiple threads)
12 | ///
13 | public static DefaultRandomGenerator Instance { get; } = new DefaultRandomGenerator(allowParallel: true);
14 |
15 | ///
16 | /// This uses the same random number generator but forces the optimization process to run on a single thread (which may be desirable if multiple requests may be processed concurrently
17 | /// or if it is otherwise not desirable to let a single request access all of the CPUs)
18 | ///
19 | public static DefaultRandomGenerator DisableThreading { get; } = new DefaultRandomGenerator(allowParallel: false);
20 |
21 | private DefaultRandomGenerator(bool allowParallel) => IsThreadSafe = allowParallel;
22 |
23 | public bool IsThreadSafe { get; }
24 |
25 | [MethodImpl(MethodImplOptions.AggressiveInlining)]
26 | public int Next(int minValue, int maxValue) => ThreadSafeFastRandom.Next(minValue, maxValue);
27 |
28 | [MethodImpl(MethodImplOptions.AggressiveInlining)]
29 | public float NextFloat() => ThreadSafeFastRandom.NextFloat();
30 |
31 | [MethodImpl(MethodImplOptions.AggressiveInlining)]
32 | public void NextFloats(Span buffer) => ThreadSafeFastRandom.NextFloats(buffer);
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/Src/HNSW.Net/DistanceCache.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | using System;
7 | using System.Runtime.CompilerServices;
8 |
9 | namespace HNSW.Net
10 | {
11 | public static class DistanceCacheLimits
12 | {
13 | ///
14 | /// https://referencesource.microsoft.com/#mscorlib/system/array.cs,2d2b551eabe74985,references
15 | /// We use powers of 2 for efficient modulo
16 | /// 2^28 = 268435456
17 | /// 2^29 = 536870912
18 | /// 2^30 = 1073741824
19 | ///
20 | public static int MaxArrayLength
21 | {
22 | get { return _maxArrayLength; }
23 | set { _maxArrayLength = NextPowerOf2((uint)value); }
24 | }
25 |
26 | private static int NextPowerOf2(uint x)
27 | {
28 | #if NET7_0_OR_GREATER
29 | var v = System.Numerics.BitOperations.RoundUpToPowerOf2(x);
30 | if (v > 0x10000000) return 0x10000000;
31 | return (int)v;
32 | #else
33 | return (int)Math.Pow(2, Math.Ceiling(Math.Log(x) / Math.Log(2)));
34 | #endif
35 | }
36 |
37 | private static int _maxArrayLength = 268_435_456; // 0x10000000;
38 | }
39 |
40 | internal class DistanceCache where TDistance : struct
41 | {
42 | private TDistance[] _values;
43 |
44 | private long[] _keys;
45 |
46 | internal int HitCount;
47 |
48 | internal DistanceCache()
49 | {
50 | }
51 |
52 | internal void Resize(int pointsCount, bool overwrite)
53 | {
54 | if (pointsCount <= 0) { pointsCount = 1024; }
55 |
56 | long capacity = ((long)pointsCount * (pointsCount + 1)) >> 1;
57 |
58 | capacity = capacity < DistanceCacheLimits.MaxArrayLength ? capacity : DistanceCacheLimits.MaxArrayLength;
59 |
60 | if (_keys is null || capacity > _keys.Length || overwrite)
61 | {
62 | int i0 = 0;
63 | if (_keys is null || overwrite)
64 | {
65 | _keys = new long[(int)capacity];
66 | _values = new TDistance[(int)capacity];
67 | }
68 | else
69 | {
70 | i0 = _keys.Length;
71 | Array.Resize(ref _keys, (int)capacity);
72 | Array.Resize(ref _values, (int)capacity);
73 | }
74 |
75 | // TODO: may be there is a better way to warm up cache and force OS to allocate pages
76 | _keys.AsSpan().Slice(i0).Fill(-1);
77 | _values.AsSpan().Slice(i0).Fill(default);
78 | }
79 | }
80 |
81 | [MethodImpl(MethodImplOptions.AggressiveInlining)]
82 | internal TDistance GetOrCacheValue(int fromId, int toId, Func getter)
83 | {
84 | long key = MakeKey(fromId, toId);
85 | int hash = (int)(key & (_keys.Length - 1));
86 |
87 | if (_keys[hash] == key)
88 | {
89 | HitCount++;
90 | return _values[hash];
91 | }
92 | else
93 | {
94 | var d = getter(fromId, toId);
95 | _keys[hash] = key;
96 | _values[hash] = d;
97 | return d;
98 | }
99 | }
100 |
101 | [MethodImpl(MethodImplOptions.AggressiveInlining)]
102 | private void SetValue(int fromId, int toId, TDistance distance)
103 | {
104 | long key = MakeKey(fromId, toId);
105 | int hash = (int)(key & (_keys.Length - 1));
106 | _keys[hash] = key;
107 | _values[hash] = distance;
108 | }
109 |
110 | ///
111 | /// Builds key for the pair of points.
112 | /// MakeKey(fromId, toId) == MakeKey(toId, fromId)
113 | ///
114 | /// The from point identifier.
115 | /// The to point identifier.
116 | /// Key of the pair.
117 | private static long MakeKey(int fromId, int toId)
118 | {
119 | return fromId > toId ? (((long)fromId * (fromId + 1)) >> 1) + toId : (((long)toId * (toId + 1)) >> 1) + fromId;
120 | }
121 | }
122 | }
123 |
--------------------------------------------------------------------------------
/Src/HNSW.Net/DistanceUtils.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net
7 | {
8 | using System;
9 |
10 | public static class DistanceUtils
11 | {
12 | public static bool LowerThan(TDistance x, TDistance y) where TDistance : IComparable
13 | {
14 | return x.CompareTo(y) < 0;
15 | }
16 |
17 | public static bool GreaterThan(TDistance x, TDistance y) where TDistance : IComparable
18 | {
19 | return x.CompareTo(y) > 0;
20 | }
21 |
22 | public static bool IsEqual(TDistance x, TDistance y) where TDistance : IComparable
23 | {
24 | return x.CompareTo(y) == 0;
25 | }
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/Src/HNSW.Net/EventSources.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net
7 | {
8 | using System;
9 | using System.Diagnostics.CodeAnalysis;
10 | using System.Diagnostics.Tracing;
11 | using System.Runtime.InteropServices;
12 |
13 | public static class EventSources
14 | {
15 | ///
16 | /// Writes specific metric if the source is enabled.
17 | ///
18 | /// The event source to check.
19 | /// The counter to write metric.
20 | /// The value to write.
21 | internal static void WriteMetricIfEnabled(EventSource source, EventCounter counter, float value)
22 | {
23 | if (source.IsEnabled())
24 | {
25 | counter.WriteMetric(value);
26 | }
27 | }
28 |
29 | ///
30 | /// Source of events occuring at graph construction phase.
31 | ///
32 | [SuppressMessage("Design", "CA1034:Nested types should not be visible", Justification = "By Design")]
33 | [EventSource(Name = "HNSW.Net.Graph.Build")]
34 | [ComVisible(false)]
35 | public class GraphBuildEventSource : EventSource
36 | {
37 | ///
38 | /// The singleton instance of the source.
39 | ///
40 | public static readonly GraphBuildEventSource Instance = new GraphBuildEventSource();
41 |
42 | ///
43 | /// Initializes a new instance of the class.
44 | ///
45 | private GraphBuildEventSource() : base(EventSourceSettings.EtwSelfDescribingEventFormat)
46 | {
47 | var coreGetDistanceCacheHitRate = new EventCounter("GetDistance.CacheHitRate", this);
48 | CoreGetDistanceCacheHitRateReporter = (float value) => WriteMetricIfEnabled(this, coreGetDistanceCacheHitRate, value);
49 |
50 | var graphInsertNodeLatency = new EventCounter("InsertNode.Latency", this);
51 | GraphInsertNodeLatencyReporter = (float value) => WriteMetricIfEnabled(this, graphInsertNodeLatency, value);
52 | }
53 |
54 | ///
55 | /// Gets the delegate to report the hit rate of the distance cache.
56 | ///
57 | ///
58 | internal Action CoreGetDistanceCacheHitRateReporter { get; }
59 |
60 | ///
61 | /// Gets the delegate to report the node insertion latency.
62 | ///
63 | ///
64 | internal Action GraphInsertNodeLatencyReporter { get; }
65 | }
66 |
67 | ///
68 | /// Source of events occuring at graph construction phase.
69 | ///
70 | [SuppressMessage("Design", "CA1034:Nested types should not be visible", Justification = "By Design")]
71 | [EventSource(Name = "HNSW.Net.Graph.Search")]
72 | [ComVisible(false)]
73 | public class GraphSearchEventSource : EventSource
74 | {
75 | ///
76 | /// The singleton instance of the source.
77 | ///
78 | public static readonly GraphSearchEventSource Instance = new GraphSearchEventSource();
79 |
80 | ///
81 | /// Initializes a new instance of the class.
82 | ///
83 | private GraphSearchEventSource() : base(EventSourceSettings.EtwSelfDescribingEventFormat)
84 | {
85 | var graphKNearestLatency = new EventCounter("KNearest.Latency", this);
86 | GraphKNearestLatencyReporter = (float value) => WriteMetricIfEnabled(this, graphKNearestLatency, value);
87 |
88 | var graphKNearestVisitedNodes = new EventCounter("KNearest.VisitedNodes", this);
89 | GraphKNearestVisitedNodesReporter = (float value) => WriteMetricIfEnabled(this, graphKNearestVisitedNodes, value);
90 | }
91 |
92 | ///
93 | /// Gets the delegate to report latency.
94 | ///
95 | internal Action GraphKNearestLatencyReporter { get; }
96 |
97 | ///
98 | /// Gets the counter to report the number of expanded nodes at runtime.
99 | ///
100 | ///
101 | internal Action GraphKNearestVisitedNodesReporter { get; }
102 | }
103 | }
104 | }
105 |
--------------------------------------------------------------------------------
/Src/HNSW.Net/FastRandom.cs:
--------------------------------------------------------------------------------
1 | using System;
2 |
3 | namespace HNSW.Net
4 | {
5 | ///
6 | /// A fast random number generator for .NET, from https://www.codeproject.com/Articles/9187/A-fast-equivalent-for-System-Random
7 | /// Colin Green, January 2005
8 | ///
9 | /// September 4th 2005
10 | /// Added NextBytesUnsafe() - commented out by default.
11 | /// Fixed bug in Reinitialise() - y,z and w variables were not being reset.
12 | ///
13 | /// Key points:
14 | /// 1) Based on a simple and fast xor-shift pseudo random number generator (RNG) specified in:
15 | /// Marsaglia, George. (2003). Xorshift RNGs.
16 | /// http://www.jstatsoft.org/v08/i14/xorshift.pdf
17 | ///
18 | /// This particular implementation of xorshift has a period of 2^128-1. See the above paper to see
19 | /// how this can be easily extened if you need a longer period. At the time of writing I could find no
20 | /// information on the period of System.Random for comparison.
21 | ///
22 | /// 2) Faster than System.Random. Up to 8x faster, depending on which methods are called.
23 | ///
24 | /// 3) Direct replacement for System.Random. This class implements all of the methods that System.Random
25 | /// does plus some additional methods. The like named methods are functionally equivalent.
26 | ///
27 | /// 4) Allows fast re-initialisation with a seed, unlike System.Random which accepts a seed at construction
28 | /// time which then executes a relatively expensive initialisation routine. This provides a vast speed improvement
29 | /// if you need to reset the pseudo-random number sequence many times, e.g. if you want to re-generate the same
30 | /// sequence many times. An alternative might be to cache random numbers in an array, but that approach is limited
31 | /// by memory capacity and the fact that you may also want a large number of different sequences cached. Each sequence
32 | /// can each be represented by a single seed value (int) when using FastRandom.
33 | ///
34 | /// Notes.
35 | /// A further performance improvement can be obtained by declaring local variables as static, thus avoiding
36 | /// re-allocation of variables on each call. However care should be taken if multiple instances of
37 | /// FastRandom are in use or if being used in a multi-threaded environment.
38 | ///
39 | ///
40 | internal class FastRandom
41 | {
42 | // The +1 ensures NextDouble doesn't generate 1.0
43 | const float FLOAT_UNIT_INT = 1.0f / ((float)int.MaxValue + 1.0f);
44 |
45 | const double REAL_UNIT_INT = 1.0 / ((double)int.MaxValue + 1.0);
46 | const double REAL_UNIT_UINT = 1.0 / ((double)uint.MaxValue + 1.0);
47 | const uint Y = 842502087, Z = 3579807591, W = 273326509;
48 |
49 | uint x, y, z, w;
50 |
51 | ///
52 | /// Initialises a new instance using time dependent seed.
53 | ///
54 | public FastRandom()
55 | {
56 | // Initialise using the system tick count.
57 | Reinitialise(Environment.TickCount);
58 | }
59 |
60 | ///
61 | /// Initialises a new instance using an int value as seed.
62 | /// This constructor signature is provided to maintain compatibility with
63 | /// System.Random
64 | ///
65 | public FastRandom(int seed)
66 | {
67 | Reinitialise(seed);
68 | }
69 |
70 | ///
71 | /// Reinitialises using an int value as a seed.
72 | ///
73 | public void Reinitialise(int seed)
74 | {
75 | // The only stipulation stated for the xorshift RNG is that at least one of
76 | // the seeds x,y,z,w is non-zero. We fulfill that requirement by only allowing
77 | // resetting of the x seed
78 | x = (uint)seed;
79 | y = Y;
80 | z = Z;
81 | w = W;
82 | }
83 |
84 | ///
85 | /// Generates a random int over the range 0 to int.MaxValue-1.
86 | /// MaxValue is not generated in order to remain functionally equivalent to System.Random.Next().
87 | /// This does slightly eat into some of the performance gain over System.Random, but not much.
88 | /// For better performance see:
89 | ///
90 | /// Call NextInt() for an int over the range 0 to int.MaxValue.
91 | ///
92 | /// Call NextUInt() and cast the result to an int to generate an int over the full Int32 value range
93 | /// including negative values.
94 | ///
95 | public int Next()
96 | {
97 | uint t = (x ^ (x << 11));
98 | x = y; y = z; z = w;
99 | w = (w ^ (w >> 19)) ^ (t ^ (t >> 8));
100 |
101 | // Handle the special case where the value int.MaxValue is generated. This is outside of
102 | // the range of permitted values, so we therefore call Next() to try again.
103 | uint rtn = w & 0x7FFFFFFF;
104 | if (rtn == 0x7FFFFFFF)
105 | return Next();
106 | return (int)rtn;
107 | }
108 |
109 | ///
110 | /// Generates a random int over the range 0 to upperBound-1, and not including upperBound.
111 | ///
112 | public int Next(int upperBound)
113 | {
114 | if (upperBound < 0)
115 | throw new ArgumentOutOfRangeException("upperBound", upperBound, "upperBound must be >=0");
116 |
117 | uint t = (x ^ (x << 11));
118 | x = y; y = z; z = w;
119 |
120 | // The explicit int cast before the first multiplication gives better performance.
121 | // See comments in NextDouble.
122 | return (int)((REAL_UNIT_INT * (int)(0x7FFFFFFF & (w = (w ^ (w >> 19)) ^ (t ^ (t >> 8))))) * upperBound);
123 | }
124 |
125 | ///
126 | /// Generates a random int over the range lowerBound to upperBound-1, and not including upperBound.
127 | /// upperBound must be >= lowerBound. lowerBound may be negative.
128 | ///
129 | public int Next(int lowerBound, int upperBound)
130 | {
131 | if (lowerBound > upperBound)
132 | throw new ArgumentOutOfRangeException("upperBound", upperBound, "upperBound must be >=lowerBound");
133 |
134 | uint t = (x ^ (x << 11));
135 | x = y; y = z; z = w;
136 |
137 | // The explicit int cast before the first multiplication gives better performance.
138 | // See comments in NextDouble.
139 | int range = upperBound - lowerBound;
140 | if (range < 0)
141 | { // If range is <0 then an overflow has occured and must resort to using long integer arithmetic instead (slower).
142 | // We also must use all 32 bits of precision, instead of the normal 31, which again is slower.
143 | return lowerBound + (int)((REAL_UNIT_UINT * (double)(w = (w ^ (w >> 19)) ^ (t ^ (t >> 8)))) * (double)((long)upperBound - (long)lowerBound));
144 | }
145 |
146 | // 31 bits of precision will suffice if range<=int.MaxValue. This allows us to cast to an int and gain
147 | // a little more performance.
148 | return lowerBound + (int)((REAL_UNIT_INT * (double)(int)(0x7FFFFFFF & (w = (w ^ (w >> 19)) ^ (t ^ (t >> 8))))) * (double)range);
149 | }
150 |
151 | ///
152 | /// Generates a random double. Values returned are from 0.0 up to but not including 1.0.
153 | ///
154 | public double NextDouble()
155 | {
156 | uint t = (x ^ (x << 11));
157 | x = y; y = z; z = w;
158 |
159 | // Here we can gain a 2x speed improvement by generating a value that can be cast to
160 | // an int instead of the more easily available uint. If we then explicitly cast to an
161 | // int the compiler will then cast the int to a double to perform the multiplication,
162 | // this final cast is a lot faster than casting from a uint to a double. The extra cast
163 | // to an int is very fast (the allocated bits remain the same) and so the overall effect
164 | // of the extra cast is a significant performance improvement.
165 | //
166 | // Also note that the loss of one bit of precision is equivalent to what occurs within
167 | // System.Random.
168 | return (REAL_UNIT_INT * (int)(0x7FFFFFFF & (w = (w ^ (w >> 19)) ^ (t ^ (t >> 8)))));
169 | }
170 |
171 | ///
172 | /// Generates a random double. Values returned are from 0.0 up to but not including 1.0.
173 | ///
174 | public float NextFloat()
175 | {
176 | uint x = this.x, y = this.y, z = this.z, w = this.w;
177 | uint t = (x ^ (x << 11));
178 | x = y; y = z; z = w;
179 | w = (w ^ (w >> 19)) ^ (t ^ (t >> 8));
180 | var value = FLOAT_UNIT_INT * (int)(0x7FFFFFFF & w);
181 | this.x = x; this.y = y; this.z = z; this.w = w;
182 | return value;
183 | }
184 |
185 | ///
186 | /// Fills the provided byte array with random floats.
187 | ///
188 | public void NextFloats(Span buffer)
189 | {
190 | uint x = this.x, y = this.y, z = this.z, w = this.w;
191 | int i = 0;
192 | uint t;
193 | for (int bound = buffer.Length; i < bound;)
194 | {
195 | t = (x ^ (x << 11));
196 | x = y; y = z; z = w;
197 | w = (w ^ (w >> 19)) ^ (t ^ (t >> 8));
198 |
199 | buffer[i++] = FLOAT_UNIT_INT * (int)(0x7FFFFFFF & w);
200 | }
201 |
202 | this.x = x; this.y = y; this.z = z; this.w = w;
203 | }
204 |
205 |
206 | ///
207 | /// Fills the provided byte array with random bytes.
208 | /// This method is functionally equivalent to System.Random.NextBytes().
209 | ///
210 | public void NextBytes(byte[] buffer)
211 | {
212 | // Fill up the bulk of the buffer in chunks of 4 bytes at a time.
213 | uint x = this.x, y = this.y, z = this.z, w = this.w;
214 | int i = 0;
215 | uint t;
216 | for (int bound = buffer.Length - 3; i < bound;)
217 | {
218 | // Generate 4 bytes.
219 | // Increased performance is achieved by generating 4 random bytes per loop.
220 | // Also note that no mask needs to be applied to zero out the higher order bytes before
221 | // casting because the cast ignores thos bytes. Thanks to Stefan Troschütz for pointing this out.
222 | t = (x ^ (x << 11));
223 | x = y; y = z; z = w;
224 | w = (w ^ (w >> 19)) ^ (t ^ (t >> 8));
225 |
226 | buffer[i++] = (byte)w;
227 | buffer[i++] = (byte)(w >> 8);
228 | buffer[i++] = (byte)(w >> 16);
229 | buffer[i++] = (byte)(w >> 24);
230 | }
231 |
232 | // Fill up any remaining bytes in the buffer.
233 | if (i < buffer.Length)
234 | {
235 | // Generate 4 bytes.
236 | t = (x ^ (x << 11));
237 | x = y; y = z; z = w;
238 | w = (w ^ (w >> 19)) ^ (t ^ (t >> 8));
239 |
240 | buffer[i++] = (byte)w;
241 | if (i < buffer.Length)
242 | {
243 | buffer[i++] = (byte)(w >> 8);
244 | if (i < buffer.Length)
245 | {
246 | buffer[i++] = (byte)(w >> 16);
247 | if (i < buffer.Length)
248 | {
249 | buffer[i] = (byte)(w >> 24);
250 | }
251 | }
252 | }
253 | }
254 | this.x = x; this.y = y; this.z = z; this.w = w;
255 | }
256 |
257 | ///
258 | /// Fills the provided byte array with random bytes.
259 | /// This method is functionally equivalent to System.Random.NextBytes().
260 | ///
261 | public void NextBytes(Span buffer)
262 | {
263 | // Fill up the bulk of the buffer in chunks of 4 bytes at a time.
264 | uint x = this.x, y = this.y, z = this.z, w = this.w;
265 | int i = 0;
266 | uint t;
267 | for (int bound = buffer.Length - 3; i < bound;)
268 | {
269 | // Generate 4 bytes.
270 | // Increased performance is achieved by generating 4 random bytes per loop.
271 | // Also note that no mask needs to be applied to zero out the higher order bytes before
272 | // casting because the cast ignores thos bytes. Thanks to Stefan Troschütz for pointing this out.
273 | t = (x ^ (x << 11));
274 | x = y; y = z; z = w;
275 | w = (w ^ (w >> 19)) ^ (t ^ (t >> 8));
276 |
277 | buffer[i++] = (byte)w;
278 | buffer[i++] = (byte)(w >> 8);
279 | buffer[i++] = (byte)(w >> 16);
280 | buffer[i++] = (byte)(w >> 24);
281 | }
282 |
283 | // Fill up any remaining bytes in the buffer.
284 | if (i < buffer.Length)
285 | {
286 | // Generate 4 bytes.
287 | t = (x ^ (x << 11));
288 | x = y; y = z; z = w;
289 | w = (w ^ (w >> 19)) ^ (t ^ (t >> 8));
290 |
291 | buffer[i++] = (byte)w;
292 | if (i < buffer.Length)
293 | {
294 | buffer[i++] = (byte)(w >> 8);
295 | if (i < buffer.Length)
296 | {
297 | buffer[i++] = (byte)(w >> 16);
298 | if (i < buffer.Length)
299 | {
300 | buffer[i] = (byte)(w >> 24);
301 | }
302 | }
303 | }
304 | }
305 | this.x = x; this.y = y; this.z = z; this.w = w;
306 | }
307 |
308 | ///
309 | /// Generates a uint. Values returned are over the full range of a uint,
310 | /// uint.MinValue to uint.MaxValue, inclusive.
311 | ///
312 | /// This is the fastest method for generating a single random number because the underlying
313 | /// random number generator algorithm generates 32 random bits that can be cast directly to
314 | /// a uint.
315 | ///
316 | public uint NextUInt()
317 | {
318 | uint t = (x ^ (x << 11));
319 | x = y; y = z; z = w;
320 | return (w = (w ^ (w >> 19)) ^ (t ^ (t >> 8)));
321 | }
322 |
323 | ///
324 | /// Generates a random int over the range 0 to int.MaxValue, inclusive.
325 | /// This method differs from Next() only in that the range is 0 to int.MaxValue
326 | /// and not 0 to int.MaxValue-1.
327 | ///
328 | /// The slight difference in range means this method is slightly faster than Next()
329 | /// but is not functionally equivalent to System.Random.Next().
330 | ///
331 | public int NextInt()
332 | {
333 | uint t = (x ^ (x << 11));
334 | x = y; y = z; z = w;
335 | return (int)(0x7FFFFFFF & (w = (w ^ (w >> 19)) ^ (t ^ (t >> 8))));
336 | }
337 |
338 |
339 | // Buffer 32 bits in bitBuffer, return 1 at a time, keep track of how many have been returned
340 | // with bitBufferIdx.
341 | uint bitBuffer;
342 | uint bitMask = 1;
343 |
344 | ///
345 | /// Generates a single random bit.
346 | /// This method's performance is improved by generating 32 bits in one operation and storing them
347 | /// ready for future calls.
348 | ///
349 | public bool NextBool()
350 | {
351 | if (bitMask == 1)
352 | {
353 | // Generate 32 more bits.
354 | uint t = (x ^ (x << 11));
355 | x = y; y = z; z = w;
356 | bitBuffer = w = (w ^ (w >> 19)) ^ (t ^ (t >> 8));
357 |
358 | // Reset the bitMask that tells us which bit to read next.
359 | bitMask = 0x80000000;
360 | return (bitBuffer & bitMask) == 0;
361 | }
362 |
363 | return (bitBuffer & (bitMask >>= 1)) == 0;
364 | }
365 | }
366 | }
--------------------------------------------------------------------------------
/Src/HNSW.Net/Graph.Core.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net
7 | {
8 | using System;
9 | using System.Collections.Generic;
10 | using System.IO;
11 | using System.Linq;
12 | using System.Runtime.CompilerServices;
13 | using System.Threading;
14 | using MessagePack;
15 |
16 | using static HNSW.Net.EventSources;
17 |
18 | internal partial class Graph
19 | {
20 | internal class Core
21 | {
22 | private readonly Func Distance;
23 |
24 | private readonly DistanceCache DistanceCache;
25 |
26 | private long DistanceCalculationsCount;
27 |
28 | internal List Nodes { get; private set; }
29 |
30 | internal List Items { get; private set; }
31 |
32 | internal Algorithms.Algorithm Algorithm { get; private set; }
33 |
34 | internal SmallWorld.Parameters Parameters { get; private set; }
35 |
36 | internal float DistanceCacheHitRate => (float)(DistanceCache?.HitCount ?? 0) / DistanceCalculationsCount;
37 |
38 | internal Core(Func distance, SmallWorld.Parameters parameters)
39 | {
40 | Distance = distance;
41 | Parameters = parameters;
42 |
43 | var initialSize = Math.Max(1024, parameters.InitialItemsSize);
44 |
45 | Nodes = new List(initialSize);
46 | Items = new List(initialSize);
47 |
48 | switch (Parameters.NeighbourHeuristic)
49 | {
50 | case NeighbourSelectionHeuristic.SelectSimple:
51 | {
52 | Algorithm = new Algorithms.Algorithm3(this);
53 | break;
54 | }
55 | case NeighbourSelectionHeuristic.SelectHeuristic:
56 | {
57 | Algorithm = new Algorithms.Algorithm4(this);
58 | break;
59 | }
60 | }
61 |
62 | if (Parameters.EnableDistanceCacheForConstruction)
63 | {
64 | DistanceCache = new DistanceCache();
65 | DistanceCache.Resize(parameters.InitialDistanceCacheSize, false);
66 | }
67 |
68 | DistanceCalculationsCount = 0;
69 | }
70 |
71 | internal IReadOnlyList AddItems(IReadOnlyList items, IProvideRandomValues generator)
72 | {
73 | int newCount = items.Count;
74 |
75 | var newIDs = new List();
76 | Items.AddRange(items);
77 | DistanceCache?.Resize(newCount, false);
78 |
79 | int id0 = Nodes.Count;
80 |
81 | for (int id = 0; id < newCount; ++id)
82 | {
83 | Nodes.Add(Algorithm.NewNode(id0 + id, RandomLayer(generator, Parameters.LevelLambda)));
84 | newIDs.Add(id0 + id);
85 | }
86 | return newIDs;
87 | }
88 |
89 | internal void ResizeDistanceCache(int newSize)
90 | {
91 | DistanceCache?.Resize(newSize, true);
92 | }
93 |
94 | internal void Serialize(Stream stream)
95 | {
96 | MessagePackSerializer.Serialize(stream, Nodes);
97 | }
98 |
99 | internal TItem[] Deserialize(IReadOnlyList items, Stream stream)
100 | {
101 | // readStrict: true -> removed, as not available anymore on MessagePack 2.0 - also probably not necessary anymore
102 | // see https://github.com/neuecc/MessagePack-CSharp/pull/663
103 | Nodes = MessagePackSerializer.Deserialize>(stream);
104 | var remainingItems = items.Skip(Nodes.Count).ToArray();
105 | Items.AddRange(items.Take(Nodes.Count));
106 | return remainingItems;
107 | }
108 |
109 | [MethodImpl(MethodImplOptions.AggressiveInlining)]
110 | internal TDistance GetDistance(int fromId, int toId)
111 | {
112 | DistanceCalculationsCount++;
113 | if (DistanceCache is object)
114 | {
115 | return DistanceCache.GetOrCacheValue(fromId, toId, GetDistanceSkipCache);
116 | }
117 | else
118 | {
119 | return Distance(Items[fromId], Items[toId]);
120 | }
121 | }
122 |
123 | [MethodImpl(MethodImplOptions.AggressiveInlining)]
124 | private TDistance GetDistanceSkipCache(int fromId, int toId)
125 | {
126 | return Distance(Items[fromId], Items[toId]);
127 | }
128 |
129 | private static int RandomLayer(IProvideRandomValues generator, double lambda)
130 | {
131 | var r = -Math.Log(generator.NextFloat()) * lambda;
132 | return (int)r;
133 | }
134 | }
135 | }
136 | }
137 |
--------------------------------------------------------------------------------
/Src/HNSW.Net/Graph.Searcher.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net
7 | {
8 | using System;
9 | using System.Collections.Generic;
10 | using System.Threading;
11 |
12 | ///
13 | /// The implementation of knn search.
14 | ///
15 | internal partial class Graph
16 | {
17 | ///
18 | /// The graph searcher.
19 | ///
20 | internal struct Searcher
21 | {
22 | private readonly Core Core;
23 | private readonly List ExpansionBuffer;
24 | private readonly VisitedBitSet VisitedSet;
25 |
26 | ///
27 | /// Initializes a new instance of the struct.
28 | ///
29 | /// The core of the graph.
30 | internal Searcher(Core core)
31 | {
32 | Core = core;
33 | ExpansionBuffer = new List();
34 | VisitedSet = new VisitedBitSet(core.Nodes.Count);
35 | }
36 |
37 | ///
38 | /// The implementaiton of SEARCH-LAYER(q, ep, ef, lc) algorithm.
39 | /// Article: Section 4. Algorithm 2.
40 | ///
41 | /// The identifier of the entry point for the search.
42 | /// The traveling costs for the search target.
43 | /// The list of identifiers of the nearest neighbours at the level.
44 | /// The layer to perform search at.
45 | /// The number of the nearest neighbours to get from the layer.
46 | /// The version of the graph, will retry the search if the version changed
47 | /// The number of expanded nodes during the run.
48 | internal int RunKnnAtLayer(int entryPointId, TravelingCosts targetCosts, List resultList, int layer, int k, ref long version, long versionAtStart, Func keepResult, CancellationToken cancellationToken = default)
49 | {
50 | /*
51 | * v ← ep // set of visited elements
52 | * C ← ep // set of candidates
53 | * W ← ep // dynamic list of found nearest neighbors
54 | * while │C│ > 0
55 | * c ← extract nearest element from C to q
56 | * f ← get furthest element from W to q
57 | * if distance(c, q) > distance(f, q)
58 | * break // all elements in W are evaluated
59 | * for each e ∈ neighbourhood(c) at layer lc // update C and W
60 | * if e ∉ v
61 | * v ← v ⋃ e
62 | * f ← get furthest element from W to q
63 | * if distance(e, q) < distance(f, q) or │W│ < ef
64 | * C ← C ⋃ e
65 | * W ← W ⋃ e
66 | * if │W│ > ef
67 | * remove furthest element from W to q
68 | * return W
69 | */
70 |
71 | // prepare tools
72 | IComparer fartherIsOnTop = targetCosts;
73 | IComparer closerIsOnTop = fartherIsOnTop.Reverse();
74 |
75 | // prepare collections
76 | // TODO: Optimize by providing buffers
77 | var resultHeap = new BinaryHeap(resultList, fartherIsOnTop);
78 | var expansionHeap = new BinaryHeap(ExpansionBuffer, closerIsOnTop);
79 |
80 | if (keepResult(entryPointId))
81 | {
82 | resultHeap.Push(entryPointId);
83 | }
84 |
85 | expansionHeap.Push(entryPointId);
86 | VisitedSet.Add(entryPointId);
87 |
88 | try
89 | {
90 | // run bfs
91 | int visitedNodesCount = 1;
92 | while (expansionHeap.Buffer.Count > 0)
93 | {
94 | if (cancellationToken.IsCancellationRequested)
95 | {
96 | return visitedNodesCount;
97 | }
98 |
99 | GraphChangedException.ThrowIfChanged(ref version, versionAtStart);
100 |
101 | // get next candidate to check and expand
102 | var toExpandId = expansionHeap.Pop();
103 | var farthestResultId = resultHeap.Buffer.Count > 0 ? resultHeap.Buffer[0] : -1;
104 | if (farthestResultId > 0 && DistanceUtils.GreaterThan(targetCosts.From(toExpandId), targetCosts.From(farthestResultId)))
105 | {
106 | // the closest candidate is farther than farthest result
107 | break;
108 | }
109 |
110 | // expand candidate
111 | var neighboursIds = Core.Nodes[toExpandId][layer];
112 |
113 | for (int i = 0; i < neighboursIds.Count; ++i)
114 | {
115 | if (cancellationToken.IsCancellationRequested)
116 | {
117 | return visitedNodesCount;
118 | }
119 |
120 | int neighbourId = neighboursIds[i];
121 |
122 | if (!VisitedSet.Contains(neighbourId))
123 | {
124 | // enqueue perspective neighbours to expansion list
125 | farthestResultId = resultHeap.Buffer.Count > 0 ? resultHeap.Buffer[0] : -1;
126 | if (resultHeap.Buffer.Count < k || (farthestResultId >= 0 && DistanceUtils.LowerThan(targetCosts.From(neighbourId), targetCosts.From(farthestResultId))))
127 | {
128 | expansionHeap.Push(neighbourId);
129 |
130 | if (keepResult(neighbourId))
131 | {
132 | resultHeap.Push(neighbourId);
133 | }
134 |
135 | if (resultHeap.Buffer.Count > k)
136 | {
137 | resultHeap.Pop();
138 | }
139 | }
140 |
141 | // update visited list
142 | ++visitedNodesCount;
143 | VisitedSet.Add(neighbourId);
144 | }
145 | }
146 | }
147 |
148 | ExpansionBuffer.Clear();
149 | VisitedSet.Clear();
150 |
151 |
152 | return visitedNodesCount;
153 | }
154 | catch (Exception ex)
155 | {
156 | //Throws if the collection changed, otherwise propagates the original exception
157 | GraphChangedException.ThrowIfChanged(ref version, versionAtStart);
158 | throw;
159 | }
160 | }
161 | }
162 | }
163 | }
164 |
--------------------------------------------------------------------------------
/Src/HNSW.Net/Graph.Utils.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net
7 | {
8 | using System;
9 | using System.Collections.Generic;
10 | using System.Linq;
11 |
12 | internal partial class Graph
13 | {
14 | ///
15 | /// Runs breadth first search.
16 | ///
17 | /// The graph core.
18 | /// The entry point.
19 | /// The layer of the graph where to run BFS.
20 | /// The action to perform on each node.
21 | internal static void BFS(Core core, Node entryPoint, int layer, Action visitAction)
22 | {
23 | var visitedIds = new HashSet();
24 | var expansionQueue = new Queue(new[] { entryPoint.Id });
25 |
26 | while (expansionQueue.Any())
27 | {
28 | var currentNode = core.Nodes[expansionQueue.Dequeue()];
29 | if (!visitedIds.Contains(currentNode.Id))
30 | {
31 | visitAction(currentNode);
32 | visitedIds.Add(currentNode.Id);
33 | foreach (var neighbourId in currentNode[layer])
34 | {
35 | expansionQueue.Enqueue(neighbourId);
36 | }
37 | }
38 | }
39 | }
40 |
41 | internal class VisitedBitSet
42 | {
43 | private int[] Buffer;
44 |
45 | internal VisitedBitSet(int nodesCount)
46 | {
47 | Buffer = new int[(nodesCount >> 5) + 1];
48 | }
49 |
50 | internal bool Contains(int nodeId)
51 | {
52 | int carrier = Buffer[nodeId >> 5];
53 | return ((1 << (nodeId & 31)) & carrier) != 0;
54 | }
55 |
56 | internal void Add(int nodeId)
57 | {
58 | int mask = 1 << (nodeId & 31);
59 | Buffer[nodeId >> 5] |= mask;
60 | }
61 |
62 | internal void Clear()
63 | {
64 | Array.Clear(Buffer, 0, Buffer.Length);
65 | }
66 | }
67 | }
68 | }
--------------------------------------------------------------------------------
/Src/HNSW.Net/Graph.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net
7 | {
8 | using System;
9 | using System.Collections.Generic;
10 | using System.IO;
11 | using System.Linq;
12 | using MessagePack;
13 | using System.Text;
14 |
15 | using static HNSW.Net.EventSources;
16 | using System.Threading;
17 |
18 | ///
19 | /// The implementation of a hierarchical small world graph.
20 | ///
21 | /// The type of items to connect into small world.
22 | /// The type of distance between items (expects any numeric type: float, double, decimal, int, ...).
23 | internal partial class Graph where TDistance : struct, IComparable
24 | {
25 | private readonly Func Distance;
26 |
27 | internal Core GraphCore;
28 |
29 | private Node? EntryPoint;
30 |
31 | private long _version;
32 |
33 | ///
34 | /// Initializes a new instance of the class.
35 | ///
36 | /// The distance function.
37 | /// The parameters of the world.
38 | internal Graph(Func distance, SmallWorld.Parameters parameters)
39 | {
40 | Distance = distance;
41 | Parameters = parameters;
42 | }
43 |
44 | internal SmallWorld.Parameters Parameters { get; }
45 |
46 | ///
47 | /// Creates graph from the given items.
48 | /// Contains implementation of INSERT(hnsw, q, M, Mmax, efConstruction, mL) algorithm.
49 | /// Article: Section 4. Algorithm 1.
50 | ///
51 | /// The items to insert.
52 | /// The random number generator to distribute nodes across layers.
53 | /// Interface to report progress
54 | internal IReadOnlyList AddItems(IReadOnlyList items, IProvideRandomValues generator, IProgressReporter progressReporter)
55 | {
56 | if (items is null || !items.Any()) { return Array.Empty(); }
57 |
58 | GraphCore = GraphCore ?? new Core(Distance, Parameters);
59 |
60 | int startIndex = GraphCore.Items.Count;
61 |
62 | var newIDs = GraphCore.AddItems(items, generator);
63 |
64 | var entryPoint = EntryPoint ?? GraphCore.Nodes[0];
65 |
66 | var searcher = new Searcher(GraphCore);
67 | Func nodeDistance = GraphCore.GetDistance;
68 | var neighboursIdsBuffer = new List(GraphCore.Algorithm.GetM(0) + 1);
69 |
70 | for (int nodeId = startIndex; nodeId < GraphCore.Nodes.Count; ++nodeId)
71 | {
72 | var versionNow = Interlocked.Increment(ref _version);
73 |
74 | using (new ScopeLatencyTracker(GraphBuildEventSource.Instance?.GraphInsertNodeLatencyReporter))
75 | {
76 | /*
77 | * W ← ∅ // list for the currently found nearest elements
78 | * ep ← get enter point for hnsw
79 | * L ← level of ep // top layer for hnsw
80 | * l ← ⌊-ln(unif(0..1))∙mL⌋ // new element’s level
81 | * for lc ← L … l+1
82 | * W ← SEARCH-LAYER(q, ep, ef=1, lc)
83 | * ep ← get the nearest element from W to q
84 | * for lc ← min(L, l) … 0
85 | * W ← SEARCH-LAYER(q, ep, efConstruction, lc)
86 | * neighbors ← SELECT-NEIGHBORS(q, W, M, lc) // alg. 3 or alg. 4
87 | * for each e ∈ neighbors // shrink connections if needed
88 | * eConn ← neighbourhood(e) at layer lc
89 | * if │eConn│ > Mmax // shrink connections of e if lc = 0 then Mmax = Mmax0
90 | * eNewConn ← SELECT-NEIGHBORS(e, eConn, Mmax, lc) // alg. 3 or alg. 4
91 | * set neighbourhood(e) at layer lc to eNewConn
92 | * ep ← W
93 | * if l > L
94 | * set enter point for hnsw to q
95 | */
96 |
97 | // zoom in and find the best peer on the same level as newNode
98 | var bestPeer = entryPoint;
99 | var currentNode = GraphCore.Nodes[nodeId];
100 | var currentNodeTravelingCosts = new TravelingCosts(nodeDistance, nodeId);
101 | for (int layer = bestPeer.MaxLayer; layer > currentNode.MaxLayer; --layer)
102 | {
103 | searcher.RunKnnAtLayer(bestPeer.Id, currentNodeTravelingCosts, neighboursIdsBuffer, layer, 1, ref _version, versionNow, _ => true);
104 | bestPeer = GraphCore.Nodes[neighboursIdsBuffer[0]];
105 | neighboursIdsBuffer.Clear();
106 | }
107 |
108 | // connecting new node to the small world
109 | for (int layer = Math.Min(currentNode.MaxLayer, entryPoint.MaxLayer); layer >= 0; --layer)
110 | {
111 | searcher.RunKnnAtLayer(bestPeer.Id, currentNodeTravelingCosts, neighboursIdsBuffer, layer, Parameters.ConstructionPruning, ref _version, versionNow, _ => true);
112 | var bestNeighboursIds = GraphCore.Algorithm.SelectBestForConnecting(neighboursIdsBuffer, currentNodeTravelingCosts, layer);
113 |
114 | for (int i = 0; i < bestNeighboursIds.Count; ++i)
115 | {
116 | int newNeighbourId = bestNeighboursIds[i];
117 | versionNow = Interlocked.Increment(ref _version);
118 | GraphCore.Algorithm.Connect(currentNode, GraphCore.Nodes[newNeighbourId], layer);
119 |
120 | versionNow = Interlocked.Increment(ref _version);
121 | GraphCore.Algorithm.Connect(GraphCore.Nodes[newNeighbourId], currentNode, layer);
122 |
123 | // if distance from newNode to newNeighbour is better than to bestPeer => update bestPeer
124 | if (DistanceUtils.LowerThan(currentNodeTravelingCosts.From(newNeighbourId), currentNodeTravelingCosts.From(bestPeer.Id)))
125 | {
126 | bestPeer = GraphCore.Nodes[newNeighbourId];
127 | }
128 | }
129 |
130 | neighboursIdsBuffer.Clear();
131 | }
132 |
133 | // zoom out to the highest level
134 | if (currentNode.MaxLayer > entryPoint.MaxLayer)
135 | {
136 | entryPoint = currentNode;
137 | }
138 |
139 | // report distance cache hit rate
140 | GraphBuildEventSource.Instance?.CoreGetDistanceCacheHitRateReporter?.Invoke(GraphCore.DistanceCacheHitRate);
141 | }
142 | progressReporter?.Progress(nodeId - startIndex, GraphCore.Nodes.Count - startIndex);
143 | }
144 |
145 | // construction is done
146 | EntryPoint = entryPoint;
147 |
148 | return newIDs;
149 | }
150 |
151 | ///
152 | /// Get k nearest items for a given one.
153 | /// Contains implementation of K-NN-SEARCH(hnsw, q, K, ef) algorithm.
154 | /// Article: Section 4. Algorithm 5.
155 | ///
156 | /// The given node to get the nearest neighbourhood for.
157 | /// The size of the neighbourhood.
158 | /// Filter results by ID that should be kept (return true to keep, false to exclude from results)
159 | /// Cancellation Token for stopping the search when filtering is active
160 | /// The list of the nearest neighbours.
161 | internal IList.KNNSearchResult> KNearest(TItem destination, int k, Func filterItem = null, CancellationToken cancellationToken = default)
162 | {
163 | if (EntryPoint is null) return null;
164 |
165 | Func keepResultInner = _ => true;
166 |
167 | if (filterItem is object)
168 | {
169 | var keepResults = new Dictionary();
170 | keepResultInner = (id) =>
171 | {
172 | if (keepResults.TryGetValue(id, out var v)) return v;
173 | v = filterItem(GraphCore.Items[id]);
174 | keepResults[id] = v;
175 | return v;
176 | };
177 | }
178 |
179 | int retries = 1_024;
180 |
181 | // TODO: hack we know that destination id is -1.
182 | TDistance RuntimeDistance(int x, int y)
183 | {
184 | int nodeId = x >= 0 ? x : y;
185 | return Distance(destination, GraphCore.Items[nodeId]);
186 | }
187 |
188 | while (true)
189 | {
190 | var versionNow = Interlocked.Read(ref _version);
191 |
192 | try
193 | {
194 | using (new ScopeLatencyTracker(GraphSearchEventSource.Instance?.GraphKNearestLatencyReporter))
195 | {
196 | var bestPeer = EntryPoint.Value;
197 | var searcher = new Searcher(GraphCore);
198 | var destinationTravelingCosts = new TravelingCosts(RuntimeDistance, -1);
199 | var resultIds = new List(k + 1);
200 |
201 | int visitedNodesCount = 0;
202 |
203 | for (int layer = EntryPoint.Value.MaxLayer; layer > 0; --layer)
204 | {
205 | visitedNodesCount += searcher.RunKnnAtLayer(bestPeer.Id, destinationTravelingCosts, resultIds, layer, 1, ref _version, versionNow, keepResultInner, cancellationToken);
206 |
207 | if (cancellationToken.IsCancellationRequested)
208 | {
209 | //Return best so far - TODO: Investigate if this assumption is correct
210 | return resultIds.Select(id => new SmallWorld.KNNSearchResult(id, GraphCore.Items[id], RuntimeDistance(id, -1))).ToList();
211 | }
212 |
213 | if (resultIds.Count > 0)
214 | {
215 | bestPeer = GraphCore.Nodes[resultIds[0]];
216 | }
217 |
218 | resultIds.Clear();
219 | }
220 |
221 | visitedNodesCount += searcher.RunKnnAtLayer(bestPeer.Id, destinationTravelingCosts, resultIds, 0, k, ref _version, versionNow, keepResultInner, cancellationToken);
222 |
223 | GraphSearchEventSource.Instance?.GraphKNearestVisitedNodesReporter?.Invoke(visitedNodesCount);
224 |
225 | return resultIds.Select(id => new SmallWorld.KNNSearchResult(id, GraphCore.Items[id], RuntimeDistance(id, -1))).ToList();
226 | }
227 | }
228 | catch (GraphChangedException)
229 | {
230 | if(retries > 0)
231 | {
232 | retries--;
233 | continue;
234 | }
235 | throw;
236 | }
237 | catch(Exception)
238 | {
239 | if (versionNow != Interlocked.Read(ref _version))
240 | {
241 | if (retries > 0)
242 | {
243 | retries--;
244 | continue;
245 | }
246 | }
247 | throw;
248 | }
249 | }
250 | }
251 |
252 | ///
253 | /// Serializes core of the graph.
254 | ///
255 | /// Bytes representing edges.
256 | internal void Serialize(Stream stream)
257 | {
258 | GraphCore.Serialize(stream);
259 | MessagePackSerializer.Serialize(stream, EntryPoint);
260 | }
261 |
262 | ///
263 | /// Deserilaizes graph edges and assigns nodes to the items.
264 | ///
265 | /// The underlying items.
266 | /// The serialized edges.
267 | internal TItem[] Deserialize(IReadOnlyList items, Stream stream)
268 | {
269 | // readStrict: true -> removed, as not available anymore on MessagePack 2.0 - also probably not necessary anymore
270 | // see https://github.com/neuecc/MessagePack-CSharp/pull/663
271 |
272 | var core = new Core(Distance, Parameters);
273 | var remainingItems = core.Deserialize(items, stream);
274 | EntryPoint = MessagePackSerializer.Deserialize(stream);
275 | GraphCore = core;
276 | return remainingItems;
277 | }
278 |
279 | ///
280 | /// Prints edges of the graph.
281 | ///
282 | /// String representation of the graph's edges.
283 | internal string Print()
284 | {
285 | var buffer = new StringBuilder();
286 | for (int layer = EntryPoint.Value.MaxLayer; layer >= 0; --layer)
287 | {
288 | buffer.AppendLine($"[LEVEL {layer}]");
289 | BFS(GraphCore, EntryPoint.Value, layer, (node) =>
290 | {
291 | var neighbours = string.Join(", ", node[layer]);
292 | buffer.AppendLine($"({node.Id}) -> {{{neighbours}}}");
293 | });
294 |
295 | buffer.AppendLine();
296 | }
297 |
298 | return buffer.ToString();
299 | }
300 | }
301 | }
302 |
--------------------------------------------------------------------------------
/Src/HNSW.Net/GraphChangedException.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | using System;
7 | using System.Runtime.Serialization;
8 | using System.Threading;
9 |
10 | namespace HNSW.Net
11 | {
12 | [Serializable]
13 | internal class GraphChangedException : Exception
14 | {
15 | public GraphChangedException()
16 | {
17 | }
18 |
19 | public GraphChangedException(string message) : base(message)
20 | {
21 | }
22 |
23 | public GraphChangedException(string message, Exception innerException) : base(message, innerException)
24 | {
25 | }
26 |
27 | protected GraphChangedException(SerializationInfo info, StreamingContext context) : base(info, context)
28 | {
29 | }
30 |
31 | internal static void ThrowIfChanged(ref long version, long versionAtStart)
32 | {
33 | if (Interlocked.Read(ref version) != versionAtStart)
34 | {
35 | throw new GraphChangedException();
36 | }
37 | }
38 | }
39 | }
--------------------------------------------------------------------------------
/Src/HNSW.Net/HNSW.Net.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | netstandard2.0;netstandard2.1;net7.0;net8.0
6 | AnyCPU
7 | Curiosity GmbH
8 | HNSW
9 | Curiosity GmbH and original HSNW.Net authors
10 | HNSW
11 | C# library for fast approximate nearest neighbours search using Hierarchical Navigable Small World graphs. Fork from original at https://github.com/microsoft/HNSW.Net, adds support for incremental build and MessagePack serialization.
12 | MIT
13 | (c) Copyright 2019 Curiosity GmbH, Copyright (c) Microsoft Corporation
14 | true
15 | https://github.com/curiosity-ai/hnsw.net
16 | https://github.com/curiosity-ai/hnsw.net
17 | kNN, ANN, approximate nearest neighbor
18 | false
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/Src/HNSW.Net/IProgressReporter.cs:
--------------------------------------------------------------------------------
1 | namespace HNSW.Net
2 | {
3 | public interface IProgressReporter
4 | {
5 | void Progress(int current, int total);
6 | }
7 | }
8 |
--------------------------------------------------------------------------------
/Src/HNSW.Net/IProvideRandomValues.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Runtime.CompilerServices;
3 |
4 | namespace HNSW.Net
5 | {
6 | public interface IProvideRandomValues
7 | {
8 | bool IsThreadSafe { get; }
9 |
10 | ///
11 | /// Generates a random float. Values returned are from 0.0 up to but not including 1.0.
12 | ///
13 | [MethodImpl(MethodImplOptions.AggressiveInlining)]
14 | float NextFloat();
15 |
16 | ///
17 | /// Fills the elements of a specified array of bytes with random numbers.
18 | ///
19 | /// An array of bytes to contain random numbers.
20 | [MethodImpl(MethodImplOptions.AggressiveInlining)]
21 | void NextFloats(Span buffer);
22 |
23 | ///
24 | /// Returns a random integer that is within a specified range.
25 | ///
26 | /// The inclusive lower bound of the random number returned.
27 | /// The exclusive upper bound of the random number returned. maxValue must be greater than or equal to minValue.
28 | /// A 32-bit signed integer greater than or equal to minValue and less than maxValue; that is, the range of return values includes minValue but not maxValue. If minValue
29 | // equals maxValue, minValue is returned.
30 | [MethodImpl(MethodImplOptions.AggressiveInlining)]
31 | int Next(int minValue, int maxValue);
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/Src/HNSW.Net/MessagePackCompat/FloatBits.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Runtime.InteropServices;
3 |
4 | namespace MessagePackCompat
5 | {
6 | // safe accessor of Single/Double's underlying byte.
7 | // This code is borrowed from MsgPack-Cli https://github.com/msgpack/msgpack-cli
8 |
9 | [StructLayout(LayoutKind.Explicit)]
10 | internal struct Float32Bits
11 | {
12 | [FieldOffset(0)]
13 | public readonly float Value;
14 |
15 | [FieldOffset(0)]
16 | public readonly Byte Byte0;
17 |
18 | [FieldOffset(1)]
19 | public readonly Byte Byte1;
20 |
21 | [FieldOffset(2)]
22 | public readonly Byte Byte2;
23 |
24 | [FieldOffset(3)]
25 | public readonly Byte Byte3;
26 |
27 | public Float32Bits(float value)
28 | {
29 | this = default(Float32Bits);
30 | this.Value = value;
31 | }
32 |
33 | public Float32Bits(byte[] bigEndianBytes, int offset)
34 | {
35 | this = default(Float32Bits);
36 |
37 | if (BitConverter.IsLittleEndian)
38 | {
39 | this.Byte0 = bigEndianBytes[offset + 3];
40 | this.Byte1 = bigEndianBytes[offset + 2];
41 | this.Byte2 = bigEndianBytes[offset + 1];
42 | this.Byte3 = bigEndianBytes[offset];
43 | }
44 | else
45 | {
46 | this.Byte0 = bigEndianBytes[offset];
47 | this.Byte1 = bigEndianBytes[offset + 1];
48 | this.Byte2 = bigEndianBytes[offset + 2];
49 | this.Byte3 = bigEndianBytes[offset + 3];
50 | }
51 | }
52 | }
53 |
54 | [StructLayout(LayoutKind.Explicit)]
55 | internal struct Float64Bits
56 | {
57 | [FieldOffset(0)]
58 | public readonly double Value;
59 |
60 | [FieldOffset(0)]
61 | public readonly Byte Byte0;
62 |
63 | [FieldOffset(1)]
64 | public readonly Byte Byte1;
65 |
66 | [FieldOffset(2)]
67 | public readonly Byte Byte2;
68 |
69 | [FieldOffset(3)]
70 | public readonly Byte Byte3;
71 |
72 | [FieldOffset(4)]
73 | public readonly Byte Byte4;
74 |
75 | [FieldOffset(5)]
76 | public readonly Byte Byte5;
77 |
78 | [FieldOffset(6)]
79 | public readonly Byte Byte6;
80 |
81 | [FieldOffset(7)]
82 | public readonly Byte Byte7;
83 |
84 | public Float64Bits(double value)
85 | {
86 | this = default(Float64Bits);
87 | this.Value = value;
88 | }
89 |
90 | public Float64Bits(byte[] bigEndianBytes, int offset)
91 | {
92 | this = default(Float64Bits);
93 |
94 | if (BitConverter.IsLittleEndian)
95 | {
96 | this.Byte0 = bigEndianBytes[offset + 7];
97 | this.Byte1 = bigEndianBytes[offset + 6];
98 | this.Byte2 = bigEndianBytes[offset + 5];
99 | this.Byte3 = bigEndianBytes[offset + 4];
100 | this.Byte4 = bigEndianBytes[offset + 3];
101 | this.Byte5 = bigEndianBytes[offset + 2];
102 | this.Byte6 = bigEndianBytes[offset + 1];
103 | this.Byte7 = bigEndianBytes[offset];
104 | }
105 | else
106 | {
107 | this.Byte0 = bigEndianBytes[offset];
108 | this.Byte1 = bigEndianBytes[offset + 1];
109 | this.Byte2 = bigEndianBytes[offset + 2];
110 | this.Byte3 = bigEndianBytes[offset + 3];
111 | this.Byte4 = bigEndianBytes[offset + 4];
112 | this.Byte5 = bigEndianBytes[offset + 5];
113 | this.Byte6 = bigEndianBytes[offset + 6];
114 | this.Byte7 = bigEndianBytes[offset + 7];
115 | }
116 | }
117 | }
118 | }
--------------------------------------------------------------------------------
/Src/HNSW.Net/MessagePackCompat/README.md:
--------------------------------------------------------------------------------
1 | These files are included here because they've been removed from the new 2.0 MessagePack API - but we depend on them.
2 | Once we refactor the code to use the new Writer/Reader, we can remove this
3 | See: https://github.com/neuecc/MessagePack-CSharp/blob/master/doc/migration.md
4 |
5 |
6 |
7 |
8 | MessagePack for C#
9 |
10 | MIT License
11 |
12 | Copyright (c) 2017 Yoshifumi Kawai and contributors
13 |
14 | Permission is hereby granted, free of charge, to any person obtaining a copy
15 | of this software and associated documentation files (the "Software"), to deal
16 | in the Software without restriction, including without limitation the rights
17 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18 | copies of the Software, and to permit persons to whom the Software is
19 | furnished to do so, subject to the following conditions:
20 |
21 | The above copyright notice and this permission notice shall be included in all
22 | copies or substantial portions of the Software.
23 |
24 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30 | SOFTWARE.
31 |
32 | ---
33 |
34 | lz4net
35 |
36 | Copyright (c) 2013-2017, Milosz Krajewski
37 |
38 | All rights reserved.
39 |
40 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
41 |
42 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
43 |
44 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
45 |
46 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/Src/HNSW.Net/MessagePackCompat/StringEncoding.cs:
--------------------------------------------------------------------------------
1 | using System.Text;
2 |
3 | namespace MessagePackCompat
4 | {
5 | internal static class StringEncoding
6 | {
7 | public static readonly Encoding UTF8 = new UTF8Encoding(false);
8 | }
9 | }
--------------------------------------------------------------------------------
/Src/HNSW.Net/Node.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net
7 | {
8 | using MessagePack;
9 | using System.Collections.Generic;
10 |
11 | ///
12 | /// The implementation of the node in hnsw graph.
13 | ///
14 | [MessagePackObject]
15 | public struct Node
16 | {
17 | [Key(0)]
18 | public List> Connections;
19 |
20 | [Key(1)] public int Id;
21 |
22 | ///
23 | /// Gets the max layer where the node is presented.
24 | ///
25 | [IgnoreMember]
26 | public int MaxLayer
27 | {
28 | get
29 | {
30 | return Connections.Count - 1;
31 | }
32 | }
33 |
34 | ///
35 | /// Gets connections ids of the node at the given layer
36 | ///
37 | /// The layer to get connections at.
38 | /// The connections of the node at the given layer.
39 | public List this[int layer]
40 | {
41 | get
42 | {
43 | return Connections[layer];
44 | }
45 | set
46 | {
47 | Connections[layer] = value;
48 | }
49 | }
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/Src/HNSW.Net/Properties/AssemblyInfo.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | using System.Runtime.CompilerServices;
7 |
8 | [assembly: InternalsVisibleTo("HNSW.Net.Tests")]
--------------------------------------------------------------------------------
/Src/HNSW.Net/ReverseComparer.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net
7 | {
8 | using System.Collections.Generic;
9 | using System.Diagnostics.CodeAnalysis;
10 |
11 | ///
12 | /// Reverses the order of the nested comparer.
13 | ///
14 | /// The types of items to comapre.
15 | public class ReverseComparer : IComparer
16 | {
17 | private readonly IComparer Comparer;
18 |
19 | ///
20 | /// Initializes a new instance of the class.
21 | ///
22 | /// The comparer to invert.
23 | public ReverseComparer(IComparer comparer)
24 | {
25 | Comparer = comparer;
26 | }
27 |
28 | ///
29 | public int Compare(T x, T y)
30 | {
31 | return Comparer.Compare(y, x);
32 | }
33 | }
34 |
35 | ///
36 | /// Extension methods to shortcut usage.
37 | ///
38 | [SuppressMessage("StyleCop.CSharp.MaintainabilityRules", "SA1402:File may only contain a single type", Justification = "By Design")]
39 | [SuppressMessage("StyleCop.CSharp.OrderingRules", "SA1204:Static elements must appear before instance elements", Justification = "By Design")]
40 | public static class ReverseComparerExtensions
41 | {
42 | ///
43 | /// Creates new wrapper for the given comparer.
44 | ///
45 | /// The types of items to comapre.
46 | /// The source comparer.
47 | /// The inverted to source comparer.
48 | public static ReverseComparer Reverse(this IComparer comparer)
49 | {
50 | return new ReverseComparer(comparer);
51 | }
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/Src/HNSW.Net/ScopeLatencyTracker.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net
7 | {
8 | using System;
9 | using System.Diagnostics;
10 |
11 | ///
12 | /// Latency tracker for using scope.
13 | /// TODO: make it ref struct in C# 8.0
14 | ///
15 | internal struct ScopeLatencyTracker : IDisposable
16 | {
17 | private long StartTimestamp;
18 | private Action LatencyCallback;
19 |
20 | ///
21 | /// Initializes a new instance of the struct.
22 | ///
23 | /// The latency reporting callback to associate with the scope.
24 | internal ScopeLatencyTracker(Action callback)
25 | {
26 | StartTimestamp = callback != null ? Stopwatch.GetTimestamp() : 0;
27 | LatencyCallback = callback;
28 | }
29 |
30 | ///
31 | /// Reports the time ellsapsed between the tracker creation and this call.
32 | ///
33 | public void Dispose()
34 | {
35 | const long ticksPerMicroSecond = TimeSpan.TicksPerMillisecond / 1000;
36 | if (LatencyCallback != null)
37 | {
38 | long ellapsedMuS = (Stopwatch.GetTimestamp() - StartTimestamp) / ticksPerMicroSecond;
39 | LatencyCallback(ellapsedMuS);
40 | }
41 | }
42 | }
43 | }
44 |
--------------------------------------------------------------------------------
/Src/HNSW.Net/SmallWorld.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net
7 | {
8 | using System;
9 | using System.Collections.Generic;
10 | using System.Diagnostics;
11 | using System.Diagnostics.CodeAnalysis;
12 | using System.IO;
13 | using System.Linq;
14 | using System.Threading;
15 | using MessagePack;
16 | using MessagePackCompat;
17 |
18 | ///
19 | /// The Hierarchical Navigable Small World Graphs. https://arxiv.org/abs/1603.09320
20 | ///
21 | /// The type of items to connect into small world.
22 | /// The type of distance between items (expect any numeric type: float, double, decimal, int, ...).
23 | public partial class SmallWorld where TDistance : struct, IComparable
24 | {
25 | private const string SERIALIZATION_HEADER = "HNSW";
26 | private readonly Func Distance;
27 |
28 | private Graph Graph;
29 | private IProvideRandomValues Generator;
30 |
31 | private ReaderWriterLockSlim _rwLock;
32 |
33 | ///
34 | /// Gets the list of items currently held by the SmallWorld graph.
35 | /// The list is not protected by any locks, and should only be used when it is known the graph won't change
36 | ///
37 | public IReadOnlyList UnsafeItems => Graph?.GraphCore?.Items;
38 |
39 | ///
40 | /// Gets a copy of the list of items currently held by the SmallWorld graph.
41 | /// This call is protected by a read-lock and is safe to be called from multiple threads.
42 | ///
43 | public IReadOnlyList Items
44 | {
45 | get
46 | {
47 | if (_rwLock is object)
48 | {
49 | _rwLock.EnterReadLock();
50 | try
51 | {
52 | return Graph.GraphCore.Items.ToList();
53 | }
54 | finally
55 | {
56 | _rwLock.ExitReadLock();
57 | }
58 | }
59 | else
60 | {
61 | return Graph?.GraphCore?.Items;
62 | }
63 | }
64 | }
65 |
66 |
67 | ///
68 | /// Initializes a new instance of the class.
69 | ///
70 | /// The distance function to use in the small world.
71 | /// The random number generator for building graph.
72 | /// Parameters of the algorithm.
73 | public SmallWorld(Func distance, IProvideRandomValues generator, Parameters parameters, bool threadSafe = true)
74 | {
75 | Distance = distance;
76 | Graph = new Graph(Distance, parameters);
77 | Generator = generator;
78 | _rwLock = threadSafe ? new ReaderWriterLockSlim() : null;
79 | }
80 |
81 | ///
82 | /// Builds hnsw graph from the items.
83 | ///
84 | /// The items to connect into the graph.
85 |
86 | public IReadOnlyList AddItems(IReadOnlyList items, IProgressReporter progressReporter = null)
87 | {
88 | _rwLock?.EnterWriteLock();
89 | try
90 | {
91 | return Graph.AddItems(items, Generator, progressReporter);
92 | }
93 | finally
94 | {
95 | _rwLock?.ExitWriteLock();
96 | }
97 | }
98 |
99 | ///
100 | /// Run knn search for a given item.
101 | ///
102 | /// The item to search nearest neighbours.
103 | /// The number of nearest neighbours.
104 | /// Filter results by ID that should be kept (return true to keep, false to exclude from results)
105 | /// Cancellation Token for stopping the search when filtering is active
106 | /// The list of found nearest neighbours.
107 | public IList KNNSearch(TItem item, int k, Func filterItem = null, CancellationToken cancellationToken = default)
108 | {
109 | _rwLock?.EnterReadLock();
110 | try
111 | {
112 | return Graph.KNearest(item, k, filterItem, cancellationToken);
113 | }
114 | finally
115 | {
116 | _rwLock?.ExitReadLock();
117 | }
118 | }
119 |
120 | ///
121 | /// Get the item with the index
122 | ///
123 | /// The index of the item
124 | public TItem GetItem(int index)
125 | {
126 | _rwLock?.EnterReadLock();
127 | try
128 | {
129 | return Items[index];
130 | }
131 | finally
132 | {
133 | _rwLock?.ExitReadLock();
134 | }
135 | }
136 |
137 | ///
138 | /// Serializes the graph WITHOUT linked items.
139 | ///
140 | /// Bytes representing the graph.
141 | public void SerializeGraph(Stream stream)
142 | {
143 | if (Graph == null)
144 | {
145 | throw new InvalidOperationException("The graph does not exist");
146 | }
147 | _rwLock?.EnterReadLock();
148 | try
149 | {
150 | MessagePackBinary.WriteString(stream, SERIALIZATION_HEADER);
151 | MessagePackSerializer.Serialize(stream, Graph.Parameters);
152 | Graph.Serialize(stream);
153 | }
154 | finally
155 | {
156 | _rwLock?.ExitReadLock();
157 | }
158 | }
159 |
160 | ///
161 | /// Deserializes the graph from byte array.
162 | ///
163 | /// The items to assign to the graph's verticies.
164 | /// The serialized parameters and edges.
165 | public static (SmallWorld Graph, TItem[] ItemsNotInGraph) DeserializeGraph(IReadOnlyList items, Func distance, IProvideRandomValues generator, Stream stream, bool threadSafe = true)
166 | {
167 | var p0 = stream.Position;
168 | string hnswHeader;
169 | try
170 | {
171 | hnswHeader = MessagePackBinary.ReadString(stream);
172 | }
173 | catch(Exception E)
174 | {
175 | if(stream.CanSeek) { stream.Position = p0; } //Resets the stream to original position
176 | throw new InvalidDataException($"Invalid header found in stream, data is corrupted or invalid", E);
177 | }
178 |
179 | if (hnswHeader != SERIALIZATION_HEADER)
180 | {
181 | if (stream.CanSeek) { stream.Position = p0; } //Resets the stream to original position
182 | throw new InvalidDataException($"Invalid header found in stream, data is corrupted or invalid");
183 | }
184 |
185 | // readStrict: true -> removed, as not available anymore on MessagePack 2.0 - also probably not necessary anymore
186 | // see https://github.com/neuecc/MessagePack-CSharp/pull/663
187 |
188 | var parameters = MessagePackSerializer.Deserialize(stream);
189 |
190 | //Overwrite previous InitialDistanceCacheSize parameter, so we don't waste time/memory allocating a distance cache for an already existing graph
191 | parameters.InitialDistanceCacheSize = 0;
192 |
193 | var world = new SmallWorld(distance, generator, parameters, threadSafe: threadSafe);
194 | var remainingItems = world.Graph.Deserialize(items, stream);
195 | return (world, remainingItems);
196 | }
197 |
198 | ///
199 | /// Prints edges of the graph. Mostly for debug and test purposes.
200 | ///
201 | /// String representation of the graph's edges.
202 | public string Print()
203 | {
204 | return Graph.Print();
205 | }
206 |
207 | ///
208 | /// Frees the memory used by the Distance Cache
209 | ///
210 | public void ResizeDistanceCache(int newSize)
211 | {
212 | Graph.GraphCore.ResizeDistanceCache(newSize);
213 | }
214 |
215 | [MessagePackObject(keyAsPropertyName:true)]
216 | public class Parameters
217 | {
218 | public Parameters()
219 | {
220 | M = 10;
221 | LevelLambda = 1 / Math.Log(M);
222 | NeighbourHeuristic = NeighbourSelectionHeuristic.SelectSimple;
223 | ConstructionPruning = 200;
224 | ExpandBestSelection = false;
225 | KeepPrunedConnections = false;
226 | EnableDistanceCacheForConstruction = true;
227 | InitialDistanceCacheSize = 1024 * 1024;
228 | InitialItemsSize = 1024;
229 | }
230 |
231 | ///
232 | /// Gets or sets the parameter which defines the maximum number of neighbors in the zero and above-zero layers.
233 | /// The maximum number of neighbors for the zero layer is 2 * M.
234 | /// The maximum number of neighbors for higher layers is M.
235 | ///
236 | public int M { get; set; }
237 |
238 | ///
239 | /// Gets or sets the max level decay parameter. https://en.wikipedia.org/wiki/Exponential_distribution See 'mL' parameter in the HNSW article.
240 | ///
241 | public double LevelLambda { get; set; }
242 |
243 | ///
244 | /// Gets or sets parameter which specifies the type of heuristic to use for best neighbours selection.
245 | ///
246 | public NeighbourSelectionHeuristic NeighbourHeuristic { get; set; }
247 |
248 | ///
249 | /// Gets or sets the number of candidates to consider as neighbours for a given node at the graph construction phase. See 'efConstruction' parameter in the article.
250 | ///
251 | public int ConstructionPruning { get; set; }
252 |
253 | ///
254 | /// Gets or sets a value indicating whether to expand candidates if is used. See 'extendCandidates' parameter in the article.
255 | ///
256 | public bool ExpandBestSelection { get; set; }
257 |
258 | ///
259 | /// Gets or sets a value indicating whether to keep pruned candidates if is used. See 'keepPrunedConnections' parameter in the article.
260 | ///
261 | public bool KeepPrunedConnections { get; set; }
262 |
263 | ///
264 | /// Gets or sets a value indicating whether to cache calculated distances at graph construction time.
265 | ///
266 | public bool EnableDistanceCacheForConstruction { get; set; }
267 |
268 | ///
269 | /// Gets or sets a the initial distance cache size.
270 | /// Note: This value is reset to 0 on deserialization to avoid allocating the distance cache for pre-built graphs.
271 | ///
272 | public int InitialDistanceCacheSize { get; set; }
273 |
274 | ///
275 | /// Gets or sets a the initial size of the Items list
276 | ///
277 | public int InitialItemsSize { get; set; }
278 | }
279 |
280 | public class KNNSearchResult
281 | {
282 | internal KNNSearchResult(int id, TItem item, TDistance distance)
283 | {
284 | Id = id;
285 | Item = item;
286 | Distance = distance;
287 | }
288 |
289 | public int Id { get; }
290 |
291 | public TItem Item { get; }
292 |
293 | public TDistance Distance { get; }
294 |
295 | public override string ToString()
296 | {
297 | return $"I:{Id} Dist:{Distance:n2} [{Item}]";
298 | }
299 | }
300 | }
301 | }
302 |
--------------------------------------------------------------------------------
/Src/HNSW.Net/ThreadSafeFastRandom.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Runtime.CompilerServices;
3 |
4 | namespace HNSW.Net
5 | {
6 | internal static class ThreadSafeFastRandom
7 | {
8 | private static readonly Random _global = new Random();
9 |
10 | [ThreadStatic]
11 | private static FastRandom _local;
12 |
13 | private static int GetGlobalSeed()
14 | {
15 | int seed;
16 | lock (_global)
17 | {
18 | seed = _global.Next();
19 | }
20 | return seed;
21 | }
22 |
23 | ///
24 | /// Returns a non-negative random integer.
25 | ///
26 | /// A 32-bit signed integer that is greater than or equal to 0 and less than System.Int32.MaxValue.
27 | [MethodImpl(MethodImplOptions.AggressiveInlining)]
28 | public static int Next()
29 | {
30 | var inst = _local;
31 | if (inst == null)
32 | {
33 | int seed;
34 | seed = GetGlobalSeed();
35 | _local = inst = new FastRandom(seed);
36 | }
37 | return inst.Next();
38 | }
39 |
40 | ///
41 | /// Returns a non-negative random integer that is less than the specified maximum.
42 | ///
43 | /// The exclusive upper bound of the random number to be generated. maxValue must be greater than or equal to 0.
44 | /// A 32-bit signed integer that is greater than or equal to 0, and less than maxValue; that is, the range of return values ordinarily includes 0 but not maxValue. However,
45 | // if maxValue equals 0, maxValue is returned.
46 | [MethodImpl(MethodImplOptions.AggressiveInlining)]
47 | public static int Next(int maxValue)
48 | {
49 | var inst = _local;
50 | if (inst == null)
51 | {
52 | int seed;
53 | seed = GetGlobalSeed();
54 | _local = inst = new FastRandom(seed);
55 | }
56 | int ans;
57 | do
58 | {
59 | ans = inst.Next(maxValue);
60 | } while (ans == maxValue);
61 |
62 | return ans;
63 | }
64 |
65 | ///
66 | /// Returns a random integer that is within a specified range.
67 | ///
68 | /// The inclusive lower bound of the random number returned.
69 | /// The exclusive upper bound of the random number returned. maxValue must be greater than or equal to minValue.
70 | /// A 32-bit signed integer greater than or equal to minValue and less than maxValue; that is, the range of return values includes minValue but not maxValue. If minValue
71 | // equals maxValue, minValue is returned.
72 | [MethodImpl(MethodImplOptions.AggressiveInlining)]
73 | public static int Next(int minValue, int maxValue)
74 | {
75 | var inst = _local;
76 | if (inst == null)
77 | {
78 | int seed;
79 | seed = GetGlobalSeed();
80 | _local = inst = new FastRandom(seed);
81 | }
82 | return inst.Next(minValue, maxValue);
83 | }
84 |
85 | ///
86 | /// Generates a random float. Values returned are from 0.0 up to but not including 1.0.
87 | ///
88 | [MethodImpl(MethodImplOptions.AggressiveInlining)]
89 | public static float NextFloat()
90 | {
91 | var inst = _local;
92 | if (inst == null)
93 | {
94 | int seed;
95 | seed = GetGlobalSeed();
96 | _local = inst = new FastRandom(seed);
97 | }
98 | return inst.NextFloat();
99 | }
100 |
101 | ///
102 | /// Fills the elements of a specified array of bytes with random numbers.
103 | ///
104 | /// An array of bytes to contain random numbers.
105 | [MethodImpl(MethodImplOptions.AggressiveInlining)]
106 | public static void NextFloats(Span buffer)
107 | {
108 | var inst = _local;
109 | if (inst == null)
110 | {
111 | int seed;
112 | seed = GetGlobalSeed();
113 | _local = inst = new FastRandom(seed);
114 | }
115 | inst.NextFloats(buffer);
116 | }
117 | }
118 | }
--------------------------------------------------------------------------------
/Src/HNSW.Net/TravelingCosts.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net
7 | {
8 | using System;
9 | using System.Collections.Generic;
10 |
11 | ///
12 | /// Implementation of distance calculation from an arbitrary point to the given destination.
13 | ///
14 | /// Type of the points.
15 | /// Type of the distance.
16 | public class TravelingCosts : IComparer
17 | {
18 | private static readonly Comparer DistanceComparer = Comparer.Default;
19 |
20 | private readonly Func Distance;
21 |
22 | public TravelingCosts(Func distance, TItem destination)
23 | {
24 | Distance = distance;
25 | Destination = destination;
26 | }
27 |
28 | public TItem Destination { get; }
29 |
30 | public TDistance From(TItem departure)
31 | {
32 | return Distance(departure, Destination);
33 | }
34 |
35 | ///
36 | /// Compares 2 points by the distance from the destination.
37 | ///
38 | /// Left point.
39 | /// Right point.
40 | ///
41 | /// -1 if x is closer to the destination than y;
42 | /// 0 if x and y are equally far from the destination;
43 | /// 1 if x is farther from the destination than y.
44 | ///
45 | public int Compare(TItem x, TItem y)
46 | {
47 | var fromX = From(x);
48 | var fromY = From(y);
49 | return DistanceComparer.Compare(fromX, fromY);
50 | }
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/Src/HNSW.Net/VectorUtils.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net
7 | {
8 | using System;
9 | using System.Collections.Generic;
10 | using System.Numerics;
11 |
12 | public static class VectorUtils
13 | {
14 | public static float Magnitude(IList vector)
15 | {
16 | float magnitude = 0.0f;
17 | for (int i = 0; i < vector.Count; ++i)
18 | {
19 | magnitude += vector[i] * vector[i];
20 | }
21 |
22 | return (float)Math.Sqrt(magnitude);
23 | }
24 |
25 | public static void Normalize(IList vector)
26 | {
27 | float normFactor = 1 / Magnitude(vector);
28 | for (int i = 0; i < vector.Count; ++i)
29 | {
30 | vector[i] *= normFactor;
31 | }
32 | }
33 |
34 | public static float MagnitudeSIMD(float[] vector)
35 | {
36 | if (!Vector.IsHardwareAccelerated)
37 | {
38 | throw new NotSupportedException($"{nameof(VectorUtils.NormalizeSIMD)} is not supported");
39 | }
40 |
41 | float magnitude = 0.0f;
42 | int step = Vector.Count;
43 |
44 | int i, to = vector.Length - step;
45 | for (i = 0; i <= to; i += Vector.Count)
46 | {
47 | var vi = new Vector(vector, i);
48 | magnitude += Vector.Dot(vi, vi);
49 | }
50 |
51 | for (; i < vector.Length; ++i)
52 | {
53 | magnitude += vector[i] * vector[i];
54 | }
55 |
56 | return (float)Math.Sqrt(magnitude);
57 | }
58 |
59 | public static void NormalizeSIMD(float[] vector)
60 | {
61 | if (!Vector.IsHardwareAccelerated)
62 | {
63 | throw new NotSupportedException($"{nameof(VectorUtils.NormalizeSIMD)} is not supported");
64 | }
65 |
66 | float normFactor = 1f / MagnitudeSIMD(vector);
67 | int step = Vector.Count;
68 |
69 | int i, to = vector.Length - step;
70 | for (i = 0; i <= to; i += step)
71 | {
72 | var vi = new Vector(vector, i);
73 | vi = Vector.Multiply(normFactor, vi);
74 | vi.CopyTo(vector, i);
75 | }
76 |
77 | for (; i < vector.Length; ++i)
78 | {
79 | vector[i] *= normFactor;
80 | }
81 | }
82 | }
83 | }
84 |
--------------------------------------------------------------------------------
/src/HNSW.Net/NeighbourSelectionHeuristic.cs:
--------------------------------------------------------------------------------
1 | //
2 | // Copyright (c) Microsoft Corporation. All rights reserved.
3 | // Licensed under the MIT License.
4 | //
5 |
6 | namespace HNSW.Net
7 | {
8 | ///
9 | /// Type of heuristic to select best neighbours for a node.
10 | ///
11 | public enum NeighbourSelectionHeuristic
12 | {
13 | ///
14 | /// Marker for the Algorithm 3 (SELECT-NEIGHBORS-SIMPLE) from the article. Implemented in
15 | ///
16 | SelectSimple,
17 |
18 | ///
19 | /// Marker for the Algorithm 4 (SELECT-NEIGHBORS-HEURISTIC) from the article. Implemented in
20 | ///
21 | SelectHeuristic
22 | }
23 | }
24 |
--------------------------------------------------------------------------------