├── .gitattributes
├── restore.ps1
├── pack.ps1
├── src
├── OnnxSharp
│ ├── packageIcon.png
│ ├── Formatting
│ │ ├── Align.cs
│ │ ├── ColumnSpec.cs
│ │ ├── TextWriterExtensions.cs
│ │ ├── MarkdownFormatter.cs
│ │ └── ColumnSpecs.cs
│ ├── ValueInfoProtoExtensions.cs
│ ├── TypeProtoTypesSequenceExtensions.cs
│ ├── TensorProtoExtensions.cs
│ ├── TypeProtoTypesMapExtensions.cs
│ ├── TypeProtoTypesTensorExtensions.cs
│ ├── Thrower.cs
│ ├── Ops.cs
│ ├── MessageExtensions.cs
│ ├── MessageParserExtensions.cs
│ ├── Collections
│ │ ├── ReadOnlyListExtensions.cs
│ │ └── ListExtensions.cs
│ ├── OnnxSharp.csproj
│ ├── DimParamOrValue.cs
│ ├── GraphExtensions.Info.cs
│ ├── GraphExtensions.Clean.cs
│ ├── GraphExtensions.SetDim.cs
│ └── onnx.proto3
├── OnnxSharp.Test
│ ├── mnist-8.onnx
│ ├── mnist-8-expected-Clean.onnx
│ ├── mnist-8-expected-SetDim.onnx
│ ├── mnist-8-expected-RemoveInitializersFromInputs.onnx
│ ├── mnist-8-expected-RemoveUnnecessaryInitializerReshapes.onnx
│ ├── ModelProtoTestExtensions.cs
│ ├── OnnxSharp.Test.csproj
│ ├── AssemblyResourceLoader.cs
│ └── GraphExtensionsTest.cs
├── dotnet-onnx
│ ├── packageIcon.png
│ ├── Commands
│ │ ├── CleanCommand.cs
│ │ ├── Command.cs
│ │ ├── InfoCommand.cs
│ │ ├── InputCommand.cs
│ │ ├── InputOutputCommand.cs
│ │ └── SetDimCommand.cs
│ ├── Program.cs
│ └── dotnet-onnx.csproj
├── OnnxSharpConsole
│ ├── mnist-8.onnx
│ ├── Program.cs
│ └── OnnxSharpConsole.csproj
├── Directory.Build.props
├── Project.Output.Executable.props
├── Project.Output.Common.props
├── Project.Output.Library.props
└── Project.Output.Test.props
├── global.json
├── update-tool-from-build.ps1
├── all.ps1
├── nuget.config
├── clean.ps1
├── test.ps1
├── rename.ps1
├── .github
└── workflows
│ └── dotnet.yml
├── update.ps1
├── LICENSE
├── README.md
├── OnnxSharp.sln
└── .gitignore
/.gitattributes:
--------------------------------------------------------------------------------
1 | * text=auto
2 | *.cs diff=csharp
--------------------------------------------------------------------------------
/restore.ps1:
--------------------------------------------------------------------------------
1 | #!/usr/local/bin/powershell
2 | dotnet restore
--------------------------------------------------------------------------------
/pack.ps1:
--------------------------------------------------------------------------------
1 | #!/usr/local/bin/powershell
2 | dotnet pack --nologo -c Release
--------------------------------------------------------------------------------
/src/OnnxSharp/packageIcon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nietras/OnnxSharp/HEAD/src/OnnxSharp/packageIcon.png
--------------------------------------------------------------------------------
/src/OnnxSharp.Test/mnist-8.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nietras/OnnxSharp/HEAD/src/OnnxSharp.Test/mnist-8.onnx
--------------------------------------------------------------------------------
/src/dotnet-onnx/packageIcon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nietras/OnnxSharp/HEAD/src/dotnet-onnx/packageIcon.png
--------------------------------------------------------------------------------
/src/OnnxSharpConsole/mnist-8.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nietras/OnnxSharp/HEAD/src/OnnxSharpConsole/mnist-8.onnx
--------------------------------------------------------------------------------
/global.json:
--------------------------------------------------------------------------------
1 | {
2 | "sdk": {
3 | "version": "8.0.405",
4 | "rollForward": "latestPatch",
5 | "allowPrerelease": false
6 | }
7 | }
--------------------------------------------------------------------------------
/update-tool-from-build.ps1:
--------------------------------------------------------------------------------
1 | #!/usr/local/bin/powershell
2 | dotnet tool update dotnet-onnx --add-source ./build/dotnet-onnx_AnyCPU_Release -g
--------------------------------------------------------------------------------
/src/OnnxSharp.Test/mnist-8-expected-Clean.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nietras/OnnxSharp/HEAD/src/OnnxSharp.Test/mnist-8-expected-Clean.onnx
--------------------------------------------------------------------------------
/src/OnnxSharp.Test/mnist-8-expected-SetDim.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nietras/OnnxSharp/HEAD/src/OnnxSharp.Test/mnist-8-expected-SetDim.onnx
--------------------------------------------------------------------------------
/src/OnnxSharp/Formatting/Align.cs:
--------------------------------------------------------------------------------
1 | namespace Onnx.Formatting
2 | {
3 | internal enum Align
4 | {
5 | Left,
6 | Right,
7 | }
8 | }
--------------------------------------------------------------------------------
/src/OnnxSharp.Test/mnist-8-expected-RemoveInitializersFromInputs.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nietras/OnnxSharp/HEAD/src/OnnxSharp.Test/mnist-8-expected-RemoveInitializersFromInputs.onnx
--------------------------------------------------------------------------------
/src/OnnxSharp.Test/mnist-8-expected-RemoveUnnecessaryInitializerReshapes.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nietras/OnnxSharp/HEAD/src/OnnxSharp.Test/mnist-8-expected-RemoveUnnecessaryInitializerReshapes.onnx
--------------------------------------------------------------------------------
/all.ps1:
--------------------------------------------------------------------------------
1 | #!/usr/local/bin/powershell
2 | .\restore.ps1
3 | .\build.ps1
4 | .\test.ps1
5 | .\pack.ps1
6 | Write-Host "Check output for errors, since scripts do not stop if errors occur." -foregroundcolor "yellow"
--------------------------------------------------------------------------------
/nuget.config:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/src/OnnxSharp/ValueInfoProtoExtensions.cs:
--------------------------------------------------------------------------------
1 | namespace Onnx
2 | {
3 | /// Convenience extension methods.
4 | public static class ValueInfoProtoExtensions
5 | {
6 | }
7 | }
8 |
--------------------------------------------------------------------------------
/src/OnnxSharp/TypeProtoTypesSequenceExtensions.cs:
--------------------------------------------------------------------------------
1 | namespace Onnx
2 | {
3 | /// Convenience extension methods.
4 | public static class TypeProtoTypesSequenceExtensions
5 | {
6 | }
7 | }
8 |
--------------------------------------------------------------------------------
/clean.ps1:
--------------------------------------------------------------------------------
1 | #!/usr/local/bin/powershell
2 | $buildPath = "./build"
3 | $testResultsPath = "./TestResults"
4 | If (Test-Path $buildPath)
5 | { Remove-Item -Confirm -Recurse -Path $buildPath }
6 | If (Test-Path $testResultsPath)
7 | { Remove-Item -Confirm -Recurse -Path $testResultsPath }
--------------------------------------------------------------------------------
/src/dotnet-onnx/Commands/CleanCommand.cs:
--------------------------------------------------------------------------------
1 | using McMaster.Extensions.CommandLineUtils;
2 | using Onnx;
3 |
4 | [Command("clean", Description = "Clean model for inference e.g. remove initializers from inputs")]
5 | public class CleanCommand : InputOutputCommand
6 | {
7 | public CleanCommand(IConsole console) : base(console)
8 | { }
9 |
10 | protected override void Run(ModelProto model)
11 | {
12 | model.Graph.Clean();
13 | }
14 | }
--------------------------------------------------------------------------------
/src/OnnxSharp/TensorProtoExtensions.cs:
--------------------------------------------------------------------------------
1 | namespace Onnx
2 | {
3 | /// Convenience extension methods.
4 | public static class TensorProtoExtensions
5 | {
6 | /// Get data type of as enum.
7 | public static TensorProto.Types.DataType DataType(this TensorProto tensor) =>
8 | (TensorProto.Types.DataType)tensor.DataType;
9 | }
10 | }
11 |
--------------------------------------------------------------------------------
/src/OnnxSharp/TypeProtoTypesMapExtensions.cs:
--------------------------------------------------------------------------------
1 | namespace Onnx
2 | {
3 | /// Convenience extension methods.
4 | public static class TypeProtoTypesMapExtensions
5 | {
6 | /// Get key data type of as enum.
7 | public static TensorProto.Types.DataType KeyType(this TypeProto.Types.Map map) =>
8 | (TensorProto.Types.DataType)map.KeyType;
9 | }
10 | }
11 |
--------------------------------------------------------------------------------
/test.ps1:
--------------------------------------------------------------------------------
1 | #!/usr/local/bin/powershell
2 | Write-Host "Testing Debug X64"
3 | dotnet test --nologo -c Debug -- RunConfiguration.TargetPlatform=x64 /Parallel
4 | Write-Host "Testing Release X64"
5 | dotnet test --nologo -c Release -- RunConfiguration.TargetPlatform=x64 /Parallel
6 | Write-Host "Testing Debug X86"
7 | dotnet test --nologo -c Debug -- RunConfiguration.TargetPlatform=x86 /Parallel
8 | Write-Host "Testing Release X86"
9 | dotnet test --nologo -c Release -- RunConfiguration.TargetPlatform=x86 /Parallel
--------------------------------------------------------------------------------
/src/OnnxSharp/TypeProtoTypesTensorExtensions.cs:
--------------------------------------------------------------------------------
1 | namespace Onnx
2 | {
3 | /// Convenience extension methods.
4 | public static class TypeProtoTypesTensorExtensions
5 | {
6 | /// Get element data type of as enum.
7 | public static TensorProto.Types.DataType ElemType(this TypeProto.Types.Tensor tensor) =>
8 | (TensorProto.Types.DataType)tensor.ElemType;
9 | }
10 | }
11 |
--------------------------------------------------------------------------------
/src/OnnxSharp/Formatting/ColumnSpec.cs:
--------------------------------------------------------------------------------
1 | using System;
2 |
3 | namespace Onnx.Formatting
4 | {
5 | internal record ColumnSpec(string Name, Align Align);
6 | internal record ColumnSpec(string Name, Align Align, Func Get) : ColumnSpec(Name, Align);
7 | }
8 |
9 | // https://stackoverflow.com/questions/64749385/predefined-type-system-runtime-compilerservices-isexternalinit-is-not-defined
10 | namespace System.Runtime.CompilerServices
11 | {
12 | internal static class IsExternalInit { }
13 | }
14 |
15 |
--------------------------------------------------------------------------------
/src/Directory.Build.props:
--------------------------------------------------------------------------------
1 |
2 |
3 | nietras
4 | Copyright © nietras 2025
5 | en
6 | 0.3.0.0
7 | 0.3.2
8 | $(FileVersion)
9 | $(InformationalVersion)
10 |
11 | true
12 | 12.0
13 | false
14 |
15 |
--------------------------------------------------------------------------------
/src/dotnet-onnx/Commands/Command.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Threading.Tasks;
3 |
4 | public abstract class Command
5 | {
6 | public async Task OnExecuteAsync()
7 | {
8 | try
9 | {
10 | await Run();
11 | }
12 | //catch (CliException e)
13 | catch (Exception e)
14 | {
15 | Console.ForegroundColor = ConsoleColor.Red;
16 | Console.Error.WriteLine(e.Message);
17 | Console.ResetColor();
18 | //Environment.Exit(e.ExitCode);
19 | }
20 | }
21 |
22 | public abstract Task Run();
23 | }
24 |
--------------------------------------------------------------------------------
/src/dotnet-onnx/Commands/InfoCommand.cs:
--------------------------------------------------------------------------------
1 | using McMaster.Extensions.CommandLineUtils;
2 | using Onnx;
3 |
4 | [Command("info", Description = "Print information about a model e.g. inputs and outputs")]
5 | public class InfoCommand : InputCommand
6 | {
7 | public InfoCommand(IConsole console) : base(console)
8 | {
9 | LogInput = null;
10 | }
11 |
12 | protected override void Run(ModelProto model)
13 | {
14 | var writer = _console.Out;
15 |
16 | writer.WriteLine($"# {Input}");
17 | writer.WriteLine();
18 |
19 | model.Graph.Info(writer);
20 | }
21 | }
--------------------------------------------------------------------------------
/rename.ps1:
--------------------------------------------------------------------------------
1 | #!/usr/local/bin/powershell
2 | $oldName = "OLD"
3 | $newName = "NEW"
4 | Get-ChildItem -Filter "*$oldName*" -Recurse | Rename-Item -NewName {$_.name -replace $oldName, $newName }
5 | Get-ChildItem -Recurse -Include "*.sln","*.cs","*.xaml","*.xml","*.csproj","*.xproj","*.json","*.md","*.cmd","*.props","*.txt","*.bat" |
6 | ForEach-Object { $a = $_.fullname; ( [System.IO.File]::ReadAllText($a) ) | % {
7 | If ($_.Contains($oldName))
8 | {
9 | $newContent = $_.Replace($oldName, $newName)
10 | #$newContent
11 | [System.IO.File]::WriteAllText($a, $newContent)
12 | "Changed: " + $a
13 | }
14 | }
15 | }
--------------------------------------------------------------------------------
/src/OnnxSharp/Thrower.cs:
--------------------------------------------------------------------------------
1 | using System;
2 |
3 | namespace Onnx
4 | {
5 | internal static class Thrower
6 | {
7 | internal static void EnsureLittleEndian()
8 | {
9 | if (!BitConverter.IsLittleEndian)
10 | {
11 | var message = "Only little-endian systems are supported. " +
12 | "This is due to raw data in onnx files being stored in little-endian order and " +
13 | "conversion to big-endian has not implemented.";
14 | throw new NotSupportedException(message);
15 | }
16 | }
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/src/OnnxSharpConsole/Program.cs:
--------------------------------------------------------------------------------
1 | using Onnx;
2 |
3 | // Examples see https://github.com/onnx/models
4 | var onnxInputFilePath = @"mnist-8.onnx";
5 |
6 | var model = ModelProto.Parser.ParseFromFile(onnxInputFilePath);
7 |
8 | var graph = model.Graph;
9 | // Clean graph e.g. remove initializers from inputs that may prevent constant folding
10 | graph.Clean();
11 | // Set dimension in graph to enable dynamic batch size during inference
12 | graph.SetDim(dimIndex: 0, DimParamOrValue.New("N"));
13 | // Get summarized info about the graph
14 | var info = graph.Info();
15 |
16 | System.Console.WriteLine(info);
17 |
18 | model.WriteToFile(@"mnist-8-clean-dynamic-batch-size.onnx");
--------------------------------------------------------------------------------
/src/Project.Output.Executable.props:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | $(OutputRelativePath)/$(ProjectOrAssemblyName)_$(ShortPlatform)_$(Configuration)
8 | $(OutputRelativePath)/Obj_Exe/$(ProjectOrAssemblyName)_$(ShortPlatform)
9 | $(BaseIntermediateOutputPath)_$(Configuration)
10 |
11 |
12 |
--------------------------------------------------------------------------------
/src/Project.Output.Common.props:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | $(ProjectName)
5 | $(AssemblyName)
6 | $(MSBuildProjectName)
7 | AnyCPU
8 | ../../build/
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.github/workflows/dotnet.yml:
--------------------------------------------------------------------------------
1 | name: .NET
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | jobs:
10 | build:
11 | runs-on: ${{ matrix.os }}
12 |
13 | strategy:
14 | fail-fast: false
15 | matrix:
16 | os: [ ubuntu-latest, windows-latest, macos-latest ]
17 | configuration: [ Release, Debug ]
18 |
19 | steps:
20 | - uses: actions/checkout@v2
21 | - name: Setup .NET
22 | uses: actions/setup-dotnet@v1
23 | with:
24 | dotnet-version: 8.0.x
25 | - name: Restore dependencies
26 | run: dotnet restore
27 | - name: Build
28 | run: dotnet build --no-restore
29 | - name: Test
30 | run: dotnet test --no-build --verbosity normal
31 |
--------------------------------------------------------------------------------
/src/Project.Output.Library.props:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | $(OutputRelativePath)/Libs_$(ShortPlatform)_$(Configuration)/
8 | $(OutputRelativePath)/Obj_Libs/$(ProjectOrAssemblyName)_$(ShortPlatform)
9 | $(BaseIntermediateOutputPath)_$(Configuration)/
10 | $(IntermediateOutputPath)
11 | $(OutputPath)
12 |
13 |
14 |
--------------------------------------------------------------------------------
/src/Project.Output.Test.props:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | $(OutputRelativePath)/Tests/$(ProjectOrAssemblyName)_$(ShortPlatform)_$(Configuration)_$(TargetFramework)/
8 | $(OutputRelativePath)/Obj_Tests/$(ProjectOrAssemblyName)_$(ShortPlatform)
9 | $(BaseIntermediateOutputPath)_$(Configuration)/
10 | $(IntermediateOutputPath)
11 | $(OutputPath)
12 | false
13 |
14 |
15 |
--------------------------------------------------------------------------------
/src/OnnxSharp/Formatting/TextWriterExtensions.cs:
--------------------------------------------------------------------------------
1 | using System.IO;
2 |
3 | namespace Onnx.Formatting
4 | {
5 | internal static class TextWriterExtensions
6 | {
7 | internal static void WriteAligned(this TextWriter writer,
8 | string columnName, Align alignment, char pad, int width)
9 | {
10 | var padCount = width - columnName.Length;
11 | if (alignment == Align.Right)
12 | {
13 | writer.Write(pad, padCount);
14 | }
15 | writer.Write(columnName);
16 | if (alignment == Align.Left)
17 | {
18 | writer.Write(pad, padCount);
19 | }
20 | }
21 |
22 | internal static void Write(this TextWriter writer, char value, int repeatCount)
23 | {
24 | for (int i = 0; i < repeatCount; i++)
25 | {
26 | writer.Write(value);
27 | }
28 | }
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/update.ps1:
--------------------------------------------------------------------------------
1 | #!/usr/local/bin/powershell
2 | # For this script to work with our Azure DevOps feeds,
3 | # you need to install the credential provider for Azure DevOps:
4 | # https://raw.githubusercontent.com/microsoft/artifacts-credprovider/master/helpers/installcredprovider.ps1
5 | $regex = 'PackageReference Include="([^"]*)" Version="([^"]*)"'
6 |
7 | ForEach ($file in get-childitem . -recurse | where {$_.extension -like "*proj"})
8 | {
9 | $packages = Get-Content $file.FullName |
10 | select-string -pattern $regex -AllMatches |
11 | ForEach-Object {$_.Matches} |
12 | ForEach-Object {$_.Groups[1].Value.ToString()}|
13 | sort -Unique
14 |
15 | ForEach ($package in $packages)
16 | {
17 | write-host "Update $file package :$package" -foreground 'magenta'
18 | $fullName = $file.FullName
19 | $command = "dotnet add $fullName package --interactive $package"
20 | write-host $command
21 | iex $command
22 | }
23 | }
--------------------------------------------------------------------------------
/src/dotnet-onnx/Commands/InputCommand.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.ComponentModel.DataAnnotations;
3 | using System.Threading.Tasks;
4 | using McMaster.Extensions.CommandLineUtils;
5 | using Onnx;
6 |
7 | public abstract class InputCommand : Command
8 | {
9 | protected readonly IConsole _console;
10 | protected Action LogInput;
11 |
12 | public InputCommand(IConsole console)
13 | {
14 | _console = console;
15 | LogInput = t => _console.WriteLine(t);
16 | }
17 |
18 | [Argument(0, "input", Description = "Input file path")]
19 | [Required]
20 | public string Input { get; }
21 |
22 | public override Task Run()
23 | {
24 | var model = ModelProto.Parser.ParseFromFile(Input);
25 |
26 | LogInput?.Invoke($"Parsed input file '{Input}' of size {model.CalculateSize()}");
27 |
28 | Run(model);
29 |
30 | return Task.CompletedTask;
31 | }
32 |
33 | protected abstract void Run(ModelProto model);
34 | }
35 |
--------------------------------------------------------------------------------
/src/OnnxSharp.Test/ModelProtoTestExtensions.cs:
--------------------------------------------------------------------------------
1 | using System.IO;
2 | using System.Text.Json;
3 | using Google.Protobuf;
4 | using Onnx;
5 |
6 | namespace OnnxSharp.Test
7 | {
8 | public static class ModelProtoTestExtensions
9 | {
10 | public static void WriteIndentedJsonToFile(this ModelProto model, string filePath)
11 | {
12 | var jsonText = JsonFormatter.Default.Format(model);
13 | var jsonElement = JsonSerializer.Deserialize(jsonText);
14 |
15 | var options = new JsonSerializerOptions() { WriteIndented = true };
16 |
17 | var jsonTextPretty = JsonSerializer.Serialize(jsonElement, options);
18 | File.WriteAllText(filePath, jsonTextPretty);
19 | // Below does not indent
20 | //using var stream = File.Open(filePath, FileMode.Create);
21 | //using var writer = new Utf8JsonWriter(stream);
22 | //JsonSerializer.Serialize(writer, jsonElement, options);
23 | }
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/src/OnnxSharp/Ops.cs:
--------------------------------------------------------------------------------
1 | using System;
2 |
3 | namespace Onnx
4 | {
5 | internal static class Ops
6 | {
7 | internal static class Reshape
8 | {
9 | internal const int InputDataIndex = 0;
10 | internal const int InputShapeIndex = 1;
11 |
12 | // Reshape op supports only one dimension in shape to be dynamic,
13 | // which is defined as -1.
14 | internal const int DynamicReshapeValue = -1;
15 |
16 | internal static readonly OpSpec Spec = new OpSpec(nameof(Reshape), 2, 1);
17 | }
18 |
19 | internal readonly struct OpSpec
20 | {
21 | public OpSpec(string opType, int inputs, int outputs)
22 | {
23 | OpType = opType ?? throw new ArgumentNullException(nameof(opType));
24 | Inputs = inputs;
25 | Outputs = outputs;
26 | }
27 |
28 | public string OpType { get; }
29 | public int Inputs { get; }
30 | public int Outputs { get; }
31 | }
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/src/OnnxSharpConsole/OnnxSharpConsole.csproj:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 |
7 |
8 |
9 |
10 | Exe
11 | net8.0
12 | false
13 |
14 |
15 |
16 |
17 | PreserveNewest
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/src/OnnxSharp/MessageExtensions.cs:
--------------------------------------------------------------------------------
1 | using Google.Protobuf;
2 | using System.IO;
3 |
4 | namespace Onnx
5 | {
6 | /// Convenience extension methods.
7 | public static partial class MessageExtensions
8 | {
9 | ///
10 | /// Writes the given data to the
11 | /// given in protobuf encoding.
12 | ///
13 | public static void WriteToFile(this IMessage message, string filePath)
14 | {
15 | using var stream = File.Open(filePath, FileMode.Create);
16 | message.WriteTo(stream);
17 | }
18 |
19 | ///
20 | /// Writes the given data to the
21 | /// given in JSON encoding.
22 | ///
23 | public static void WriteJsonToFile(this IMessage message, string filePath)
24 | {
25 | using var writer = new StreamWriter(filePath);
26 | JsonFormatter.Default.Format(message, writer);
27 | }
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 nietras
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/dotnet-onnx/Commands/InputOutputCommand.cs:
--------------------------------------------------------------------------------
1 | using System.ComponentModel.DataAnnotations;
2 | using System.Threading.Tasks;
3 | using McMaster.Extensions.CommandLineUtils;
4 | using Onnx;
5 |
6 | public abstract class InputOutputCommand : Command
7 | {
8 | protected readonly IConsole _console;
9 |
10 | public InputOutputCommand(IConsole console)
11 | {
12 | _console = console;
13 | }
14 |
15 | [Argument(0, "input", Description = "Input file path")]
16 | [Required]
17 | public string Input { get; }
18 |
19 | [Argument(1, "output", Description = "Output file path")]
20 | [Required]
21 | public string Output { get; }
22 |
23 | public override Task Run()
24 | {
25 | var model = ModelProto.Parser.ParseFromFile(Input);
26 |
27 | _console.WriteLine($"Parsed input file '{Input}' of size {model.CalculateSize()}");
28 |
29 | Run(model);
30 |
31 | model.WriteToFile(Output);
32 |
33 | _console.WriteLine($"Wrote output file '{Output}' of size {model.CalculateSize()}");
34 |
35 | return Task.CompletedTask;
36 | }
37 |
38 | protected abstract void Run(ModelProto model);
39 | }
40 |
--------------------------------------------------------------------------------
/src/OnnxSharp/MessageParserExtensions.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.IO;
3 | using Google.Protobuf;
4 |
5 | namespace Onnx
6 | {
7 | /// Convenience extension methods.
8 | public static partial class MessageParserExtensions
9 | {
10 | ///
11 | /// Parse from file via .
12 | ///
13 | public static T ParseFromFile(this MessageParser parser, string filePath)
14 | where T : IMessage
15 | {
16 | using var stream = File.Open(filePath, FileMode.Open);
17 | return parser.ParseFrom(stream);
18 | }
19 |
20 | ///
21 | /// Parse from file via
22 | /// and disposes the created stream after parsing is done.
23 | ///
24 | public static T ParseFrom(this MessageParser parser, Func createStream)
25 | where T : IMessage
26 | {
27 | using var stream = createStream();
28 | return parser.ParseFrom(stream);
29 | }
30 | }
31 | }
32 |
--------------------------------------------------------------------------------
/src/dotnet-onnx/Commands/SetDimCommand.cs:
--------------------------------------------------------------------------------
1 | using McMaster.Extensions.CommandLineUtils;
2 | using Onnx;
3 |
4 | [Command("setdim", Description = "Set dimension of reshapes, inputs and outputs of model e.g. set new dynamic or fixed batch size.")]
5 | public class SetDimCommand : InputOutputCommand
6 | {
7 | public SetDimCommand(IConsole console) : base(console)
8 | { }
9 |
10 | [Option("-i|--index", Description = "Dimension index to set. Default = 0.")]
11 | public int Index { get; } = 0; // Parametize defaults
12 |
13 | [Option("-d|--dim", Description = "Dimension to set. Default = N. Use string e.g. 'N' for dynamic batch size or integer e.g. '3' for fixed size")]
14 | public string Dim { get; } = "N";
15 |
16 | protected override void Run(ModelProto model)
17 | {
18 | // Should this not be before loading input? Is the abstract base really that good?
19 |
20 | var dimParamOrValue = int.TryParse(Dim, out var dimValue)
21 | ? DimParamOrValue.New(dimValue)
22 | : DimParamOrValue.New(Dim);
23 |
24 | _console.WriteLine($"Setting dimension at {Index} to '{dimParamOrValue}'");
25 |
26 | model.Graph.SetDim(Index, dimParamOrValue);
27 | }
28 | }
--------------------------------------------------------------------------------
/src/OnnxSharp/Collections/ReadOnlyListExtensions.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 |
4 | namespace Onnx.Collections
5 | {
6 | /// Convenience extension methods for .
7 | public static class ReadOnlyListExtensions
8 | {
9 | /// Compute the product of all values.
10 | public static long Product(this IReadOnlyList values)
11 | {
12 | var product = 1L;
13 | for (int i = 0; i < values.Count; i++)
14 | {
15 | product *= values[i];
16 | }
17 | return product;
18 | }
19 |
20 | internal static T Single(this IReadOnlyList fields, Func select, TSelect valueToFind)
21 | where TSelect : IEquatable
22 | {
23 | for (int i = 0; i < fields.Count; i++)
24 | {
25 | var field = fields[i];
26 | var value = select(field);
27 | if (value.Equals(valueToFind))
28 | {
29 | return field;
30 | }
31 | }
32 | throw new ArgumentException($"Could not find field with value '{valueToFind}'");
33 | }
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/src/dotnet-onnx/Program.cs:
--------------------------------------------------------------------------------
1 | using System.IO;
2 | using System.Threading.Tasks;
3 | using McMaster.Extensions.CommandLineUtils;
4 |
5 | // https://github.com/natemcmaster/CommandLineUtils
6 | // https://natemcmaster.github.io/CommandLineUtils/docs/advanced/dependency-injection.html
7 | // TODO: Change to builder API
8 | // TODO: Clean up
9 | // TODO: Handle multiple command names etc.
10 |
11 | // https://github.com/jonstodle/DotNetSdkHelpers/blob/master/src/DotNetSdkHelpers/Program.cs
12 | // TODO: Switch from attributes to code instead
13 | [Command("dotnet onnx", Description = "Inspect and manipulate ONNX files. Copyright nietras 2021."),
14 | Subcommand(typeof(CleanCommand)),
15 | Subcommand(typeof(SetDimCommand)),
16 | Subcommand(typeof(InfoCommand))
17 | ]
18 | class Program
19 | {
20 | static Task Main(string[] args)
21 | {
22 | var app = new CommandLineApplication(
23 | PhysicalConsole.Singleton,
24 | Directory.GetCurrentDirectory());
25 |
26 | app.Conventions.UseDefaultConventions();
27 | app.UsePagerForHelpText = false;
28 |
29 | return app.ExecuteAsync(args);
30 | }
31 |
32 | public Task OnExecuteAsync(CommandLineApplication app)
33 | {
34 | app.ShowHelp();
35 |
36 | return Task.FromResult(0);
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/src/OnnxSharp/OnnxSharp.csproj:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 |
7 |
8 |
9 |
10 | netstandard2.0
11 | Onnx
12 |
13 | ONNX format parsing and manipulation in C#.
14 | true
15 |
16 | packageIcon.png
17 | $(MSBuildThisFileDirectory)packageIcon.png
18 | onnx
19 | MIT
20 | https://github.com/nietras/OnnxSharp
21 | true
22 | true
23 |
24 | true
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
--------------------------------------------------------------------------------
/src/dotnet-onnx/dotnet-onnx.csproj:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 |
7 |
8 |
9 |
10 | Exe
11 | net8.0
12 | dotnet_onnx
13 |
14 | true
15 | dotnet-onnx
16 | dotnet-onnx
17 |
18 | Inspect and manipulate ONNX files
19 |
20 | packageIcon.png
21 | $(MSBuildThisFileDirectory)packageIcon.png
22 | onnx
23 | MIT
24 | https://github.com/nietras/OnnxSharp
25 | true
26 | true
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/src/OnnxSharp/Collections/ListExtensions.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 |
4 | namespace Onnx.Collections
5 | {
6 | /// Convenience extension methods for .
7 | internal static class ListExtensions
8 | {
9 | internal static bool TryRemove(this IList fields, Func select, Predicate predicate)
10 | {
11 | for (int i = 0; i < fields.Count; i++)
12 | {
13 | var field = fields[i];
14 | var value = select(field);
15 | if (predicate(value))
16 | {
17 | fields.RemoveAt(i);
18 | return true;
19 | }
20 | }
21 | return false;
22 | }
23 |
24 | internal static bool TryRemove(this IList fields, Func select, TSelect valueToRemove)
25 | where TSelect : IEquatable
26 | {
27 | var index = fields.IndexOf(select, valueToRemove);
28 | if (index >= 0)
29 | {
30 | fields.RemoveAt(index);
31 | return true;
32 | }
33 | return false;
34 | }
35 |
36 | internal static int IndexOf(this IList fields, Func select, TSelect valueToFind)
37 | where TSelect : IEquatable
38 | {
39 | for (int i = 0; i < fields.Count; i++)
40 | {
41 | var field = fields[i];
42 | var value = select(field);
43 | if (value.Equals(valueToFind))
44 | {
45 | return i;
46 | }
47 | }
48 | return -1;
49 | }
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/src/OnnxSharp.Test/OnnxSharp.Test.csproj:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 |
7 |
8 |
9 |
10 | net462
11 | $(TargetFrameworks);net8.0
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 | $([System.String]::new('%(EmbeddedResource.Identity)').Replace('\','/'))
38 |
39 |
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/src/OnnxSharp/DimParamOrValue.cs:
--------------------------------------------------------------------------------
1 | using System;
2 |
3 | namespace Onnx
4 | {
5 | ///
6 | /// Dimension represented either a string 'Param' or an integer 'Value'.
7 | ///
8 | public readonly struct DimParamOrValue
9 | {
10 | readonly string _param;
11 | readonly int _value;
12 |
13 | private DimParamOrValue(string param, int value)
14 | {
15 | _param = param;
16 | _value = value;
17 | }
18 |
19 | /// Create a new named dimension parameter.
20 | public static DimParamOrValue New(string param)
21 | {
22 | if (!IsParamValid(param))
23 | {
24 | throw new ArgumentException($"{nameof(param)} '{param}' must be a non-whitespace string like 'N'.");
25 | }
26 | return new DimParamOrValue(param, default);
27 | }
28 |
29 | /// Create a new fixed size dimension.
30 | public static DimParamOrValue New(int value) =>
31 | new DimParamOrValue(default, value);
32 |
33 | /// Get dimension as a named parameter string.
34 | public string Param { get { CheckIsParam(); return _param; } }
35 | /// Get dimension as an integer value.
36 | public int Value { get { CheckIsValue(); return _value; } }
37 |
38 | /// Is the dimension a named parameter.
39 | public bool IsParam => IsParamValid(_param);
40 | /// Is the dimension a fixed sized integer.
41 | public bool IsValue => !IsParam;
42 |
43 | /// Converts the dimension to its equivalent string representation.
44 | public override string ToString() => IsParam ? Param : Value.ToString();
45 | /// Returns the hash code for this instance.
46 | public override int GetHashCode() => IsParam ? Param.GetHashCode() : Value.GetHashCode();
47 |
48 | void CheckIsParam()
49 | {
50 | if (IsValue)
51 | {
52 | throw new ArgumentException($"{nameof(DimParamOrValue)} is a value '{_value}' not a param.");
53 | }
54 | }
55 |
56 | void CheckIsValue()
57 | {
58 | if (IsParam)
59 | {
60 | throw new ArgumentException($"{nameof(DimParamOrValue)} is a param '{_param}' not a value.");
61 | }
62 | }
63 |
64 | static bool IsParamValid(string param) => !string.IsNullOrWhiteSpace(param);
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 | [](https://github.com/nietras/OnnxSharp/stargazers)
3 | [](LICENSE.md)
4 |
5 | |What |Links and Status|
6 | |---------------|------|
7 | |`OnnxSharp` |[](https://www.nuget.org/packages/OnnxSharp/) [](https://www.nuget.org/packages/OnnxSharp/) |
8 | |`dotnet-onnx`|[](https://www.nuget.org/packages/dotnet-onnx/) [](https://www.nuget.org/packages/dotnet-onnx/) |
9 |
10 | # `OnnxSharp` library and `dotnet-onnx` tool
11 | ONNX format parsing and manipulation in C# and with command line .NET tool.
12 |
13 | # Quick Guide
14 | Install latest version of .NET:
15 | * PowerShell (Windows): [https://dot.net/v1/dotnet-install.ps1](https://dot.net/v1/dotnet-install.ps1)
16 | * Bash (Linux/macOS): [https://dot.net/v1/dotnet-install.sh](https://dot.net/v1/dotnet-install.sh)
17 |
18 | #### Code
19 | |What |How |
20 | |--------------|---------------------------------------------------|
21 | |Install |`dotnet add PROJECT.csproj package OnnxSharp`|
22 | |Parse |`var model = ModelProto.Parser.ParseFromFile("mnist-8.onnx");`|
23 | |Info |`var info = model.Graph.Info();`|
24 | |Clean |`model.Graph.Clean();`|
25 | |SetDim |`model.Graph.SetDim();`|
26 | |Write |`model.WriteToFile("mnist-8-clean-dynamic.onnx");`|
27 |
28 | #### Tool
29 | |What |How |
30 | |--------------|----------------------------|
31 | |Install |`dotnet tool install dotnet-onnx -g`|
32 | |Info |`dotnet onnx info mnist-8.onnx`|
33 | |Info |`dotnet onnx info mnist-8.onnx`|
34 | |Clean |`dotnet onnx clean mnist-8.onnx mnist-8-clean.onnx`|
35 | |SetDim |`dotnet onnx setdim mnist-8.onnx mnist-8-setdim.onnx`|
36 |
37 | # Source Code
38 | Base functionality is based on:
39 | ```
40 | .\protoc.exe .\onnx.proto3 --csharp_out=OnnxSharp
41 | ```
42 | Everything else written in beautiful C# 9.0 as extensions to this.
43 |
44 | # Example Code
45 | ```csharp
46 | using Onnx;
47 |
48 | // Examples see https://github.com/onnx/models
49 | var onnxInputFilePath = @"mnist-8.onnx";
50 |
51 | var model = ModelProto.Parser.ParseFromFile(onnxInputFilePath);
52 |
53 | var graph = model.Graph;
54 | // Clean graph e.g. remove initializers from inputs that may prevent constant folding
55 | graph.Clean();
56 | // Set dimension in graph to enable dynamic batch size during inference
57 | graph.SetDim(dimIndex: 0, DimParamOrValue.New("N"));
58 | // Get summarized info about the graph
59 | var info = graph.Info();
60 |
61 | System.Console.WriteLine(info);
62 |
63 | model.WriteToFile(@"mnist-8-clean-dynamic-batch-size.onnx");
64 | ```
--------------------------------------------------------------------------------
/src/OnnxSharp/GraphExtensions.Info.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 | using System.IO;
4 | using System.Linq;
5 | using Onnx.Formatting;
6 |
7 | namespace Onnx
8 | {
9 | public static partial class GraphExtensions
10 | {
11 | /// Summarize information about the .
12 | public static string Info(this GraphProto graph)
13 | {
14 | var writer = new StringWriter();
15 | graph.Info(writer);
16 | return writer.ToString();
17 | }
18 |
19 | /// Summarize information about the .
20 | public static void Info(this GraphProto graph, TextWriter writer)
21 | {
22 | var initializerNameSet = new HashSet(graph.Initializer.Select(i => i.Name));
23 | var inferenceInputs = graph.Input.Where(i => !initializerNameSet.Contains(i.Name)).ToList();
24 | var initializerInputs = graph.Input.Where(i => initializerNameSet.Contains(i.Name)).ToList();
25 |
26 | writer.WriteLine("## Inputs without Initializer");
27 | Info(inferenceInputs, writer);
28 |
29 | writer.WriteLine();
30 | writer.WriteLine("## Outputs");
31 | Info(graph.Output, writer);
32 |
33 | writer.WriteLine();
34 | writer.WriteLine("## Inputs with Initializer");
35 | Info(initializerInputs, writer);
36 |
37 | writer.WriteLine();
38 | writer.WriteLine("## Initializers (Parameters etc.)");
39 | MarkdownFormatter.Format(graph.Initializer, writer);
40 |
41 | writer.WriteLine();
42 | writer.WriteLine("## Value Infos (Intermediate Outputs/Feature Maps etc.)");
43 | Info(graph.ValueInfo, writer);
44 | }
45 |
46 | static void Info(IReadOnlyList valueInfos, TextWriter writer)
47 | {
48 | var tensorTypes = valueInfos.Where(i => i.Type.ValueCase == TypeProto.ValueOneofCase.TensorType).ToList();
49 | WriteInfoIfAny(tensorTypes, "Tensors", MarkdownFormatter.FormatAsTensors, writer);
50 |
51 | var sequenceTypes = valueInfos.Where(i => i.Type.ValueCase == TypeProto.ValueOneofCase.SequenceType).ToList();
52 | WriteInfoIfAny(sequenceTypes, "Sequences", MarkdownFormatter.FormatAsSequences, writer);
53 |
54 | var mapTypes = valueInfos.Where(i => i.Type.ValueCase == TypeProto.ValueOneofCase.MapType).ToList();
55 | WriteInfoIfAny(mapTypes, "Maps", MarkdownFormatter.FormatAsMaps, writer);
56 |
57 | var noneTypes = valueInfos.Where(i => i.Type.ValueCase == TypeProto.ValueOneofCase.None).ToList();
58 | WriteInfoIfAny(noneTypes, "Nones", MarkdownFormatter.FormatAsNones, writer);
59 | }
60 |
61 | static void WriteInfoIfAny(IReadOnlyList values, string name,
62 | Action, TextWriter> info, TextWriter writer)
63 | {
64 | if (values.Count > 0)
65 | {
66 | writer.WriteLine($"### {name}");
67 | info(values, writer);
68 | }
69 | }
70 | }
71 | }
72 |
--------------------------------------------------------------------------------
/src/OnnxSharp.Test/AssemblyResourceLoader.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.IO;
3 | using System.Linq;
4 | using System.Reflection;
5 |
6 | namespace OnnxSharp.Test
7 | {
8 | public static class AssemblyResourceLoader
9 | {
10 | public static readonly string ResourceNamespace =
11 | typeof(AssemblyResourceLoader).Assembly.GetName().Name;
12 | public const string ResourceNamePrefix = "";
13 |
14 | public static byte[] GetBytes(string resourceName)
15 | {
16 | using (var stream = GetStream(resourceName))
17 | using (var memoryStream = new MemoryStream())
18 | {
19 | stream.CopyTo(memoryStream);
20 | return memoryStream.ToArray();
21 | }
22 | }
23 |
24 | public static string[] GetLines(string resourceName) => GetString(resourceName)
25 | .Split(new[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries);
26 |
27 | public static string GetString(string resourceName)
28 | {
29 | using (var stream = GetStream(resourceName))
30 | using (var reader = new StreamReader(stream))
31 | {
32 | return reader.ReadToEnd();
33 | }
34 | }
35 |
36 | public static string GetFullResourceName(string resourceName) =>
37 | ResourceNamePrefix + resourceName;
38 |
39 | public static string FindResourceName(Func filter)
40 | {
41 | var names = FindResourceNames(filter);
42 |
43 | if (names.Length == 0)
44 | {
45 | throw new ArgumentException("Could not find any resource. " +
46 | "The desired file might not have been defined as Embedded Resource.");
47 | }
48 | else if (names.Length != 1)
49 | {
50 | throw new ArgumentException($"Ambiguous name, cannot identify resource - " +
51 | $"found {names.Length} possible candidates.");
52 | }
53 | else
54 | {
55 | return names[0];
56 | }
57 | }
58 |
59 | public static string[] FindResourceNames(Func filter)
60 | {
61 | var allResourceNames = Assembly.GetExecutingAssembly()
62 | .GetManifestResourceNames();
63 | var resources = allResourceNames
64 | .Where(s => s.StartsWith(ResourceNamePrefix))
65 | .Select(s => s.Substring(ResourceNamePrefix.Length))
66 | .ToArray();
67 |
68 | return resources.Where(filter).ToArray();
69 | }
70 |
71 | ///
72 | /// http://stackoverflow.com/questions/3314140/how-to-read-embedded-resource-text-file
73 | ///
74 | public static Stream GetStream(string resourceName)
75 | {
76 | var fullResourceName = GetFullResourceName(resourceName);
77 | var stream = Assembly.GetExecutingAssembly().GetManifestResourceStream(fullResourceName);
78 | if (stream == null)
79 | {
80 | throw new ArgumentException($"Could not find resource '{resourceName}'. " +
81 | $"The desired file might not have been defined as Embedded Resource.");
82 | }
83 | return stream;
84 | }
85 | }
86 | }
87 |
--------------------------------------------------------------------------------
/OnnxSharp.sln:
--------------------------------------------------------------------------------
1 |
2 | Microsoft Visual Studio Solution File, Format Version 12.00
3 | # Visual Studio Version 16
4 | VisualStudioVersion = 16.0.30907.101
5 | MinimumVisualStudioVersion = 10.0.40219.1
6 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OnnxSharp", "src\OnnxSharp\OnnxSharp.csproj", "{226093F1-29E7-477D-B7D6-9662B94A41D4}"
7 | EndProject
8 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OnnxSharpConsole", "src\OnnxSharpConsole\OnnxSharpConsole.csproj", "{1EC63E50-3866-4148-BCC1-54561A564D38}"
9 | EndProject
10 | Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "root", "root", "{604A7FA2-17D1-4B9E-B1D3-15F186389B0B}"
11 | ProjectSection(SolutionItems) = preProject
12 | .gitattributes = .gitattributes
13 | .gitignore = .gitignore
14 | all.ps1 = all.ps1
15 | build.ps1 = build.ps1
16 | clean.ps1 = clean.ps1
17 | global.json = global.json
18 | LICENSE = LICENSE
19 | nuget.config = nuget.config
20 | pack.ps1 = pack.ps1
21 | README.md = README.md
22 | rename.ps1 = rename.ps1
23 | restore.ps1 = restore.ps1
24 | test.ps1 = test.ps1
25 | update-tool-from-build.ps1 = update-tool-from-build.ps1
26 | update.ps1 = update.ps1
27 | EndProjectSection
28 | EndProject
29 | Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{8727AD59-4D98-463D-B4F9-48EB2AC5F8F8}"
30 | ProjectSection(SolutionItems) = preProject
31 | src\Directory.Build.props = src\Directory.Build.props
32 | src\Project.Output.Common.props = src\Project.Output.Common.props
33 | src\Project.Output.Executable.props = src\Project.Output.Executable.props
34 | src\Project.Output.Library.props = src\Project.Output.Library.props
35 | src\Project.Output.Test.props = src\Project.Output.Test.props
36 | EndProjectSection
37 | EndProject
38 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OnnxSharp.Test", "src\OnnxSharp.Test\OnnxSharp.Test.csproj", "{33CFFE3C-7846-419C-89B7-00A512949546}"
39 | EndProject
40 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "dotnet-onnx", "src\dotnet-onnx\dotnet-onnx.csproj", "{DA6F8267-24F1-4104-AB36-10C97B899A8E}"
41 | EndProject
42 | Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "github-workflows", "github-workflows", "{1EA6AA06-C436-482C-A14C-1AC414D5552F}"
43 | ProjectSection(SolutionItems) = preProject
44 | .github\workflows\dotnet.yml = .github\workflows\dotnet.yml
45 | EndProjectSection
46 | EndProject
47 | Global
48 | GlobalSection(SolutionConfigurationPlatforms) = preSolution
49 | Debug|Any CPU = Debug|Any CPU
50 | Release|Any CPU = Release|Any CPU
51 | EndGlobalSection
52 | GlobalSection(ProjectConfigurationPlatforms) = postSolution
53 | {226093F1-29E7-477D-B7D6-9662B94A41D4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
54 | {226093F1-29E7-477D-B7D6-9662B94A41D4}.Debug|Any CPU.Build.0 = Debug|Any CPU
55 | {226093F1-29E7-477D-B7D6-9662B94A41D4}.Release|Any CPU.ActiveCfg = Release|Any CPU
56 | {226093F1-29E7-477D-B7D6-9662B94A41D4}.Release|Any CPU.Build.0 = Release|Any CPU
57 | {1EC63E50-3866-4148-BCC1-54561A564D38}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
58 | {1EC63E50-3866-4148-BCC1-54561A564D38}.Debug|Any CPU.Build.0 = Debug|Any CPU
59 | {1EC63E50-3866-4148-BCC1-54561A564D38}.Release|Any CPU.ActiveCfg = Release|Any CPU
60 | {1EC63E50-3866-4148-BCC1-54561A564D38}.Release|Any CPU.Build.0 = Release|Any CPU
61 | {33CFFE3C-7846-419C-89B7-00A512949546}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
62 | {33CFFE3C-7846-419C-89B7-00A512949546}.Debug|Any CPU.Build.0 = Debug|Any CPU
63 | {33CFFE3C-7846-419C-89B7-00A512949546}.Release|Any CPU.ActiveCfg = Release|Any CPU
64 | {33CFFE3C-7846-419C-89B7-00A512949546}.Release|Any CPU.Build.0 = Release|Any CPU
65 | {DA6F8267-24F1-4104-AB36-10C97B899A8E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
66 | {DA6F8267-24F1-4104-AB36-10C97B899A8E}.Debug|Any CPU.Build.0 = Debug|Any CPU
67 | {DA6F8267-24F1-4104-AB36-10C97B899A8E}.Release|Any CPU.ActiveCfg = Release|Any CPU
68 | {DA6F8267-24F1-4104-AB36-10C97B899A8E}.Release|Any CPU.Build.0 = Release|Any CPU
69 | EndGlobalSection
70 | GlobalSection(SolutionProperties) = preSolution
71 | HideSolutionNode = FALSE
72 | EndGlobalSection
73 | GlobalSection(ExtensibilityGlobals) = postSolution
74 | SolutionGuid = {013E7AA1-9947-4C4E-B7EA-78F4524777AB}
75 | EndGlobalSection
76 | EndGlobal
77 |
--------------------------------------------------------------------------------
/src/OnnxSharp/Formatting/MarkdownFormatter.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 | using System.IO;
4 | using System.Linq;
5 |
6 | namespace Onnx.Formatting
7 | {
8 | internal static class MarkdownFormatter
9 | {
10 | internal static void FormatAsTensors(this IReadOnlyList valueInfos, TextWriter writer)
11 | {
12 | Format(valueInfos, ColumnSpecs.ValueInfo.Tensor, writer);
13 | }
14 |
15 | internal static void FormatAsSequences(this IReadOnlyList valueInfos, TextWriter writer)
16 | {
17 | Format(valueInfos, ColumnSpecs.ValueInfo.Sequence, writer);
18 | }
19 |
20 | internal static void FormatAsMaps(this IReadOnlyList valueInfos, TextWriter writer)
21 | {
22 | Format(valueInfos, ColumnSpecs.ValueInfo.Map, writer);
23 | }
24 |
25 | internal static void FormatAsNones(this IReadOnlyList valueInfos, TextWriter writer)
26 | {
27 | Format(valueInfos, ColumnSpecs.ValueInfo.None, writer);
28 | }
29 |
30 | internal static void Format(this IReadOnlyList summaries, TextWriter writer)
31 | {
32 | Format(summaries, ColumnSpecs.Tensor, writer);
33 | }
34 |
35 | internal static void Format(
36 | IReadOnlyList values,
37 | IReadOnlyList> columnSpecs,
38 | TextWriter writer)
39 | {
40 | var maxColumnWidth = columnSpecs.Select(n => n.Name.Length).ToArray();
41 |
42 | int rows = values.Count;
43 | int cols = columnSpecs.Count;
44 |
45 | var table = new string[rows, cols];
46 | for (int row = 0; row < rows; row++)
47 | {
48 | var summary = values[row];
49 |
50 | for (int col = 0; col < cols; col++)
51 | {
52 | var spec = columnSpecs[col];
53 | var text = spec.Get(summary);
54 | table[row, col] = text;
55 | maxColumnWidth[col] = Math.Max(maxColumnWidth[col], text.Length);
56 | }
57 | }
58 |
59 | Format(table, columnSpecs, maxColumnWidth, writer);
60 | }
61 |
62 | internal static void Format(
63 | string[,] table,
64 | IReadOnlyList columnSpecs,
65 | IReadOnlyList columnWidths,
66 | TextWriter writer)
67 | {
68 | // TODO: Define constants below
69 |
70 | var rows = table.GetLength(0);
71 | var cols = table.GetLength(1);
72 |
73 | // Column Names
74 | for (int col = 0; col < cols; col++)
75 | {
76 | var columnName = columnSpecs[col].Name;
77 | writer.Write('|');
78 | writer.Write(columnName);
79 | writer.Write(' ', columnWidths[col] - columnName.Length);
80 | }
81 | writer.Write('|');
82 | writer.WriteLine();
83 |
84 | // Separator and alignment
85 | for (int col = 0; col < cols; col++)
86 | {
87 | writer.Write('|');
88 | var align = columnSpecs[col].Align;
89 | if (align == Align.Left)
90 | {
91 | writer.Write(':');
92 | }
93 | writer.Write('-', columnWidths[col] - 1);
94 | if (align == Align.Right)
95 | {
96 | writer.Write(':');
97 | }
98 | }
99 | writer.Write('|');
100 | writer.WriteLine();
101 |
102 | // Rows
103 | for (int row = 0; row < rows; row++)
104 | {
105 | for (int col = 0; col < cols; col++)
106 | {
107 | var align = columnSpecs[col].Align;
108 | var value = table[row, col];
109 | writer.Write('|');
110 | writer.WriteAligned(value, align, ' ', columnWidths[col]);
111 | }
112 | writer.Write('|');
113 | writer.WriteLine();
114 | }
115 | }
116 | }
117 | }
118 |
--------------------------------------------------------------------------------
/src/OnnxSharp/GraphExtensions.Clean.cs:
--------------------------------------------------------------------------------
1 | using System.Collections.Generic;
2 | using System.Diagnostics;
3 | using System.Linq;
4 | using Google.Protobuf.Collections;
5 | using Onnx.Collections;
6 |
7 | namespace Onnx
8 | {
9 | /// Extension methods to ONNX graph.
10 | public static partial class GraphExtensions
11 | {
12 | /// Clean graph for inference.
13 | public static void Clean(this GraphProto graph)
14 | {
15 | graph.RemoveInitializersFromInputs();
16 | graph.RemoveUnnecessaryInitializerReshapes();
17 | }
18 |
19 | /// Remove initializers from inputs of graph.
20 | // https://github.com/microsoft/onnxruntime/blob/master/tools/python/remove_initializer_from_input.py
21 | public static void RemoveInitializersFromInputs(this GraphProto graph)
22 | {
23 | var inputs = graph.Input;
24 | var nameToInput = inputs.ToDictionary(i => i.Name, i => i);
25 |
26 | foreach (var initializer in graph.Initializer)
27 | {
28 | if (nameToInput.TryGetValue(initializer.Name, out var input))
29 | {
30 | // https://github.com/protocolbuffers/protobuf/blob/master/csharp/src/Google.Protobuf/Collections/RepeatedField.cs
31 | var removed = inputs.Remove(input);
32 | Debug.Assert(removed, $"{removed} {inputs.Count}");
33 | }
34 | }
35 | }
36 |
37 | /// Remove unnecessary initializer reshapes from graph.
38 | // https://github.com/microsoft/onnxruntime/blob/master/tools/python/remove_initializer_from_input.py
39 | public static void RemoveUnnecessaryInitializerReshapes(this GraphProto graph)
40 | {
41 | var nameToInitializer = graph.Initializer.ToDictionary(i => i.Name, i => i);
42 |
43 | var nodes = graph.Node;
44 | var valueInfos = graph.ValueInfo;
45 |
46 | var nodesToRemove = new List();
47 | for (int nodeIndex = 0; nodeIndex < nodes.Count; nodeIndex++)
48 | {
49 | var node = nodes[nodeIndex];
50 |
51 | var opSpec = Ops.Reshape.Spec;
52 | if (node.OpType == opSpec.OpType)
53 | {
54 | var inputs = node.Input;
55 | var outputs = node.Output;
56 |
57 | // Expected Reshape takes 2 inputs and has 1 output
58 | if (inputs.Count == opSpec.Inputs && outputs.Count == opSpec.Outputs)
59 | {
60 | var dataName = inputs[0];
61 | var shapeName = inputs[1];
62 | var reshapeOutputName = outputs[0];
63 |
64 | // Both inputs must be initializers ("static")
65 | if (nameToInitializer.TryGetValue(dataName, out var dataInitializer) &&
66 | nameToInitializer.TryGetValue(shapeName, out var shapeInitializer))
67 | {
68 | // TODO: Check initializer not used in other nodes
69 |
70 | var outputShapeValue = valueInfos.Single(v => v.Name, reshapeOutputName);
71 |
72 | var outputShapeDims = outputShapeValue.Type.TensorType.Shape.Dim;
73 | var allValue = outputShapeDims.All(d => d.ValueCase ==
74 | TensorShapeProto.Types.Dimension.ValueOneofCase.DimValue);
75 | if (allValue)
76 | {
77 | var outputShape = outputShapeDims.Select(d => d.DimValue).ToArray();
78 |
79 | var allPositive = outputShape.All(d => d > 0);
80 | if (allPositive)
81 | {
82 | // Check shape compared to initializer shape
83 | var dataShape = dataInitializer.Dims.ToArray();
84 |
85 | var outputShapeProductSum = outputShape.Product();
86 | var dataShapeProductSum = dataShape.Product();
87 |
88 | if (outputShapeProductSum == dataShapeProductSum)
89 | {
90 | // Change data shape to the reshape output shape directly
91 | dataInitializer.Dims.Clear();
92 | dataInitializer.Dims.AddRange(outputShape);
93 |
94 | // Remove reshape data shape both as initializer and input
95 | graph.Initializer.TryRemove(i => i.Name, shapeName);
96 | graph.Input.TryRemove(i => i.Name, shapeName);
97 |
98 | nodesToRemove.Add(node);
99 |
100 | // Replace reshape output name with data name directly in all nodes
101 | ReplaceInput(nodes, reshapeOutputName, dataName);
102 | }
103 | }
104 | }
105 | }
106 | }
107 | }
108 | }
109 | foreach (var node in nodesToRemove)
110 | {
111 | nodes.Remove(node);
112 | }
113 | }
114 |
115 | internal static void ReplaceInput(RepeatedField nodes, string oldValue, string newValue)
116 | {
117 | for (int nodeIndex = 0; nodeIndex < nodes.Count; nodeIndex++)
118 | {
119 | var updateNodeInputs = nodes[nodeIndex].Input;
120 | for (int inputIndex = 0; inputIndex < updateNodeInputs.Count; inputIndex++)
121 | {
122 | if (updateNodeInputs[inputIndex] == oldValue)
123 | {
124 | updateNodeInputs[inputIndex] = newValue;
125 | }
126 | }
127 | }
128 | }
129 | }
130 | }
131 |
--------------------------------------------------------------------------------
/src/OnnxSharp/GraphExtensions.SetDim.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 | using System.Diagnostics;
4 | using System.Linq;
5 | using System.Runtime.InteropServices;
6 | using Google.Protobuf;
7 | using Google.Protobuf.Collections;
8 | using Onnx.Collections;
9 |
10 | namespace Onnx
11 | {
12 | public static partial class GraphExtensions
13 | {
14 | ///
15 | /// Set dimension of inputs, value infos, outputs and potential Reshape ops.
16 | /// Default sets leading dimension to dynamic batch size 'N'.
17 | ///
18 | public static void SetDim(this GraphProto graph) =>
19 | graph.SetDim(dimIndex: 0, DimParamOrValue.New("N"));
20 |
21 | ///
22 | /// Set dimension of inputs, value infos, outputs and potential Reshape ops.
23 | /// Can be used to make models have dynamic batch size or different static batch sizes.
24 | ///
25 | public static void SetDim(this GraphProto graph, int dimIndex, DimParamOrValue dimParamOrValue)
26 | {
27 | // Reshape ops have their "new shape" defined as input to the reshape op.
28 | // This input needs to be changed to reflect new dim e.g. be set -1 if dynamic.
29 | var reshapeDimValue = dimParamOrValue.IsParam
30 | ? Ops.Reshape.DynamicReshapeValue
31 | : dimParamOrValue.Value;
32 | SetDimInReshapes(graph, dimIndex, reshapeDimValue);
33 |
34 | // Should we set this based on nodes instead? Handling input, outputs based on that?
35 |
36 | // Shapes are defined in inputs, valueInfos and outputs
37 | //
38 | // Only real inputs should be changed, not "initializer" inputs
39 | var initializserNames = new HashSet(graph.Initializer.Select(i => i.Name));
40 | var inferenceInputs = graph.Input.Where(i => !initializserNames.Contains(i.Name));
41 | foreach (var input in inferenceInputs)
42 | {
43 | SetDim(input, dimIndex, dimParamOrValue);
44 | }
45 | //SetDim(graph.Input, dimIndex, dimParam);
46 |
47 | SetDim(graph.ValueInfo, dimIndex, dimParamOrValue);
48 | SetDim(graph.Output, dimIndex, dimParamOrValue);
49 | }
50 |
51 | static void SetDimInReshapes(GraphProto graph, int dimIndex, int dimValue)
52 | {
53 | var nodes = graph.Node;
54 | var initializers = graph.Initializer;
55 |
56 | // TODO: Only fix reshapes that have data input and with dynamic shape after
57 |
58 | var opSpec = Ops.Reshape.Spec;
59 | foreach (var node in nodes)
60 | {
61 | if (node.OpType == opSpec.OpType)
62 | {
63 | var dataInputName = node.Input[Ops.Reshape.InputDataIndex];
64 |
65 | // Check if data input is an initializer if so we should not change the reshape
66 | // and hence skip this reshape node
67 | var dataInitializerIndex = initializers.IndexOf(t => t.Name, dataInputName);
68 | if (dataInitializerIndex >= 0)
69 | { continue; }
70 |
71 | var shapeInputName = node.Input[Ops.Reshape.InputShapeIndex];
72 |
73 | var shape = initializers.Single(tensor => tensor.Name, shapeInputName);
74 |
75 | SetDimInReshapeTensorShape(shape, dimIndex, dimValue);
76 | }
77 | }
78 | }
79 |
80 | static void SetDimInReshapeTensorShape(TensorProto shape, int dimIndex, int dimValue)
81 | {
82 | Debug.Assert(shape.DataType == (int)TensorProto.Types.DataType.Int64);
83 | var dims = shape.Dims;
84 | if (dims.Count > 0 && dims[dimIndex] > 0)
85 | {
86 | // Data may be stored as Int64 or Raw (fixed-width, little-endian)
87 | if (shape.Int64Data.Count > 0)
88 | {
89 | var int64Data = shape.Int64Data;
90 | if (int64Data[dimIndex] == 1) // Dimension we replace
91 | {
92 | int64Data[dimIndex] = dimValue;
93 | }
94 | }
95 | if (!shape.RawData.IsEmpty)
96 | {
97 | var rawData = shape.RawData;
98 | var rawAsInt64Data = MemoryMarshal.Cast(rawData.Span);
99 | Debug.Assert(rawAsInt64Data.Length == dims[dimIndex]);
100 | if (rawAsInt64Data[dimIndex] == 1) // Dimension we replace
101 | {
102 | var newShape = rawAsInt64Data.ToArray();
103 | newShape[dimIndex] = dimValue;
104 | var newShapeBytes = MemoryMarshal.Cast(newShape.AsSpan());
105 | shape.RawData = ByteString.CopyFrom(newShapeBytes);
106 | }
107 | }
108 | }
109 | }
110 |
111 | internal static void SetDim(RepeatedField valueInfos,
112 | int dimIndex, DimParamOrValue dimParamOrValue)
113 | {
114 | for (int i = 0; i < valueInfos.Count; i++)
115 | {
116 | var valueInfo = valueInfos[i];
117 | SetDim(valueInfo, dimIndex, dimParamOrValue);
118 | }
119 | }
120 |
121 | internal static void SetDim(ValueInfoProto valueInfo,
122 | int dimIndex, DimParamOrValue dimParamOrValue)
123 | {
124 | var shape = valueInfo.Type.TensorType.Shape;
125 | var dims = shape.Dim;
126 | var dim = dims[dimIndex];
127 | if (dim.ValueCase == TensorShapeProto.Types.Dimension.ValueOneofCase.DimValue)
128 | {
129 | // TODO: Should perhaps be parameter that says
130 | // bool shouldSetDimFor(dim)
131 | if (dim.DimValue == 1)
132 | {
133 | SetDim(dim, dimParamOrValue);
134 | }
135 | }
136 | }
137 |
138 | internal static void SetDim(TensorShapeProto.Types.Dimension dim,
139 | DimParamOrValue dimParamOrValue)
140 | {
141 | dim.ClearValue();
142 | if (dimParamOrValue.IsParam)
143 | {
144 | dim.DimParam = dimParamOrValue.Param;
145 | }
146 | else
147 | {
148 | dim.DimValue = dimParamOrValue.Value;
149 | }
150 | }
151 | }
152 | }
153 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ## Ignore Visual Studio temporary files, build results, and
2 | ## files generated by popular Visual Studio add-ons.
3 | ##
4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
5 |
6 | # User-specific files
7 | *.rsuser
8 | *.suo
9 | *.user
10 | *.userosscache
11 | *.sln.docstates
12 |
13 | # User-specific files (MonoDevelop/Xamarin Studio)
14 | *.userprefs
15 |
16 | # Mono auto generated files
17 | mono_crash.*
18 |
19 | # Build results
20 | [Dd]ebug/
21 | [Dd]ebugPublic/
22 | [Rr]elease/
23 | [Rr]eleases/
24 | x64/
25 | x86/
26 | [Aa][Rr][Mm]/
27 | [Aa][Rr][Mm]64/
28 | bld/
29 | [Bb]in/
30 | [Oo]bj/
31 | [Ll]og/
32 | [Ll]ogs/
33 | [Bb]uild/
34 |
35 | # Visual Studio 2015/2017 cache/options directory
36 | .vs/
37 | # Uncomment if you have tasks that create the project's static files in wwwroot
38 | #wwwroot/
39 |
40 | # Visual Studio 2017 auto generated files
41 | Generated\ Files/
42 |
43 | # MSTest test Results
44 | [Tt]est[Rr]esult*/
45 | [Bb]uild[Ll]og.*
46 |
47 | # NUnit
48 | *.VisualState.xml
49 | TestResult.xml
50 | nunit-*.xml
51 |
52 | # Build Results of an ATL Project
53 | [Dd]ebugPS/
54 | [Rr]eleasePS/
55 | dlldata.c
56 |
57 | # Benchmark Results
58 | BenchmarkDotNet.Artifacts/
59 |
60 | # .NET Core
61 | project.lock.json
62 | project.fragment.lock.json
63 | artifacts/
64 |
65 | # StyleCop
66 | StyleCopReport.xml
67 |
68 | # Files built by Visual Studio
69 | *_i.c
70 | *_p.c
71 | *_h.h
72 | *.ilk
73 | *.meta
74 | *.obj
75 | *.iobj
76 | *.pch
77 | *.pdb
78 | *.ipdb
79 | *.pgc
80 | *.pgd
81 | *.rsp
82 | *.sbr
83 | *.tlb
84 | *.tli
85 | *.tlh
86 | *.tmp
87 | *.tmp_proj
88 | *_wpftmp.csproj
89 | *.log
90 | *.vspscc
91 | *.vssscc
92 | .builds
93 | *.pidb
94 | *.svclog
95 | *.scc
96 |
97 | # Chutzpah Test files
98 | _Chutzpah*
99 |
100 | # Visual C++ cache files
101 | ipch/
102 | *.aps
103 | *.ncb
104 | *.opendb
105 | *.opensdf
106 | *.sdf
107 | *.cachefile
108 | *.VC.db
109 | *.VC.VC.opendb
110 |
111 | # Visual Studio profiler
112 | *.psess
113 | *.vsp
114 | *.vspx
115 | *.sap
116 |
117 | # Visual Studio Trace Files
118 | *.e2e
119 |
120 | # TFS 2012 Local Workspace
121 | $tf/
122 |
123 | # Guidance Automation Toolkit
124 | *.gpState
125 |
126 | # ReSharper is a .NET coding add-in
127 | _ReSharper*/
128 | *.[Rr]e[Ss]harper
129 | *.DotSettings.user
130 |
131 | # TeamCity is a build add-in
132 | _TeamCity*
133 |
134 | # DotCover is a Code Coverage Tool
135 | *.dotCover
136 |
137 | # AxoCover is a Code Coverage Tool
138 | .axoCover/*
139 | !.axoCover/settings.json
140 |
141 | # Visual Studio code coverage results
142 | *.coverage
143 | *.coveragexml
144 |
145 | # NCrunch
146 | _NCrunch_*
147 | .*crunch*.local.xml
148 | nCrunchTemp_*
149 |
150 | # MightyMoose
151 | *.mm.*
152 | AutoTest.Net/
153 |
154 | # Web workbench (sass)
155 | .sass-cache/
156 |
157 | # Installshield output folder
158 | [Ee]xpress/
159 |
160 | # DocProject is a documentation generator add-in
161 | DocProject/buildhelp/
162 | DocProject/Help/*.HxT
163 | DocProject/Help/*.HxC
164 | DocProject/Help/*.hhc
165 | DocProject/Help/*.hhk
166 | DocProject/Help/*.hhp
167 | DocProject/Help/Html2
168 | DocProject/Help/html
169 |
170 | # Click-Once directory
171 | publish/
172 |
173 | # Publish Web Output
174 | *.[Pp]ublish.xml
175 | *.azurePubxml
176 | # Note: Comment the next line if you want to checkin your web deploy settings,
177 | # but database connection strings (with potential passwords) will be unencrypted
178 | *.pubxml
179 | *.publishproj
180 |
181 | # Microsoft Azure Web App publish settings. Comment the next line if you want to
182 | # checkin your Azure Web App publish settings, but sensitive information contained
183 | # in these scripts will be unencrypted
184 | PublishScripts/
185 |
186 | # NuGet Packages
187 | *.nupkg
188 | # NuGet Symbol Packages
189 | *.snupkg
190 | # The packages folder can be ignored because of Package Restore
191 | **/[Pp]ackages/*
192 | # except build/, which is used as an MSBuild target.
193 | !**/[Pp]ackages/build/
194 | # Uncomment if necessary however generally it will be regenerated when needed
195 | #!**/[Pp]ackages/repositories.config
196 | # NuGet v3's project.json files produces more ignorable files
197 | *.nuget.props
198 | *.nuget.targets
199 |
200 | # Microsoft Azure Build Output
201 | csx/
202 | *.build.csdef
203 |
204 | # Microsoft Azure Emulator
205 | ecf/
206 | rcf/
207 |
208 | # Windows Store app package directories and files
209 | AppPackages/
210 | BundleArtifacts/
211 | Package.StoreAssociation.xml
212 | _pkginfo.txt
213 | *.appx
214 | *.appxbundle
215 | *.appxupload
216 |
217 | # Visual Studio cache files
218 | # files ending in .cache can be ignored
219 | *.[Cc]ache
220 | # but keep track of directories ending in .cache
221 | !?*.[Cc]ache/
222 |
223 | # Others
224 | ClientBin/
225 | ~$*
226 | *~
227 | *.dbmdl
228 | *.dbproj.schemaview
229 | *.jfm
230 | *.pfx
231 | *.publishsettings
232 | orleans.codegen.cs
233 |
234 | # Including strong name files can present a security risk
235 | # (https://github.com/github/gitignore/pull/2483#issue-259490424)
236 | #*.snk
237 |
238 | # Since there are multiple workflows, uncomment next line to ignore bower_components
239 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
240 | #bower_components/
241 |
242 | # RIA/Silverlight projects
243 | Generated_Code/
244 |
245 | # Backup & report files from converting an old project file
246 | # to a newer Visual Studio version. Backup files are not needed,
247 | # because we have git ;-)
248 | _UpgradeReport_Files/
249 | Backup*/
250 | UpgradeLog*.XML
251 | UpgradeLog*.htm
252 | ServiceFabricBackup/
253 | *.rptproj.bak
254 |
255 | # SQL Server files
256 | *.mdf
257 | *.ldf
258 | *.ndf
259 |
260 | # Business Intelligence projects
261 | *.rdl.data
262 | *.bim.layout
263 | *.bim_*.settings
264 | *.rptproj.rsuser
265 | *- [Bb]ackup.rdl
266 | *- [Bb]ackup ([0-9]).rdl
267 | *- [Bb]ackup ([0-9][0-9]).rdl
268 |
269 | # Microsoft Fakes
270 | FakesAssemblies/
271 |
272 | # GhostDoc plugin setting file
273 | *.GhostDoc.xml
274 |
275 | # Node.js Tools for Visual Studio
276 | .ntvs_analysis.dat
277 | node_modules/
278 |
279 | # Visual Studio 6 build log
280 | *.plg
281 |
282 | # Visual Studio 6 workspace options file
283 | *.opt
284 |
285 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
286 | *.vbw
287 |
288 | # Visual Studio LightSwitch build output
289 | **/*.HTMLClient/GeneratedArtifacts
290 | **/*.DesktopClient/GeneratedArtifacts
291 | **/*.DesktopClient/ModelManifest.xml
292 | **/*.Server/GeneratedArtifacts
293 | **/*.Server/ModelManifest.xml
294 | _Pvt_Extensions
295 |
296 | # Paket dependency manager
297 | .paket/paket.exe
298 | paket-files/
299 |
300 | # FAKE - F# Make
301 | .fake/
302 |
303 | # CodeRush personal settings
304 | .cr/personal
305 |
306 | # Python Tools for Visual Studio (PTVS)
307 | __pycache__/
308 | *.pyc
309 |
310 | # Cake - Uncomment if you are using it
311 | # tools/**
312 | # !tools/packages.config
313 |
314 | # Tabs Studio
315 | *.tss
316 |
317 | # Telerik's JustMock configuration file
318 | *.jmconfig
319 |
320 | # BizTalk build output
321 | *.btp.cs
322 | *.btm.cs
323 | *.odx.cs
324 | *.xsd.cs
325 |
326 | # OpenCover UI analysis results
327 | OpenCover/
328 |
329 | # Azure Stream Analytics local run output
330 | ASALocalRun/
331 |
332 | # MSBuild Binary and Structured Log
333 | *.binlog
334 |
335 | # NVidia Nsight GPU debugger configuration file
336 | *.nvuser
337 |
338 | # MFractors (Xamarin productivity tool) working folder
339 | .mfractor/
340 |
341 | # Local History for Visual Studio
342 | .localhistory/
343 |
344 | # BeatPulse healthcheck temp database
345 | healthchecksdb
346 |
347 | # Backup folder for Package Reference Convert tool in Visual Studio 2017
348 | MigrationBackup/
349 |
350 | # Ionide (cross platform F# VS Code tools) working folder
351 | .ionide/
352 |
353 |
--------------------------------------------------------------------------------
/src/OnnxSharp.Test/GraphExtensionsTest.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Diagnostics;
3 | using System.IO;
4 | using Google.Protobuf;
5 | using Microsoft.VisualStudio.TestTools.UnitTesting;
6 | using Onnx;
7 |
8 | namespace OnnxSharp.Test
9 | {
10 | [TestClass]
11 | public class GraphExtensionsTest
12 | {
13 | readonly Func m_createStream = () => AssemblyResourceLoader.GetStream("mnist-8.onnx");
14 |
15 | [TestMethod]
16 | public void ParseFrom()
17 | {
18 | // Act
19 | var model = ModelProto.Parser.ParseFrom(m_createStream);
20 |
21 | // Assert
22 | var graph = model.Graph;
23 | // 9 inputs since includes initializers
24 | Assert.AreEqual(9, graph.Input.Count);
25 | Assert.AreEqual(1, graph.Output.Count);
26 | }
27 |
28 | [TestMethod]
29 | public void Info()
30 | {
31 | // Arrange
32 | var model = ModelProto.Parser.ParseFrom(m_createStream);
33 |
34 | // Act
35 | var actual = model.Graph.Info();
36 |
37 | // Assert
38 | Trace.WriteLine(actual);
39 | var expected = ExpectedInfo;
40 | Assert.AreEqual(expected, actual);
41 | }
42 |
43 | [TestMethod]
44 | public void Clean()
45 | {
46 | // Arrange
47 | var model = ModelProto.Parser.ParseFrom(m_createStream);
48 |
49 | // Act
50 | model.Graph.Clean();
51 |
52 | // Assert
53 | var graph = model.Graph;
54 | Assert.AreEqual(1, graph.Input.Count);
55 | Assert.AreEqual(1, graph.Output.Count);
56 | var expectedName = $"mnist-8-expected-{nameof(Clean)}.onnx";
57 | AssertModelBytesEqualToEmbeddedExpected(model, expectedName);
58 | }
59 |
60 | [TestMethod]
61 | public void RemoveInitializersFromInputs()
62 | {
63 | // Arrange
64 | var model = ModelProto.Parser.ParseFrom(m_createStream);
65 |
66 | // Act
67 | model.Graph.RemoveInitializersFromInputs();
68 |
69 | // Assert
70 | var graph = model.Graph;
71 | Assert.AreEqual(1, graph.Input.Count);
72 | Assert.AreEqual(1, graph.Output.Count);
73 | var expectedName = $"mnist-8-expected-{nameof(RemoveInitializersFromInputs)}.onnx";
74 | AssertModelBytesEqualToEmbeddedExpected(model, expectedName);
75 | }
76 |
77 | [TestMethod]
78 | public void RemoveUnnecessaryInitializerReshapes()
79 | {
80 | // Arrange
81 | var model = ModelProto.Parser.ParseFrom(m_createStream);
82 |
83 | // Act
84 | model.Graph.RemoveUnnecessaryInitializerReshapes();
85 |
86 | // Assert
87 | var graph = model.Graph;
88 | Assert.AreEqual(8, graph.Input.Count);
89 | Assert.AreEqual(1, graph.Output.Count);
90 | var expectedName = $"mnist-8-expected-{nameof(RemoveUnnecessaryInitializerReshapes)}.onnx";
91 | AssertModelBytesEqualToEmbeddedExpected(model, expectedName);
92 | }
93 |
94 | [TestMethod]
95 | public void SetDim()
96 | {
97 | // Arrange
98 | var model = ModelProto.Parser.ParseFrom(m_createStream);
99 |
100 | // Act
101 | model.Graph.SetDim(dimIndex: 0, DimParamOrValue.New("N"));
102 |
103 | // Assert
104 | var graph = model.Graph;
105 | Assert.AreEqual(9, graph.Input.Count);
106 | Assert.AreEqual(1, graph.Output.Count);
107 | var expectedName = $"mnist-8-expected-{nameof(SetDim)}.onnx";
108 | AssertModelBytesEqualToEmbeddedExpected(model, expectedName);
109 | }
110 |
111 | static void AssertModelBytesEqualToEmbeddedExpected(ModelProto model, string expectedName)
112 | {
113 | var actualBytes = model.ToByteArray();
114 | //model.WriteToFile(expectedName);
115 | var expectedBytes = AssemblyResourceLoader.GetBytes(expectedName);
116 | CollectionAssert.AreEqual(expectedBytes, actualBytes);
117 | }
118 |
119 | const string ExpectedInfo = @"## Inputs without Initializer
120 | ### Tensors
121 | |Name |Type |ElemType|Shape |Π(Shape)|SizeInBytes|SizeInFile|
122 | |:-----|:---------|:-------|--------:|-------:|----------:|---------:|
123 | |Input3|TensorType|Float |1x1x28x28| 784| 3136| 32|
124 |
125 | ## Outputs
126 | ### Tensors
127 | |Name |Type |ElemType|Shape|Π(Shape)|SizeInBytes|SizeInFile|
128 | |:---------------|:---------|:-------|----:|-------:|----------:|---------:|
129 | |Plus214_Output_0|TensorType|Float | 1x10| 10| 40| 34|
130 |
131 | ## Inputs with Initializer
132 | ### Tensors
133 | |Name |Type |ElemType|Shape |Π(Shape)|SizeInBytes|SizeInFile|
134 | |:---------------------------------|:---------|:-------|--------:|-------:|----------:|---------:|
135 | |Parameter5 |TensorType|Float | 8x1x5x5| 200| 800| 36|
136 | |Parameter6 |TensorType|Float | 8x1x1| 8| 32| 32|
137 | |Parameter87 |TensorType|Float | 16x8x5x5| 3200| 12800| 37|
138 | |Parameter88 |TensorType|Float | 16x1x1| 16| 64| 33|
139 | |Pooling160_Output_0_reshape0_shape|TensorType|Int64 | 2| 2| 16| 48|
140 | |Parameter193 |TensorType|Float |16x4x4x10| 2560| 10240| 38|
141 | |Parameter193_reshape1_shape |TensorType|Int64 | 2| 2| 16| 41|
142 | |Parameter194 |TensorType|Float | 1x10| 10| 40| 30|
143 |
144 | ## Initializers (Parameters etc.)
145 | |Name |DataType|Dims |Π(Dims)|[v0,v1..vN] or (Min,Mean,Max) |SizeInBytes|SizeInFile|
146 | |:---------------------------------|:-------|--------:|------:|-----------------------------------:|----------:|---------:|
147 | |Parameter193 |Float |16x4x4x10| 2560|(-7.595E-001,-1.779E-003,1.186E+000)| 10240| 10265|
148 | |Parameter87 |Float | 16x8x5x5| 3200|(-5.089E-001,-3.028E-002,5.647E-001)| 12800| 12824|
149 | |Parameter5 |Float | 8x1x5x5| 200|(-9.727E-001,-7.360E-003,1.019E+000)| 800| 823|
150 | |Parameter6 |Float | 8x1x1| 8|(-4.338E-001,-1.023E-001,9.164E-002)| 32| 53|
151 | |Parameter88 |Float | 16x1x1| 16|(-4.147E-001,-1.554E-001,1.328E-002)| 64| 86|
152 | |Pooling160_Output_0_reshape0_shape|Int64 | 2| 2| [1,256]| 16| 46|
153 | |Parameter193_reshape1_shape |Int64 | 2| 2| [256,10]| 16| 39|
154 | |Parameter194 |Float | 1x10| 10|(-1.264E-001,-4.777E-006,1.402E-001)| 40| 62|
155 |
156 | ## Value Infos (Intermediate Outputs/Feature Maps etc.)
157 | ### Tensors
158 | |Name |Type |ElemType|Shape |Π(Shape)|SizeInBytes|SizeInFile|
159 | |:---------------------------|:---------|:-------|---------:|-------:|----------:|---------:|
160 | |Parameter193_reshape1 |TensorType|Float | 256x10| 2560| 10240| 40|
161 | |Convolution28_Output_0 |TensorType|Float | 1x8x28x28| 6272| 25088| 48|
162 | |Plus30_Output_0 |TensorType|Float | 1x8x28x28| 6272| 25088| 41|
163 | |ReLU32_Output_0 |TensorType|Float | 1x8x28x28| 6272| 25088| 41|
164 | |Pooling66_Output_0 |TensorType|Float | 1x8x14x14| 1568| 6272| 44|
165 | |Convolution110_Output_0 |TensorType|Float |1x16x14x14| 3136| 12544| 49|
166 | |Plus112_Output_0 |TensorType|Float |1x16x14x14| 3136| 12544| 42|
167 | |ReLU114_Output_0 |TensorType|Float |1x16x14x14| 3136| 12544| 42|
168 | |Pooling160_Output_0 |TensorType|Float | 1x16x4x4| 256| 1024| 45|
169 | |Pooling160_Output_0_reshape0|TensorType|Float | 1x256| 256| 1024| 47|
170 | |Times212_Output_0 |TensorType|Float | 1x10| 10| 40| 35|
171 | ";
172 | }
173 | }
174 |
--------------------------------------------------------------------------------
/src/OnnxSharp/Formatting/ColumnSpecs.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 | using System.Linq;
4 | using System.Runtime.InteropServices;
5 | using Google.Protobuf;
6 | using Onnx.Collections;
7 |
8 | namespace Onnx.Formatting
9 | {
10 | internal static partial class ColumnSpecs
11 | {
12 | internal static partial class ValueInfo
13 | {
14 | internal static readonly IReadOnlyList> Tensor =
15 | new ColumnSpec[]
16 | {
17 | new ("Name", Align.Left, i => i.Name),
18 | new ("Type", Align.Left, i => i.Type.ValueCase.ToString()),
19 | new ("ElemType", Align.Left, i => i.Type.TensorType.ElemType().ToString()),
20 | new ("Shape", Align.Right, i => FormatShape(i.Type.TensorType.Shape)),
21 | new ("Π(Shape)", Align.Right, i => FormatShapeProduct(i.Type.TensorType.Shape)),
22 | new ("SizeInBytes", Align.Right, i => SizeInBytes(i.Type.TensorType)),
23 | new ("SizeInFile", Align.Right, i => i.CalculateSize().ToString()),
24 | };
25 |
26 | internal static readonly IReadOnlyList> Sequence =
27 | new ColumnSpec[]
28 | {
29 | new ("Name", Align.Left, i => i.Name),
30 | new ("Type", Align.Left, i => i.Type.ValueCase.ToString()),
31 | new ("ElemType", Align.Left, i => i.Type.SequenceType.ElemType.ValueCase.ToString()),
32 | new ("SizeInBytes", Align.Right, i => SizeInBytes(i.Type.TensorType)),
33 | new ("SizeInFile", Align.Right, i => i.CalculateSize().ToString()),
34 | };
35 |
36 | internal static readonly IReadOnlyList> Map =
37 | new ColumnSpec[]
38 | {
39 | new ("Name", Align.Left, i => i.Name),
40 | new ("Type", Align.Left, i => i.Type.ValueCase.ToString()),
41 | new ("KeyType", Align.Left, i => i.Type.MapType.KeyType().ToString()),
42 | new ("ValueType", Align.Left, i => i.Type.MapType.ValueType.ValueCase.ToString()),
43 | new ("SizeInBytes", Align.Right, i => SizeInBytes(i.Type.TensorType)),
44 | new ("SizeInFile", Align.Right, i => i.CalculateSize().ToString()),
45 | };
46 |
47 | internal static readonly IReadOnlyList> None =
48 | new ColumnSpec[]
49 | {
50 | new ("Name", Align.Left, i => i.Name),
51 | new ("Type", Align.Left, i => i.Type.ValueCase.ToString()),
52 | new ("SizeInBytes", Align.Right, i => SizeInBytes(i.Type.TensorType)),
53 | new ("SizeInFile", Align.Right, i => i.CalculateSize().ToString()),
54 | };
55 | }
56 |
57 | internal static readonly IReadOnlyList> Tensor =
58 | new ColumnSpec[]
59 | {
60 | new ("Name", Align.Left, t => t.Name),
61 | new ("DataType", Align.Left, t => t.DataType().ToString()),
62 | new ("Dims", Align.Right, t => string.Join("x", t.Dims)),
63 | new ("Π(Dims)", Align.Right, t => t.Dims.Product().ToString()),
64 | new ("[v0,v1..vN] or (Min,Mean,Max)", Align.Right, t => FormatValuesOrStats(t)),
65 | new ("SizeInBytes", Align.Right, t => SizeInBytes(t.DataType(), t.Dims)),
66 | new ("SizeInFile", Align.Right, t => t.CalculateSize().ToString()),
67 | };
68 |
69 | static string FormatShape(TensorShapeProto shape)
70 | {
71 | return string.Join("x", shape.Dim.Select(d => Format(d)));
72 | }
73 |
74 | static string FormatShapeProduct(TensorShapeProto shape)
75 | {
76 | var dimValuesProduct = GetDimValuesProduct(shape);
77 | return FormatShapeProduct(shape, dimValuesProduct);
78 | }
79 |
80 | static string SizeInBytes(TypeProto.Types.Tensor tensorType)
81 | {
82 | var bytesPerElement = ByteCount(tensorType.ElemType());
83 | var dimValuesProduct = GetDimValuesProduct(tensorType.Shape);
84 | var sizeInBytes = bytesPerElement * dimValuesProduct;
85 | return sizeInBytes.ToString();
86 | }
87 |
88 | static string SizeInBytes(TensorProto.Types.DataType dataType, IReadOnlyList dims)
89 | {
90 | var bytesPerElement = ByteCount(dataType);
91 | var dimsProduct = dims.Product();
92 | var sizeInBytes = bytesPerElement * dimsProduct;
93 | return sizeInBytes.ToString();
94 | }
95 |
96 | static int ByteCount(TensorProto.Types.DataType dataType) => dataType switch
97 | {
98 | TensorProto.Types.DataType.Double => sizeof(double),
99 | TensorProto.Types.DataType.Float => sizeof(float),
100 | TensorProto.Types.DataType.Float16 => 2,
101 | TensorProto.Types.DataType.Bfloat16 => 2,
102 | TensorProto.Types.DataType.Float8E5M2 => 1,
103 | TensorProto.Types.DataType.Float8E5M2Fnuz => 1,
104 | TensorProto.Types.DataType.Float8E4M3Fn => 1,
105 | TensorProto.Types.DataType.Float8E4M3Fnuz => 1,
106 | TensorProto.Types.DataType.Int64 => sizeof(long),
107 | TensorProto.Types.DataType.Uint64 => sizeof(ulong),
108 | TensorProto.Types.DataType.Int32 => sizeof(int),
109 | TensorProto.Types.DataType.Uint32 => sizeof(uint),
110 | TensorProto.Types.DataType.Int16 => sizeof(short),
111 | TensorProto.Types.DataType.Uint16 => sizeof(ushort),
112 | TensorProto.Types.DataType.Int8 => sizeof(sbyte),
113 | TensorProto.Types.DataType.Uint8 => sizeof(byte),
114 | _ => throw new NotSupportedException(dataType.ToString()),
115 | };
116 |
117 | static long GetDimValuesProduct(TensorShapeProto shape)
118 | {
119 | var dimValues = shape.Dim
120 | .Where(d => d.ValueCase == TensorShapeProto.Types.Dimension.ValueOneofCase.DimValue)
121 | .Select(d => d.DimValue).ToArray();
122 | var dimValuesProduct = dimValues.Product();
123 | return dimValuesProduct;
124 | }
125 |
126 | static string FormatShapeProduct(TensorShapeProto shape, long product)
127 | {
128 | var dimParams = shape.Dim
129 | .Where(d => d.ValueCase == TensorShapeProto.Types.Dimension.ValueOneofCase.DimParam)
130 | .Select(d => d.DimParam).ToArray();
131 | var dimAll = dimParams.Concat([product.ToString()]);
132 | return string.Join("x", dimAll);
133 | }
134 |
135 | static string Format(TensorShapeProto.Types.Dimension d) => d.ValueCase switch
136 | {
137 | TensorShapeProto.Types.Dimension.ValueOneofCase.DimParam => d.DimParam,
138 | TensorShapeProto.Types.Dimension.ValueOneofCase.DimValue => d.DimValue.ToString(),
139 | TensorShapeProto.Types.Dimension.ValueOneofCase.None => "?",
140 | _ => throw new NotSupportedException(d.ValueCase.ToString()),
141 | };
142 |
143 | static unsafe string FormatValuesOrStats(TensorProto tensor) => tensor.DataType() switch
144 | {
145 | // NOTE: Long lines accepted below for structure
146 | TensorProto.Types.DataType.Float => FormatValuesOrStats(tensor.FloatData, tensor.RawData, &Math.Min, (m, v) => m + v, (m, c) => m / c, &Math.Max),
147 | TensorProto.Types.DataType.Double => FormatValuesOrStats(tensor.DoubleData, tensor.RawData, &Math.Min, (m, v) => m + v, (m, c) => m / c, &Math.Max),
148 | TensorProto.Types.DataType.Int32 => FormatValuesOrStats(tensor.Int32Data, tensor.RawData, &Math.Min, (m, v) => m + v, (m, c) => m / c, &Math.Max),
149 | TensorProto.Types.DataType.Int64 => FormatValuesOrStats(tensor.Int64Data, tensor.RawData, &Math.Min, (m, v) => m + v, (m, c) => m / c, &Math.Max),
150 | // TODO: StringData
151 | _ => "N/A",
152 | };
153 |
154 | // NOTE: Perf below is not great since function pointer and func calls cannot be inlined.
155 | // If necessary refactor to use "value type functor"s.
156 | static unsafe string FormatValuesOrStats(
157 | IReadOnlyList values,
158 | ByteString rawData,
159 | delegate* min,
160 | Func add,
161 | Func divide,
162 | delegate* max)
163 | where T : struct
164 | {
165 | // Data may not be in typed part but in raw data
166 | // Unfortunately there is no common and efficient "ground" for
167 | // "IReadOnlyList values" and "ByteString rawData",
168 | // so we have to go through hoops.
169 | // RawData and talk about Constant nodes
170 | // https://github.com/onnx/onnx/issues/2825#issuecomment-644334359
171 |
172 | var useRawData = values.Count == 0 && rawData.Length > 0;
173 | var rawValues = MemoryMarshal.Cast(rawData.Span);
174 | var count = useRawData ? rawValues.Length : values.Count;
175 |
176 | const int MaxValueCountToShow = 4;
177 | if (count <= MaxValueCountToShow)
178 | {
179 | return useRawData
180 | ? FormatValues(rawValues.ToArray())
181 | : FormatValues(values);
182 | }
183 | else if (count > 0)
184 | {
185 | if (useRawData) { Thrower.EnsureLittleEndian(); }
186 | var stats = useRawData
187 | ? GetStats(rawValues, min, add, divide, max)
188 | : GetStats(values, min, add, divide, max);
189 |
190 | return $"({stats.min:E3},{stats.mean:E3},{stats.max:E3})";
191 | }
192 | else
193 | {
194 | return "[]";
195 | }
196 | }
197 |
198 | static string FormatValues(IReadOnlyList values) => $"[{string.Join(",", values)}]";
199 |
200 | static unsafe (T min, TMean mean, T max) GetStats(
201 | ReadOnlySpan values,
202 | delegate* min,
203 | Func add,
204 | Func divide,
205 | delegate* max)
206 | where T : struct
207 | {
208 | T minValue = values[0];
209 | T maxValue = values[0];
210 | TMean sum = add(default, values[0]);
211 | for (int i = 1; i < values.Length; i++)
212 | {
213 | var value = values[i];
214 | minValue = min(minValue, value);
215 | maxValue = max(maxValue, value);
216 | sum = add(sum, value);
217 | }
218 | var mean = divide(sum, values.Length);
219 | return (minValue, mean, maxValue);
220 | }
221 |
222 | static unsafe (T min, TMean mean, T max) GetStats(
223 | IReadOnlyList values,
224 | delegate* min,
225 | Func add,
226 | Func divide,
227 | delegate* max)
228 | where T : struct
229 | {
230 | T minValue = values[0];
231 | T maxValue = values[0];
232 | TMean sum = add(default, values[0]);
233 | for (int i = 1; i < values.Count; i++)
234 | {
235 | var value = values[i];
236 | minValue = min(minValue, value);
237 | maxValue = max(maxValue, value);
238 | sum = add(sum, value);
239 | }
240 | var mean = divide(sum, values.Count);
241 | return (minValue, mean, maxValue);
242 | }
243 | }
244 | }
245 |
--------------------------------------------------------------------------------
/src/OnnxSharp/onnx.proto3:
--------------------------------------------------------------------------------
1 | //
2 | // WARNING: This file is automatically generated! Please edit onnx.in.proto.
3 | //
4 |
5 |
6 | // SPDX-License-Identifier: Apache-2.0
7 |
8 |
9 | syntax = "proto3";
10 |
11 | package onnx;
12 |
13 | // Overview
14 | //
15 | // ONNX is an open specification that is comprised of the following components:
16 | //
17 | // 1) A definition of an extensible computation graph model.
18 | // 2) Definitions of standard data types.
19 | // 3) Definitions of built-in operators.
20 | //
21 | // This document describes the syntax of models and their computation graphs,
22 | // as well as the standard data types. Together, they are referred to as the ONNX
23 | // Intermediate Representation, or 'IR' for short.
24 | //
25 | // The normative semantic specification of the ONNX IR is found in docs/IR.md.
26 | // Definitions of the built-in neural network operators may be found in docs/Operators.md.
27 |
28 | // Notes
29 | //
30 | // Protobuf compatibility
31 | //
32 | // To simplify framework compatibility, ONNX is defined using the subset of protobuf
33 | // that is compatible with both protobuf v2 and v3. This means that we do not use any
34 | // protobuf features that are only available in one of the two versions.
35 | //
36 | // Here are the most notable contortions we have to carry out to work around
37 | // these limitations:
38 | //
39 | // - No 'map' (added protobuf 3.0). We instead represent mappings as lists
40 | // of key-value pairs, where order does not matter and duplicates
41 | // are not allowed.
42 |
43 |
44 | // Versioning
45 | //
46 | // ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md
47 | //
48 | // To be compatible with both proto2 and proto3, we will use a version number
49 | // that is not defined by the default value but an explicit enum number.
50 | enum Version {
51 | // proto3 requires the first enum value to be zero.
52 | // We add this just to appease the compiler.
53 | _START_VERSION = 0;
54 | // The version field is always serialized and we will use it to store the
55 | // version that the graph is generated from. This helps us set up version
56 | // control.
57 | // For the IR, we are using simple numbers starting with 0x00000001,
58 | // which was the version we published on Oct 10, 2017.
59 | IR_VERSION_2017_10_10 = 0x0000000000000001;
60 |
61 | // IR_VERSION 2 published on Oct 30, 2017
62 | // - Added type discriminator to AttributeProto to support proto3 users
63 | IR_VERSION_2017_10_30 = 0x0000000000000002;
64 |
65 | // IR VERSION 3 published on Nov 3, 2017
66 | // - For operator versioning:
67 | // - Added new message OperatorSetIdProto
68 | // - Added opset_import in ModelProto
69 | // - For vendor extensions, added domain in NodeProto
70 | IR_VERSION_2017_11_3 = 0x0000000000000003;
71 |
72 | // IR VERSION 4 published on Jan 22, 2019
73 | // - Relax constraint that initializers should be a subset of graph inputs
74 | // - Add type BFLOAT16
75 | IR_VERSION_2019_1_22 = 0x0000000000000004;
76 |
77 | // IR VERSION 5 published on March 18, 2019
78 | // - Add message TensorAnnotation.
79 | // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters.
80 | IR_VERSION_2019_3_18 = 0x0000000000000005;
81 |
82 | // IR VERSION 6 published on Sep 19, 2019
83 | // - Add support for sparse tensor constants stored in model.
84 | // - Add message SparseTensorProto
85 | // - Add sparse initializers
86 | IR_VERSION_2019_9_19 = 0x0000000000000006;
87 |
88 | // IR VERSION 7 published on May 8, 2020
89 | // - Add support to allow function body graph to rely on multiple external opreator sets.
90 | // - Add a list to promote inference graph's initializers to global and
91 | // mutable variables. Global variables are visible in all graphs of the
92 | // stored models.
93 | // - Add message TrainingInfoProto to store initialization
94 | // method and training algorithm. The execution of TrainingInfoProto
95 | // can modify the values of mutable variables.
96 | // - Implicitly add inference graph into each TrainingInfoProto's algorithm.
97 | IR_VERSION_2020_5_8 = 0x0000000000000007;
98 |
99 | // IR VERSION 8 published on July 30, 2021
100 | // Introduce TypeProto.SparseTensor
101 | // Introduce TypeProto.Optional
102 | // Added a list of FunctionProtos local to the model
103 | // Deprecated since_version and operator status from FunctionProto
104 | IR_VERSION_2021_7_30 = 0x0000000000000008;
105 |
106 | // IR VERSION 9 published on May 5, 2023
107 | // Added AttributeProto to FunctionProto so that default attribute values can be set.
108 | // Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
109 | IR_VERSION_2023_5_5 = 0x0000000000000009;
110 |
111 | // IR VERSION 10 published on March 25, 2024
112 | // Added UINT4, INT4.
113 | IR_VERSION_2024_3_25 = 0x000000000000000A;
114 |
115 | // IR VERSION 11 published on TBD
116 | // Added FLOAT4E2M1.
117 | IR_VERSION = 0x000000000000000B;
118 | }
119 |
120 | // Attributes
121 | //
122 | // A named attribute containing either singular float, integer, string, graph,
123 | // and tensor values, or repeated float, integer, string, graph, and tensor values.
124 | // An AttributeProto MUST contain the name field, and *only one* of the
125 | // following content fields, effectively enforcing a C/C++ union equivalent.
126 | message AttributeProto {
127 | reserved 12, 16 to 19;
128 | reserved "v";
129 |
130 | // Note: this enum is structurally identical to the OpSchema::AttrType
131 | // enum defined in schema.h. If you rev one, you likely need to rev the other.
132 | enum AttributeType {
133 | UNDEFINED = 0;
134 | FLOAT = 1;
135 | INT = 2;
136 | STRING = 3;
137 | TENSOR = 4;
138 | GRAPH = 5;
139 | SPARSE_TENSOR = 11;
140 | TYPE_PROTO = 13;
141 |
142 | FLOATS = 6;
143 | INTS = 7;
144 | STRINGS = 8;
145 | TENSORS = 9;
146 | GRAPHS = 10;
147 | SPARSE_TENSORS = 12;
148 | TYPE_PROTOS = 14;
149 | }
150 |
151 | // The name field MUST be present for this version of the IR.
152 | string name = 1; // namespace Attribute
153 |
154 | // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
155 | // In this case, this AttributeProto does not contain data, and it's a reference of attribute
156 | // in parent scope.
157 | // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
158 | string ref_attr_name = 21;
159 |
160 | // A human-readable documentation for this attribute. Markdown is allowed.
161 | string doc_string = 13;
162 |
163 | // The type field MUST be present for this version of the IR.
164 | // For 0.0.1 versions of the IR, this field was not defined, and
165 | // implementations needed to use has_field heuristics to determine
166 | // which value field was in use. For IR_VERSION 0.0.2 or later, this
167 | // field MUST be set and match the f|i|s|t|... field in use. This
168 | // change was made to accommodate proto3 implementations.
169 | AttributeType type = 20; // discriminator that indicates which field below is in use
170 |
171 | // Exactly ONE of the following fields must be present for this version of the IR
172 | float f = 2; // float
173 | int64 i = 3; // int
174 | bytes s = 4; // UTF-8 string
175 | TensorProto t = 5; // tensor value
176 | GraphProto g = 6; // graph
177 | SparseTensorProto sparse_tensor = 22; // sparse tensor value
178 | // Do not use field below, it's deprecated.
179 | // optional ValueProto v = 12; // value - subsumes everything but graph
180 | TypeProto tp = 14; // type proto
181 |
182 | repeated float floats = 7; // list of floats
183 | repeated int64 ints = 8; // list of ints
184 | repeated bytes strings = 9; // list of UTF-8 strings
185 | repeated TensorProto tensors = 10; // list of tensors
186 | repeated GraphProto graphs = 11; // list of graph
187 | repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors
188 | repeated TypeProto type_protos = 15;// list of type protos
189 | }
190 |
191 | // Defines information on value, including the name, the type, and
192 | // the shape of the value.
193 | message ValueInfoProto {
194 | // This field MUST be present in this version of the IR.
195 | string name = 1; // namespace Value
196 | // This field MUST be present in this version of the IR for
197 | // inputs and outputs of the top-level graph.
198 | TypeProto type = 2;
199 | // A human-readable documentation for this value. Markdown is allowed.
200 | string doc_string = 3;
201 | // Named metadata values; keys should be distinct.
202 | repeated StringStringEntryProto metadata_props = 4;
203 | }
204 |
205 | // Nodes
206 | //
207 | // Computation graphs are made up of a DAG of nodes, which represent what is
208 | // commonly called a "layer" or "pipeline stage" in machine learning frameworks.
209 | //
210 | // For example, it can be a node of type "Conv" that takes in an image, a filter
211 | // tensor and a bias tensor, and produces the convolved output.
212 | message NodeProto {
213 | repeated string input = 1; // namespace Value
214 | repeated string output = 2; // namespace Value
215 |
216 | // An optional identifier for this node in a graph.
217 | // This field MAY be absent in this version of the IR.
218 | string name = 3; // namespace Node
219 |
220 | // The symbolic identifier of the Operator to execute.
221 | string op_type = 4; // namespace Operator
222 | // The domain of the OperatorSet that specifies the operator named by op_type.
223 | string domain = 7; // namespace Domain
224 | // Overload identifier, used only to map this to a model-local function.
225 | string overload = 8;
226 |
227 | // Additional named attributes.
228 | repeated AttributeProto attribute = 5;
229 |
230 | // A human-readable documentation for this node. Markdown is allowed.
231 | string doc_string = 6;
232 |
233 | // Named metadata values; keys should be distinct.
234 | repeated StringStringEntryProto metadata_props = 9;
235 | }
236 |
237 | // Training information
238 | // TrainingInfoProto stores information for training a model.
239 | // In particular, this defines two functionalities: an initialization-step
240 | // and a training-algorithm-step. Initialization resets the model
241 | // back to its original state as if no training has been performed.
242 | // Training algorithm improves the model based on input data.
243 | //
244 | // The semantics of the initialization-step is that the initializers
245 | // in ModelProto.graph and in TrainingInfoProto.algorithm are first
246 | // initialized as specified by the initializers in the graph, and then
247 | // updated by the "initialization_binding" in every instance in
248 | // ModelProto.training_info.
249 | //
250 | // The field "algorithm" defines a computation graph which represents a
251 | // training algorithm's step. After the execution of a
252 | // TrainingInfoProto.algorithm, the initializers specified by "update_binding"
253 | // may be immediately updated. If the targeted training algorithm contains
254 | // consecutive update steps (such as block coordinate descent methods),
255 | // the user needs to create a TrainingInfoProto for each step.
256 | message TrainingInfoProto {
257 | // This field describes a graph to compute the initial tensors
258 | // upon starting the training process. Initialization graph has no input
259 | // and can have multiple outputs. Usually, trainable tensors in neural
260 | // networks are randomly initialized. To achieve that, for each tensor,
261 | // the user can put a random number operator such as RandomNormal or
262 | // RandomUniform in TrainingInfoProto.initialization.node and assign its
263 | // random output to the specific tensor using "initialization_binding".
264 | // This graph can also set the initializers in "algorithm" in the same
265 | // TrainingInfoProto; a use case is resetting the number of training
266 | // iteration to zero.
267 | //
268 | // By default, this field is an empty graph and its evaluation does not
269 | // produce any output. Thus, no initializer would be changed by default.
270 | GraphProto initialization = 1;
271 |
272 | // This field represents a training algorithm step. Given required inputs,
273 | // it computes outputs to update initializers in its own or inference graph's
274 | // initializer lists. In general, this field contains loss node, gradient node,
275 | // optimizer node, increment of iteration count.
276 | //
277 | // An execution of the training algorithm step is performed by executing the
278 | // graph obtained by combining the inference graph (namely "ModelProto.graph")
279 | // and the "algorithm" graph. That is, the actual
280 | // input/initializer/output/node/value_info/sparse_initializer list of
281 | // the training graph is the concatenation of
282 | // "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer"
283 | // and "algorithm.input/initializer/output/node/value_info/sparse_initializer"
284 | // in that order. This combined graph must satisfy the normal ONNX conditions.
285 | // Now, let's provide a visualization of graph combination for clarity.
286 | // Let the inference graph (i.e., "ModelProto.graph") be
287 | // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d
288 | // and the "algorithm" graph be
289 | // tensor_d -> Add -> tensor_e
290 | // The combination process results
291 | // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e
292 | //
293 | // Notice that an input of a node in the "algorithm" graph may reference the
294 | // output of a node in the inference graph (but not the other way round). Also, inference
295 | // node cannot reference inputs of "algorithm". With these restrictions, inference graph
296 | // can always be run independently without training information.
297 | //
298 | // By default, this field is an empty graph and its evaluation does not
299 | // produce any output. Evaluating the default training step never
300 | // update any initializers.
301 | GraphProto algorithm = 2;
302 |
303 | // This field specifies the bindings from the outputs of "initialization" to
304 | // some initializers in "ModelProto.graph.initializer" and
305 | // the "algorithm.initializer" in the same TrainingInfoProto.
306 | // See "update_binding" below for details.
307 | //
308 | // By default, this field is empty and no initializer would be changed
309 | // by the execution of "initialization".
310 | repeated StringStringEntryProto initialization_binding = 3;
311 |
312 | // Gradient-based training is usually an iterative procedure. In one gradient
313 | // descent iteration, we apply
314 | //
315 | // x = x - r * g
316 | //
317 | // where "x" is the optimized tensor, "r" stands for learning rate, and "g" is
318 | // gradient of "x" with respect to a chosen loss. To avoid adding assignments
319 | // into the training graph, we split the update equation into
320 | //
321 | // y = x - r * g
322 | // x = y
323 | //
324 | // The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To
325 | // tell that "y" should be assigned to "x", the field "update_binding" may
326 | // contain a key-value pair of strings, "x" (key of StringStringEntryProto)
327 | // and "y" (value of StringStringEntryProto).
328 | // For a neural network with multiple trainable (mutable) tensors, there can
329 | // be multiple key-value pairs in "update_binding".
330 | //
331 | // The initializers appears as keys in "update_binding" are considered
332 | // mutable variables. This implies some behaviors
333 | // as described below.
334 | //
335 | // 1. We have only unique keys in all "update_binding"s so that two
336 | // variables may not have the same name. This ensures that one
337 | // variable is assigned up to once.
338 | // 2. The keys must appear in names of "ModelProto.graph.initializer" or
339 | // "TrainingInfoProto.algorithm.initializer".
340 | // 3. The values must be output names of "algorithm" or "ModelProto.graph.output".
341 | // 4. Mutable variables are initialized to the value specified by the
342 | // corresponding initializer, and then potentially updated by
343 | // "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
344 | //
345 | // This field usually contains names of trainable tensors
346 | // (in ModelProto.graph), optimizer states such as momentums in advanced
347 | // stochastic gradient methods (in TrainingInfoProto.graph),
348 | // and number of training iterations (in TrainingInfoProto.graph).
349 | //
350 | // By default, this field is empty and no initializer would be changed
351 | // by the execution of "algorithm".
352 | repeated StringStringEntryProto update_binding = 4;
353 | }
354 |
355 | // Models
356 | //
357 | // ModelProto is a top-level file/container format for bundling a ML model and
358 | // associating its computation graph with metadata.
359 | //
360 | // The semantics of the model are described by the associated GraphProto's.
361 | message ModelProto {
362 | // The version of the IR this model targets. See Version enum above.
363 | // This field MUST be present.
364 | int64 ir_version = 1;
365 |
366 | // The OperatorSets this model relies on.
367 | // All ModelProtos MUST have at least one entry that
368 | // specifies which version of the ONNX OperatorSet is
369 | // being imported.
370 | //
371 | // All nodes in the ModelProto's graph will bind against the operator
372 | // with the same-domain/same-op_type operator with the HIGHEST version
373 | // in the referenced operator sets.
374 | repeated OperatorSetIdProto opset_import = 8;
375 |
376 | // The name of the framework or tool used to generate this model.
377 | // This field SHOULD be present to indicate which implementation/tool/framework
378 | // emitted the model.
379 | string producer_name = 2;
380 |
381 | // The version of the framework or tool used to generate this model.
382 | // This field SHOULD be present to indicate which implementation/tool/framework
383 | // emitted the model.
384 | string producer_version = 3;
385 |
386 | // Domain name of the model.
387 | // We use reverse domain names as name space indicators. For example:
388 | // `com.facebook.fair` or `com.microsoft.cognitiveservices`
389 | //
390 | // Together with `model_version` and GraphProto.name, this forms the unique identity of
391 | // the graph.
392 | string domain = 4;
393 |
394 | // The version of the graph encoded. See Version enum below.
395 | int64 model_version = 5;
396 |
397 | // A human-readable documentation for this model. Markdown is allowed.
398 | string doc_string = 6;
399 |
400 | // The parameterized graph that is evaluated to execute the model.
401 | GraphProto graph = 7;
402 |
403 | // Named metadata values; keys should be distinct.
404 | repeated StringStringEntryProto metadata_props = 14;
405 |
406 | // Training-specific information. Sequentially executing all stored
407 | // `TrainingInfoProto.algorithm`s and assigning their outputs following
408 | // the corresponding `TrainingInfoProto.update_binding`s is one training
409 | // iteration. Similarly, to initialize the model
410 | // (as if training hasn't happened), the user should sequentially execute
411 | // all stored `TrainingInfoProto.initialization`s and assigns their outputs
412 | // using `TrainingInfoProto.initialization_binding`s.
413 | //
414 | // If this field is empty, the training behavior of the model is undefined.
415 | repeated TrainingInfoProto training_info = 20;
416 |
417 | // A list of function protos local to the model.
418 | //
419 | // The (domain, name, overload) tuple must be unique across the function protos in this list.
420 | // In case of any conflicts the behavior (whether the model local functions are given higher priority,
421 | // or standard operator sets are given higher priotity or this is treated as error) is defined by
422 | // the runtimes.
423 | //
424 | // The operator sets imported by FunctionProto should be compatible with the ones
425 | // imported by ModelProto and other model local FunctionProtos.
426 | // Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto
427 | // or by 2 FunctionProtos then versions for the operator set may be different but,
428 | // the operator schema returned for op_type, domain, version combination
429 | // for both the versions should be same for every node in the function body.
430 | //
431 | // One FunctionProto can reference other FunctionProto in the model, however, recursive reference
432 | // is not allowed.
433 | repeated FunctionProto functions = 25;
434 | };
435 |
436 | // StringStringEntryProto follows the pattern for cross-proto-version maps.
437 | // See https://developers.google.com/protocol-buffers/docs/proto3#maps
438 | message StringStringEntryProto {
439 | string key = 1;
440 | string value = 2;
441 | };
442 |
443 | message TensorAnnotation {
444 | string tensor_name = 1;
445 | // pairs to annotate tensor specified by above.
446 | // The keys used in the mapping below must be pre-defined in ONNX spec.
447 | // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as
448 | // quantization parameter keys.
449 | repeated StringStringEntryProto quant_parameter_tensor_names = 2;
450 | }
451 |
452 |
453 |
454 | // Graphs
455 | //
456 | // A graph defines the computational logic of a model and is comprised of a parameterized
457 | // list of nodes that form a directed acyclic graph based on their inputs and outputs.
458 | // This is the equivalent of the "network" or "graph" in many deep learning
459 | // frameworks.
460 | message GraphProto {
461 | // The nodes in the graph, sorted topologically.
462 | repeated NodeProto node = 1;
463 |
464 | // The name of the graph.
465 | string name = 2; // namespace Graph
466 |
467 | // A list of named tensor values, used to specify constant inputs of the graph.
468 | // Each initializer (both TensorProto as well SparseTensorProto) MUST have a name.
469 | // The name MUST be unique across both initializer and sparse_initializer,
470 | // but the name MAY also appear in the input list.
471 | repeated TensorProto initializer = 5;
472 |
473 | // Initializers (see above) stored in sparse format.
474 | repeated SparseTensorProto sparse_initializer = 15;
475 |
476 | // A human-readable documentation for this graph. Markdown is allowed.
477 | string doc_string = 10;
478 |
479 | // The inputs and outputs of the graph.
480 | repeated ValueInfoProto input = 11;
481 | repeated ValueInfoProto output = 12;
482 |
483 | // Information for the values in the graph. The ValueInfoProto.name's
484 | // must be distinct. It is optional for a value to appear in value_info list.
485 | repeated ValueInfoProto value_info = 13;
486 |
487 | // This field carries information to indicate the mapping among a tensor and its
488 | // quantization parameter tensors. For example:
489 | // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated,
490 | // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
491 | repeated TensorAnnotation quantization_annotation = 14;
492 |
493 | // Named metadata values; keys should be distinct.
494 | repeated StringStringEntryProto metadata_props = 16;
495 |
496 | reserved 3, 4, 6 to 9;
497 | reserved "ir_version", "producer_version", "producer_tag", "domain";
498 | }
499 |
500 | // Tensors
501 | //
502 | // A serialized tensor value.
503 | message TensorProto {
504 | enum DataType {
505 | UNDEFINED = 0;
506 | // Basic types.
507 | FLOAT = 1; // float
508 | UINT8 = 2; // uint8_t
509 | INT8 = 3; // int8_t
510 | UINT16 = 4; // uint16_t
511 | INT16 = 5; // int16_t
512 | INT32 = 6; // int32_t
513 | INT64 = 7; // int64_t
514 | STRING = 8; // string
515 | BOOL = 9; // bool
516 |
517 | // IEEE754 half-precision floating-point format (16 bits wide).
518 | // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
519 | FLOAT16 = 10;
520 |
521 | DOUBLE = 11;
522 | UINT32 = 12;
523 | UINT64 = 13;
524 | COMPLEX64 = 14; // complex with float32 real and imaginary components
525 | COMPLEX128 = 15; // complex with float64 real and imaginary components
526 |
527 | // Non-IEEE floating-point format based on IEEE754 single-precision
528 | // floating-point number truncated to 16 bits.
529 | // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
530 | BFLOAT16 = 16;
531 |
532 | // Non-IEEE floating-point format based on papers
533 | // FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433,
534 | // 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf.
535 | // Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear.
536 | // The computation usually happens inside a block quantize / dequantize
537 | // fused by the runtime.
538 | FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf
539 | FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero
540 | FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
541 | FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, not inf, mostly used for gradients, no negative zero
542 |
543 | // 4-bit integer data types
544 | UINT4 = 21; // Unsigned integer in range [0, 15]
545 | INT4 = 22; // Signed integer in range [-8, 7], using two's-complement representation
546 |
547 | // 4-bit floating point data types
548 | FLOAT4E2M1 = 23;
549 |
550 | // Future extensions go here.
551 | }
552 |
553 | // The shape of the tensor.
554 | repeated int64 dims = 1;
555 |
556 | // The data type of the tensor.
557 | // This field MUST have a valid TensorProto.DataType value
558 | int32 data_type = 2;
559 |
560 | // For very large tensors, we may want to store them in chunks, in which
561 | // case the following fields will specify the segment that is stored in
562 | // the current TensorProto.
563 | message Segment {
564 | int64 begin = 1;
565 | int64 end = 2;
566 | }
567 | Segment segment = 3;
568 |
569 | // Tensor content must be organized in row-major order.
570 | //
571 | // Depending on the data_type field, exactly one of the fields below with
572 | // name ending in _data is used to store the elements of the tensor.
573 |
574 | // For float and complex64 values
575 | // Complex64 tensors are encoded as a single array of floats,
576 | // with the real components appearing in odd numbered positions,
577 | // and the corresponding imaginary component appearing in the
578 | // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
579 | // is encoded as [1.0, 2.0 ,3.0 ,4.0]
580 | // When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
581 | repeated float float_data = 4 [packed = true];
582 |
583 | // For int32, uint8, int8, uint16, int16, uint4, int4, bool, float8 and float16 values
584 | // float16 and float8 values must be bit-wise converted to an uint16_t prior
585 | // to writing to the buffer.
586 | // uint4 and int4 values must be packed to 4bitx2 prior to writing to the buffer, the first element is stored in
587 | // the 4 LSB and the second element is stored in the 4 MSB.
588 | // When this field is present, the data_type field MUST be
589 | // INT32, INT16, INT8, INT4, UINT16, UINT8, UINT4, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ
590 | repeated int32 int32_data = 5 [packed = true];
591 |
592 | // For strings.
593 | // Each element of string_data is a UTF-8 encoded Unicode
594 | // string. No trailing null, no leading BOM. The protobuf "string"
595 | // scalar type is not used to match ML community conventions.
596 | // When this field is present, the data_type field MUST be STRING
597 | repeated bytes string_data = 6;
598 |
599 | // For int64.
600 | // When this field is present, the data_type field MUST be INT64
601 | repeated int64 int64_data = 7 [packed = true];
602 |
603 | // Optionally, a name for the tensor.
604 | string name = 8; // namespace Value
605 |
606 | // A human-readable documentation for this tensor. Markdown is allowed.
607 | string doc_string = 12;
608 |
609 | // Serializations can either use one of the fields above, or use this
610 | // raw bytes field. The only exception is the string case, where one is
611 | // required to store the content in the repeated bytes string_data field.
612 | //
613 | // When this raw_data field is used to store tensor value, elements MUST
614 | // be stored in as fixed-width, little-endian order.
615 | // Floating-point data types MUST be stored in IEEE 754 format.
616 | // Complex64 elements must be written as two consecutive FLOAT values, real component first.
617 | // Complex128 elements must be written as two consecutive DOUBLE values, real component first.
618 | // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
619 | // uint4 and int4 values must be packed to 4bitx2, the first element is stored in the 4 LSB and the second element is stored in the 4 MSB.
620 | //
621 | // Note: the advantage of specific field rather than the raw_data field is
622 | // that in some cases (e.g. int data), protobuf does a better packing via
623 | // variable length storage, and may lead to smaller binary footprint.
624 | // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
625 | bytes raw_data = 9;
626 |
627 | // Data can be stored inside the protobuf file using type-specific fields or raw_data.
628 | // Alternatively, raw bytes data can be stored in an external file, using the external_data field.
629 | // external_data stores key-value pairs describing data location. Recognized keys are:
630 | // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX
631 | // protobuf model was stored
632 | // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string.
633 | // Offset values SHOULD be multiples 4096 (page size) to enable mmap support.
634 | // - "length" (optional) - number of bytes containing data. Integer stored as string.
635 | // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key.
636 | repeated StringStringEntryProto external_data = 13;
637 |
638 | // Location of the data for this tensor. MUST be one of:
639 | // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field.
640 | // - EXTERNAL - data stored in an external location as described by external_data field.
641 | enum DataLocation {
642 | DEFAULT = 0;
643 | EXTERNAL = 1;
644 | }
645 |
646 | // If value not set, data is stored in raw_data (if set) otherwise in type-specified field.
647 | DataLocation data_location = 14;
648 |
649 | // For double
650 | // Complex128 tensors are encoded as a single array of doubles,
651 | // with the real components appearing in odd numbered positions,
652 | // and the corresponding imaginary component appearing in the
653 | // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
654 | // is encoded as [1.0, 2.0 ,3.0 ,4.0]
655 | // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
656 | repeated double double_data = 10 [packed = true];
657 |
658 | // For uint64 and uint32 values
659 | // When this field is present, the data_type field MUST be
660 | // UINT32 or UINT64
661 | repeated uint64 uint64_data = 11 [packed = true];
662 |
663 | // Named metadata values; keys should be distinct.
664 | repeated StringStringEntryProto metadata_props = 16;
665 | }
666 |
667 | // A serialized sparse-tensor value
668 | message SparseTensorProto {
669 | // The sequence of non-default values are encoded as a tensor of shape [NNZ].
670 | // The default-value is zero for numeric tensors, and empty-string for string tensors.
671 | // values must have a non-empty name present which serves as a name for SparseTensorProto
672 | // when used in sparse_initializer list.
673 | TensorProto values = 1;
674 |
675 | // The indices of the non-default values, which may be stored in one of two formats.
676 | // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value
677 | // corresponding to the j-th index of the i-th value (in the values tensor).
678 | // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value
679 | // must be the linearized-index of the i-th value (in the values tensor).
680 | // The linearized-index can be converted into an index tuple (k_1,...,k_rank)
681 | // using the shape provided below.
682 | // The indices must appear in ascending order without duplication.
683 | // In the first format, the ordering is lexicographic-ordering:
684 | // e.g., index-value [1,4] must appear before [2,1]
685 | TensorProto indices = 2;
686 |
687 | // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank]
688 | repeated int64 dims = 3;
689 | }
690 |
691 | // Defines a tensor shape. A dimension can be either an integer value
692 | // or a symbolic variable. A symbolic variable represents an unknown
693 | // dimension.
694 | message TensorShapeProto {
695 | message Dimension {
696 | oneof value {
697 | int64 dim_value = 1;
698 | string dim_param = 2; // namespace Shape
699 | };
700 | // Standard denotation can optionally be used to denote tensor
701 | // dimensions with standard semantic descriptions to ensure
702 | // that operations are applied to the correct axis of a tensor.
703 | // Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
704 | // for pre-defined dimension denotations.
705 | string denotation = 3;
706 | };
707 | repeated Dimension dim = 1;
708 | }
709 |
710 | // Types
711 | //
712 | // The standard ONNX data types.
713 | message TypeProto {
714 |
715 | message Tensor {
716 | // This field MUST NOT have the value of UNDEFINED
717 | // This field MUST have a valid TensorProto.DataType value
718 | // This field MUST be present for this version of the IR.
719 | int32 elem_type = 1;
720 | TensorShapeProto shape = 2;
721 | }
722 |
723 | // repeated T
724 | message Sequence {
725 | // The type and optional shape of each element of the sequence.
726 | // This field MUST be present for this version of the IR.
727 | TypeProto elem_type = 1;
728 | };
729 |
730 | // map
731 | message Map {
732 | // This field MUST have a valid TensorProto.DataType value
733 | // This field MUST be present for this version of the IR.
734 | // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING
735 | int32 key_type = 1;
736 | // This field MUST be present for this version of the IR.
737 | TypeProto value_type = 2;
738 | };
739 |
740 | // wrapper for Tensor, Sequence, or Map
741 | message Optional {
742 | // The type and optional shape of the element wrapped.
743 | // This field MUST be present for this version of the IR.
744 | // Possible values correspond to OptionalProto.DataType enum
745 | TypeProto elem_type = 1;
746 | };
747 |
748 |
749 | message SparseTensor {
750 | // This field MUST NOT have the value of UNDEFINED
751 | // This field MUST have a valid TensorProto.DataType value
752 | // This field MUST be present for this version of the IR.
753 | int32 elem_type = 1;
754 | TensorShapeProto shape = 2;
755 | }
756 |
757 |
758 | oneof value {
759 | // The type of a tensor.
760 | Tensor tensor_type = 1;
761 |
762 | // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values
763 | // as input and output to graphs and nodes. These types are needed to naturally
764 | // support classical ML operators. DNN operators SHOULD restrict their input
765 | // and output types to tensors.
766 |
767 | // The type of a sequence.
768 | Sequence sequence_type = 4;
769 |
770 | // The type of a map.
771 | Map map_type = 5;
772 |
773 | // The type of an optional.
774 | Optional optional_type = 9;
775 |
776 |
777 | // Type of the sparse tensor
778 | SparseTensor sparse_tensor_type = 8;
779 |
780 | }
781 |
782 | // An optional denotation can be used to denote the whole
783 | // type with a standard semantic description as to what is
784 | // stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition
785 | // for pre-defined type denotations.
786 | string denotation = 6;
787 | }
788 |
789 | // Operator Sets
790 | //
791 | // OperatorSets are uniquely identified by a (domain, opset_version) pair.
792 | message OperatorSetIdProto {
793 | // The domain of the operator set being identified.
794 | // The empty string ("") or absence of this field implies the operator
795 | // set that is defined as part of the ONNX specification.
796 | // This field MUST be present in this version of the IR when referring to any other operator set.
797 | string domain = 1;
798 |
799 | // The version of the operator set being identified.
800 | // This field MUST be present in this version of the IR.
801 | int64 version = 2;
802 | }
803 |
804 | // Operator/function status.
805 | enum OperatorStatus {
806 | EXPERIMENTAL = 0;
807 | STABLE = 1;
808 | }
809 |
810 | message FunctionProto {
811 | // The name of the function, similar to op_type in NodeProto.
812 | // This is part of the unique-id (domain, name, overload) of FunctionProtos in a model.
813 | string name = 1;
814 |
815 | // Deprecated since IR Version 8
816 | // optional int64 since_version = 2;
817 | reserved 2;
818 | reserved "since_version";
819 |
820 | // Deprecated since IR Version 8
821 | // optional OperatorStatus status = 3;
822 | reserved 3;
823 | reserved "status";
824 |
825 | // The inputs and outputs of the function.
826 | repeated string input = 4;
827 | repeated string output = 5;
828 |
829 | // The attribute parameters of the function.
830 | // It is for function parameters without default values.
831 | repeated string attribute = 6;
832 |
833 | // The attribute protos of the function.
834 | // It is for function attributes with default values.
835 | // A function attribute shall be represented either as
836 | // a string attribute or an AttributeProto, not both.
837 | repeated AttributeProto attribute_proto = 11;
838 |
839 | // The nodes in the function.
840 | repeated NodeProto node = 7;
841 | // A human-readable documentation for this function. Markdown is allowed.
842 | string doc_string = 8;
843 |
844 | // The OperatorSets this function body (graph) relies on.
845 | //
846 | // All nodes in the function body (graph) will bind against the operator
847 | // with the same-domain/same-op_type operator with the HIGHEST version
848 | // in the referenced operator sets. This means at most one version can be relied
849 | // for one domain.
850 | //
851 | // The operator sets imported by FunctionProto should be compatible with the ones
852 | // imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto
853 | // and ModelProto then versions for the operator set may be different but,
854 | // the operator schema returned for op_type, domain, version combination
855 | // for both the versions should be same.
856 |
857 | repeated OperatorSetIdProto opset_import = 9;
858 |
859 | // The domain which this function belongs to.
860 | // This is part of the unique-id (domain, name, overload) of FunctionProtos in a model.
861 | string domain = 10;
862 |
863 | // The overload identifier of the function.
864 | // This is part of the unique-id (domain, name, overload) of FunctionProtos in a model.
865 | string overload = 13;
866 |
867 | // Information for the values in the function. The ValueInfoProto.name's
868 | // must be distinct and refer to names in the function (including inputs,
869 | // outputs, and intermediate values). It is optional for a value to appear
870 | // in value_info list.
871 | repeated ValueInfoProto value_info = 12;
872 |
873 | // Named metadata values; keys should be distinct.
874 | repeated StringStringEntryProto metadata_props = 14;
875 | }
876 |
877 | // For using protobuf-lite
878 | option optimize_for = LITE_RUNTIME;
879 |
880 |
--------------------------------------------------------------------------------