├── .gitignore
├── ComLightLib
├── ComLightLib.vcxproj
├── ComLightLib.vcxproj.filters
├── Exception.hpp
├── Readme.txt
├── client
│ └── CComPtr.hpp
├── comLightClient.h
├── comLightCommon.h
├── comLightServer.h
├── hresult.h
├── pal
│ ├── guiddef.h
│ └── hresult.h
├── server
│ ├── Object.hpp
│ ├── ObjectRoot.hpp
│ ├── RefCounter.hpp
│ ├── freeThreadedMarshaller.cpp
│ ├── freeThreadedMarshaller.h
│ └── interfaceMap.h
├── streams.h
├── unknwn.h
└── utils
│ ├── guid_parse.hpp
│ └── typeTraits.hpp
├── ComputeShaders
├── ComputeShaders.cpp
├── ComputeShaders.vcxproj
├── ComputeShaders.vcxproj.filters
├── Readme.txt
├── add.hlsl
├── addInPlace.hlsl
├── addRepeat.hlsl
├── addRepeat64.hlsl
├── addRepeatEx.hlsl
├── addRepeatGelu.hlsl
├── addRepeatGelu64.hlsl
├── addRepeatScale.hlsl
├── addRows.hlsl
├── componentwiseBinaryOp.hlsli
├── convolutionMain.hlsl
├── convolutionMain2.hlsl
├── convolutionMain2Fixed.hlsl
├── convolutionPrep1.hlsl
├── convolutionPrep2.hlsl
├── copyConvert.hlsl
├── copyTranspose.hlsl
├── dbgFindNaN.hlsl
├── diagMaskInf.hlsl
├── flashAttention.hlsl
├── flashAttentionCommon.hlsli
├── flashAttentionCompat1.hlsl
├── flashAttentionCompat2.hlsl
├── flashAttentionCompat3.hlsl
├── fmaRepeat1.hlsl
├── fmaRepeat164.hlsl
├── fmaRepeat2.hlsl
├── fp64Utils.hlsli
├── groupReduce.hlsli
├── groupReduce64.hlsli
├── matReshapePanels.hlsl
├── miscUtils.hlsli
├── mulMatByRow.hlsl
├── mulMatByRow64.hlsl
├── mulMatByRowTiled.hlsl
├── mulMatByRowTiledEx.hlsl
├── mulMatByScalar.hlsl
├── mulMatDotMain.hlsl
├── mulMatDotReshape.hlsl
├── mulMatMadMain.hlsl
├── mulMatTiled.hlsl
├── mulMatTiledEx.hlsl
├── norm.hlsl
├── normCompat.hlsl
├── normFixed.hlsl
├── normFixed64.hlsl
├── repeatUtils.hlsli
├── scaleInPlace.hlsl
├── softMax.hlsl
├── softMax64.hlsl
├── softMaxCompat.hlsl
├── softMaxFixed.hlsl
├── softMaxLong.hlsl
└── zeroMemory.hlsl
├── Examples
├── MicrophoneCS
│ ├── CaptureThread.cs
│ ├── CommandLineArgs.cs
│ ├── MicrophoneCS.cs
│ ├── MicrophoneCS.csproj
│ ├── Readme.txt
│ └── TranscribeCallbacks.cs
├── OldMain
│ ├── OldMain.vcxproj
│ ├── OldMain.vcxproj.filters
│ ├── Readme.txt
│ ├── Utils
│ │ ├── Logger.cpp
│ │ └── Logger.h
│ ├── dr_wav.h
│ └── main.cpp
├── TranscribeCS
│ ├── AnsiCodes.cs
│ ├── CommandLineArgs.cs
│ ├── Readme.txt
│ ├── Transcribe.cs
│ ├── TranscribeCS.cs
│ └── TranscribeCS.csproj
├── WhisperDesktop
│ ├── AppState.cpp
│ ├── AppState.h
│ ├── CaptureDlg.cpp
│ ├── CaptureDlg.h
│ ├── CircleIndicator.cpp
│ ├── CircleIndicator.h
│ ├── LoadModelDlg.cpp
│ ├── LoadModelDlg.h
│ ├── ModelAdvancedDlg.cpp
│ ├── ModelAdvancedDlg.h
│ ├── Readme.txt
│ ├── Resource.h
│ ├── TranscribeDlg.cpp
│ ├── TranscribeDlg.h
│ ├── Utils
│ │ ├── DebugConsole.cpp
│ │ ├── DebugConsole.h
│ │ ├── LanguageDropdown.cpp
│ │ ├── LanguageDropdown.h
│ │ ├── PendingState.cpp
│ │ ├── PendingState.h
│ │ ├── TranslateCheckbox.cpp
│ │ ├── TranslateCheckbox.h
│ │ ├── WTL
│ │ │ ├── MS-PL.txt
│ │ │ ├── ReadMe.html
│ │ │ ├── atlapp.h
│ │ │ ├── atlcrack.h
│ │ │ ├── atlctrls.h
│ │ │ ├── atlddx.h
│ │ │ ├── atlgdi.h
│ │ │ ├── atlres.h
│ │ │ ├── atluser.h
│ │ │ └── atlwinx.h
│ │ ├── logger.cpp
│ │ ├── logger.h
│ │ ├── miscUtils.cpp
│ │ └── miscUtils.h
│ ├── WhisperDesktop.cpp
│ ├── WhisperDesktop.manifest
│ ├── WhisperDesktop.rc
│ ├── WhisperDesktop.vcxproj
│ ├── WhisperDesktop.vcxproj.filters
│ ├── framework.h
│ ├── stdafx.cpp
│ ├── stdafx.h
│ ├── sunflower.ico
│ ├── targetver.h
│ └── useDiscreteGpu.c
└── main
│ ├── Readme.txt
│ ├── main.cpp
│ ├── main.vcxproj
│ ├── main.vcxproj.filters
│ ├── miscUtils.cpp
│ ├── miscUtils.h
│ ├── params.cpp
│ ├── params.h
│ ├── textWriter.cpp
│ └── textWriter.h
├── LICENSE
├── Readme.md
├── SampleClips
├── Readme.txt
├── columbia-large-1080ti.txt
├── columbia-large-1650.txt
├── columbia-large-vega7.txt
├── columbia-large-vega8.txt
├── columbia-medium-1080ti.txt
├── columbia-medium-1650.txt
├── columbia-medium-vega7.txt
├── columbia-medium-vega8.txt
├── columbia.wma
├── jfk-large-1080ti.txt
├── jfk-large-1650.txt
├── jfk-large-vega7.txt
├── jfk-large-vega8.txt
├── jfk-medium-1080ti.txt
├── jfk-medium-1650.txt
├── jfk-medium-vega7.txt
├── jfk-medium-vega8.txt
├── jfk.wav
└── summary.tsv
├── Tools
├── CompressShaders
│ ├── Cabinet.cs
│ ├── CompressShaders.cs
│ ├── CompressShaders.csproj
│ ├── DetectFp64.cs
│ ├── LZ4.cs
│ ├── LanguageCodes.cs
│ ├── Readme.txt
│ └── ShaderNames.cs
├── CompressTables
│ ├── CompressTables.cs
│ └── CompressTables.csproj
├── PerfSummary
│ ├── LogParser.cs
│ ├── PerfSummary.cs
│ ├── PerfSummary.csproj
│ └── Summary.cs
├── compareTraces
│ ├── CommandLineArgs.cpp
│ ├── CommandLineArgs.h
│ ├── Readme.txt
│ ├── TraceReader.cpp
│ ├── TraceReader.h
│ ├── compare.cpp
│ ├── compare.h
│ ├── compareTraces.cpp
│ ├── compareTraces.vcxproj
│ ├── compareTraces.vcxproj.filters
│ ├── stdafx.cpp
│ ├── stdafx.h
│ └── testUtils.cpp
└── copy-binaries.cmd
├── Whisper
├── API
│ ├── MfStructs.h
│ ├── Readme.txt
│ ├── SpecialTokens.h
│ ├── TranscribeStructs.h
│ ├── iContext.cl.h
│ ├── iContext.h
│ ├── iMediaFoundation.cl.h
│ ├── iMediaFoundation.h
│ ├── iTranscribeResult.cl.h
│ ├── iTranscribeResult.h
│ ├── loggerApi.h
│ ├── sFullParams.h
│ ├── sLanguageList.h
│ ├── sLoadModelCallbacks.h
│ ├── sModelSetup.h
│ ├── whisperComLight.h
│ └── whisperWindows.h
├── CPU
│ ├── BufferAllocator.cpp
│ ├── BufferAllocator.h
│ ├── DecoderTensors.cpp
│ ├── DecoderTensors.h
│ ├── HybridLoader.cpp
│ ├── HybridLoader.h
│ ├── KvTensors.h
│ ├── KvTensorsCpu.cpp
│ ├── LargeBuffer.cpp
│ ├── LargeBuffer.h
│ ├── MlContext.h
│ ├── MlContextCpu.cpp
│ ├── ParallelForRunner.cpp
│ ├── ParallelForRunner.h
│ ├── Readme.txt
│ ├── Tensor.h
│ ├── TensorCpu.cpp
│ ├── mulMat.cpp
│ ├── mulMat.h
│ ├── mulMat.kernel.hpp
│ ├── mulMatImpl.avx2.cpp
│ ├── mulMatImpl.cpp
│ ├── mulMatImpl.h
│ ├── mulMatImpl.panel.cpp
│ ├── mulMatUtils.hpp
│ ├── simdUtils.cpp
│ └── simdUtils.h
├── D3D
│ ├── Binder.cpp
│ ├── Binder.h
│ ├── MappedResource.cpp
│ ├── MappedResource.h
│ ├── RenderDoc
│ │ ├── renderDoc.cpp
│ │ ├── renderDoc.h
│ │ └── renderdoc_app.h
│ ├── createBuffer.cpp
│ ├── createBuffer.h
│ ├── createDevice.cpp
│ ├── createDevice.h
│ ├── device.h
│ ├── downloadBuffer.cpp
│ ├── downloadBuffer.h
│ ├── enums.cpp
│ ├── enums.h
│ ├── listGPUs.cpp
│ ├── listGPUs.h
│ ├── sGpuInfo.h
│ ├── shaderNames.cpp
│ ├── shaderNames.h
│ ├── shaders.cpp
│ └── shaders.h
├── DllMain.cpp
├── Hybrid
│ ├── HybridContext.cpp
│ ├── HybridContext.h
│ ├── KeyValueDownloader.cpp
│ ├── KeyValueDownloader.h
│ └── Readme.txt
├── MF
│ ├── AudioBuffer.cpp
│ ├── AudioBuffer.h
│ ├── AudioCapture.cpp
│ ├── AudioCapture.h
│ ├── MediaFoundation.cpp
│ ├── PcmReader.cpp
│ ├── PcmReader.h
│ ├── loadAudioFile.cpp
│ ├── loadAudioFile.h
│ ├── mfStartup.cpp
│ ├── mfStartup.h
│ ├── mfUtils.cpp
│ └── mfUtils.h
├── ML
│ ├── ConstantBuffer.cpp
│ ├── ConstantBuffer.h
│ ├── Context.ops.cpp
│ ├── DbgNanTest.cpp
│ ├── DbgNanTest.h
│ ├── Device.cpp
│ ├── Device.h
│ ├── LookupTables.cpp
│ ├── LookupTables.h
│ ├── LookupTablesData.cpp
│ ├── LookupTablesData.h
│ ├── LookupTablesData.inl
│ ├── MlContext.cpp
│ ├── MlContext.dbg.cpp
│ ├── MlContext.h
│ ├── Reshaper.cpp
│ ├── Reshaper.h
│ ├── TempBuffers.cpp
│ ├── TempBuffers.h
│ ├── Tensor.cpp
│ ├── Tensor.h
│ ├── TensorEx.cpp
│ ├── TensorEx.h
│ ├── TensorGpuViews.cpp
│ ├── TensorGpuViews.h
│ ├── TensorShape.cpp
│ ├── TensorShape.h
│ ├── TensorsArena.cpp
│ ├── TensorsArena.h
│ ├── mlUtils.cpp
│ ├── mlUtils.h
│ ├── reshapedMultiply.h
│ ├── tensorOpsTests.cpp
│ ├── tensorOpsTests.h
│ ├── testUtils.cpp
│ ├── testUtils.h
│ └── testUtilsC.h
├── Readme.txt
├── Resource.rc
├── Utils
│ ├── CpuProfiler.cpp
│ ├── CpuProfiler.h
│ ├── DelayExecution.cpp
│ ├── DelayExecution.h
│ ├── GpuProfiler.cpp
│ ├── GpuProfiler.h
│ ├── GpuProfilerSimple.h
│ ├── LZ4
│ │ ├── LICENSE
│ │ ├── lz4.c
│ │ └── lz4.h
│ ├── Logger.cpp
│ ├── Logger.h
│ ├── MurmurHash3.cpp
│ ├── MurmurHash3.h
│ ├── ProfileCollection.cpp
│ ├── ProfileCollection.h
│ ├── ReadStream.h
│ ├── Trace
│ │ ├── TraceStructures.cpp
│ │ ├── TraceStructures.h
│ │ ├── TraceWriter.cpp
│ │ ├── TraceWriter.h
│ │ ├── tracing.cpp
│ │ └── tracing.h
│ ├── miscUtils.cpp
│ ├── miscUtils.h
│ ├── parallelFor.cpp
│ └── parallelFor.h
├── Whisper.vcxproj
├── Whisper.vcxproj.filters
├── Whisper
│ ├── ContextImpl.capture.cpp
│ ├── ContextImpl.cpp
│ ├── ContextImpl.diarize.cpp
│ ├── ContextImpl.h
│ ├── ContextImpl.misc.cpp
│ ├── DecoderInputBuffers.cpp
│ ├── DecoderInputBuffers.h
│ ├── DecoderResultBuffer.cpp
│ ├── DecoderResultBuffer.h
│ ├── KeyValueBuffers.cpp
│ ├── KeyValueBuffers.h
│ ├── Languages.cpp
│ ├── Languages.h
│ ├── MelInputTensor.cpp
│ ├── MelInputTensor.h
│ ├── MelStreamer.cpp
│ ├── MelStreamer.h
│ ├── ModelBuffers.clone.cpp
│ ├── ModelBuffers.cpp
│ ├── ModelBuffers.h
│ ├── ModelImpl.cpp
│ ├── ModelImpl.h
│ ├── ModelLoader.h
│ ├── Spectrogram.cpp
│ ├── Spectrogram.h
│ ├── TranscribeResult.h
│ ├── Vocabulary.cpp
│ ├── Vocabulary.h
│ ├── WhisperContext.cpp
│ ├── WhisperContext.h
│ ├── WhisperModel.cpp
│ ├── WhisperModel.h
│ ├── audioConstants.h
│ ├── iSpectrogram.h
│ ├── languageCodez.inl
│ ├── languageCodez.tsv
│ ├── loaderUtils.h
│ ├── melSpectrogram.cpp
│ ├── melSpectrogram.h
│ ├── sEncodeParams.h
│ ├── sModelParams.h
│ ├── sTokenData.h
│ ├── voiceActivityDetection.cpp
│ └── voiceActivityDetection.h
├── misc.natvis
├── modelFactory.cpp
├── modelFactory.h
├── resource.h
├── source.compat
│ ├── Readme.txt
│ ├── convertThings.cpp
│ ├── convertThings.h
│ └── ggmlMsvc.c
├── source
│ ├── LICENSE
│ ├── Readme.txt
│ ├── ggml.c
│ ├── ggml.h
│ ├── whisper.cpp
│ └── whisper.h
├── stdafx.cpp
├── stdafx.h
├── whisper.def
└── whisperCom.cpp
├── WhisperCpp.sln
├── WhisperNet
├── API
│ ├── CaptureDeviceId.cs
│ ├── Parameters.cs
│ ├── SpecialTokens.cs
│ ├── eCaptureStatus.cs
│ ├── eGpuModelFlags.cs
│ ├── eLanguage.cs
│ ├── eLogLevel.cs
│ ├── eModelImplementation.cs
│ ├── eResultFlags.cs
│ ├── eSpeakerChannel.cs
│ ├── iAudioBuffer.cs
│ ├── iAudioReader.cs
│ ├── iMediaFoundation.cs
│ ├── iModel.cs
│ └── sCaptureParams.cs
├── AssemblyInfo.cs
├── AssemblyTitle.cs
├── Callbacks.cs
├── CaptureCallbacks.cs
├── Context.cs
├── ExtensionMethods.cs
├── Internal
│ ├── NativeLogger.cs
│ ├── iContext.cs
│ ├── iTranscribeResult.cs
│ ├── sCaptureCallbacks.cs
│ ├── sCaptureDevice.cs
│ ├── sFullParams.cs
│ ├── sLoadModelCallbacks.cs
│ ├── sLoggerSetup.cs
│ ├── sModelSetup.cs
│ └── sProgressSink.cs
├── Library.cs
├── Readme.md
├── WhisperNet.csproj
└── WhisperNet.nuspec
├── WhisperPS
├── Commands
│ ├── ExportBase.cs
│ ├── ExportSubrip.cs
│ ├── ExportText.cs
│ ├── ExportWebVtt.cs
│ ├── FormatSegments.cs
│ ├── ListAdapters.cs
│ ├── LoadModel.cs
│ ├── TranscribeBase.cs
│ └── TranscribeFile.cs
├── Internal
│ ├── MarshalEx.cs
│ ├── NativeLogger.cs
│ ├── iTranscribeResult.cs
│ ├── sCaptureDevice.cs
│ ├── sFullParams.cs
│ ├── sLoadModelCallbacks.cs
│ ├── sModelSetup.cs
│ └── sProgressSink.cs
├── Library.cs
├── Properties
│ └── AssemblyTitle.cs
├── Readme.md
├── Types
│ ├── Model.cs
│ ├── Segment.cs
│ └── Transcription.cs
├── Utils
│ ├── CommandLogger.cs
│ └── MiscUtils.cs
├── WhisperPS.csproj
├── WhisperPS.psd1
├── app.config
└── packages.config
├── gui-capture.png
├── gui-load-model.png
└── gui-transcribe.png
/.gitignore:
--------------------------------------------------------------------------------
1 | .vs/
2 | ComLightLib/x64/
3 | Whisper/x64/
4 | x64/
5 | Tools/CompressShaders/bin/
6 | Tools/CompressShaders/obj/
7 | Whisper/D3D/shaderData-Debug.inl
8 | Whisper/D3D/shaderData-Release.inl
9 | WhisperNet/bin/
10 | WhisperNet/obj/
11 | Examples/TranscribeCS/bin/
12 | Examples/TranscribeCS/obj/
13 | *.aps
14 | *.json
15 | *.user
16 | Examples/MicrophoneCS/obj/
17 | Examples/MicrophoneCS/bin/
18 | Tools/PerfSummary/bin/
19 | Tools/PerfSummary/obj/
20 | packages/
21 | WhisperPS/obj/
22 | WhisperPS/bin/
23 | Tools/CompressTables/bin/
24 | Tools/CompressTables/obj/
--------------------------------------------------------------------------------
/ComLightLib/ComLightLib.vcxproj.filters:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/ComLightLib/Exception.hpp:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | namespace ComLight
4 | {
5 | class Exception : public std::runtime_error
6 | {
7 | // I don't like C++ exceptions too much, but for some cases they are useful.
8 | // You can throw ComLight::Exception from constructor, or from FinalConstruct() method, the library will catch & return the code from the class factory function.
9 | // Unfortunately, for interface methods this doesn't work, the C++ parts of the library can't catch them without very complex trickery like code generation.
10 | // You can still use this class in methods, but you'll need to catch them manually near the API boundary or the app will crash.
11 | // C++ doesn't have an ABI, the framework can't catch C++ exception across the modules.
12 | const HRESULT m_code;
13 |
14 | public:
15 |
16 | Exception( HRESULT hr ) : runtime_error( "ComLight HRESULT exception" ), m_code( hr ) { }
17 |
18 | HRESULT code() const { return m_code; }
19 | };
20 | }
--------------------------------------------------------------------------------
/ComLightLib/Readme.txt:
--------------------------------------------------------------------------------
1 | Copy-pasted from there:
2 | https://github.com/Const-me/ComLightInterop/tree/master/ComLightLib
3 | With only a few minor changes.
--------------------------------------------------------------------------------
/ComLightLib/comLightClient.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "comLightCommon.h"
3 | #include "client/CComPtr.hpp"
4 | #include "utils/typeTraits.hpp"
5 |
6 | namespace ComLight
7 | {
8 | namespace details
9 | {
10 | template
11 | inline constexpr void** castDoublePointerToVoid( T** pp )
12 | {
13 | static_assert( pointersAssignable(), "IID_PPV_ARGS macro should be used with IUnknown interfaces" );
14 | return reinterpret_cast( pp );
15 | }
16 | }
17 | }
18 |
19 | #ifdef IID_PPV_ARGS
20 | #undef IID_PPV_ARGS
21 | #endif
22 |
23 | #define IID_PPV_ARGS( pp ) decltype( **pp )::iid, ::ComLight::details::castDoublePointerToVoid( pp )
--------------------------------------------------------------------------------
/ComLightLib/comLightCommon.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "hresult.h"
3 |
4 | #ifdef _MSC_VER
5 | #include
6 | #else
7 | #include "pal/guiddef.h"
8 | using LPCTSTR = const char*;
9 | #endif
10 |
11 | #include "unknwn.h"
--------------------------------------------------------------------------------
/ComLightLib/comLightServer.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "comLightCommon.h"
3 | #include "client/CComPtr.hpp"
4 |
5 | #include "server/ObjectRoot.hpp"
6 | #include "server/interfaceMap.h"
7 | #include "server/Object.hpp"
8 | #include "server/freeThreadedMarshaller.h"
9 |
10 | #ifdef _MSC_VER
11 | // On Windows, it's controlled by library.def module definition file. There's __declspec(dllexport), but it adds underscore, I don't like that.
12 | #define DLLEXPORT extern "C"
13 | #else
14 | #define DLLEXPORT extern "C" __attribute__((visibility("default")))
15 | #endif
--------------------------------------------------------------------------------
/ComLightLib/hresult.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #ifdef _MSC_VER
4 | #include
5 | #include
6 | #else
7 | #include "pal/hresult.h"
8 | #endif
9 |
10 | #define CHECK( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) return __hr; }
11 |
12 | #ifndef _MSC_VER
13 | inline constexpr HRESULT HRESULT_FROM_WIN32( int c )
14 | {
15 | return c < 0 ? c : ( ( 0xFFFF & c ) | 0x80070000 );
16 | }
17 |
18 | constexpr HRESULT OLE_E_BLANK = _HRESULT_TYPEDEF_( 0x80040007 );
19 | constexpr HRESULT E_BOUNDS = _HRESULT_TYPEDEF_( 0x8000000BL );
20 |
21 | constexpr int ERROR_HANDLE_EOF = 38;
22 | constexpr int ERROR_ALREADY_INITIALIZED = 1247;
23 | #endif
24 |
25 | constexpr HRESULT E_EOF = HRESULT_FROM_WIN32( ERROR_HANDLE_EOF );
26 | constexpr HRESULT E_ALREADY_INITIALIZED = HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED );
--------------------------------------------------------------------------------
/ComLightLib/pal/guiddef.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 | #ifndef GUID_DEFINED
5 | #define GUID_DEFINED
6 | #endif
7 |
8 | struct GUID
9 | {
10 | uint32_t Data1;
11 | uint16_t Data2;
12 | uint16_t Data3;
13 | std::array Data4;
14 |
15 | constexpr inline bool operator==( const GUID& that ) const
16 | {
17 | return Data1 == that.Data1 && Data2 == that.Data2 && Data3 == that.Data3 && Data4 == that.Data4;
18 | }
19 | };
20 |
21 | using REFIID = const GUID&;
--------------------------------------------------------------------------------
/ComLightLib/server/ObjectRoot.hpp:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "RefCounter.hpp"
3 | #include "../comLightCommon.h"
4 | #include "../utils/typeTraits.hpp"
5 |
6 | namespace ComLight
7 | {
8 | // Base class of objects, implements reference counting, also a few lifetime methods.
9 | // The template argument is the interface you want clients to get when they ask for IID_IUnknown. By convention, that pointer defines object's identity.
10 | template
11 | class ObjectRoot : public RefCounter, public I
12 | {
13 | protected:
14 |
15 | inline HRESULT internalFinalConstruct()
16 | {
17 | return S_FALSE;
18 | }
19 |
20 | inline HRESULT FinalConstruct()
21 | {
22 | return S_FALSE;
23 | }
24 |
25 | inline void FinalRelease() { }
26 |
27 | IUnknown* getUnknown()
28 | {
29 | static_assert( details::pointersAssignable(), "The interface doesn't derive from IUnknown" );
30 | return static_cast( this );
31 | }
32 |
33 | bool queryExtraInterfaces( REFIID riid, void **ppvObject ) const
34 | {
35 | return false;
36 | }
37 |
38 | // Implement query interface with 2 entries, IUnknown and I.
39 | bool implQueryInterface( REFIID riid, void** ppvObject )
40 | {
41 | if( riid == I::iid() || riid == IUnknown::iid() )
42 | {
43 | I* const result = this;
44 | result->AddRef();
45 | *ppvObject = result;
46 | return true;
47 | }
48 | return false;
49 | }
50 | };
51 | }
--------------------------------------------------------------------------------
/ComLightLib/server/RefCounter.hpp:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 | #include
5 |
6 | namespace ComLight
7 | {
8 | // Very base class of objects, implements reference counting.
9 | class RefCounter
10 | {
11 | std::atomic_uint referenceCounter;
12 |
13 | public:
14 |
15 | RefCounter() : referenceCounter( 0 ) { }
16 |
17 | inline virtual ~RefCounter() { }
18 |
19 | RefCounter( const RefCounter &that ) = delete;
20 | RefCounter( RefCounter &&that ) = delete;
21 |
22 | protected:
23 |
24 | uint32_t implAddRef()
25 | {
26 | return ++referenceCounter;
27 | }
28 |
29 | uint32_t implRelease()
30 | {
31 | // Might be a good idea to use locks, at least in debug builds. They're much slower than atomics, but with locks it's possible to detect when 2 threads call release at the same time, for object with counter = 1.
32 | // It's a memory management bug, but it would be nice if debug builds would handle that case gracefully.
33 | const uint32_t rc = --referenceCounter;
34 | assert( rc != UINT_MAX );
35 | return rc;
36 | }
37 | };
38 | }
--------------------------------------------------------------------------------
/ComLightLib/server/freeThreadedMarshaller.cpp:
--------------------------------------------------------------------------------
1 | #include "freeThreadedMarshaller.h"
2 | #ifdef _MSC_VER
3 | #include
4 |
5 | HRESULT ComLight::details::createFreeThreadedMarshaller( IUnknown* pUnkOuter, IUnknown** ppUnkMarshal )
6 | {
7 | return ::CoCreateFreeThreadedMarshaler( (LPUNKNOWN)pUnkOuter, (LPUNKNOWN *)ppUnkMarshal );
8 | }
9 |
10 | bool ComLight::details::queryMarshallerInterface( REFIID riid, void **ppvObject, IUnknown* marshaller )
11 | {
12 | if( riid != IID_IMarshal || nullptr == marshaller )
13 | return false;
14 | const HRESULT hr = marshaller->QueryInterface( IID_IMarshal, ppvObject );
15 | return SUCCEEDED( hr ) ? true : false;
16 | }
17 | #endif
--------------------------------------------------------------------------------
/ComLightLib/server/freeThreadedMarshaller.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #ifdef _MSC_VER
3 | #include "../comLightCommon.h"
4 |
5 | namespace ComLight
6 | {
7 | namespace details
8 | {
9 | HRESULT createFreeThreadedMarshaller( IUnknown* pUnkOuter, IUnknown** ppUnkMarshal );
10 | bool queryMarshallerInterface( REFIID riid, void **ppvObject, IUnknown* marshaller );
11 | }
12 | }
13 |
14 | #define DECLARE_FREE_THREADED_MARSHALLER() \
15 | private: \
16 | ComLight::CComPtr m_freeThreadedMarshaller; \
17 | protected: \
18 | HRESULT internalFinalConstruct() \
19 | { \
20 | return ComLight::details::createFreeThreadedMarshaller( getUnknown(), &m_freeThreadedMarshaller ); \
21 | } \
22 | bool queryExtraInterfaces( REFIID riid, void **ppvObject ) const \
23 | { \
24 | return ComLight::details::queryMarshallerInterface( riid, ppvObject, m_freeThreadedMarshaller ); \
25 | }
26 |
27 | #else
28 | #define DECLARE_FREE_THREADED_MARSHALLER()
29 | #endif
--------------------------------------------------------------------------------
/ComLightLib/server/interfaceMap.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "../utils/typeTraits.hpp"
3 |
4 | // Unlike ATL, the interface map is optional for ComLight.
5 | // If you won't declare a map, the object will support 2 interfaces: IUnknown, and whatever template argument was passed to ObjectRoot class.
6 | #define BEGIN_COM_MAP() \
7 | protected: \
8 | bool implQueryInterface( REFIID iid, void** ppvObject ) {
9 |
10 | #define END_COM_MAP() return false; }
11 |
12 | namespace ComLight
13 | {
14 | namespace details
15 | {
16 | template
17 | inline bool tryReturnInterface( REFIID iid, C* pThis, void** ppvResult )
18 | {
19 | static_assert( pointersAssignable(), "Trying to implement an interface that doesn't derive from IUnknown" );
20 | static_assert( pointersAssignable(), "Declared support for an interface, but the class doesn't implement it" );
21 | if( I::iid() != iid )
22 | return false;
23 | I* const result = pThis;
24 | result->AddRef();
25 | *ppvResult = result;
26 | return true;
27 | }
28 | }
29 | }
30 |
31 | #define COM_INTERFACE_ENTRY( I ) if( ComLight::details::tryReturnInterface( iid, this, ppvObject ) ) return true;
--------------------------------------------------------------------------------
/ComLightLib/unknwn.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | // Calling conventions
5 | #ifdef _MSC_VER
6 | #define COMLIGHTCALL __stdcall
7 | #define DECLSPEC_NOVTABLE __declspec(novtable)
8 | #elif defined(__GNUC__) || defined(__clang__)
9 | #if defined(__i386__)
10 | #define COMLIGHTCALL __attribute__((stdcall))
11 | #else
12 | #define COMLIGHTCALL
13 | #endif
14 | #define DECLSPEC_NOVTABLE
15 | #else
16 | #error Unsupported C++ compiler
17 | #endif
18 |
19 | #include "utils/guid_parse.hpp"
20 |
21 | #define DEFINE_INTERFACE_ID( guidString ) static constexpr GUID iid() { return ::ComLight::make_guid( guidString ); }
22 |
23 | namespace ComLight
24 | {
25 | // This thing is binary compatible with IUnknown from Windows SDK. See DesktopClient demo project, it uses normal COM interop in .NET framework 4.7 to call my implementation.
26 | struct DECLSPEC_NOVTABLE IUnknown
27 | {
28 | DEFINE_INTERFACE_ID( "00000000-0000-0000-c000-000000000046" );
29 |
30 | virtual HRESULT COMLIGHTCALL QueryInterface( REFIID riid, void **ppvObject ) = 0;
31 |
32 | virtual uint32_t COMLIGHTCALL AddRef() = 0;
33 |
34 | virtual uint32_t COMLIGHTCALL Release() = 0;
35 | };
36 | }
--------------------------------------------------------------------------------
/ComputeShaders/ComputeShaders.cpp:
--------------------------------------------------------------------------------
1 | void fnComputeShaders()
2 | {
3 | }
--------------------------------------------------------------------------------
/ComputeShaders/Readme.txt:
--------------------------------------------------------------------------------
1 | This project compiles all the compute shaders which implement the model.
2 |
3 | Many shaders come in 2 versions, something.hlsl and something64.hlsl
4 |
5 | The version with the `64` suffix is used on AMD GPUs, the version without suffix is used on nVidia and Intel GPUs.
6 |
7 | Not all of these shaders are actually used for anything.
8 | Some of them are implementing binary compatibility for the reference CPU version, and not used unless messing with the `constexpr` flags in MlContext C++ class.
9 | Such shaders often require FP64 support, which is an optional feature in D3D11.
10 | CompressShaders tool detects such shaders by looking at the SFI0 chunk in the binary, and outputs a bitmap of the FP64 shaders.
11 | This way, missing FP64 hardware support shouldn’t break the library.
--------------------------------------------------------------------------------
/ComputeShaders/add.hlsl:
--------------------------------------------------------------------------------
1 | inline float compute( float a, float b )
2 | {
3 | return a + b;
4 | }
5 |
6 | #include "componentwiseBinaryOp.hlsli"
--------------------------------------------------------------------------------
/ComputeShaders/addInPlace.hlsl:
--------------------------------------------------------------------------------
1 | #ifndef THREADS
2 | #define THREADS 512
3 | #endif
4 |
5 | Buffer arg0: register( t0 );
6 | RWBuffer result: register( u0 );
7 |
8 | cbuffer Constants: register( b0 )
9 | {
10 | uint4 size: packoffset( c0 );
11 | uint4 strides: packoffset( c1 );
12 | uint4 argStrides: packoffset( c3 );
13 | }
14 |
15 | inline uint rowOffset( uint3 idx, uint4 strides )
16 | {
17 | return idx[ 0 ] * strides[ 1 ] + idx[ 1 ] * strides[ 2 ] + idx[ 2 ] * strides[ 3 ];
18 | }
19 |
20 | [ numthreads( THREADS, 1, 1 ) ]
21 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
22 | {
23 | uint rdi = rowOffset( group, strides );
24 | uint rsi = rowOffset( group, argStrides );
25 |
26 | const uint rdiEnd = rdi + size[ 0 ] * strides[ 0 ];
27 | rdi += thread * strides[ 0 ];
28 | rsi += thread * argStrides[ 0 ];
29 |
30 | const uint rdiInc = THREADS * strides[ 0 ];
31 | const uint rsiInc = THREADS * argStrides[ 0 ];
32 |
33 | for( ; rdi < rdiEnd; rdi += rdiInc, rsi += rsiInc )
34 | {
35 | float f = result[ rdi ];
36 | f += arg0[ rsi ];
37 | result[ rdi ] = f;
38 | }
39 | }
--------------------------------------------------------------------------------
/ComputeShaders/addRepeat64.hlsl:
--------------------------------------------------------------------------------
1 | #define THREADS 64
2 | #include "addRepeat.hlsl"
--------------------------------------------------------------------------------
/ComputeShaders/addRepeatGelu64.hlsl:
--------------------------------------------------------------------------------
1 | #define THREADS 64
2 | #include "addRepeatGelu.hlsl"
--------------------------------------------------------------------------------
/ComputeShaders/addRows.hlsl:
--------------------------------------------------------------------------------
1 | #ifndef THREADS
2 | #define THREADS 256
3 | #endif
4 |
5 | // dec.tokenEmbedding tensor
6 | Buffer tokenEmbedding: register( t0 );
7 | // dec.positionalEmbedding tensor
8 | Buffer positionalEmbedding: register( t1 );
9 | // R32_UINT buffer with the input tokens
10 | Buffer embd: register( t2 );
11 | // Output tensor
12 | RWBuffer result: register( u0 );
13 |
14 | cbuffer Constants: register( b0 )
15 | {
16 | uint rowLength: packoffset( c0.x );
17 | uint pastTokensCount: packoffset( c0.y );
18 | uint outputRowStride: packoffset( c0.z );
19 | uint2 embStrides: packoffset( c1.x );
20 | uint2 posStrides: packoffset( c1.z );
21 | }
22 |
23 | [ numthreads( THREADS, 1, 1 ) ]
24 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
25 | {
26 | const uint row = group.x;
27 | const uint rowTok = embd[ row ];
28 | const uint rowPos = row + pastTokensCount;
29 |
30 | uint rdi = row * outputRowStride;
31 | const uint rdiEnd = rdi + rowLength;
32 | rdi += thread;
33 |
34 | uint rsiTok = rowTok * embStrides.y;
35 | rsiTok += thread * embStrides.x;
36 |
37 | uint rsiPos = rowPos * posStrides.y;
38 | rsiPos += thread * posStrides.x;
39 |
40 | for( ; rdi < rdiEnd; rdi += THREADS, rsiTok += THREADS * embStrides.x, rsiPos += THREADS * posStrides.x )
41 | {
42 | float a = tokenEmbedding[ rsiTok ];
43 | float b = positionalEmbedding[ rsiPos ];
44 | result[ rdi ] = a + b;
45 | }
46 | }
--------------------------------------------------------------------------------
/ComputeShaders/componentwiseBinaryOp.hlsli:
--------------------------------------------------------------------------------
1 | Buffer arg0: register( t0 );
2 | Buffer arg1: register( t1 );
3 | RWBuffer result: register( u0 );
4 |
5 | cbuffer Constants: register( b0 )
6 | {
7 | uint4 src0_elements: packoffset( c0 );
8 | uint4 src0_strides: packoffset( c1 );
9 | uint4 src1_elements: packoffset( c2 );
10 | uint4 src1_strides: packoffset( c3 );
11 | uint4 result_elements: packoffset( c4 );
12 | uint4 result_strides: packoffset( c5 );
13 | }
14 |
15 | [ numthreads( 32, 1, 1 ) ]
16 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
17 | {
18 | const uint j = group.x;
19 | const uint nb1 = result_strides[ 1 ];
20 | const uint nb01 = src0_strides[ 1 ];
21 |
22 | const uint nb10 = src1_strides[ 0 ];
23 | const uint nb11 = src1_strides[ 1 ];
24 | const uint nc = src0_elements[ 0 ];
25 |
26 | uint rsi0 = j * nb01;
27 | uint rsi1 = j * nb11;
28 | uint rdi = j * nb1;
29 | const uint rsi0End = rsi0 + nc;
30 |
31 | rsi0 += thread;
32 | rsi1 += thread * nb10;
33 | rdi += thread;
34 |
35 | const uint rsi1Inc = 32 * nb10;
36 | for( ; rsi0 < rsi0End; rsi0 += 32, rsi1 += rsi1Inc, rdi += 32 )
37 | {
38 | const float a = arg0[ rsi0 ];
39 | const float b = arg1[ rsi1 ];
40 | const float res = compute( a, b );
41 | result[ rdi ] = res;
42 | }
43 | }
--------------------------------------------------------------------------------
/ComputeShaders/convolutionPrep1.hlsl:
--------------------------------------------------------------------------------
1 | // ggml_compute_forward_conv_1d_1s_f16_f32, prepare kernel data (src0)
2 | // Dispatch [ ne01, ne02, 1 ] thread groups
3 | Buffer arg0: register( t0 );
4 | RWBuffer result: register( u0 );
5 |
6 | cbuffer Constants: register( b0 )
7 | {
8 | uint4 src0_elements: packoffset( c0 );
9 | uint4 src0_strides: packoffset( c1 );
10 | }
11 |
12 | inline uint roundUp32( uint x )
13 | {
14 | return ( x + 31 ) & ( ~31u );
15 | }
16 |
17 | [ numthreads( 32, 1, 1 ) ]
18 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
19 | {
20 | const uint nb01 = src0_strides[ 1 ];
21 | const uint nb02 = src0_strides[ 2 ];
22 |
23 | const uint ne00 = src0_elements[ 0 ];
24 | const uint ne01 = src0_elements[ 1 ];
25 | const uint ew0 = roundUp32( ne01 );
26 |
27 | const uint i02 = group.y;
28 | const uint i01 = group.x;
29 |
30 | uint rsi = i02 * nb02 + i01 * nb01;
31 | const uint rsiEnd = rsi + ne00;
32 | uint rdi = i02 * ew0 * ne00 + i01;
33 | rsi += thread;
34 | rdi += thread * ew0;
35 | const uint rdiInc = 32 * ew0;
36 |
37 | for( ; rsi < rsiEnd; rsi += 32, rdi += rdiInc )
38 | result[ rdi ] = arg0[ rsi ];
39 | }
--------------------------------------------------------------------------------
/ComputeShaders/convolutionPrep2.hlsl:
--------------------------------------------------------------------------------
1 | // ggml_compute_forward_conv_1d_1s_f16_f32, prepare source data (src1)
2 | // Dispatch [ ne11, 1, 1 ] thread groups
3 | Buffer arg1: register( t0 );
4 | RWBuffer result: register( u0 );
5 |
6 | cbuffer Constants: register( b0 )
7 | {
8 | uint4 src0_elements: packoffset( c0 );
9 | uint4 src1_elements: packoffset( c2 );
10 | uint4 src1_strides: packoffset( c3 );
11 | }
12 |
13 | #include "miscUtils.hlsli"
14 |
15 | [ numthreads( 32, 1, 1 ) ]
16 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
17 | {
18 | const uint i11 = group.x;
19 |
20 | const uint ne00 = src0_elements[ 0 ];
21 | const uint ne01 = src0_elements[ 1 ];
22 | const uint ne10 = src1_elements[ 0 ];
23 | const uint nb11 = src1_strides[ 1 ];
24 |
25 | const uint nk = ne00;
26 | const uint nh = nk / 2;
27 | const int ew0 = roundUp32( ne01 );
28 |
29 | uint rsi = i11 * nb11;
30 | uint rdi = nh * ew0 + i11;
31 | const uint rdiInc = ew0 * 32;
32 | const uint rsiEnd = rsi + ne10;
33 |
34 | rsi += thread;
35 | rdi += thread * ew0;
36 |
37 | for( ; rsi < rsiEnd; rsi += 32, rdi += rdiInc )
38 | {
39 | float f = arg1[ rsi ];
40 | f = adjustFp16( f );
41 | result[ rdi ] = f;
42 | }
43 | }
--------------------------------------------------------------------------------
/ComputeShaders/copyConvert.hlsl:
--------------------------------------------------------------------------------
1 | // ggml_compute_forward_dup_f32 when we only need to convert types, but not reshape the tensor
2 | // Dispatch [ ne01, ne02, ne03 ] thread groups of this shader
3 | Buffer arg0: register( t0 );
4 | RWBuffer result: register( u0 );
5 |
6 | cbuffer Constants: register( b0 )
7 | {
8 | uint4 src0_elements: packoffset( c0 );
9 | uint4 src0_strides: packoffset( c1 );
10 | bool downcastFp32 : packoffset( c2.x );
11 | }
12 |
13 | #include "miscUtils.hlsli"
14 |
15 | [ numthreads( 32, 1, 1 ) ]
16 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
17 | {
18 | const uint nb00 = src0_strides[ 0 ];
19 | const uint nb01 = src0_strides[ 1 ];
20 | const uint nb02 = src0_strides[ 2 ];
21 | const uint nb03 = src0_strides[ 3 ];
22 |
23 | const uint ne00 = src0_elements[ 0 ];
24 | const uint ne01 = src0_elements[ 1 ];
25 | const uint ne02 = src0_elements[ 2 ];
26 | const uint ne03 = src0_elements[ 3 ];
27 |
28 | const uint i01 = group.x;
29 | const uint i02 = group.y;
30 | const uint i03 = group.z;
31 |
32 | const uint rs = ne00 * nb00;
33 | //const uint id = i01 + i02 * ne02 + i03 * ne01 * ne02;
34 | const uint id = ( i03 * ne01 + i02 ) * ne02 + i01;
35 |
36 | uint rsi = i01 * nb01 + i02 * nb02 + i03 * nb03;
37 | uint rdi = id * rs;
38 |
39 | const uint rsiEnd = rsi + rs;
40 | rsi += thread;
41 | rdi += thread;
42 | for( ; rsi < rsiEnd; rsi += 32, rdi += 32 )
43 | {
44 | float f = arg0[ rsi ];
45 | [branch]
46 | if( downcastFp32 )
47 | f = adjustFp16( f );
48 | result[ rdi ] = f;
49 | }
50 | }
--------------------------------------------------------------------------------
/ComputeShaders/diagMaskInf.hlsl:
--------------------------------------------------------------------------------
1 | // ggml_compute_forward_diag_mask_inf_f32
2 | RWBuffer result: register( u0 );
3 |
4 | cbuffer Constants: register( b0 )
5 | {
6 | uint4 elements: packoffset( c0 );
7 | uint4 strides: packoffset( c1 );
8 | uint n_past : packoffset( c2.x );
9 | }
10 |
11 | static const float negativeInfinity = asfloat( 0xff800000 );
12 |
13 | [numthreads( 32, 1, 1 )]
14 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
15 | {
16 | const uint k = group.y;
17 | const uint j = group.x;
18 |
19 | // Start of the row
20 | uint rdi = k * strides[ 2 ] + j * strides[ 1 ];
21 | // End of the row
22 | const uint rdiEnd = rdi + elements[ 0 ] * strides[ 0 ];
23 | // First index to write in this thread
24 | rdi += ( n_past + j + thread + 1 ) * strides[ 0 ];
25 | // Index increment
26 | const uint rdiInc = 32 * strides[ 0 ];
27 |
28 | for( ; rdi < rdiEnd; rdi += rdiInc )
29 | result[ rdi ] = negativeInfinity;
30 | }
--------------------------------------------------------------------------------
/ComputeShaders/fmaRepeat164.hlsl:
--------------------------------------------------------------------------------
1 | #define THREADS 64
2 | #include "fmaRepeat1.hlsl"
--------------------------------------------------------------------------------
/ComputeShaders/fmaRepeat2.hlsl:
--------------------------------------------------------------------------------
1 | // Implementation of fmaRepeat() when source arguments have different shape or VRAM layout
2 | // Dispatch [ nb[ 1 ], nb[ 2 ], nb[ 3 ] ] thread groups of this shader, where nb is size of the destination tensor
3 | RWBuffer tensor: register( u0 );
4 | Buffer patternMul: register( t0 );
5 | Buffer patternAdd: register( t1 );
6 |
7 | cbuffer Constants: register( b0 )
8 | {
9 | uint4 tensorSize: packoffset( c0 );
10 | uint4 tensorStrides: packoffset( c1 );
11 | uint4 patternSizeMul: packoffset( c2 );
12 | uint4 patternStridesMul: packoffset( c3 );
13 | uint4 patternSizeAdd: packoffset( c4 );
14 | uint4 patternStridesAdd: packoffset( c5 );
15 | }
16 |
17 | #ifndef THREADS
18 | #define THREADS 32
19 | #endif
20 |
21 | #include "repeatUtils.hlsli"
22 |
23 | inline float loadPattern( Buffer buffer, uint rowStart, uint i, uint4 size, uint4 stride )
24 | {
25 | i %= size.x;
26 | return buffer[ i * stride.x + rowStart ];
27 | }
28 |
29 | [ numthreads( THREADS, 1, 1 ) ]
30 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
31 | {
32 | uint3 it = tensorIteratorState( group, thread, tensorSize, tensorStrides );
33 | const uint rsiMul = rowOffset( group % patternSizeMul.yzw, patternStridesMul );
34 | const uint rsiAdd = rowOffset( group % patternSizeAdd.yzw, patternStridesAdd );
35 |
36 | for( uint i = thread; it.x < it.z; it.x += it.y, i++ )
37 | {
38 | precise float f = tensor[ it.x ];
39 | float mul = loadPattern( patternMul, rsiMul, i, patternSizeMul, patternStridesMul );
40 | float add = loadPattern( patternAdd, rsiAdd, i, patternSizeAdd, patternStridesAdd );
41 | f *= mul;
42 | f += add;
43 | tensor[ it.x ] = f;
44 | }
45 | }
--------------------------------------------------------------------------------
/ComputeShaders/fp64Utils.hlsli:
--------------------------------------------------------------------------------
1 | // TODO: compile another version of these shader, and use it on GPUs with ExtendedDoublesShaderInstructions flag, will become slightly faster
2 | // https://learn.microsoft.com/en-us/windows/win32/api/d3d11/ns-d3d11-d3d11_feature_data_d3d11_options
3 | #ifndef ExtendedDoublesShaderInstructions
4 | #define ExtendedDoublesShaderInstructions 0
5 | #endif
6 |
7 | // Compute num/den in FP64 precision
8 | inline double div64( double num, double den )
9 | {
10 | #if ExtendedDoublesShaderInstructions
11 | return num / den;
12 | #else
13 | // https://en.wikipedia.org/wiki/Division_algorithm#Newton%E2%80%93Raphson_division
14 | double x = 1.0f / (float)den;
15 | x += x * ( 1.0 - den * x );
16 | x += x * ( 1.0 - den * x );
17 | return num * x;
18 | #endif
19 | }
20 |
21 | // Compute sqrt(x) in FP64 precision
22 | inline double sqrt64( double x )
23 | {
24 | double root = sqrt( (float)x );
25 | root = 0.5 * ( root + div64( x, root ) );
26 | root = 0.5 * ( root + div64( x, root ) );
27 | return root;
28 | }
--------------------------------------------------------------------------------
/ComputeShaders/groupReduce64.hlsli:
--------------------------------------------------------------------------------
1 | groupshared float sharedAccumulators[ 64 ];
2 |
3 | // Compute horisontal sum of the numbers. The result is only correct on the thread #0 of the group.
4 | void horizontalSum( const uint thread, inout float sum )
5 | {
6 | sharedAccumulators[ thread ] = sum;
7 | for( uint i = 32; i > 1; i /= 2 )
8 | {
9 | GroupMemoryBarrierWithGroupSync();
10 | if( thread < i )
11 | {
12 | sum += sharedAccumulators[ thread + i ];
13 | sharedAccumulators[ thread ] = sum;
14 | }
15 | }
16 | GroupMemoryBarrierWithGroupSync();
17 | if( 0 == thread )
18 | sum += sharedAccumulators[ 1 ];
19 | }
20 |
21 | // Compute horisontal sum of the numbers, and broadcast to all threads of the group.
22 | void horizontalSumBroadcast( const uint thread, inout float sum )
23 | {
24 | horizontalSum( thread, sum );
25 | if( 0 == thread )
26 | sharedAccumulators[ 0 ] = sum;
27 | GroupMemoryBarrierWithGroupSync();
28 | sum = sharedAccumulators[ 0 ];
29 | }
30 |
31 | // Compute horizontal maximum of the numbers, and broadcast to all threads of the group.
32 | void horizontalMaxBroadcast( const uint thread, inout float ax )
33 | {
34 | sharedAccumulators[ thread ] = ax;
35 | for( uint i = 32; i > 0; i /= 2 )
36 | {
37 | GroupMemoryBarrierWithGroupSync();
38 | if( thread < i )
39 | {
40 | ax = max( ax, sharedAccumulators[ thread + i ] );
41 | sharedAccumulators[ thread ] = ax;
42 | }
43 | }
44 | GroupMemoryBarrierWithGroupSync();
45 | ax = sharedAccumulators[ 0 ];
46 | }
--------------------------------------------------------------------------------
/ComputeShaders/mulMatByRow.hlsl:
--------------------------------------------------------------------------------
1 | // Matrix * row product, like [ E0, E1, E2, E3 ] * [ E0, 1, E2, E3 ] = [ E1, 1, E2, E3 ]
2 | // Dispatch [ E1, E2, E3 ] groups of this shader
3 | Buffer arg0: register( t0 );
4 | Buffer arg1: register( t1 );
5 | RWBuffer result: register( u0 );
6 |
7 | cbuffer Constants: register( b0 )
8 | {
9 | uint4 arg0Size: packoffset( c0 );
10 | uint4 arg0Strides: packoffset( c1 );
11 | uint4 arg1Size: packoffset( c2 );
12 | uint4 arg1Strides: packoffset( c3 );
13 | uint4 resultSize: packoffset( c4 );
14 | uint4 resultStrides: packoffset( c5 );
15 | }
16 |
17 | #include "groupReduce.hlsli"
18 |
19 | inline uint hadd( uint3 vec )
20 | {
21 | return vec.x + vec.y + vec.z;
22 | }
23 | inline uint hadd( uint2 vec )
24 | {
25 | return vec.x + vec.y;
26 | }
27 |
28 | [ numthreads( 32, 1, 1 ) ]
29 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
30 | {
31 | uint s0 = hadd( group * arg0Strides.yzw );
32 | uint s1 = hadd( group.yz * arg1Strides.zw );
33 | const uint s0End = s0 + arg0Size.x * arg0Strides.x;
34 | const uint s0Inc = 32 * arg0Strides.x;
35 | const uint s1Inc = 32 * arg1Strides.x;
36 |
37 | s0 += thread * arg0Strides.x;
38 | s1 += thread * arg1Strides.x;
39 | float dp = 0;
40 | for( ; s0 < s0End; s0 += s0Inc, s1 += s1Inc )
41 | dp = mad( arg0[ s0 ], arg1[ s1 ], dp );
42 |
43 | horizontalSum( thread, dp );
44 | if( 0 != thread )
45 | return;
46 |
47 | const uint rdi = group.x + hadd( group.yz * resultStrides.zw );
48 | result[ rdi ] = dp;
49 | }
--------------------------------------------------------------------------------
/ComputeShaders/mulMatByScalar.hlsl:
--------------------------------------------------------------------------------
1 | // Matrix * scalar product, like [ 1, E1, E2, E3 ] * [ 1, 1, E2, E3 ] = [ E1, 1, E2, E3 ]
2 | // Dispatch [ E2, E3, 1 ] thread groups of this shader
3 | Buffer arg0: register( t0 );
4 | Buffer arg1: register( t1 );
5 | RWBuffer result: register( u0 );
6 |
7 | cbuffer Constants: register( b0 )
8 | {
9 | uint4 arg0Size: packoffset( c0 );
10 | uint4 arg0Strides: packoffset( c1 );
11 | uint4 arg1Size: packoffset( c2 );
12 | uint4 arg1Strides: packoffset( c3 );
13 | uint4 resultSize: packoffset( c4 );
14 | uint4 resultStrides: packoffset( c5 );
15 | }
16 |
17 | inline uint hadd( uint2 vec )
18 | {
19 | return vec.x + vec.y;
20 | }
21 |
22 | [ numthreads( 32, 1, 1 ) ]
23 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
24 | {
25 | const float scalarValue = arg1[ hadd( group.xy * arg1Strides.zw ) ];
26 |
27 | uint s0 = hadd( group.xy * arg0Strides.zw );
28 | const uint s0Inc = 32 * arg0Strides.y;
29 | s0 += thread * arg0Strides.y;
30 |
31 | uint rdi = hadd( group.xy * resultStrides.zw );
32 | const uint rdiEnd = rdi + arg0Size.y;
33 | rdi += thread;
34 |
35 | for( ; rdi < rdiEnd; rdi += 32, s0 += s0Inc )
36 | {
37 | float f = arg0[ s0 ];
38 | f *= scalarValue;
39 | result[ rdi ] = f;
40 | }
41 | }
--------------------------------------------------------------------------------
/ComputeShaders/mulMatDotReshape.hlsl:
--------------------------------------------------------------------------------
1 | // GGML_TASK_INIT step for matrix*matrix product, where nb01 >= nb00;
2 | // Dispatch with [ ne11, ne12 ] groups
3 | Buffer arg0: register( t0 );
4 | RWBuffer result: register( u0 );
5 |
6 | cbuffer Constants: register( b0 )
7 | {
8 | uint4 src0_elements: packoffset( c0 );
9 | uint4 src0_strides: packoffset( c1 );
10 | }
11 |
12 | #include "miscUtils.hlsli"
13 |
14 | // Each thread group of this shader copies a single rows of the matrix
15 | [ numthreads( 32, 1, 1 ) ]
16 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
17 | {
18 | const uint i12 = group.y;
19 | const uint i11 = group.x;
20 | const uint ne10 = src0_elements.x;
21 | const uint ne11 = src0_elements.y;
22 | const uint nb12 = src0_strides.z;
23 | const uint nb11 = src0_strides.y;
24 |
25 | uint rdi = i11 * ne10 + i12 * ne10 * ne11;
26 | const uint rdiEnd = rdi + ne10;
27 | uint rsi = i12 * nb12 + i11 * nb11;
28 | rdi += thread;
29 | rsi += thread;
30 |
31 | for( ; rdi < rdiEnd; rdi += 32, rsi += 32 )
32 | result[ rdi ] = adjustFp16( arg0[ rsi ] );
33 | }
--------------------------------------------------------------------------------
/ComputeShaders/normFixed64.hlsl:
--------------------------------------------------------------------------------
1 | #define THREADS 64
2 | #include "normFixed.hlsl"
--------------------------------------------------------------------------------
/ComputeShaders/repeatUtils.hlsli:
--------------------------------------------------------------------------------
1 | inline uint rowOffset( uint3 idx, uint4 strides )
2 | {
3 | return idx[ 0 ] * strides[ 1 ] + idx[ 1 ] * strides[ 2 ] + idx[ 2 ] * strides[ 3 ];
4 | }
5 |
6 | // Initial iterator state for a row of the output tensor
7 | // x = current index, y = index increment, z = end of the index
8 | inline uint3 tensorIteratorState( uint3 group, uint thread, uint4 size, uint4 stride )
9 | {
10 | uint3 res;
11 | res.x = rowOffset( group, stride );
12 | res.y = THREADS * stride[ 0 ];
13 | res.z = res.x + size[ 0 ] * stride[ 0 ];
14 | res.x += thread * stride[ 0 ];
15 | return res;
16 | }
17 |
18 | // Handle a complete row of output tensor, using the iterator made by tensorIteratorState() function
19 | #define ROW_LOOP( ts ) for( ; ts.x < ts.z; ts.x += ts.y )
20 | // Same as above, using different row length
21 | #define ROW_LOOP_EX( ts, len, stride ) for( ; ts.x < ts.z; ts.x += len * stride[ 0 ] )
--------------------------------------------------------------------------------
/ComputeShaders/scaleInPlace.hlsl:
--------------------------------------------------------------------------------
1 | RWBuffer buffer: register( u0 );
2 |
3 | cbuffer Constants: register( b0 )
4 | {
5 | uint4 src0_elements: packoffset( c0 );
6 | uint4 src0_strides: packoffset( c1 );
7 | float multiplier: packoffset( c2.x );
8 | }
9 |
10 | [ numthreads( 32, 1, 1 ) ]
11 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
12 | {
13 | const uint nc0 = src0_elements[ 0 ];
14 | uint i = group.x * src0_strides[ 1 ];
15 | const uint iEnd = i + nc0;
16 | const float mul = multiplier;
17 | for( i += thread; i < iEnd; i += 32 )
18 | {
19 | float f = buffer[ i ];
20 | f *= mul;
21 | buffer[ i ] = f;
22 | }
23 | }
--------------------------------------------------------------------------------
/ComputeShaders/softMax64.hlsl:
--------------------------------------------------------------------------------
1 | #define THREADS 64
2 | #include "softMax.hlsl"
--------------------------------------------------------------------------------
/ComputeShaders/softMaxCompat.hlsl:
--------------------------------------------------------------------------------
1 | // ggml_compute_forward_soft_max_f32
2 | // Dispatch [ ( nr + 31 ) / 32, 1, 1 ] thread groups of this shader
3 | RWBuffer result: register( u0 );
4 |
5 | // table_exp_f16
6 | Buffer lookupTable: register( t0 );
7 |
8 | cbuffer Constants: register( b0 )
9 | {
10 | uint4 elements: packoffset( c0 );
11 | uint4 strides: packoffset( c1 );
12 | uint nr: packoffset( c2.x );
13 | }
14 |
15 | #include "miscUtils.hlsli"
16 | #include "fp64Utils.hlsli"
17 |
18 | static const float negativeInfinity = asfloat( 0xff800000 );
19 |
20 | [ numthreads( 32, 1, 1 ) ]
21 | void main( uint3 dtid: SV_DispatchThreadID )
22 | {
23 | if( dtid.x >= nr )
24 | return;
25 |
26 | const uint p = dtid.x * strides[ 1 ];
27 | const uint nc = elements[ 0 ];
28 | const uint pEnd = p + nc;
29 | uint i;
30 |
31 | float m = negativeInfinity;
32 | for( i = p; i < pEnd; i++ )
33 | m = max( m, result[ i ] );
34 |
35 | double sum = 0;
36 | for( i = p; i < pEnd; i++ )
37 | {
38 | float f = result[ i ];
39 |
40 | [branch]
41 | if( f != negativeInfinity )
42 | {
43 | uint s = fp16Rounded( f - m );
44 | s = lookupTable[ s ];
45 | f = f16tof32( s );
46 | sum += f;
47 | }
48 | else
49 | f = 0;
50 |
51 | result[ i ] = f;
52 | }
53 |
54 | const float scale = (float)div64( 1.0, sum );
55 | // ggml_vec_scale_f32
56 | for( i = p; i < pEnd; i++ )
57 | {
58 | float f = result[ i ];
59 | f *= scale;
60 | result[ i ] = f;
61 | }
62 | }
--------------------------------------------------------------------------------
/ComputeShaders/softMaxLong.hlsl:
--------------------------------------------------------------------------------
1 | // This version is for the "dec.probs" shader tag
2 | // The input tensor has a size [ 51865, 3 ], a very long tensor with just 3 rows.
3 | // Despite the shader only runs on 3 GPU cores, large count of threads helps substantially, this shader is about 50% faster.
4 | #define THREADS 1024
5 |
6 | #include "softMax.hlsl"
--------------------------------------------------------------------------------
/ComputeShaders/zeroMemory.hlsl:
--------------------------------------------------------------------------------
1 | RWBuffer result: register( u0 );
2 |
3 | cbuffer Constants: register( b0 )
4 | {
5 | uint elements: packoffset( c0.x );
6 | bool writeNan: packoffset( c0.y );
7 | }
8 |
9 | // Thread group index is 16 bits per coordinate:
10 | // https://learn.microsoft.com/en-us/windows/win32/api/d3d11/nf-d3d11-id3d11devicecontext-dispatch
11 | // We want this shader to support buffers up to 2 GB.
12 | #ifndef THREADS
13 | static const uint THREADS = 512;
14 | #endif
15 | #ifndef ITERATIONS
16 | static const uint ITERATIONS = 128;
17 | #endif
18 |
19 | static const uint itemsPerGroup = THREADS * ITERATIONS;
20 |
21 | [numthreads( THREADS, 1, 1 )]
22 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
23 | {
24 | uint rdi = group.x * itemsPerGroup;
25 | const uint rdiEnd = min( rdi + itemsPerGroup, elements );
26 | // https://www.h-schmidt.net/FloatConverter/IEEE754.html
27 | const float pattern = writeNan ? asfloat( 0x7FFFFFFFu ) : 0.0;
28 | for( rdi += thread; rdi < rdiEnd; rdi += THREADS )
29 | result[ rdi ] = pattern;
30 | }
--------------------------------------------------------------------------------
/Examples/MicrophoneCS/CaptureThread.cs:
--------------------------------------------------------------------------------
1 | using System.Runtime.ExceptionServices;
2 | using Whisper;
3 |
4 | namespace MicrophoneCS
5 | {
6 | sealed class CaptureThread: CaptureCallbacks
7 | {
8 | public CaptureThread( CommandLineArgs args, Context context, iAudioCapture source )
9 | {
10 | callbacks = new TranscribeCallbacks( args );
11 | this.context = context;
12 | this.source = source;
13 |
14 | thread = new Thread( threadMain ) { Name = "Capture Thread" };
15 | Console.WriteLine( "Press any key to quit" );
16 | thread.Start();
17 | }
18 |
19 | static void readKeyCallback( object? state )
20 | {
21 | CaptureThread ct = ( state as CaptureThread ) ?? throw new ApplicationException();
22 | Console.ReadKey();
23 | ct.shouldQuit = true;
24 | }
25 |
26 | public void join()
27 | {
28 | ThreadPool.QueueUserWorkItem( readKeyCallback, this );
29 | thread.Join();
30 | edi?.Throw();
31 | }
32 |
33 | volatile bool shouldQuit = false;
34 |
35 | protected override bool shouldCancel( Context sender ) =>
36 | shouldQuit;
37 |
38 | protected override void captureStatusChanged( Context sender, eCaptureStatus status )
39 | {
40 | Console.WriteLine( $"CaptureStatusChanged: {status}" );
41 | }
42 |
43 | readonly TranscribeCallbacks callbacks;
44 | readonly Thread thread;
45 | readonly Context context;
46 | readonly iAudioCapture source;
47 | ExceptionDispatchInfo? edi = null;
48 |
49 | void threadMain()
50 | {
51 | try
52 | {
53 | context.runCapture( source, callbacks, this );
54 | }
55 | catch( Exception ex )
56 | {
57 | edi = ExceptionDispatchInfo.Capture( ex );
58 | }
59 | }
60 | }
61 | }
--------------------------------------------------------------------------------
/Examples/MicrophoneCS/MicrophoneCS.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Exe
5 | net6.0-windows
6 | enable
7 | enable
8 | true
9 | false
10 | x64
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | PreserveNewest
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
--------------------------------------------------------------------------------
/Examples/MicrophoneCS/Readme.txt:
--------------------------------------------------------------------------------
1 | This example builds .NET 6 console application which shows how to use audio capture API of the .NET wrapper.
--------------------------------------------------------------------------------
/Examples/OldMain/OldMain.vcxproj.filters:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/Examples/OldMain/Readme.txt:
--------------------------------------------------------------------------------
1 | This project builds the original whisper.cpp command-line sample
--------------------------------------------------------------------------------
/Examples/OldMain/Utils/Logger.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include "Logger.h"
5 |
6 | namespace
7 | {
8 | void logMessage( const char* lvl, const char8_t* pszFormat, std::va_list va )
9 | {
10 | fprintf( stderr, "%s: ", lvl );
11 | vfprintf( stderr, (const char*)pszFormat, va );
12 | fprintf( stderr, "\n" );
13 | }
14 | }
15 |
16 | #define LOG_MESSAGE_IMPL( lvl ) \
17 | std::va_list args; \
18 | va_start( args, pszFormat ); \
19 | logMessage( lvl, pszFormat, args ); \
20 | va_end( args );
21 |
22 | void logError( const char8_t* pszFormat, ... )
23 | {
24 | LOG_MESSAGE_IMPL( "Error" );
25 | }
26 |
27 | void logWarning( const char8_t* pszFormat, ... )
28 | {
29 | LOG_MESSAGE_IMPL( "Warning" );
30 | }
31 |
32 | void logInfo( const char8_t* pszFormat, ... )
33 | {
34 | LOG_MESSAGE_IMPL( "Info" );
35 | }
36 |
37 | void logDebug( const char8_t* pszFormat, ... )
38 | {
39 | LOG_MESSAGE_IMPL( "Debug" );
40 | }
--------------------------------------------------------------------------------
/Examples/OldMain/Utils/Logger.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #ifdef __cplusplus
4 | extern "C" {
5 | #endif
6 |
7 | struct ggml_tensor;
8 |
9 | void logError( const char8_t* pszFormat, ... );
10 | void logWarning( const char8_t* pszFormat, ... );
11 | void logInfo( const char8_t* pszFormat, ... );
12 | void logDebug( const char8_t* pszFormat, ... );
13 |
14 | #ifdef __cplusplus
15 | }
16 |
17 | namespace Tracing
18 | {
19 | struct ItemName
20 | {
21 | ItemName( const char* str ) { }
22 | ItemName( const char* str, uint32_t a0 ) { }
23 | ItemName( const char* str, int a0 ) { }
24 | };
25 |
26 | inline void tensor( const ItemName& name, const ggml_tensor* tensor ) { }
27 | inline void delayTensor( const ItemName& name, const ggml_tensor* tensor ) { }
28 | inline void vector( const ItemName& name, const std::vector& vec ) { }
29 | inline void writeDelayedTensors() { }
30 | }
31 | #endif
--------------------------------------------------------------------------------
/Examples/TranscribeCS/Readme.txt:
--------------------------------------------------------------------------------
1 | This example builds .NET 6 console application which shows how to transcribe or translate audio files with the .NET wrapper.
--------------------------------------------------------------------------------
/Examples/TranscribeCS/TranscribeCS.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 | Exe
4 | net6.0-windows
5 | enable
6 | enable
7 | true
8 | false
9 | x64
10 |
11 |
12 |
13 | PreserveNewest
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/Examples/WhisperDesktop/AppState.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "Utils/DebugConsole.h"
3 |
4 | class AppState
5 | {
6 | bool coInit = false;
7 | CRegKey registryKey;
8 | CIcon appIcon;
9 | public:
10 |
11 | struct ModelSource
12 | {
13 | CString path;
14 | bool found = false;
15 | Whisper::eModelImplementation impl = (Whisper::eModelImplementation)0;
16 | uint64_t sizeInBytes = 0;
17 | };
18 | ModelSource source;
19 |
20 | DebugConsole console;
21 | CComPtr mediaFoundation;
22 | CComPtr model;
23 |
24 | ~AppState();
25 |
26 | // Setup the initial things
27 | HRESULT startup();
28 |
29 | HRESULT findModelSource();
30 |
31 | HRESULT saveModelSource();
32 |
33 | uint32_t languageRead();
34 | void languageWrite( uint32_t key );
35 |
36 | CString stringLoad( LPCTSTR name );
37 | void stringStore( LPCTSTR name, LPCTSTR value );
38 | uint32_t dwordLoad( LPCTSTR name, uint32_t fallback );
39 | void dwordStore( LPCTSTR name, uint32_t value );
40 | bool boolLoad( LPCTSTR name );
41 | void boolStore( LPCTSTR name, bool val );
42 |
43 | bool automaticallyLoadModel = true;
44 |
45 | void lastScreenSave( HRESULT code );
46 | HRESULT lastScreenLoad();
47 |
48 | void setupIcon( CWindow* wnd );
49 |
50 | uint32_t gpuFlagsLoad();
51 | void gpuFlagsStore( uint32_t flags );
52 | };
53 |
54 | constexpr HRESULT SCREEN_MODEL = 1;
55 | constexpr HRESULT SCREEN_TRANSCRIBE = 2;
56 | constexpr HRESULT SCREEN_CAPTURE = 3;
--------------------------------------------------------------------------------
/Examples/WhisperDesktop/CaptureDlg.cpp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Const-me/Whisper/306aadd1fce4b168cd38659236f4ba7c1603cebd/Examples/WhisperDesktop/CaptureDlg.cpp
--------------------------------------------------------------------------------
/Examples/WhisperDesktop/CircleIndicator.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "Utils/miscUtils.h"
3 | #include "Utils/WTL/atlcrack.h"
4 |
5 | // This control renders a black circle, and in the active state, the circle is filled with a bright color.
6 | class CircleIndicator: public CWindowImpl
7 | {
8 | public:
9 | static ATL::CWndClassInfo& GetWndClassInfo();
10 |
11 | BEGIN_MSG_MAP( CircleIndicator )
12 | MSG_WM_PAINT( onPaint )
13 | MSG_WM_DESTROY( onDestroy )
14 | END_MSG_MAP()
15 |
16 | // Class registration
17 | static HRESULT registerClass();
18 |
19 | void setActive( bool nowActive );
20 |
21 | void setActiveColor( uint32_t col )
22 | {
23 | activeColor = col;
24 | }
25 | CircleIndicator();
26 |
27 | private:
28 | bool isActive = false;
29 | uint32_t activeColor;
30 | int fontHeight = 0;
31 | CFont font;
32 | HRESULT createFont( int height );
33 |
34 | void onDestroy();
35 | void onPaint( CDCHandle dc );
36 | };
--------------------------------------------------------------------------------
/Examples/WhisperDesktop/ModelAdvancedDlg.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "AppState.h"
3 | #include "Utils/WTL/atlddx.h"
4 | #include "Utils/miscUtils.h"
5 |
6 | class ModelAdvancedDlg :
7 | public CDialogImpl
8 | {
9 | CComboBox cbWave, cbReshapedMatMul, cbAdapter;
10 | AppState& appState;
11 |
12 | public:
13 | static constexpr UINT IDD = IDD_MODEL_ADV;
14 |
15 | ModelAdvancedDlg( AppState& app ) : appState( app ) { }
16 |
17 | BEGIN_MSG_MAP( ModelAdvancedDlg )
18 | MESSAGE_HANDLER( WM_INITDIALOG, onInitDialog )
19 | ON_BUTTON_CLICK( IDOK, onOk )
20 | ON_BUTTON_CLICK( IDCANCEL, onCancel )
21 | END_MSG_MAP()
22 |
23 | bool show( HWND owner );
24 |
25 | private:
26 |
27 | LRESULT onInitDialog( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled );
28 |
29 | void onOk();
30 |
31 | void onCancel()
32 | {
33 | EndDialog( IDCANCEL );
34 | }
35 | };
--------------------------------------------------------------------------------
/Examples/WhisperDesktop/Readme.txt:
--------------------------------------------------------------------------------
1 | This example shows how to consume the DLL from a C++ GUI application.
2 |
3 | The GUI is implemented with ATL and WTL libraries.
--------------------------------------------------------------------------------
/Examples/WhisperDesktop/Utils/DebugConsole.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 | #include
5 |
6 | class AppState;
7 | class DebugConsole
8 | {
9 | using eLogLevel = Whisper::eLogLevel;
10 |
11 | struct Entry
12 | {
13 | eLogLevel level;
14 | CStringA message;
15 | HRESULT print( HANDLE hConsole, CString& tempString ) const;
16 | };
17 |
18 | CComAutoCriticalSection critSec;
19 | std::deque buffer;
20 | CString tempString;
21 | CHandle output;
22 |
23 | inline void logSink( eLogLevel lvl, const char* message );
24 | static void __stdcall logSinkStatic( void* context, eLogLevel lvl, const char* message );
25 |
26 | static BOOL __stdcall consoleHandlerRoutine( DWORD dwCtrlType );
27 |
28 | static DebugConsole* pGlobalInstance;
29 | void windowClosed();
30 |
31 | std::unordered_set checkboxes;
32 |
33 | CStringA tempStringA;
34 | void log( eLogLevel lvl, const char* pszFormat, va_list args );
35 |
36 | public:
37 | HRESULT initialize( eLogLevel level = eLogLevel::Debug );
38 | ~DebugConsole();
39 |
40 | HRESULT show();
41 | HRESULT hide();
42 | bool isVisible() const { return output; }
43 |
44 | void addCheckbox( CButton& cb );
45 | void removeCheckbox( CButton& cb );
46 |
47 | static void logMessage( eLogLevel lvl, const char* pszFormat, va_list args );
48 | };
49 |
50 | class ConsoleCheckbox
51 | {
52 | CButton control;
53 | DebugConsole* console = nullptr;
54 |
55 | public:
56 | HRESULT initialize( HWND dialog, int idc, AppState& state );
57 | void click();
58 | ~ConsoleCheckbox()
59 | {
60 | if( nullptr != console )
61 | console->removeCheckbox( control );
62 | }
63 | void ensureChecked();
64 | };
--------------------------------------------------------------------------------
/Examples/WhisperDesktop/Utils/LanguageDropdown.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "../AppState.h"
3 |
4 | // Dropdown list which implements language selector control
5 | class LanguageDropdown
6 | {
7 | HWND m_hWnd = nullptr;
8 | std::vector keys;
9 | int getInitialSelection( AppState& state ) const;
10 |
11 | public:
12 | operator HWND() const
13 | {
14 | return m_hWnd;
15 | }
16 |
17 | // Query language list form the native library, populate the combo box
18 | // Then load the last saved language selection from registry, and preselect an item.
19 | void initialize( HWND owner, int idc, AppState& state );
20 |
21 | // Get the ID of the currently selected language, or UINT_MAX if nothing's selected
22 | uint32_t selectedLanguage();
23 |
24 | // Get the ID of the currently selected language, and store in registry
25 | void saveSelection( AppState& state );
26 | };
--------------------------------------------------------------------------------
/Examples/WhisperDesktop/Utils/PendingState.cpp:
--------------------------------------------------------------------------------
1 | #include "stdafx.h"
2 | #include "PendingState.h"
3 |
4 | void PendingState::initialize( std::initializer_list editors, std::initializer_list pending )
5 | {
6 | editorsWindows = editors;
7 | wasEnabled.resize( editorsWindows.size() );
8 | pendingWindows = pending;
9 | }
10 |
11 | void PendingState::setPending( bool nowPending )
12 | {
13 | if( nowPending )
14 | {
15 | for( size_t i = 0; i < editorsWindows.size(); i++ )
16 | {
17 | BOOL e = IsWindowEnabled( editorsWindows[ i ] );
18 | if( e )
19 | {
20 | wasEnabled[ i ] = true;
21 | EnableWindow( editorsWindows[ i ], FALSE );
22 | }
23 | else
24 | wasEnabled[ i ] = false;
25 | }
26 | }
27 | else
28 | {
29 | for( size_t i = 0; i < editorsWindows.size(); i++ )
30 | {
31 | if( !wasEnabled[ i ] )
32 | continue;
33 | EnableWindow( editorsWindows[ i ], TRUE );
34 | }
35 | }
36 |
37 | const int show = nowPending ? SW_NORMAL : SW_HIDE;
38 | for( HWND w : pendingWindows )
39 | ::ShowWindow( w, show );
40 | }
--------------------------------------------------------------------------------
/Examples/WhisperDesktop/Utils/PendingState.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | // Utility class to switch visual state of dialog controls between idle and pending
4 | class PendingState
5 | {
6 | std::vector editorsWindows;
7 | std::vector wasEnabled;
8 | std::vector pendingWindows;
9 | public:
10 | void initialize( std::initializer_list editors, std::initializer_list pending );
11 | void setPending( bool nowPending );
12 | };
--------------------------------------------------------------------------------
/Examples/WhisperDesktop/Utils/TranslateCheckbox.cpp:
--------------------------------------------------------------------------------
1 | #include "stdafx.h"
2 | #include "TranslateCheckbox.h"
3 |
4 | static const LPCTSTR regValTranslate = L"translate";
5 |
6 | void TranslateCheckbox::initialize( HWND owner, int idc, AppState& state )
7 | {
8 | m_hWnd = GetDlgItem( owner, idc );
9 | assert( nullptr != m_hWnd );
10 |
11 | if( state.boolLoad( regValTranslate ) )
12 | ::SendMessage( m_hWnd, BM_SETCHECK, BST_CHECKED, 0L );
13 | }
14 |
15 | bool TranslateCheckbox::checked()
16 | {
17 | assert( nullptr != m_hWnd );
18 | const int state = ( int )::SendMessage( m_hWnd, BM_GETCHECK, 0, 0 );
19 | return state == BST_CHECKED;
20 | }
21 |
22 | void TranslateCheckbox::saveSelection( AppState& state )
23 | {
24 | state.boolStore( regValTranslate, checked() );
25 | }
--------------------------------------------------------------------------------
/Examples/WhisperDesktop/Utils/TranslateCheckbox.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "../AppState.h"
3 |
4 | class TranslateCheckbox
5 | {
6 | HWND m_hWnd = nullptr;
7 | public:
8 | operator HWND() const
9 | {
10 | return m_hWnd;
11 | }
12 |
13 | void initialize( HWND owner, int idc, AppState& state );
14 |
15 | bool checked();
16 |
17 | void saveSelection( AppState& state );
18 | };
--------------------------------------------------------------------------------
/Examples/WhisperDesktop/Utils/logger.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | void logMessage( Whisper::eLogLevel lvl, const char8_t* pszFormat, va_list args );
6 |
7 | #define LOG_MESSAGE_IMPL( lvl ) \
8 | std::va_list args; \
9 | va_start( args, pszFormat ); \
10 | logMessage( lvl, pszFormat, args ); \
11 | va_end( args )
12 |
13 | inline void logError( const char8_t* pszFormat, ... )
14 | {
15 | LOG_MESSAGE_IMPL( Whisper::eLogLevel::Error );
16 | }
17 | inline void logWarning( const char8_t* pszFormat, ... )
18 | {
19 | LOG_MESSAGE_IMPL( Whisper::eLogLevel::Warning );
20 | }
21 | inline void logInfo( const char8_t* pszFormat, ... )
22 | {
23 | LOG_MESSAGE_IMPL( Whisper::eLogLevel::Info );
24 | }
25 | inline void logDebug( const char8_t* pszFormat, ... )
26 | {
27 | LOG_MESSAGE_IMPL( Whisper::eLogLevel::Debug );
28 | }
29 |
30 | #undef LOG_MESSAGE_IMPL
31 |
32 | HRESULT logNewSegments( const Whisper::iTranscribeResult* results, size_t newSegments, bool printSpecial = false );
33 |
34 | void clearLastError();
35 | bool getLastError( CString& rdi );
36 |
37 | void printTime( CStringA& rdi, Whisper::sTimeSpan time, bool comma = false );
--------------------------------------------------------------------------------
/Examples/WhisperDesktop/WhisperDesktop.cpp:
--------------------------------------------------------------------------------
1 | #include "stdafx.h"
2 | #include "AppState.h"
3 | #include "Utils/miscUtils.h"
4 | #include "LoadModelDlg.h"
5 | #include "TranscribeDlg.h"
6 | #include "CaptureDlg.h"
7 |
8 | static HRESULT dialogLoadModel( AppState& appState )
9 | {
10 | LoadModelDlg loadDialog{ appState };
11 | HRESULT hr = loadDialog.show();
12 | if( FAILED( hr ) )
13 | {
14 | reportFatalError( "Error loading the model", hr );
15 | return hr;
16 | }
17 | appState.automaticallyLoadModel = false;
18 | return hr;
19 | }
20 |
21 | static HRESULT dialogTranscribe( AppState& appState )
22 | {
23 | TranscribeDlg dialog{ appState };
24 | return dialog.show();
25 | }
26 |
27 | static HRESULT dialogCapture( AppState& appState )
28 | {
29 | CaptureDlg dialog{ appState };
30 | return dialog.show();
31 | }
32 |
33 | using pfnDialog = HRESULT( * )( AppState& appState );
34 | static const std::array s_dialogs =
35 | {
36 | nullptr, // S_OK
37 | &dialogLoadModel, // SCREEN_MODEL
38 | &dialogTranscribe, // SCREEN_TRANSCRIBE
39 | &dialogCapture, // SCREEN_CAPTURE
40 | };
41 |
42 | int __stdcall wWinMain( HINSTANCE hInstance, HINSTANCE hPrevInstance, LPWSTR lpCmdLine, int nCmdShow )
43 | {
44 | AppState appState;
45 | HRESULT hr = appState.startup();
46 | if( FAILED( hr ) )
47 | return hr;
48 |
49 | appState.findModelSource();
50 |
51 | hr = SCREEN_MODEL;
52 | while( true )
53 | {
54 | pfnDialog pfn = s_dialogs[ hr ];
55 | if( nullptr == pfn )
56 | return S_OK;
57 | hr = pfn( appState );
58 | if( FAILED( hr ) )
59 | return hr;
60 | if( hr == SCREEN_MODEL )
61 | appState.model = nullptr;
62 | }
63 | }
--------------------------------------------------------------------------------
/Examples/WhisperDesktop/WhisperDesktop.manifest:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Your application description here.
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 | true
13 | PerMonitorV2
14 |
15 |
16 |
--------------------------------------------------------------------------------
/Examples/WhisperDesktop/WhisperDesktop.rc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Const-me/Whisper/306aadd1fce4b168cd38659236f4ba7c1603cebd/Examples/WhisperDesktop/WhisperDesktop.rc
--------------------------------------------------------------------------------
/Examples/WhisperDesktop/framework.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #define WIN32_LEAN_AND_MEAN // Exclude rarely-used stuff from Windows headers
3 | #define NOMINMAX
4 | // Windows Header Files
5 | #include "targetver.h"
6 | #include
7 | // ATL header files
8 | #include
9 | #include
10 | #include
11 | #include
12 |
13 | // C RunTime Header Files
14 | #include
15 | #include
16 | #include
17 | #include
18 | #include
19 | // C++ headers
20 | #include
21 | #include
22 | #include
--------------------------------------------------------------------------------
/Examples/WhisperDesktop/stdafx.cpp:
--------------------------------------------------------------------------------
1 | #include "stdafx.h"
--------------------------------------------------------------------------------
/Examples/WhisperDesktop/stdafx.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "framework.h"
3 |
4 | #include
5 |
6 | #include "resource.h"
7 | #include "Utils/WTL/atlapp.h"
8 | #include "Utils/WTL/atlctrls.h"
--------------------------------------------------------------------------------
/Examples/WhisperDesktop/sunflower.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Const-me/Whisper/306aadd1fce4b168cd38659236f4ba7c1603cebd/Examples/WhisperDesktop/sunflower.ico
--------------------------------------------------------------------------------
/Examples/WhisperDesktop/targetver.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | // Setup Windows SDK to only enable features available since Windows 8.0
3 | #include
4 | #define _WIN32_WINNT _WIN32_WINNT_WIN8
5 | #define NTDDI_VERSION NTDDI_WIN8
6 | #include
7 |
--------------------------------------------------------------------------------
/Examples/WhisperDesktop/useDiscreteGpu.c:
--------------------------------------------------------------------------------
1 | __declspec( dllexport ) int NvOptimusEnablement = 1;
2 | __declspec( dllexport ) int AmdPowerXpressRequestHighPerformance = 1;
--------------------------------------------------------------------------------
/Examples/main/Readme.txt:
--------------------------------------------------------------------------------
1 | This example shows how to consume the DLL from a C++ console application.
2 |
3 | The command-line interface matches the corresponding example from whisper.cpp project.
--------------------------------------------------------------------------------
/Examples/main/main.vcxproj.filters:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/Examples/main/miscUtils.cpp:
--------------------------------------------------------------------------------
1 | #include "miscUtils.h"
2 | #define WIN32_LEAN_AND_MEAN
3 | #include
4 |
5 | std::string utf8( const std::wstring& utf16 )
6 | {
7 | int count = WideCharToMultiByte( CP_UTF8, 0, utf16.c_str(), (int)utf16.length(), nullptr, 0, nullptr, nullptr );
8 | std::string str( count, 0 );
9 | WideCharToMultiByte( CP_UTF8, 0, utf16.c_str(), -1, &str[ 0 ], count, nullptr, nullptr );
10 | return str;
11 | }
12 |
13 | std::wstring utf16( const std::string& u8 )
14 | {
15 | int count = MultiByteToWideChar( CP_UTF8, 0, u8.c_str(), (int)u8.length(), nullptr, 0 );
16 | std::wstring str( count, 0 );
17 | MultiByteToWideChar( CP_UTF8, 0, u8.c_str(), (int)u8.length(), &str[ 0 ], count );
18 | return str;
19 | }
20 |
21 | namespace
22 | {
23 | wchar_t* formatMessage( HRESULT hr )
24 | {
25 | wchar_t* err;
26 | if( FormatMessage( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM,
27 | NULL,
28 | hr,
29 | MAKELANGID( LANG_NEUTRAL, SUBLANG_DEFAULT ),
30 | (LPTSTR)&err,
31 | 0,
32 | NULL ) )
33 | return err;
34 | return nullptr;
35 | }
36 | }
37 |
38 | void printError( const char* what, HRESULT hr )
39 | {
40 | const wchar_t* err = formatMessage( hr );
41 | if( nullptr != err )
42 | {
43 | fwprintf( stderr, L"%S: %s\n", what, err );
44 | LocalFree( (HLOCAL)err );
45 | }
46 | else
47 | fprintf( stderr, "%s: error code %i (0x%08X)\n", what, hr, hr );
48 | }
--------------------------------------------------------------------------------
/Examples/main/miscUtils.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | std::string utf8( const std::wstring& utf16 );
5 |
6 | std::wstring utf16( const std::string& u8 );
7 |
8 | using HRESULT = long;
9 | void printError( const char* what, HRESULT hr );
--------------------------------------------------------------------------------
/Examples/main/params.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | // command-line parameters
6 | struct whisper_params
7 | {
8 | uint32_t n_threads;
9 | uint32_t n_processors = 1;
10 | uint32_t offset_t_ms = 0;
11 | uint32_t offset_n = 0;
12 | uint32_t duration_ms = 0;
13 | uint32_t max_context = UINT_MAX;
14 | uint32_t max_len = 0;
15 |
16 | float word_thold = 0.01f;
17 |
18 | bool speed_up = false;
19 | bool translate = false;
20 | bool diarize = false;
21 | bool output_txt = false;
22 | bool output_vtt = false;
23 | bool output_srt = false;
24 | bool output_wts = false;
25 | bool print_special = false;
26 | bool print_colors = true;
27 | bool no_timestamps = false;
28 |
29 | std::string language = "en";
30 | std::wstring model = L"models/ggml-base.en.bin";
31 | std::wstring gpu;
32 | std::string prompt;
33 | std::vector fname_inp;
34 |
35 | whisper_params();
36 |
37 | bool parse( int argc, wchar_t* argv[] );
38 | };
39 |
40 | void whisper_print_usage( int argc, wchar_t** argv, const whisper_params& params );
--------------------------------------------------------------------------------
/Examples/main/textWriter.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "../../Whisper/API/iContext.cl.h"
3 |
4 | // These functions print output segments into text files of various formats
5 | HRESULT writeText( Whisper::iContext* context, LPCTSTR audioPath, bool timestamps );
6 | HRESULT writeSubRip( Whisper::iContext* context, LPCTSTR audioPath );
7 | HRESULT writeWebVTT( Whisper::iContext* context, LPCTSTR audioPath );
--------------------------------------------------------------------------------
/SampleClips/columbia.wma:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Const-me/Whisper/306aadd1fce4b168cd38659236f4ba7c1603cebd/SampleClips/columbia.wma
--------------------------------------------------------------------------------
/SampleClips/jfk.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Const-me/Whisper/306aadd1fce4b168cd38659236f4ba7c1603cebd/SampleClips/jfk.wav
--------------------------------------------------------------------------------
/SampleClips/summary.tsv:
--------------------------------------------------------------------------------
1 | Audio Clip Model GPU Total, sec Relative speed Encode, sec Decode, sec RAM, MB VRAM, MB
2 | jfk.wav medium GeForce 1080Ti 1.13909 9.6568 0.599964 0.451832 2.84049 2185.90208
3 | jfk.wav medium GeForce 1650 3.16174 3.4791 1.95373 0.987441 2.84049 2185.90208
4 | jfk.wav medium Ryzen 5 5600U 8.79853 1.2502 7.34021 1.30925 2.84049 2233.35424
5 | jfk.wav medium Ryzen 7 5700G 4.95485 2.2200 3.82894 1.03424 2.84049 2233.35424
6 | jfk.wav large GeForce 1080Ti 2.19628 5.0085 1.29615 0.7739170000000001 2.85543 4052.89984
7 | jfk.wav large GeForce 1650 8.33686 1.3194 3.98133 4.07331 2.85543 4052.89984
8 | jfk.wav large Ryzen 5 5600U 14.2729 0.7707 11.9404 2.12941 2.85543 4112.35328
9 | jfk.wav large Ryzen 7 5700G 9.46787 1.1618 7.35144 2.01739 2.85543 4112.35328
10 | columbia.wma medium GeForce 1080Ti 14.9475 13.2973 6.01034 8.78676 91.929 2247.34208
11 | columbia.wma medium GeForce 1650 48.7479 4.0773 19.2258 29.233 91.929 2247.34208
12 | columbia.wma medium Ryzen 5 5600U 81.256 2.4461 51.1656 29.6502 91.929 2295.52128
13 | columbia.wma medium Ryzen 7 5700G 62.1145 3.1999 37.8702 23.9262 91.929 2295.52128
14 | columbia.wma large GeForce 1080Ti 27.5329 7.2191 11.3967 15.9412 93.1329 4118.28224
15 | columbia.wma large GeForce 1650 109.423 1.8165 36.3141 72.8459 93.1329 4118.28224
16 | columbia.wma large Ryzen 5 5600U 140.747 1.4122 88.7441 51.4306 93.1329 4178.78016
17 | columbia.wma large Ryzen 7 5700G 110.474 1.7992 65.7998 44.3232 93.1329 4178.78016
18 |
--------------------------------------------------------------------------------
/Tools/CompressShaders/CompressShaders.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 | Exe
4 | net6.0
5 | enable
6 | enable
7 | true
8 | false
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/Tools/CompressShaders/DetectFp64.cs:
--------------------------------------------------------------------------------
1 | #pragma warning disable CS0649
2 | using System.Runtime.InteropServices;
3 |
4 | namespace CompressShaders
5 | {
6 | static class DetectFp64
7 | {
8 | struct DXBCHeader
9 | {
10 | public uint FourCC; // Four character code "DXBC"
11 | public uint Hash0; // 32-bit hash of the DXBC file
12 | public uint Hash1; // 32-bit hash of the DXBC file
13 | public uint Hash2; // 32-bit hash of the DXBC file
14 | public uint Hash3; // 32-bit hash of the DXBC file
15 | public uint unknownOne;
16 | public uint TotalSize; // Total size of the DXBC file in bytes
17 | public int NumChunks; // Number of chunks in the DXBC file
18 | };
19 |
20 | public static bool usesFp64( ReadOnlySpan dxbc )
21 | {
22 | ReadOnlySpan dxbcHeaderSpan = MemoryMarshal.Cast( dxbc );
23 | DXBCHeader dxbcHeader = dxbcHeaderSpan[ 0 ];
24 |
25 | int cbHeader = Marshal.SizeOf();
26 | int nChunks = dxbcHeader.NumChunks;
27 | ReadOnlySpan chunkOffsets = MemoryMarshal.Cast( dxbc.Slice( cbHeader, nChunks * 4 ) );
28 | foreach( int off in chunkOffsets )
29 | {
30 | uint id = MemoryMarshal.Cast( dxbc.Slice( off, 4 ) )[ 0 ];
31 | const uint SFI0 = 0x30494653;
32 | if( id != SFI0 )
33 | continue;
34 | int size = MemoryMarshal.Cast( dxbc.Slice( off + 4, 4 ) )[ 0 ];
35 | if( size < 4 )
36 | throw new ApplicationException();
37 | uint data = MemoryMarshal.Cast( dxbc.Slice( off + 8, 4 ) )[ 0 ];
38 | return 0 != ( data & 1u );
39 | }
40 | return false;
41 | }
42 | }
43 | }
--------------------------------------------------------------------------------
/Tools/CompressShaders/LZ4.cs:
--------------------------------------------------------------------------------
1 | namespace CompressShaders;
2 | using K4os.Compression.LZ4;
3 |
4 | /// Lossless data compressor which uses LZ4-HC compressor
5 | ///
6 | ///
7 | static class LZ4
8 | {
9 | // compression speed drops rapidly when not using FAST mode, while decompression speed stays the same
10 | // Actually, it is usually faster for high compression levels as there is less data to process
11 | // https://github.com/MiloszKrajewski/K4os.Compression.LZ4#compression-levels
12 | const LZ4Level compressionLevel = LZ4Level.L12_MAX;
13 |
14 | public static byte[] compressBuffer( byte[] src )
15 | {
16 | int maxLength = LZ4Codec.MaximumOutputSize( src.Length );
17 | byte[] output = new byte[ maxLength ];
18 | int cb = LZ4Codec.Encode( src, output, compressionLevel );
19 | if( cb > 0 )
20 | {
21 | Array.Resize( ref output, cb );
22 | return output;
23 | }
24 | throw new ApplicationException( $"LZ4Codec.Encode failed with status {cb}" );
25 | }
26 | }
--------------------------------------------------------------------------------
/Tools/CompressShaders/Readme.txt:
--------------------------------------------------------------------------------
1 | This project builds a C# console app which serves as a code generator for a few pieces of Whisper.dll and WhisperNet.dll.
2 |
3 | Specifically, it generates two things.
4 |
5 | 1. It compresses the compiled DXBC shaders into a blob of bytes, and prints std::array with these bytes into shaderData-Release.inl and shaderData-Debug.inl C++ files.
6 |
7 | 2. It parses the `languageCodez.tsv`, and generates both C++ and C# code with the data from that table.
8 |
9 | The tool uses relative paths across source files.
10 | These paths will break if you move the source of the tool, or the source data of the tool.
--------------------------------------------------------------------------------
/Tools/CompressShaders/ShaderNames.cs:
--------------------------------------------------------------------------------
1 | static class ShaderNames
2 | {
3 | public static void write( string path, IEnumerable names )
4 | {
5 | string[] arr = names.ToArray();
6 | using var stream = File.CreateText( path );
7 | stream.WriteLine( @"// This source file is generated by a tool
8 | #include ""stdafx.h""
9 | #include ""shaderNames.h""
10 | " );
11 |
12 | stream.WriteLine( "static const std::array s_shaderNames = ", arr.Length );
13 | stream.WriteLine( "{" );
14 | foreach( string name in arr )
15 | stream.WriteLine( @" ""{0}"",", name );
16 |
17 | stream.Write( @"};
18 |
19 | const char* DirectCompute::computeShaderName( eComputeShader cs )
20 | {
21 | const uint16_t i = (uint16_t)cs;
22 | if( i < s_shaderNames.size() )
23 | return s_shaderNames[ i ];
24 | return nullptr;
25 | }" );
26 | }
27 | }
--------------------------------------------------------------------------------
/Tools/CompressTables/CompressTables.cs:
--------------------------------------------------------------------------------
1 | namespace CompressTables;
2 | using CompressShaders;
3 | using System.IO;
4 | using System.Runtime.CompilerServices;
5 |
6 | /// Utility app to compress lookup tables data with LZ4-HC, and generate C++ source with the data
7 | internal class Program
8 | {
9 | static string getSolutionRoot( [CallerFilePath] string? path = null )
10 | {
11 | string? dir = Path.GetDirectoryName( path );
12 | dir = Path.GetDirectoryName( dir );
13 | dir = Path.GetDirectoryName( dir );
14 | return dir ?? throw new ApplicationException();
15 | }
16 |
17 | static void writeArray( byte[] compressed, string path )
18 | {
19 | using var stream = File.CreateText( path );
20 | stream.WriteLine( "// This source file is generated by a tool" );
21 | stream.Write( @"static const std::array s_tableData = {{", compressed.Length );
22 |
23 | for( int i = 0; i < compressed.Length; i++ )
24 | {
25 | if( 0 == i % 32 )
26 | stream.Write( "\r\n\t" );
27 | else
28 | stream.Write( ' ' );
29 | stream.Write( "0x{0:X02},", compressed[ i ] );
30 | }
31 | stream.Write( @"
32 | };" );
33 | }
34 |
35 | static void Main( string[] args )
36 | {
37 | byte[] source = File.ReadAllBytes( @"C:\Temp\2remove\Whisper\tables.bin" );
38 | byte[] result = LZ4.compressBuffer( source );
39 |
40 | string root = getSolutionRoot();
41 | string path = Path.Combine( root, "Whisper", "ML", "LookupTablesData.inl" );
42 | writeArray( result, path );
43 | }
44 | }
--------------------------------------------------------------------------------
/Tools/CompressTables/CompressTables.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 | Exe
4 | net7.0
5 | enable
6 | enable
7 | true
8 | false
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/Tools/PerfSummary/PerfSummary.cs:
--------------------------------------------------------------------------------
1 | using System.Runtime.CompilerServices;
2 |
3 | namespace PerfSummary
4 | {
5 | internal class Program
6 | {
7 | static string getSolutionRoot( [CallerFilePath] string? path = null )
8 | {
9 | string? dir = Path.GetDirectoryName( path );
10 | dir = Path.GetDirectoryName( dir );
11 | dir = Path.GetDirectoryName( dir );
12 | return dir ?? throw new ApplicationException();
13 | }
14 |
15 | static void Main( string[] args )
16 | {
17 | string root = getSolutionRoot();
18 | root = Path.Combine( root, "SampleClips" );
19 |
20 | LogData[] logs = LogParser.parse( root )
21 | .OrderBy( x => x.name.clip )
22 | .ThenBy( x => x.name.model )
23 | .ThenBy( x => x.name.gpu )
24 | .ToArray();
25 |
26 | Summary.print( logs, root );
27 | }
28 | }
29 | }
--------------------------------------------------------------------------------
/Tools/PerfSummary/PerfSummary.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 | Exe
4 | net6.0
5 | enable
6 | enable
7 | true
8 | false
9 |
10 |
--------------------------------------------------------------------------------
/Tools/compareTraces/CommandLineArgs.cpp:
--------------------------------------------------------------------------------
1 | #include "stdafx.h"
2 | #include "CommandLineArgs.h"
3 | #include
4 |
5 | static bool printUsage()
6 | {
7 | fprintf( stderr, "Usage: compareTraces.exe trace1.bin trace2.bin [-diff N]\n" );
8 | return false;
9 | }
10 |
11 | bool CommandLineArgs::parse( int argc, wchar_t* argv[] )
12 | {
13 | size_t idx = 0;
14 | CString sw;
15 | CStringA tmp;
16 | for( int i = 1; i < argc; i++ )
17 | {
18 | if( argv[ i ][ 0 ] != L'-' )
19 | {
20 | if( idx >= 2 )
21 | return printUsage();
22 | inputs[ idx ] = argv[ i ];
23 | idx++;
24 | continue;
25 | }
26 | sw = argv[ i ];
27 | if( 0 == sw.CompareNoCase( L"-diff" ) )
28 | {
29 | i++;
30 | if( i >= argc )
31 | return printUsage();
32 | tmp.Format( "%S", argv[ i ] );
33 | tmp.Trim();
34 | uint64_t v;
35 | auto res = std::from_chars( tmp, cstr( tmp ) + tmp.GetLength(), v );
36 | if( res.ec != (std::errc)0 )
37 | {
38 | fprintf( stderr, "Unable to parse string into number\n" );
39 | return false;
40 | }
41 | printDiff = v;
42 | continue;
43 | }
44 | return printUsage();
45 | }
46 |
47 | if( idx != 2 )
48 | return printUsage();
49 |
50 | return true;
51 | }
--------------------------------------------------------------------------------
/Tools/compareTraces/CommandLineArgs.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | struct CommandLineArgs
4 | {
5 | int64_t printDiff = -1;
6 | std::array inputs;
7 |
8 | bool parse( int argc, wchar_t* argv[] );
9 | };
--------------------------------------------------------------------------------
/Tools/compareTraces/Readme.txt:
--------------------------------------------------------------------------------
1 | This project builds a C++ console tool which compares debug traces of the model.
2 |
3 | Tracing files easily exceed 1GB, and by default they’re disabled with a preprocessor macro in stdafx.h of the Whisper project.
4 |
5 | When enabled, the main GPU implementation saves a trace into C:\Temp\2remove\Whisper\gpu.bin
6 |
7 | The reference CPU implementation saves a trace into C:\Temp\2remove\Whisper\ref.bin
8 |
9 | This code in this project is optimized for development speed. For this reason it requires AVX2 CPU, uses memory-mapped IO instead of proper parsing, and checks little to no errors.
--------------------------------------------------------------------------------
/Tools/compareTraces/TraceReader.cpp:
--------------------------------------------------------------------------------
1 | #include "stdafx.h"
2 | #include "TraceReader.h"
3 | using namespace Tracing;
4 |
5 | const sTraceItem& TraceReader::operator[]( size_t idx ) const
6 | {
7 | if( idx >= countItems )
8 | throw E_BOUNDS;
9 | return items[ idx ];
10 | }
11 |
12 | CStringA TraceReader::getName( const sTraceItem& item ) const
13 | {
14 | const size_t idx = item.stringIndex;
15 | if( idx >= countStrings )
16 | throw E_BOUNDS;
17 | const char* const source = stringData + stringIndex[ idx ];
18 | CStringA res;
19 | res.Format( source, item.formatArgs[ 0 ], item.formatArgs[ 1 ], item.formatArgs[ 2 ], item.formatArgs[ 3 ] );
20 | return res;
21 | }
22 |
23 | HRESULT TraceReader::open( LPCTSTR path )
24 | {
25 | CHECK( file.Create( path, GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING ) );
26 | CHECK( mapping.MapFile( file ) );
27 |
28 | const uint8_t* rsi = mapping;
29 | const sFileHeader& header = *(const sFileHeader*)rsi;
30 | if( header.magic != header.correctMagic )
31 | return E_INVALIDARG;
32 | countItems = header.countItems;
33 | countStrings = header.countStrings;
34 |
35 | rsi += sizeof( sFileHeader );
36 | payloadPointer = rsi;
37 |
38 | rsi += header.bytesPayload;
39 | stringIndex = (const uint32_t*)( rsi );
40 | stringData = (const char*)( rsi + countStrings * 4 );
41 |
42 | rsi += header.bytesStrings;
43 | items = (const sTraceItem*)rsi;
44 |
45 | return S_OK;
46 | }
--------------------------------------------------------------------------------
/Tools/compareTraces/TraceReader.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "../../Whisper/Utils/Trace/TraceStructures.h"
3 | #include
4 | #include
5 |
6 | namespace Tracing
7 | {
8 | class TraceReader
9 | {
10 | const uint8_t* payloadPointer = nullptr;
11 | const sTraceItem* items = nullptr;
12 | size_t countItems = 0;
13 | size_t countStrings = 0;
14 | const uint32_t* stringIndex = nullptr;
15 | const char* stringData = nullptr;
16 |
17 | CAtlFile file;
18 | CAtlFileMapping mapping;
19 |
20 | public:
21 |
22 | TraceReader() = default;
23 | ~TraceReader() = default;
24 |
25 | HRESULT open( LPCTSTR path );
26 | size_t size() const { return countItems; }
27 | const sTraceItem& operator[]( size_t idx ) const;
28 | CStringA getName( const sTraceItem& item ) const;
29 |
30 | const void* payload( const sTraceItem& item ) const
31 | {
32 | return payloadPointer + item.payloadOffset;
33 | }
34 | };
35 | }
--------------------------------------------------------------------------------
/Tools/compareTraces/compare.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "CommandLineArgs.h"
3 |
4 | HRESULT compareTraces( const CommandLineArgs& arguments );
--------------------------------------------------------------------------------
/Tools/compareTraces/compareTraces.cpp:
--------------------------------------------------------------------------------
1 | #include "stdafx.h"
2 | #include
3 | #include "compare.h"
4 | #include "CommandLineArgs.h"
5 |
6 | int wmain( int argc, wchar_t* argv[] )
7 | {
8 | CommandLineArgs cla;
9 | if( !cla.parse( argc, argv ) )
10 | return 1;
11 |
12 | HRESULT hr = compareTraces( cla );
13 | if( SUCCEEDED( hr ) )
14 | return 0;
15 | return hr;
16 | }
--------------------------------------------------------------------------------
/Tools/compareTraces/compareTraces.vcxproj.filters:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/Tools/compareTraces/stdafx.cpp:
--------------------------------------------------------------------------------
1 | #include "stdafx.h"
2 |
3 | namespace
4 | {
5 | wchar_t* formatMessage( HRESULT hr )
6 | {
7 | wchar_t* err;
8 | if( FormatMessage( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM,
9 | NULL,
10 | hr,
11 | MAKELANGID( LANG_NEUTRAL, SUBLANG_DEFAULT ),
12 | (LPTSTR)&err,
13 | 0,
14 | NULL ) )
15 | return err;
16 | return nullptr;
17 | }
18 | }
19 |
20 | void printError( HRESULT hr )
21 | {
22 | const wchar_t* err = formatMessage( hr );
23 | if( nullptr != err )
24 | {
25 | fwprintf( stderr, L"%s\n", err );
26 | LocalFree( (HLOCAL)err );
27 | }
28 | else
29 | fprintf( stderr, "Error code %i (0x%08X)\n", hr, hr );
30 | }
--------------------------------------------------------------------------------
/Tools/compareTraces/stdafx.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | #define WIN32_LEAN_AND_MEAN
6 | #define NOMINMAX
7 | #include
8 | #include
9 | #include
10 |
11 | #include
12 | #include
13 | #include
14 | #include
15 |
16 | #define CHECK( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) return __hr; }
17 |
18 | inline __m128i load16( const int* rsi )
19 | {
20 | return _mm_loadu_si128( ( const __m128i* )rsi );
21 | }
22 | inline __m128i load16( const uint32_t* rsi )
23 | {
24 | return _mm_loadu_si128( ( const __m128i* )rsi );
25 | }
26 | inline __m128i load( const std::array& arr )
27 | {
28 | return load16( arr.data() );
29 | }
30 |
31 | inline bool vectorEqual( __m128i a, __m128i b )
32 | {
33 | __m128i xx = _mm_xor_si128( a, b );
34 | return (bool)_mm_testz_si128( xx, xx );
35 | }
36 |
37 | void printError( HRESULT hr );
38 |
39 | inline const char* cstr( const CStringA& s ) { return s; }
40 | inline const wchar_t* cstr( const CString& s ) { return s; }
--------------------------------------------------------------------------------
/Whisper/API/MfStructs.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | namespace Whisper
4 | {
5 | struct sCaptureDevice
6 | {
7 | // The display name is suitable for showing to the user, but might not be unique.
8 | const wchar_t* displayName;
9 |
10 | // Endpoint ID for an audio capture device
11 | // It uniquely identifies the device on the system, but is not a readable string.
12 | const wchar_t* endpoint;
13 | };
14 |
15 | using pfnFoundCaptureDevices = HRESULT( __stdcall* )( int len, const sCaptureDevice* buffer, void* pv );
16 |
17 | // Flags for the audio capture
18 | enum struct eCaptureFlags : uint32_t
19 | {
20 | // When the capture device supports stereo, keep stereo PCM samples in addition to mono
21 | Stereo = 1,
22 | };
23 |
24 | // Parameters for audio capture
25 | struct sCaptureParams
26 | {
27 | float minDuration = 2.0f;
28 | float maxDuration = 3.0f;
29 | float dropStartSilence = 0.25f;
30 | float pauseDuration = 0.333f;
31 | // Flags for the audio capture
32 | uint32_t flags = 0;
33 | };
34 |
35 | enum struct eCaptureStatus : uint8_t
36 | {
37 | Listening = 1,
38 | Voice = 2,
39 | Transcribing = 4,
40 | Stalled = 0x80,
41 | };
42 |
43 | // Return S_OK to continue, or S_FALSE to stop the capture session
44 | using pfnShouldCancel = HRESULT( __stdcall* )( void* pv ) noexcept;
45 |
46 | using pfnCaptureStatus = HRESULT( __stdcall* )( void* pv, eCaptureStatus status ) noexcept;
47 |
48 | struct sCaptureCallbacks
49 | {
50 | pfnShouldCancel shouldCancel;
51 | pfnCaptureStatus captureStatus;
52 | void* pv;
53 | };
54 | }
--------------------------------------------------------------------------------
/Whisper/API/Readme.txt:
--------------------------------------------------------------------------------
1 | The headers in this folder define the complete public API of Whisper.dll.
2 |
3 | To consume the library in your C++ software, include exactly one of the following headers.
4 |
5 | 1. If you’re building a windows app, include whisperWindows.h header, and you'll get traditional Win32 COM projection of the API.
6 |
7 | 2. If you’re porting to other OS, or porting to different C++ compiler, or already using ComLight support library, include whisperComLight.h header.
8 | If you do that, in addition to this "Whisper/API" folder you also gonna need the "ComLightLib" dependency.
9 | This will get you the ComLight flavor of these COM interfaces.
10 |
11 | Internally, the actual implementation uses the ComLight flavour of the interfaces, but that’s fine because they are binary compatible.
12 |
13 | The reason for the difference between these flavors — Visual Studio’s CComPtr and other related utilities expect interface IDs specified with __declspec(uuid) directive.
14 |
15 | That language extension is specific to Visual C++, not supported in GCC nor Clang compilers.
--------------------------------------------------------------------------------
/Whisper/API/SpecialTokens.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | namespace Whisper
4 | {
5 | struct SpecialTokens
6 | {
7 | // The end of a transcription, token_eot
8 | int TranscriptionEnd;
9 | // Start of a transcription, token_sot
10 | int TranscriptionStart;
11 | // Represents the previous word in the transcription. It is used to help the model predict the current word based on the context of the words that came before it.
12 | int PreviousWord; // token_prev
13 | // Start of a sentence
14 | int SentenceStart; // token_solm
15 | //Represents the word "not" in the transcription
16 | int Not; // token_not
17 | //New transcription
18 | int TranscriptionBegin; // token_beg
19 |
20 | // token_translate
21 | int TaskTranslate;
22 | // token_transcribe
23 | int TaskTranscribe;
24 | };
25 | }
--------------------------------------------------------------------------------
/Whisper/API/iTranscribeResult.cl.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "TranscribeStructs.h"
3 | #include "../../ComLightLib/comLightCommon.h"
4 |
5 | namespace Whisper
6 | {
7 | struct iTranscribeResult : public ComLight::IUnknown
8 | {
9 | DEFINE_INTERFACE_ID( "{2871a73f-5ce3-48f8-8779-6582ee11935e}" );
10 |
11 | virtual HRESULT COMLIGHTCALL getSize( sTranscribeLength& rdi ) const = 0;
12 | virtual const sSegment* COMLIGHTCALL getSegments() const = 0;
13 | virtual const sToken* COMLIGHTCALL getTokens() const = 0;
14 | };
15 | }
--------------------------------------------------------------------------------
/Whisper/API/iTranscribeResult.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "TranscribeStructs.h"
3 |
4 | namespace Whisper
5 | {
6 | __interface __declspec( novtable, uuid( "2871a73f-5ce3-48f8-8779-6582ee11935e" ) ) iTranscribeResult : public IUnknown
7 | {
8 | HRESULT __stdcall getSize( sTranscribeLength& rdi ) const;
9 | const sSegment* __stdcall getSegments() const;
10 | const sToken* __stdcall getTokens() const;
11 | };
12 | }
--------------------------------------------------------------------------------
/Whisper/API/loggerApi.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | namespace Whisper
5 | {
6 | // Log level for messages
7 | enum struct eLogLevel : uint8_t
8 | {
9 | Error = 0,
10 | Warning = 1,
11 | Info = 2,
12 | Debug = 3
13 | };
14 | enum struct eLoggerFlags : uint8_t
15 | {
16 | UseStandardError = 1,
17 | SkipFormatMessage = 2,
18 | };
19 |
20 | // C function pointer to receive log messages from the library. The messages are encoded in UTF-8.
21 | using pfnLoggerSink = void( __stdcall* )( void* context, eLogLevel lvl, const char* message );
22 |
23 | // A sink to receive log messages produced by MeshRepair.dll
24 | struct sLoggerSetup
25 | {
26 | // C function pointer to receive log messages from the library
27 | pfnLoggerSink sink = nullptr;
28 | // Optional context parameter for the sink function; when consuming from C# you don't need that, pass IntPtr.Zero, delegates can capture things.
29 | void* context = nullptr;
30 | // Maximum log level to produce
31 | eLogLevel level;
32 | // Flags about the logger
33 | eLoggerFlags flags = (eLoggerFlags)0;
34 | };
35 | }
--------------------------------------------------------------------------------
/Whisper/API/sLanguageList.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | namespace Whisper
5 | {
6 | struct sLanguageEntry
7 | {
8 | uint32_t key;
9 | int id;
10 | const char* name;
11 | };
12 |
13 | struct sLanguageList
14 | {
15 | uint32_t length;
16 | const sLanguageEntry* pointer;
17 | };
18 | }
--------------------------------------------------------------------------------
/Whisper/API/sLoadModelCallbacks.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | namespace Whisper
4 | {
5 | using pfnLoadProgress = HRESULT( __stdcall* )( double val, void* pv ) noexcept;
6 | // Return S_OK to continue, or S_FALSE to fail with "The operation was canceled by the user" status code
7 | using pfnCancel = HRESULT( __stdcall* )( void* pv ) noexcept;
8 |
9 | struct sLoadModelCallbacks
10 | {
11 | pfnLoadProgress progress;
12 | pfnCancel cancel;
13 | void* pv;
14 | };
15 | }
--------------------------------------------------------------------------------
/Whisper/API/sModelSetup.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | namespace Whisper
5 | {
6 | enum struct eModelImplementation : uint32_t
7 | {
8 | // GPGPU implementation based on Direct3D 11.0 compute shaders
9 | GPU = 1,
10 |
11 | // A hybrid implementation which uses DirectCompute for encode, and decodes on CPU
12 | // Not implemented in the published builds of the DLL. To enable, change BUILD_HYBRID_VERSION macro to 1
13 | Hybrid = 2,
14 |
15 | // A reference implementation which uses the original GGML CPU-running code
16 | // Not implemented in the published builds of the DLL. To enable, change BUILD_BOTH_VERSIONS macro to 1
17 | Reference = 3,
18 | };
19 |
20 | enum struct eGpuModelFlags : uint32_t
21 | {
22 | Wave32 = 1,
23 | Wave64 = 2,
24 | NoReshapedMatMul = 4,
25 | UseReshapedMatMul = 8,
26 | Cloneable = 0x10,
27 | };
28 |
29 | struct sModelSetup
30 | {
31 | eModelImplementation impl = eModelImplementation::GPU;
32 | uint32_t flags = 0;
33 | const wchar_t* adapter = nullptr;
34 | };
35 |
36 | // Function pointer to enumerate GPUs
37 | using pfnListAdapters = void( __stdcall* )( const wchar_t* name, void* pv );
38 |
39 | // Function pointer to receive array of tokens from iModel.tokenize() API method
40 | using pfnDecodedTokens = void( __stdcall* )( const int* tokens, int tokensLength, void* pv );
41 | }
--------------------------------------------------------------------------------
/Whisper/API/whisperComLight.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "iMediaFoundation.cl.h"
3 | #include "iContext.cl.h"
4 | #include "iTranscribeResult.cl.h"
--------------------------------------------------------------------------------
/Whisper/API/whisperWindows.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "iMediaFoundation.h"
3 | #include "iContext.h"
4 | #include "iTranscribeResult.h"
--------------------------------------------------------------------------------
/Whisper/CPU/HybridLoader.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "DecoderTensors.h"
3 | #include
4 | #include
5 | #include "../../ComLightLib/streams.h"
6 |
7 | namespace CpuCompute
8 | {
9 | __interface iLoaderProgressSink
10 | {
11 | HRESULT gotBytes( int64_t cb );
12 | };
13 |
14 | class HybridLoader
15 | {
16 | DecoderTensors& destination;
17 | CAtlMap map;
18 | size_t bufferBytes = 0;
19 |
20 | struct alignas( 32 ) PendingTensor
21 | {
22 | Tensor* destPointer = nullptr;
23 | int64_t streamOffset = 0;
24 | size_t bufferOffset = 0;
25 | size_t payloadBytes = 0;
26 | };
27 | std::vector pending;
28 |
29 | public:
30 |
31 | HybridLoader( DecoderTensors& m, int countLayers );
32 |
33 | HRESULT setupTensor( const CStringA& name, int n_dims, int ftype, const std::array& ne, ComLight::iReadStream* stream, int64_t& postponedBytes );
34 |
35 | HRESULT completeLoad( ComLight::iReadStream* stream, iLoaderProgressSink& progressSink );
36 | };
37 | }
--------------------------------------------------------------------------------
/Whisper/CPU/KvTensors.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "Tensor.h"
3 | #include "LargeBuffer.h"
4 | #include "../Whisper/sModelParams.h"
5 |
6 | namespace CpuCompute
7 | {
8 | class KvTensors
9 | {
10 | uint16_t* keys = nullptr;
11 | uint16_t* values = nullptr;
12 | uint32_t size = 0;
13 |
14 | CpuCompute::LargeBuffer memory;
15 |
16 | public:
17 | // Create these two large tensors, FP16 precision
18 | HRESULT create( const Whisper::sModelParams& mp );
19 |
20 | // A slice of model.memory_cross_k tensor
21 | Tensor keysView( uint32_t len, uint32_t off ) const
22 | {
23 | if( len + off <= size )
24 | return Tensor::fromData( keys + off, eDataType::FP16, len );
25 | throw E_BOUNDS;
26 | }
27 |
28 | // A slice of model.memory_cross_v tensor
29 | Tensor valuesView( uint32_t len, uint32_t off ) const
30 | {
31 | if( len + off <= size )
32 | return Tensor::fromData( values + off, eDataType::FP16, len );
33 | throw E_BOUNDS;
34 | }
35 | };
36 | }
--------------------------------------------------------------------------------
/Whisper/CPU/KvTensorsCpu.cpp:
--------------------------------------------------------------------------------
1 | #include "stdafx.h"
2 | #include "KvTensors.h"
3 | using namespace CpuCompute;
4 |
5 | // Create these two large tensors, FP16 precision
6 | HRESULT KvTensors::create( const Whisper::sModelParams& mp )
7 | {
8 | const uint32_t n_mem = mp.n_text_layer * mp.n_text_ctx;
9 | const uint32_t n_elements = mp.n_text_state * n_mem;
10 |
11 | const size_t cb = sizeof( uint16_t ) * (size_t)n_elements * 2;
12 | CHECK( memory.allocate( cb ) );
13 |
14 | uint16_t* pointer = (uint16_t*)memory.pointer();
15 | keys = pointer;
16 | values = pointer + n_elements;
17 | size = n_elements;
18 | return S_OK;
19 | }
--------------------------------------------------------------------------------
/Whisper/CPU/LargeBuffer.cpp:
--------------------------------------------------------------------------------
1 | #include "stdafx.h"
2 | #include "LargeBuffer.h"
3 | using namespace CpuCompute;
4 |
5 | void LargeBuffer::deallocate()
6 | {
7 | if( nullptr == pv )
8 | return;
9 | VirtualFree( pv, 0, MEM_RELEASE );
10 | pv = nullptr;
11 | }
12 |
13 | HRESULT LargeBuffer::allocate( size_t cb )
14 | {
15 | deallocate();
16 |
17 | pv = VirtualAlloc( nullptr, cb, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE );
18 | if( nullptr != pv )
19 | return S_OK;
20 | return HRESULT_FROM_WIN32( GetLastError() );
21 | }
22 |
23 | HRESULT LargeBuffer::setReadOnly( size_t cb )
24 | {
25 | if( nullptr != pv )
26 | {
27 | DWORD op = 0;
28 | if( VirtualProtect( pv, cb, PAGE_READONLY, &op ) )
29 | return S_OK;
30 | return HRESULT_FROM_WIN32( GetLastError() );
31 | }
32 | else
33 | return OLE_E_BLANK;
34 | }
--------------------------------------------------------------------------------
/Whisper/CPU/LargeBuffer.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | namespace CpuCompute
4 | {
5 | // A large memory buffer allocated with VirtualAlloc kernel API, bypassing the heap.
6 | class LargeBuffer
7 | {
8 | void* pv = nullptr;
9 | public:
10 | LargeBuffer() = default;
11 | LargeBuffer( const LargeBuffer& ) = delete;
12 | LargeBuffer( LargeBuffer&& that ) noexcept
13 | {
14 | pv = that.pv;
15 | that.pv = nullptr;
16 | }
17 | ~LargeBuffer()
18 | {
19 | deallocate();
20 | }
21 | void operator=( LargeBuffer&& that ) noexcept
22 | {
23 | std::swap( pv, that.pv );
24 | }
25 | void operator=( const LargeBuffer& that ) = delete;
26 |
27 | // Allocate buffer with specified count of bytes, and read+write memory protection
28 | // The OS kernel guarantees zero-initialization of that memory.
29 | HRESULT allocate( size_t cb );
30 |
31 | // Change memory protection of the buffer to read only
32 | HRESULT setReadOnly( size_t cb );
33 |
34 | // Unless the pointer is nullptr, deallocate the buffer
35 | void deallocate();
36 |
37 | // Pointer to the start of the buffer, aligned by memory page = 4 kilobytes
38 | uint8_t* pointer() const
39 | {
40 | assert( nullptr != pv );
41 | return (uint8_t*)pv;
42 | }
43 | };
44 | }
--------------------------------------------------------------------------------
/Whisper/CPU/ParallelForRunner.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "LargeBuffer.h"
3 |
4 | namespace CpuCompute
5 | {
6 | // Callback interface for the parallel `for`
7 | __interface iComputeRange
8 | {
9 | // The implementation calls this method on multiple thread pool threads in parallel, and aggregates status codes.
10 | HRESULT __stdcall compute( size_t begin, size_t end ) const;
11 | };
12 |
13 | // Similar to ThreadPoolWork in parallelFor.h, optimized to be used as a direct replacement of OpenMP pool.
14 | class alignas( 64 ) ParallelForRunner
15 | {
16 | public:
17 | ParallelForRunner( int threads );
18 | ~ParallelForRunner();
19 |
20 | HRESULT setThreadsCount( int threads );
21 |
22 | HRESULT parallelFor( iComputeRange& compute, size_t length, size_t minBatch = 1 );
23 |
24 | // Allocate a temporary buffer for the calling thread.
25 | // The pointer is guaranteed to be aligned by page size = 4kb
26 | void* threadLocalBuffer( size_t cb );
27 |
28 | private:
29 |
30 | int maxThreads;
31 | PTP_WORK work = nullptr;
32 | iComputeRange* computeRange = nullptr;
33 | size_t countItems = 0;
34 | size_t countThreads = 0;
35 |
36 | // Aligning by cache lines.
37 | // Avoiding cache line sharing between CPU cores improves performance, despite wasting a few bytes of memory.
38 | struct alignas( 64 ) ThreadBuffer
39 | {
40 | LargeBuffer memory;
41 | size_t cb = 0;
42 | };
43 | std::vector threadBuffers;
44 |
45 | alignas( 64 ) volatile long threadIndex = 0;
46 | volatile HRESULT status = S_OK;
47 |
48 | void runBatch( size_t ith ) noexcept;
49 |
50 | static void __stdcall workCallbackStatic( PTP_CALLBACK_INSTANCE Instance, void* pv, PTP_WORK Work ) noexcept;
51 | };
52 | }
--------------------------------------------------------------------------------
/Whisper/CPU/Readme.txt:
--------------------------------------------------------------------------------
1 | The code in this folder is dropped by the linker’s dead code elimination optimization pass, unless you change BUILD_HYBRID_VERSION macro in stdafx.h
--------------------------------------------------------------------------------
/Whisper/CPU/mulMat.cpp:
--------------------------------------------------------------------------------
1 | #include "stdafx.h"
2 | #include "mulMat.h"
3 | #include "mulMatImpl.h"
4 | using namespace CpuCompute;
5 |
6 | namespace
7 | {
8 | template
9 | static HRESULT mulMatImpl( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor )
10 | {
11 | MulMatImpl impl{ result, a, b, pfor };
12 | return impl.run( pfor );
13 | }
14 | }
15 |
16 | HRESULT CpuCompute::mulMat( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor )
17 | {
18 | if( a.type() != eDataType::FP16 )
19 | return E_NOTIMPL;
20 | if( b.type() != eDataType::FP32 )
21 | return E_NOTIMPL;
22 |
23 | // return mulMatImpl<1, 1>( result, a, b, pfor );
24 |
25 | if( b.ne[ 1 ] == 1 )
26 | {
27 | // Multiplying by a single row
28 | if( a.ne[ 1 ] >= 32 )
29 | return mulMatImpl<4, 1>( result, a, b, pfor );
30 | else
31 | return mulMatImpl<1, 1>( result, a, b, pfor );
32 | }
33 | else if( b.ne[ 1 ] == 2 )
34 | {
35 | if( a.ne[ 1 ] >= 32 )
36 | return mulMatImpl<4, 2>( result, a, b, pfor );
37 | else
38 | return mulMatImpl<1, 2>( result, a, b, pfor );
39 | }
40 | else if( b.ne[ 1 ] == 3 )
41 | {
42 | if( a.ne[ 1 ] >= 16 )
43 | return mulMatImpl<2, 3>( result, a, b, pfor );
44 | else
45 | return mulMatImpl<1, 3>( result, a, b, pfor );
46 | }
47 | else
48 | {
49 | if( a.ne[ 1 ] >= 16 )
50 | return mulMatImpl<2, 4>( result, a, b, pfor );
51 | else
52 | return mulMatImpl<1, 4>( result, a, b, pfor );
53 | }
54 | }
--------------------------------------------------------------------------------
/Whisper/CPU/mulMat.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "ParallelForRunner.h"
3 | #include "Tensor.h"
4 |
5 | namespace CpuCompute
6 | {
7 | HRESULT mulMat( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor );
8 | }
9 |
10 | #if TENSOR_GGML_COMPAT
11 | #include "../source/ggml.h"
12 | inline HRESULT mulMat( ggml_tensor* result, const ggml_tensor* a, const ggml_tensor* b, CpuCompute::ParallelForRunner& pfor )
13 | {
14 | CpuCompute::Tensor r{ result }, lhs{ a }, rhs{ b };
15 | return CpuCompute::mulMat( r, lhs, rhs, pfor );
16 | }
17 | #endif
--------------------------------------------------------------------------------
/Whisper/CPU/mulMatImpl.cpp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Const-me/Whisper/306aadd1fce4b168cd38659236f4ba7c1603cebd/Whisper/CPU/mulMatImpl.cpp
--------------------------------------------------------------------------------
/Whisper/D3D/Binder.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "device.h"
3 |
4 | namespace DirectCompute
5 | {
6 | class Binder
7 | {
8 | uint8_t maxSrv = 0;
9 | uint8_t maxUav = 0;
10 |
11 | public:
12 | Binder() = default;
13 | Binder( const Binder& ) = delete;
14 |
15 | void bind( ID3D11ShaderResourceView* srv0, ID3D11UnorderedAccessView* uav0 );
16 | void bind( ID3D11ShaderResourceView* srv0, ID3D11ShaderResourceView* srv1, ID3D11UnorderedAccessView* uav0 );
17 | void bind( std::initializer_list srvs, std::initializer_list uavs );
18 | void bind( ID3D11UnorderedAccessView* uav0 );
19 | ~Binder();
20 | };
21 | }
--------------------------------------------------------------------------------
/Whisper/D3D/MappedResource.cpp:
--------------------------------------------------------------------------------
1 | #include "stdafx.h"
2 | #include "MappedResource.h"
3 | using namespace DirectCompute;
4 | #define CHECK( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) return __hr; }
5 |
6 | MappedResource::MappedResource()
7 | {
8 | mapped.pData = nullptr;
9 | mapped.RowPitch = mapped.DepthPitch = 0;
10 | resource = nullptr;
11 | }
12 |
13 | HRESULT MappedResource::map( ID3D11Resource* res, bool reading )
14 | {
15 | if( nullptr == resource )
16 | {
17 | D3D11_MAP mt = reading ? D3D11_MAP_READ : D3D11_MAP_WRITE_DISCARD;
18 | CHECK( context()->Map( res, 0, mt, 0, &mapped ) );
19 | resource = res;
20 | return S_OK;
21 | }
22 | return HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED );
23 | }
24 |
25 | MappedResource::~MappedResource()
26 | {
27 | if( nullptr != resource )
28 | {
29 | context()->Unmap( resource, 0 );
30 | resource = nullptr;
31 | mapped.pData = nullptr;
32 | }
33 | }
--------------------------------------------------------------------------------
/Whisper/D3D/MappedResource.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "device.h"
3 | #include
4 |
5 | namespace DirectCompute
6 | {
7 | class MappedResource
8 | {
9 | D3D11_MAPPED_SUBRESOURCE mapped;
10 | ID3D11Resource* resource;
11 | public:
12 | MappedResource();
13 | HRESULT map( ID3D11Resource* res, bool reading );
14 | ~MappedResource();
15 |
16 | void* data() const
17 | {
18 | assert( nullptr != mapped.pData );
19 | return mapped.pData;
20 | }
21 | };
22 | }
--------------------------------------------------------------------------------
/Whisper/D3D/RenderDoc/renderDoc.cpp:
--------------------------------------------------------------------------------
1 | #include "stdafx.h"
2 | #include "renderDoc.h"
3 | #include "renderdoc_app.h"
4 | #include "../device.h"
5 |
6 | #define ENABLE_RENDERDOC_DEBUGGER 1
7 |
8 | #if ENABLE_RENDERDOC_DEBUGGER
9 | namespace
10 | {
11 | static HMODULE hmRenderDoc = nullptr;
12 | static RENDERDOC_API_1_6_0* api = nullptr;
13 | }
14 |
15 | bool DirectCompute::initializeRenderDoc()
16 | {
17 | hmRenderDoc = GetModuleHandleW( L"renderdoc.dll" );
18 | if( nullptr == hmRenderDoc )
19 | return false;
20 |
21 | pRENDERDOC_GetAPI getApi = (pRENDERDOC_GetAPI)GetProcAddress( hmRenderDoc, "RENDERDOC_GetAPI" );
22 | if( nullptr == getApi )
23 | return false;
24 | if( 1 != getApi( eRENDERDOC_API_Version_1_6_0, (void**)&api ) )
25 | return false;
26 | if( nullptr == api )
27 | return false;
28 |
29 | return true;
30 | }
31 |
32 | namespace
33 | {
34 | using namespace DirectCompute;
35 | inline bool isKeyPressed( int vKey )
36 | {
37 | return 0 != ( GetAsyncKeyState( vKey ) & 0x8000 );
38 | }
39 | }
40 |
41 | CaptureRaii::CaptureRaii() : capturing( false )
42 | {
43 | if( nullptr == api )
44 | return;
45 | if( !isKeyPressed( VK_F12 ) )
46 | return;
47 | ID3D11Device* const dev = device();
48 | if( nullptr == dev )
49 | return;
50 |
51 | api->StartFrameCapture( dev, nullptr );
52 | capturing = true;
53 | }
54 |
55 | CaptureRaii::~CaptureRaii()
56 | {
57 | if( !capturing )
58 | return;
59 | api->EndFrameCapture( device(), nullptr );
60 | }
61 | #else // !ENABLE_RENDERDOC_DEBUGGER
62 | bool DirectCompute::initializeRenderDoc()
63 | {
64 | return false;
65 | }
66 | DirectCompute::CaptureRaii::CaptureRaii() : capturing( false )
67 | {
68 | }
69 | DirectCompute::CaptureRaii::~CaptureRaii()
70 | {
71 | }
72 | #endif
--------------------------------------------------------------------------------
/Whisper/D3D/RenderDoc/renderDoc.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | namespace DirectCompute
4 | {
5 | bool initializeRenderDoc();
6 |
7 | class CaptureRaii
8 | {
9 | bool capturing;
10 | public:
11 | CaptureRaii();
12 | CaptureRaii( const CaptureRaii& ) = delete;
13 | ~CaptureRaii();
14 | };
15 | }
--------------------------------------------------------------------------------
/Whisper/D3D/createBuffer.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "enums.h"
3 | #include "device.h"
4 |
5 | namespace DirectCompute
6 | {
7 | HRESULT createBuffer( eBufferUse use, size_t totalBytes, ID3D11Buffer** ppGpuBuffer, const void* rsi, ID3D11Buffer** ppStagingBuffer, bool shared = false );
8 | }
--------------------------------------------------------------------------------
/Whisper/D3D/createDevice.h:
--------------------------------------------------------------------------------
1 | // Low-level functions to create and initialize D3D11 device
2 | #pragma once
3 | #include
4 | #include "sGpuInfo.h"
5 |
6 | namespace DirectCompute
7 | {
8 | HRESULT createDevice( const std::wstring& adapter, ID3D11Device** dev, ID3D11DeviceContext** context );
9 |
10 | HRESULT validateFlags( uint32_t flags );
11 |
12 | HRESULT queryDeviceInfo( sGpuInfo& rdi, ID3D11Device* dev, uint32_t flags );
13 |
14 | // Create another device and context, on the same hardware GPU
15 | HRESULT cloneDevice( ID3D11Device* source, ID3D11Device** dev, ID3D11DeviceContext** context );
16 | }
--------------------------------------------------------------------------------
/Whisper/D3D/device.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 | #include "sGpuInfo.h"
5 |
6 | namespace DirectCompute
7 | {
8 | ID3D11Device* device();
9 | ID3D11DeviceContext* context();
10 | const sGpuInfo& gpuInfo();
11 |
12 | inline void csSetCB( ID3D11Buffer* cb )
13 | {
14 | context()->CSSetConstantBuffers( 0, 1, &cb );
15 | }
16 |
17 | __m128i bufferMemoryUsage( ID3D11Buffer* buffer );
18 |
19 | __m128i resourceMemoryUsage( ID3D11ShaderResourceView* srv );
20 | }
--------------------------------------------------------------------------------
/Whisper/D3D/downloadBuffer.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | namespace DirectCompute
4 | {
5 | // Download a buffer from VRAM into std::vector
6 | // The function is relatively expensive, creates a temporary staging buffer on each call, and only used to test things.
7 | template
8 | HRESULT downloadBuffer( ID3D11ShaderResourceView* srv, std::vector& vec );
9 | }
--------------------------------------------------------------------------------
/Whisper/D3D/enums.cpp:
--------------------------------------------------------------------------------
1 | #include "stdafx.h"
2 | #include "enums.h"
3 |
4 | static const alignas( 16 ) std::array s_tensorViewFormats = { DXGI_FORMAT_R16_FLOAT, DXGI_FORMAT_R32_FLOAT, DXGI_FORMAT_R32_UINT };
5 |
6 | DXGI_FORMAT DirectCompute::viewFormat( eDataType dt )
7 | {
8 | return s_tensorViewFormats[ (uint8_t)dt ];
9 | }
--------------------------------------------------------------------------------
/Whisper/D3D/enums.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | namespace DirectCompute
6 | {
7 | enum struct eDataType : uint8_t
8 | {
9 | FP16,
10 | FP32,
11 | U32,
12 | };
13 |
14 | inline size_t elementSize( eDataType dt )
15 | {
16 | assert( dt == eDataType::FP16 || dt == eDataType::FP32 || dt == eDataType::U32 );
17 |
18 | return ( dt == eDataType::FP16 ) ? 2 : 4;
19 | }
20 |
21 | DXGI_FORMAT viewFormat( eDataType dt );
22 |
23 | enum struct eBufferUse : uint8_t
24 | {
25 | // Immutable tensor, readable from GPU
26 | Immutable,
27 | // Read+write tensor, readable and writable on GPU
28 | ReadWrite,
29 | // Read+write tensor, readable and writable on GPU, which supports downloads from GPU
30 | ReadWriteDownload,
31 | // The tensor is accessible by both GPU (read only) and CPU (write only). Optimized for resources frequently updated from CPU.
32 | Dynamic,
33 | };
34 | }
--------------------------------------------------------------------------------
/Whisper/D3D/listGPUs.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | namespace DirectCompute
5 | {
6 | CComPtr selectAdapter( const std::wstring& requestedName );
7 | }
--------------------------------------------------------------------------------
/Whisper/D3D/sGpuInfo.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | namespace DirectCompute
5 | {
6 | // DXGI_ADAPTER_DESC.VendorId magic numbers; they come from that database: https://pcisig.com/membership/member-companies
7 | enum struct eGpuVendor : uint16_t
8 | {
9 | AMD = 0x1002,
10 | NVidia = 0x10de,
11 | Intel = 0x8086,
12 | VMWare = 0x15ad,
13 | };
14 |
15 | enum struct eGpuEffectiveFlags : uint8_t
16 | {
17 | Wave64 = 1,
18 | ReshapedMatMul = 2,
19 | Cloneable = 4,
20 | };
21 |
22 | struct sGpuInfo
23 | {
24 | eGpuEffectiveFlags flags;
25 | eGpuVendor vendor;
26 | uint16_t device, revision;
27 | uint32_t subsystem;
28 | size_t vramDedicated, ramDedicated, ramShared;
29 | std::wstring description;
30 |
31 | inline bool wave64() const
32 | {
33 | return 0 != ( (uint8_t)flags & (uint8_t)eGpuEffectiveFlags::Wave64 );
34 | }
35 |
36 | // On nVidia 1080Ti that approach is much slower, by a factor of 2.4
37 | // On AMD Cezanne that approach is faster by a factor of 0.69, i.e. 30% faster.
38 | // Dunno why that is, maybe 'coz on that AMD complete panels fit in L3 cache.
39 | // Anyway, we do want extra 30% perf on AMD Cezanne, so only using that code on AMD GPUs.
40 | // Dunno how it gonna behave on other GPUs, need to test.
41 | inline bool useReshapedMatMul() const
42 | {
43 | return 0 != ( (uint8_t)flags & (uint8_t)eGpuEffectiveFlags::ReshapedMatMul );
44 | }
45 |
46 | inline bool cloneableModel() const
47 | {
48 | return 0 != ( (uint8_t)flags & (uint8_t)eGpuEffectiveFlags::Cloneable );
49 | }
50 | };
51 | }
--------------------------------------------------------------------------------
/Whisper/D3D/shaderNames.cpp:
--------------------------------------------------------------------------------
1 | // This source file is generated by a tool
2 | #include "stdafx.h"
3 | #include "shaderNames.h"
4 |
5 | static const std::array s_shaderNames =
6 | {
7 | "add",
8 | "addInPlace",
9 | "addRepeat",
10 | "addRepeatEx",
11 | "addRepeatGelu",
12 | "addRepeatScale",
13 | "addRows",
14 | "convolutionMain",
15 | "convolutionMain2",
16 | "convolutionMain2Fixed",
17 | "convolutionPrep1",
18 | "convolutionPrep2",
19 | "copyConvert",
20 | "copyTranspose",
21 | "dbgFindNaN",
22 | "diagMaskInf",
23 | "flashAttention",
24 | "flashAttentionCompat1",
25 | "flashAttentionCompat2",
26 | "flashAttentionCompat3",
27 | "fmaRepeat1",
28 | "fmaRepeat2",
29 | "matReshapePanels",
30 | "mulMatByRow",
31 | "mulMatByRowTiled",
32 | "mulMatByRowTiledEx",
33 | "mulMatByScalar",
34 | "mulMatDotMain",
35 | "mulMatDotReshape",
36 | "mulMatMadMain",
37 | "mulMatTiled",
38 | "mulMatTiledEx",
39 | "norm",
40 | "normCompat",
41 | "normFixed",
42 | "scaleInPlace",
43 | "softMax",
44 | "softMaxCompat",
45 | "softMaxFixed",
46 | "softMaxLong",
47 | "zeroMemory",
48 | };
49 |
50 | const char* DirectCompute::computeShaderName( eComputeShader cs )
51 | {
52 | const uint16_t i = (uint16_t)cs;
53 | if( i < s_shaderNames.size() )
54 | return s_shaderNames[ i ];
55 | return nullptr;
56 | }
--------------------------------------------------------------------------------
/Whisper/D3D/shaderNames.h:
--------------------------------------------------------------------------------
1 | // This header is generated by a tool
2 | #pragma once
3 | #include
4 |
5 | namespace DirectCompute
6 | {
7 | enum struct eComputeShader: uint16_t
8 | {
9 | add = 0,
10 | addInPlace = 1,
11 | addRepeat = 2,
12 | addRepeatEx = 3,
13 | addRepeatGelu = 4,
14 | addRepeatScale = 5,
15 | addRows = 6,
16 | convolutionMain = 7,
17 | convolutionMain2 = 8,
18 | convolutionMain2Fixed = 9,
19 | convolutionPrep1 = 10,
20 | convolutionPrep2 = 11,
21 | copyConvert = 12,
22 | copyTranspose = 13,
23 | dbgFindNaN = 14,
24 | diagMaskInf = 15,
25 | flashAttention = 16,
26 | flashAttentionCompat1 = 17,
27 | flashAttentionCompat2 = 18,
28 | flashAttentionCompat3 = 19,
29 | fmaRepeat1 = 20,
30 | fmaRepeat2 = 21,
31 | matReshapePanels = 22,
32 | mulMatByRow = 23,
33 | mulMatByRowTiled = 24,
34 | mulMatByRowTiledEx = 25,
35 | mulMatByScalar = 26,
36 | mulMatDotMain = 27,
37 | mulMatDotReshape = 28,
38 | mulMatMadMain = 29,
39 | mulMatTiled = 30,
40 | mulMatTiledEx = 31,
41 | norm = 32,
42 | normCompat = 33,
43 | normFixed = 34,
44 | scaleInPlace = 35,
45 | softMax = 36,
46 | softMaxCompat = 37,
47 | softMaxFixed = 38,
48 | softMaxLong = 39,
49 | zeroMemory = 40,
50 | };
51 |
52 | const char* computeShaderName( eComputeShader cs );
53 | }
--------------------------------------------------------------------------------
/Whisper/D3D/shaders.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "shaderNames.h"
3 |
4 | namespace DirectCompute
5 | {
6 | HRESULT createComputeShaders( std::vector>& shaders );
7 |
8 | void bindShader( eComputeShader shader );
9 | }
--------------------------------------------------------------------------------
/Whisper/DllMain.cpp:
--------------------------------------------------------------------------------
1 | #include "stdafx.h"
2 |
3 | BOOL __stdcall DllMain( HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpvReserved )
4 | {
5 | // Perform actions based on the reason for calling.
6 | switch( fdwReason )
7 | {
8 | case DLL_PROCESS_ATTACH:
9 | // Initialize once for each new process. Return FALSE to fail DLL load.
10 | DisableThreadLibraryCalls( (HMODULE)hinstDLL );
11 | break;
12 | case DLL_THREAD_ATTACH:
13 | // Do thread-specific initialization.
14 | break;
15 | case DLL_THREAD_DETACH:
16 | // Do thread-specific cleanup.
17 | break;
18 | case DLL_PROCESS_DETACH:
19 | if( lpvReserved != nullptr )
20 | {
21 | break; // do not do cleanup if process termination scenario
22 | }
23 | // Perform any necessary cleanup
24 | break;
25 | }
26 | return TRUE; // Successful DLL_PROCESS_ATTACH.
27 | }
--------------------------------------------------------------------------------
/Whisper/Hybrid/HybridContext.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "../Whisper/WhisperModel.h"
3 | #include "../CPU/MlContext.h"
4 | #include "../CPU/BufferAllocator.h"
5 | #include "KeyValueDownloader.h"
6 | #include "../CPU/KvTensors.h"
7 |
8 | // This version of the hybrid context uses the new, custom-built kernels
9 | class HybridContext
10 | {
11 | CpuCompute::MlContext ml;
12 | CpuCompute::VirtualAllocator allocCompute, allocComputeLayer;
13 |
14 | class AllocSingle : public CpuCompute::iArenaAllocator
15 | {
16 | CpuCompute::LargeBuffer buffer;
17 | size_t capacity = 0;
18 | bool allocated = false;
19 | // Inherited via iArenaAllocator
20 | virtual void* allocate( size_t cb, size_t align ) override final;
21 |
22 | public:
23 | virtual void resetArena() override final;
24 | };
25 | AllocSingle allocLayerOutput;
26 |
27 | const CpuCompute::DecoderTensors& model;
28 | const Whisper::WhisperModel& whisperModel;
29 | KeyValueDownloader kvCross;
30 | CpuCompute::KvTensors kv;
31 |
32 | class SetAllocatorRaii;
33 |
34 | public:
35 |
36 | HybridContext( const Whisper::WhisperModel& wm );
37 |
38 | HRESULT create();
39 |
40 | HRESULT downloadKeyValues( const DirectCompute::KeyValueBuffers& source )
41 | {
42 | return kvCross.download( source );
43 | }
44 |
45 | struct sDecParams
46 | {
47 | int n_threads;
48 | int M;
49 | };
50 |
51 | HRESULT decode( const int* tokens, const int n_tokens, const int n_past, const sDecParams& dp, std::vector& probs_out );
52 | };
--------------------------------------------------------------------------------
/Whisper/Hybrid/KeyValueDownloader.cpp:
--------------------------------------------------------------------------------
1 | #include "stdafx.h"
2 | #include "KeyValueDownloader.h"
3 |
4 | HRESULT KeyValueDownloader::create( const Whisper::sModelParams& mp )
5 | {
6 | const uint32_t n_audio_ctx = mp.n_audio_ctx;
7 | const uint32_t n_mem = mp.n_text_layer * mp.n_audio_ctx;
8 | const uint32_t n_elements = mp.n_text_state * n_mem;
9 |
10 | CD3D11_BUFFER_DESC desc{ n_elements * 2, 0, D3D11_USAGE_STAGING, D3D11_CPU_ACCESS_READ };
11 | ID3D11Device* dev = DirectCompute::device();
12 | CHECK( dev->CreateBuffer( &desc, nullptr, &keys ) );
13 | CHECK( dev->CreateBuffer( &desc, nullptr, &values ) );
14 |
15 | length = n_elements;
16 | return S_OK;
17 | }
18 |
19 | HRESULT KeyValueDownloader::download( const DirectCompute::KeyValueBuffers& source )
20 | {
21 | ID3D11DeviceContext* ctx = DirectCompute::context();
22 | ctx->CopyResource( keys, source.keys.getBuffer() );
23 | ctx->CopyResource( values, source.values.getBuffer() );
24 | return S_OK;
25 | }
26 |
27 | KeyValueDownloader::ReadMap::ReadMap( KeyValueDownloader& owner ) :
28 | length( owner.length )
29 | {
30 | check( mappedKeys.map( owner.keys, true ) );
31 | check( mappedValues.map( owner.values, true ) );
32 | }
--------------------------------------------------------------------------------
/Whisper/Hybrid/Readme.txt:
--------------------------------------------------------------------------------
1 | The code in this folder is dropped by the linker’s dead code elimination optimization pass, unless you change BUILD_HYBRID_VERSION macro in stdafx.h
--------------------------------------------------------------------------------
/Whisper/MF/AudioBuffer.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | namespace Whisper
5 | {
6 | struct AudioBuffer
7 | {
8 | std::vector mono;
9 | std::vector stereo;
10 |
11 | void appendMono( const float* rsi, size_t countFloats );
12 | void appendDownmixedStereo( const float* rsi, size_t countFloats );
13 | void appendStereo( const float* rsi, size_t countFloats );
14 |
15 | using pfnAppendSamples = void( AudioBuffer::* )( const float* rsi, size_t countFloats );
16 |
17 | inline static pfnAppendSamples appendSamplesFunc( bool sourceMono, bool wantStereo )
18 | {
19 | if( sourceMono )
20 | return &AudioBuffer::appendMono;
21 | else if( !wantStereo )
22 | return &AudioBuffer::appendDownmixedStereo;
23 | else
24 | return &AudioBuffer::appendStereo;
25 | }
26 |
27 | void clear()
28 | {
29 | mono.clear();
30 | stereo.clear();
31 | }
32 |
33 | void swap( AudioBuffer& that )
34 | {
35 | mono.swap( that.mono );
36 | stereo.swap( that.stereo );
37 | }
38 |
39 | void resize( size_t len )
40 | {
41 | assert( len <= mono.size() );
42 | mono.resize( len );
43 | if( !stereo.empty() )
44 | stereo.resize( len * 2 );
45 | }
46 | };
47 | }
--------------------------------------------------------------------------------
/Whisper/MF/AudioCapture.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "../API/MfStructs.h"
3 |
4 | namespace Whisper
5 | {
6 | struct iAudioCapture;
7 | struct iMediaFoundation;
8 |
9 | HRESULT __stdcall captureDeviceList( pfnFoundCaptureDevices pfn, void* pv );
10 |
11 | HRESULT __stdcall captureOpen( iMediaFoundation* owner, const wchar_t* endpoint, const sCaptureParams& captureParams, iAudioCapture** pp ) noexcept;
12 | }
--------------------------------------------------------------------------------
/Whisper/MF/loadAudioFile.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "../API/iMediaFoundation.cl.h"
3 |
4 | namespace Whisper
5 | {
6 | HRESULT COMLIGHTCALL loadAudioFile( LPCTSTR path, bool stereo, iAudioBuffer** pp );
7 | }
--------------------------------------------------------------------------------
/Whisper/MF/mfStartup.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | namespace Whisper
4 | {
5 | class MfStartupRaii
6 | {
7 | uint8_t successFlags = 0;
8 | public:
9 | MfStartupRaii() = default;
10 | ~MfStartupRaii();
11 | MfStartupRaii( const MfStartupRaii& ) = delete;
12 |
13 | HRESULT startup();
14 | };
15 | }
--------------------------------------------------------------------------------
/Whisper/MF/mfUtils.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include "../Whisper/audioConstants.h"
7 |
8 | namespace Whisper
9 | {
10 | HRESULT createMediaType( bool stereo, IMFMediaType** pp );
11 |
12 | HRESULT getStreamDuration( IMFSourceReader* reader, int64_t& duration );
13 |
14 | HRESULT validateCurrentMediaType( IMFSourceReader* reader, uint32_t expectedChannels );
15 |
16 | struct iAudioReader;
17 | void setPreciseSamplesCount( const iAudioReader* ar, int64_t count );
18 | }
--------------------------------------------------------------------------------
/Whisper/ML/ConstantBuffer.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "../D3D/device.h"
3 | #include "TensorShape.h"
4 |
5 | namespace DirectCompute
6 | {
7 | // 96 bytes dynamic constant buffers, with dimensions and VRAM layout of 2-3 tensors
8 | class ConstantBuffer
9 | {
10 | CComPtr buffer;
11 |
12 | public:
13 | HRESULT create();
14 | HRESULT update( const TensorShape& t0 );
15 | HRESULT update( const TensorShape& t0, const TensorShape& t1 );
16 | HRESULT update( const TensorShape& t0, const TensorShape& t1, const TensorShape& t2 );
17 |
18 | void bind() const;
19 |
20 | __m128i getMemoryUse() const
21 | {
22 | return bufferMemoryUsage( buffer );
23 | }
24 | };
25 | }
--------------------------------------------------------------------------------
/Whisper/ML/DbgNanTest.cpp:
--------------------------------------------------------------------------------
1 | #include "stdafx.h"
2 | #include "DbgNanTest.h"
3 | #include "../D3D/MappedResource.h"
4 | using namespace DirectCompute;
5 |
6 | HRESULT DbgNanTest::create()
7 | {
8 | ID3D11Device* const dev = DirectCompute::device();
9 |
10 | CD3D11_BUFFER_DESC desc{ 4, D3D11_BIND_UNORDERED_ACCESS };
11 | CHECK( dev->CreateBuffer( &desc, nullptr, &bufferDefault ) );
12 |
13 | desc.Usage = D3D11_USAGE_STAGING;
14 | desc.BindFlags = 0;
15 | desc.CPUAccessFlags = D3D11_CPU_ACCESS_READ;
16 | CHECK( dev->CreateBuffer( &desc, nullptr, &bufferStaging ) );
17 |
18 | CD3D11_UNORDERED_ACCESS_VIEW_DESC viewDesc{ D3D11_UAV_DIMENSION_BUFFER, DXGI_FORMAT_R32_UINT, 0, 1 };
19 | CHECK( dev->CreateUnorderedAccessView( bufferDefault, &viewDesc, &uav ) );
20 |
21 | return S_OK;
22 | }
23 |
24 | void DbgNanTest::destroy()
25 | {
26 | uav = nullptr;
27 | bufferStaging = nullptr;
28 | bufferDefault = nullptr;
29 | }
30 |
31 | bool DbgNanTest::test() const
32 | {
33 | context()->CopyResource( bufferStaging, bufferDefault );
34 | MappedResource mapped;
35 | check( mapped.map( bufferStaging, true ) );
36 | const BOOL val = *(const BOOL*)mapped.data();
37 | return val != 0;
38 | }
--------------------------------------------------------------------------------
/Whisper/ML/DbgNanTest.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | namespace DirectCompute
4 | {
5 | class DbgNanTest
6 | {
7 | CComPtr bufferDefault, bufferStaging;
8 | CComPtr uav;
9 | public:
10 | HRESULT create();
11 | void destroy();
12 | operator ID3D11UnorderedAccessView* ( ) const
13 | {
14 | return uav;
15 | }
16 | bool test() const;
17 | };
18 |
19 | #if DBG_TEST_NAN
20 | const DbgNanTest& getNanTestBuffers();
21 | #endif
22 | }
--------------------------------------------------------------------------------
/Whisper/ML/Device.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include "../D3D/sGpuInfo.h"
4 | #include "LookupTables.h"
5 | #include "DbgNanTest.h"
6 |
7 | namespace DirectCompute
8 | {
9 | struct Device
10 | {
11 | CComPtr device;
12 | CComPtr context;
13 |
14 | std::vector> shaders;
15 | CComPtr smallCb;
16 | sGpuInfo gpuInfo;
17 | LookupTables lookupTables;
18 | #if DBG_TEST_NAN
19 | DbgNanTest nanTestBuffers;
20 | #endif
21 |
22 | HRESULT create( uint32_t flags, const std::wstring& adapter );
23 | HRESULT createClone( const Device& source );
24 | void destroy();
25 |
26 | class ThreadSetupRaii
27 | {
28 | bool setup;
29 | public:
30 | ThreadSetupRaii( const Device* dev );
31 | ~ThreadSetupRaii();
32 | ThreadSetupRaii( ThreadSetupRaii&& that ) noexcept
33 | {
34 | setup = that.setup;
35 | that.setup = false;
36 | }
37 | ThreadSetupRaii( const ThreadSetupRaii& ) = delete;
38 | void operator=( const ThreadSetupRaii& ) = delete;
39 | };
40 |
41 | ThreadSetupRaii setForCurrentThread() const
42 | {
43 | return ThreadSetupRaii{ this };
44 | }
45 | };
46 | }
--------------------------------------------------------------------------------
/Whisper/ML/LookupTables.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "../D3D/device.h"
3 |
4 | namespace DirectCompute
5 | {
6 | class LookupTables
7 | {
8 | CComPtr m_gelu, m_exponent;
9 |
10 | public:
11 |
12 | HRESULT create();
13 | HRESULT createClone( const LookupTables& source );
14 | void clear();
15 | ID3D11ShaderResourceView* gelu() const { return m_gelu; }
16 | ID3D11ShaderResourceView* exponent() const { return m_exponent; }
17 |
18 | __m128i getMemoryUsage() const;
19 | };
20 |
21 | const LookupTables& lookupTables();
22 | }
--------------------------------------------------------------------------------
/Whisper/ML/LookupTablesData.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | namespace DirectCompute
6 | {
7 | struct LookupTablesData
8 | {
9 | std::array gelu;
10 | std::array exponent;
11 |
12 | LookupTablesData();
13 | };
14 | }
--------------------------------------------------------------------------------
/Whisper/ML/Reshaper.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "Tensor.h"
3 |
4 | namespace DirectCompute
5 | {
6 | // This class reshapes some of the model’s tensor, immediately after they’re loaded.
7 | // That feature is used on all AMD GPUs.
8 | class Reshaper
9 | {
10 | CComPtr constantBuffer;
11 | HRESULT createConstants();
12 |
13 | public:
14 | ~Reshaper();
15 | HRESULT makePanels( Tensor& tensor, eDataType dataType );
16 | };
17 | }
--------------------------------------------------------------------------------
/Whisper/ML/TempBuffers.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "TensorGpuViews.h"
3 |
4 | namespace DirectCompute
5 | {
6 | class TempBuffers
7 | {
8 | class Buffer : public TensorGpuViews
9 | {
10 | size_t capacity = 0;
11 |
12 | public:
13 |
14 | void clear()
15 | {
16 | TensorGpuViews::clear();
17 | capacity = 0;
18 | }
19 |
20 | HRESULT resize( DXGI_FORMAT format, size_t elements, size_t cbElement, bool zeroMemory );
21 |
22 | size_t getCapacity() const { return capacity; }
23 | };
24 |
25 | Buffer m_fp16;
26 | Buffer m_fp16_2;
27 | Buffer m_fp32;
28 |
29 | public:
30 |
31 | const TensorGpuViews& fp16( size_t countElements, bool zeroMemory = false );
32 | const TensorGpuViews& fp16_2( size_t countElements, bool zeroMemory = false );
33 | const TensorGpuViews& fp32( size_t countElements, bool zeroMemory = false );
34 |
35 | void clear()
36 | {
37 | m_fp16.clear();
38 | m_fp16_2.clear();
39 | m_fp32.clear();
40 | }
41 |
42 | __m128i getMemoryUse() const;
43 | };
44 | }
--------------------------------------------------------------------------------
/Whisper/ML/TensorEx.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "Tensor.h"
3 |
4 | namespace DirectCompute
5 | {
6 | // A tensor which supports dynamic updates from CPU, or downloads from VRAM to system RAM
7 | class TensorEx : public Tensor
8 | {
9 | protected:
10 | CComPtr buffer;
11 | CComPtr stagingBuffer;
12 |
13 | HRESULT getViewSize( uint32_t& cbElement, uint32_t& countElements ) const;
14 |
15 | public:
16 |
17 | HRESULT create( const ggml_tensor& ggml, eBufferUse usage, bool uploadData );
18 | HRESULT create( eDataType type, eBufferUse usage, const std::array& sizeElements );
19 |
20 | HRESULT download( void* rdi, size_t cb ) const;
21 |
22 | HRESULT download( void* rdi ) const;
23 |
24 | template
25 | HRESULT download( std::vector& vec ) const
26 | {
27 | uint32_t cbElement, numElements;
28 | CHECK( getViewSize( cbElement, numElements ) );
29 |
30 | try
31 | {
32 | vec.resize( numElements );
33 | }
34 | catch( const std::bad_alloc& )
35 | {
36 | return E_OUTOFMEMORY;
37 | }
38 |
39 | return download( vec.data(), (size_t)cbElement * numElements );
40 | }
41 | };
42 | }
--------------------------------------------------------------------------------
/Whisper/ML/TensorGpuViews.cpp:
--------------------------------------------------------------------------------
1 | #include "stdafx.h"
2 | #include "TensorGpuViews.h"
3 | using namespace DirectCompute;
4 |
5 | HRESULT TensorGpuViews::create( ID3D11Buffer* gpuBuffer, DXGI_FORMAT format, size_t countElements, bool makeUav )
6 | {
7 | srv = nullptr;
8 | uav = nullptr;
9 |
10 | if( countElements > UINT_MAX )
11 | return DISP_E_OVERFLOW;
12 |
13 | CD3D11_SHADER_RESOURCE_VIEW_DESC viewDesc{ D3D11_SRV_DIMENSION_BUFFER, format, 0, (UINT)countElements };
14 | CHECK( device()->CreateShaderResourceView( gpuBuffer, &viewDesc, &srv ) );
15 |
16 | if( makeUav )
17 | {
18 | CD3D11_UNORDERED_ACCESS_VIEW_DESC uavDesc{ D3D11_UAV_DIMENSION_BUFFER, format , 0, (UINT)countElements };
19 | CHECK( device()->CreateUnorderedAccessView( gpuBuffer, &uavDesc, &uav ) );
20 | }
21 |
22 | return S_OK;
23 | }
--------------------------------------------------------------------------------
/Whisper/ML/TensorGpuViews.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include "../D3D/device.h"
4 |
5 | namespace DirectCompute
6 | {
7 | class TensorGpuViews
8 | {
9 | protected:
10 | CComPtr srv;
11 | CComPtr uav;
12 |
13 | public:
14 |
15 | operator ID3D11ShaderResourceView* ( ) const { return srv; }
16 | operator ID3D11UnorderedAccessView* ( ) const { return uav; }
17 |
18 | HRESULT create( ID3D11Buffer* buffer, DXGI_FORMAT format, size_t countElements, bool makeUav );
19 |
20 | void clear()
21 | {
22 | srv = nullptr;
23 | uav = nullptr;
24 | }
25 |
26 | void setGpuViews( ID3D11ShaderResourceView* read, ID3D11UnorderedAccessView* write = nullptr )
27 | {
28 | srv = read;
29 | uav = write;
30 | }
31 | };
32 | }
--------------------------------------------------------------------------------
/Whisper/ML/mlUtils.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | namespace DirectCompute
4 | {
5 | // Update the small dynamic constant buffer
6 | ID3D11Buffer* __vectorcall updateSmallCb( __m128i cbData );
7 |
8 | // Fill the tensor with either 0.0 or NaN values
9 | void zeroMemory( ID3D11UnorderedAccessView* uav, uint32_t length, bool fillWithNaN = false );
10 |
11 | // Fill the complete UAV with NaN values
12 | void fillTensorWithNaN( ID3D11UnorderedAccessView* uav );
13 |
14 | // true when the tensor contains at least 1 NaN value
15 | bool scanTensorForNaN( ID3D11ShaderResourceView* tensor, uint32_t length );
16 |
17 | // Create SRV on another device, reusing the resource
18 | HRESULT cloneResourceView( ID3D11ShaderResourceView* rsi, ID3D11ShaderResourceView** rdi );
19 | }
--------------------------------------------------------------------------------
/Whisper/ML/reshapedMultiply.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | namespace DirectCompute
5 | {
6 | namespace ReshapedMultiply
7 | {
8 | constexpr uint32_t TILE_SIZE = 32;
9 | }
10 | }
--------------------------------------------------------------------------------
/Whisper/ML/tensorOpsTests.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "../source/ggml.h"
3 |
4 | namespace DirectCompute
5 | {
6 | // void testMulMatReshape( const ggml_tensor* src1, const void* tempBuffer );
7 | void testMulMat( const ggml_tensor* src0, const ggml_tensor* src1, const ggml_tensor* dst, const void* tempBuffer );
8 | void computeMulMat( const ggml_tensor* src0, const ggml_tensor* src1, ggml_tensor* dst );
9 |
10 | void testFlashAttention( const ggml_tensor* q, const ggml_tensor* k, const ggml_tensor* v, bool masked, const ggml_tensor* dst );
11 | void computeFlashAttention( const ggml_tensor* q, const ggml_tensor* k, const ggml_tensor* v, bool masked, ggml_tensor* dst );
12 |
13 | void testConvolution( const ggml_tensor* src0, const ggml_tensor* src1, const ggml_tensor* dst );
14 | void computeConvolution( const ggml_tensor* src0, const ggml_tensor* src1, ggml_tensor* dst );
15 | }
--------------------------------------------------------------------------------
/Whisper/ML/testUtils.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "../D3D/downloadBuffer.h"
3 | #include "../D3D/RenderDoc/renderDoc.h"
4 | #include
5 | #include
6 |
7 | // Funfact: this code written by ChatGPT
8 | namespace std
9 | {
10 | template<>
11 | struct hash>
12 | {
13 | size_t operator()( const array& arr ) const
14 | {
15 | size_t result = 0;
16 | for( uint32_t element : arr )
17 | result = ( result * 31 ) ^ element;
18 | return result;
19 | }
20 | };
21 | }
22 |
23 | namespace DirectCompute
24 | {
25 | struct sTensorDiff
26 | {
27 | // maximum( absolute( a - b ) )
28 | float maxAbsDiff;
29 | // average( ( a - b )^2 )
30 | float avgDiffSquared;
31 | size_t length;
32 |
33 | void print() const;
34 | void print( const char* what ) const;
35 | };
36 |
37 | // Compute difference between 2 FP32 vectors
38 | sTensorDiff computeDiff( const float* a, const float* b, size_t length );
39 |
40 | // Compute difference between 2 FP16 vectors
41 | sTensorDiff computeDiff( const uint16_t* a, const uint16_t* b, size_t length );
42 |
43 | class Tensor;
44 | sTensorDiff computeDiff( const Tensor& a, const Tensor& b );
45 |
46 | HRESULT dbgWriteBinaryFile( LPCTSTR fileName, const void* rsi, size_t cb );
47 |
48 | // Print unique sizes of the two tensors
49 | class PrintUniqueTensorSizes
50 | {
51 | std::unordered_set> set;
52 | const char* const what;
53 | void printImpl( const std::array& a );
54 |
55 | public:
56 | PrintUniqueTensorSizes( const char* w ) : what( w ) { }
57 |
58 | void print( const Tensor& lhs, const Tensor& rhs );
59 | void print( const Tensor& lhs );
60 | void print( const int* lhs, const int* rhs );
61 | };
62 | }
--------------------------------------------------------------------------------
/Whisper/ML/testUtilsC.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #ifdef __cplusplus
4 | extern "C"
5 | {
6 | #endif
7 | void printUniqueTensorSize( const char* name, const int* lhs, const int* rhs );
8 | #ifdef __cplusplus
9 | }
10 | #endif
--------------------------------------------------------------------------------
/Whisper/Readme.txt:
--------------------------------------------------------------------------------
1 | This C++ project builds a DLL which actually does the heavy lifting of this project.
2 |
3 | It implements the ML model, handles multimedia files with Media Foundation, captures audio (also with MF), does voice activity detection (custom code running on CPU), and a few smaller things.
4 |
5 | The code requires C++/20, and only tested with Visual Studio 2022.
6 |
7 | When running pure GPGPU model, the DLL requires SSE 4.1 instruction set.
8 |
9 | When running a hybrid model, the DLL requires AVX1, FMA3, F16C, and BMI1 instruction set extensions.
--------------------------------------------------------------------------------
/Whisper/Resource.rc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Const-me/Whisper/306aadd1fce4b168cd38659236f4ba7c1603cebd/Whisper/Resource.rc
--------------------------------------------------------------------------------
/Whisper/Utils/CpuProfiler.cpp:
--------------------------------------------------------------------------------
1 | #include "stdafx.h"
2 | #include "CpuProfiler.h"
3 |
4 | namespace
5 | {
6 | using namespace Whisper;
7 |
8 | inline int64_t qpcNow()
9 | {
10 | int64_t res;
11 | QueryPerformanceCounter( (LARGE_INTEGER*)&res );
12 | return res;
13 | }
14 |
15 | class CpuTimescale
16 | {
17 | uint64_t frequency = 0;
18 | const int64_t tscStart;
19 | const int64_t qpcStart;
20 |
21 | uint64_t computeTscFrequency();
22 |
23 | public:
24 |
25 | CpuTimescale() :
26 | tscStart( tscNow() ),
27 | qpcStart( qpcNow() )
28 | { }
29 |
30 | inline uint64_t computeTicks( uint64_t tsc )
31 | {
32 | uint64_t freq = frequency;
33 | if( freq == 0 )
34 | freq = computeTscFrequency();
35 |
36 | return makeTime( tsc, freq );
37 | }
38 | };
39 |
40 | uint64_t __declspec( noinline ) CpuTimescale::computeTscFrequency()
41 | {
42 | int64_t tsc = tscNow();
43 | int64_t qpc = qpcNow();
44 | tsc -= tscStart;
45 | qpc -= qpcStart;
46 |
47 | uint64_t qpcFreq;
48 | QueryPerformanceFrequency( (LARGE_INTEGER*)&qpcFreq );
49 |
50 | // Seconds = qpc / qpcFreq
51 | // ticks per second = tsc / seconds = tsc * qpcFreq / qpc
52 | uint64_t res = ( (uint64_t)tsc * qpcFreq + ( (uint64_t)qpc / 2 ) - 1 ) / (uint64_t)qpc;
53 | frequency = res;
54 | const double GHz = (double)(int64_t)res * 1.0E-9;
55 | logDebug( u8"Computed CPU base frequency: %g GHz", GHz );
56 | return res;
57 | }
58 |
59 | static CpuTimescale timescale;
60 | }
61 |
62 | uint64_t Whisper::ticksFromTsc( uint64_t tscDiff )
63 | {
64 | return timescale.computeTicks( tscDiff );
65 | }
--------------------------------------------------------------------------------
/Whisper/Utils/CpuProfiler.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | namespace Whisper
4 | {
5 | // Get current time in CPU clock
6 | // More specifically, each CPU core has a timestamp counter which runs at CPU's base frequency, regardless on the frequency scaling of that core.
7 | inline int64_t tscNow()
8 | {
9 | return __rdtsc();
10 | }
11 |
12 | // Scale the time interval from CPU time stamp counter clock into 100-nanosecond ticks, rounding to nearest
13 | uint64_t ticksFromTsc( uint64_t tscDiff );
14 |
15 | class CpuProfiler
16 | {
17 | const int64_t started = tscNow();
18 |
19 | public:
20 |
21 | uint64_t elapsed() const
22 | {
23 | return ticksFromTsc( (uint64_t)( tscNow() - started ) );
24 | }
25 | };
26 | }
--------------------------------------------------------------------------------
/Whisper/Utils/DelayExecution.cpp:
--------------------------------------------------------------------------------
1 | #include "stdafx.h"
2 | #include "DelayExecution.h"
3 |
4 | namespace
5 | {
6 | constexpr bool useHighRezTimer = false;
7 |
8 | constexpr int64_t sleepMicroseconds = 200;
9 |
10 | inline HRESULT sleepImpl( HANDLE timer )
11 | {
12 | constexpr int64_t sleepTicks = sleepMicroseconds * 10;
13 |
14 | LARGE_INTEGER li;
15 | // Negative values indicate relative time
16 | li.QuadPart = -sleepTicks;
17 | if( !SetWaitableTimerEx( timer, &li, 0, nullptr, nullptr, nullptr, 0 ) )
18 | return getLastHr();
19 | const DWORD res = WaitForSingleObject( timer, 50 );
20 | if( res == WAIT_OBJECT_0 )
21 | return S_OK;
22 | if( res == WAIT_FAILED )
23 | return getLastHr();
24 | return E_FAIL;
25 | }
26 | }
27 |
28 | void DelayExecution::sleepOnTheTimer( const DelayExecution& delay )
29 | {
30 | HRESULT hr = sleepImpl( delay.timer );
31 | if( SUCCEEDED( hr ) )
32 | return;
33 | logWarningHr( hr, u8"DelayExecution.sleepOnTheTimer" );
34 | }
35 |
36 | void DelayExecution::spinWait( const DelayExecution& )
37 | {
38 | for( size_t i = 0; i < 1024; i++ )
39 | _mm_pause();
40 | }
41 |
42 | void DelayExecution::sleep( const DelayExecution& )
43 | {
44 | Sleep( 0 );
45 | }
46 |
47 | DelayExecution::DelayExecution()
48 | {
49 | if constexpr( useHighRezTimer )
50 | {
51 | constexpr DWORD flags = CREATE_WAITABLE_TIMER_HIGH_RESOLUTION;
52 | HANDLE h = CreateWaitableTimerEx( nullptr, nullptr, flags, TIMER_ALL_ACCESS );
53 | if( nullptr != h )
54 | {
55 | timer.Attach( h );
56 | pfn = &sleepOnTheTimer;
57 | return;
58 | }
59 |
60 | const HRESULT hr = getLastHr();
61 | logWarningHr( hr, u8"CreateWaitableTimerEx" );
62 | }
63 |
64 | pfn = &spinWait;
65 | // pfn = &sleep;
66 | }
--------------------------------------------------------------------------------
/Whisper/Utils/DelayExecution.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | // Utility class implementing a high-resolution Sleep() function
5 | class DelayExecution
6 | {
7 | using pfnDelay = void( * )( const DelayExecution& de );
8 | pfnDelay pfn = nullptr;
9 | CHandle timer;
10 |
11 | static void sleepOnTheTimer( const DelayExecution& delay );
12 | static void spinWait( const DelayExecution& );
13 | static void sleep( const DelayExecution& );
14 |
15 | public:
16 | DelayExecution();
17 | DelayExecution( const DelayExecution& ) = delete;
18 | ~DelayExecution() = default;
19 |
20 | void delay() const
21 | {
22 | pfn( *this );
23 | }
24 | };
--------------------------------------------------------------------------------
/Whisper/Utils/GpuProfiler.h:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Const-me/Whisper/306aadd1fce4b168cd38659236f4ba7c1603cebd/Whisper/Utils/GpuProfiler.h
--------------------------------------------------------------------------------
/Whisper/Utils/GpuProfilerSimple.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "../D3D/device.h"
3 | #include "DelayExecution.h"
4 |
5 | namespace DirectCompute
6 | {
7 | // A simple profiler which doesn't collect anything, used to measure time it took to load the model
8 | class GpuProfilerSimple
9 | {
10 | DelayExecution delay;
11 | CComPtr disjoint, begin, end;
12 | public:
13 | HRESULT create();
14 | HRESULT time( uint64_t& rdi ) const;
15 | };
16 | }
--------------------------------------------------------------------------------
/Whisper/Utils/LZ4/LICENSE:
--------------------------------------------------------------------------------
1 | LZ4 Library
2 | Copyright (c) 2011-2020, Yann Collet
3 | All rights reserved.
4 |
5 | Redistribution and use in source and binary forms, with or without modification,
6 | are permitted provided that the following conditions are met:
7 |
8 | * Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | * Redistributions in binary form must reproduce the above copyright notice, this
12 | list of conditions and the following disclaimer in the documentation and/or
13 | other materials provided with the distribution.
14 |
15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
19 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
22 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 |
--------------------------------------------------------------------------------
/Whisper/Utils/Logger.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "../API/loggerApi.h"
3 |
4 | #ifdef __cplusplus
5 | extern "C" {
6 | #endif
7 |
8 | void logError( const char8_t* pszFormat, ... );
9 | void logError16( const wchar_t* pszFormat, ... );
10 | void logErrorHr( long hr, const char8_t* pszFormat, ... );
11 | void logWarning( const char8_t* pszFormat, ... );
12 | void logWarning16( const wchar_t* pszFormat, ... );
13 | void logWarningHr( long hr, const char8_t* pszFormat, ... );
14 | void logInfo( const char8_t* pszFormat, ... );
15 | void logInfo16( const wchar_t* pszFormat, ... );
16 | void logDebug( const char8_t* pszFormat, ... );
17 | void logDebug16( const wchar_t* pszFormat, ... );
18 |
19 | bool willLogMessage( Whisper::eLogLevel lvl );
20 |
21 | #ifdef __cplusplus
22 | }
23 | #endif
24 |
--------------------------------------------------------------------------------
/Whisper/Utils/MurmurHash3.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | void MurmurHash3_x86_32( const void* key, int len, uint32_t seed, void* out );
5 | void MurmurHash3_x86_128( const void* key, int len, uint32_t seed, void* out );
6 | void MurmurHash3_x64_128( const void* key, int len, uint32_t seed, void* out );
7 |
8 | #include
9 |
10 | // Traits class for `CAtlMap` which does not copy nor owns these strings
11 | struct StringPtrTraits : public ATL::CDefaultElementTraits
12 | {
13 | using INARGTYPE = const char*;
14 |
15 | static inline bool CompareElements( const char* a, const char* b )
16 | {
17 | return 0 == strcmp( a, b );
18 | }
19 |
20 | static inline int CompareElementsOrdered( const char* a, const char* b )
21 | {
22 | return strcmp( a, b );
23 | }
24 |
25 | static inline ULONG Hash( const char* ptr )
26 | {
27 | uint32_t hash = UINT_MAX;
28 | if( nullptr != ptr )
29 | {
30 | const int len = (int)strlen( ptr );
31 | constexpr uint32_t seed = 0;
32 | MurmurHash3_x86_32( ptr, len, seed, &hash );
33 | }
34 | return hash;
35 | }
36 | };
--------------------------------------------------------------------------------
/Whisper/Utils/ReadStream.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "../ComLightLib/streams.h"
3 | #include "../ComLightLib/comLightServer.h"
4 | #define WIN32_LEAN_AND_MEAN
5 | #include
6 |
7 | class ReadStream : public ComLight::ObjectRoot
8 | {
9 | CAtlFile file;
10 | // TODO: implement a buffer in this class, at least 256kb
11 |
12 | HRESULT COMLIGHTCALL read( void* lpBuffer, int nNumberOfBytesToRead, int& lpNumberOfBytesRead ) override final
13 | {
14 | return file.Read( lpBuffer, (DWORD)nNumberOfBytesToRead, *(DWORD*)&lpNumberOfBytesRead );
15 | }
16 | HRESULT COMLIGHTCALL seek( int64_t offset, ComLight::eSeekOrigin origin ) override final
17 | {
18 | return file.Seek( offset, (uint8_t)origin );
19 | }
20 | HRESULT COMLIGHTCALL getPosition( int64_t& position ) override final
21 | {
22 | return file.GetPosition( *(ULONGLONG*)&position );
23 | }
24 | HRESULT COMLIGHTCALL getLength( int64_t& length ) override final
25 | {
26 | return file.GetSize( *(ULONGLONG*)&length );
27 | }
28 |
29 | public:
30 |
31 | HRESULT open( const wchar_t* path )
32 | {
33 | if( file )
34 | return HRESULT_CODE( ERROR_ALREADY_INITIALIZED );
35 | return file.Create( path, GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL | FILE_FLAG_SEQUENTIAL_SCAN );
36 | }
37 | };
--------------------------------------------------------------------------------
/Whisper/Utils/Trace/TraceStructures.cpp:
--------------------------------------------------------------------------------
1 | #include "stdafx.h"
2 | #include "TraceStructures.h"
3 | using namespace Tracing;
4 |
5 | uint64_t sTraceItem::buffer( uint64_t off, size_t length, eDataType type )
6 | {
7 | payloadOffset = off;
8 | payloadSize = length * DirectCompute::elementSize( type );
9 | *(uint64_t*)( &size[ 0 ] ) = length;
10 | *(uint64_t*)( &size[ 2 ] ) = 0;
11 | _mm_storeu_si128( ( __m128i* )stride.data(), _mm_setzero_si128() );
12 | itemType = eItemType::Buffer;
13 | dataType = type;
14 | return payloadSize;
15 | }
16 |
17 | uint64_t sTraceItem::tensor( uint64_t off, __m128i ne, __m128i nb, eDataType type )
18 | {
19 | payloadOffset = off;
20 | _mm_storeu_si128( ( __m128i* )size.data(), ne );
21 | _mm_storeu_si128( ( __m128i* )stride.data(), nb );
22 | uint64_t count = 1;
23 | for( uint32_t i : size )
24 | if( i != 0 )
25 | count *= i;
26 |
27 | payloadSize = count * DirectCompute::elementSize( type );
28 | itemType = eItemType::Tensor;
29 | dataType = type;
30 | return payloadSize;
31 | }
--------------------------------------------------------------------------------
/Whisper/Utils/parallelFor.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | namespace Whisper
4 | {
5 | // A callback to offload to the thread pool
6 | using pfnParallelForCallback = HRESULT( * )( int ith, void* ctx ) noexcept;
7 |
8 | // A simple parallel for implementation; Windows includes a decent thread pool since Vista (2006)
9 | HRESULT parallelFor( pfnParallelForCallback pfn, int threadsCount, void* ctx );
10 |
11 | // Use this version when you wanna use the thread pool repeatedly, for the same work.
12 | // This class caches native work handle, saving a couple of WinAPI calls.
13 | class alignas( 64 ) ThreadPoolWork
14 | {
15 | PTP_WORK work = nullptr;
16 |
17 | // We want these volatile fields in another cache line from the rest of the data of this class.
18 | // threadIndex field is concurrently modified by different CPU cores, and these cache coherency protocols are slow.
19 | // OTOH, work and callback fields of this class only change when created / destroyed, that cache line is shared by CPU cores without any performance penalty.
20 | alignas( 64 ) volatile long threadIndex = 0;
21 | volatile HRESULT status = E_UNEXPECTED;
22 |
23 | static void __stdcall callbackStatic( PTP_CALLBACK_INSTANCE Instance, PVOID pv, PTP_WORK Work );
24 |
25 | protected:
26 | virtual HRESULT threadPoolCallback( int ith ) noexcept = 0;
27 |
28 | public:
29 | ThreadPoolWork() = default;
30 | ThreadPoolWork( const ThreadPoolWork& ) = delete;
31 |
32 | ~ThreadPoolWork();
33 |
34 | HRESULT create();
35 |
36 | HRESULT parallelFor( int threadsCount ) noexcept;
37 | };
38 | }
--------------------------------------------------------------------------------
/Whisper/Whisper/DecoderInputBuffers.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "../ML/Tensor.h"
3 |
4 | namespace DirectCompute
5 | {
6 | // A dynamic buffer
7 | class DecoderInputBuffers
8 | {
9 | CComPtr embd;
10 | uint32_t m_size = 0;
11 | uint32_t m_capacity = 0;
12 |
13 | public:
14 |
15 | void resize( uint32_t size );
16 |
17 | // Create 1D tensor with R32_UINT elements, upload the source data
18 | Tensor embedding( const int* rsi ) const;
19 |
20 | void clear();
21 |
22 | __m128i getMemoryUse() const
23 | {
24 | size_t i = m_capacity;
25 | i *= sizeof( uint32_t );
26 | return _mm_set_epi64x( (int64_t)i, 0 );
27 | }
28 |
29 | HRESULT zeroMemory() const;
30 | };
31 | }
--------------------------------------------------------------------------------
/Whisper/Whisper/DecoderResultBuffer.cpp:
--------------------------------------------------------------------------------
1 | #include "stdafx.h"
2 | #include "DecoderResultBuffer.h"
3 | #include "../D3D/MappedResource.h"
4 | using namespace DirectCompute;
5 |
6 | void DecoderResultBuffer::copyFromVram( const Tensor& rsi )
7 | {
8 | ID3D11ShaderResourceView* srv = rsi;
9 | if( nullptr == srv )
10 | throw OLE_E_BLANK;
11 | if( !rsi.isContinuous() )
12 | throw E_INVALIDARG;
13 |
14 | const uint32_t len = rsi.countElements();
15 | if( len > m_capacity )
16 | {
17 | buffer = nullptr;
18 | CD3D11_BUFFER_DESC desc{ len * 4, 0, D3D11_USAGE_STAGING, D3D11_CPU_ACCESS_READ };
19 | check( device()->CreateBuffer( &desc, nullptr, &buffer ) );
20 | m_capacity = len;
21 | }
22 |
23 | CComPtr source;
24 | srv->GetResource( &source );
25 | // Coordinates of a box are in bytes for buffers
26 | D3D11_BOX box;
27 | store16( &box, _mm_setr_epi32( 0, 0, 0, (int)( len * 4 ) ) );
28 | *(uint64_t*)&box.bottom = 0x100000001ull;
29 | context()->CopySubresourceRegion( buffer, 0, 0, 0, 0, source, 0, &box );
30 | m_size = len;
31 | }
32 |
33 | void DecoderResultBuffer::copyToVector( std::vector& vec ) const
34 | {
35 | vec.resize( m_size );
36 | if( vec.empty() )
37 | throw OLE_E_BLANK;
38 |
39 | MappedResource mapped;
40 | check( mapped.map( buffer, true ) );
41 | memcpy( vec.data(), mapped.data(), (size_t)4 * m_size );
42 | }
43 |
44 | void DecoderResultBuffer::clear()
45 | {
46 | buffer = nullptr;
47 | m_size = m_capacity = 0;
48 | }
--------------------------------------------------------------------------------
/Whisper/Whisper/DecoderResultBuffer.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "../ML/Tensor.h"
3 |
4 | namespace DirectCompute
5 | {
6 | class DecoderResultBuffer
7 | {
8 | CComPtr buffer;
9 | uint32_t m_size = 0;
10 | uint32_t m_capacity = 0;
11 |
12 | public:
13 |
14 | void copyFromVram( const Tensor& rsi );
15 |
16 | void copyToVector( std::vector& vec ) const;
17 |
18 | uint32_t size() const
19 | {
20 | return m_size;
21 | }
22 | void clear();
23 |
24 | __m128i getMemoryUse() const
25 | {
26 | return bufferMemoryUsage( buffer );
27 | }
28 | };
29 | }
--------------------------------------------------------------------------------
/Whisper/Whisper/KeyValueBuffers.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "../ML/Tensor.h"
3 |
4 | namespace DirectCompute
5 | {
6 | // FP16 buffer for self-attention and cross-attention layers
7 | class AttentionBuffer
8 | {
9 | CComPtr buffer;
10 | uint32_t m_size = 0;
11 |
12 | public:
13 | // Create buffer for the specified count of elements
14 | void resize( uint32_t size );
15 |
16 | // Create an 1D tensor which references a slice of that buffer
17 | Tensor view( uint32_t length, uint32_t offset ) const;
18 |
19 | void clear()
20 | {
21 | buffer = nullptr;
22 | m_size = 0;
23 | }
24 |
25 | ID3D11Buffer* getBuffer() const { return buffer; }
26 |
27 | uint32_t getSize() const { return m_size; }
28 |
29 | HRESULT zeroMemory() const;
30 | };
31 |
32 | struct KeyValueBuffers
33 | {
34 | AttentionBuffer keys, values;
35 |
36 | void resize( uint32_t size );
37 |
38 | void clear()
39 | {
40 | keys.clear();
41 | values.clear();
42 | }
43 |
44 | __m128i getMemoryUse() const
45 | {
46 | size_t i = keys.getSize();
47 | i += values.getSize();
48 | i *= sizeof( uint16_t );
49 | return setHigh_size( (int64_t)i ); // They both are in VRAM
50 | }
51 |
52 | HRESULT zeroMemory() const;
53 | };
54 | }
--------------------------------------------------------------------------------
/Whisper/Whisper/Languages.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "../../ComLightLib/comLightCommon.h"
3 |
4 | namespace Whisper
5 | {
6 | int lookupLanguageId( const char* code );
7 | int lookupLanguageId( uint32_t key );
8 |
9 | const char* lookupLanguageName( const char* code );
10 |
11 | int COMLIGHTCALL getLanguageId( const char* lang );
12 | }
--------------------------------------------------------------------------------
/Whisper/Whisper/MelInputTensor.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "../ML/TensorEx.h"
3 | #include "sEncodeParams.h"
4 | #include "iSpectrogram.h"
5 |
6 | namespace DirectCompute
7 | {
8 | // Input tensor in VRAM, in a dynamic FP32 buffer
9 | class MelInputTensor : public TensorEx
10 | {
11 | uint32_t capacity;
12 |
13 | public:
14 |
15 | HRESULT create( Whisper::iSpectrogram& spectrogram, const sEncodeParams& encParams );
16 |
17 | __m128i getMemoryUse() const
18 | {
19 | return setHigh_size( (size_t)capacity * 4 );
20 | }
21 | };
22 | }
--------------------------------------------------------------------------------
/Whisper/Whisper/ModelLoader.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "ModelBuffers.h"
3 | #include